implement server+tests and write documentation

master
Yorick van Pelt 2016-05-30 20:25:32 +02:00
parent 1326fe84ff
commit d6d84edbab
9 changed files with 430 additions and 64 deletions

View File

@ -12,6 +12,13 @@ from dns.domainname import Parser, Composer
from dns.resource import ResourceRecord
def makepacket(flags, qs=[], ans=[], ns=[], ar=[], ident=9001):
header = Header(ident, 0, len(qs), len(ans), len(ns), len(ar))
for (k,v) in flags.items():
setattr(header, k, v)
return Message(header, qs, ans, ns, ar)
class Message(object):
""" DNS message """

View File

@ -23,12 +23,25 @@ import itertools
def dictappend(seq):
"""
Append the [(k,v)] list to a defaultdict.
"""
s = defaultdict(list)
# update doesn't work
# .update does not work as expected for this.
for (k,v) in seq:
s[k].append(v)
return s
def ipport(name):
"""
Split a name into an (ip, port) pair, defaulting the port to 53
"""
x = name.split(':')
if len(x) == 1:
return x[0], 53
else:
return x[0], int(x[1])
class Resolver(object):
""" DNS resolver """
@ -50,6 +63,9 @@ class Resolver(object):
@classmethod
def query(cls, query, dest):
"""
Send a `query` to the destination address at `dest` and get a response back
"""
timeout = 2
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(timeout)
@ -62,25 +78,39 @@ class Resolver(object):
sock.close()
return response
@classmethod
def makepacket(cls, flags, qs=[], ans=[], ns=[], ar=[], ident=9001):
header = dns.message.Header(ident, 0, len(qs), len(ans), len(ns), len(ar))
for (k,v) in flags.items():
setattr(header, k, v)
return dns.message.Message(header, qs, ans, ns, ar)
def query_dns(self, hostname, type_, server, recursive=False):
"""
Query a nameserver directly with the arguments as specified
"""
question = dns.message.Question(hostname, type_, Class.IN)
query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": recursive}, [question])
print "asking", server
response = Resolver.query(query, server)
if response.header.rcode == RCode.ServFail:
raise Exception("server failure")
return response
def gethostbyname(self, hostname, additionals=[]):
""" Translate a host name to IPv4 address.
def best_ns_from_cache(self, hostname):
"""
This function finds the nearest nameserver for this hostname in the cache.
"""
hs = hostname
a_list = []
while hs: # look up NS servers in cache with decreasing fqdn
lookedup = self.cache.lookup(hs, Type.NS, Class.IN)
for n in lookedup:
a_list += self.cache.lookup(n.rdata.data, Type.A, Class.IN)
if lookedup:
return (a_list, lookedup)
hs = '.'.join(hs.split('.')[1:])
return [], None
Currently this method contains an example. You will have to replace
this example with example with the algorithm described in section
5.3.3 in RFC 1034.
Args:
hostname (str): the hostname to resolve
Returns:
(str, [str], [str]): (hostname, aliaslist, ipaddrlist)
def lookup(self, type_, hostname, additionals=[]):
"""
Look up all of the relevant resource records, either from cache
or by going over the network.
This is used by the server, and also by gethostbyname.
The algorithm is specified in RFC 1034, but copied into the comments here.
"""
SNAME = hostname
@ -89,42 +119,41 @@ class Resolver(object):
# 1. See if the answer is in local information, and if so return
# it to the client.
while True:
while type_ != Type.CNAME:
# follow cnames
lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN)
if lookedupcnames:
#print "found CNAMES", SNAME, lookedupcnames
aliases.append(lookedupcnames[0])
#aliases.append(SNAME)
SNAME = lookedupcnames[0].rdata.data
aliases.append(SNAME)
continue
break
lookedup = self.cache.lookup(SNAME, Type.A, Class.IN)
lookedup = self.cache.lookup(SNAME, type_, Class.IN)
if lookedup:
#print "got result from cache", SNAME
return SNAME, aliases, [d.rdata.data for d in lookedup]
return SNAME, aliases, lookedup
# 2. Find the best servers to ask.
hs = hostname
SLIST = []
while hs: # look up NS servers in cache with decreasing fqdn
lookedup = self.cache.lookup(hs, Type.NS, Class.IN)
for n in lookedup:
SLIST.append((False, n.rdata.data))
ns_a_list, lookedup = self.best_ns_from_cache(hostname)
if lookedup:
#print "got NS from cache", hs
break
hs = '.'.join(hs.split('.')[1:])
for n in lookedup:
# TODO: use ns_a_list
SLIST.append((False, n.rdata.data))
else:
nameservers = [ns.rdata.data for ns in self.SBELT.records['.']]
nameservers = [ns.rdata.data for ns in self.SBELT.lookup('.')[1]]
SLIST = []
for ns in nameservers:
try:
SLIST.append((True, self.SBELT.records[ns][0].rdata.data))
ips = [x.rdata.data for x in self.SBELT.lookup(ns)[1] if x.type_ == Type.A]
SLIST.append((True, ips[0]))
except KeyError:
SLIST.append((False, ns))
while True:
# do the looked up ones first
SLIST.sort(key=lambda a: a[0], reverse=True)
if len(SLIST) == 0:
return [[], [], []]
@ -135,7 +164,9 @@ class Resolver(object):
if not isLookedUp:
# XXX: this can get into an infinite loop
#print "looking up name server"
ip,port = ipport(ip)
ns_hostname, ns_aliases, ns_ips = self.gethostbyname(ip)
ns_ips = [i + ":" + str(port) for i in ns_ips]
if len(ns_ips):
ip = ns_ips[0]
SLIST = ns_ips[1:] + SLIST
@ -146,15 +177,12 @@ class Resolver(object):
# state 2
# 3. Send them queries until one returns a response.
#print "querying", ip, "with", hostname
question = dns.message.Question(hostname, Type.A, Class.IN)
query = Resolver.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question])
question = dns.message.Question(hostname, type_, Class.IN)
query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question])
# state 3
try:
response = Resolver.query(query, (ip, 53))
# temp
if response.header.rcode == RCode.ServFail:
raise Exception("server failure")
except Exception as e:
response = self.query_dns(hostname, type_, ipport(ip), False)
except RuntimeError as e:
print e
# d. if the response shows a servers failure or other
# bizarre contents, delete the server from the SLIST and
@ -167,7 +195,7 @@ class Resolver(object):
# error, cache the data as well as returning it back to
# the client.
if response.header.rcode == dns.rcodes.RCode.NXDomain:
print "NXDOMAIN"
print "NXDOMAIN", SNAME
return ([],[],[]) # edge 3 -> 6
for ans in (response.answers + response.additionals + response.authorities):
@ -175,26 +203,28 @@ class Resolver(object):
if response.header.aa:
# state 5
answers = dictappend([(ans.name, ans.rdata.data) for ans in response.answers if ans.type_ == Type.A])
cnames = dictappend([(ans.name, ans.rdata.data) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME])
answers = dictappend([(ans.name, ans) for ans in response.answers if ans.type_ == type_])
cnames = dictappend([(ans.name, ans) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME])
if len(answers) + len(cnames) == 0:
return SNAME, aliases, []
while True:
if SNAME in cnames:
if SNAME in cnames and type_ != Type.CNAME:
# state 7
# c. if the response shows a CNAME and that is not the
# answer itself, cache the CNAME, change the SNAME to the
# canonical name in the CNAME RR and go to step 1.
if cnames[SNAME][0] in aliases:
#if cnames[SNAME][0] in aliases: # XXX Fix this
#print "not following cname loop"
return SNAME, aliases, []
# return SNAME, aliases, []
aliases.append(cnames[SNAME][0])
SNAME = cnames[SNAME][0]
SNAME = cnames[SNAME][0].rdata.data
#print "following CNAME", SNAME
if SNAME in answers:
return SNAME, aliases, answers[SNAME]
break # 5 -> 5'
# recursing
# 5 -> 0
return self.gethostbyname(SNAME, aliases)
return self.lookup(type_, SNAME, aliases)
# b. if the response contains a better delegation to other
@ -212,3 +242,17 @@ class Resolver(object):
continue # 4 -> 1
# NOT REACHED
#print "NOT REACHED"
def gethostbyname(self, hostname, additionals=[]):
""" Translate a host name to IPv4 address.
Args:
hostname (str): the hostname to resolve
Returns:
(str, [str], [str]): (hostname, aliaslist, ipaddrlist)
"""
hostname, aliases, responses = self.lookup(Type.A, hostname, additionals)
aliases = [a.name for a in aliases]
responses = [r.rdata.data for r in responses]
return hostname, aliases, responses

View File

@ -28,10 +28,25 @@ class ResourceRecord(object):
self.name = name
self.type_ = type_
self.class_ = class_
self.ttl = ttl
self._ttl = ttl
self.rdata = rdata
self.createDate = createDate
def copy(self):
return ResourceRecord(self.name, self.type_, self.class_, self._ttl, self.rdata, self.createDate)
@property
def ttl(self):
if self.createDate is None:
return self._ttl
else:
return int(self._ttl - (datetime.now() - self.createDate).total_seconds())
@ttl.setter
def ttl(self, value):
self.createDate = datetime.now()
self._ttl = value
def valid(self):
if self.createDate is None:
return True
@ -96,6 +111,7 @@ class RecordData(object):
if type_ in classdict:
return classdict[type_](data)
else:
print "warning,", type_, "not found"
return GenericRecordData(data)
@staticmethod

View File

@ -8,19 +8,189 @@ server using the algorithm described in section 4.3.2 of RFC 1034.
from threading import Thread
import socket
import dns.message
import dns.zone
from dns.classes import Class
from dns.types import Type
from resolver import Resolver
class RequestHandler(Thread):
""" A handler for requests to the DNS server """
def __init__(self):
def __init__(self, addr, packet, parent):
""" Initialize the handler thread """
super(RequestHandler, self).__init__()
self.daemon = True
self.packet = packet
self.addr = addr
self.parent = parent
def generate_response(self, QNAME, QTYPE, packet, rec=False, answers=None,auth=None,additional=None, response_flags = None):
"""
Generate the response records and flags for a query.
This follows the algorithm as outlined in the RFC, and copied in the comments below.
"""
if answers is None: answers = []
if auth is None: auth = []
if additional is None: additional = []
print "generating result for ", QNAME, QTYPE
# The actual algorithm used by the name server will depend on the local OS
# and data structures used to store RRs. The following algorithm assumes
# that the RRs are organized in several tree structures, one for each
# zone, and another for the cache:
if response_flags is None: response_flags = {}
# 1. Set or clear the value of recursion available in the response
# depending on whether the name server is willing to provide
# recursive service. If recursive service is available and
# requested via the RD bit in the query, go to step 5,
# otherwise step 2.
response_flags['ra'] = 1
done = False
if not packet.header.rd:
# 2. Search the available zones for the zone which is the nearest
# ancestor to QNAME. If such a zone is found, go to step 3,
# otherwise step 4.
zone = self.parent.catalog.find_nearest(QNAME)
if zone:
(matchtype, lookedup) = zone.lookup(QNAME)
# 3. Start matching down, label by label, in the zone. The
# matching process can terminate several ways:
if matchtype == 'full':
# a. If the whole of QNAME is matched, we have found the
# node.
# If the data at the node is a CNAME, and QTYPE doesn't
# match CNAME, copy the CNAME RR into the answer section
# of the response, change QNAME to the canonical name in
# the CNAME RR, and go back to step 1.
cnames = [n for n in lookedup if n.type_ == Type.CNAME]
response_flags['aa'] = True
if cnames:
#print "following cname", cnames[0].rdata.data
answers.append(cnames[0])
QNAME = cnames[0].rdata.data
return self.generate_response(cnames[0].rdata.data, QTYPE, packet, True, answers, auth, additional, response_flags)
else:
# Otherwise, copy all RRs which match QTYPE into the
# answer section and go to step 6.
answers += [n for n in lookedup if n.type_ == QTYPE]
response_flags['aa'] = True
done = True
elif matchtype == 'auth':
# b. If a match would take us out of the authoritative data,
# we have a referral. This happens when we encounter a
# node with NS RRs marking cuts along the bottom of a
# zone.
# Copy the NS RRs for the subzone into the authority
# section of the reply. Put whatever addresses are
# available into the additional section, using glue RRs
# if the addresses are not available from authoritative
# data or the cache. Go to step 4.
nameservs = [n for n in lookedup if n.type_ == Type.NS]
response_flags['aa'] = True
auth += nameservs
for serv in nameservs:
# try to find them in the zone
(match2, look2) = zone.lookup(serv.rdata.data)
if match2 == 'full':
additional += look2
# try to find them in the cache
else:
look2 = self.parent.resolver.cache.lookup(serv.rdata.data, Types.A, Class.IN)
additional += look2
# c. If at some label, a match is impossible (i.e., the
# corresponding label does not exist), look to see if a
# the "*" label exists.
elif matchtype == 'no':
# If the "*" label does not exist, check whether the name
# we are looking for is the original QNAME in the query
# or a name we have followed due to a CNAME. If the name
# is original, set an authoritative name error in the
# response and exit. Otherwise just exit.
if not rec:
response_flags['rcode'] = dns.rcodes.RCode.NXDomain
done = True
elif matchtype == 'star':
# If the "*" label does exist, match RRs at that node
# against QTYPE. If any match, copy them into the answer
# section, but set the owner of the RR to be QNAME, and
# not the node with the "*" label. Go to step 6.
new = [n.copy() for n in lookedup if n.type_ == QTYPE]
for x in new:
x.name = QNAME
answers += new
done = True
if not done:
# 4. Start matching down in the cache. If QNAME is found in the
# cache, copy all RRs attached to it that match QTYPE into the
# answer section. If there was no delegation from
# authoritative data, look for the best one from the cache, and
# put it in the authority section. Go to step 6.
lookedup = self.parent.resolver.cache.lookup(QNAME, QTYPE, Class.IN)
answers += lookedup
if zone and matchtype == 'auth' and len(auth) == 0:
(addit, ns) = self.parent.resolver.best_ns_from_cache(QNAME)
auth += ns
additional += addit
else:
# 5. Using the local resolver or a copy of its algorithm (see
# resolver section of this memo) to answer the query. Store
# the results, including any intermediate CNAMEs, in the answer
# section of the response.
hostname, alias, ips = self.parent.resolver.lookup(QTYPE, QNAME)
answers += alias + ips
# 6. Using local data only, attempt to add other RRs which may be
# useful to the additional section of the query. Exit.
pass # This happens in step 4 and 3b
response_flags['rq'] = 1
response_flags['opcode'] = 0
return response_flags,answers,auth,additional
def run(self):
""" Run the handler thread """
# TODO: Handle DNS request
pass
# parse the packet
packet = dns.message.Message.from_bytes(self.packet)
assert len(packet.questions) == 1
q = packet.questions[0]
answers = []
auth = []
additional = []
response_flags,answers,auth,additional = self.generate_response(q.qname, q.qtype, packet)
print "sending", answers, auth, additional
message = dns.message.makepacket(response_flags, packet.questions, answers, auth, additional, ident=packet.header.ident)
self.parent.sock.sendto(message.to_bytes(), self.addr)
class Server(object):
@ -38,16 +208,24 @@ class Server(object):
self.ttl = ttl
self.port = port
self.done = False
# TODO: create socket
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.resolver = Resolver(caching, ttl)
self.resolver.cache.read_cache_file()
self.catalog = dns.zone.Catalog()
def add_zone(self, name, file):
z = dns.zone.Zone()
z.read_master_file(file)
self.catalog.add_zone(name, z)
def serve(self):
""" Start serving request """
# TODO: start listening
self.sock.bind(('', self.port))
while not self.done:
# TODO: receive request and open handler
pass
(packet, addr) = self.sock.recvfrom(4096)
RequestHandler(addr, packet, self).start()
def shutdown(self):
""" Shutdown the server """
self.done = True
# TODO: shutdown socket
self.sock.close()
self.resolver.cache.write_cache_file()

View File

@ -30,14 +30,28 @@ class Catalog(object):
"""
self.zones[name] = zone
def find_nearest(self, name):
""" Find the zone that best matches 'name'
"""
name = name + '.'
while name:
if name in self.zones:
return self.zones[name]
name = '.'.join(name.split('.')[1:])
return None
from collections import defaultdict
defdict = lambda: defaultdict(defdict)
CONTENT_KEY = None # the dicts otherwise store strings, so this won't give a collision
class Zone(object):
""" A zone in the domain name space """
def __init__(self):
""" Initialize the Zone """
self.records = defaultdict(list)
self.records = defdict()
def add_node(self, name, record_set):
""" Add a record set to the zone
@ -46,8 +60,39 @@ class Zone(object):
name (str): domain name
record_set ([ResourceRecord]): resource records
"""
self.records[name].append(record_set)
rec = self.records
for n in reversed(name.split('.')):
rec = rec[n]
if CONTENT_KEY in rec: # no collisions
rec[CONTENT_KEY] += record_set
else:
rec[CONTENT_KEY] = record_set
def lookup(self, name):
""" Find the tree node that best matches name
Possible matches are:
- no
- star (the last match is a *.a.b),
- auth (the match is probably a NS record),
- full (the name was found entirely)
"""
rec = self.records
match_type = 'no'
for i,n in enumerate(reversed(name.split('.'))):
if n in rec:
rec = rec[n]
elif '*' in rec and i == len(name.split('.'))-1:
rec = rec['*']
match_type = 'star'
break
else:
match_type = 'auth'
break # there might be nameservers here
else:
match_type = 'full'
if CONTENT_KEY in rec:
return (match_type, rec[CONTENT_KEY])
else:
return ('no', [])
def read_master_file(self, filename):
""" Read the zone from a master file
@ -61,7 +106,8 @@ class Zone(object):
records = []
for line in lines:
line = line.strip()
if line.startswith(';'): continue
if line.startswith(';') or not line: continue
name, ttl, type_, data = line.split()
rdata = RecordData(data)
self.add_node(name, ResourceRecord(name, Type.by_string[type_], Class.IN, int(ttl), rdata))
t = Type.by_string[type_]
rdata = RecordData.create(t, data)
self.add_node(name, [ResourceRecord(name, t, Class.IN, int(ttl), rdata)])

View File

@ -21,6 +21,7 @@ if __name__ == "__main__":
# Start server
server = dns.server.Server(args.port, args.caching, args.ttl)
server.add_zone("test.com.", "test.zone")
try:
server.serve()
except KeyboardInterrupt:

View File

@ -19,7 +19,7 @@ class TestResolver(unittest.TestCase):
resolver = dns.resolver.Resolver(False, 0)
host, alias, ip = resolver.gethostbyname("mail.polvanaubel.com")
self.assertEqual(host, "sog.polvanaubel.com")
self.assertEqual(alias, ["sog.polvanaubel.com"])
self.assertEqual(alias, ["mail.polvanaubel.com"])
self.assertEqual(ip, ["138.201.39.104"])
def test_invalid(self):
resolver = dns.resolver.Resolver(False, 0)
@ -49,9 +49,63 @@ class TestResolverCache(unittest.TestCase):
self.assertFalse(alias)
self.assertFalse(ip)
import socket
class TestServer(unittest.TestCase):
pass
# solve a query for a FQDN for which your server has direct authority
def test_solve_auth(self):
resolver = dns.resolver.Resolver(True, 0)
resolver.cache.add_record(ResourceRecord("localhost", Type.A, Class.IN, 60, RecordData("127.0.0.1")))
resolver.cache.add_record(ResourceRecord("test.com", Type.NS, Class.IN, 60, RecordData(server + ":" + str(portnr))))
host, alias, ip = resolver.gethostbyname("cnametest.test.com")
self.assertEqual(host, "hello.test.com")
self.assertEqual(alias, ["cnametest.test.com"])
self.assertEqual(ip, ["1.2.3.4"])
# solve a query for a FQDN for which your server does not have direct authority,
# yet there is a name server in your zone which does
def test_solve_ns_inside(self):
# I had a bit of trouble figuring this one out
# I think it would require me to set up two servers
resolver = dns.resolver.Resolver(True, 0)
ans = resolver.query_dns("a.subzone.test.com", Type.A, (server, portnr), False)
self.assertEqual(ans.authorities[0].rdata.data, "nameserver.test.com")
self.assertEqual(ans.additionals[0].name, "nameserver.test.com")
# solve a query for a fqdn which points outside your zone
def test_solve_outside(self):
resolver = dns.resolver.Resolver(True, 0)
# resolver.cache.add_record(ResourceRecord("localhost", Type.A, Class.IN, 60, RecordData("127.0.0.1")))
# resolver.cache.add_record(ResourceRecord("edu", Type.NS, Class.IN, 60, RecordData(server + ":" + str(portnr))))
# host, alias, ip = resolver.gethostbyname("gaia.cs.umass.edu")
ans = resolver.query_dns("gaia.cs.umass.edu", Type.A, (server, portnr), True)
self.assertEqual(ans.answers[0].rdata.data, "128.119.245.12")
# solve parallel requests for different FQDN, their servicing should be made in parallel and correct
# responses should be generated
def test_solve_parallel(self):
resolver = dns.resolver.Resolver(True, 0)
question1 = dns.message.Question("gaia.cs.umass.edu", Type.A, Class.IN)
question2 = dns.message.Question("yori.cc", Type.A, Class.IN)
query1 = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": True}, [question1], ident=10)
query2 = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": True}, [question2], ident=20)
timeout = 2
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(timeout)
sock.sendto(query1.to_bytes(), (server, portnr))
sock.sendto(query2.to_bytes(), (server, portnr))
# Receive response
for i in range(2):
data = sock.recv(65530) # max MTU is not this value
response = dns.message.Message.from_bytes(data)
if response.header.ident == 10:
self.assertEqual(response.answers[0].rdata.data, "128.119.245.12")
elif response.header.ident == 20:
self.assertEqual(response.answers[0].rdata.data, "138.201.39.105")
else:
self.assertFalse(True)
sock.close()
if __name__ == "__main__":

14
documentation.txt Normal file
View File

@ -0,0 +1,14 @@
Name: Yorick van Pelt
student nr: s4503678
The general flow of this program is described in the RFC 1034, which has been followed as much as possible.
The DNS server sends authority flags if it is the authority (described in the RFC), and the resolver looks at the authority flags to see if it has found the authoritative nameserver.
I documented the resolver on a state machine on paper, but it is not included here.
I organized the zone into a tree, where the lookup algorithm from the server is implemented (in zone.py)
- I crafted DNS messages using your own libraries
- I did not have to generate transaction IDs, because every transaction used a different socket.
- I used threads for concurrencies
- I remove expired entries from the cache upon lookup, load and store
I had some trouble implementing the testcase with where there would be an NS record in the zone, pointing to another nameserver. I implemented it as best as I could, but I think it's only testable with at least two DNS servers.

6
test.zone Normal file
View File

@ -0,0 +1,6 @@
hello.test.com 30 A 1.2.3.4
*.test.com 30 A 2.3.4.5
cnametest.test.com 30 CNAME hello.test.com
nameserver.test.com 30 A 8.8.8.8
subzone.test.com 30 NS nameserver.test.com