diff options
author | Casey Deccio <casey@deccio.net> | 2020-11-17 01:11:08 +0300 |
---|---|---|
committer | Casey Deccio <casey@deccio.net> | 2020-11-17 01:11:08 +0300 |
commit | 9a708cbc8ce110a8a610029ea02491d6c316c9ee (patch) | |
tree | 2e43936dac9f88adf0b7ac3a10425e3f76bf3a03 | |
parent | a19fe6ffe87df408284955cc144ec46902e33616 (diff) |
Use argparse for dnsviz graph
-rw-r--r-- | tests/dnsviz_graph_options.py | 321 |
1 files changed, 321 insertions, 0 deletions
diff --git a/tests/dnsviz_graph_options.py b/tests/dnsviz_graph_options.py new file mode 100644 index 0000000..9c60036 --- /dev/null +++ b/tests/dnsviz_graph_options.py @@ -0,0 +1,321 @@ +import argparse +import binascii +import datetime +import gzip +import importlib +import io +import logging +import os +import subprocess +import tempfile +import unittest + +import dns.name, dns.rdatatype, dns.rrset, dns.zone + +from dnsviz.format import utc +from dnsviz.util import get_default_trusted_keys + +mod = importlib.import_module('dnsviz.commands.graph') +GraphArgHelper = getattr(mod, 'GraphArgHelper') +AnalysisInputError = getattr(mod, 'AnalysisInputError') + +DATA_DIR = os.path.dirname(__file__) +EXAMPLE_AUTHORITATIVE = os.path.join(DATA_DIR, 'data', 'example-authoritative.json.gz') + + +class DNSVizGraphOptionsTestCase(unittest.TestCase): + def setUp(self): + self.logger = logging.getLogger() + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + self.logger.addHandler(logging.NullHandler()) + + def test_rrtype_list(self): + arg1 = 'A,AAAA,MX,CNAME' + arg1_with_spaces = ' A , AAAA , MX , CNAME ' + arg2 = 'A' + arg3 = 'A,BLAH' + arg4_empty = '' + arg4_empty_spaces = ' ' + + type_list1 = [dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.MX, dns.rdatatype.CNAME] + type_list2 = [dns.rdatatype.A] + empty_list = [] + + self.assertEqual(GraphArgHelper.comma_separated_dns_types(arg1), type_list1) + self.assertEqual(GraphArgHelper.comma_separated_dns_types(arg1_with_spaces), type_list1) + self.assertEqual(GraphArgHelper.comma_separated_dns_types(arg2), type_list2) + self.assertEqual(GraphArgHelper.comma_separated_dns_types(arg4_empty), empty_list) + self.assertEqual(GraphArgHelper.comma_separated_dns_types(arg4_empty_spaces), empty_list) + + # invalid schema + with self.assertRaises(argparse.ArgumentTypeError): + GraphArgHelper.comma_separated_dns_types(arg3) + + def test_integer_list(self): + arg1 = '1,2,3,4,5' + arg1_with_spaces = ' 1 , 2 , 3 , 4 , 5 ' + arg2 = '1' + arg3 = '1,A' + arg4_empty = '' + arg4_empty_spaces = ' ' + + int_list1 = [1,2,3,4,5] + int_list2 = [1] + empty_list = [] + + int_set1 = set([1,2,3,4,5]) + int_set2 = set([1]) + empty_set = set([]) + + self.assertEqual(GraphArgHelper.comma_separated_ints(arg1), int_list1) + self.assertEqual(GraphArgHelper.comma_separated_ints(arg1_with_spaces), int_list1) + self.assertEqual(GraphArgHelper.comma_separated_ints(arg2), int_list2) + self.assertEqual(GraphArgHelper.comma_separated_ints(arg4_empty), empty_list) + self.assertEqual(GraphArgHelper.comma_separated_ints(arg4_empty_spaces), empty_list) + + self.assertEqual(GraphArgHelper.comma_separated_ints_set(arg1), int_set1) + self.assertEqual(GraphArgHelper.comma_separated_ints_set(arg1_with_spaces), int_set1) + self.assertEqual(GraphArgHelper.comma_separated_ints_set(arg2), int_set2) + self.assertEqual(GraphArgHelper.comma_separated_ints_set(arg4_empty), empty_set) + self.assertEqual(GraphArgHelper.comma_separated_ints_set(arg4_empty_spaces), empty_set) + + # invalid schema + with self.assertRaises(argparse.ArgumentTypeError): + GraphArgHelper.comma_separated_ints(arg3) + + def test_valid_domain_name(self): + arg1 = '.' + arg2 = 'www.example.com' + arg3 = 'www..example.com' + + self.assertEqual(GraphArgHelper.valid_domain_name(arg1), dns.name.from_text(arg1)) + self.assertEqual(GraphArgHelper.valid_domain_name(arg2), dns.name.from_text(arg2)) + + # invalid domain name + with self.assertRaises(argparse.ArgumentTypeError): + GraphArgHelper.valid_domain_name(arg3) + + def test_ingest_input(self): + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_bad_json: + example_bad_json.write(b'{') + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_no_version: + example_no_version.write(b'{}') + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_invalid_version_1: + example_invalid_version_1.write(b'{ "_meta._dnsviz.": { "version": 1.11 } }') + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_invalid_version_2: + example_invalid_version_2.write(b'{ "_meta._dnsviz.": { "version": 5.0 } }') + + with gzip.open(EXAMPLE_AUTHORITATIVE, 'rb') as example_auth_in: + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_auth_out: + example_auth_out.write(example_auth_in.read()) + + try: + args = ['-r', example_auth_out.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.ingest_input() + + # Bad json + args = ['-r', example_bad_json.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(AnalysisInputError): + arghelper.ingest_input() + + # No version + args = ['-r', example_no_version.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(AnalysisInputError): + arghelper.ingest_input() + + # Invalid version + args = ['-r', example_invalid_version_1.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(AnalysisInputError): + arghelper.ingest_input() + + # Invalid version + args = ['-r', example_invalid_version_2.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(AnalysisInputError): + arghelper.ingest_input() + + finally: + for tmpfile in (example_auth_out, example_bad_json, example_no_version, \ + example_invalid_version_1, example_invalid_version_2): + os.remove(tmpfile.name) + + def test_ingest_names(self): + args = ['example.com', 'example.net'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.ingest_names() + self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net')]) + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as names_file: + names_file.write(b'example.com\nexample.net\n') + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_names_only: + example_names_only.write(b'{ "_meta._dnsviz.": { "version": 1.2, "names": [ "example.com.", "example.net.", "example.org." ] } }') + + try: + args = ['-f', names_file.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.ingest_names() + self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net')]) + + args = ['-r', example_names_only.name] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.ingest_input() + arghelper.ingest_names() + self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net'), dns.name.from_text('example.org')]) + + args = ['-r', example_names_only.name, 'example.com'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.ingest_input() + arghelper.ingest_names() + self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com')]) + finally: + for tmpfile in (names_file, example_names_only): + os.remove(tmpfile.name) + + def test_trusted_keys_file(self): + tk1 = 'example.com. IN DNSKEY 256 3 7 AwEAAZ2YEuBl4X58v1CezDfZjT1viYn5kY3MF3lSDjvHjMZ6gJlYt4Qq oIdpChifmeJldEX9/wPc04Tg7MlEfV3m0x2j80dMyObM0FZTxzMgbTFk Zs0AWrDXELieGkFZv1FB9YoxSX2XqvpFxwvPyyszUtCy/c5hrb6vfKRB Jh+qIO+NsNrl6O8NiYjWWNjdiFw+c2BxzpArQoaA+rcoyDYwH4xGpvTw YLnE9HmkwTSQuwASkgWgX3KgTmsDEw4I0P5Tk+wvmNnaqDhmFMHJK5Oh 92wUX+ppxxSgUx4UIJmftzi7sCg0qekIYUf99Dkn7OlC8X0rjj+xO4cD hbTjGkxmsD0=' + tk2 = 'example.com. IN DNSKEY 256 3 7 AwEAAaerI6CXvvG6U3UxkB0PXj+ORyGFtABYJ6JG3NL6w1KKlZl+73AS aPEEa7SXeuWmAWE1N3rsbnrMBvepBXkCbP609eoo2mJ8bsozT/NNwSSc FP1Ddw4wxpZAC/+/K736rF1HbI3ROS/rBTr7RW6rWzcyPbYFuUMVzrAM ZSJNJsTDcmyGc5Is3cFzNcrd3/Gmcjt8TKMmGq51HXWzFvxro7EH6aOl K6G4O4+mzaUKp91mg7DAVhX8yXnadXUZQ4yDfLzSleYQ2TroQqeSgI3X m/gUoACm3ELUOr84TmIKZ67X/zBTx8tHC5iBWY2tbIKqiJY7I4/aW4S4 NraCSRbDpbM=' + tk1_rdata = ' '.join(tk1.split()[3:]) + tk2_rdata = ' '.join(tk2.split()[3:]) + tk_explicit = [(dns.name.from_text('example.com'), dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DNSKEY, tk1_rdata)), + (dns.name.from_text('example.com'), dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DNSKEY, tk2_rdata))] + + now = datetime.datetime.now(utc) + tk_default = get_default_trusted_keys(now) + + args = ['example.com'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.aggregate_trusted_key_info() + self.assertEqual(arghelper.trusted_keys, None) + arghelper.update_trusted_key_info(now) + self.assertEqual(arghelper.trusted_keys, tk_default) + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as tk1_file: + tk1_file.write(tk1.encode('utf-8')) + + with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as tk2_file: + tk2_file.write(tk2.encode('utf-8')) + + try: + args = ['-t', tk1_file.name, '-t', tk2_file.name, 'example.com'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.aggregate_trusted_key_info() + arghelper.update_trusted_key_info(now) + self.assertEqual(arghelper.trusted_keys, tk_explicit) + + args = ['-t', '/dev/null', 'example.com'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.aggregate_trusted_key_info() + arghelper.update_trusted_key_info(now) + self.assertEqual(arghelper.trusted_keys, []) + + finally: + for tmpfile in (tk1_file, tk2_file): + os.remove(tmpfile.name) + + def test_option_combination_errors(self): + + # Names file and command-line domain names are mutually exclusive + args = ['-f', '/dev/null', 'example.com'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(argparse.ArgumentTypeError): + arghelper.check_args() + + # Names file and command-line domain names are mutually exclusive + args = ['-O', '-o', '/dev/null'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(argparse.ArgumentTypeError): + arghelper.check_args() + + # But this is allowed + args = ['-o', '/dev/null'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.check_args() + + # So is this + args = ['-O'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.check_args() + + def test_output_format(self): + + args = ['-T', 'png', '-o', 'foo.dot'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'png') + + args = ['-o', 'foo.dot'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'dot') + + args = ['-o', 'foo.png'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'png') + + args = ['-o', 'foo.html'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'html') + + args = ['-o', 'foo.svg'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'svg') + + args = ['-o', 'foo.xyz'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(argparse.ArgumentTypeError): + arghelper.set_kwargs() + + args = ['-o', 'png'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + with self.assertRaises(argparse.ArgumentTypeError): + arghelper.set_kwargs() + + args = ['-o', '-'] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'dot') + + args = [] + arghelper = GraphArgHelper(self.logger) + arghelper.build_parser('graph', args) + arghelper.set_kwargs() + self.assertEqual(arghelper.output_format, 'dot') + +if __name__ == '__main__': + unittest.main() |