add cache

master
Yorick van Pelt 2016-05-29 17:02:34 +02:00
parent 4afb687939
commit dbfe970da2
5 changed files with 108 additions and 31 deletions

View File

@ -46,6 +46,23 @@ def resource_from_json(dct):
rdata = RecordData.create(type_, dct["rdata"])
return ResourceRecord(name, type_, class_, ttl, rdata)
class DummyCache(object):
""" Cache for ResourceRecords """
def __init__(self, ttl):
pass
def gc(self):
pass
def lookup(self, dname, type_, class_):
return []
def add_record(self, record):
pass
def read_cache_file(self):
""" Read the cache file from disk """
pass
def write_cache_file(self):
""" Write the cache file to disk """
pass
class RecordCache(object):
""" Cache for ResourceRecords """
@ -59,6 +76,9 @@ class RecordCache(object):
self.records = []
self.ttl = ttl
def gc(self):
pass
def lookup(self, dname, type_, class_):
""" Lookup resource records in cache
@ -70,7 +90,9 @@ class RecordCache(object):
type_ (Type): type
class_ (Class): class
"""
pass
matches = lambda rec: rec.type_ == type_ and rec.class_ == class_ and rec.name.lower() == dname.lower()
return filter(matches, self.records)
def add_record(self, record):
""" Add a new Record to the cache
@ -78,12 +100,24 @@ class RecordCache(object):
Args:
record (ResourceRecord): the record added to the cache
"""
pass
if self.ttl > 0:
record.ttl = self.ttl
for rec in self.records:
if rec.type_ == record.type_ and rec.class_ == record.class_ and rec.name.lower() == record.name.lower() \
and rec.rdata.data == record.rdata.data:
rec.ttl = record.ttl
# update last seen time
break
else:
self.records.append(record)
def read_cache_file(self):
""" Read the cache file from disk """
pass
with open("cache.json") as f:
self.records = json.load(f, object_hook=resource_from_json)
def write_cache_file(self):
""" Write the cache file to disk """
pass
with open("cache.json", 'w') as f:
json.dump(self.records, f, cls=ResourceEncoder, indent=4)

View File

@ -16,26 +16,15 @@ import dns.cache
import dns.message
from dns.rcodes import RCode
from collections import defaultdict
from dns.zone import Zone
from collections import defaultdict
import itertools
SBELT = {
"A.ROOT-SERVERS.NET": (True, "198.41.0.4"),
"B.ROOT-SERVERS.NET": (True, "192.228.79.201"),
"C.ROOT-SERVERS.NET": (True, "192.33.4.12"),
"D.ROOT-SERVERS.NET": (True, "199.7.91.13"),
"E.ROOT-SERVERS.NET": (True, "192.203.230.10"),
"F.ROOT-SERVERS.NET": (True, "192.5.5.241"),
"G.ROOT-SERVERS.NET": (True, "192.112.36.4"),
"H.ROOT-SERVERS.NET": (True, "128.63.2.53"),
"I.ROOT-SERVERS.NET": (True, "192.36.148.17"),
"J.ROOT-SERVERS.NET": (True, "192.58.128.30"),
"K.ROOT-SERVERS.NET": (True, "193.0.14.129"),
"L.ROOT-SERVERS.NET": (True, "199.7.83.42"),
"M.ROOT-SERVERS.NET": (True, "202.12.27.33")
}
def dictappend(seq):
s = defaultdict(list)
# update doesn't work
for (k,v) in seq:
s[k].append(v)
return s
@ -52,6 +41,13 @@ class Resolver(object):
"""
self.caching = caching
self.ttl = ttl
self.SBELT = Zone()
self.SBELT.read_master_file("named.root")
if caching:
self.cache = dns.cache.RecordCache(self.ttl)
self.cache.read_cache_file()
else:
self.cache = dns.cache.DummyCache(self.ttl)
@classmethod
def query(cls, query, dest):
@ -87,17 +83,47 @@ class Resolver(object):
Returns:
(str, [str], [str]): (hostname, aliaslist, ipaddrlist)
"""
# Create and send query
SNAME = hostname
aliases = [] + additionals
# 1. See if the answer is in local information, and if so return
# it to the client.
pass
while True:
# follow cnames
lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN)
if lookedupcnames:
#print "found CNAMES", SNAME, lookedupcnames
SNAME = lookedupcnames[0].rdata.data
aliases.append(SNAME)
continue
break
lookedup = self.cache.lookup(SNAME, Type.A, Class.IN)
if lookedup:
#print "got result from cache", SNAME
return SNAME, aliases, [d.rdata.data for d in lookedup]
# 2. Find the best servers to ask.
SLIST = SBELT.values()
SNAME = hostname
aliases = [] + additionals
hs = hostname
SLIST = []
while hs: # look up NS servers in cache with decreasing fqdn
lookedup = self.cache.lookup(hs, Type.NS, Class.IN)
for n in lookedup:
SLIST.append((False, n.rdata.data))
if lookedup:
#print "got NS from cache", hs
break
hs = '.'.join(hs.split('.')[1:])
else:
nameservers = [ns.rdata.data for ns in self.SBELT.records['.']]
SLIST = []
for ns in nameservers:
try:
SLIST.append((True, self.SBELT.records[ns][0].rdata.data))
except KeyError:
SLIST.append((False, ns))
while True:
SLIST.sort(key=lambda a: a[0], reverse=True)
@ -120,6 +146,7 @@ class Resolver(object):
#print "asking " + ip
# state 2
# 3. Send them queries until one returns a response.
#print "querying", ip, "with", hostname
question = dns.message.Question(hostname, Type.A, Class.IN)
query = Resolver.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question])
# state 3
@ -144,6 +171,9 @@ class Resolver(object):
print "NXDOMAIN"
return ([],[],[]) # edge 3 -> 6
for ans in (response.answers + response.additionals + response.authorities):
self.cache.add_record(ans)
if response.header.aa:
# state 5
answers = dictappend([(ans.name, ans.rdata.data) for ans in response.answers if ans.type_ == Type.A])
@ -173,7 +203,6 @@ class Resolver(object):
# step 2.
if response.authorities:
# state 4
# cache
cands = [auth.rdata.data for auth in response.authorities]
d = {}
for add in response.additionals:

View File

@ -82,6 +82,7 @@ class RecordData(object):
Type.NS: NSRecordData,
Type.AAAA: AAAARecordData
}
data = str(data)
if type_ in classdict:
return classdict[type_](data)
else:

View File

@ -9,6 +9,10 @@ zones or record sets.
These classes are merely a suggestion, feel free to use something else.
"""
from resource import ResourceRecord, RecordData
from types import Type
from classes import Class
class Catalog(object):
""" A catalog of zones """
@ -26,13 +30,14 @@ class Catalog(object):
"""
self.zones[name] = zone
from collections import defaultdict
class Zone(object):
""" A zone in the domain name space """
def __init__(self):
""" Initialize the Zone """
self.records = {}
self.records = defaultdict(list)
def add_node(self, name, record_set):
""" Add a record set to the zone
@ -41,7 +46,7 @@ class Zone(object):
name (str): domain name
record_set ([ResourceRecord]): resource records
"""
self.records[name] = record_set
self.records[name].append(record_set)
def read_master_file(self, filename):
""" Read the zone from a master file
@ -51,4 +56,12 @@ class Zone(object):
Args:
filename (str): the filename of the master file
"""
pass
with open(filename) as f:
lines = f.readlines()
records = []
for line in lines:
line = line.strip()
if line.startswith(';'): continue
name, ttl, type_, data = line.split()
rdata = RecordData(data)
self.add_node(name, ResourceRecord(name, Type.by_string[type_], Class.IN, int(ttl), rdata))

View File

@ -21,7 +21,7 @@ if __name__ == "__main__":
# Resolve hostname
resolver = dns.resolver.Resolver(args.caching, args.ttl)
hostname, aliases, addresses = resolver.gethostbyname(args.hostname)
resolver.cache.write_cache_file()
# Print output
print(hostname)
print(aliases)