diff options
author | lovetox <philipp@hoerist.com> | 2022-08-28 09:46:49 +0300 |
---|---|---|
committer | Philipp Hörist <philipp@hoerist.com> | 2022-08-28 10:01:40 +0300 |
commit | c9969f4fe6caafa07397fe04d940c5e08a912f33 (patch) | |
tree | c49a0e4e145949c10bc005a7139caa1fa3676161 | |
parent | 85120e7ccc04e835312b6043d54e8c4091c984e0 (diff) |
refactor: SASL: Prepare module for XEP-0388
- Move sending stanzas out of mechanism impl
- Add missing SHA-512-PLUS method
- Add type hints
-rw-r--r-- | nbxmpp/client.py | 2 | ||||
-rw-r--r-- | nbxmpp/protocol.py | 1 | ||||
-rw-r--r-- | nbxmpp/sasl.py (renamed from nbxmpp/auth.py) | 256 | ||||
-rw-r--r-- | test/unit/test_sasl_scram.py | 27 |
4 files changed, 136 insertions, 150 deletions
diff --git a/nbxmpp/client.py b/nbxmpp/client.py index cbf11bc..fbeda6c 100644 --- a/nbxmpp/client.py +++ b/nbxmpp/client.py @@ -39,7 +39,7 @@ from nbxmpp.addresses import NoMoreAddresses from nbxmpp.tcp import TCPConnection from nbxmpp.websocket import WebsocketConnection from nbxmpp.smacks import Smacks -from nbxmpp.auth import SASL +from nbxmpp.sasl import SASL from nbxmpp.const import StreamState from nbxmpp.const import StreamError from nbxmpp.const import ConnectionType diff --git a/nbxmpp/protocol.py b/nbxmpp/protocol.py index 178cf92..6f6cc6e 100644 --- a/nbxmpp/protocol.py +++ b/nbxmpp/protocol.py @@ -44,6 +44,7 @@ def ascii_upper(s): return s.upper() SASL_AUTH_MECHS = [ + 'SCRAM-SHA-512-PLUS', 'SCRAM-SHA-512', 'SCRAM-SHA-256-PLUS', 'SCRAM-SHA-256', diff --git a/nbxmpp/auth.py b/nbxmpp/sasl.py index 965a773..b4c4242 100644 --- a/nbxmpp/auth.py +++ b/nbxmpp/sasl.py @@ -21,6 +21,7 @@ import binascii import logging import hashlib from hashlib import pbkdf2_hmac +from typing import Optional from nbxmpp.namespaces import Namespace from nbxmpp.protocol import Node @@ -32,7 +33,7 @@ from nbxmpp.util import LogAdapter from nbxmpp.const import StreamState -log = logging.getLogger('nbxmpp.auth') +log = logging.getLogger('nbxmpp.sasl') try: gssapi = __import__('gssapi') @@ -51,9 +52,20 @@ class SASL: self._password = None + self._mechanism_classes = { + 'PLAIN': PLAIN, + 'EXTERNAL': EXTERNAL, + 'GSSAPI': GSSAPI, + 'SCRAM-SHA-1': SCRAM_SHA_1, + 'SCRAM-SHA-1-PLUS': SCRAM_SHA_1_PLUS, + 'SCRAM-SHA-256': SCRAM_SHA_256, + 'SCRAM-SHA-256-PLUS': SCRAM_SHA_256_PLUS, + 'SCRAM-SHA-512': SCRAM_SHA_512, + 'SCRAM-SHA-512-PLUS': SCRAM_SHA_512_PLUS + } + self._allowed_mechs = None self._enabled_mechs = None - self._method = None self._error = None self._log = LogAdapter(log, {'context': client.log_context}) @@ -82,13 +94,14 @@ class SASL: def start_auth(self, features): self._allowed_mechs = self._client.mechs self._enabled_mechs = self._allowed_mechs - self._method = None + self._mechanism = None self._error = None # -PLUS variants need TLS channel binding data # This is currently not supported via GLib self._enabled_mechs.discard('SCRAM-SHA-1-PLUS') self._enabled_mechs.discard('SCRAM-SHA-256-PLUS') + self._enabled_mechs.discard('SCRAM-SHA-512-PLUS') # channel_binding_data = None if not GSSAPI_AVAILABLE: @@ -127,69 +140,49 @@ class SASL: self._on_sasl_finished(False, 'no-password') return - # if chosen_mechanism == 'SCRAM-SHA-256-PLUS': - # self._method = SCRAM_SHA_256_PLUS(self._client, - # channel_binding_data) - # self._method.initiate(self._client.username, self._password) - - # elif chosen_mechanism == 'SCRAM-SHA-1-PLUS': - # self._method = SCRAM_SHA_1_PLUS(self._client, - # channel_binding_data) - # self._method.initiate(self._client.username, self._password) - - if chosen_mechanism == 'SCRAM-SHA-512': - self._method = SCRAM_SHA_512(self._client, None) - self._method.initiate(self._client.username, self._password) - - elif chosen_mechanism == 'SCRAM-SHA-256': - self._method = SCRAM_SHA_256(self._client, None) - self._method.initiate(self._client.username, self._password) - - elif chosen_mechanism == 'SCRAM-SHA-1': - self._method = SCRAM_SHA_1(self._client, None) - self._method.initiate(self._client.username, self._password) + mech_class = self._mechanism_classes[chosen_mechanism] + self._mechanism = mech_class(self._client.username, + self._password, + domain_based_name or self._client.domain) - elif chosen_mechanism == 'PLAIN': - self._method = PLAIN(self._client) - self._method.initiate(self._client.username, self._password) - - elif chosen_mechanism == 'ANONYMOUS': - self._method = ANONYMOUS(self._client) - self._method.initiate() # pylint: disable=E1120 - - elif chosen_mechanism == 'EXTERNAL': - self._method = EXTERNAL(self._client) - self._method.initiate(self._client.username, self._client.Server) + try: + self._send_initiate() + except AuthFail as error: + self._log.error(error) + self._abort_auth() + return - elif chosen_mechanism == 'GSSAPI': - self._method = GSSAPI(self._client) - if domain_based_name: - hostname = domain_based_name - else: - hostname = self._client.domain - try: - self._method.initiate(hostname) # pylint: disable=E1120 - except AuthFail as error: - self._log.error(error) - self._abort_auth() - return - else: - self._log.error('Unknown auth mech') + def _send_initiate(self) -> None: + data = self._mechanism.get_initiate_data() + node = Node('auth', + attrs={'xmlns': Namespace.SASL, + 'mechanism': self._mechanism.name}) + if data is not None: + node.setData(data) + self._client.send_nonza(node) - def _on_challenge(self, stanza): + def _on_challenge(self, stanza) -> None: try: - self._method.response(stanza.getData()) + data = self._mechanism.get_response_data(stanza.getData()) except AttributeError: self._log.info('Mechanism has no response method') self._abort_auth() + return + except AuthFail as error: self._log.error(error) self._abort_auth() + return + + node = Node('response', + attrs={'xmlns': Namespace.SASL}, + payload=[data]) + self._client.send_nonza(node) def _on_success(self, stanza): self._log.info('Successfully authenticated with remote server') try: - self._method.success(stanza.getData()) + self._mechanism.get_success_data(stanza.getData()) except AttributeError: pass except AuthFail as error: @@ -227,61 +220,59 @@ class SASL: self._client.set_state(StreamState.AUTH_SUCCESSFUL) -class PLAIN: +class BaseMechanism: - _mechanism = 'PLAIN' + name: str - def __init__(self, client): - self._client = client + def __init__(self, username: str, password: str, domain: str): + self._username = username + self._password = password + self._domain = domain - def initiate(self, username, password): - payload = b64encode('\x00%s\x00%s' % (username, password)) - node = Node('auth', - attrs={'xmlns': Namespace.SASL, 'mechanism': 'PLAIN'}, - payload=[payload]) - self._client.send_nonza(node) + def get_initiate_data(self) -> Optional[str]: + raise NotImplementedError + def get_response_data(self, data: str) -> str: + raise NotImplementedError -class EXTERNAL: + def validate_success_data(self, data: str) -> None: + raise NotImplementedError - _mechanism = 'EXTERNAL' - def __init__(self, client): - self._client = client +class PLAIN(BaseMechanism): - def initiate(self, username, server): - payload = b64encode('%s@%s' % (username, server)) - node = Node('auth', - attrs={'xmlns': Namespace.SASL, 'mechanism': 'EXTERNAL'}, - payload=[payload]) - self._client.send_nonza(node) + name = 'PLAIN' + def get_initiate_data(self) -> str: + return b64encode('\x00%s\x00%s' % (self._username, self._password)) -class ANONYMOUS: - _mechanism = 'ANONYMOUS' +class EXTERNAL(BaseMechanism): - def __init__(self, client): - self._client = client + name = 'EXTERNAL' - def initiate(self): - node = Node('auth', attrs={'xmlns': Namespace.SASL, - 'mechanism': 'ANONYMOUS'}) - self._client.send_nonza(node) + def get_initiate_data(self) -> str: + return b64encode('%s@%s' % (self._username, self._domain)) -class GSSAPI: +class ANONYMOUS(BaseMechanism): - # See https://tools.ietf.org/html/rfc4752#section-3.1 + name = 'ANONYMOUS' - _mechanism = 'GSSAPI' + def get_initiate_data(self) -> None: + return None - def __init__(self, client): - self._client = client - def initiate(self, hostname): +class GSSAPI(BaseMechanism): + + # See https://tools.ietf.org/html/rfc4752#section-3.1 + + name = 'GSSAPI' + + def get_initiate_data(self) -> str: service = gssapi.Name( - 'xmpp@%s' % hostname, name_type=gssapi.NameType.hostbased_service) + 'xmpp@%s' % self._domain, + name_type=gssapi.NameType.hostbased_service) try: self.ctx = gssapi.SecurityContext( name=service, usage="initiate", @@ -289,75 +280,66 @@ class GSSAPI: token = self.ctx.step() except (gssapi.exceptions.GeneralError, gssapi.raw.misc.GSSError) as e: raise AuthFail(e) - node = Node('auth', - attrs={'xmlns': Namespace.SASL, 'mechanism': 'GSSAPI'}, - payload=b64encode(token)) - self._client.send_nonza(node) - def response(self, server_message, *args, **kwargs): - server_message = b64decode(server_message) + return b64encode(token) + + def get_response_data(self, data: str) -> str: + byte_data = b64decode(data) try: if not self.ctx.complete: - output_token = self.ctx.step(server_message) + output_token = self.ctx.step(byte_data) else: - _result = self.ctx.unwrap(server_message) + _result = self.ctx.unwrap(byte_data) # TODO(jelmer): Log result.message data = b'\x00\x00\x00\x00' + bytes(self.ctx.initiator_name) output_token = self.ctx.wrap(data, False).message except (gssapi.exceptions.GeneralError, gssapi.raw.misc.GSSError) as e: raise AuthFail(e) - response = b64encode(output_token) - node = Node('response', - attrs={'xmlns': Namespace.SASL}, - payload=response) - self._client.send_nonza(node) + + return b64encode(output_token) -class SCRAM: +class SCRAM(BaseMechanism): - _mechanism = '' + name = '' _channel_binding = '' _hash_method = '' - def __init__(self, client, channel_binding): - self._client = client - self._channel_binding_data = channel_binding + def __init__(self, *args, **kwargs) -> None: + BaseMechanism.__init__(self, *args, **kwargs) + self._channel_binding_data = None self._client_nonce = '%x' % int(binascii.hexlify(os.urandom(24)), 16) self._client_first_message_bare = None self._server_signature = None - self._password = None + + def set_channel_binding_data(self, data: bytes) -> None: + self._channel_binding_data = data @property - def nonce_length(self): + def nonce_length(self) -> int: return len(self._client_nonce) @property - def _b64_channel_binding_data(self): - if self._mechanism.endswith('PLUS'): + def _b64_channel_binding_data(self) -> str: + if self.name.endswith('PLUS'): return b64encode(b'%s%s' % (self._channel_binding.encode(), self._channel_binding_data)) return b64encode(self._channel_binding) @staticmethod - def _scram_parse(scram_data): + def _scram_parse(scram_data: str) -> dict[str, str]: return dict(s.split('=', 1) for s in scram_data.split(',')) - def initiate(self, username, password): - self._password = password - self._client_first_message_bare = 'n=%s,r=%s' % (username, + def get_initiate_data(self) -> str: + self._client_first_message_bare = 'n=%s,r=%s' % (self._username, self._client_nonce) client_first_message = '%s%s' % (self._channel_binding, self._client_first_message_bare) - payload = b64encode(client_first_message) - node = Node('auth', - attrs={'xmlns': Namespace.SASL, - 'mechanism': self._mechanism}, - payload=[payload]) - self._client.send_nonza(node) + return b64encode(client_first_message) - def response(self, server_first_message): - server_first_message = b64decode(server_first_message).decode() + def get_response_data(self, data) -> str: + server_first_message = b64decode(data).decode() challenge = self._scram_parse(server_first_message) client_nonce = challenge['r'][:self.nonce_length] @@ -397,64 +379,66 @@ class SCRAM: server_key = self._hmac(salted_password, 'Server Key') self._server_signature = self._hmac(server_key, auth_message) - payload = b64encode(client_finale_message) - node = Node('response', - attrs={'xmlns': Namespace.SASL}, - payload=[payload]) - self._client.send_nonza(node) + return b64encode(client_finale_message) - def success(self, server_last_message): - server_last_message = b64decode(server_last_message).decode() + def validate_success_data(self, data: str) -> None: + server_last_message = b64decode(data).decode() success = self._scram_parse(server_last_message) server_signature = b64decode(success['v']) if server_signature != self._server_signature: raise AuthFail('Invalid server signature') - def _hmac(self, key, message): + def _hmac(self, key: bytes, message: str) -> bytes: return hmac.new(key=key, msg=message.encode(), digestmod=self._hash_method).digest() @staticmethod - def _xor(x, y): + def _xor(x: bytes, y: bytes) -> bytes: return bytes([px ^ py for px, py in zip(x, y)]) - def _h(self, data): + def _h(self, data: bytes) -> bytes: return hashlib.new(self._hash_method, data).digest() class SCRAM_SHA_1(SCRAM): - _mechanism = 'SCRAM-SHA-1' + name = 'SCRAM-SHA-1' _channel_binding = 'n,,' _hash_method = 'sha1' class SCRAM_SHA_1_PLUS(SCRAM_SHA_1): - _mechanism = 'SCRAM-SHA-1-PLUS' + name = 'SCRAM-SHA-1-PLUS' _channel_binding = 'p=tls-unique,,' class SCRAM_SHA_256(SCRAM): - _mechanism = 'SCRAM-SHA-256' + name = 'SCRAM-SHA-256' _channel_binding = 'n,,' _hash_method = 'sha256' class SCRAM_SHA_256_PLUS(SCRAM_SHA_256): - _mechanism = 'SCRAM-SHA-256-PLUS' + name = 'SCRAM-SHA-256-PLUS' _channel_binding = 'p=tls-unique,,' class SCRAM_SHA_512(SCRAM): - _mechanism = 'SCRAM-SHA-512' + name = 'SCRAM-SHA-512' _channel_binding = 'n,,' _hash_method = 'sha512' +class SCRAM_SHA_512_PLUS(SCRAM_SHA_512): + + name = 'SCRAM-SHA-512-PLUS' + _channel_binding = 'p=tls-unique,,' + + class AuthFail(Exception): pass diff --git a/test/unit/test_sasl_scram.py b/test/unit/test_sasl_scram.py index d41e48e..2c14e32 100644 --- a/test/unit/test_sasl_scram.py +++ b/test/unit/test_sasl_scram.py @@ -1,33 +1,34 @@ import unittest from unittest.mock import Mock -from nbxmpp.auth import SCRAM_SHA_1 +from nbxmpp.sasl import SCRAM_SHA_1 from nbxmpp.util import b64encode # Test vector from https://wiki.xmpp.org/web/SASL_and_SCRAM-SHA-1 + class SCRAM(unittest.TestCase): def setUp(self): self.con = Mock() - self._method = SCRAM_SHA_1(self.con, None) - self._method._client_nonce = 'fyko+d2lbbFgONRv9qkxdawL' self.maxDiff = None self._username = 'user' self._password = 'pencil' - - self.auth = '<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="SCRAM-SHA-1">%s</auth>' % b64encode('n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL') - self.challenge = b64encode('r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096') - self.response = '<response xmlns="urn:ietf:params:xml:ns:xmpp-sasl">%s</response>' % b64encode('c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=') - self.success = b64encode('v=rmF9pqV8S7suAoZWja4dJRkFsKQ=') + self._mechanism = SCRAM_SHA_1(self._username, self._password, None) + self._mechanism._client_nonce = 'fyko+d2lbbFgONRv9qkxdawL' def test_auth(self): - self._method.initiate(self._username, self._password) - self.assertEqual(self.auth, str(self.con.send_nonza.call_args[0][0])) + initial = b64encode('n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL') + data = self._mechanism.get_initiate_data() + self.assertEqual(data, initial) + + challenge = b64encode('r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096') + data = self._mechanism.get_response_data(challenge) - self._method.response(self.challenge) - self.assertEqual(self.response, str(self.con.send_nonza.call_args[0][0])) + response = b64encode('c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=') + self.assertEqual(data, response) - self._method.success(self.success) + success = b64encode('v=rmF9pqV8S7suAoZWja4dJRkFsKQ=') + self._mechanism.validate_success_data(success) if __name__ == '__main__': |