diff --git a/dns/cache.py b/dns/cache.py index 57a821a..ea5b2b0 100644 --- a/dns/cache.py +++ b/dns/cache.py @@ -46,6 +46,23 @@ def resource_from_json(dct): rdata = RecordData.create(type_, dct["rdata"]) return ResourceRecord(name, type_, class_, ttl, rdata) +class DummyCache(object): + """ Cache for ResourceRecords """ + + def __init__(self, ttl): + pass + def gc(self): + pass + def lookup(self, dname, type_, class_): + return [] + def add_record(self, record): + pass + def read_cache_file(self): + """ Read the cache file from disk """ + pass + def write_cache_file(self): + """ Write the cache file to disk """ + pass class RecordCache(object): """ Cache for ResourceRecords """ @@ -59,6 +76,9 @@ class RecordCache(object): self.records = [] self.ttl = ttl + def gc(self): + pass + def lookup(self, dname, type_, class_): """ Lookup resource records in cache @@ -70,7 +90,9 @@ class RecordCache(object): type_ (Type): type class_ (Class): class """ - pass + + matches = lambda rec: rec.type_ == type_ and rec.class_ == class_ and rec.name.lower() == dname.lower() + return filter(matches, self.records) def add_record(self, record): """ Add a new Record to the cache @@ -78,12 +100,24 @@ class RecordCache(object): Args: record (ResourceRecord): the record added to the cache """ - pass + if self.ttl > 0: + record.ttl = self.ttl + for rec in self.records: + if rec.type_ == record.type_ and rec.class_ == record.class_ and rec.name.lower() == record.name.lower() \ + and rec.rdata.data == record.rdata.data: + rec.ttl = record.ttl + # update last seen time + break + else: + self.records.append(record) def read_cache_file(self): """ Read the cache file from disk """ - pass - + with open("cache.json") as f: + self.records = json.load(f, object_hook=resource_from_json) def write_cache_file(self): """ Write the cache file to disk """ - pass + with open("cache.json", 'w') as f: + json.dump(self.records, f, cls=ResourceEncoder, indent=4) + + diff --git a/dns/resolver.py b/dns/resolver.py index 12a9ec4..091a4ef 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -16,26 +16,15 @@ import dns.cache import dns.message from dns.rcodes import RCode -from collections import defaultdict +from dns.zone import Zone + +from collections import defaultdict +import itertools -SBELT = { - "A.ROOT-SERVERS.NET": (True, "198.41.0.4"), - "B.ROOT-SERVERS.NET": (True, "192.228.79.201"), - "C.ROOT-SERVERS.NET": (True, "192.33.4.12"), - "D.ROOT-SERVERS.NET": (True, "199.7.91.13"), - "E.ROOT-SERVERS.NET": (True, "192.203.230.10"), - "F.ROOT-SERVERS.NET": (True, "192.5.5.241"), - "G.ROOT-SERVERS.NET": (True, "192.112.36.4"), - "H.ROOT-SERVERS.NET": (True, "128.63.2.53"), - "I.ROOT-SERVERS.NET": (True, "192.36.148.17"), - "J.ROOT-SERVERS.NET": (True, "192.58.128.30"), - "K.ROOT-SERVERS.NET": (True, "193.0.14.129"), - "L.ROOT-SERVERS.NET": (True, "199.7.83.42"), - "M.ROOT-SERVERS.NET": (True, "202.12.27.33") -} def dictappend(seq): s = defaultdict(list) + # update doesn't work for (k,v) in seq: s[k].append(v) return s @@ -52,6 +41,13 @@ class Resolver(object): """ self.caching = caching self.ttl = ttl + self.SBELT = Zone() + self.SBELT.read_master_file("named.root") + if caching: + self.cache = dns.cache.RecordCache(self.ttl) + self.cache.read_cache_file() + else: + self.cache = dns.cache.DummyCache(self.ttl) @classmethod def query(cls, query, dest): @@ -87,17 +83,47 @@ class Resolver(object): Returns: (str, [str], [str]): (hostname, aliaslist, ipaddrlist) """ - # Create and send query + SNAME = hostname + + aliases = [] + additionals + # 1. See if the answer is in local information, and if so return # it to the client. - pass + while True: + # follow cnames + lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN) + if lookedupcnames: + #print "found CNAMES", SNAME, lookedupcnames + SNAME = lookedupcnames[0].rdata.data + aliases.append(SNAME) + continue + break + lookedup = self.cache.lookup(SNAME, Type.A, Class.IN) + if lookedup: + #print "got result from cache", SNAME + return SNAME, aliases, [d.rdata.data for d in lookedup] # 2. Find the best servers to ask. - SLIST = SBELT.values() - SNAME = hostname - aliases = [] + additionals + hs = hostname + SLIST = [] + while hs: # look up NS servers in cache with decreasing fqdn + lookedup = self.cache.lookup(hs, Type.NS, Class.IN) + for n in lookedup: + SLIST.append((False, n.rdata.data)) + if lookedup: + #print "got NS from cache", hs + break + hs = '.'.join(hs.split('.')[1:]) + else: + nameservers = [ns.rdata.data for ns in self.SBELT.records['.']] + SLIST = [] + for ns in nameservers: + try: + SLIST.append((True, self.SBELT.records[ns][0].rdata.data)) + except KeyError: + SLIST.append((False, ns)) while True: SLIST.sort(key=lambda a: a[0], reverse=True) @@ -120,6 +146,7 @@ class Resolver(object): #print "asking " + ip # state 2 # 3. Send them queries until one returns a response. + #print "querying", ip, "with", hostname question = dns.message.Question(hostname, Type.A, Class.IN) query = Resolver.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question]) # state 3 @@ -144,6 +171,9 @@ class Resolver(object): print "NXDOMAIN" return ([],[],[]) # edge 3 -> 6 + for ans in (response.answers + response.additionals + response.authorities): + self.cache.add_record(ans) + if response.header.aa: # state 5 answers = dictappend([(ans.name, ans.rdata.data) for ans in response.answers if ans.type_ == Type.A]) @@ -173,7 +203,6 @@ class Resolver(object): # step 2. if response.authorities: # state 4 - # cache cands = [auth.rdata.data for auth in response.authorities] d = {} for add in response.additionals: diff --git a/dns/resource.py b/dns/resource.py index 6fe3aad..083969c 100644 --- a/dns/resource.py +++ b/dns/resource.py @@ -82,6 +82,7 @@ class RecordData(object): Type.NS: NSRecordData, Type.AAAA: AAAARecordData } + data = str(data) if type_ in classdict: return classdict[type_](data) else: diff --git a/dns/zone.py b/dns/zone.py index b7254b9..509b530 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -9,6 +9,10 @@ zones or record sets. These classes are merely a suggestion, feel free to use something else. """ +from resource import ResourceRecord, RecordData + +from types import Type +from classes import Class class Catalog(object): """ A catalog of zones """ @@ -26,13 +30,14 @@ class Catalog(object): """ self.zones[name] = zone +from collections import defaultdict class Zone(object): """ A zone in the domain name space """ def __init__(self): """ Initialize the Zone """ - self.records = {} + self.records = defaultdict(list) def add_node(self, name, record_set): """ Add a record set to the zone @@ -41,7 +46,7 @@ class Zone(object): name (str): domain name record_set ([ResourceRecord]): resource records """ - self.records[name] = record_set + self.records[name].append(record_set) def read_master_file(self, filename): """ Read the zone from a master file @@ -51,4 +56,12 @@ class Zone(object): Args: filename (str): the filename of the master file """ - pass + with open(filename) as f: + lines = f.readlines() + records = [] + for line in lines: + line = line.strip() + if line.startswith(';'): continue + name, ttl, type_, data = line.split() + rdata = RecordData(data) + self.add_node(name, ResourceRecord(name, Type.by_string[type_], Class.IN, int(ttl), rdata)) diff --git a/dns_client.py b/dns_client.py index 710a785..ced082b 100644 --- a/dns_client.py +++ b/dns_client.py @@ -21,7 +21,7 @@ if __name__ == "__main__": # Resolve hostname resolver = dns.resolver.Resolver(args.caching, args.ttl) hostname, aliases, addresses = resolver.gethostbyname(args.hostname) - + resolver.cache.write_cache_file() # Print output print(hostname) print(aliases)