From d11d4e93d0ec312f74cd7559e153e3d8d07c8546 Mon Sep 17 00:00:00 2001 From: Kjell Braden Date: Sun, 22 Sep 2013 17:04:49 +0200 Subject: gotr: update provided potr to 1.0.0beta7 --- gotr/potr/__init__.py | 2 +- gotr/potr/compatcrypto/common.py | 14 ++- gotr/potr/compatcrypto/pycrypto.py | 25 ++--- gotr/potr/context.py | 156 ++++++++++++++++++++---------- gotr/potr/crypt.py | 188 +++++++++++++++++++------------------ gotr/potr/proto.py | 172 ++++++++++++++++++++------------- gotr/potr/utils.py | 5 +- 7 files changed, 333 insertions(+), 229 deletions(-) (limited to 'gotr') diff --git a/gotr/potr/__init__.py b/gotr/potr/__init__.py index 965aed2..c6e8d02 100644 --- a/gotr/potr/__init__.py +++ b/gotr/potr/__init__.py @@ -24,4 +24,4 @@ from potr.utils import human_hash ''' version is: (major, minor, patch, sub) with sub being one of 'alpha', 'beta', 'final' ''' -VERSION = (1, 0, 0, 'beta5') +VERSION = (1, 0, 0, 'beta7') diff --git a/gotr/potr/compatcrypto/common.py b/gotr/potr/compatcrypto/common.py index 5d6af40..61f2bea 100644 --- a/gotr/potr/compatcrypto/common.py +++ b/gotr/potr/compatcrypto/common.py @@ -26,8 +26,8 @@ from potr.utils import human_hash, bytes_to_long, unpack, pack_mpi DEFAULT_KEYTYPE = 0x0000 pkTypes = {} def registerkeytype(cls): - if not hasattr(cls, 'parsePayload'): - raise TypeError('registered key types need parsePayload()') + if cls.keyType is None: + raise TypeError('registered key class needs a type value') pkTypes[cls.keyType] = cls return cls @@ -35,12 +35,16 @@ def generateDefaultKey(): return pkTypes[DEFAULT_KEYTYPE].generate() class PK(object): - __slots__ = [] + keyType = None @classmethod def generate(cls): raise NotImplementedError + @classmethod + def parsePayload(cls, data, private=False): + raise NotImplementedError + def sign(self, data): raise NotImplementedError def verify(self, data): @@ -80,13 +84,13 @@ class PK(object): @classmethod def parsePrivateKey(cls, data): implCls, data = cls.getImplementation(data) - logging.debug('Got privkey of type %r' % implCls) + logging.debug('Got privkey of type %r', implCls) return implCls.parsePayload(data, private=True) @classmethod def parsePublicKey(cls, data): implCls, data = cls.getImplementation(data) - logging.debug('Got pubkey of type %r' % implCls) + logging.debug('Got pubkey of type %r', implCls) return implCls.parsePayload(data) def __str__(self): diff --git a/gotr/potr/compatcrypto/pycrypto.py b/gotr/potr/compatcrypto/pycrypto.py index dd93295..2800431 100644 --- a/gotr/potr/compatcrypto/pycrypto.py +++ b/gotr/potr/compatcrypto/pycrypto.py @@ -15,18 +15,16 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from Crypto import Cipher, Random +from Crypto import Cipher from Crypto.Hash import SHA256 as _SHA256 -from Crypto.Hash import SHA as _SHA1 +from Crypto.Hash import SHA as _SHA1 from Crypto.Hash import HMAC as _HMAC from Crypto.PublicKey import DSA +from Crypto.Random import random from numbers import Number from potr.compatcrypto import common -from potr.utils import pack_mpi, read_mpi, bytes_to_long, long_to_bytes - -# XXX atfork? -RNG = Random.new() +from potr.utils import read_mpi, bytes_to_long, long_to_bytes def SHA256(data): return _SHA256.new(data).digest() @@ -54,7 +52,6 @@ def AESCTR(key, counter=0): return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter) class Counter(object): - __slots__ = ['prefix', 'val'] def __init__(self, prefix): self.prefix = prefix self.val = 0 @@ -72,17 +69,15 @@ class Counter(object): return ''.format(p=self.prefix, v=self.val) def byteprefix(self): - return long_to_bytes(self.prefix).rjust(8, b'\0') + return long_to_bytes(self.prefix, 8) def __call__(self): - val = long_to_bytes(self.val) - prefix = long_to_bytes(self.prefix) + bytesuffix = long_to_bytes(self.val, 8) self.val += 1 - return self.byteprefix() + val.rjust(8, b'\0') + return self.byteprefix() + bytesuffix @common.registerkeytype class DSAKey(common.PK): - __slots__ = ['priv', 'pub'] keyType = 0x0000 def __init__(self, key=None, private=False): @@ -111,10 +106,10 @@ class DSAKey(common.PK): return SHA1(self.getSerializedPublicPayload()) def sign(self, data): - # 2 <= K <= q = 160bit = 20 byte - K = bytes_to_long(RNG.read(19)) + 2 + # 2 <= K <= q + K = random.randrange(2, self.priv.q) r, s = self.priv.sign(data, K) - return long_to_bytes(r) + long_to_bytes(s) + return long_to_bytes(r, 20) + long_to_bytes(s, 20) def verify(self, data, sig): r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:]) diff --git a/gotr/potr/context.py b/gotr/potr/context.py index aa99f3a..3ec3d74 100644 --- a/gotr/potr/context.py +++ b/gotr/potr/context.py @@ -19,7 +19,7 @@ from __future__ import unicode_literals try: - basestring = basestring + type(basestring) except NameError: # all strings are unicode in python3k basestring = str @@ -27,7 +27,7 @@ except NameError: # callable is not available in python 3.0 and 3.1 try: - callable = callable + type(callable) except NameError: from collections import Callable def callable(x): @@ -42,6 +42,7 @@ logger = logging.getLogger(__name__) from potr import crypt from potr import proto +from potr import compatcrypto from time import time @@ -62,16 +63,11 @@ OFFER_REJECTED = 2 OFFER_ACCEPTED = 3 class Context(object): - __slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend', - 'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state', - 'inject', 'trust', 'peer', 'trustName'] - def __init__(self, account, peername): self.user = account self.peer = peername self.policy = {} self.crypto = crypt.CryptEngine(self) - self.discardFragment() self.tagOffer = OFFER_NOTSENT self.mayRetransmit = 0 self.lastSend = 0 @@ -79,6 +75,10 @@ class Context(object): self.state = STATE_PLAINTEXT self.trustName = self.peer + self.fragmentInfo = None + self.fragment = None + self.discardFragment() + def getPolicy(self, key): raise NotImplementedError @@ -100,13 +100,19 @@ class Context(object): params = message.split(b',') if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit(): logger.warning('invalid formed fragmented message: %r', params) - return None + self.discardFragment() + return message K, N = self.fragmentInfo + try: + k = int(params[1]) + n = int(params[2]) + except ValueError: + logger.warning('invalid formed fragmented message: %r', params) + self.discardFragment() + return message - k = int(params[1]) - n = int(params[2]) fragData = params[3] logger.debug(params) @@ -114,17 +120,17 @@ class Context(object): if n >= k == 1: # first fragment self.discardFragment() - self.fragmentInfo = (k,n) + self.fragmentInfo = (k, n) self.fragment.append(fragData) elif N == n >= k > 1 and k == K+1: # accumulate - self.fragmentInfo = (k,n) + self.fragmentInfo = (k, n) self.fragment.append(fragData) else: # bad, discard self.discardFragment() logger.warning('invalid fragmented message: %r', params) - return None + return message if n == k > 0: assembled = b''.join(self.fragment) @@ -210,7 +216,7 @@ class Context(object): if self.state != STATE_ENCRYPTED: self.sendInternal(proto.Error( 'You sent encrypted to {user}, who wasn\'t expecting it.' - .format(user=self.user.name)), appdata=appdata) + .format(user=self.user.name).encode('utf-8')), appdata=appdata) if ignore: return IGN raise NotEncryptedError(EXC_UNREADABLE_MESSAGE) @@ -263,12 +269,13 @@ class Context(object): return msg def processOutgoingMessage(self, msg, flags, tlvs=[]): - if isinstance(self.parse(msg), proto.Query): + isQuery = self.parseExplicitQuery(msg) is not None + if isQuery: return self.user.getDefaultQueryMessage(self.getPolicy) if self.state == STATE_PLAINTEXT: if self.getPolicy('REQUIRE_ENCRYPTION'): - if not isinstance(self.parse(msg), proto.Query): + if not isQuery: self.lastMessage = msg self.lastSend = time() self.mayRetransmit = 2 @@ -277,8 +284,12 @@ class Context(object): return msg if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED: self.tagOffer = OFFER_SENT - return proto.TaggedPlaintext(msg, self.getPolicy('ALLOW_V1'), - self.getPolicy('ALLOW_V2')) + versions = set() + if self.getPolicy('ALLOW_V1'): + versions.add(1) + if self.getPolicy('ALLOW_V2'): + versions.add(2) + return proto.TaggedPlaintext(msg, versions) return msg if self.state == STATE_ENCRYPTED: msg = self.crypto.createDataMessage(msg, flags, tlvs) @@ -304,9 +315,9 @@ class Context(object): def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None): mms = self.maxMessageSize(appdata) msgLen = len(msg) - if mms != 0 and len(msg) > mms: + if mms != 0 and msgLen > mms: fms = mms - 19 - fragments = [ msg[i:i+fms] for i in range(0, len(msg), fms) ] + fragments = [ msg[i:i+fms] for i in range(0, msgLen, fms) ] fc = len(fragments) @@ -375,9 +386,9 @@ class Context(object): self.crypto.smpSecret(secret, question=question, appdata=appdata) def handleQuery(self, message, appdata=None): - if message.v2 and self.getPolicy('ALLOW_V2'): + if 2 in message.versions and self.getPolicy('ALLOW_V2'): self.authStartV2(appdata=appdata) - elif message.v1 and self.getPolicy('ALLOW_V1'): + elif 1 in message.versions and self.getPolicy('ALLOW_V1'): self.authStartV1(appdata=appdata) def authStartV1(self, appdata=None): @@ -386,7 +397,33 @@ class Context(object): def authStartV2(self, appdata=None): self.crypto.startAKE(appdata=appdata) - def parse(self, message): + def parseExplicitQuery(self, message): + otrTagPos = message.find(proto.OTRTAG) + + if otrTagPos == -1: + return None + + indexBase = otrTagPos + len(proto.OTRTAG) + + if len(message) <= indexBase: + return None + + compare = message[indexBase] + + hasq = compare == b'?'[0] + hasv = compare == b'v'[0] + + if not hasq and not hasv: + return None + + hasv |= len(message) > indexBase+1 and message[indexBase+1] == b'v'[0] + if hasv: + end = message.find(b'?', indexBase+1) + else: + end = indexBase+1 + return message[indexBase:end] + + def parse(self, message, nofragment=False): otrTagPos = message.find(proto.OTRTAG) if otrTagPos == -1: if proto.MESSAGE_TAG_BASE in message: @@ -395,38 +432,40 @@ class Context(object): return message indexBase = otrTagPos + len(proto.OTRTAG) + + if len(message) <= indexBase: + return message + compare = message[indexBase] - if compare == b','[0]: + if nofragment is False and compare == b','[0]: message = self.fragmentAccumulate(message[indexBase:]) if message is None: return None else: - return self.parse(message) + return self.parse(message, nofragment=True) else: self.discardFragment() - hasq = compare == b'?'[0] - hasv = compare == b'v'[0] - if hasq or hasv: - hasv |= len(message) > indexBase+1 and \ - message[indexBase+1] == b'v'[0] - if hasv: - end = message.find(b'?', indexBase+1) - else: - end = indexBase+1 - payload = message[indexBase:end] - return proto.Query.parse(payload) + queryPayload = self.parseExplicitQuery(message) + if queryPayload is not None: + return proto.Query.parse(queryPayload) if compare == b':'[0] and len(message) > indexBase + 4: - infoTag = base64.b64decode(message[indexBase+1:indexBase+5]) - classInfo = struct.unpack(b'!HB', infoTag) - cls = proto.messageClasses.get(classInfo, None) - if cls is None: + try: + infoTag = base64.b64decode(message[indexBase+1:indexBase+5]) + classInfo = struct.unpack(b'!HB', infoTag) + + cls = proto.messageClasses.get(classInfo, None) + if cls is None: + return message + + logger.debug('{user} got msg {typ!r}' \ + .format(user=self.user.name, typ=cls)) + return cls.parsePayload(message[indexBase+5:]) + except (TypeError, struct.error): + logger.exception('could not parse OTR message %s', message) return message - logger.debug('{user} got msg {typ!r}' \ - .format(user=self.user.name, typ=cls)) - return cls.parsePayload(message[indexBase+5:]) if message[indexBase:indexBase+7] == b' Error:': return proto.Error(message[indexBase+7:]) @@ -437,6 +476,22 @@ class Context(object): """Return the max message size for this context.""" return self.user.maxMessageSize + def getExtraKey(self, extraKeyAppId=None, extraKeyAppData=None, appdata=None): + """ retrieves the generated extra symmetric key. + + if extraKeyAppId is set, notifies the chat partner about intended + usage (additional application specific information can be supplied in + extraKeyAppData). + + returns the 256 bit symmetric key """ + + if self.state != STATE_ENCRYPTED: + raise NotEncryptedError + if extraKeyAppId is not None: + tlvs = [proto.ExtraKeyTLV(extraKeyAppId, extraKeyAppData)] + self.sendInternal(b'', tlvs=tlvs, appdata=appdata) + return self.crypto.extraKey + class Account(object): contextclass = Context def __init__(self, name, protocol, maxMessageSize, privkey=None): @@ -447,10 +502,10 @@ class Account(object): self.ctxs = {} self.trusts = {} self.maxMessageSize = maxMessageSize - self.defaultQuery = b'?OTRv{versions}?\n{accountname} has requested ' \ - b'an Off-the-Record private conversation. However, you ' \ - b'do not have a plugin to support that.\nSee '\ - b'http://otr.cypherpunks.ca/ for more information.'; + self.defaultQuery = '?OTRv{versions}?\n{accountname} has requested ' \ + 'an Off-the-Record private conversation. However, you ' \ + 'do not have a plugin to support that.\nSee '\ + 'http://otr.cypherpunks.ca/ for more information.' def __repr__(self): return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__, @@ -461,7 +516,7 @@ class Account(object): self.privkey = self.loadPrivkey() if self.privkey is None: if autogen is True: - self.privkey = crypt.generateDefaultKey() + self.privkey = compatcrypto.generateDefaultKey() self.savePrivkey() else: raise LookupError @@ -484,8 +539,9 @@ class Account(object): return self.ctxs[uid] def getDefaultQueryMessage(self, policy): - v = b'2' if policy('ALLOW_V2') else b'' - return self.defaultQuery.format(accountname=self.name, versions=v) + v = '2' if policy('ALLOW_V2') else '' + msg = self.defaultQuery.format(accountname=self.name, versions=v) + return msg.encode('ascii') def setTrust(self, key, fingerprint, trustLevel): if key not in self.trusts: diff --git a/gotr/potr/crypt.py b/gotr/potr/crypt.py index ad5d663..3a4bb4b 100644 --- a/gotr/potr/crypt.py +++ b/gotr/potr/crypt.py @@ -22,8 +22,8 @@ import logging import struct -from potr.compatcrypto import SHA256, SHA1, HMAC, SHA1HMAC, SHA256HMAC, \ - SHA256HMAC160, Counter, AESCTR, RNG, PK, generateDefaultKey +from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \ + SHA256HMAC160, Counter, AESCTR, PK, random from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi from potr import proto @@ -36,32 +36,31 @@ STATE_AWAITING_SIG = 4 STATE_V1_SETUP = 5 -DH1536_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919 -DH1536_MODULUS_2 = DH1536_MODULUS-2 -DH1536_GENERATOR = 2 -SM_ORDER = (DH1536_MODULUS - 1) // 2 +DH_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919 +DH_MODULUS_2 = DH_MODULUS-2 +DH_GENERATOR = 2 +DH_BITS = 1536 +DH_MAX = 2**DH_BITS +SM_ORDER = (DH_MODULUS - 1) // 2 def check_group(n): - return 2 <= n <= DH1536_MODULUS_2 + return 2 <= n <= DH_MODULUS_2 def check_exp(n): return 1 <= n < SM_ORDER class DH(object): - __slots__ = ['priv', 'pub'] @classmethod def set_params(cls, prime, gen): cls.prime = prime cls.gen = gen def __init__(self): - self.priv = bytes_to_long(RNG.read(40)) + self.priv = random.randrange(2, 2**320) self.pub = pow(self.gen, self.priv, self.prime) -DH.set_params(DH1536_MODULUS, DH1536_GENERATOR) +DH.set_params(DH_MODULUS, DH_GENERATOR) class DHSession(object): - __slots__ = ['sendenc', 'sendmac', 'rcvenc', 'rcvmac', 'sendctr', 'rcvctr', - 'sendmacused', 'rcvmacused'] def __init__(self, sendenc, sendmac, rcvenc, rcvmac): self.sendenc = sendenc self.sendmac = sendmac @@ -79,7 +78,7 @@ class DHSession(object): @classmethod def create(cls, dh, y): - s = pow(y, dh.priv, DH1536_MODULUS) + s = pow(y, dh.priv, DH_MODULUS) sb = pack_mpi(s) if dh.pub > y: @@ -96,9 +95,6 @@ class DHSession(object): return cls(sendenc, sendmac, rcvenc, rcvmac) class CryptEngine(object): - __slots__ = ['ctx', 'ake', 'sessionId', 'sessionIdHalf', 'theirKeyid', - 'theirY', 'theirOldY', 'ourOldDHKey', 'ourDHKey', 'ourKeyid', - 'sessionkeys', 'theirPubkey', 'savedMacKeys', 'smp'] def __init__(self, ctx): self.ctx = ctx self.ake = None @@ -118,6 +114,7 @@ class CryptEngine(object): self.savedMacKeys = [] self.smp = None + self.extraKey = None def revealMacs(self, ours=True): if ours: @@ -174,7 +171,7 @@ class CryptEngine(object): if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()): logger.error('HMACs don\'t match') raise InvalidParameterError - sesskey.rcvmacused = 1 + sesskey.rcvmacused = True newCtrPrefix = bytes_to_long(msg.ctr) if newCtrPrefix <= sesskey.rcvctr.prefix: @@ -223,11 +220,14 @@ class CryptEngine(object): self.smp = SMPHandler(self) self.smp.abort(appdata=appdata) - def createDataMessage(self, message, flags=0, tlvs=[]): + def createDataMessage(self, message, flags=0, tlvs=None): # check MSGSTATE if self.theirKeyid == 0: raise InvalidParameterError + if tlvs is None: + tlvs = [] + sess = self.sessionkeys[1][0] sess.sendctr.inc() @@ -303,13 +303,16 @@ class CryptEngine(object): self.ourKeyid = ake.ourKeyid self.theirY = ake.gy self.theirOldY = None + self.extraKey = ake.extraKey if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub: - # XXX is this really ok? self.ourDHKey = ake.dh self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY) self.rotateDHKeys() + # we don't need the AKE anymore, free the reference + self.ake = None + self.ctx._wentEncrypted() logger.info('went encrypted with {0}'.format(self.theirPubkey)) @@ -317,10 +320,6 @@ class CryptEngine(object): self.smp = None class AuthKeyExchange(object): - __slots__ = ['privkey', 'state', 'r', 'encgx', 'hashgx', 'ourKeyid', - 'theirPubkey', 'theirKeyid', 'enc_c', 'enc_cp', 'mac_m1', - 'mac_m1p', 'mac_m2', 'mac_m2p', 'sessionId', 'dh', 'onSuccess', - 'gy', 'lastmsg', 'sessionIdHalf'] def __init__(self, privkey, onSuccess): self.privkey = privkey self.state = STATE_NONE @@ -341,9 +340,11 @@ class AuthKeyExchange(object): self.dh = DH() self.onSuccess = onSuccess self.gy = None + self.extraKey = None + self.lastmsg = None def startAKE(self): - self.r = RNG.read(16) + self.r = long_to_bytes(random.getrandbits(128)) gxmpi = pack_mpi(self.dh.pub) @@ -444,15 +445,17 @@ class AuthKeyExchange(object): self.state = STATE_NONE def createAuthKeys(self): - s = pow(self.gy, self.dh.priv, DH1536_MODULUS) + s = pow(self.gy, self.dh.priv, DH_MODULUS) sbyte = pack_mpi(s) - self.sessionId = SHA256(b'\0' + sbyte)[:8] - enc = SHA256(b'\1' + sbyte) - self.enc_c, self.enc_cp = enc[:16], enc[16:] - self.mac_m1 = SHA256(b'\2' + sbyte) - self.mac_m2 = SHA256(b'\3' + sbyte) - self.mac_m1p = SHA256(b'\4' + sbyte) - self.mac_m2p = SHA256(b'\5' + sbyte) + self.sessionId = SHA256(b'\x00' + sbyte)[:8] + enc = SHA256(b'\x01' + sbyte) + self.enc_c = enc[:16] + self.enc_cp = enc[16:] + self.mac_m1 = SHA256(b'\x02' + sbyte) + self.mac_m2 = SHA256(b'\x03' + sbyte) + self.mac_m1p = SHA256(b'\x04' + sbyte) + self.mac_m2p = SHA256(b'\x05' + sbyte) + self.extraKey = SHA256(b'\xff' + sbyte) def calculatePubkeyAuth(self, key, mackey): pubkey = self.privkey.serializePublicKey() @@ -490,14 +493,15 @@ SMPPROG_FAILED = -1 SMPPROG_SUCCEEDED = 1 class SMPHandler: - __slots__ = ['crypto', 'questionReceived', 'prog', 'state', 'g1', 'g3o', - 'x2', 'x3', 'g2', 'g3', 'pab', 'qab', 'secret', 'p', 'q'] - def __init__(self, crypto): self.crypto = crypto self.state = 1 - self.g1 = DH1536_GENERATOR + self.g1 = DH_GENERATOR + self.g2 = None + self.g3 = None self.g3o = None + self.x2 = None + self.x3 = None self.prog = SMPPROG_OK self.pab = None self.qab = None @@ -539,11 +543,11 @@ class SMPHandler: self.g3o = msg[3] - self.x2 = bytes_to_long(RNG.read(192)) - self.x3 = bytes_to_long(RNG.read(192)) + self.x2 = random.randrange(2, DH_MAX) + self.x3 = random.randrange(2, DH_MAX) - self.g2 = pow(msg[0], self.x2, DH1536_MODULUS) - self.g3 = pow(msg[3], self.x3, DH1536_MODULUS) + self.g2 = pow(msg[0], self.x2, DH_MODULUS) + self.g3 = pow(msg[3], self.x3, DH_MODULUS) self.prog = SMPPROG_OK self.state = 0 @@ -568,29 +572,29 @@ class SMPHandler: return self.g3o = msg[3] - self.g2 = pow(msg[0], self.x2, DH1536_MODULUS) - self.g3 = pow(msg[3], self.x3, DH1536_MODULUS) + self.g2 = pow(msg[0], self.x2, DH_MODULUS) + self.g3 = pow(msg[3], self.x3, DH_MODULUS) if not self.check_equal_coords(msg[6:11], 5): logger.error('invalid SMP2TLV received') self.abort(appdata=appdata) return - r = bytes_to_long(RNG.read(192)) - self.p = pow(self.g3, r, DH1536_MODULUS) + r = random.randrange(2, DH_MAX) + self.p = pow(self.g3, r, DH_MODULUS) msg = [self.p] - qa1 = pow(self.g1, r, DH1536_MODULUS) - qa2 = pow(self.g2, self.secret, DH1536_MODULUS) - self.q = qa1*qa2 % DH1536_MODULUS + qa1 = pow(self.g1, r, DH_MODULUS) + qa2 = pow(self.g2, self.secret, DH_MODULUS) + self.q = qa1*qa2 % DH_MODULUS msg.append(self.q) msg += self.proof_equal_coords(r, 6) inv = invMod(mp) - self.pab = self.p * inv % DH1536_MODULUS + self.pab = self.p * inv % DH_MODULUS inv = invMod(mq) - self.qab = self.q * inv % DH1536_MODULUS + self.qab = self.q * inv % DH_MODULUS - msg.append(pow(self.qab, self.x3, DH1536_MODULUS)) + msg.append(pow(self.qab, self.x3, DH_MODULUS)) msg += self.proof_equal_logs(7) self.state = 4 @@ -613,9 +617,9 @@ class SMPHandler: return inv = invMod(self.p) - self.pab = msg[0] * inv % DH1536_MODULUS + self.pab = msg[0] * inv % DH_MODULUS inv = invMod(self.q) - self.qab = msg[1] * inv % DH1536_MODULUS + self.qab = msg[1] * inv % DH_MODULUS if not self.check_equal_logs(msg[5:8], 7): logger.error('invalid SMP3TLV received') @@ -623,10 +627,10 @@ class SMPHandler: return md = msg[5] - msg = [pow(self.qab, self.x3, DH1536_MODULUS)] + msg = [pow(self.qab, self.x3, DH_MODULUS)] msg += self.proof_equal_logs(8) - rab = pow(md, self.x3, DH1536_MODULUS) + rab = pow(md, self.x3, DH_MODULUS) self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED if self.prog != SMPPROG_SUCCEEDED: @@ -654,7 +658,7 @@ class SMPHandler: self.abort(appdata=appdata) return - rab = pow(msg[0], self.x3, DH1536_MODULUS) + rab = pow(msg[0], self.x3, DH_MODULUS) self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED @@ -679,12 +683,12 @@ class SMPHandler: self.secret = bytes_to_long(combSecret) - self.x2 = bytes_to_long(RNG.read(192)) - self.x3 = bytes_to_long(RNG.read(192)) + self.x2 = random.randrange(2, DH_MAX) + self.x3 = random.randrange(2, DH_MAX) - msg = [pow(self.g1, self.x2, DH1536_MODULUS)] + msg = [pow(self.g1, self.x2, DH_MODULUS)] msg += proof_known_log(self.g1, self.x2, 1) - msg.append(pow(self.g1, self.x3, DH1536_MODULUS)) + msg.append(pow(self.g1, self.x3, DH_MODULUS)) msg += proof_known_log(self.g1, self.x3, 2) self.prog = SMPPROG_OK @@ -700,19 +704,19 @@ class SMPHandler: self.secret = bytes_to_long(combSecret) - msg = [pow(self.g1, self.x2, DH1536_MODULUS)] + msg = [pow(self.g1, self.x2, DH_MODULUS)] msg += proof_known_log(self.g1, self.x2, 3) - msg.append(pow(self.g1, self.x3, DH1536_MODULUS)) + msg.append(pow(self.g1, self.x3, DH_MODULUS)) msg += proof_known_log(self.g1, self.x3, 4) - r = bytes_to_long(RNG.read(192)) + r = random.randrange(2, DH_MAX) - self.p = pow(self.g3, r, DH1536_MODULUS) + self.p = pow(self.g3, r, DH_MODULUS) msg.append(self.p) - qb1 = pow(self.g1, r, DH1536_MODULUS) - qb2 = pow(self.g2, self.secret, DH1536_MODULUS) - self.q = qb1 * qb2 % DH1536_MODULUS + qb1 = pow(self.g1, r, DH_MODULUS) + qb2 = pow(self.g2, self.secret, DH_MODULUS) + self.q = qb1 * qb2 % DH_MODULUS msg.append(self.q) msg += self.proof_equal_coords(r, 5) @@ -721,11 +725,11 @@ class SMPHandler: self.sendTLV(proto.SMP2TLV(msg), appdata=appdata) def proof_equal_coords(self, r, v): - r1 = bytes_to_long(RNG.read(192)) - r2 = bytes_to_long(RNG.read(192)) - temp2 = pow(self.g1, r1, DH1536_MODULUS) \ - * pow(self.g2, r2, DH1536_MODULUS) % DH1536_MODULUS - temp1 = pow(self.g3, r1, DH1536_MODULUS) + r1 = random.randrange(2, DH_MAX) + r2 = random.randrange(2, DH_MAX) + temp2 = pow(self.g1, r1, DH_MODULUS) \ + * pow(self.g2, r2, DH_MODULUS) % DH_MODULUS + temp1 = pow(self.g3, r1, DH_MODULUS) cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) c = bytes_to_long(cb) @@ -739,21 +743,21 @@ class SMPHandler: def check_equal_coords(self, coords, v): (p, q, c, d1, d2) = coords - temp1 = pow(self.g3, d1, DH1536_MODULUS) * pow(p, c, DH1536_MODULUS) \ - % DH1536_MODULUS + temp1 = pow(self.g3, d1, DH_MODULUS) * pow(p, c, DH_MODULUS) \ + % DH_MODULUS - temp2 = pow(self.g1, d1, DH1536_MODULUS) \ - * pow(self.g2, d2, DH1536_MODULUS) \ - * pow(q, c, DH1536_MODULUS) % DH1536_MODULUS + temp2 = pow(self.g1, d1, DH_MODULUS) \ + * pow(self.g2, d2, DH_MODULUS) \ + * pow(q, c, DH_MODULUS) % DH_MODULUS cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) - return long_to_bytes(c) == cprime + return long_to_bytes(c, 32) == cprime def proof_equal_logs(self, v): - r = bytes_to_long(RNG.read(192)) - temp1 = pow(self.g1, r, DH1536_MODULUS) - temp2 = pow(self.qab, r, DH1536_MODULUS) + r = random.randrange(2, DH_MAX) + temp1 = pow(self.g1, r, DH_MODULUS) + temp2 = pow(self.qab, r, DH_MODULUS) cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) c = bytes_to_long(cb) @@ -763,29 +767,29 @@ class SMPHandler: def check_equal_logs(self, logs, v): (r, c, d) = logs - temp1 = pow(self.g1, d, DH1536_MODULUS) \ - * pow(self.g3o, c, DH1536_MODULUS) % DH1536_MODULUS + temp1 = pow(self.g1, d, DH_MODULUS) \ + * pow(self.g3o, c, DH_MODULUS) % DH_MODULUS - temp2 = pow(self.qab, d, DH1536_MODULUS) \ - * pow(r, c, DH1536_MODULUS) % DH1536_MODULUS + temp2 = pow(self.qab, d, DH_MODULUS) \ + * pow(r, c, DH_MODULUS) % DH_MODULUS cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) - return long_to_bytes(c) == cprime + return long_to_bytes(c, 32) == cprime def proof_known_log(g, x, v): - r = bytes_to_long(RNG.read(192)) - c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH1536_MODULUS)))) + r = random.randrange(2, DH_MAX) + c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS)))) temp = x * c % SM_ORDER return c, (r-temp) % SM_ORDER def check_known_log(c, d, g, x, v): - gd = pow(g, d, DH1536_MODULUS) - xc = pow(x, c, DH1536_MODULUS) - gdxc = gd * xc % DH1536_MODULUS - return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c) + gd = pow(g, d, DH_MODULUS) + xc = pow(x, c, DH_MODULUS) + gdxc = gd * xc % DH_MODULUS + return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c, 32) def invMod(n): - return pow(n, DH1536_MODULUS_2, DH1536_MODULUS) + return pow(n, DH_MODULUS_2, DH_MODULUS) class InvalidParameterError(RuntimeError): pass diff --git a/gotr/potr/proto.py b/gotr/potr/proto.py index 745a53d..91904a8 100644 --- a/gotr/potr/proto.py +++ b/gotr/potr/proto.py @@ -19,14 +19,16 @@ from __future__ import unicode_literals import base64 -import logging import struct from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack OTRTAG = b'?OTR' MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t ' -MESSAGE_TAG_V1 = b' \t \t \t ' -MESSAGE_TAG_V2 = b' \t\t \t ' +MESSAGE_TAGS = { + 1:b' \t \t \t ', + 2:b' \t\t \t ', + 3:b' \t\t \t\t', + } MSGTYPE_NOTOTR = 0 MSGTYPE_TAGGEDPLAINTEXT = 1 @@ -62,6 +64,8 @@ def registermessage(cls): def registertlv(cls): if not hasattr(cls, 'parsePayload'): raise TypeError('registered tlv types need parsePayload()') + if cls.typ is None: + raise TypeError('registered tlv type needs type ID') tlvClasses[cls.typ] = cls return cls @@ -84,16 +88,6 @@ class OTRMessage(object): __slots__ = ['payload'] version = 0x0002 msgtype = 0 - def __init__(self, payload): - self.payload = payload - - def getPayload(self): - return self.payload - - def __bytes__(self): - data = struct.pack(b'!HB', self.version, self.msgtype) \ - + self.getPayload() - return b'?OTR:' + base64.b64encode(data) + b'.' def __eq__(self, other): if not isinstance(other, self.__class__): @@ -110,6 +104,7 @@ class OTRMessage(object): class Error(OTRMessage): __slots__ = ['error'] def __init__(self, error): + super(Error, self).__init__() self.error = error def __repr__(self): @@ -119,56 +114,58 @@ class Error(OTRMessage): return b'?OTR Error:' + self.error class Query(OTRMessage): - __slots__ = ['v1', 'v2'] - def __init__(self, v1, v2): - self.v1 = v1 - self.v2 = v2 + __slots__ = ['versions'] + def __init__(self, versions=set()): + super(Query, self).__init__() + self.versions = versions @classmethod def parse(cls, data): - v2 = False - v1 = False - if len(data) > 0 and data[0:1] == b'?': - data = data[1:] - v1 = True - - if len(data) > 0 and data[0:1] == b'v': - for c in data[1:]: - if c == b'2'[0]: - v2 = True - return cls(v1, v2) + if not isinstance(data, bytes): + raise TypeError('can only parse bytes') + udata = data.decode('ascii', errors='replace') + + versions = set() + if len(udata) > 0 and udata[0] == '?': + udata = udata[1:] + versions.add(1) + + if len(udata) > 0 and udata[0] == 'v': + versions.update(( int(c) for c in udata if c.isdigit() )) + return cls(versions) def __repr__(self): - return ''%(self.v1,self.v2) + return '' % (self.versions) def __bytes__(self): d = b'?OTR' - if self.v1: + if 1 in self.versions: d += b'?' d += b'v' - if self.v2: - d += b'2' + + # in python3 there is only int->unicode conversion + # so I convert to unicode and encode it to a byte string + versions = [ '%d' % v for v in self.versions if v != 1 ] + d += ''.join(versions).encode('ascii') + d += b'?' return d class TaggedPlaintext(Query): __slots__ = ['msg'] - def __init__(self, msg, v1, v2): + def __init__(self, msg, versions): + super(TaggedPlaintext, self).__init__(versions) self.msg = msg - self.v1 = v1 - self.v2 = v2 def __bytes__(self): data = self.msg + MESSAGE_TAG_BASE - if self.v1: - data += MESSAGE_TAG_V1 - if self.v2: - data += MESSAGE_TAG_V2 + for v in self.versions: + data += MESSAGE_TAGS[v] return data def __repr__(self): - return '' \ - .format(v1=self.v1, v2=self.v2, msg=self.msg) + return '' \ + .format(versions=self.versions, msg=self.msg) @classmethod def parse(cls, data): @@ -177,21 +174,18 @@ class TaggedPlaintext(Query): raise TypeError( 'this is not a tagged plaintext ({0!r:.20})'.format(data)) - v1 = False - v2 = False - tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ] - for tag in tags: - if not tag.isspace(): - break - v1 |= tag == MESSAGE_TAG_V1 - v2 |= tag == MESSAGE_TAG_V2 + versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag + in tags ]) - return TaggedPlaintext(data[:tagPos], v1, v2) + return TaggedPlaintext(data[:tagPos], versions) class GenericOTRMessage(OTRMessage): __slots__ = ['data'] + fields = [] + def __init__(self, *args): + super(GenericOTRMessage, self).__init__() if len(args) != len(self.fields): raise TypeError('%s needs %d arguments, got %d' % (self.__class__.__name__, len(self.fields), len(args))) @@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage): self.__getattr__(attr) # existence check self.data[attr] = val + def __bytes__(self): + data = struct.pack(b'!HB', self.version, self.msgtype) \ + + self.getPayload() + return b'?OTR:' + base64.b64encode(data) + b'.' + def __repr__(self): name = self.__class__.__name__ data = '' @@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage): def parsePayload(cls, data): data = base64.b64decode(data) args = [] - for k, ftype in cls.fields: + for _, ftype in cls.fields: if ftype == 'data': value, data = read_data(data) elif isinstance(ftype, bytes): - size = int(struct.calcsize(ftype)) value, data = unpack(ftype, data) elif isinstance(ftype, int): value, data = data[:ftype], data[ftype:] @@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage): class AKEMessage(GenericOTRMessage): __slots__ = [] - pass @registermessage class DHCommit(AKEMessage): __slots__ = [] msgtype = 0x02 - fields = [('encgx','data'), ('hashgx','data'), ] - + fields = [('encgx', 'data'), ('hashgx', 'data'), ] @registermessage class DHKey(AKEMessage): __slots__ = [] msgtype = 0x0a - fields = [('gy','data'), ] + fields = [('gy', 'data'), ] @registermessage class RevealSig(AKEMessage): __slots__ = [] msgtype = 0x11 - fields = [('rkey','data'), ('encsig','data'), ('mac',20),] + fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),] def getMacedData(self): p = self.encsig @@ -280,7 +276,7 @@ class RevealSig(AKEMessage): class Signature(AKEMessage): __slots__ = [] msgtype = 0x12 - fields = [('encsig','data'), ('mac',20)] + fields = [('encsig', 'data'), ('mac', 20)] def getMacedData(self): p = self.encsig @@ -290,8 +286,9 @@ class Signature(AKEMessage): class DataMessage(GenericOTRMessage): __slots__ = [] msgtype = 0x03 - fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'), - ('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ] + fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'), + ('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20), + ('oldmacs', 'data'), ] def getMacedData(self): return struct.pack(b'!HB', self.version, self.msgtype) + \ @@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage): @bytesAndStrings class TLV(object): __slots__ = [] + typ = None + + def getPayload(self): + raise NotImplementedError def __repr__(self): val = self.getPayload() @@ -330,11 +331,28 @@ class TLV(object): def __neq__(self, other): return not self.__eq__(other) +@registertlv +class PaddingTLV(TLV): + typ = 0 + + __slots__ = ['padding'] + + def __init__(self, padding): + super(PaddingTLV, self).__init__() + self.padding = padding + + def getPayload(self): + return self.padding + + @classmethod + def parsePayload(cls, data): + return cls(data) + @registertlv class DisconnectTLV(TLV): typ = 1 def __init__(self): - pass + super(DisconnectTLV, self).__init__() def getPayload(self): return b'' @@ -348,8 +366,14 @@ class DisconnectTLV(TLV): class SMPTLV(TLV): __slots__ = ['mpis'] - - def __init__(self, mpis=[]): + dlen = None + + def __init__(self, mpis=None): + super(SMPTLV, self).__init__() + if mpis is None: + mpis = [] + if self.dlen is None: + raise TypeError('no amount of mpis specified in dlen') if len(mpis) != self.dlen: raise TypeError('expected {0} mpis, got {1}' .format(self.dlen, len(mpis))) @@ -366,7 +390,7 @@ class SMPTLV(TLV): mpis = [] if cls.dlen > 0: count, data = unpack(b'!I', data) - for i in range(count): + for _ in range(count): n, data = read_mpi(data) mpis.append(n) if len(data) > 0: @@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV): def getPayload(self): return b'' + +@registertlv +class ExtraKeyTLV(TLV): + typ = 8 + + __slots__ = ['appid', 'appdata'] + + def __init__(self, appid, appdata): + super(ExtraKeyTLV, self).__init__() + self.appid = appid + self.appdata = appdata + if appdata is None: + self.appdata = b'' + + def getPayload(self): + return self.appid + self.appdata + + @classmethod + def parsePayload(cls, data): + return cls(data[:4], data[4:]) diff --git a/gotr/potr/utils.py b/gotr/potr/utils.py index 2bedf55..e41ca46 100644 --- a/gotr/potr/utils.py +++ b/gotr/potr/utils.py @@ -43,11 +43,12 @@ def bytes_to_long(b): s += byte_to_long(b[i:i+1]) << 8*(l-i-1) return s -def long_to_bytes(l): +def long_to_bytes(l, n=0): b = b'' - while l != 0: + while l != 0 or n > 0: b = long_to_byte(l & 0xff) + b l >>= 8 + n -= 1 return b def byte_to_long(b): -- cgit v1.2.3