diff options
Diffstat (limited to 'nbxmpp/sasl.py')
-rw-r--r-- | nbxmpp/sasl.py | 67 |
1 files changed, 43 insertions, 24 deletions
diff --git a/nbxmpp/sasl.py b/nbxmpp/sasl.py index 3fbd935..1470ec3 100644 --- a/nbxmpp/sasl.py +++ b/nbxmpp/sasl.py @@ -25,10 +25,13 @@ import logging import hashlib from hashlib import pbkdf2_hmac +from gi.repository import Gio + from nbxmpp.namespaces import Namespace from nbxmpp.protocol import Node from nbxmpp.protocol import SASL_ERROR_CONDITIONS from nbxmpp.protocol import SASL_AUTH_MECHS +from nbxmpp.structs import ChannelBindingData from nbxmpp.util import b64decode from nbxmpp.util import b64encode from nbxmpp.util import LogAdapter @@ -100,23 +103,41 @@ class SASL: elif stanza.getName() == 'success': self._on_success(stanza) + def _get_channel_binding_data(self, features) -> Optional[ChannelBindingData]: + if self._client.tls_version != Gio.TlsProtocolVersion.TLS_1_3: + return None + + binding_type = features.get_channel_binding_type() + if binding_type is None: + return None + + channel_binding_data = self._client.get_channel_binding_data(binding_type) + if channel_binding_data is None: + return None + + return ChannelBindingData(binding_type, channel_binding_data) + def start_auth(self, features): + self._mechanism = None self._allowed_mechs = self._client.mechs self._enabled_mechs = self._allowed_mechs - self._mechanism = None self._sasl_ns = Namespace.SASL if features.has_sasl_2(): self._sasl_ns = Namespace.SASL2 + self._log.info('Using %s', self._sasl_ns) + self._error = None - # -PLUS variants need TLS channel binding data - # This is currently not supported via GLib - self._enabled_mechs.discard('SCRAM-SHA-1-PLUS') - self._enabled_mechs.discard('SCRAM-SHA-256-PLUS') - self._enabled_mechs.discard('SCRAM-SHA-512-PLUS') - # channel_binding_data = None + channel_binding_data = None + # Segfaults see https://gitlab.gnome.org/GNOME/pygobject/-/issues/603 + # So for now channel binding is deactivated + # channel_binding_data = self._get_channel_binding_data(features) + if channel_binding_data is None: + self._enabled_mechs.discard('SCRAM-SHA-1-PLUS') + self._enabled_mechs.discard('SCRAM-SHA-256-PLUS') + self._enabled_mechs.discard('SCRAM-SHA-512-PLUS') if not GSSAPI_AVAILABLE: self._enabled_mechs.discard('GSSAPI') @@ -146,10 +167,7 @@ class SASL: self._log.info('Chosen auth mechanism: %s', chosen_mechanism) - if chosen_mechanism in ('SCRAM-SHA-512', - 'SCRAM-SHA-256', - 'SCRAM-SHA-1', - 'PLAIN'): + if chosen_mechanism.startswith(('SCRAM', 'PLAIN')): if not self._password: self._on_sasl_finished(False, 'no-password') return @@ -159,6 +177,10 @@ class SASL: self._password, domain_based_name or self._client.domain) + if (isinstance(self._mechanism, SCRAM) and + channel_binding_data is not None): + self._mechanism.set_channel_binding_data(channel_binding_data) + try: self._send_initiate() except AuthFail as error: @@ -340,18 +362,19 @@ class GSSAPI(BaseMechanism): class SCRAM(BaseMechanism): name = '' - _channel_binding = '' _hash_method = '' def __init__(self, *args, **kwargs) -> None: BaseMechanism.__init__(self, *args, **kwargs) - self._channel_binding_data = None + self._channel_binding_data: ChannelBindingData | None = None + self._gs2_header = 'n,,' self._client_nonce = '%x' % int(binascii.hexlify(os.urandom(24)), 16) self._client_first_message_bare = None self._server_signature = None - def set_channel_binding_data(self, data: bytes) -> None: + def set_channel_binding_data(self, data: ChannelBindingData) -> None: self._channel_binding_data = data + self._gs2_header = f'p={data.type},,' @property def nonce_length(self) -> int: @@ -360,9 +383,10 @@ class SCRAM(BaseMechanism): @property def _b64_channel_binding_data(self) -> str: if self.name.endswith('PLUS'): - return b64encode(b'%s%s' % (self._channel_binding.encode(), - self._channel_binding_data)) - return b64encode(self._channel_binding) + assert self._channel_binding_data is not None + return b64encode(b'%s%s' % (self._gs2_header.encode(), + self._channel_binding_data.data)) + return b64encode(self._gs2_header) @staticmethod def _scram_parse(scram_data: str) -> dict[str, str]: @@ -371,7 +395,8 @@ class SCRAM(BaseMechanism): def get_initiate_data(self) -> str: self._client_first_message_bare = 'n=%s,r=%s' % (self._username, self._client_nonce) - client_first_message = '%s%s' % (self._channel_binding, + + client_first_message = '%s%s' % (self._gs2_header, self._client_first_message_bare) return b64encode(client_first_message) @@ -442,40 +467,34 @@ class SCRAM(BaseMechanism): class SCRAM_SHA_1(SCRAM): name = 'SCRAM-SHA-1' - _channel_binding = 'n,,' _hash_method = 'sha1' class SCRAM_SHA_1_PLUS(SCRAM_SHA_1): name = 'SCRAM-SHA-1-PLUS' - _channel_binding = 'p=tls-unique,,' class SCRAM_SHA_256(SCRAM): name = 'SCRAM-SHA-256' - _channel_binding = 'n,,' _hash_method = 'sha256' class SCRAM_SHA_256_PLUS(SCRAM_SHA_256): name = 'SCRAM-SHA-256-PLUS' - _channel_binding = 'p=tls-unique,,' class SCRAM_SHA_512(SCRAM): name = 'SCRAM-SHA-512' - _channel_binding = 'n,,' _hash_method = 'sha512' class SCRAM_SHA_512_PLUS(SCRAM_SHA_512): name = 'SCRAM-SHA-512-PLUS' - _channel_binding = 'p=tls-unique,,' class AuthFail(Exception): |