#!/usr/bin/env python2 """ DNS Resolver This module contains a class for resolving hostnames. You will have to implement things in this module. This resolver will be both used by the DNS client and the DNS server, but with a different list of servers. """ import socket from dns.classes import Class from dns.types import Type import dns.cache import dns.message from dns.rcodes import RCode from dns.zone import Zone from collections import defaultdict import itertools def dictappend(seq): """ Append the [(k,v)] list to a defaultdict. """ s = defaultdict(list) # .update does not work as expected for this. for (k,v) in seq: s[k].append(v) return s def ipport(name): """ Split a name into an (ip, port) pair, defaulting the port to 53 """ x = name.split(':') if len(x) == 1: return x[0], 53 else: return x[0], int(x[1]) class Resolver(object): """ DNS resolver """ def __init__(self, caching, ttl): """ Initialize the resolver Args: caching (bool): caching is enabled if True ttl (int): ttl of cache entries (if > 0) """ self.caching = caching self.ttl = ttl self.SBELT = Zone() self.SBELT.read_master_file("named.root") if caching: self.cache = dns.cache.RecordCache(self.ttl) else: self.cache = dns.cache.DummyCache(self.ttl) @classmethod def query(cls, query, dest): """ Send a `query` to the destination address at `dest` and get a response back """ timeout = 2 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) sock.sendto(query.to_bytes(), dest) # Receive response data = sock.recv(65530) # max MTU is not this value response = dns.message.Message.from_bytes(data) sock.close() return response def query_dns(self, hostname, type_, server, recursive=False): """ Query a nameserver directly with the arguments as specified """ question = dns.message.Question(hostname, type_, Class.IN) query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": recursive}, [question]) print "asking", server response = Resolver.query(query, server) if response.header.rcode == RCode.ServFail: raise Exception("server failure") return response def best_ns_from_cache(self, hostname): """ This function finds the nearest nameserver for this hostname in the cache. """ hs = hostname a_list = [] while hs: # look up NS servers in cache with decreasing fqdn lookedup = self.cache.lookup(hs, Type.NS, Class.IN) for n in lookedup: a_list += self.cache.lookup(n.rdata.data, Type.A, Class.IN) if lookedup: return (a_list, lookedup) hs = '.'.join(hs.split('.')[1:]) return [], None def lookup(self, type_, hostname, additionals=[]): """ Look up all of the relevant resource records, either from cache or by going over the network. This is used by the server, and also by gethostbyname. The algorithm is specified in RFC 1034, but copied into the comments here. """ SNAME = hostname aliases = [] + additionals # 1. See if the answer is in local information, and if so return # it to the client. while type_ != Type.CNAME: # follow cnames lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN) if lookedupcnames: #print "found CNAMES", SNAME, lookedupcnames aliases.append(lookedupcnames[0]) #aliases.append(SNAME) SNAME = lookedupcnames[0].rdata.data continue break lookedup = self.cache.lookup(SNAME, type_, Class.IN) if lookedup: #print "got result from cache", SNAME return SNAME, aliases, lookedup # 2. Find the best servers to ask. SLIST = [] ns_a_list, lookedup = self.best_ns_from_cache(hostname) if lookedup: for n in lookedup: # TODO: use ns_a_list SLIST.append((False, n.rdata.data)) else: nameservers = [ns.rdata.data for ns in self.SBELT.lookup('.')[1]] SLIST = [] for ns in nameservers: try: ips = [x.rdata.data for x in self.SBELT.lookup(ns)[1] if x.type_ == Type.A] SLIST.append((True, ips[0])) except KeyError: SLIST.append((False, ns)) while True: # do the looked up ones first SLIST.sort(key=lambda a: a[0], reverse=True) if len(SLIST) == 0: return [[], [], []] # state 1 (isLookedUp, ip) = SLIST.pop(0) if not isLookedUp: # XXX: this can get into an infinite loop #print "looking up name server" ip,port = ipport(ip) ns_hostname, ns_aliases, ns_ips = self.gethostbyname(ip) ns_ips = [i + ":" + str(port) for i in ns_ips] if len(ns_ips): ip = ns_ips[0] SLIST = ns_ips[1:] + SLIST else: continue # didn't find the server #print "asking " + ip # state 2 # 3. Send them queries until one returns a response. #print "querying", ip, "with", hostname question = dns.message.Question(hostname, type_, Class.IN) query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question]) # state 3 try: response = self.query_dns(hostname, type_, ipport(ip), False) except RuntimeError as e: print e # d. if the response shows a servers failure or other # bizarre contents, delete the server from the SLIST and # go back to step 3. continue # edge 3 -> 1 # 4. Analyze the response, either: # a. if the response answers the question or contains a name # error, cache the data as well as returning it back to # the client. if response.header.rcode == dns.rcodes.RCode.NXDomain: print "NXDOMAIN", SNAME return ([],[],[]) # edge 3 -> 6 for ans in (response.answers + response.additionals + response.authorities): self.cache.add_record(ans) if response.header.aa: # state 5 answers = dictappend([(ans.name, ans) for ans in response.answers if ans.type_ == type_]) cnames = dictappend([(ans.name, ans) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME]) if len(answers) + len(cnames) == 0: return SNAME, aliases, [] while True: if SNAME in cnames and type_ != Type.CNAME: # state 7 # c. if the response shows a CNAME and that is not the # answer itself, cache the CNAME, change the SNAME to the # canonical name in the CNAME RR and go to step 1. #if cnames[SNAME][0] in aliases: # XXX Fix this #print "not following cname loop" # return SNAME, aliases, [] aliases.append(cnames[SNAME][0]) SNAME = cnames[SNAME][0].rdata.data #print "following CNAME", SNAME if SNAME in answers: return SNAME, aliases, answers[SNAME] break # 5 -> 5' # recursing # 5 -> 0 return self.lookup(type_, SNAME, aliases) # b. if the response contains a better delegation to other # servers, cache the delegation information, and go to # step 2. if response.authorities: # state 4 cands = [auth.rdata.data for auth in response.authorities] d = {} for add in response.additionals: if add.type_ == Type.A: d[add.name] = add.rdata.data SLIST = [(True, d[x]) if x in d else (False, x) for x in cands] #print "got authorities", SLIST continue # 4 -> 1 # NOT REACHED #print "NOT REACHED" def gethostbyname(self, hostname, additionals=[]): """ Translate a host name to IPv4 address. Args: hostname (str): the hostname to resolve Returns: (str, [str], [str]): (hostname, aliaslist, ipaddrlist) """ hostname, aliases, responses = self.lookup(Type.A, hostname, additionals) aliases = [a.name for a in aliases] responses = [r.rdata.data for r in responses] return hostname, aliases, responses