diff --git a/dns/cache.py b/dns/cache.py index 0348b2f..04cce3d 100644 --- a/dns/cache.py +++ b/dns/cache.py @@ -123,8 +123,11 @@ class RecordCache(object): def read_cache_file(self): """ Read the cache file from disk """ - with open("cache.json") as f: - self.records = json.load(f, object_hook=resource_from_json) + try: + with open("cache.json") as f: + self.records = json.load(f, object_hook=resource_from_json) + except Exception, e: + print "error loading cache", repr(e) self.sweep() def write_cache_file(self): """ Write the cache file to disk """ diff --git a/dns/resolver.py b/dns/resolver.py index 091a4ef..1479e24 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -45,7 +45,6 @@ class Resolver(object): self.SBELT.read_master_file("named.root") if caching: self.cache = dns.cache.RecordCache(self.ttl) - self.cache.read_cache_file() else: self.cache = dns.cache.DummyCache(self.ttl) diff --git a/dns_client.py b/dns_client.py index ced082b..7a40b9a 100644 --- a/dns_client.py +++ b/dns_client.py @@ -20,6 +20,7 @@ if __name__ == "__main__": # Resolve hostname resolver = dns.resolver.Resolver(args.caching, args.ttl) + resolver.cache.read_cache_file() hostname, aliases, addresses = resolver.gethostbyname(args.hostname) resolver.cache.write_cache_file() # Print output diff --git a/dns_tests.py b/dns_tests.py index 66c04ed..439b46b 100644 --- a/dns_tests.py +++ b/dns_tests.py @@ -5,12 +5,49 @@ portnr = 5353 server = "localhost" +import dns.resolver +from dns.resource import ResourceRecord, RecordData +from dns.classes import Class +from dns.types import Type +import unittest +import sys +import time + class TestResolver(unittest.TestCase): - pass + # solve a FQDN, output with corresponding IP/CNAME/authoitative status generated + def test_solve(self): + resolver = dns.resolver.Resolver(False, 0) + host, alias, ip = resolver.gethostbyname("mail.polvanaubel.com") + self.assertEqual(host, "sog.polvanaubel.com") + self.assertEqual(alias, ["sog.polvanaubel.com"]) + self.assertEqual(ip, ["138.201.39.104"]) + def test_invalid(self): + resolver = dns.resolver.Resolver(False, 0) + host, alias, ip = resolver.gethostbyname("invalid.example.com") + self.assertFalse(host) + self.assertFalse(alias) + self.assertFalse(ip) class TestResolverCache(unittest.TestCase): - pass + # solve an invalid cached FQDN, output corresponds to cache + def test_solveinv(self): + resolver = dns.resolver.Resolver(True, 0) + resolver.cache.add_record(ResourceRecord("invalid.example.com", Type.A, Class.IN, 60, RecordData("1.2.3.4"))) + host, alias, ip = resolver.gethostbyname("invalid.example.com") + self.assertEqual(host, "invalid.example.com") + self.assertEqual(alias, []) + self.assertEqual(ip, ["1.2.3.4"]) + # start your server and wait configured TTL + 1 time for an invalid cached FQDN to expire, + # an empty output should be generated + def test_ttlexpire(self): + resolver = dns.resolver.Resolver(True, 0.05) + resolver.cache.add_record(ResourceRecord("invalid.example.com", Type.A, Class.IN, 60, RecordData("1.2.3.4"))) + time.sleep(0.06) + host, alias, ip = resolver.gethostbyname("invalid.example.com") + self.assertFalse(host) + self.assertFalse(alias) + self.assertFalse(ip) class TestServer(unittest.TestCase):