diff options
author | Casey Deccio <casey@deccio.net> | 2021-08-06 08:39:48 +0300 |
---|---|---|
committer | Casey Deccio <casey@deccio.net> | 2021-08-06 08:41:22 +0300 |
commit | 25503d1c859f262074a9a9f6b16053530213789d (patch) | |
tree | 94629e49c15fd4405cdd5f40749008479249e957 /dnsviz | |
parent | b1dbe1a85039b765fec683207c07880db955e1c8 (diff) |
Use a daemon-type thread to simplify management
Diffstat (limited to 'dnsviz')
-rw-r--r-- | dnsviz/commands/probe.py | 75 | ||||
-rw-r--r-- | dnsviz/transport.py | 2 |
2 files changed, 47 insertions, 30 deletions
diff --git a/dnsviz/commands/probe.py b/dnsviz/commands/probe.py index b2c058c..f916c75 100644 --- a/dnsviz/commands/probe.py +++ b/dnsviz/commands/probe.py @@ -119,6 +119,7 @@ def _cleanup_tm(): global tm if tm is not None: tm.close() + tm = None def _init_stub_resolver(): global resolver @@ -491,7 +492,11 @@ class NameServerMappingsForDomain(object): _allow_stop_at = None _handle_file_arg = None - def __init__(self, domain, stop_at, resolver): + _resolvers_initialized = False + _stub_resolver = None + _full_resolver = None + + def __init__(self, domain, stop_at): if not (self._allow_file is not None and \ self._allow_name_only is not None and \ self._allow_addr_only is not None and \ @@ -503,7 +508,6 @@ class NameServerMappingsForDomain(object): raise argparse.ArgumentTypeError('The "+" may not be specified with this option') self.domain = domain - self._resolver = resolver self._nsi = 1 self.delegation_mapping = {} @@ -514,6 +518,23 @@ class NameServerMappingsForDomain(object): self.delegation_mapping[(self.domain, dns.rdatatype.NS)] = dns.rrset.RRset(self.domain, dns.rdataclass.IN, dns.rdatatype.NS) @classmethod + def init_resolvers(cls): + if not NameServerMappingsForDomain._resolvers_initialized: + tm = transport.DNSQueryTransportManager() + try: + NameServerMappingsForDomain._stub_resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm) + except ResolvConfError: + pass + NameServerMappingsForDomain._full_resolver = PrivateFullResolver(transport_manager=tm) + NameServerMappingsForDomain._resolvers_initialized = True + + @classmethod + def cleanup_resolvers(cls): + NameServerMappingsForDomain._stub_resolver = None + NameServerMappingsForDomain._full_resolver = None + NameServerMappingsForDomain._resolvers_initialized = False + + @classmethod def _strip_port(cls, s): # Determine whether there is a port attached to the end match = cls.PORT_RE.search(s) @@ -540,10 +561,14 @@ class NameServerMappingsForDomain(object): self._handle_name_addr_mapping(name_addr) def _handle_name_no_addr(self, name, port): - if self._resolver is None: - raise argparse.ArgumentTypeError('If addresses are not provided for names, then %s must have valid nameserver entries.\n' % RESOLV_CONF) + resolver = None + self.init_resolvers() + if self._stub_resolver is not None: + resolver = self._stub_resolver + else: + resolver = self._full_resolver query_tuples = ((name, dns.rdatatype.A, dns.rdataclass.IN), (name, dns.rdatatype.AAAA, dns.rdataclass.IN)) - answer_map = self._resolver.query_multiple_for_answer(*query_tuples) + answer_map = resolver.query_multiple_for_answer(*query_tuples) found_answer = False for (n, rdtype, rdclass) in answer_map: a = answer_map[(n, rdtype, rdclass)] @@ -764,7 +789,7 @@ class RecursiveServersForDomain(NameServerMappingsForDomain): _handle_file_arg = None class DSForDomain: - def __init__(self, domain, stop_at, resolver): + def __init__(self, domain, stop_at): self.domain = domain if stop_at and not self._allow_stop_at: @@ -819,9 +844,6 @@ class DSForDomain: class DomainListArgHelper: STOP_RE = re.compile(r'^(.*)\+$') - def __init__(self, resolver): - self._resolver = resolver - @classmethod def _strip_stop_marker(cls, s): match = cls.STOP_RE.search(s) @@ -857,14 +879,14 @@ class DomainListArgHelper: if list_arg is not None: list_arg = list_arg.strip() - obj = cls(domain, stop_at, self._resolver) + obj = cls(domain, stop_at) if list_arg: obj.handle_list_arg(list_arg) return obj def _handle_list_arg(self, cls, list_arg): - obj = cls(WILDCARD_EXPLICIT_DELEGATION, False, self._resolver) + obj = cls(WILDCARD_EXPLICIT_DELEGATION, False) obj.handle_list_arg(list_arg) return obj @@ -883,7 +905,7 @@ class DomainListArgHelper: class ArgHelper: BRACKETS_RE = re.compile(r'^\[(.*)\]$') - def __init__(self, resolver, logger): + def __init__(self, logger): self.parser = None self.odd_ports = {} @@ -907,13 +929,12 @@ class ArgHelper: self.args = None self._arg_mapping = None - self._resolver = resolver self._logger = logger self._zones_to_serve = [] def build_parser(self, prog): self.parser = argparse.ArgumentParser(description='Issue diagnostic DNS queries', prog=prog) - helper = DomainListArgHelper(self._resolver) + helper = DomainListArgHelper() # python3/python2 dual compatibility stdout_buffer = io.open(sys.stdout.fileno(), 'wb', closefd=False) @@ -1238,12 +1259,14 @@ class ArgHelper: def populate_recursive_servers(self): if not self.args.authoritative_analysis and not self.args.recursive_servers: - if self._resolver is None: + try: + resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm) + except ResolvConfError: raise argparse.ArgumentTypeError('If servers are not specified with the %s option, then %s must have valid nameserver entries.\n' % \ (self._arg_mapping['recursive_servers'], RESOLV_CONF)) if (WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS) not in self.explicit_delegations: self.explicit_delegations[(WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS)] = dns.rrset.RRset(WILDCARD_EXPLICIT_DELEGATION, dns.rdataclass.IN, dns.rdatatype.NS) - for i, server in enumerate(self._resolver._servers): + for i, server in enumerate(resolver._servers): if IPAddr(server).version == 6: rdtype = dns.rdatatype.AAAA else: @@ -1451,23 +1474,16 @@ class ArgHelper: zone.serve() def build_helper(logger, cmd, subcmd): - try: - resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm) - except ResolvConfError: - resolver = None - - arghelper = ArgHelper(resolver, logger) + arghelper = ArgHelper(logger) arghelper.build_parser('%s %s' % (cmd, subcmd)) return arghelper def main(argv): - global tm global th_factories global explicit_delegations global odd_ports try: - _init_tm() arghelper = build_helper(logger, sys.argv[0], argv[0]) arghelper.parse_args(argv[1:]) logger.setLevel(arghelper.get_log_level()) @@ -1516,6 +1532,10 @@ def main(argv): kwargs = {} dnsviz_meta = { 'version': DNS_RAW_VERSION, 'names': [lb2s(n.to_text()) for n in arghelper.names] } + NameServerMappingsForDomain.cleanup_resolvers() + + _init_tm() + name_objs = [] if arghelper.args.input_file: cache = {} @@ -1536,6 +1556,8 @@ def main(argv): name_objs = a.analyze(arghelper.names) + _cleanup_tm() + name_objs = [x for x in name_objs if x is not None] if not name_objs: @@ -1556,10 +1578,5 @@ def main(argv): logger.error('Interrupted.') sys.exit(4) - # tm is global (because of possible multiprocessing), so we need to - # explicitly close it here - finally: - _cleanup_tm() - if __name__ == "__main__": main(sys.argv) diff --git a/dnsviz/transport.py b/dnsviz/transport.py index efa8b17..8d46c3c 100644 --- a/dnsviz/transport.py +++ b/dnsviz/transport.py @@ -1403,7 +1403,7 @@ class _DNSQueryTransportManager: self._event_map = {} self._close = threading.Event() - t = threading.Thread(target=self._loop) + t = threading.Thread(target=self._loop, daemon=True) t.start() def close(self): |