From 25503d1c859f262074a9a9f6b16053530213789d Mon Sep 17 00:00:00 2001 From: Casey Deccio Date: Thu, 5 Aug 2021 23:39:48 -0600 Subject: Use a daemon-type thread to simplify management --- dnsviz/commands/probe.py | 75 ++++++++++++++++++++++-------------- dnsviz/transport.py | 2 +- tests/dnsviz_probe_options.py | 89 ++++++++++++++++++++----------------------- 3 files changed, 89 insertions(+), 77 deletions(-) diff --git a/dnsviz/commands/probe.py b/dnsviz/commands/probe.py index b2c058c..f916c75 100644 --- a/dnsviz/commands/probe.py +++ b/dnsviz/commands/probe.py @@ -119,6 +119,7 @@ def _cleanup_tm(): global tm if tm is not None: tm.close() + tm = None def _init_stub_resolver(): global resolver @@ -491,7 +492,11 @@ class NameServerMappingsForDomain(object): _allow_stop_at = None _handle_file_arg = None - def __init__(self, domain, stop_at, resolver): + _resolvers_initialized = False + _stub_resolver = None + _full_resolver = None + + def __init__(self, domain, stop_at): if not (self._allow_file is not None and \ self._allow_name_only is not None and \ self._allow_addr_only is not None and \ @@ -503,7 +508,6 @@ class NameServerMappingsForDomain(object): raise argparse.ArgumentTypeError('The "+" may not be specified with this option') self.domain = domain - self._resolver = resolver self._nsi = 1 self.delegation_mapping = {} @@ -513,6 +517,23 @@ class NameServerMappingsForDomain(object): self.delegation_mapping[(self.domain, dns.rdatatype.NS)] = dns.rrset.RRset(self.domain, dns.rdataclass.IN, dns.rdatatype.NS) + @classmethod + def init_resolvers(cls): + if not NameServerMappingsForDomain._resolvers_initialized: + tm = transport.DNSQueryTransportManager() + try: + NameServerMappingsForDomain._stub_resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm) + except ResolvConfError: + pass + NameServerMappingsForDomain._full_resolver = PrivateFullResolver(transport_manager=tm) + NameServerMappingsForDomain._resolvers_initialized = True + + @classmethod + def cleanup_resolvers(cls): + NameServerMappingsForDomain._stub_resolver = None + NameServerMappingsForDomain._full_resolver = None + NameServerMappingsForDomain._resolvers_initialized = False + @classmethod def _strip_port(cls, s): # Determine whether there is a port attached to the end @@ -540,10 +561,14 @@ class NameServerMappingsForDomain(object): self._handle_name_addr_mapping(name_addr) def _handle_name_no_addr(self, name, port): - if self._resolver is None: - raise argparse.ArgumentTypeError('If addresses are not provided for names, then %s must have valid nameserver entries.\n' % RESOLV_CONF) + resolver = None + self.init_resolvers() + if self._stub_resolver is not None: + resolver = self._stub_resolver + else: + resolver = self._full_resolver query_tuples = ((name, dns.rdatatype.A, dns.rdataclass.IN), (name, dns.rdatatype.AAAA, dns.rdataclass.IN)) - answer_map = self._resolver.query_multiple_for_answer(*query_tuples) + answer_map = resolver.query_multiple_for_answer(*query_tuples) found_answer = False for (n, rdtype, rdclass) in answer_map: a = answer_map[(n, rdtype, rdclass)] @@ -764,7 +789,7 @@ class RecursiveServersForDomain(NameServerMappingsForDomain): _handle_file_arg = None class DSForDomain: - def __init__(self, domain, stop_at, resolver): + def __init__(self, domain, stop_at): self.domain = domain if stop_at and not self._allow_stop_at: @@ -819,9 +844,6 @@ class DSForDomain: class DomainListArgHelper: STOP_RE = re.compile(r'^(.*)\+$') - def __init__(self, resolver): - self._resolver = resolver - @classmethod def _strip_stop_marker(cls, s): match = cls.STOP_RE.search(s) @@ -857,14 +879,14 @@ class DomainListArgHelper: if list_arg is not None: list_arg = list_arg.strip() - obj = cls(domain, stop_at, self._resolver) + obj = cls(domain, stop_at) if list_arg: obj.handle_list_arg(list_arg) return obj def _handle_list_arg(self, cls, list_arg): - obj = cls(WILDCARD_EXPLICIT_DELEGATION, False, self._resolver) + obj = cls(WILDCARD_EXPLICIT_DELEGATION, False) obj.handle_list_arg(list_arg) return obj @@ -883,7 +905,7 @@ class DomainListArgHelper: class ArgHelper: BRACKETS_RE = re.compile(r'^\[(.*)\]$') - def __init__(self, resolver, logger): + def __init__(self, logger): self.parser = None self.odd_ports = {} @@ -907,13 +929,12 @@ class ArgHelper: self.args = None self._arg_mapping = None - self._resolver = resolver self._logger = logger self._zones_to_serve = [] def build_parser(self, prog): self.parser = argparse.ArgumentParser(description='Issue diagnostic DNS queries', prog=prog) - helper = DomainListArgHelper(self._resolver) + helper = DomainListArgHelper() # python3/python2 dual compatibility stdout_buffer = io.open(sys.stdout.fileno(), 'wb', closefd=False) @@ -1238,12 +1259,14 @@ class ArgHelper: def populate_recursive_servers(self): if not self.args.authoritative_analysis and not self.args.recursive_servers: - if self._resolver is None: + try: + resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm) + except ResolvConfError: raise argparse.ArgumentTypeError('If servers are not specified with the %s option, then %s must have valid nameserver entries.\n' % \ (self._arg_mapping['recursive_servers'], RESOLV_CONF)) if (WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS) not in self.explicit_delegations: self.explicit_delegations[(WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS)] = dns.rrset.RRset(WILDCARD_EXPLICIT_DELEGATION, dns.rdataclass.IN, dns.rdatatype.NS) - for i, server in enumerate(self._resolver._servers): + for i, server in enumerate(resolver._servers): if IPAddr(server).version == 6: rdtype = dns.rdatatype.AAAA else: @@ -1451,23 +1474,16 @@ class ArgHelper: zone.serve() def build_helper(logger, cmd, subcmd): - try: - resolver = Resolver.from_file(RESOLV_CONF, StandardRecursiveQueryCD, transport_manager=tm) - except ResolvConfError: - resolver = None - - arghelper = ArgHelper(resolver, logger) + arghelper = ArgHelper(logger) arghelper.build_parser('%s %s' % (cmd, subcmd)) return arghelper def main(argv): - global tm global th_factories global explicit_delegations global odd_ports try: - _init_tm() arghelper = build_helper(logger, sys.argv[0], argv[0]) arghelper.parse_args(argv[1:]) logger.setLevel(arghelper.get_log_level()) @@ -1516,6 +1532,10 @@ def main(argv): kwargs = {} dnsviz_meta = { 'version': DNS_RAW_VERSION, 'names': [lb2s(n.to_text()) for n in arghelper.names] } + NameServerMappingsForDomain.cleanup_resolvers() + + _init_tm() + name_objs = [] if arghelper.args.input_file: cache = {} @@ -1536,6 +1556,8 @@ def main(argv): name_objs = a.analyze(arghelper.names) + _cleanup_tm() + name_objs = [x for x in name_objs if x is not None] if not name_objs: @@ -1556,10 +1578,5 @@ def main(argv): logger.error('Interrupted.') sys.exit(4) - # tm is global (because of possible multiprocessing), so we need to - # explicitly close it here - finally: - _cleanup_tm() - if __name__ == "__main__": main(sys.argv) diff --git a/dnsviz/transport.py b/dnsviz/transport.py index efa8b17..8d46c3c 100644 --- a/dnsviz/transport.py +++ b/dnsviz/transport.py @@ -1403,7 +1403,7 @@ class _DNSQueryTransportManager: self._event_map = {} self._close = threading.Event() - t = threading.Thread(target=self._loop) + t = threading.Thread(target=self._loop, daemon=True) t.start() def close(self): diff --git a/tests/dnsviz_probe_options.py b/tests/dnsviz_probe_options.py index 4519160..d24b2f4 100644 --- a/tests/dnsviz_probe_options.py +++ b/tests/dnsviz_probe_options.py @@ -12,9 +12,8 @@ import unittest import dns.name, dns.rdatatype, dns.rrset, dns.zone from dnsviz.commands.probe import ZoneFileToServe, ArgHelper, DomainListArgHelper, StandardRecursiveQueryCD, WILDCARD_EXPLICIT_DELEGATION, AnalysisInputError, CustomQueryMixin -from dnsviz import transport -from dnsviz.resolver import Resolver from dnsviz.ipaddr import IPAddr +from dnsviz import transport DATA_DIR = os.path.dirname(__file__) EXAMPLE_COM_ZONE = os.path.join(DATA_DIR, 'zone', 'example.com.zone') @@ -23,9 +22,7 @@ EXAMPLE_AUTHORITATIVE = os.path.join(DATA_DIR, 'data', 'example-authoritative.js class DNSVizProbeOptionsTestCase(unittest.TestCase): def setUp(self): - self.tm = transport.DNSQueryTransportManager() - self.resolver = Resolver.from_file('/etc/resolv.conf', StandardRecursiveQueryCD, transport_manager=self.tm) - self.helper = DomainListArgHelper(self.resolver) + self.helper = DomainListArgHelper() self.logger = logging.getLogger() for handler in self.logger.handlers: self.logger.removeHandler(handler) @@ -41,8 +38,6 @@ class DNSVizProbeOptionsTestCase(unittest.TestCase): def tearDown(self): CustomQueryMixin.edns_options = self.custom_query_mixin_edns_options_orig[:] - if self.tm is not None: - self.tm.close() def test_authoritative_option(self): arg1 = 'example.com+:ns1.example.com=192.0.2.1:1234,ns1.example.com=[2001:db8::1],' + \ @@ -599,7 +594,7 @@ ns1.example 0 IN A 192.0.2.1 ZoneFileToServe._next_free_port = self.first_port - arghelper1 = ArgHelper(self.resolver, self.logger) + arghelper1 = ArgHelper(self.logger) arghelper1.build_parser('probe') arghelper1.parse_args(args1) arghelper1.aggregate_delegation_info() @@ -612,7 +607,7 @@ ns1.example 0 IN A 192.0.2.1 ZoneFileToServe._next_free_port = self.first_port - arghelper2 = ArgHelper(self.resolver, self.logger) + arghelper2 = ArgHelper(self.logger) arghelper2.build_parser('probe') arghelper2.parse_args(args2) arghelper2.aggregate_delegation_info() @@ -625,7 +620,7 @@ ns1.example 0 IN A 192.0.2.1 ZoneFileToServe._next_free_port = self.first_port - arghelper3 = ArgHelper(self.resolver, self.logger) + arghelper3 = ArgHelper(self.logger) arghelper3.build_parser('probe') arghelper3.parse_args(args3) arghelper3.aggregate_delegation_info() @@ -634,7 +629,7 @@ ns1.example 0 IN A 192.0.2.1 ZoneFileToServe._next_free_port = self.first_port - arghelper4 = ArgHelper(self.resolver, self.logger) + arghelper4 = ArgHelper(self.logger) arghelper4.build_parser('probe') arghelper4.parse_args(args4) arghelper4.aggregate_delegation_info() @@ -675,7 +670,7 @@ ns1.example 0 IN A 192.0.2.1 ZoneFileToServe._next_free_port = self.first_port - arghelper1 = ArgHelper(self.resolver, self.logger) + arghelper1 = ArgHelper(self.logger) arghelper1.build_parser('probe') arghelper1.parse_args(args1) arghelper1.aggregate_delegation_info() @@ -686,7 +681,7 @@ ns1.example 0 IN A 192.0.2.1 args1 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1,ns1.example.com=[2001:db8::1]', '-x', 'com:ns1.foo.com=192.0.2.3'] - arghelper1 = ArgHelper(self.resolver, self.logger) + arghelper1 = ArgHelper(self.logger) arghelper1.build_parser('probe') arghelper1.parse_args(args1) @@ -718,7 +713,7 @@ ns1.example 0 IN A 192.0.2.1 odd_ports1 = {} - arghelper1 = ArgHelper(self.resolver, self.logger) + arghelper1 = ArgHelper(self.logger) arghelper1.build_parser('probe') arghelper1.parse_args(args1) arghelper1.aggregate_delegation_info() @@ -729,7 +724,7 @@ ns1.example 0 IN A 192.0.2.1 # Names, input file, or names file required args = [] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(argparse.ArgumentTypeError): @@ -737,7 +732,7 @@ ns1.example 0 IN A 192.0.2.1 # Names file and command-line domain names are mutually exclusive args = ['-f', '/dev/null', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(argparse.ArgumentTypeError): @@ -746,7 +741,7 @@ ns1.example 0 IN A 192.0.2.1 # Authoritative analysis and recursive servers args = ['-A', '-s', '192.0.2.1', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(argparse.ArgumentTypeError): @@ -754,7 +749,7 @@ ns1.example 0 IN A 192.0.2.1 # Authoritative servers with recursive analysis args = ['-x', 'example.com:ns1.example.com=192.0.2.1', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(argparse.ArgumentTypeError): @@ -762,7 +757,7 @@ ns1.example 0 IN A 192.0.2.1 # Delegation information with recursive analysis args = ['-N', 'example.com:ns1.example.com=192.0.2.1', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(argparse.ArgumentTypeError): @@ -770,7 +765,7 @@ ns1.example 0 IN A 192.0.2.1 # Delegation information with recursive analysis args = [ '-D', 'example.com:34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(argparse.ArgumentTypeError): @@ -778,21 +773,21 @@ ns1.example 0 IN A 192.0.2.1 def test_ceiling(self): args = ['-a', 'com', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() self.assertEqual(arghelper.ceiling, dns.name.from_text('com')) args = ['example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() self.assertEqual(arghelper.ceiling, dns.name.root) args = ['-A', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -800,7 +795,7 @@ ns1.example 0 IN A 192.0.2.1 def test_ip4_ipv6(self): args = [] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -808,7 +803,7 @@ ns1.example 0 IN A 192.0.2.1 self.assertEqual(arghelper.try_ipv6, True) args = ['-4', '-6'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -816,7 +811,7 @@ ns1.example 0 IN A 192.0.2.1 self.assertEqual(arghelper.try_ipv6, True) args = ['-4'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -824,7 +819,7 @@ ns1.example 0 IN A 192.0.2.1 self.assertEqual(arghelper.try_ipv6, False) args = ['-6'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -833,7 +828,7 @@ ns1.example 0 IN A 192.0.2.1 def test_client_ip(self): args = [] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -843,7 +838,7 @@ ns1.example 0 IN A 192.0.2.1 args = ['-b', '127.0.0.1'] if self.use_ipv6: args.extend(['-b', '::1']) - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -853,28 +848,28 @@ ns1.example 0 IN A 192.0.2.1 def test_th_factories(self): args = ['example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() self.assertIsNone(arghelper.th_factories) args = ['-u', 'http://example.com/', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() self.assertIsInstance(arghelper.th_factories[0], transport.DNSQueryTransportHandlerHTTPFactory) args = ['-u', 'ws:///dev/null', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() self.assertIsInstance(arghelper.th_factories[0], transport.DNSQueryTransportHandlerWebSocketServerFactory) args = ['-u', 'ssh://example.com/', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -885,7 +880,7 @@ ns1.example 0 IN A 192.0.2.1 # None args = ['-c', '', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -895,7 +890,7 @@ ns1.example 0 IN A 192.0.2.1 # Only DNS cookie args = ['example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -905,7 +900,7 @@ ns1.example 0 IN A 192.0.2.1 # All EDNS options args = ['-n', '-e', '192.0.2.0/24', 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.set_kwargs() @@ -932,14 +927,14 @@ ns1.example 0 IN A 192.0.2.1 try: args = ['-r', example_auth_out.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_input() # Bad json args = ['-r', example_bad_json.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(AnalysisInputError): @@ -947,7 +942,7 @@ ns1.example 0 IN A 192.0.2.1 # No version args = ['-r', example_no_version.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(AnalysisInputError): @@ -955,7 +950,7 @@ ns1.example 0 IN A 192.0.2.1 # Invalid version args = ['-r', example_invalid_version_1.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(AnalysisInputError): @@ -963,7 +958,7 @@ ns1.example 0 IN A 192.0.2.1 # Invalid version args = ['-r', example_invalid_version_2.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) with self.assertRaises(AnalysisInputError): @@ -976,7 +971,7 @@ ns1.example 0 IN A 192.0.2.1 def test_ingest_names(self): args = ['example.com', 'example.net'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_names() @@ -985,7 +980,7 @@ ns1.example 0 IN A 192.0.2.1 unicode_name = 'ใƒ†ใ‚นใƒˆ' args = [unicode_name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_names() @@ -1006,21 +1001,21 @@ ns1.example 0 IN A 192.0.2.1 try: args = ['-f', names_file.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_names() self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net')]) args = ['-f', names_file_unicode.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_names() self.assertEqual(list(arghelper.names), [dns.name.from_text('xn--zckzah.')]) args = ['-r', example_names_only.name] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_input() @@ -1028,7 +1023,7 @@ ns1.example 0 IN A 192.0.2.1 self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net'), dns.name.from_text('example.org')]) args = ['-r', example_names_only.name, 'example.com'] - arghelper = ArgHelper(self.resolver, self.logger) + arghelper = ArgHelper(self.logger) arghelper.build_parser('probe') arghelper.parse_args(args) arghelper.ingest_input() -- cgit v1.2.3