Welcome to mirror list, hosted at ThFree Co, Russian Federation.

dev.gajim.org/gajim/gajim-plugins.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/gotr
diff options
context:
space:
mode:
authorKjell Braden <afflux.gajim@pentabarf.de>2013-09-22 19:04:49 +0400
committerKjell Braden <afflux.gajim@pentabarf.de>2013-09-22 19:04:49 +0400
commitd11d4e93d0ec312f74cd7559e153e3d8d07c8546 (patch)
tree5e928cd9770d42b542a2a19948db51bc6f456699 /gotr
parentf4807426cd3a1235053b3ba717fd030b610635d6 (diff)
gotr: update provided potr to 1.0.0beta7
Diffstat (limited to 'gotr')
-rw-r--r--gotr/potr/__init__.py2
-rw-r--r--gotr/potr/compatcrypto/common.py14
-rw-r--r--gotr/potr/compatcrypto/pycrypto.py25
-rw-r--r--gotr/potr/context.py156
-rw-r--r--gotr/potr/crypt.py188
-rw-r--r--gotr/potr/proto.py172
-rw-r--r--gotr/potr/utils.py5
7 files changed, 333 insertions, 229 deletions
diff --git a/gotr/potr/__init__.py b/gotr/potr/__init__.py
index 965aed2..c6e8d02 100644
--- a/gotr/potr/__init__.py
+++ b/gotr/potr/__init__.py
@@ -24,4 +24,4 @@ from potr.utils import human_hash
''' version is: (major, minor, patch, sub) with sub being one of 'alpha',
'beta', 'final' '''
-VERSION = (1, 0, 0, 'beta5')
+VERSION = (1, 0, 0, 'beta7')
diff --git a/gotr/potr/compatcrypto/common.py b/gotr/potr/compatcrypto/common.py
index 5d6af40..61f2bea 100644
--- a/gotr/potr/compatcrypto/common.py
+++ b/gotr/potr/compatcrypto/common.py
@@ -26,8 +26,8 @@ from potr.utils import human_hash, bytes_to_long, unpack, pack_mpi
DEFAULT_KEYTYPE = 0x0000
pkTypes = {}
def registerkeytype(cls):
- if not hasattr(cls, 'parsePayload'):
- raise TypeError('registered key types need parsePayload()')
+ if cls.keyType is None:
+ raise TypeError('registered key class needs a type value')
pkTypes[cls.keyType] = cls
return cls
@@ -35,12 +35,16 @@ def generateDefaultKey():
return pkTypes[DEFAULT_KEYTYPE].generate()
class PK(object):
- __slots__ = []
+ keyType = None
@classmethod
def generate(cls):
raise NotImplementedError
+ @classmethod
+ def parsePayload(cls, data, private=False):
+ raise NotImplementedError
+
def sign(self, data):
raise NotImplementedError
def verify(self, data):
@@ -80,13 +84,13 @@ class PK(object):
@classmethod
def parsePrivateKey(cls, data):
implCls, data = cls.getImplementation(data)
- logging.debug('Got privkey of type %r' % implCls)
+ logging.debug('Got privkey of type %r', implCls)
return implCls.parsePayload(data, private=True)
@classmethod
def parsePublicKey(cls, data):
implCls, data = cls.getImplementation(data)
- logging.debug('Got pubkey of type %r' % implCls)
+ logging.debug('Got pubkey of type %r', implCls)
return implCls.parsePayload(data)
def __str__(self):
diff --git a/gotr/potr/compatcrypto/pycrypto.py b/gotr/potr/compatcrypto/pycrypto.py
index dd93295..2800431 100644
--- a/gotr/potr/compatcrypto/pycrypto.py
+++ b/gotr/potr/compatcrypto/pycrypto.py
@@ -15,18 +15,16 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
-from Crypto import Cipher, Random
+from Crypto import Cipher
from Crypto.Hash import SHA256 as _SHA256
-from Crypto.Hash import SHA as _SHA1
+from Crypto.Hash import SHA as _SHA1
from Crypto.Hash import HMAC as _HMAC
from Crypto.PublicKey import DSA
+from Crypto.Random import random
from numbers import Number
from potr.compatcrypto import common
-from potr.utils import pack_mpi, read_mpi, bytes_to_long, long_to_bytes
-
-# XXX atfork?
-RNG = Random.new()
+from potr.utils import read_mpi, bytes_to_long, long_to_bytes
def SHA256(data):
return _SHA256.new(data).digest()
@@ -54,7 +52,6 @@ def AESCTR(key, counter=0):
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
class Counter(object):
- __slots__ = ['prefix', 'val']
def __init__(self, prefix):
self.prefix = prefix
self.val = 0
@@ -72,17 +69,15 @@ class Counter(object):
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
def byteprefix(self):
- return long_to_bytes(self.prefix).rjust(8, b'\0')
+ return long_to_bytes(self.prefix, 8)
def __call__(self):
- val = long_to_bytes(self.val)
- prefix = long_to_bytes(self.prefix)
+ bytesuffix = long_to_bytes(self.val, 8)
self.val += 1
- return self.byteprefix() + val.rjust(8, b'\0')
+ return self.byteprefix() + bytesuffix
@common.registerkeytype
class DSAKey(common.PK):
- __slots__ = ['priv', 'pub']
keyType = 0x0000
def __init__(self, key=None, private=False):
@@ -111,10 +106,10 @@ class DSAKey(common.PK):
return SHA1(self.getSerializedPublicPayload())
def sign(self, data):
- # 2 <= K <= q = 160bit = 20 byte
- K = bytes_to_long(RNG.read(19)) + 2
+ # 2 <= K <= q
+ K = random.randrange(2, self.priv.q)
r, s = self.priv.sign(data, K)
- return long_to_bytes(r) + long_to_bytes(s)
+ return long_to_bytes(r, 20) + long_to_bytes(s, 20)
def verify(self, data, sig):
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])
diff --git a/gotr/potr/context.py b/gotr/potr/context.py
index aa99f3a..3ec3d74 100644
--- a/gotr/potr/context.py
+++ b/gotr/potr/context.py
@@ -19,7 +19,7 @@
from __future__ import unicode_literals
try:
- basestring = basestring
+ type(basestring)
except NameError:
# all strings are unicode in python3k
basestring = str
@@ -27,7 +27,7 @@ except NameError:
# callable is not available in python 3.0 and 3.1
try:
- callable = callable
+ type(callable)
except NameError:
from collections import Callable
def callable(x):
@@ -42,6 +42,7 @@ logger = logging.getLogger(__name__)
from potr import crypt
from potr import proto
+from potr import compatcrypto
from time import time
@@ -62,16 +63,11 @@ OFFER_REJECTED = 2
OFFER_ACCEPTED = 3
class Context(object):
- __slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend',
- 'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state',
- 'inject', 'trust', 'peer', 'trustName']
-
def __init__(self, account, peername):
self.user = account
self.peer = peername
self.policy = {}
self.crypto = crypt.CryptEngine(self)
- self.discardFragment()
self.tagOffer = OFFER_NOTSENT
self.mayRetransmit = 0
self.lastSend = 0
@@ -79,6 +75,10 @@ class Context(object):
self.state = STATE_PLAINTEXT
self.trustName = self.peer
+ self.fragmentInfo = None
+ self.fragment = None
+ self.discardFragment()
+
def getPolicy(self, key):
raise NotImplementedError
@@ -100,13 +100,19 @@ class Context(object):
params = message.split(b',')
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
logger.warning('invalid formed fragmented message: %r', params)
- return None
+ self.discardFragment()
+ return message
K, N = self.fragmentInfo
+ try:
+ k = int(params[1])
+ n = int(params[2])
+ except ValueError:
+ logger.warning('invalid formed fragmented message: %r', params)
+ self.discardFragment()
+ return message
- k = int(params[1])
- n = int(params[2])
fragData = params[3]
logger.debug(params)
@@ -114,17 +120,17 @@ class Context(object):
if n >= k == 1:
# first fragment
self.discardFragment()
- self.fragmentInfo = (k,n)
+ self.fragmentInfo = (k, n)
self.fragment.append(fragData)
elif N == n >= k > 1 and k == K+1:
# accumulate
- self.fragmentInfo = (k,n)
+ self.fragmentInfo = (k, n)
self.fragment.append(fragData)
else:
# bad, discard
self.discardFragment()
logger.warning('invalid fragmented message: %r', params)
- return None
+ return message
if n == k > 0:
assembled = b''.join(self.fragment)
@@ -210,7 +216,7 @@ class Context(object):
if self.state != STATE_ENCRYPTED:
self.sendInternal(proto.Error(
'You sent encrypted to {user}, who wasn\'t expecting it.'
- .format(user=self.user.name)), appdata=appdata)
+ .format(user=self.user.name).encode('utf-8')), appdata=appdata)
if ignore:
return IGN
raise NotEncryptedError(EXC_UNREADABLE_MESSAGE)
@@ -263,12 +269,13 @@ class Context(object):
return msg
def processOutgoingMessage(self, msg, flags, tlvs=[]):
- if isinstance(self.parse(msg), proto.Query):
+ isQuery = self.parseExplicitQuery(msg) is not None
+ if isQuery:
return self.user.getDefaultQueryMessage(self.getPolicy)
if self.state == STATE_PLAINTEXT:
if self.getPolicy('REQUIRE_ENCRYPTION'):
- if not isinstance(self.parse(msg), proto.Query):
+ if not isQuery:
self.lastMessage = msg
self.lastSend = time()
self.mayRetransmit = 2
@@ -277,8 +284,12 @@ class Context(object):
return msg
if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED:
self.tagOffer = OFFER_SENT
- return proto.TaggedPlaintext(msg, self.getPolicy('ALLOW_V1'),
- self.getPolicy('ALLOW_V2'))
+ versions = set()
+ if self.getPolicy('ALLOW_V1'):
+ versions.add(1)
+ if self.getPolicy('ALLOW_V2'):
+ versions.add(2)
+ return proto.TaggedPlaintext(msg, versions)
return msg
if self.state == STATE_ENCRYPTED:
msg = self.crypto.createDataMessage(msg, flags, tlvs)
@@ -304,9 +315,9 @@ class Context(object):
def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None):
mms = self.maxMessageSize(appdata)
msgLen = len(msg)
- if mms != 0 and len(msg) > mms:
+ if mms != 0 and msgLen > mms:
fms = mms - 19
- fragments = [ msg[i:i+fms] for i in range(0, len(msg), fms) ]
+ fragments = [ msg[i:i+fms] for i in range(0, msgLen, fms) ]
fc = len(fragments)
@@ -375,9 +386,9 @@ class Context(object):
self.crypto.smpSecret(secret, question=question, appdata=appdata)
def handleQuery(self, message, appdata=None):
- if message.v2 and self.getPolicy('ALLOW_V2'):
+ if 2 in message.versions and self.getPolicy('ALLOW_V2'):
self.authStartV2(appdata=appdata)
- elif message.v1 and self.getPolicy('ALLOW_V1'):
+ elif 1 in message.versions and self.getPolicy('ALLOW_V1'):
self.authStartV1(appdata=appdata)
def authStartV1(self, appdata=None):
@@ -386,7 +397,33 @@ class Context(object):
def authStartV2(self, appdata=None):
self.crypto.startAKE(appdata=appdata)
- def parse(self, message):
+ def parseExplicitQuery(self, message):
+ otrTagPos = message.find(proto.OTRTAG)
+
+ if otrTagPos == -1:
+ return None
+
+ indexBase = otrTagPos + len(proto.OTRTAG)
+
+ if len(message) <= indexBase:
+ return None
+
+ compare = message[indexBase]
+
+ hasq = compare == b'?'[0]
+ hasv = compare == b'v'[0]
+
+ if not hasq and not hasv:
+ return None
+
+ hasv |= len(message) > indexBase+1 and message[indexBase+1] == b'v'[0]
+ if hasv:
+ end = message.find(b'?', indexBase+1)
+ else:
+ end = indexBase+1
+ return message[indexBase:end]
+
+ def parse(self, message, nofragment=False):
otrTagPos = message.find(proto.OTRTAG)
if otrTagPos == -1:
if proto.MESSAGE_TAG_BASE in message:
@@ -395,38 +432,40 @@ class Context(object):
return message
indexBase = otrTagPos + len(proto.OTRTAG)
+
+ if len(message) <= indexBase:
+ return message
+
compare = message[indexBase]
- if compare == b','[0]:
+ if nofragment is False and compare == b','[0]:
message = self.fragmentAccumulate(message[indexBase:])
if message is None:
return None
else:
- return self.parse(message)
+ return self.parse(message, nofragment=True)
else:
self.discardFragment()
- hasq = compare == b'?'[0]
- hasv = compare == b'v'[0]
- if hasq or hasv:
- hasv |= len(message) > indexBase+1 and \
- message[indexBase+1] == b'v'[0]
- if hasv:
- end = message.find(b'?', indexBase+1)
- else:
- end = indexBase+1
- payload = message[indexBase:end]
- return proto.Query.parse(payload)
+ queryPayload = self.parseExplicitQuery(message)
+ if queryPayload is not None:
+ return proto.Query.parse(queryPayload)
if compare == b':'[0] and len(message) > indexBase + 4:
- infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
- classInfo = struct.unpack(b'!HB', infoTag)
- cls = proto.messageClasses.get(classInfo, None)
- if cls is None:
+ try:
+ infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
+ classInfo = struct.unpack(b'!HB', infoTag)
+
+ cls = proto.messageClasses.get(classInfo, None)
+ if cls is None:
+ return message
+
+ logger.debug('{user} got msg {typ!r}' \
+ .format(user=self.user.name, typ=cls))
+ return cls.parsePayload(message[indexBase+5:])
+ except (TypeError, struct.error):
+ logger.exception('could not parse OTR message %s', message)
return message
- logger.debug('{user} got msg {typ!r}' \
- .format(user=self.user.name, typ=cls))
- return cls.parsePayload(message[indexBase+5:])
if message[indexBase:indexBase+7] == b' Error:':
return proto.Error(message[indexBase+7:])
@@ -437,6 +476,22 @@ class Context(object):
"""Return the max message size for this context."""
return self.user.maxMessageSize
+ def getExtraKey(self, extraKeyAppId=None, extraKeyAppData=None, appdata=None):
+ """ retrieves the generated extra symmetric key.
+
+ if extraKeyAppId is set, notifies the chat partner about intended
+ usage (additional application specific information can be supplied in
+ extraKeyAppData).
+
+ returns the 256 bit symmetric key """
+
+ if self.state != STATE_ENCRYPTED:
+ raise NotEncryptedError
+ if extraKeyAppId is not None:
+ tlvs = [proto.ExtraKeyTLV(extraKeyAppId, extraKeyAppData)]
+ self.sendInternal(b'', tlvs=tlvs, appdata=appdata)
+ return self.crypto.extraKey
+
class Account(object):
contextclass = Context
def __init__(self, name, protocol, maxMessageSize, privkey=None):
@@ -447,10 +502,10 @@ class Account(object):
self.ctxs = {}
self.trusts = {}
self.maxMessageSize = maxMessageSize
- self.defaultQuery = b'?OTRv{versions}?\n{accountname} has requested ' \
- b'an Off-the-Record private conversation. However, you ' \
- b'do not have a plugin to support that.\nSee '\
- b'http://otr.cypherpunks.ca/ for more information.';
+ self.defaultQuery = '?OTRv{versions}?\n{accountname} has requested ' \
+ 'an Off-the-Record private conversation. However, you ' \
+ 'do not have a plugin to support that.\nSee '\
+ 'http://otr.cypherpunks.ca/ for more information.'
def __repr__(self):
return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__,
@@ -461,7 +516,7 @@ class Account(object):
self.privkey = self.loadPrivkey()
if self.privkey is None:
if autogen is True:
- self.privkey = crypt.generateDefaultKey()
+ self.privkey = compatcrypto.generateDefaultKey()
self.savePrivkey()
else:
raise LookupError
@@ -484,8 +539,9 @@ class Account(object):
return self.ctxs[uid]
def getDefaultQueryMessage(self, policy):
- v = b'2' if policy('ALLOW_V2') else b''
- return self.defaultQuery.format(accountname=self.name, versions=v)
+ v = '2' if policy('ALLOW_V2') else ''
+ msg = self.defaultQuery.format(accountname=self.name, versions=v)
+ return msg.encode('ascii')
def setTrust(self, key, fingerprint, trustLevel):
if key not in self.trusts:
diff --git a/gotr/potr/crypt.py b/gotr/potr/crypt.py
index ad5d663..3a4bb4b 100644
--- a/gotr/potr/crypt.py
+++ b/gotr/potr/crypt.py
@@ -22,8 +22,8 @@ import logging
import struct
-from potr.compatcrypto import SHA256, SHA1, HMAC, SHA1HMAC, SHA256HMAC, \
- SHA256HMAC160, Counter, AESCTR, RNG, PK, generateDefaultKey
+from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
+ SHA256HMAC160, Counter, AESCTR, PK, random
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
from potr import proto
@@ -36,32 +36,31 @@ STATE_AWAITING_SIG = 4
STATE_V1_SETUP = 5
-DH1536_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
-DH1536_MODULUS_2 = DH1536_MODULUS-2
-DH1536_GENERATOR = 2
-SM_ORDER = (DH1536_MODULUS - 1) // 2
+DH_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
+DH_MODULUS_2 = DH_MODULUS-2
+DH_GENERATOR = 2
+DH_BITS = 1536
+DH_MAX = 2**DH_BITS
+SM_ORDER = (DH_MODULUS - 1) // 2
def check_group(n):
- return 2 <= n <= DH1536_MODULUS_2
+ return 2 <= n <= DH_MODULUS_2
def check_exp(n):
return 1 <= n < SM_ORDER
class DH(object):
- __slots__ = ['priv', 'pub']
@classmethod
def set_params(cls, prime, gen):
cls.prime = prime
cls.gen = gen
def __init__(self):
- self.priv = bytes_to_long(RNG.read(40))
+ self.priv = random.randrange(2, 2**320)
self.pub = pow(self.gen, self.priv, self.prime)
-DH.set_params(DH1536_MODULUS, DH1536_GENERATOR)
+DH.set_params(DH_MODULUS, DH_GENERATOR)
class DHSession(object):
- __slots__ = ['sendenc', 'sendmac', 'rcvenc', 'rcvmac', 'sendctr', 'rcvctr',
- 'sendmacused', 'rcvmacused']
def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
self.sendenc = sendenc
self.sendmac = sendmac
@@ -79,7 +78,7 @@ class DHSession(object):
@classmethod
def create(cls, dh, y):
- s = pow(y, dh.priv, DH1536_MODULUS)
+ s = pow(y, dh.priv, DH_MODULUS)
sb = pack_mpi(s)
if dh.pub > y:
@@ -96,9 +95,6 @@ class DHSession(object):
return cls(sendenc, sendmac, rcvenc, rcvmac)
class CryptEngine(object):
- __slots__ = ['ctx', 'ake', 'sessionId', 'sessionIdHalf', 'theirKeyid',
- 'theirY', 'theirOldY', 'ourOldDHKey', 'ourDHKey', 'ourKeyid',
- 'sessionkeys', 'theirPubkey', 'savedMacKeys', 'smp']
def __init__(self, ctx):
self.ctx = ctx
self.ake = None
@@ -118,6 +114,7 @@ class CryptEngine(object):
self.savedMacKeys = []
self.smp = None
+ self.extraKey = None
def revealMacs(self, ours=True):
if ours:
@@ -174,7 +171,7 @@ class CryptEngine(object):
if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()):
logger.error('HMACs don\'t match')
raise InvalidParameterError
- sesskey.rcvmacused = 1
+ sesskey.rcvmacused = True
newCtrPrefix = bytes_to_long(msg.ctr)
if newCtrPrefix <= sesskey.rcvctr.prefix:
@@ -223,11 +220,14 @@ class CryptEngine(object):
self.smp = SMPHandler(self)
self.smp.abort(appdata=appdata)
- def createDataMessage(self, message, flags=0, tlvs=[]):
+ def createDataMessage(self, message, flags=0, tlvs=None):
# check MSGSTATE
if self.theirKeyid == 0:
raise InvalidParameterError
+ if tlvs is None:
+ tlvs = []
+
sess = self.sessionkeys[1][0]
sess.sendctr.inc()
@@ -303,13 +303,16 @@ class CryptEngine(object):
self.ourKeyid = ake.ourKeyid
self.theirY = ake.gy
self.theirOldY = None
+ self.extraKey = ake.extraKey
if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub:
- # XXX is this really ok?
self.ourDHKey = ake.dh
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
self.rotateDHKeys()
+ # we don't need the AKE anymore, free the reference
+ self.ake = None
+
self.ctx._wentEncrypted()
logger.info('went encrypted with {0}'.format(self.theirPubkey))
@@ -317,10 +320,6 @@ class CryptEngine(object):
self.smp = None
class AuthKeyExchange(object):
- __slots__ = ['privkey', 'state', 'r', 'encgx', 'hashgx', 'ourKeyid',
- 'theirPubkey', 'theirKeyid', 'enc_c', 'enc_cp', 'mac_m1',
- 'mac_m1p', 'mac_m2', 'mac_m2p', 'sessionId', 'dh', 'onSuccess',
- 'gy', 'lastmsg', 'sessionIdHalf']
def __init__(self, privkey, onSuccess):
self.privkey = privkey
self.state = STATE_NONE
@@ -341,9 +340,11 @@ class AuthKeyExchange(object):
self.dh = DH()
self.onSuccess = onSuccess
self.gy = None
+ self.extraKey = None
+ self.lastmsg = None
def startAKE(self):
- self.r = RNG.read(16)
+ self.r = long_to_bytes(random.getrandbits(128))
gxmpi = pack_mpi(self.dh.pub)
@@ -444,15 +445,17 @@ class AuthKeyExchange(object):
self.state = STATE_NONE
def createAuthKeys(self):
- s = pow(self.gy, self.dh.priv, DH1536_MODULUS)
+ s = pow(self.gy, self.dh.priv, DH_MODULUS)
sbyte = pack_mpi(s)
- self.sessionId = SHA256(b'\0' + sbyte)[:8]
- enc = SHA256(b'\1' + sbyte)
- self.enc_c, self.enc_cp = enc[:16], enc[16:]
- self.mac_m1 = SHA256(b'\2' + sbyte)
- self.mac_m2 = SHA256(b'\3' + sbyte)
- self.mac_m1p = SHA256(b'\4' + sbyte)
- self.mac_m2p = SHA256(b'\5' + sbyte)
+ self.sessionId = SHA256(b'\x00' + sbyte)[:8]
+ enc = SHA256(b'\x01' + sbyte)
+ self.enc_c = enc[:16]
+ self.enc_cp = enc[16:]
+ self.mac_m1 = SHA256(b'\x02' + sbyte)
+ self.mac_m2 = SHA256(b'\x03' + sbyte)
+ self.mac_m1p = SHA256(b'\x04' + sbyte)
+ self.mac_m2p = SHA256(b'\x05' + sbyte)
+ self.extraKey = SHA256(b'\xff' + sbyte)
def calculatePubkeyAuth(self, key, mackey):
pubkey = self.privkey.serializePublicKey()
@@ -490,14 +493,15 @@ SMPPROG_FAILED = -1
SMPPROG_SUCCEEDED = 1
class SMPHandler:
- __slots__ = ['crypto', 'questionReceived', 'prog', 'state', 'g1', 'g3o',
- 'x2', 'x3', 'g2', 'g3', 'pab', 'qab', 'secret', 'p', 'q']
-
def __init__(self, crypto):
self.crypto = crypto
self.state = 1
- self.g1 = DH1536_GENERATOR
+ self.g1 = DH_GENERATOR
+ self.g2 = None
+ self.g3 = None
self.g3o = None
+ self.x2 = None
+ self.x3 = None
self.prog = SMPPROG_OK
self.pab = None
self.qab = None
@@ -539,11 +543,11 @@ class SMPHandler:
self.g3o = msg[3]
- self.x2 = bytes_to_long(RNG.read(192))
- self.x3 = bytes_to_long(RNG.read(192))
+ self.x2 = random.randrange(2, DH_MAX)
+ self.x3 = random.randrange(2, DH_MAX)
- self.g2 = pow(msg[0], self.x2, DH1536_MODULUS)
- self.g3 = pow(msg[3], self.x3, DH1536_MODULUS)
+ self.g2 = pow(msg[0], self.x2, DH_MODULUS)
+ self.g3 = pow(msg[3], self.x3, DH_MODULUS)
self.prog = SMPPROG_OK
self.state = 0
@@ -568,29 +572,29 @@ class SMPHandler:
return
self.g3o = msg[3]
- self.g2 = pow(msg[0], self.x2, DH1536_MODULUS)
- self.g3 = pow(msg[3], self.x3, DH1536_MODULUS)
+ self.g2 = pow(msg[0], self.x2, DH_MODULUS)
+ self.g3 = pow(msg[3], self.x3, DH_MODULUS)
if not self.check_equal_coords(msg[6:11], 5):
logger.error('invalid SMP2TLV received')
self.abort(appdata=appdata)
return
- r = bytes_to_long(RNG.read(192))
- self.p = pow(self.g3, r, DH1536_MODULUS)
+ r = random.randrange(2, DH_MAX)
+ self.p = pow(self.g3, r, DH_MODULUS)
msg = [self.p]
- qa1 = pow(self.g1, r, DH1536_MODULUS)
- qa2 = pow(self.g2, self.secret, DH1536_MODULUS)
- self.q = qa1*qa2 % DH1536_MODULUS
+ qa1 = pow(self.g1, r, DH_MODULUS)
+ qa2 = pow(self.g2, self.secret, DH_MODULUS)
+ self.q = qa1*qa2 % DH_MODULUS
msg.append(self.q)
msg += self.proof_equal_coords(r, 6)
inv = invMod(mp)
- self.pab = self.p * inv % DH1536_MODULUS
+ self.pab = self.p * inv % DH_MODULUS
inv = invMod(mq)
- self.qab = self.q * inv % DH1536_MODULUS
+ self.qab = self.q * inv % DH_MODULUS
- msg.append(pow(self.qab, self.x3, DH1536_MODULUS))
+ msg.append(pow(self.qab, self.x3, DH_MODULUS))
msg += self.proof_equal_logs(7)
self.state = 4
@@ -613,9 +617,9 @@ class SMPHandler:
return
inv = invMod(self.p)
- self.pab = msg[0] * inv % DH1536_MODULUS
+ self.pab = msg[0] * inv % DH_MODULUS
inv = invMod(self.q)
- self.qab = msg[1] * inv % DH1536_MODULUS
+ self.qab = msg[1] * inv % DH_MODULUS
if not self.check_equal_logs(msg[5:8], 7):
logger.error('invalid SMP3TLV received')
@@ -623,10 +627,10 @@ class SMPHandler:
return
md = msg[5]
- msg = [pow(self.qab, self.x3, DH1536_MODULUS)]
+ msg = [pow(self.qab, self.x3, DH_MODULUS)]
msg += self.proof_equal_logs(8)
- rab = pow(md, self.x3, DH1536_MODULUS)
+ rab = pow(md, self.x3, DH_MODULUS)
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
if self.prog != SMPPROG_SUCCEEDED:
@@ -654,7 +658,7 @@ class SMPHandler:
self.abort(appdata=appdata)
return
- rab = pow(msg[0], self.x3, DH1536_MODULUS)
+ rab = pow(msg[0], self.x3, DH_MODULUS)
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
@@ -679,12 +683,12 @@ class SMPHandler:
self.secret = bytes_to_long(combSecret)
- self.x2 = bytes_to_long(RNG.read(192))
- self.x3 = bytes_to_long(RNG.read(192))
+ self.x2 = random.randrange(2, DH_MAX)
+ self.x3 = random.randrange(2, DH_MAX)
- msg = [pow(self.g1, self.x2, DH1536_MODULUS)]
+ msg = [pow(self.g1, self.x2, DH_MODULUS)]
msg += proof_known_log(self.g1, self.x2, 1)
- msg.append(pow(self.g1, self.x3, DH1536_MODULUS))
+ msg.append(pow(self.g1, self.x3, DH_MODULUS))
msg += proof_known_log(self.g1, self.x3, 2)
self.prog = SMPPROG_OK
@@ -700,19 +704,19 @@ class SMPHandler:
self.secret = bytes_to_long(combSecret)
- msg = [pow(self.g1, self.x2, DH1536_MODULUS)]
+ msg = [pow(self.g1, self.x2, DH_MODULUS)]
msg += proof_known_log(self.g1, self.x2, 3)
- msg.append(pow(self.g1, self.x3, DH1536_MODULUS))
+ msg.append(pow(self.g1, self.x3, DH_MODULUS))
msg += proof_known_log(self.g1, self.x3, 4)
- r = bytes_to_long(RNG.read(192))
+ r = random.randrange(2, DH_MAX)
- self.p = pow(self.g3, r, DH1536_MODULUS)
+ self.p = pow(self.g3, r, DH_MODULUS)
msg.append(self.p)
- qb1 = pow(self.g1, r, DH1536_MODULUS)
- qb2 = pow(self.g2, self.secret, DH1536_MODULUS)
- self.q = qb1 * qb2 % DH1536_MODULUS
+ qb1 = pow(self.g1, r, DH_MODULUS)
+ qb2 = pow(self.g2, self.secret, DH_MODULUS)
+ self.q = qb1 * qb2 % DH_MODULUS
msg.append(self.q)
msg += self.proof_equal_coords(r, 5)
@@ -721,11 +725,11 @@ class SMPHandler:
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)
def proof_equal_coords(self, r, v):
- r1 = bytes_to_long(RNG.read(192))
- r2 = bytes_to_long(RNG.read(192))
- temp2 = pow(self.g1, r1, DH1536_MODULUS) \
- * pow(self.g2, r2, DH1536_MODULUS) % DH1536_MODULUS
- temp1 = pow(self.g3, r1, DH1536_MODULUS)
+ r1 = random.randrange(2, DH_MAX)
+ r2 = random.randrange(2, DH_MAX)
+ temp2 = pow(self.g1, r1, DH_MODULUS) \
+ * pow(self.g2, r2, DH_MODULUS) % DH_MODULUS
+ temp1 = pow(self.g3, r1, DH_MODULUS)
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
c = bytes_to_long(cb)
@@ -739,21 +743,21 @@ class SMPHandler:
def check_equal_coords(self, coords, v):
(p, q, c, d1, d2) = coords
- temp1 = pow(self.g3, d1, DH1536_MODULUS) * pow(p, c, DH1536_MODULUS) \
- % DH1536_MODULUS
+ temp1 = pow(self.g3, d1, DH_MODULUS) * pow(p, c, DH_MODULUS) \
+ % DH_MODULUS
- temp2 = pow(self.g1, d1, DH1536_MODULUS) \
- * pow(self.g2, d2, DH1536_MODULUS) \
- * pow(q, c, DH1536_MODULUS) % DH1536_MODULUS
+ temp2 = pow(self.g1, d1, DH_MODULUS) \
+ * pow(self.g2, d2, DH_MODULUS) \
+ * pow(q, c, DH_MODULUS) % DH_MODULUS
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
- return long_to_bytes(c) == cprime
+ return long_to_bytes(c, 32) == cprime
def proof_equal_logs(self, v):
- r = bytes_to_long(RNG.read(192))
- temp1 = pow(self.g1, r, DH1536_MODULUS)
- temp2 = pow(self.qab, r, DH1536_MODULUS)
+ r = random.randrange(2, DH_MAX)
+ temp1 = pow(self.g1, r, DH_MODULUS)
+ temp2 = pow(self.qab, r, DH_MODULUS)
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
c = bytes_to_long(cb)
@@ -763,29 +767,29 @@ class SMPHandler:
def check_equal_logs(self, logs, v):
(r, c, d) = logs
- temp1 = pow(self.g1, d, DH1536_MODULUS) \
- * pow(self.g3o, c, DH1536_MODULUS) % DH1536_MODULUS
+ temp1 = pow(self.g1, d, DH_MODULUS) \
+ * pow(self.g3o, c, DH_MODULUS) % DH_MODULUS
- temp2 = pow(self.qab, d, DH1536_MODULUS) \
- * pow(r, c, DH1536_MODULUS) % DH1536_MODULUS
+ temp2 = pow(self.qab, d, DH_MODULUS) \
+ * pow(r, c, DH_MODULUS) % DH_MODULUS
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
- return long_to_bytes(c) == cprime
+ return long_to_bytes(c, 32) == cprime
def proof_known_log(g, x, v):
- r = bytes_to_long(RNG.read(192))
- c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH1536_MODULUS))))
+ r = random.randrange(2, DH_MAX)
+ c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS))))
temp = x * c % SM_ORDER
return c, (r-temp) % SM_ORDER
def check_known_log(c, d, g, x, v):
- gd = pow(g, d, DH1536_MODULUS)
- xc = pow(x, c, DH1536_MODULUS)
- gdxc = gd * xc % DH1536_MODULUS
- return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c)
+ gd = pow(g, d, DH_MODULUS)
+ xc = pow(x, c, DH_MODULUS)
+ gdxc = gd * xc % DH_MODULUS
+ return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c, 32)
def invMod(n):
- return pow(n, DH1536_MODULUS_2, DH1536_MODULUS)
+ return pow(n, DH_MODULUS_2, DH_MODULUS)
class InvalidParameterError(RuntimeError):
pass
diff --git a/gotr/potr/proto.py b/gotr/potr/proto.py
index 745a53d..91904a8 100644
--- a/gotr/potr/proto.py
+++ b/gotr/potr/proto.py
@@ -19,14 +19,16 @@
from __future__ import unicode_literals
import base64
-import logging
import struct
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
OTRTAG = b'?OTR'
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
-MESSAGE_TAG_V1 = b' \t \t \t '
-MESSAGE_TAG_V2 = b' \t\t \t '
+MESSAGE_TAGS = {
+ 1:b' \t \t \t ',
+ 2:b' \t\t \t ',
+ 3:b' \t\t \t\t',
+ }
MSGTYPE_NOTOTR = 0
MSGTYPE_TAGGEDPLAINTEXT = 1
@@ -62,6 +64,8 @@ def registermessage(cls):
def registertlv(cls):
if not hasattr(cls, 'parsePayload'):
raise TypeError('registered tlv types need parsePayload()')
+ if cls.typ is None:
+ raise TypeError('registered tlv type needs type ID')
tlvClasses[cls.typ] = cls
return cls
@@ -84,16 +88,6 @@ class OTRMessage(object):
__slots__ = ['payload']
version = 0x0002
msgtype = 0
- def __init__(self, payload):
- self.payload = payload
-
- def getPayload(self):
- return self.payload
-
- def __bytes__(self):
- data = struct.pack(b'!HB', self.version, self.msgtype) \
- + self.getPayload()
- return b'?OTR:' + base64.b64encode(data) + b'.'
def __eq__(self, other):
if not isinstance(other, self.__class__):
@@ -110,6 +104,7 @@ class OTRMessage(object):
class Error(OTRMessage):
__slots__ = ['error']
def __init__(self, error):
+ super(Error, self).__init__()
self.error = error
def __repr__(self):
@@ -119,56 +114,58 @@ class Error(OTRMessage):
return b'?OTR Error:' + self.error
class Query(OTRMessage):
- __slots__ = ['v1', 'v2']
- def __init__(self, v1, v2):
- self.v1 = v1
- self.v2 = v2
+ __slots__ = ['versions']
+ def __init__(self, versions=set()):
+ super(Query, self).__init__()
+ self.versions = versions
@classmethod
def parse(cls, data):
- v2 = False
- v1 = False
- if len(data) > 0 and data[0:1] == b'?':
- data = data[1:]
- v1 = True
-
- if len(data) > 0 and data[0:1] == b'v':
- for c in data[1:]:
- if c == b'2'[0]:
- v2 = True
- return cls(v1, v2)
+ if not isinstance(data, bytes):
+ raise TypeError('can only parse bytes')
+ udata = data.decode('ascii', errors='replace')
+
+ versions = set()
+ if len(udata) > 0 and udata[0] == '?':
+ udata = udata[1:]
+ versions.add(1)
+
+ if len(udata) > 0 and udata[0] == 'v':
+ versions.update(( int(c) for c in udata if c.isdigit() ))
+ return cls(versions)
def __repr__(self):
- return '<proto.Query(v1=%r,v2=%r)>'%(self.v1,self.v2)
+ return '<proto.Query(versions=%r)>' % (self.versions)
def __bytes__(self):
d = b'?OTR'
- if self.v1:
+ if 1 in self.versions:
d += b'?'
d += b'v'
- if self.v2:
- d += b'2'
+
+ # in python3 there is only int->unicode conversion
+ # so I convert to unicode and encode it to a byte string
+ versions = [ '%d' % v for v in self.versions if v != 1 ]
+ d += ''.join(versions).encode('ascii')
+
d += b'?'
return d
class TaggedPlaintext(Query):
__slots__ = ['msg']
- def __init__(self, msg, v1, v2):
+ def __init__(self, msg, versions):
+ super(TaggedPlaintext, self).__init__(versions)
self.msg = msg
- self.v1 = v1
- self.v2 = v2
def __bytes__(self):
data = self.msg + MESSAGE_TAG_BASE
- if self.v1:
- data += MESSAGE_TAG_V1
- if self.v2:
- data += MESSAGE_TAG_V2
+ for v in self.versions:
+ data += MESSAGE_TAGS[v]
return data
def __repr__(self):
- return '<proto.TaggedPlaintext(v1={v1!r},v2={v2!r},msg={msg!r})>' \
- .format(v1=self.v1, v2=self.v2, msg=self.msg)
+ return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
+ .format(versions=self.versions, msg=self.msg)
@classmethod
def parse(cls, data):
@@ -177,21 +174,18 @@ class TaggedPlaintext(Query):
raise TypeError(
'this is not a tagged plaintext ({0!r:.20})'.format(data))
- v1 = False
- v2 = False
-
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
- for tag in tags:
- if not tag.isspace():
- break
- v1 |= tag == MESSAGE_TAG_V1
- v2 |= tag == MESSAGE_TAG_V2
+ versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
+ in tags ])
- return TaggedPlaintext(data[:tagPos], v1, v2)
+ return TaggedPlaintext(data[:tagPos], versions)
class GenericOTRMessage(OTRMessage):
__slots__ = ['data']
+ fields = []
+
def __init__(self, *args):
+ super(GenericOTRMessage, self).__init__()
if len(args) != len(self.fields):
raise TypeError('%s needs %d arguments, got %d' %
(self.__class__.__name__, len(self.fields), len(args)))
@@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage):
self.__getattr__(attr) # existence check
self.data[attr] = val
+ def __bytes__(self):
+ data = struct.pack(b'!HB', self.version, self.msgtype) \
+ + self.getPayload()
+ return b'?OTR:' + base64.b64encode(data) + b'.'
+
def __repr__(self):
name = self.__class__.__name__
data = ''
@@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage):
def parsePayload(cls, data):
data = base64.b64decode(data)
args = []
- for k, ftype in cls.fields:
+ for _, ftype in cls.fields:
if ftype == 'data':
value, data = read_data(data)
elif isinstance(ftype, bytes):
- size = int(struct.calcsize(ftype))
value, data = unpack(ftype, data)
elif isinstance(ftype, int):
value, data = data[:ftype], data[ftype:]
@@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage):
class AKEMessage(GenericOTRMessage):
__slots__ = []
- pass
@registermessage
class DHCommit(AKEMessage):
__slots__ = []
msgtype = 0x02
- fields = [('encgx','data'), ('hashgx','data'), ]
-
+ fields = [('encgx', 'data'), ('hashgx', 'data'), ]
@registermessage
class DHKey(AKEMessage):
__slots__ = []
msgtype = 0x0a
- fields = [('gy','data'), ]
+ fields = [('gy', 'data'), ]
@registermessage
class RevealSig(AKEMessage):
__slots__ = []
msgtype = 0x11
- fields = [('rkey','data'), ('encsig','data'), ('mac',20),]
+ fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
def getMacedData(self):
p = self.encsig
@@ -280,7 +276,7 @@ class RevealSig(AKEMessage):
class Signature(AKEMessage):
__slots__ = []
msgtype = 0x12
- fields = [('encsig','data'), ('mac',20)]
+ fields = [('encsig', 'data'), ('mac', 20)]
def getMacedData(self):
p = self.encsig
@@ -290,8 +286,9 @@ class Signature(AKEMessage):
class DataMessage(GenericOTRMessage):
__slots__ = []
msgtype = 0x03
- fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'),
- ('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ]
+ fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
+ ('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
+ ('oldmacs', 'data'), ]
def getMacedData(self):
return struct.pack(b'!HB', self.version, self.msgtype) + \
@@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage):
@bytesAndStrings
class TLV(object):
__slots__ = []
+ typ = None
+
+ def getPayload(self):
+ raise NotImplementedError
def __repr__(self):
val = self.getPayload()
@@ -331,10 +332,27 @@ class TLV(object):
return not self.__eq__(other)
@registertlv
+class PaddingTLV(TLV):
+ typ = 0
+
+ __slots__ = ['padding']
+
+ def __init__(self, padding):
+ super(PaddingTLV, self).__init__()
+ self.padding = padding
+
+ def getPayload(self):
+ return self.padding
+
+ @classmethod
+ def parsePayload(cls, data):
+ return cls(data)
+
+@registertlv
class DisconnectTLV(TLV):
typ = 1
def __init__(self):
- pass
+ super(DisconnectTLV, self).__init__()
def getPayload(self):
return b''
@@ -348,8 +366,14 @@ class DisconnectTLV(TLV):
class SMPTLV(TLV):
__slots__ = ['mpis']
-
- def __init__(self, mpis=[]):
+ dlen = None
+
+ def __init__(self, mpis=None):
+ super(SMPTLV, self).__init__()
+ if mpis is None:
+ mpis = []
+ if self.dlen is None:
+ raise TypeError('no amount of mpis specified in dlen')
if len(mpis) != self.dlen:
raise TypeError('expected {0} mpis, got {1}'
.format(self.dlen, len(mpis)))
@@ -366,7 +390,7 @@ class SMPTLV(TLV):
mpis = []
if cls.dlen > 0:
count, data = unpack(b'!I', data)
- for i in range(count):
+ for _ in range(count):
n, data = read_mpi(data)
mpis.append(n)
if len(data) > 0:
@@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV):
def getPayload(self):
return b''
+
+@registertlv
+class ExtraKeyTLV(TLV):
+ typ = 8
+
+ __slots__ = ['appid', 'appdata']
+
+ def __init__(self, appid, appdata):
+ super(ExtraKeyTLV, self).__init__()
+ self.appid = appid
+ self.appdata = appdata
+ if appdata is None:
+ self.appdata = b''
+
+ def getPayload(self):
+ return self.appid + self.appdata
+
+ @classmethod
+ def parsePayload(cls, data):
+ return cls(data[:4], data[4:])
diff --git a/gotr/potr/utils.py b/gotr/potr/utils.py
index 2bedf55..e41ca46 100644
--- a/gotr/potr/utils.py
+++ b/gotr/potr/utils.py
@@ -43,11 +43,12 @@ def bytes_to_long(b):
s += byte_to_long(b[i:i+1]) << 8*(l-i-1)
return s
-def long_to_bytes(l):
+def long_to_bytes(l, n=0):
b = b''
- while l != 0:
+ while l != 0 or n > 0:
b = long_to_byte(l & 0xff) + b
l >>= 8
+ n -= 1
return b
def byte_to_long(b):