#
# This file is a part of DNSViz, a tool suite for DNS/DNSSEC monitoring,
# analysis, and visualization.
# Created by Casey Deccio (casey@deccio.net)
#
# Copyright 2014-2016 VeriSign, Inc.
#
# Copyright 2016-2021 Casey Deccio
#
# DNSViz is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# DNSViz is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with DNSViz. If not, see .
#
from __future__ import unicode_literals
import base64
import bisect
import codecs
import errno
import fcntl
import io
import json
import os
import random
import re
import select
import socket
import ssl
import struct
import subprocess
import threading
import time
# minimal support for python2.6
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict
# python3/python2 dual compatibility
try:
import queue
except ImportError:
import Queue as queue
try:
import urllib.parse
except ImportError:
import urlparse
import urllib
urlquote = urllib
else:
urlparse = urllib.parse
urlquote = urllib.parse
import dns.exception
from .ipaddr import IPAddr, ANY_IPV6, ANY_IPV4
from .format import latin1_binary_to_string as lb2s
DNS_TRANSPORT_VERSION = 1.0
MAX_PORT_BIND_ATTEMPTS=10
MAX_WAIT_FOR_REQUEST=30
HTTP_HEADER_END_RE = re.compile(r'(\r\n\r\n|\n\n|\r\r)')
HTTP_STATUS_RE = re.compile(r'^HTTP/\S+ (?P\d+) ')
CONTENT_LENGTH_RE = re.compile(r'^Content-Length: (?P\d+)', re.MULTILINE)
CHUNKED_ENCODING_RE = re.compile(r'^Transfer-Encoding: chunked(\r\n|\r|\n)', re.MULTILINE)
CHUNK_SIZE_RE = re.compile(r'^(?P[0-9a-fA-F]+)(;[^\r\n]+)?(\r\n|\r|\n)')
CRLF_START_RE = re.compile(r'^(\r\n|\n|\r)')
class SocketWrapper(object):
def __init__(self):
raise NotImplemented
class Socket(SocketWrapper):
def __init__(self, sock):
self.sock = sock
self.reader = sock
self.writer = sock
self.reader_fd = sock.fileno()
self.writer_fd = sock.fileno()
self.family = sock.family
self.type = sock.type
self.lock = None
def recv(self, n):
return self.sock.recv(n)
def send(self, s):
return self.sock.send(s)
def setblocking(self, b):
self.sock.setblocking(b)
def bind(self, a):
self.sock.bind(a)
def connect(self, a):
self.sock.connect(a)
def getsockname(self):
return self.sock.getsockname()
def close(self):
self.sock.close()
class ReaderWriter(SocketWrapper):
def __init__(self, reader, writer, proc=None):
self.reader = reader
self.writer = writer
self.reader_fd = self.reader.fileno()
self.writer_fd = self.writer.fileno()
self.family = socket.AF_INET
self.type = socket.SOCK_STREAM
self.lock = None
self.proc = proc
def recv(self, n):
return os.read(self.reader_fd, n)
def send(self, s):
return os.write(self.writer_fd, s)
def setblocking(self, b):
if not b:
fcntl.fcntl(self.reader_fd, fcntl.F_SETFL, os.O_NONBLOCK)
fcntl.fcntl(self.writer_fd, fcntl.F_SETFL, os.O_NONBLOCK)
def bind(self, a):
pass
def connect(self, a):
pass
def getsockname(self):
return ('localhost', 0)
def close(self):
pass
class RemoteQueryTransportError(Exception):
pass
class TransportMetaDeserializationError(Exception):
pass
class SocketInUse(Exception):
pass
class DNSQueryTransportMeta(object):
def __init__(self, req, dst, tcp, timeout, dport, src=None, sport=None):
self.req = req
self.dst = dst
self.tcp = tcp
self.timeout = timeout
self.dport = dport
self.src = src
self.sport = sport
self.res = None
self.err = None
self.start_time = None
self.end_time = None
def serialize_request(self):
d = OrderedDict()
d['req'] = lb2s(base64.b64encode(self.req))
d['dst'] = self.dst
d['dport'] = self.dport
if self.src is not None:
d['src'] = self.src
if self.sport is not None:
d['sport'] = self.sport
d['tcp'] = self.tcp
d['timeout'] = int(self.timeout*1000)
return d
@classmethod
def deserialize_request(cls, d):
if 'req' not in d or d['req'] is None:
raise TransportMetaDeserializationError('Missing "req" field in input.')
try:
req = base64.b64decode(d['req'])
except TypeError:
raise TransportMetaDeserializationError('Base64 decoding DNS request failed: %s' % d['req'])
if 'dst' not in d or d['dst'] is None:
raise TransportMetaDeserializationError('Missing "dst" field in input.')
try:
dst = IPAddr(d['dst'])
except ValueError:
raise TransportMetaDeserializationError('Invalid destination IP address: %s' % d['dst'])
if 'dport' not in d or d['dport'] is None:
raise TransportMetaDeserializationError('Missing "dport" field in input.')
try:
dport = int(d['dport'])
if dport < 0 or dport > 65535:
raise ValueError()
except ValueError:
raise TransportMetaDeserializationError('Invalid destination port: %s' % d['dport'])
if 'src' not in d or d['src'] is None:
src = None
else:
try:
src = IPAddr(d['src'])
except ValueError:
raise TransportMetaDeserializationError('Invalid source IP address: %s' % d['src'])
if 'sport' not in d or d['sport'] is None:
sport = None
else:
try:
sport = int(d['sport'])
if sport < 0 or sport > 65535:
raise ValueError()
except ValueError:
raise TransportMetaDeserializationError('Invalid source port: %s' % d['sport'])
if 'tcp' not in d or d['tcp'] is None:
raise TransportMetaDeserializationError('Missing "tcp" field in input.')
else:
tcp = bool(d['tcp'])
if 'timeout' not in d or d['timeout'] is None:
raise TransportMetaDeserializationError('Missing "timeout" field in input.')
else:
try:
timeout = int(d['timeout'])/1000.0
except ValueError:
raise TransportMetaDeserializationError('Invalid timeout value: %s' % d['timeout'])
return cls(req, dst, tcp, timeout, dport, src, sport)
def serialize_response(self):
d = OrderedDict()
if self.res is not None:
d['res'] = lb2s(base64.b64encode(self.res))
else:
d['res'] = None
if self.err is not None:
if isinstance(self.err, (socket.error, EOFError)):
d['err'] = 'NETWORK_ERROR'
elif isinstance(self.err, dns.exception.Timeout):
d['err'] = 'TIMEOUT'
else:
d['err'] = 'ERROR'
if hasattr(self.err, 'errno'):
errno_name = errno.errorcode.get(self.err.errno, None)
if errno_name is not None:
d['errno'] = errno_name
d['src'] = self.src
d['sport'] = self.sport
d['time_elapsed'] = int((self.end_time - self.start_time)*1000)
return d
def deserialize_response(self, d):
if 'err' in d and d['err'] is not None:
if d['err'] == 'NETWORK_ERROR':
self.err = socket.error()
if 'errno' in d and d['errno'] is not None:
if hasattr(errno, d['errno']):
self.err.errno = getattr(errno, d['errno'])
else:
raise TransportMetaDeserializationError('Unknown errno name: %s' % d['errno'])
elif d['err'] == 'TIMEOUT':
self.err = dns.exception.Timeout()
else:
raise TransportMetaDeserializationError('Unknown DNS response error: %s' % d['err'])
elif not ('res' in d and d['res'] is not None):
raise TransportMetaDeserializationError('Missing DNS response or response error in input.')
else:
try:
self.res = base64.b64decode(d['res'])
except TypeError:
raise TransportMetaDeserializationError('Base64 decoding of DNS response failed: %s' % d['res'])
if 'src' in d and d['src'] is not None:
try:
self.src = IPAddr(d['src'])
except ValueError:
raise TransportMetaDeserializationError('Invalid source IP address: %s' % d['src'])
elif not isinstance(self.err, socket.error):
raise TransportMetaDeserializationError('Missing "src" field in input')
if 'sport' in d and d['sport'] is not None:
try:
self.sport = int(d['sport'])
if self.sport < 0 or self.sport > 65535:
raise ValueError()
except ValueError:
raise TransportMetaDeserializationError('Invalid source port: %s' % d['sport'])
elif not isinstance(self.err, socket.error):
raise TransportMetaDeserializationError('Missing "sport" field in input.')
if 'time_elapsed' in d and d['time_elapsed'] is not None:
try:
elapsed = int(d['time_elapsed'])
if elapsed < 0:
raise ValueError()
except ValueError:
raise TransportMetaDeserializationError('Invalid time elapsed value: %s' % d['time_elapsed'])
else:
raise TransportMetaDeserializationError('Missing "time_elapsed" field in input.')
self.end_time = time.time()
self.start_time = self.end_time - (elapsed/1000.0)
QTH_MODE_WRITE_READ = 0
QTH_MODE_WRITE = 1
QTH_MODE_READ = 2
QTH_SETUP_DONE = 0
QTH_SETUP_NEED_WRITE = 1
QTH_SETUP_NEED_READ = 2
class DNSQueryTransportHandler(object):
singleton = False
allow_loopback_query = False
allow_private_query = False
timeout_baseline = 0.0
mode = QTH_MODE_WRITE_READ
def __init__(self, sock=None, recycle_sock=False, processed_queue=None, factory=None):
self.msg_send = None
self.msg_send_len = None
self.msg_send_index = None
self.msg_recv = None
self.msg_recv_len = None
self.msg_recv_buf = None
self.msg_recv_index = None
self.err = None
self.dst = None
self.dport = None
self.src = None
self.sport = None
self.transport_type = None
self.timeout = None
self._processed_queue = processed_queue
self.factory = factory
self._sock = sock
self.sock = None
self.recycle_sock = recycle_sock
self.expiration = None
self.start_time = None
self.end_time = None
self.setup_state = QTH_SETUP_DONE
self.qtms = []
def _set_timeout(self, qtm):
if self.timeout is None or qtm.timeout > self.timeout:
self.timeout = qtm.timeout
def add_qtm(self, qtm):
if self.singleton and self.qtms:
raise TypeError('Only one DNSQueryTransportMeta instance allowed for DNSQueryTransportHandlers of singleton type!')
self.qtms.append(qtm)
self._set_timeout(qtm)
def _check_source(self):
if self.src in (ANY_IPV6, ANY_IPV4):
self.src = None
def finalize(self):
assert self.mode in (QTH_MODE_WRITE_READ, QTH_MODE_READ), 'finalize() can only be called for modes QTH_MODE_READ and QTH_MODE_WRITE_READ'
assert self.msg_recv is not None or self.err is not None, 'Query must have been executed before finalize() can be called'
self._check_source()
# clear out any partial responses if there was an error
if self.err is not None:
self.msg_recv = None
if self.factory is not None:
if self.recycle_sock:
# if recycle_sock is requested, add the sock to the factory.
# Then add the lock to the sock to prevent concurrent use of
# the socket.
if self.sock is not None and self._sock is None:
self.factory.lock.acquire()
try:
if self.factory.sock is None:
self.factory.sock = self.sock
self.factory.sock.lock = self.factory.lock
finally:
self.factory.lock.release()
elif self.sock is not None and self.sock is self.factory.sock:
# if recycle_sock is not requested, and this sock is in the
# factory, then remove it.
self.factory.lock.acquire()
try:
if self.sock is self.factory.sock:
self.factory.sock = None
finally:
self.factory.lock.release()
#TODO change this and the overriding child methods to init_msg_send
def init_req(self):
raise NotImplemented
def _init_msg_recv(self):
self.msg_recv = b''
self.msg_recv_buf = b''
self.msg_recv_index = 0
self.msg_recv_len = None
def prepare(self):
if self.mode in (QTH_MODE_WRITE_READ, QTH_MODE_WRITE):
assert self.msg_send is not None, 'Request must be initialized with init_req() before be added before prepare() can be called'
if self.mode in (QTH_MODE_WRITE_READ, QTH_MODE_READ):
self._init_msg_recv()
if self.timeout is None:
self.timeout = self.timeout_baseline
if self._sock is not None:
# if a pre-existing socket is available for re-use, then use that
# instead
try:
self._reuse_socket()
self._set_start_time()
except SocketInUse as e:
self.err = e
else:
try:
self._create_socket()
self._configure_socket()
self._bind_socket()
self._set_start_time()
self._connect_socket()
except socket.error as e:
self.err = e
def _reuse_socket(self):
# wait for the lock on the socket
if not self._sock.lock.acquire(False):
raise SocketInUse()
self.sock = self._sock
def _get_af(self):
if self.dst.version == 6:
return socket.AF_INET6
else:
return socket.AF_INET
def _create_socket(self):
af = self._get_af()
self.sock = Socket(socket.socket(af, self.transport_type))
def _configure_socket(self):
self.sock.setblocking(0)
def _bind_socket(self):
if self.src is not None:
src = self.src
else:
if self.sock.family == socket.AF_INET6:
src = ANY_IPV6
else:
src = ANY_IPV4
if self.sport is not None:
self.sock.bind((src, self.sport))
else:
i = 0
while True:
sport = random.randint(1024, 65535)
try:
self.sock.bind((src, sport))
break
except socket.error as e:
i += 1
if i > MAX_PORT_BIND_ATTEMPTS or e.errno != socket.errno.EADDRINUSE:
raise
def _set_socket_info(self):
src, sport = self.sock.getsockname()[:2]
self.src = IPAddr(src)
self.sport = sport
def _get_connect_arg(self):
return (self.dst, self.dport)
def _connect_socket(self):
try:
self.sock.connect(self._get_connect_arg())
except socket.error as e:
if e.errno != socket.errno.EINPROGRESS:
raise
def _set_start_time(self):
self.expiration = self.timeout + time.time()
self.start_time = time.time()
def _set_end_time(self):
self.end_time = time.time()
if self.start_time is None:
self.start_time = self.end_time
def cleanup(self):
# set end (and start, if necessary) times, as appropriate
self._set_end_time()
# close socket
if self.sock is not None:
self._set_socket_info()
if not self.recycle_sock:
self.sock.close()
if self.sock.lock is not None:
self.sock.lock.release()
# place in processed queue, if specified
if self._processed_queue is not None:
self._processed_queue.put(self)
def do_setup(self):
pass
def do_write(self):
try:
self.msg_send_index += self.sock.send(self.msg_send[self.msg_send_index:])
if self.msg_send_index >= self.msg_send_len:
return True
except socket.error as e:
self.err = e
return True
def do_read(self):
raise NotImplemented
def do_timeout(self):
raise NotImplemented
def serialize_requests(self):
d = {
'version': DNS_TRANSPORT_VERSION,
'requests': [q.serialize_request() for q in self.qtms]
}
return d
def serialize_responses(self):
d = {
'version': DNS_TRANSPORT_VERSION,
'responses': [q.serialize_response() for q in self.qtms]
}
return d
class DNSQueryTransportHandlerDNS(DNSQueryTransportHandler):
singleton = True
require_queryid_match = True
def finalize(self):
super(DNSQueryTransportHandlerDNS, self).finalize()
qtm = self.qtms[0]
qtm.src = self.src
qtm.sport = self.sport
qtm.res = self.msg_recv
qtm.err = self.err
qtm.start_time = self.start_time
qtm.end_time = self.end_time
def init_req(self):
assert self.qtms, 'At least one DNSQueryTransportMeta must be added before init_req() can be called'
qtm = self.qtms[0]
self.dst = qtm.dst
self.dport = qtm.dport
self.src = qtm.src
self.sport = qtm.sport
self.msg_send = qtm.req
self.msg_send_len = len(qtm.req)
self.msg_send_index = 0
# python3/python2 dual compatibility
if isinstance(self.msg_send, str):
map_func = lambda x: ord(x)
else:
map_func = lambda x: x
self._queryid_wire = self.msg_send[:2]
index = 12
while map_func(self.msg_send[index]) != 0:
index += map_func(self.msg_send[index]) + 1
index += 4
self._question_wire = self.msg_send[12:index]
if qtm.tcp:
self.transport_type = socket.SOCK_STREAM
self.msg_send = struct.pack(b'!H', self.msg_send_len) + self.msg_send
self.msg_send_len += struct.calcsize(b'H')
else:
self.transport_type = socket.SOCK_DGRAM
def _check_msg_recv_consistency(self):
if self.require_queryid_match and self.msg_recv[:2] != self._queryid_wire:
return False
return True
def do_read(self):
# UDP
if self.sock.type == socket.SOCK_DGRAM:
try:
self.msg_recv = self.sock.recv(65536)
if self._check_msg_recv_consistency():
return True
else:
self.msg_recv = b''
except socket.error as e:
self.err = e
return True
# TCP
else:
try:
if self.msg_recv_len is None:
if self.msg_recv_buf:
buf = self.sock.recv(1)
else:
buf = self.sock.recv(2)
if buf == b'':
raise EOFError()
self.msg_recv_buf += buf
if len(self.msg_recv_buf) == 2:
self.msg_recv_len = struct.unpack(b'!H', self.msg_recv_buf)[0]
if self.msg_recv_len is not None:
buf = self.sock.recv(self.msg_recv_len - self.msg_recv_index)
if buf == b'':
raise EOFError()
self.msg_recv += buf
self.msg_recv_index = len(self.msg_recv)
if self.msg_recv_index >= self.msg_recv_len:
return True
except (socket.error, EOFError) as e:
if isinstance(e, socket.error) and e.errno == socket.errno.EAGAIN:
pass
else:
self.err = e
return True
def do_timeout(self):
self.err = dns.exception.Timeout()
class DNSQueryTransportHandlerDNSPrivate(DNSQueryTransportHandlerDNS):
allow_loopback_query = True
allow_private_query = True
class DNSQueryTransportHandlerDNSLoose(DNSQueryTransportHandlerDNS):
require_queryid_match = False
class DNSQueryTransportHandlerMulti(DNSQueryTransportHandler):
singleton = False
def _set_timeout(self, qtm):
if self.timeout is None:
# allow 5 seconds for looking glass overhead, as a baseline
self.timeout = self.timeout_baseline
# account for worst case, in which case queries are performed serially
# on the remote end
self.timeout += qtm.timeout
def finalize(self):
super(DNSQueryTransportHandlerMulti, self).finalize()
# if there was an error, then re-raise it here
if self.err is not None:
raise self.err
# if there is no content, raise an exception
if self.msg_recv is None:
raise RemoteQueryTransportError('No content in response')
# load the json content
try:
content = json.loads(codecs.decode(self.msg_recv, 'utf-8'))
except ValueError:
raise RemoteQueryTransportError('JSON decoding of response failed: %s' % self.msg_recv)
if 'version' not in content:
raise RemoteQueryTransportError('No version information in response.')
try:
major_vers, minor_vers = [int(x) for x in str(content['version']).split('.', 1)]
except ValueError:
raise RemoteQueryTransportError('Version of JSON input in response is invalid: %s' % content['version'])
# ensure major version is a match and minor version is no greater
# than the current minor version
curr_major_vers, curr_minor_vers = [int(x) for x in str(DNS_TRANSPORT_VERSION).split('.', 1)]
if major_vers != curr_major_vers or minor_vers > curr_minor_vers:
raise RemoteQueryTransportError('Version %d.%d of JSON input in response is incompatible with this software.' % (major_vers, minor_vers))
if 'error' in content:
raise RemoteQueryTransportError('Remote query error: %s' % content['error'])
if self.mode == QTH_MODE_WRITE_READ:
if 'responses' not in content:
raise RemoteQueryTransportError('No DNS response information in response.')
else: # self.mode == QTH_MODE_READ:
if 'requests' not in content:
raise RemoteQueryTransportError('No DNS requests information in response.')
for i in range(len(self.qtms)):
try:
if self.mode == QTH_MODE_WRITE_READ:
self.qtms[i].deserialize_response(content['responses'][i])
else: # self.mode == QTH_MODE_READ:
self.qtms[i].deserialize_request(content['requests'][i])
except IndexError:
raise RemoteQueryTransportError('DNS response or request information missing from message')
except TransportMetaDeserializationError as e:
raise RemoteQueryTransportError(str(e))
class DNSQueryTransportHandlerHTTP(DNSQueryTransportHandlerMulti):
timeout_baseline = 5.0
def __init__(self, url, insecure=False, sock=None, recycle_sock=True, processed_queue=None, factory=None):
super(DNSQueryTransportHandlerHTTP, self).__init__(sock=sock, recycle_sock=recycle_sock, processed_queue=processed_queue, factory=factory)
self.transport_type = socket.SOCK_STREAM
parse_result = urlparse.urlparse(url)
scheme = parse_result.scheme
if not scheme:
scheme = 'http'
elif scheme not in ('http', 'https'):
raise RemoteQueryTransportError('Invalid scheme: %s' % scheme)
self.use_ssl = scheme == 'https'
self.host = parse_result.hostname
self.dport = parse_result.port
if self.dport is None:
if scheme == 'http':
self.dport = 80
else: # scheme == 'https'
self.dport = 443
self.path = parse_result.path
self.username = parse_result.username
self.password = parse_result.password
self.insecure = insecure
if self.use_ssl:
self.setup_state = QTH_SETUP_NEED_WRITE
else:
self.setup_state = QTH_SETUP_DONE
af = 0
try:
addrinfo = socket.getaddrinfo(self.host, self.dport, af, self.transport_type)
except socket.gaierror:
raise RemoteQueryTransportError('Unable to resolve name of HTTP host: %s' % self.host)
self.dst = IPAddr(addrinfo[0][4][0])
self.chunked_encoding = None
def _upgrade_socket_to_ssl(self):
if isinstance(self.sock.sock, ssl.SSLSocket):
return
#XXX this is python >= 2.7.9 only
ctx = ssl.create_default_context()
if self.insecure:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
new_sock = Socket(ctx.wrap_socket(self.sock.sock, server_hostname=self.host,
do_handshake_on_connect=False))
new_sock.lock = self.sock.lock
self.sock = new_sock
def _post_data(self):
return 'content=' + urlquote.quote(json.dumps(self.serialize_requests()))
def _authentication_header(self):
if not self.username:
return ''
# set username/password
username = self.username
if self.password:
username += ':' + self.password
return 'Authorization: Basic %s\r\n' % (lb2s(base64.b64encode(codecs.encode(username, 'utf-8'))))
def init_req(self):
data = self._post_data()
self.msg_send = codecs.encode('POST %s HTTP/1.1\r\nHost: %s\r\nUser-Agent: DNSViz/0.9.3\r\nAccept: application/json\r\n%sContent-Length: %d\r\nContent-Type: application/x-www-form-urlencoded\r\n\r\n%s' % (self.path, self.host, self._authentication_header(), len(data), data), 'latin1')
self.msg_send_len = len(self.msg_send)
self.msg_send_index = 0
def prepare(self):
super(DNSQueryTransportHandlerHTTP, self).prepare()
if self.err is not None and not isinstance(self.err, SocketInUse):
self.err = RemoteQueryTransportError('Error making HTTP connection: %s' % self.err)
def do_setup(self):
if self.use_ssl:
self._upgrade_socket_to_ssl()
try:
self.sock.sock.do_handshake()
except ssl.SSLWantReadError:
self.setup_state = QTH_SETUP_NEED_READ
except ssl.SSLWantWriteError:
self.setup_state = QTH_SETUP_NEED_WRITE
except ssl.SSLError as e:
self.err = RemoteQueryTransportError('SSL Error: %s' % e)
else:
self.setup_state = QTH_SETUP_DONE
def do_write(self):
val = super(DNSQueryTransportHandlerHTTP, self).do_write()
if self.err is not None:
self.err = RemoteQueryTransportError('Error making HTTP request: %s' % self.err)
return val
def do_read(self):
try:
try:
buf = self.sock.recv(65536)
except ssl.SSLWantReadError:
return False
if buf == b'':
raise EOFError
self.msg_recv_buf += buf
# still reading status and headers
if self.chunked_encoding is None and self.msg_recv_len is None:
headers_end_match = HTTP_HEADER_END_RE.search(lb2s(self.msg_recv_buf))
if headers_end_match is not None:
headers = self.msg_recv_buf[:headers_end_match.start()]
self.msg_recv_buf = self.msg_recv_buf[headers_end_match.end():]
# check HTTP status
status_match = HTTP_STATUS_RE.search(lb2s(headers))
if status_match is None:
self.err = RemoteQueryTransportError('Malformed HTTP status line')
return True
status = int(status_match.group('status'))
if status != 200:
self.err = RemoteQueryTransportError('%d HTTP status' % status)
return True
# get content length or determine whether "chunked"
# transfer encoding is used
content_length_match = CONTENT_LENGTH_RE.search(lb2s(headers))
if content_length_match is not None:
self.chunked_encoding = False
self.msg_recv_len = int(content_length_match.group('length'))
else:
self.chunked_encoding = CHUNKED_ENCODING_RE.search(lb2s(headers)) is not None
# handle chunked encoding first
if self.chunked_encoding:
# look through as many chunks as are readily available
# (without having to read from socket again)
while self.msg_recv_buf:
if self.msg_recv_len is None:
# looking for chunk length
# strip off beginning CRLF, if any
# (this is for chunks after the first one)
crlf_start_match = CRLF_START_RE.search(lb2s(self.msg_recv_buf))
if crlf_start_match is not None:
self.msg_recv_buf = self.msg_recv_buf[crlf_start_match.end():]
# find the chunk length
chunk_len_match = CHUNK_SIZE_RE.search(lb2s(self.msg_recv_buf))
if chunk_len_match is not None:
self.msg_recv_len = int(chunk_len_match.group('length'), 16)
self.msg_recv_buf = self.msg_recv_buf[chunk_len_match.end():]
self.msg_recv_index = 0
else:
# if we don't currently know the length of the next
# chunk, and we don't have enough data to find the
# length, then break out of the loop because we
# don't have any more data to go off of.
break
if self.msg_recv_len is not None:
# we know a length of the current chunk
if self.msg_recv_len == 0:
# no chunks left, so clean up and return
return True
# read remaining bytes
bytes_remaining = self.msg_recv_len - self.msg_recv_index
if len(self.msg_recv_buf) > bytes_remaining:
self.msg_recv += self.msg_recv_buf[:bytes_remaining]
self.msg_recv_index = 0
self.msg_recv_buf = self.msg_recv_buf[bytes_remaining:]
self.msg_recv_len = None
else:
self.msg_recv += self.msg_recv_buf
self.msg_recv_index += len(self.msg_recv_buf)
self.msg_recv_buf = b''
elif self.chunked_encoding == False:
# output is not chunked, so we're either reading until we've
# read all the bytes specified by the content-length header (if
# specified) or until the server closes the connection (or we
# time out)
if self.msg_recv_len is not None:
bytes_remaining = self.msg_recv_len - self.msg_recv_index
self.msg_recv += self.msg_recv_buf[:bytes_remaining]
self.msg_recv_buf = self.msg_recv_buf[bytes_remaining:]
self.msg_recv_index = len(self.msg_recv)
if self.msg_recv_index >= self.msg_recv_len:
return True
else:
self.msg_recv += self.msg_recv_buf
self.msg_recv_buf = b''
except (socket.error, EOFError) as e:
if isinstance(e, socket.error) and e.errno == socket.errno.EAGAIN:
pass
else:
# if we weren't passed any content length header, and we're not
# using chunked encoding, then don't throw an error. If the
# content was bad, then it will be reflected in the decoding of
# the content
if self.chunked_encoding == False and self.msg_recv_len is None:
pass
else:
self.err = RemoteQueryTransportError('Error communicating with HTTP server: %s' % e)
return True
def do_timeout(self):
self.err = RemoteQueryTransportError('HTTP request timed out')
class DNSQueryTransportHandlerHTTPPrivate(DNSQueryTransportHandlerHTTP):
allow_loopback_query = True
allow_private_query = True
class DNSQueryTransportHandlerWebSocketServer(DNSQueryTransportHandlerMulti):
timeout_baseline = 5.0
unmask_on_recv = True
def __init__(self, path, sock=None, recycle_sock=True, processed_queue=None, factory=None):
super(DNSQueryTransportHandlerWebSocketServer, self).__init__(sock=sock, recycle_sock=recycle_sock, processed_queue=processed_queue, factory=factory)
self.dst = path
self.transport_type = socket.SOCK_STREAM
self.mask_mapping = []
self.has_more = None
def _get_af(self):
return socket.AF_UNIX
def _bind_socket(self):
pass
def _set_socket_info(self):
pass
def _get_connect_arg(self):
return self.dst
def prepare(self):
super(DNSQueryTransportHandlerWebSocketServer, self).prepare()
if self.err is not None and not isinstance(self.err, SocketInUse):
self.err = RemoteQueryTransportError('Error connecting to UNIX domain socket: %s' % self.err)
def do_write(self):
val = super(DNSQueryTransportHandlerWebSocketServer, self).do_write()
if self.err is not None:
self.err = RemoteQueryTransportError('Error writing to UNIX domain socket: %s' % self.err)
return val
def finalize(self):
if self.unmask_on_recv:
# python3/python2 dual compatibility
if isinstance(self.msg_recv, str):
decode_func = lambda x: struct.unpack(b'!B', x)[0]
else:
decode_func = lambda x: x
new_msg_recv = b''
for i, mask_index in enumerate(self.mask_mapping):
mask_octets = struct.unpack(b'!BBBB', self.msg_recv[mask_index:mask_index + 4])
if i >= len(self.mask_mapping) - 1:
buf = self.msg_recv[mask_index + 4:]
else:
buf = self.msg_recv[mask_index + 4:self.mask_mapping[i + 1]]
for j in range(len(buf)):
b = decode_func(buf[j])
new_msg_recv += struct.pack(b'!B', b ^ mask_octets[j % 4])
self.msg_recv = new_msg_recv
super(DNSQueryTransportHandlerWebSocketServer, self).finalize()
def init_req(self):
data = codecs.encode(json.dumps(self.serialize_requests()), 'utf-8')
header = b'\x81'
l = len(data)
if l <= 125:
header += struct.pack(b'!B', l)
elif l <= 0xffff:
header += struct.pack(b'!BH', 126, l)
else: # 0xffff < len <= 2^63
header += struct.pack(b'!BLL', 127, 0, l)
self.msg_send = header + data
self.msg_send_len = len(self.msg_send)
self.msg_send_index = 0
def init_empty_msg_send(self):
self.msg_send = b'\x81\x00'
self.msg_send_len = len(self.msg_send)
self.msg_send_index = 0
def do_read(self):
try:
buf = self.sock.recv(65536)
if buf == b'':
raise EOFError
self.msg_recv_buf += buf
# look through as many frames as are readily available
# (without having to read from socket again)
while self.msg_recv_buf:
if self.msg_recv_len is None:
# looking for frame length
if len(self.msg_recv_buf) >= 2:
byte0, byte1 = struct.unpack(b'!BB', self.msg_recv_buf[0:2])
byte1b = byte1 & 0x7f
# mask must be set
if not byte1 & 0x80:
if self.err is not None:
self.err = RemoteQueryTransportError('Mask bit not set in message from server')
return True
# check for FIN flag
self.has_more = not bool(byte0 & 0x80)
# determine the header length
if byte1b <= 125:
header_len = 2
elif byte1b == 126:
header_len = 4
else: # byte1b == 127:
header_len = 10
if len(self.msg_recv_buf) >= header_len:
if byte1b <= 125:
self.msg_recv_len = byte1b
elif byte1b == 126:
self.msg_recv_len = struct.unpack(b'!H', self.msg_recv_buf[2:4])[0]
else: # byte1b == 127:
self.msg_recv_len = struct.unpack(b'!Q', self.msg_recv_buf[2:10])[0]
if self.unmask_on_recv:
# handle mask
self.mask_mapping.append(len(self.msg_recv))
self.msg_recv_len += 4
self.msg_recv_buf = self.msg_recv_buf[header_len:]
else:
# if we don't currently know the length of the next
# frame, and we don't have enough data to find the
# length, then break out of the loop because we
# don't have any more data to go off of.
break
if self.msg_recv_len is not None:
# we know a length of the current chunk
# read remaining bytes
bytes_remaining = self.msg_recv_len - self.msg_recv_index
if len(self.msg_recv_buf) > bytes_remaining:
self.msg_recv += self.msg_recv_buf[:bytes_remaining]
self.msg_recv_index = 0
self.msg_recv_buf = self.msg_recv_buf[bytes_remaining:]
self.msg_recv_len = None
else:
self.msg_recv += self.msg_recv_buf
self.msg_recv_index += len(self.msg_recv_buf)
self.msg_recv_buf = b''
if self.msg_recv_index >= self.msg_recv_len and not self.has_more:
return True
except (socket.error, EOFError) as e:
if isinstance(e, socket.error) and e.errno == socket.errno.EAGAIN:
pass
else:
self.err = e
return True
def do_timeout(self):
self.err = RemoteQueryTransportError('Read of UNIX domain socket timed out')
class DNSQueryTransportHandlerWebSocketServerPrivate(DNSQueryTransportHandlerWebSocketServer):
allow_loopback_query = True
allow_private_query = True
class DNSQueryTransportHandlerWebSocketClient(DNSQueryTransportHandlerWebSocketServer):
unmask_on_recv = False
def __init__(self, sock, recycle_sock=True, processed_queue=None, factory=None):
super(DNSQueryTransportHandlerWebSocketClient, self).__init__(None, sock=sock, recycle_sock=recycle_sock, processed_queue=processed_queue, factory=factory)
def _init_req(self, data):
header = b'\x81'
l = len(data)
if l <= 125:
header += struct.pack(b'!B', l | 0x80)
elif l <= 0xffff:
header += struct.pack(b'!BH', 126 | 0x80, l)
else: # 0xffff < len <= 2^63
header += struct.pack(b'!BLL', 127 | 0x80, 0, l)
mask_int = random.randint(0, 0xffffffff)
mask = [(mask_int >> 24) & 0xff,
(mask_int >> 16) & 0xff,
(mask_int >> 8) & 0xff,
mask_int & 0xff]
header += struct.pack(b'!BBBB', *mask)
# python3/python2 dual compatibility
if isinstance(data, str):
map_func = lambda x: ord(x)
else:
map_func = lambda x: x
self.msg_send = header
for i, b in enumerate(data):
self.msg_send += struct.pack(b'!B', mask[i % 4] ^ map_func(b))
self.msg_send_len = len(self.msg_send)
self.msg_send_index = 0
def init_req(self):
self._init_req(codecs.encode(json.dumps(self.serialize_responses()), 'utf-8'))
def init_err_send(self, err):
self._init_req(codecs.encode(err, 'utf-8'))
class DNSQueryTransportHandlerWebSocketClientReader(DNSQueryTransportHandlerWebSocketClient):
mode = QTH_MODE_READ
class DNSQueryTransportHandlerWebSocketClientWriter(DNSQueryTransportHandlerWebSocketClient):
mode = QTH_MODE_WRITE
class DNSQueryTransportHandlerCmd(DNSQueryTransportHandlerWebSocketServer):
allow_loopback_query = True
allow_private_query = True
def __init__(self, args, sock=None, recycle_sock=True, processed_queue=None, factory=None):
super(DNSQueryTransportHandlerCmd, self).__init__(None, sock=sock, recycle_sock=recycle_sock, processed_queue=processed_queue, factory=factory)
self.args = args
def _get_af(self):
return None
def _bind_socket(self):
pass
def _set_socket_info(self):
pass
def _get_connect_arg(self):
return None
def _create_socket(self):
try:
p = subprocess.Popen(self.args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
except OSError as e:
raise socket.error(str(e))
else:
self.sock = ReaderWriter(io.open(p.stdout.fileno(), 'rb'), io.open(p.stdin.fileno(), 'wb'), p)
def _connect_socket(self):
pass
def do_write(self):
if self.sock.proc.poll() is not None:
self.err = RemoteQueryTransportError('Subprocess has ended with status %d.' % (self.sock.proc.returncode))
return True
return super(DNSQueryTransportHandlerCmd, self).do_write()
def do_read(self):
if self.sock.proc.poll() is not None:
self.err = RemoteQueryTransportError('Subprocess has ended with status %d.' % (self.sock.proc.returncode))
return True
return super(DNSQueryTransportHandlerCmd, self).do_read()
def cleanup(self):
super(DNSQueryTransportHandlerCmd, self).cleanup()
if self.sock is not None and not self.recycle_sock and self.sock.proc is not None and self.sock.proc.poll() is None:
self.sock.proc.terminate()
self.sock.proc.wait()
return True
class DNSQueryTransportHandlerRemoteCmd(DNSQueryTransportHandlerCmd):
timeout_baseline = 10.0
def __init__(self, url, sock=None, recycle_sock=True, processed_queue=None, factory=None):
parse_result = urlparse.urlparse(url)
scheme = parse_result.scheme
if not scheme:
scheme = 'ssh'
elif scheme != 'ssh':
raise RemoteQueryTransportError('Invalid scheme: %s' % scheme)
args = ['ssh', '-T']
if parse_result.port is not None:
args.extend(['-p', str(parse_result.port)])
if parse_result.username is not None:
args.append('%s@%s' % (parse_result.username, parse_result.hostname))
else:
args.append('%s' % (parse_result.hostname))
if parse_result.path and parse_result.path != '/':
args.append(parse_result.path)
else:
args.append('dnsviz lookingglass')
super(DNSQueryTransportHandlerRemoteCmd, self).__init__(args, sock=sock, recycle_sock=recycle_sock, processed_queue=processed_queue, factory=factory)
def _get_af(self):
return None
def _bind_socket(self):
pass
def _set_socket_info(self):
pass
def _get_connect_arg(self):
return None
def _create_socket(self):
try:
p = subprocess.Popen(self.args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
except OSError as e:
raise socket.error(str(e))
else:
self.sock = ReaderWriter(io.open(p.stdout.fileno(), 'rb'), io.open(p.stdin.fileno(), 'wb'), p)
def _connect_socket(self):
pass
def do_write(self):
if self.sock.proc.poll() is not None:
self.err = RemoteQueryTransportError('Subprocess has ended with status %d.' % (self.sock.proc.returncode))
return True
return super(DNSQueryTransportHandlerCmd, self).do_write()
def do_read(self):
if self.sock.proc.poll() is not None:
self.err = RemoteQueryTransportError('Subprocess has ended with status %d.' % (self.sock.proc.returncode))
return True
return super(DNSQueryTransportHandlerCmd, self).do_read()
def cleanup(self):
super(DNSQueryTransportHandlerCmd, self).cleanup()
if self.sock is not None and not self.recycle_sock and self.sock.proc is not None and self.sock.proc.poll() is None:
self.sock.proc.terminate()
self.sock.proc.wait()
return True
class DNSQueryTransportHandlerFactory(object):
cls = DNSQueryTransportHandler
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.kwargs['factory'] = self
self.lock = threading.Lock()
self.sock = None
def __del__(self):
if self.sock is not None:
self.sock.close()
def build(self, **kwargs):
if 'sock' not in kwargs and self.sock is not None:
kwargs['sock'] = self.sock
for name in self.kwargs:
if name not in kwargs:
kwargs[name] = self.kwargs[name]
return self.cls(*self.args, **kwargs)
class DNSQueryTransportHandlerDNSFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerDNS
class DNSQueryTransportHandlerDNSPrivateFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerDNSPrivate
class DNSQueryTransportHandlerHTTPFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerHTTP
class DNSQueryTransportHandlerHTTPPrivateFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerHTTPPrivate
class _DNSQueryTransportHandlerWebSocketServerFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerWebSocketServer
class DNSQueryTransportHandlerWebSocketServerFactory:
def __init__(self, *args, **kwargs):
self._f = _DNSQueryTransportHandlerWebSocketServerFactory(*args, **kwargs)
def __del__(self):
try:
qth = self._f.build()
qth.init_empty_msg_send()
qth.prepare()
qth.do_write()
except:
pass
@property
def cls(self):
return self._f.__class__.cls
def build(self, **kwargs):
return self._f.build(**kwargs)
class _DNSQueryTransportHandlerWebSocketServerPrivateFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerWebSocketServerPrivate
class DNSQueryTransportHandlerWebSocketServerPrivateFactory:
def __init__(self, *args, **kwargs):
self._f = _DNSQueryTransportHandlerWebSocketServerPrivateFactory(*args, **kwargs)
def __del__(self):
try:
qth = self._f.build()
qth.init_empty_msg_send()
qth.prepare()
qth.do_write()
except:
pass
@property
def cls(self):
return self._f.__class__.cls
def build(self, **kwargs):
return self._f.build(**kwargs)
class DNSQueryTransportHandlerCmdFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerCmd
class DNSQueryTransportHandlerRemoteCmdFactory(DNSQueryTransportHandlerFactory):
cls = DNSQueryTransportHandlerRemoteCmd
class DNSQueryTransportHandlerWrapper(object):
def __init__(self, qh):
self.qh = qh
def __eq__(self, other):
return False
def __lt__(self, other):
return False
class _DNSQueryTransportManager:
'''A class that handles'''
#TODO might need FD_SETSIZE to support lots of fds
def __init__(self):
self._notify_read_fd, self._notify_write_fd = os.pipe()
fcntl.fcntl(self._notify_read_fd, fcntl.F_SETFL, os.O_NONBLOCK)
self._msg_queue = queue.Queue()
self._event_map = {}
self._close = threading.Event()
# python3/python2 dual compatibility
try:
# python 3
t = threading.Thread(target=self._loop, daemon=True)
except TypeError:
# python 2
t = threading.Thread(target=self._loop)
t.daemon = True
t.start()
def close(self):
self._close.set()
os.write(self._notify_write_fd, struct.pack(b'!B', 0))
def handle_msg(self, qh):
self._event_map[qh] = threading.Event()
self._handle_msg(qh, True)
self._event_map[qh].wait()
del self._event_map[qh]
def handle_msg_nowait(self, qh):
self._handle_msg(qh, True)
def _handle_msg(self, qh, notify):
self._msg_queue.put(qh)
if notify:
os.write(self._notify_write_fd, struct.pack(b'!B', 0))
def _loop(self):
'''Return the data resulting from a UDP transaction.'''
query_meta = {}
expirations = []
# initialize "in" fds for select
rlist_in = [self._notify_read_fd]
wlist_in = []
xlist_in = []
while True:
# determine the new expiration
if expirations:
timeout = max(expirations[0][0] - time.time(), 0)
else:
timeout = MAX_WAIT_FOR_REQUEST
finished_fds = []
rlist_out, wlist_out, xlist_out = select.select(rlist_in, wlist_in, xlist_in, timeout)
# if we have been signalled to exit, then do that
if self._close.is_set():
break
# handle the requests
for fd in wlist_out:
qh = query_meta[fd]
if qh.setup_state == QTH_SETUP_DONE:
if qh.do_write():
if qh.err is not None or qh.mode == QTH_MODE_WRITE:
qh.cleanup()
finished_fds.append(fd)
else: # qh.mode == QTH_MODE_WRITE_READ
wlist_in.remove(fd)
rlist_in.append(qh.sock.reader_fd)
else:
qh.do_setup()
if qh.err is not None:
qh.cleanup()
finished_fds.append(fd)
elif qh.setup_state == QTH_SETUP_NEED_READ:
wlist_in.remove(fd)
rlist_in.append(qh.sock.reader_fd)
# handle the responses
for fd in rlist_out:
if fd == self._notify_read_fd:
continue
qh = query_meta[fd]
if qh.setup_state == QTH_SETUP_DONE:
if qh.do_read(): # qh.mode in (QTH_MODE_WRITE_READ, QTH_MODE_READ)
qh.cleanup()
finished_fds.append(qh.sock.reader_fd)
else:
qh.do_setup()
if qh.err is not None:
qh.cleanup()
finished_fds.append(fd)
elif qh.setup_state in (QTH_SETUP_NEED_WRITE, QTH_SETUP_DONE):
rlist_in.remove(fd)
wlist_in.append(qh.sock.writer_fd)
# handle the expired queries
future_index = bisect.bisect_right(expirations, ((time.time(), DNSQueryTransportHandlerWrapper(None))))
for i in range(future_index):
qh = expirations[i][1].qh
# this query actually finished earlier in this iteration of the
# loop, so don't indicate that it timed out
if qh.end_time is not None:
continue
qh.do_timeout()
qh.cleanup()
finished_fds.append(qh.sock.reader_fd)
expirations = expirations[future_index:]
# for any fds that need to be finished, do it now
for fd in finished_fds:
qh = query_meta[fd]
try:
rlist_in.remove(qh.sock.reader_fd)
except ValueError:
wlist_in.remove(qh.sock.writer_fd)
if qh in self._event_map:
self._event_map[qh].set()
del query_meta[fd]
if finished_fds:
# if any sockets were finished, then notify, in case any
# queued messages are waiting to be handled.
os.write(self._notify_write_fd, struct.pack(b'!B', 0))
# handle the new queries
if self._notify_read_fd in rlist_out:
# empty the pipe
os.read(self._notify_read_fd, 65536)
requeue = []
while True:
try:
qh = self._msg_queue.get_nowait()
qh.prepare()
if qh.err is not None:
if isinstance(qh.err, SocketInUse):
# if this was a SocketInUse, just requeue, and try again
qh.err = None
requeue.append(qh)
else:
qh.cleanup()
if qh in self._event_map:
self._event_map[qh].set()
else:
# if we successfully bound and connected the
# socket, then put this socket in the write fd list
query_meta[qh.sock.reader_fd] = qh
query_meta[qh.sock.writer_fd] = qh
bisect.insort(expirations, (qh.expiration, DNSQueryTransportHandlerWrapper(qh)))
if qh.setup_state == QTH_SETUP_DONE:
if qh.mode in (QTH_MODE_WRITE_READ, QTH_MODE_WRITE):
wlist_in.append(qh.sock.writer_fd)
elif qh.mode == QTH_MODE_READ:
rlist_in.append(qh.sock.reader_fd)
else:
raise Exception('Unexpected mode: %d' % qh.mode)
elif qh.setup_state == QTH_SETUP_NEED_WRITE:
wlist_in.append(qh.sock.writer_fd)
elif qh.setup_state == QTH_SETUP_NEED_READ:
rlist_in.append(qh.sock.reader_fd)
except queue.Empty:
break
for qh in requeue:
self._handle_msg(qh, False)
class DNSQueryTransportHandlerHTTPPrivate(DNSQueryTransportHandlerHTTP):
allow_loopback_query = True
allow_private_query = True
class DNSQueryTransportManager:
def __init__(self):
self._th = _DNSQueryTransportManager()
def __del__(self):
self.close()
def handle_msg(self, qh):
return self._th.handle_msg(qh)
def handle_msg_nowait(self, qh):
return self._th.handle_msg_nowait(qh)
def close(self):
return self._th.close()