Welcome to mirror list, hosted at ThFree Co, Russian Federation.

dev.gajim.org/gajim/gajim-plugins.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'gotr/potr/proto.py')
-rw-r--r--gotr/potr/proto.py172
1 files changed, 108 insertions, 64 deletions
diff --git a/gotr/potr/proto.py b/gotr/potr/proto.py
index 745a53d..91904a8 100644
--- a/gotr/potr/proto.py
+++ b/gotr/potr/proto.py
@@ -19,14 +19,16 @@
from __future__ import unicode_literals
import base64
-import logging
import struct
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
OTRTAG = b'?OTR'
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
-MESSAGE_TAG_V1 = b' \t \t \t '
-MESSAGE_TAG_V2 = b' \t\t \t '
+MESSAGE_TAGS = {
+ 1:b' \t \t \t ',
+ 2:b' \t\t \t ',
+ 3:b' \t\t \t\t',
+ }
MSGTYPE_NOTOTR = 0
MSGTYPE_TAGGEDPLAINTEXT = 1
@@ -62,6 +64,8 @@ def registermessage(cls):
def registertlv(cls):
if not hasattr(cls, 'parsePayload'):
raise TypeError('registered tlv types need parsePayload()')
+ if cls.typ is None:
+ raise TypeError('registered tlv type needs type ID')
tlvClasses[cls.typ] = cls
return cls
@@ -84,16 +88,6 @@ class OTRMessage(object):
__slots__ = ['payload']
version = 0x0002
msgtype = 0
- def __init__(self, payload):
- self.payload = payload
-
- def getPayload(self):
- return self.payload
-
- def __bytes__(self):
- data = struct.pack(b'!HB', self.version, self.msgtype) \
- + self.getPayload()
- return b'?OTR:' + base64.b64encode(data) + b'.'
def __eq__(self, other):
if not isinstance(other, self.__class__):
@@ -110,6 +104,7 @@ class OTRMessage(object):
class Error(OTRMessage):
__slots__ = ['error']
def __init__(self, error):
+ super(Error, self).__init__()
self.error = error
def __repr__(self):
@@ -119,56 +114,58 @@ class Error(OTRMessage):
return b'?OTR Error:' + self.error
class Query(OTRMessage):
- __slots__ = ['v1', 'v2']
- def __init__(self, v1, v2):
- self.v1 = v1
- self.v2 = v2
+ __slots__ = ['versions']
+ def __init__(self, versions=set()):
+ super(Query, self).__init__()
+ self.versions = versions
@classmethod
def parse(cls, data):
- v2 = False
- v1 = False
- if len(data) > 0 and data[0:1] == b'?':
- data = data[1:]
- v1 = True
-
- if len(data) > 0 and data[0:1] == b'v':
- for c in data[1:]:
- if c == b'2'[0]:
- v2 = True
- return cls(v1, v2)
+ if not isinstance(data, bytes):
+ raise TypeError('can only parse bytes')
+ udata = data.decode('ascii', errors='replace')
+
+ versions = set()
+ if len(udata) > 0 and udata[0] == '?':
+ udata = udata[1:]
+ versions.add(1)
+
+ if len(udata) > 0 and udata[0] == 'v':
+ versions.update(( int(c) for c in udata if c.isdigit() ))
+ return cls(versions)
def __repr__(self):
- return '<proto.Query(v1=%r,v2=%r)>'%(self.v1,self.v2)
+ return '<proto.Query(versions=%r)>' % (self.versions)
def __bytes__(self):
d = b'?OTR'
- if self.v1:
+ if 1 in self.versions:
d += b'?'
d += b'v'
- if self.v2:
- d += b'2'
+
+ # in python3 there is only int->unicode conversion
+ # so I convert to unicode and encode it to a byte string
+ versions = [ '%d' % v for v in self.versions if v != 1 ]
+ d += ''.join(versions).encode('ascii')
+
d += b'?'
return d
class TaggedPlaintext(Query):
__slots__ = ['msg']
- def __init__(self, msg, v1, v2):
+ def __init__(self, msg, versions):
+ super(TaggedPlaintext, self).__init__(versions)
self.msg = msg
- self.v1 = v1
- self.v2 = v2
def __bytes__(self):
data = self.msg + MESSAGE_TAG_BASE
- if self.v1:
- data += MESSAGE_TAG_V1
- if self.v2:
- data += MESSAGE_TAG_V2
+ for v in self.versions:
+ data += MESSAGE_TAGS[v]
return data
def __repr__(self):
- return '<proto.TaggedPlaintext(v1={v1!r},v2={v2!r},msg={msg!r})>' \
- .format(v1=self.v1, v2=self.v2, msg=self.msg)
+ return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
+ .format(versions=self.versions, msg=self.msg)
@classmethod
def parse(cls, data):
@@ -177,21 +174,18 @@ class TaggedPlaintext(Query):
raise TypeError(
'this is not a tagged plaintext ({0!r:.20})'.format(data))
- v1 = False
- v2 = False
-
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
- for tag in tags:
- if not tag.isspace():
- break
- v1 |= tag == MESSAGE_TAG_V1
- v2 |= tag == MESSAGE_TAG_V2
+ versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
+ in tags ])
- return TaggedPlaintext(data[:tagPos], v1, v2)
+ return TaggedPlaintext(data[:tagPos], versions)
class GenericOTRMessage(OTRMessage):
__slots__ = ['data']
+ fields = []
+
def __init__(self, *args):
+ super(GenericOTRMessage, self).__init__()
if len(args) != len(self.fields):
raise TypeError('%s needs %d arguments, got %d' %
(self.__class__.__name__, len(self.fields), len(args)))
@@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage):
self.__getattr__(attr) # existence check
self.data[attr] = val
+ def __bytes__(self):
+ data = struct.pack(b'!HB', self.version, self.msgtype) \
+ + self.getPayload()
+ return b'?OTR:' + base64.b64encode(data) + b'.'
+
def __repr__(self):
name = self.__class__.__name__
data = ''
@@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage):
def parsePayload(cls, data):
data = base64.b64decode(data)
args = []
- for k, ftype in cls.fields:
+ for _, ftype in cls.fields:
if ftype == 'data':
value, data = read_data(data)
elif isinstance(ftype, bytes):
- size = int(struct.calcsize(ftype))
value, data = unpack(ftype, data)
elif isinstance(ftype, int):
value, data = data[:ftype], data[ftype:]
@@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage):
class AKEMessage(GenericOTRMessage):
__slots__ = []
- pass
@registermessage
class DHCommit(AKEMessage):
__slots__ = []
msgtype = 0x02
- fields = [('encgx','data'), ('hashgx','data'), ]
-
+ fields = [('encgx', 'data'), ('hashgx', 'data'), ]
@registermessage
class DHKey(AKEMessage):
__slots__ = []
msgtype = 0x0a
- fields = [('gy','data'), ]
+ fields = [('gy', 'data'), ]
@registermessage
class RevealSig(AKEMessage):
__slots__ = []
msgtype = 0x11
- fields = [('rkey','data'), ('encsig','data'), ('mac',20),]
+ fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
def getMacedData(self):
p = self.encsig
@@ -280,7 +276,7 @@ class RevealSig(AKEMessage):
class Signature(AKEMessage):
__slots__ = []
msgtype = 0x12
- fields = [('encsig','data'), ('mac',20)]
+ fields = [('encsig', 'data'), ('mac', 20)]
def getMacedData(self):
p = self.encsig
@@ -290,8 +286,9 @@ class Signature(AKEMessage):
class DataMessage(GenericOTRMessage):
__slots__ = []
msgtype = 0x03
- fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'),
- ('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ]
+ fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
+ ('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
+ ('oldmacs', 'data'), ]
def getMacedData(self):
return struct.pack(b'!HB', self.version, self.msgtype) + \
@@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage):
@bytesAndStrings
class TLV(object):
__slots__ = []
+ typ = None
+
+ def getPayload(self):
+ raise NotImplementedError
def __repr__(self):
val = self.getPayload()
@@ -331,10 +332,27 @@ class TLV(object):
return not self.__eq__(other)
@registertlv
+class PaddingTLV(TLV):
+ typ = 0
+
+ __slots__ = ['padding']
+
+ def __init__(self, padding):
+ super(PaddingTLV, self).__init__()
+ self.padding = padding
+
+ def getPayload(self):
+ return self.padding
+
+ @classmethod
+ def parsePayload(cls, data):
+ return cls(data)
+
+@registertlv
class DisconnectTLV(TLV):
typ = 1
def __init__(self):
- pass
+ super(DisconnectTLV, self).__init__()
def getPayload(self):
return b''
@@ -348,8 +366,14 @@ class DisconnectTLV(TLV):
class SMPTLV(TLV):
__slots__ = ['mpis']
-
- def __init__(self, mpis=[]):
+ dlen = None
+
+ def __init__(self, mpis=None):
+ super(SMPTLV, self).__init__()
+ if mpis is None:
+ mpis = []
+ if self.dlen is None:
+ raise TypeError('no amount of mpis specified in dlen')
if len(mpis) != self.dlen:
raise TypeError('expected {0} mpis, got {1}'
.format(self.dlen, len(mpis)))
@@ -366,7 +390,7 @@ class SMPTLV(TLV):
mpis = []
if cls.dlen > 0:
count, data = unpack(b'!I', data)
- for i in range(count):
+ for _ in range(count):
n, data = read_mpi(data)
mpis.append(n)
if len(data) > 0:
@@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV):
def getPayload(self):
return b''
+
+@registertlv
+class ExtraKeyTLV(TLV):
+ typ = 8
+
+ __slots__ = ['appid', 'appdata']
+
+ def __init__(self, appid, appdata):
+ super(ExtraKeyTLV, self).__init__()
+ self.appid = appid
+ self.appdata = appdata
+ if appdata is None:
+ self.appdata = b''
+
+ def getPayload(self):
+ return self.appid + self.appdata
+
+ @classmethod
+ def parsePayload(cls, data):
+ return cls(data[:4], data[4:])