From f1305fc956c90cbf2b20d7e5b8fa4d339585d216 Mon Sep 17 00:00:00 2001 From: lovetox Date: Sun, 28 Aug 2022 20:20:12 +0200 Subject: feat: Add support for Extensible SASL Profile (XEP-0388) --- nbxmpp/client.py | 7 +++-- nbxmpp/namespaces.py | 1 + nbxmpp/protocol.py | 9 ++++-- nbxmpp/sasl.py | 78 ++++++++++++++++++++++++++++++++++++++-------------- 4 files changed, 71 insertions(+), 24 deletions(-) diff --git a/nbxmpp/client.py b/nbxmpp/client.py index fbeda6c..d16a009 100644 --- a/nbxmpp/client.py +++ b/nbxmpp/client.py @@ -692,7 +692,10 @@ class Client(Observable): elif self.state == StreamState.AUTH_SUCCESSFUL: self._stream_authenticated = True - self._start_stream() + if self._sasl.is_sasl2(): + self.state = StreamState.WAIT_FOR_FEATURES + else: + self._start_stream() elif self.state == StreamState.AUTH_FAILED: self._disconnect_with_error(StreamError.SASL, @@ -786,7 +789,7 @@ class Client(Observable): self.state = StreamState.WAIT_FOR_TLS_PROCEED def _start_auth(self, features): - if not features.has_sasl(): + if not features.has_sasl() and not features.has_sasl_2(): self._log.error('Server does not support SASL') self._disconnect_with_error(StreamError.SASL, 'sasl-not-supported') diff --git a/nbxmpp/namespaces.py b/nbxmpp/namespaces.py index bd6301e..edd6d59 100644 --- a/nbxmpp/namespaces.py +++ b/nbxmpp/namespaces.py @@ -143,6 +143,7 @@ class _Namespaces: ROSTER_VER: str = 'urn:xmpp:features:rosterver' RSM: str = 'http://jabber.org/protocol/rsm' SASL: str = 'urn:ietf:params:xml:ns:xmpp-sasl' + SASL2: str = 'urn:xmpp:sasl:1' SEARCH: str = 'jabber:iq:search' SECLABEL: str = 'urn:xmpp:sec-label:0' SECLABEL_CATALOG: str = 'urn:xmpp:sec-label:catalog:2' diff --git a/nbxmpp/protocol.py b/nbxmpp/protocol.py index 6f6cc6e..6833284 100644 --- a/nbxmpp/protocol.py +++ b/nbxmpp/protocol.py @@ -1709,10 +1709,15 @@ class Features(Node): def has_sasl(self): return self.getTag('mechanisms', namespace=Namespace.SASL) is not None + def has_sasl_2(self): + return self.getTag('mechanisms', namespace=Namespace.SASL2) is not None + def get_mechs(self) -> set[str]: - mechanisms = self.getTag('mechanisms', namespace=Namespace.SASL) + mechanisms = self.getTag('mechanisms', namespace=Namespace.SASL2) if mechanisms is None: - return set() + mechanisms = self.getTag('mechanisms', namespace=Namespace.SASL) + if mechanisms is None: + return set() mechanisms = mechanisms.getTags('mechanism') return set(mech.getData() for mech in mechanisms) diff --git a/nbxmpp/sasl.py b/nbxmpp/sasl.py index b4c4242..619b12d 100644 --- a/nbxmpp/sasl.py +++ b/nbxmpp/sasl.py @@ -15,13 +15,15 @@ # You should have received a copy of the GNU General Public License # along with this program; If not, see . +from typing import Any +from typing import Optional + import os import hmac import binascii import logging import hashlib from hashlib import pbkdf2_hmac -from typing import Optional from nbxmpp.namespaces import Namespace from nbxmpp.protocol import Node @@ -66,6 +68,8 @@ class SASL: self._allowed_mechs = None self._enabled_mechs = None + self._sasl_ns = None + self._mechanism = None self._error = None self._log = LogAdapter(log, {'context': client.log_context}) @@ -74,6 +78,10 @@ class SASL: def error(self): return self._error + def is_sasl2(self) -> bool: + assert self._sasl_ns is not None + return self._sasl_ns == Namespace.SASL2 + def set_password(self, password): self._password = password @@ -82,8 +90,9 @@ class SASL: return self._password def delegate(self, stanza): - if stanza.getNamespace() != Namespace.SASL: + if stanza.getNamespace() != self._sasl_ns: return + if stanza.getName() == 'challenge': self._on_challenge(stanza) elif stanza.getName() == 'failure': @@ -95,6 +104,11 @@ class SASL: self._allowed_mechs = self._client.mechs self._enabled_mechs = self._allowed_mechs self._mechanism = None + + self._sasl_ns = Namespace.SASL + if features.has_sasl_2(): + self._sasl_ns = Namespace.SASL2 + self._error = None # -PLUS variants need TLS channel binding data @@ -153,15 +167,13 @@ class SASL: return def _send_initiate(self) -> None: + assert self._mechanism is not None data = self._mechanism.get_initiate_data() - node = Node('auth', - attrs={'xmlns': Namespace.SASL, - 'mechanism': self._mechanism.name}) - if data is not None: - node.setData(data) - self._client.send_nonza(node) + nonza = get_initiate_nonza(self._sasl_ns, self._mechanism.name, data) + self._client.send_nonza(nonza) def _on_challenge(self, stanza) -> None: + assert self._mechanism is not None try: data = self._mechanism.get_response_data(stanza.getData()) except AttributeError: @@ -174,22 +186,21 @@ class SASL: self._abort_auth() return - node = Node('response', - attrs={'xmlns': Namespace.SASL}, - payload=[data]) - self._client.send_nonza(node) + nonza = get_response_nonza(self._sasl_ns, data) + self._client.send_nonza(nonza) def _on_success(self, stanza): self._log.info('Successfully authenticated with remote server') + data = get_success_data(stanza, self._sasl_ns) try: - self._mechanism.get_success_data(stanza.getData()) - except AttributeError: - pass - except AuthFail as error: - self._log.error(error) + self._mechanism.validate_success_data(data) + except Exception as error: + self._log.error('Unable to validate success data: %s', error) self._abort_auth() return + self._log.info('Validated success data') + self._on_sasl_finished(True, None, None) def _on_failure(self, stanza): @@ -208,7 +219,7 @@ class SASL: self._abort_auth(reason, text) def _abort_auth(self, reason='malformed-request', text=None): - node = Node('abort', attrs={'xmlns': Namespace.SASL}) + node = Node('abort', attrs={'xmlns': self._sasl_ns}) self._client.send_nonza(node) self._on_sasl_finished(False, reason, text) @@ -220,6 +231,33 @@ class SASL: self._client.set_state(StreamState.AUTH_SUCCESSFUL) +def get_initiate_nonza(ns: str, + mechanism: str, + data: Optional[str]) -> Any: + + if ns == Namespace.SASL: + node = Node('auth', attrs={'xmlns': ns, 'mechanism': mechanism}) + if data is not None: + node.setData(data) + + else: + node = Node('authenticate', attrs={'xmlns': ns, 'mechanism': mechanism}) + if data is not None: + node.setTagData('initial-response', data) + + return node + + +def get_response_nonza(ns: str, data: str) -> Any: + return Node('response', attrs={'xmlns': ns}, payload=[data]) + + +def get_success_data(stanza: Any, ns: str) -> Optional[str]: + if ns == Namespace.SASL2: + return stanza.getTagData('additional-data') + return stanza.getData() + + class BaseMechanism: name: str @@ -235,8 +273,8 @@ class BaseMechanism: def get_response_data(self, data: str) -> str: raise NotImplementedError - def validate_success_data(self, data: str) -> None: - raise NotImplementedError + def validate_success_data(self, _data: str) -> None: + return None class PLAIN(BaseMechanism): -- cgit v1.2.3