Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/dnsviz/dnsviz.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/dnsviz
diff options
context:
space:
mode:
authorCasey Deccio <casey@deccio.net>2021-08-06 08:39:48 +0300
committerCasey Deccio <casey@deccio.net>2021-08-06 08:41:22 +0300
commit25503d1c859f262074a9a9f6b16053530213789d (patch)
tree94629e49c15fd4405cdd5f40749008479249e957 /dnsviz
parentb1dbe1a85039b765fec683207c07880db955e1c8 (diff)
Use a daemon-type thread to simplify management
Diffstat (limited to 'dnsviz')
-rw-r--r--dnsviz/commands/probe.py75
-rw-r--r--dnsviz/transport.py2
2 files changed, 47 insertions, 30 deletions
diff --git a/dnsviz/commands/probe.py b/dnsviz/commands/probe.py
index b2c058c..f916c75 100644
--- a/dnsviz/commands/probe.py
+++ b/dnsviz/commands/probe.py
@@ -119,6 +119,7 @@ def _cleanup_tm():
global tm
if tm is not None:
tm.close()
+ tm = None
def _init_stub_resolver():
global resolver
@@ -491,7 +492,11 @@ class NameServerMappingsForDomain(object):
_allow_stop_at = None
_handle_file_arg = None
- def __init__(self, domain, stop_at, resolver):
+ _resolvers_initialized = False
+ _stub_resolver = None
+ _full_resolver = None
+
+ def __init__(self, domain, stop_at):
if not (self._allow_file is not None and \
self._allow_name_only is not None and \
self._allow_addr_only is not None and \
@@ -503,7 +508,6 @@ class NameServerMappingsForDomain(object):
raise argparse.ArgumentTypeError('The "+" may not be specified with this option')
self.domain = domain
- self._resolver = resolver
self._nsi = 1
self.delegation_mapping = {}
@@ -514,6 +518,23 @@ class NameServerMappingsForDomain(object):
self.delegation_mapping[(self.domain, dns.rdatatype.NS)] = dns.rrset.RRset(self.domain, dns.rdataclass.IN, dns.rdatatype.NS)
@classmethod
+ def init_resolvers(cls):
+ if not NameServerMappingsForDomain._resolvers_initialized:
+ tm = transport.DNSQueryTransportManager()
+ try:
+ NameServerMappingsForDomain._stub_resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm)
+ except ResolvConfError:
+ pass
+ NameServerMappingsForDomain._full_resolver = PrivateFullResolver(transport_manager=tm)
+ NameServerMappingsForDomain._resolvers_initialized = True
+
+ @classmethod
+ def cleanup_resolvers(cls):
+ NameServerMappingsForDomain._stub_resolver = None
+ NameServerMappingsForDomain._full_resolver = None
+ NameServerMappingsForDomain._resolvers_initialized = False
+
+ @classmethod
def _strip_port(cls, s):
# Determine whether there is a port attached to the end
match = cls.PORT_RE.search(s)
@@ -540,10 +561,14 @@ class NameServerMappingsForDomain(object):
self._handle_name_addr_mapping(name_addr)
def _handle_name_no_addr(self, name, port):
- if self._resolver is None:
- raise argparse.ArgumentTypeError('If addresses are not provided for names, then %s must have valid nameserver entries.\n' % RESOLV_CONF)
+ resolver = None
+ self.init_resolvers()
+ if self._stub_resolver is not None:
+ resolver = self._stub_resolver
+ else:
+ resolver = self._full_resolver
query_tuples = ((name, dns.rdatatype.A, dns.rdataclass.IN), (name, dns.rdatatype.AAAA, dns.rdataclass.IN))
- answer_map = self._resolver.query_multiple_for_answer(*query_tuples)
+ answer_map = resolver.query_multiple_for_answer(*query_tuples)
found_answer = False
for (n, rdtype, rdclass) in answer_map:
a = answer_map[(n, rdtype, rdclass)]
@@ -764,7 +789,7 @@ class RecursiveServersForDomain(NameServerMappingsForDomain):
_handle_file_arg = None
class DSForDomain:
- def __init__(self, domain, stop_at, resolver):
+ def __init__(self, domain, stop_at):
self.domain = domain
if stop_at and not self._allow_stop_at:
@@ -819,9 +844,6 @@ class DSForDomain:
class DomainListArgHelper:
STOP_RE = re.compile(r'^(.*)\+$')
- def __init__(self, resolver):
- self._resolver = resolver
-
@classmethod
def _strip_stop_marker(cls, s):
match = cls.STOP_RE.search(s)
@@ -857,14 +879,14 @@ class DomainListArgHelper:
if list_arg is not None:
list_arg = list_arg.strip()
- obj = cls(domain, stop_at, self._resolver)
+ obj = cls(domain, stop_at)
if list_arg:
obj.handle_list_arg(list_arg)
return obj
def _handle_list_arg(self, cls, list_arg):
- obj = cls(WILDCARD_EXPLICIT_DELEGATION, False, self._resolver)
+ obj = cls(WILDCARD_EXPLICIT_DELEGATION, False)
obj.handle_list_arg(list_arg)
return obj
@@ -883,7 +905,7 @@ class DomainListArgHelper:
class ArgHelper:
BRACKETS_RE = re.compile(r'^\[(.*)\]$')
- def __init__(self, resolver, logger):
+ def __init__(self, logger):
self.parser = None
self.odd_ports = {}
@@ -907,13 +929,12 @@ class ArgHelper:
self.args = None
self._arg_mapping = None
- self._resolver = resolver
self._logger = logger
self._zones_to_serve = []
def build_parser(self, prog):
self.parser = argparse.ArgumentParser(description='Issue diagnostic DNS queries', prog=prog)
- helper = DomainListArgHelper(self._resolver)
+ helper = DomainListArgHelper()
# python3/python2 dual compatibility
stdout_buffer = io.open(sys.stdout.fileno(), 'wb', closefd=False)
@@ -1238,12 +1259,14 @@ class ArgHelper:
def populate_recursive_servers(self):
if not self.args.authoritative_analysis and not self.args.recursive_servers:
- if self._resolver is None:
+ try:
+ resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm)
+ except ResolvConfError:
raise argparse.ArgumentTypeError('If servers are not specified with the %s option, then %s must have valid nameserver entries.\n' % \
(self._arg_mapping['recursive_servers'], RESOLV_CONF))
if (WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS) not in self.explicit_delegations:
self.explicit_delegations[(WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS)] = dns.rrset.RRset(WILDCARD_EXPLICIT_DELEGATION, dns.rdataclass.IN, dns.rdatatype.NS)
- for i, server in enumerate(self._resolver._servers):
+ for i, server in enumerate(resolver._servers):
if IPAddr(server).version == 6:
rdtype = dns.rdatatype.AAAA
else:
@@ -1451,23 +1474,16 @@ class ArgHelper:
zone.serve()
def build_helper(logger, cmd, subcmd):
- try:
- resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm)
- except ResolvConfError:
- resolver = None
-
- arghelper = ArgHelper(resolver, logger)
+ arghelper = ArgHelper(logger)
arghelper.build_parser('%s %s' % (cmd, subcmd))
return arghelper
def main(argv):
- global tm
global th_factories
global explicit_delegations
global odd_ports
try:
- _init_tm()
arghelper = build_helper(logger, sys.argv[0], argv[0])
arghelper.parse_args(argv[1:])
logger.setLevel(arghelper.get_log_level())
@@ -1516,6 +1532,10 @@ def main(argv):
kwargs = {}
dnsviz_meta = { 'version': DNS_RAW_VERSION, 'names': [lb2s(n.to_text()) for n in arghelper.names] }
+ NameServerMappingsForDomain.cleanup_resolvers()
+
+ _init_tm()
+
name_objs = []
if arghelper.args.input_file:
cache = {}
@@ -1536,6 +1556,8 @@ def main(argv):
name_objs = a.analyze(arghelper.names)
+ _cleanup_tm()
+
name_objs = [x for x in name_objs if x is not None]
if not name_objs:
@@ -1556,10 +1578,5 @@ def main(argv):
logger.error('Interrupted.')
sys.exit(4)
- # tm is global (because of possible multiprocessing), so we need to
- # explicitly close it here
- finally:
- _cleanup_tm()
-
if __name__ == "__main__":
main(sys.argv)
diff --git a/dnsviz/transport.py b/dnsviz/transport.py
index efa8b17..8d46c3c 100644
--- a/dnsviz/transport.py
+++ b/dnsviz/transport.py
@@ -1403,7 +1403,7 @@ class _DNSQueryTransportManager:
self._event_map = {}
self._close = threading.Event()
- t = threading.Thread(target=self._loop)
+ t = threading.Thread(target=self._loop, daemon=True)
t.start()
def close(self):