Welcome to mirror list, hosted at ThFree Co, Russian Federation.

dev.gajim.org/gajim/python-nbxmpp.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlovetox <philipp@hoerist.com>2022-08-28 09:46:49 +0300
committerPhilipp Hörist <philipp@hoerist.com>2022-08-28 10:01:40 +0300
commitc9969f4fe6caafa07397fe04d940c5e08a912f33 (patch)
treec49a0e4e145949c10bc005a7139caa1fa3676161
parent85120e7ccc04e835312b6043d54e8c4091c984e0 (diff)
refactor: SASL: Prepare module for XEP-0388
- Move sending stanzas out of mechanism impl - Add missing SHA-512-PLUS method - Add type hints
-rw-r--r--nbxmpp/client.py2
-rw-r--r--nbxmpp/protocol.py1
-rw-r--r--nbxmpp/sasl.py (renamed from nbxmpp/auth.py)256
-rw-r--r--test/unit/test_sasl_scram.py27
4 files changed, 136 insertions, 150 deletions
diff --git a/nbxmpp/client.py b/nbxmpp/client.py
index cbf11bc..fbeda6c 100644
--- a/nbxmpp/client.py
+++ b/nbxmpp/client.py
@@ -39,7 +39,7 @@ from nbxmpp.addresses import NoMoreAddresses
from nbxmpp.tcp import TCPConnection
from nbxmpp.websocket import WebsocketConnection
from nbxmpp.smacks import Smacks
-from nbxmpp.auth import SASL
+from nbxmpp.sasl import SASL
from nbxmpp.const import StreamState
from nbxmpp.const import StreamError
from nbxmpp.const import ConnectionType
diff --git a/nbxmpp/protocol.py b/nbxmpp/protocol.py
index 178cf92..6f6cc6e 100644
--- a/nbxmpp/protocol.py
+++ b/nbxmpp/protocol.py
@@ -44,6 +44,7 @@ def ascii_upper(s):
return s.upper()
SASL_AUTH_MECHS = [
+ 'SCRAM-SHA-512-PLUS',
'SCRAM-SHA-512',
'SCRAM-SHA-256-PLUS',
'SCRAM-SHA-256',
diff --git a/nbxmpp/auth.py b/nbxmpp/sasl.py
index 965a773..b4c4242 100644
--- a/nbxmpp/auth.py
+++ b/nbxmpp/sasl.py
@@ -21,6 +21,7 @@ import binascii
import logging
import hashlib
from hashlib import pbkdf2_hmac
+from typing import Optional
from nbxmpp.namespaces import Namespace
from nbxmpp.protocol import Node
@@ -32,7 +33,7 @@ from nbxmpp.util import LogAdapter
from nbxmpp.const import StreamState
-log = logging.getLogger('nbxmpp.auth')
+log = logging.getLogger('nbxmpp.sasl')
try:
gssapi = __import__('gssapi')
@@ -51,9 +52,20 @@ class SASL:
self._password = None
+ self._mechanism_classes = {
+ 'PLAIN': PLAIN,
+ 'EXTERNAL': EXTERNAL,
+ 'GSSAPI': GSSAPI,
+ 'SCRAM-SHA-1': SCRAM_SHA_1,
+ 'SCRAM-SHA-1-PLUS': SCRAM_SHA_1_PLUS,
+ 'SCRAM-SHA-256': SCRAM_SHA_256,
+ 'SCRAM-SHA-256-PLUS': SCRAM_SHA_256_PLUS,
+ 'SCRAM-SHA-512': SCRAM_SHA_512,
+ 'SCRAM-SHA-512-PLUS': SCRAM_SHA_512_PLUS
+ }
+
self._allowed_mechs = None
self._enabled_mechs = None
- self._method = None
self._error = None
self._log = LogAdapter(log, {'context': client.log_context})
@@ -82,13 +94,14 @@ class SASL:
def start_auth(self, features):
self._allowed_mechs = self._client.mechs
self._enabled_mechs = self._allowed_mechs
- self._method = None
+ self._mechanism = None
self._error = None
# -PLUS variants need TLS channel binding data
# This is currently not supported via GLib
self._enabled_mechs.discard('SCRAM-SHA-1-PLUS')
self._enabled_mechs.discard('SCRAM-SHA-256-PLUS')
+ self._enabled_mechs.discard('SCRAM-SHA-512-PLUS')
# channel_binding_data = None
if not GSSAPI_AVAILABLE:
@@ -127,69 +140,49 @@ class SASL:
self._on_sasl_finished(False, 'no-password')
return
- # if chosen_mechanism == 'SCRAM-SHA-256-PLUS':
- # self._method = SCRAM_SHA_256_PLUS(self._client,
- # channel_binding_data)
- # self._method.initiate(self._client.username, self._password)
-
- # elif chosen_mechanism == 'SCRAM-SHA-1-PLUS':
- # self._method = SCRAM_SHA_1_PLUS(self._client,
- # channel_binding_data)
- # self._method.initiate(self._client.username, self._password)
-
- if chosen_mechanism == 'SCRAM-SHA-512':
- self._method = SCRAM_SHA_512(self._client, None)
- self._method.initiate(self._client.username, self._password)
-
- elif chosen_mechanism == 'SCRAM-SHA-256':
- self._method = SCRAM_SHA_256(self._client, None)
- self._method.initiate(self._client.username, self._password)
-
- elif chosen_mechanism == 'SCRAM-SHA-1':
- self._method = SCRAM_SHA_1(self._client, None)
- self._method.initiate(self._client.username, self._password)
+ mech_class = self._mechanism_classes[chosen_mechanism]
+ self._mechanism = mech_class(self._client.username,
+ self._password,
+ domain_based_name or self._client.domain)
- elif chosen_mechanism == 'PLAIN':
- self._method = PLAIN(self._client)
- self._method.initiate(self._client.username, self._password)
-
- elif chosen_mechanism == 'ANONYMOUS':
- self._method = ANONYMOUS(self._client)
- self._method.initiate() # pylint: disable=E1120
-
- elif chosen_mechanism == 'EXTERNAL':
- self._method = EXTERNAL(self._client)
- self._method.initiate(self._client.username, self._client.Server)
+ try:
+ self._send_initiate()
+ except AuthFail as error:
+ self._log.error(error)
+ self._abort_auth()
+ return
- elif chosen_mechanism == 'GSSAPI':
- self._method = GSSAPI(self._client)
- if domain_based_name:
- hostname = domain_based_name
- else:
- hostname = self._client.domain
- try:
- self._method.initiate(hostname) # pylint: disable=E1120
- except AuthFail as error:
- self._log.error(error)
- self._abort_auth()
- return
- else:
- self._log.error('Unknown auth mech')
+ def _send_initiate(self) -> None:
+ data = self._mechanism.get_initiate_data()
+ node = Node('auth',
+ attrs={'xmlns': Namespace.SASL,
+ 'mechanism': self._mechanism.name})
+ if data is not None:
+ node.setData(data)
+ self._client.send_nonza(node)
- def _on_challenge(self, stanza):
+ def _on_challenge(self, stanza) -> None:
try:
- self._method.response(stanza.getData())
+ data = self._mechanism.get_response_data(stanza.getData())
except AttributeError:
self._log.info('Mechanism has no response method')
self._abort_auth()
+ return
+
except AuthFail as error:
self._log.error(error)
self._abort_auth()
+ return
+
+ node = Node('response',
+ attrs={'xmlns': Namespace.SASL},
+ payload=[data])
+ self._client.send_nonza(node)
def _on_success(self, stanza):
self._log.info('Successfully authenticated with remote server')
try:
- self._method.success(stanza.getData())
+ self._mechanism.get_success_data(stanza.getData())
except AttributeError:
pass
except AuthFail as error:
@@ -227,61 +220,59 @@ class SASL:
self._client.set_state(StreamState.AUTH_SUCCESSFUL)
-class PLAIN:
+class BaseMechanism:
- _mechanism = 'PLAIN'
+ name: str
- def __init__(self, client):
- self._client = client
+ def __init__(self, username: str, password: str, domain: str):
+ self._username = username
+ self._password = password
+ self._domain = domain
- def initiate(self, username, password):
- payload = b64encode('\x00%s\x00%s' % (username, password))
- node = Node('auth',
- attrs={'xmlns': Namespace.SASL, 'mechanism': 'PLAIN'},
- payload=[payload])
- self._client.send_nonza(node)
+ def get_initiate_data(self) -> Optional[str]:
+ raise NotImplementedError
+ def get_response_data(self, data: str) -> str:
+ raise NotImplementedError
-class EXTERNAL:
+ def validate_success_data(self, data: str) -> None:
+ raise NotImplementedError
- _mechanism = 'EXTERNAL'
- def __init__(self, client):
- self._client = client
+class PLAIN(BaseMechanism):
- def initiate(self, username, server):
- payload = b64encode('%s@%s' % (username, server))
- node = Node('auth',
- attrs={'xmlns': Namespace.SASL, 'mechanism': 'EXTERNAL'},
- payload=[payload])
- self._client.send_nonza(node)
+ name = 'PLAIN'
+ def get_initiate_data(self) -> str:
+ return b64encode('\x00%s\x00%s' % (self._username, self._password))
-class ANONYMOUS:
- _mechanism = 'ANONYMOUS'
+class EXTERNAL(BaseMechanism):
- def __init__(self, client):
- self._client = client
+ name = 'EXTERNAL'
- def initiate(self):
- node = Node('auth', attrs={'xmlns': Namespace.SASL,
- 'mechanism': 'ANONYMOUS'})
- self._client.send_nonza(node)
+ def get_initiate_data(self) -> str:
+ return b64encode('%s@%s' % (self._username, self._domain))
-class GSSAPI:
+class ANONYMOUS(BaseMechanism):
- # See https://tools.ietf.org/html/rfc4752#section-3.1
+ name = 'ANONYMOUS'
- _mechanism = 'GSSAPI'
+ def get_initiate_data(self) -> None:
+ return None
- def __init__(self, client):
- self._client = client
- def initiate(self, hostname):
+class GSSAPI(BaseMechanism):
+
+ # See https://tools.ietf.org/html/rfc4752#section-3.1
+
+ name = 'GSSAPI'
+
+ def get_initiate_data(self) -> str:
service = gssapi.Name(
- 'xmpp@%s' % hostname, name_type=gssapi.NameType.hostbased_service)
+ 'xmpp@%s' % self._domain,
+ name_type=gssapi.NameType.hostbased_service)
try:
self.ctx = gssapi.SecurityContext(
name=service, usage="initiate",
@@ -289,75 +280,66 @@ class GSSAPI:
token = self.ctx.step()
except (gssapi.exceptions.GeneralError, gssapi.raw.misc.GSSError) as e:
raise AuthFail(e)
- node = Node('auth',
- attrs={'xmlns': Namespace.SASL, 'mechanism': 'GSSAPI'},
- payload=b64encode(token))
- self._client.send_nonza(node)
- def response(self, server_message, *args, **kwargs):
- server_message = b64decode(server_message)
+ return b64encode(token)
+
+ def get_response_data(self, data: str) -> str:
+ byte_data = b64decode(data)
try:
if not self.ctx.complete:
- output_token = self.ctx.step(server_message)
+ output_token = self.ctx.step(byte_data)
else:
- _result = self.ctx.unwrap(server_message)
+ _result = self.ctx.unwrap(byte_data)
# TODO(jelmer): Log result.message
data = b'\x00\x00\x00\x00' + bytes(self.ctx.initiator_name)
output_token = self.ctx.wrap(data, False).message
except (gssapi.exceptions.GeneralError, gssapi.raw.misc.GSSError) as e:
raise AuthFail(e)
- response = b64encode(output_token)
- node = Node('response',
- attrs={'xmlns': Namespace.SASL},
- payload=response)
- self._client.send_nonza(node)
+
+ return b64encode(output_token)
-class SCRAM:
+class SCRAM(BaseMechanism):
- _mechanism = ''
+ name = ''
_channel_binding = ''
_hash_method = ''
- def __init__(self, client, channel_binding):
- self._client = client
- self._channel_binding_data = channel_binding
+ def __init__(self, *args, **kwargs) -> None:
+ BaseMechanism.__init__(self, *args, **kwargs)
+ self._channel_binding_data = None
self._client_nonce = '%x' % int(binascii.hexlify(os.urandom(24)), 16)
self._client_first_message_bare = None
self._server_signature = None
- self._password = None
+
+ def set_channel_binding_data(self, data: bytes) -> None:
+ self._channel_binding_data = data
@property
- def nonce_length(self):
+ def nonce_length(self) -> int:
return len(self._client_nonce)
@property
- def _b64_channel_binding_data(self):
- if self._mechanism.endswith('PLUS'):
+ def _b64_channel_binding_data(self) -> str:
+ if self.name.endswith('PLUS'):
return b64encode(b'%s%s' % (self._channel_binding.encode(),
self._channel_binding_data))
return b64encode(self._channel_binding)
@staticmethod
- def _scram_parse(scram_data):
+ def _scram_parse(scram_data: str) -> dict[str, str]:
return dict(s.split('=', 1) for s in scram_data.split(','))
- def initiate(self, username, password):
- self._password = password
- self._client_first_message_bare = 'n=%s,r=%s' % (username,
+ def get_initiate_data(self) -> str:
+ self._client_first_message_bare = 'n=%s,r=%s' % (self._username,
self._client_nonce)
client_first_message = '%s%s' % (self._channel_binding,
self._client_first_message_bare)
- payload = b64encode(client_first_message)
- node = Node('auth',
- attrs={'xmlns': Namespace.SASL,
- 'mechanism': self._mechanism},
- payload=[payload])
- self._client.send_nonza(node)
+ return b64encode(client_first_message)
- def response(self, server_first_message):
- server_first_message = b64decode(server_first_message).decode()
+ def get_response_data(self, data) -> str:
+ server_first_message = b64decode(data).decode()
challenge = self._scram_parse(server_first_message)
client_nonce = challenge['r'][:self.nonce_length]
@@ -397,64 +379,66 @@ class SCRAM:
server_key = self._hmac(salted_password, 'Server Key')
self._server_signature = self._hmac(server_key, auth_message)
- payload = b64encode(client_finale_message)
- node = Node('response',
- attrs={'xmlns': Namespace.SASL},
- payload=[payload])
- self._client.send_nonza(node)
+ return b64encode(client_finale_message)
- def success(self, server_last_message):
- server_last_message = b64decode(server_last_message).decode()
+ def validate_success_data(self, data: str) -> None:
+ server_last_message = b64decode(data).decode()
success = self._scram_parse(server_last_message)
server_signature = b64decode(success['v'])
if server_signature != self._server_signature:
raise AuthFail('Invalid server signature')
- def _hmac(self, key, message):
+ def _hmac(self, key: bytes, message: str) -> bytes:
return hmac.new(key=key,
msg=message.encode(),
digestmod=self._hash_method).digest()
@staticmethod
- def _xor(x, y):
+ def _xor(x: bytes, y: bytes) -> bytes:
return bytes([px ^ py for px, py in zip(x, y)])
- def _h(self, data):
+ def _h(self, data: bytes) -> bytes:
return hashlib.new(self._hash_method, data).digest()
class SCRAM_SHA_1(SCRAM):
- _mechanism = 'SCRAM-SHA-1'
+ name = 'SCRAM-SHA-1'
_channel_binding = 'n,,'
_hash_method = 'sha1'
class SCRAM_SHA_1_PLUS(SCRAM_SHA_1):
- _mechanism = 'SCRAM-SHA-1-PLUS'
+ name = 'SCRAM-SHA-1-PLUS'
_channel_binding = 'p=tls-unique,,'
class SCRAM_SHA_256(SCRAM):
- _mechanism = 'SCRAM-SHA-256'
+ name = 'SCRAM-SHA-256'
_channel_binding = 'n,,'
_hash_method = 'sha256'
class SCRAM_SHA_256_PLUS(SCRAM_SHA_256):
- _mechanism = 'SCRAM-SHA-256-PLUS'
+ name = 'SCRAM-SHA-256-PLUS'
_channel_binding = 'p=tls-unique,,'
class SCRAM_SHA_512(SCRAM):
- _mechanism = 'SCRAM-SHA-512'
+ name = 'SCRAM-SHA-512'
_channel_binding = 'n,,'
_hash_method = 'sha512'
+class SCRAM_SHA_512_PLUS(SCRAM_SHA_512):
+
+ name = 'SCRAM-SHA-512-PLUS'
+ _channel_binding = 'p=tls-unique,,'
+
+
class AuthFail(Exception):
pass
diff --git a/test/unit/test_sasl_scram.py b/test/unit/test_sasl_scram.py
index d41e48e..2c14e32 100644
--- a/test/unit/test_sasl_scram.py
+++ b/test/unit/test_sasl_scram.py
@@ -1,33 +1,34 @@
import unittest
from unittest.mock import Mock
-from nbxmpp.auth import SCRAM_SHA_1
+from nbxmpp.sasl import SCRAM_SHA_1
from nbxmpp.util import b64encode
# Test vector from https://wiki.xmpp.org/web/SASL_and_SCRAM-SHA-1
+
class SCRAM(unittest.TestCase):
def setUp(self):
self.con = Mock()
- self._method = SCRAM_SHA_1(self.con, None)
- self._method._client_nonce = 'fyko+d2lbbFgONRv9qkxdawL'
self.maxDiff = None
self._username = 'user'
self._password = 'pencil'
-
- self.auth = '<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="SCRAM-SHA-1">%s</auth>' % b64encode('n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL')
- self.challenge = b64encode('r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096')
- self.response = '<response xmlns="urn:ietf:params:xml:ns:xmpp-sasl">%s</response>' % b64encode('c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=')
- self.success = b64encode('v=rmF9pqV8S7suAoZWja4dJRkFsKQ=')
+ self._mechanism = SCRAM_SHA_1(self._username, self._password, None)
+ self._mechanism._client_nonce = 'fyko+d2lbbFgONRv9qkxdawL'
def test_auth(self):
- self._method.initiate(self._username, self._password)
- self.assertEqual(self.auth, str(self.con.send_nonza.call_args[0][0]))
+ initial = b64encode('n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL')
+ data = self._mechanism.get_initiate_data()
+ self.assertEqual(data, initial)
+
+ challenge = b64encode('r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096')
+ data = self._mechanism.get_response_data(challenge)
- self._method.response(self.challenge)
- self.assertEqual(self.response, str(self.con.send_nonza.call_args[0][0]))
+ response = b64encode('c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=')
+ self.assertEqual(data, response)
- self._method.success(self.success)
+ success = b64encode('v=rmF9pqV8S7suAoZWja4dJRkFsKQ=')
+ self._mechanism.validate_success_data(success)
if __name__ == '__main__':