Welcome to mirror list, hosted at ThFree Co, Russian Federation.

dev.gajim.org/gajim/python-nbxmpp.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'nbxmpp/sasl.py')
-rw-r--r--nbxmpp/sasl.py78
1 files changed, 58 insertions, 20 deletions
diff --git a/nbxmpp/sasl.py b/nbxmpp/sasl.py
index b4c4242..619b12d 100644
--- a/nbxmpp/sasl.py
+++ b/nbxmpp/sasl.py
@@ -15,13 +15,15 @@
# You should have received a copy of the GNU General Public License
# along with this program; If not, see <http://www.gnu.org/licenses/>.
+from typing import Any
+from typing import Optional
+
import os
import hmac
import binascii
import logging
import hashlib
from hashlib import pbkdf2_hmac
-from typing import Optional
from nbxmpp.namespaces import Namespace
from nbxmpp.protocol import Node
@@ -66,6 +68,8 @@ class SASL:
self._allowed_mechs = None
self._enabled_mechs = None
+ self._sasl_ns = None
+ self._mechanism = None
self._error = None
self._log = LogAdapter(log, {'context': client.log_context})
@@ -74,6 +78,10 @@ class SASL:
def error(self):
return self._error
+ def is_sasl2(self) -> bool:
+ assert self._sasl_ns is not None
+ return self._sasl_ns == Namespace.SASL2
+
def set_password(self, password):
self._password = password
@@ -82,8 +90,9 @@ class SASL:
return self._password
def delegate(self, stanza):
- if stanza.getNamespace() != Namespace.SASL:
+ if stanza.getNamespace() != self._sasl_ns:
return
+
if stanza.getName() == 'challenge':
self._on_challenge(stanza)
elif stanza.getName() == 'failure':
@@ -95,6 +104,11 @@ class SASL:
self._allowed_mechs = self._client.mechs
self._enabled_mechs = self._allowed_mechs
self._mechanism = None
+
+ self._sasl_ns = Namespace.SASL
+ if features.has_sasl_2():
+ self._sasl_ns = Namespace.SASL2
+
self._error = None
# -PLUS variants need TLS channel binding data
@@ -153,15 +167,13 @@ class SASL:
return
def _send_initiate(self) -> None:
+ assert self._mechanism is not None
data = self._mechanism.get_initiate_data()
- node = Node('auth',
- attrs={'xmlns': Namespace.SASL,
- 'mechanism': self._mechanism.name})
- if data is not None:
- node.setData(data)
- self._client.send_nonza(node)
+ nonza = get_initiate_nonza(self._sasl_ns, self._mechanism.name, data)
+ self._client.send_nonza(nonza)
def _on_challenge(self, stanza) -> None:
+ assert self._mechanism is not None
try:
data = self._mechanism.get_response_data(stanza.getData())
except AttributeError:
@@ -174,22 +186,21 @@ class SASL:
self._abort_auth()
return
- node = Node('response',
- attrs={'xmlns': Namespace.SASL},
- payload=[data])
- self._client.send_nonza(node)
+ nonza = get_response_nonza(self._sasl_ns, data)
+ self._client.send_nonza(nonza)
def _on_success(self, stanza):
self._log.info('Successfully authenticated with remote server')
+ data = get_success_data(stanza, self._sasl_ns)
try:
- self._mechanism.get_success_data(stanza.getData())
- except AttributeError:
- pass
- except AuthFail as error:
- self._log.error(error)
+ self._mechanism.validate_success_data(data)
+ except Exception as error:
+ self._log.error('Unable to validate success data: %s', error)
self._abort_auth()
return
+ self._log.info('Validated success data')
+
self._on_sasl_finished(True, None, None)
def _on_failure(self, stanza):
@@ -208,7 +219,7 @@ class SASL:
self._abort_auth(reason, text)
def _abort_auth(self, reason='malformed-request', text=None):
- node = Node('abort', attrs={'xmlns': Namespace.SASL})
+ node = Node('abort', attrs={'xmlns': self._sasl_ns})
self._client.send_nonza(node)
self._on_sasl_finished(False, reason, text)
@@ -220,6 +231,33 @@ class SASL:
self._client.set_state(StreamState.AUTH_SUCCESSFUL)
+def get_initiate_nonza(ns: str,
+ mechanism: str,
+ data: Optional[str]) -> Any:
+
+ if ns == Namespace.SASL:
+ node = Node('auth', attrs={'xmlns': ns, 'mechanism': mechanism})
+ if data is not None:
+ node.setData(data)
+
+ else:
+ node = Node('authenticate', attrs={'xmlns': ns, 'mechanism': mechanism})
+ if data is not None:
+ node.setTagData('initial-response', data)
+
+ return node
+
+
+def get_response_nonza(ns: str, data: str) -> Any:
+ return Node('response', attrs={'xmlns': ns}, payload=[data])
+
+
+def get_success_data(stanza: Any, ns: str) -> Optional[str]:
+ if ns == Namespace.SASL2:
+ return stanza.getTagData('additional-data')
+ return stanza.getData()
+
+
class BaseMechanism:
name: str
@@ -235,8 +273,8 @@ class BaseMechanism:
def get_response_data(self, data: str) -> str:
raise NotImplementedError
- def validate_success_data(self, data: str) -> None:
- raise NotImplementedError
+ def validate_success_data(self, _data: str) -> None:
+ return None
class PLAIN(BaseMechanism):