diff --git a/dns/message.py b/dns/message.py index 3333f43..1f75c62 100644 --- a/dns/message.py +++ b/dns/message.py @@ -12,6 +12,13 @@ from dns.domainname import Parser, Composer from dns.resource import ResourceRecord +def makepacket(flags, qs=[], ans=[], ns=[], ar=[], ident=9001): + header = Header(ident, 0, len(qs), len(ans), len(ns), len(ar)) + for (k,v) in flags.items(): + setattr(header, k, v) + return Message(header, qs, ans, ns, ar) + + class Message(object): """ DNS message """ diff --git a/dns/resolver.py b/dns/resolver.py index 1479e24..5aecc37 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -23,12 +23,25 @@ import itertools def dictappend(seq): + """ + Append the [(k,v)] list to a defaultdict. + """ s = defaultdict(list) - # update doesn't work + # .update does not work as expected for this. for (k,v) in seq: s[k].append(v) return s +def ipport(name): + """ + Split a name into an (ip, port) pair, defaulting the port to 53 + """ + x = name.split(':') + if len(x) == 1: + return x[0], 53 + else: + return x[0], int(x[1]) + class Resolver(object): """ DNS resolver """ @@ -50,6 +63,9 @@ class Resolver(object): @classmethod def query(cls, query, dest): + """ + Send a `query` to the destination address at `dest` and get a response back + """ timeout = 2 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) @@ -62,25 +78,39 @@ class Resolver(object): sock.close() return response - @classmethod - def makepacket(cls, flags, qs=[], ans=[], ns=[], ar=[], ident=9001): - header = dns.message.Header(ident, 0, len(qs), len(ans), len(ns), len(ar)) - for (k,v) in flags.items(): - setattr(header, k, v) - return dns.message.Message(header, qs, ans, ns, ar) + def query_dns(self, hostname, type_, server, recursive=False): + """ + Query a nameserver directly with the arguments as specified + """ + question = dns.message.Question(hostname, type_, Class.IN) + query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": recursive}, [question]) + print "asking", server + response = Resolver.query(query, server) + if response.header.rcode == RCode.ServFail: + raise Exception("server failure") + return response - def gethostbyname(self, hostname, additionals=[]): - """ Translate a host name to IPv4 address. + def best_ns_from_cache(self, hostname): + """ + This function finds the nearest nameserver for this hostname in the cache. + """ + hs = hostname + a_list = [] + while hs: # look up NS servers in cache with decreasing fqdn + lookedup = self.cache.lookup(hs, Type.NS, Class.IN) + for n in lookedup: + a_list += self.cache.lookup(n.rdata.data, Type.A, Class.IN) + if lookedup: + return (a_list, lookedup) + hs = '.'.join(hs.split('.')[1:]) + return [], None - Currently this method contains an example. You will have to replace - this example with example with the algorithm described in section - 5.3.3 in RFC 1034. - - Args: - hostname (str): the hostname to resolve - - Returns: - (str, [str], [str]): (hostname, aliaslist, ipaddrlist) + def lookup(self, type_, hostname, additionals=[]): + """ + Look up all of the relevant resource records, either from cache + or by going over the network. + This is used by the server, and also by gethostbyname. + The algorithm is specified in RFC 1034, but copied into the comments here. """ SNAME = hostname @@ -89,42 +119,41 @@ class Resolver(object): # 1. See if the answer is in local information, and if so return # it to the client. - while True: + while type_ != Type.CNAME: # follow cnames lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN) if lookedupcnames: #print "found CNAMES", SNAME, lookedupcnames + aliases.append(lookedupcnames[0]) + #aliases.append(SNAME) SNAME = lookedupcnames[0].rdata.data - aliases.append(SNAME) continue break - lookedup = self.cache.lookup(SNAME, Type.A, Class.IN) + lookedup = self.cache.lookup(SNAME, type_, Class.IN) if lookedup: #print "got result from cache", SNAME - return SNAME, aliases, [d.rdata.data for d in lookedup] + return SNAME, aliases, lookedup # 2. Find the best servers to ask. - hs = hostname SLIST = [] - while hs: # look up NS servers in cache with decreasing fqdn - lookedup = self.cache.lookup(hs, Type.NS, Class.IN) + ns_a_list, lookedup = self.best_ns_from_cache(hostname) + if lookedup: for n in lookedup: + # TODO: use ns_a_list SLIST.append((False, n.rdata.data)) - if lookedup: - #print "got NS from cache", hs - break - hs = '.'.join(hs.split('.')[1:]) else: - nameservers = [ns.rdata.data for ns in self.SBELT.records['.']] + nameservers = [ns.rdata.data for ns in self.SBELT.lookup('.')[1]] SLIST = [] for ns in nameservers: try: - SLIST.append((True, self.SBELT.records[ns][0].rdata.data)) + ips = [x.rdata.data for x in self.SBELT.lookup(ns)[1] if x.type_ == Type.A] + SLIST.append((True, ips[0])) except KeyError: SLIST.append((False, ns)) while True: + # do the looked up ones first SLIST.sort(key=lambda a: a[0], reverse=True) if len(SLIST) == 0: return [[], [], []] @@ -135,7 +164,9 @@ class Resolver(object): if not isLookedUp: # XXX: this can get into an infinite loop #print "looking up name server" + ip,port = ipport(ip) ns_hostname, ns_aliases, ns_ips = self.gethostbyname(ip) + ns_ips = [i + ":" + str(port) for i in ns_ips] if len(ns_ips): ip = ns_ips[0] SLIST = ns_ips[1:] + SLIST @@ -146,15 +177,12 @@ class Resolver(object): # state 2 # 3. Send them queries until one returns a response. #print "querying", ip, "with", hostname - question = dns.message.Question(hostname, Type.A, Class.IN) - query = Resolver.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question]) + question = dns.message.Question(hostname, type_, Class.IN) + query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question]) # state 3 try: - response = Resolver.query(query, (ip, 53)) - # temp - if response.header.rcode == RCode.ServFail: - raise Exception("server failure") - except Exception as e: + response = self.query_dns(hostname, type_, ipport(ip), False) + except RuntimeError as e: print e # d. if the response shows a servers failure or other # bizarre contents, delete the server from the SLIST and @@ -167,7 +195,7 @@ class Resolver(object): # error, cache the data as well as returning it back to # the client. if response.header.rcode == dns.rcodes.RCode.NXDomain: - print "NXDOMAIN" + print "NXDOMAIN", SNAME return ([],[],[]) # edge 3 -> 6 for ans in (response.answers + response.additionals + response.authorities): @@ -175,26 +203,28 @@ class Resolver(object): if response.header.aa: # state 5 - answers = dictappend([(ans.name, ans.rdata.data) for ans in response.answers if ans.type_ == Type.A]) - cnames = dictappend([(ans.name, ans.rdata.data) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME]) + answers = dictappend([(ans.name, ans) for ans in response.answers if ans.type_ == type_]) + cnames = dictappend([(ans.name, ans) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME]) + if len(answers) + len(cnames) == 0: + return SNAME, aliases, [] while True: - if SNAME in cnames: + if SNAME in cnames and type_ != Type.CNAME: # state 7 # c. if the response shows a CNAME and that is not the # answer itself, cache the CNAME, change the SNAME to the # canonical name in the CNAME RR and go to step 1. - if cnames[SNAME][0] in aliases: + #if cnames[SNAME][0] in aliases: # XXX Fix this #print "not following cname loop" - return SNAME, aliases, [] + # return SNAME, aliases, [] aliases.append(cnames[SNAME][0]) - SNAME = cnames[SNAME][0] + SNAME = cnames[SNAME][0].rdata.data #print "following CNAME", SNAME if SNAME in answers: return SNAME, aliases, answers[SNAME] break # 5 -> 5' # recursing # 5 -> 0 - return self.gethostbyname(SNAME, aliases) + return self.lookup(type_, SNAME, aliases) # b. if the response contains a better delegation to other @@ -212,3 +242,17 @@ class Resolver(object): continue # 4 -> 1 # NOT REACHED #print "NOT REACHED" + + def gethostbyname(self, hostname, additionals=[]): + """ Translate a host name to IPv4 address. + + Args: + hostname (str): the hostname to resolve + + Returns: + (str, [str], [str]): (hostname, aliaslist, ipaddrlist) + """ + hostname, aliases, responses = self.lookup(Type.A, hostname, additionals) + aliases = [a.name for a in aliases] + responses = [r.rdata.data for r in responses] + return hostname, aliases, responses diff --git a/dns/resource.py b/dns/resource.py index 8c98551..0ae08f5 100644 --- a/dns/resource.py +++ b/dns/resource.py @@ -28,10 +28,25 @@ class ResourceRecord(object): self.name = name self.type_ = type_ self.class_ = class_ - self.ttl = ttl + self._ttl = ttl self.rdata = rdata self.createDate = createDate + def copy(self): + return ResourceRecord(self.name, self.type_, self.class_, self._ttl, self.rdata, self.createDate) + + @property + def ttl(self): + if self.createDate is None: + return self._ttl + else: + return int(self._ttl - (datetime.now() - self.createDate).total_seconds()) + + @ttl.setter + def ttl(self, value): + self.createDate = datetime.now() + self._ttl = value + def valid(self): if self.createDate is None: return True @@ -96,6 +111,7 @@ class RecordData(object): if type_ in classdict: return classdict[type_](data) else: + print "warning,", type_, "not found" return GenericRecordData(data) @staticmethod diff --git a/dns/server.py b/dns/server.py index f307db3..229b5be 100644 --- a/dns/server.py +++ b/dns/server.py @@ -8,19 +8,189 @@ server using the algorithm described in section 4.3.2 of RFC 1034. from threading import Thread +import socket + +import dns.message +import dns.zone + +from dns.classes import Class +from dns.types import Type + +from resolver import Resolver class RequestHandler(Thread): """ A handler for requests to the DNS server """ - def __init__(self): + def __init__(self, addr, packet, parent): """ Initialize the handler thread """ super(RequestHandler, self).__init__() self.daemon = True + self.packet = packet + self.addr = addr + self.parent = parent + + def generate_response(self, QNAME, QTYPE, packet, rec=False, answers=None,auth=None,additional=None, response_flags = None): + """ + Generate the response records and flags for a query. + This follows the algorithm as outlined in the RFC, and copied in the comments below. + """ + if answers is None: answers = [] + if auth is None: auth = [] + if additional is None: additional = [] + + print "generating result for ", QNAME, QTYPE + + # The actual algorithm used by the name server will depend on the local OS + # and data structures used to store RRs. The following algorithm assumes + # that the RRs are organized in several tree structures, one for each + # zone, and another for the cache: + if response_flags is None: response_flags = {} + + # 1. Set or clear the value of recursion available in the response + # depending on whether the name server is willing to provide + # recursive service. If recursive service is available and + # requested via the RD bit in the query, go to step 5, + # otherwise step 2. + response_flags['ra'] = 1 + + done = False + + if not packet.header.rd: + + # 2. Search the available zones for the zone which is the nearest + # ancestor to QNAME. If such a zone is found, go to step 3, + # otherwise step 4. + + zone = self.parent.catalog.find_nearest(QNAME) + + if zone: + (matchtype, lookedup) = zone.lookup(QNAME) + + + # 3. Start matching down, label by label, in the zone. The + # matching process can terminate several ways: + + if matchtype == 'full': + # a. If the whole of QNAME is matched, we have found the + # node. + + # If the data at the node is a CNAME, and QTYPE doesn't + # match CNAME, copy the CNAME RR into the answer section + # of the response, change QNAME to the canonical name in + # the CNAME RR, and go back to step 1. + cnames = [n for n in lookedup if n.type_ == Type.CNAME] + response_flags['aa'] = True + if cnames: + #print "following cname", cnames[0].rdata.data + answers.append(cnames[0]) + QNAME = cnames[0].rdata.data + return self.generate_response(cnames[0].rdata.data, QTYPE, packet, True, answers, auth, additional, response_flags) + + else: + # Otherwise, copy all RRs which match QTYPE into the + # answer section and go to step 6. + answers += [n for n in lookedup if n.type_ == QTYPE] + response_flags['aa'] = True + done = True + + elif matchtype == 'auth': + # b. If a match would take us out of the authoritative data, + # we have a referral. This happens when we encounter a + # node with NS RRs marking cuts along the bottom of a + # zone. + + # Copy the NS RRs for the subzone into the authority + # section of the reply. Put whatever addresses are + # available into the additional section, using glue RRs + # if the addresses are not available from authoritative + # data or the cache. Go to step 4. + nameservs = [n for n in lookedup if n.type_ == Type.NS] + response_flags['aa'] = True + auth += nameservs + for serv in nameservs: + # try to find them in the zone + (match2, look2) = zone.lookup(serv.rdata.data) + if match2 == 'full': + additional += look2 + # try to find them in the cache + else: + look2 = self.parent.resolver.cache.lookup(serv.rdata.data, Types.A, Class.IN) + additional += look2 + + + # c. If at some label, a match is impossible (i.e., the + # corresponding label does not exist), look to see if a + # the "*" label exists. + + elif matchtype == 'no': + # If the "*" label does not exist, check whether the name + # we are looking for is the original QNAME in the query + # or a name we have followed due to a CNAME. If the name + # is original, set an authoritative name error in the + # response and exit. Otherwise just exit. + if not rec: + response_flags['rcode'] = dns.rcodes.RCode.NXDomain + done = True + + elif matchtype == 'star': + # If the "*" label does exist, match RRs at that node + # against QTYPE. If any match, copy them into the answer + # section, but set the owner of the RR to be QNAME, and + # not the node with the "*" label. Go to step 6. + new = [n.copy() for n in lookedup if n.type_ == QTYPE] + for x in new: + x.name = QNAME + answers += new + done = True + if not done: + # 4. Start matching down in the cache. If QNAME is found in the + # cache, copy all RRs attached to it that match QTYPE into the + # answer section. If there was no delegation from + # authoritative data, look for the best one from the cache, and + # put it in the authority section. Go to step 6. + lookedup = self.parent.resolver.cache.lookup(QNAME, QTYPE, Class.IN) + answers += lookedup + if zone and matchtype == 'auth' and len(auth) == 0: + (addit, ns) = self.parent.resolver.best_ns_from_cache(QNAME) + auth += ns + additional += addit + else: + # 5. Using the local resolver or a copy of its algorithm (see + # resolver section of this memo) to answer the query. Store + # the results, including any intermediate CNAMEs, in the answer + # section of the response. + + hostname, alias, ips = self.parent.resolver.lookup(QTYPE, QNAME) + answers += alias + ips + + # 6. Using local data only, attempt to add other RRs which may be + # useful to the additional section of the query. Exit. + pass # This happens in step 4 and 3b + + response_flags['rq'] = 1 + response_flags['opcode'] = 0 + return response_flags,answers,auth,additional + def run(self): """ Run the handler thread """ - # TODO: Handle DNS request - pass + # parse the packet + packet = dns.message.Message.from_bytes(self.packet) + assert len(packet.questions) == 1 + + q = packet.questions[0] + + answers = [] + auth = [] + additional = [] + + response_flags,answers,auth,additional = self.generate_response(q.qname, q.qtype, packet) + + print "sending", answers, auth, additional + message = dns.message.makepacket(response_flags, packet.questions, answers, auth, additional, ident=packet.header.ident) + + self.parent.sock.sendto(message.to_bytes(), self.addr) + class Server(object): @@ -38,16 +208,24 @@ class Server(object): self.ttl = ttl self.port = port self.done = False - # TODO: create socket + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.resolver = Resolver(caching, ttl) + self.resolver.cache.read_cache_file() + self.catalog = dns.zone.Catalog() + + def add_zone(self, name, file): + z = dns.zone.Zone() + z.read_master_file(file) + self.catalog.add_zone(name, z) def serve(self): """ Start serving request """ - # TODO: start listening + self.sock.bind(('', self.port)) while not self.done: - # TODO: receive request and open handler - pass - + (packet, addr) = self.sock.recvfrom(4096) + RequestHandler(addr, packet, self).start() def shutdown(self): """ Shutdown the server """ self.done = True - # TODO: shutdown socket + self.sock.close() + self.resolver.cache.write_cache_file() diff --git a/dns/zone.py b/dns/zone.py index 509b530..cf0a4b3 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -30,14 +30,28 @@ class Catalog(object): """ self.zones[name] = zone + def find_nearest(self, name): + """ Find the zone that best matches 'name' + """ + name = name + '.' + while name: + if name in self.zones: + return self.zones[name] + name = '.'.join(name.split('.')[1:]) + return None + from collections import defaultdict +defdict = lambda: defaultdict(defdict) + +CONTENT_KEY = None # the dicts otherwise store strings, so this won't give a collision + class Zone(object): """ A zone in the domain name space """ def __init__(self): """ Initialize the Zone """ - self.records = defaultdict(list) + self.records = defdict() def add_node(self, name, record_set): """ Add a record set to the zone @@ -46,8 +60,39 @@ class Zone(object): name (str): domain name record_set ([ResourceRecord]): resource records """ - self.records[name].append(record_set) - + rec = self.records + for n in reversed(name.split('.')): + rec = rec[n] + if CONTENT_KEY in rec: # no collisions + rec[CONTENT_KEY] += record_set + else: + rec[CONTENT_KEY] = record_set + def lookup(self, name): + """ Find the tree node that best matches name + Possible matches are: + - no + - star (the last match is a *.a.b), + - auth (the match is probably a NS record), + - full (the name was found entirely) + """ + rec = self.records + match_type = 'no' + for i,n in enumerate(reversed(name.split('.'))): + if n in rec: + rec = rec[n] + elif '*' in rec and i == len(name.split('.'))-1: + rec = rec['*'] + match_type = 'star' + break + else: + match_type = 'auth' + break # there might be nameservers here + else: + match_type = 'full' + if CONTENT_KEY in rec: + return (match_type, rec[CONTENT_KEY]) + else: + return ('no', []) def read_master_file(self, filename): """ Read the zone from a master file @@ -61,7 +106,8 @@ class Zone(object): records = [] for line in lines: line = line.strip() - if line.startswith(';'): continue + if line.startswith(';') or not line: continue name, ttl, type_, data = line.split() - rdata = RecordData(data) - self.add_node(name, ResourceRecord(name, Type.by_string[type_], Class.IN, int(ttl), rdata)) + t = Type.by_string[type_] + rdata = RecordData.create(t, data) + self.add_node(name, [ResourceRecord(name, t, Class.IN, int(ttl), rdata)]) diff --git a/dns_server.py b/dns_server.py index f1d4b09..c3ef998 100644 --- a/dns_server.py +++ b/dns_server.py @@ -21,6 +21,7 @@ if __name__ == "__main__": # Start server server = dns.server.Server(args.port, args.caching, args.ttl) + server.add_zone("test.com.", "test.zone") try: server.serve() except KeyboardInterrupt: diff --git a/dns_tests.py b/dns_tests.py index 439b46b..413a914 100644 --- a/dns_tests.py +++ b/dns_tests.py @@ -19,7 +19,7 @@ class TestResolver(unittest.TestCase): resolver = dns.resolver.Resolver(False, 0) host, alias, ip = resolver.gethostbyname("mail.polvanaubel.com") self.assertEqual(host, "sog.polvanaubel.com") - self.assertEqual(alias, ["sog.polvanaubel.com"]) + self.assertEqual(alias, ["mail.polvanaubel.com"]) self.assertEqual(ip, ["138.201.39.104"]) def test_invalid(self): resolver = dns.resolver.Resolver(False, 0) @@ -49,9 +49,63 @@ class TestResolverCache(unittest.TestCase): self.assertFalse(alias) self.assertFalse(ip) +import socket class TestServer(unittest.TestCase): - pass + # solve a query for a FQDN for which your server has direct authority + def test_solve_auth(self): + resolver = dns.resolver.Resolver(True, 0) + resolver.cache.add_record(ResourceRecord("localhost", Type.A, Class.IN, 60, RecordData("127.0.0.1"))) + resolver.cache.add_record(ResourceRecord("test.com", Type.NS, Class.IN, 60, RecordData(server + ":" + str(portnr)))) + host, alias, ip = resolver.gethostbyname("cnametest.test.com") + self.assertEqual(host, "hello.test.com") + self.assertEqual(alias, ["cnametest.test.com"]) + self.assertEqual(ip, ["1.2.3.4"]) + # solve a query for a FQDN for which your server does not have direct authority, + # yet there is a name server in your zone which does + def test_solve_ns_inside(self): + # I had a bit of trouble figuring this one out + # I think it would require me to set up two servers + resolver = dns.resolver.Resolver(True, 0) + ans = resolver.query_dns("a.subzone.test.com", Type.A, (server, portnr), False) + self.assertEqual(ans.authorities[0].rdata.data, "nameserver.test.com") + self.assertEqual(ans.additionals[0].name, "nameserver.test.com") + # solve a query for a fqdn which points outside your zone + def test_solve_outside(self): + resolver = dns.resolver.Resolver(True, 0) + # resolver.cache.add_record(ResourceRecord("localhost", Type.A, Class.IN, 60, RecordData("127.0.0.1"))) + # resolver.cache.add_record(ResourceRecord("edu", Type.NS, Class.IN, 60, RecordData(server + ":" + str(portnr)))) + # host, alias, ip = resolver.gethostbyname("gaia.cs.umass.edu") + ans = resolver.query_dns("gaia.cs.umass.edu", Type.A, (server, portnr), True) + self.assertEqual(ans.answers[0].rdata.data, "128.119.245.12") + # solve parallel requests for different FQDN, their servicing should be made in parallel and correct + # responses should be generated + def test_solve_parallel(self): + resolver = dns.resolver.Resolver(True, 0) + question1 = dns.message.Question("gaia.cs.umass.edu", Type.A, Class.IN) + question2 = dns.message.Question("yori.cc", Type.A, Class.IN) + query1 = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": True}, [question1], ident=10) + query2 = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": True}, [question2], ident=20) + + timeout = 2 + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(timeout) + + sock.sendto(query1.to_bytes(), (server, portnr)) + sock.sendto(query2.to_bytes(), (server, portnr)) + + # Receive response + for i in range(2): + data = sock.recv(65530) # max MTU is not this value + response = dns.message.Message.from_bytes(data) + if response.header.ident == 10: + self.assertEqual(response.answers[0].rdata.data, "128.119.245.12") + elif response.header.ident == 20: + self.assertEqual(response.answers[0].rdata.data, "138.201.39.105") + else: + self.assertFalse(True) + sock.close() + if __name__ == "__main__": diff --git a/documentation.txt b/documentation.txt new file mode 100644 index 0000000..b96f1f6 --- /dev/null +++ b/documentation.txt @@ -0,0 +1,14 @@ +Name: Yorick van Pelt +student nr: s4503678 + +The general flow of this program is described in the RFC 1034, which has been followed as much as possible. +The DNS server sends authority flags if it is the authority (described in the RFC), and the resolver looks at the authority flags to see if it has found the authoritative nameserver. +I documented the resolver on a state machine on paper, but it is not included here. +I organized the zone into a tree, where the lookup algorithm from the server is implemented (in zone.py) + +- I crafted DNS messages using your own libraries +- I did not have to generate transaction IDs, because every transaction used a different socket. +- I used threads for concurrencies +- I remove expired entries from the cache upon lookup, load and store + +I had some trouble implementing the testcase with where there would be an NS record in the zone, pointing to another nameserver. I implemented it as best as I could, but I think it's only testable with at least two DNS servers. diff --git a/test.zone b/test.zone new file mode 100644 index 0000000..06c38a4 --- /dev/null +++ b/test.zone @@ -0,0 +1,6 @@ +hello.test.com 30 A 1.2.3.4 +*.test.com 30 A 2.3.4.5 +cnametest.test.com 30 CNAME hello.test.com +nameserver.test.com 30 A 8.8.8.8 +subzone.test.com 30 NS nameserver.test.com +