215 lines
7.8 KiB
Python
215 lines
7.8 KiB
Python
#!/usr/bin/env python2
|
|
|
|
""" DNS Resolver
|
|
|
|
This module contains a class for resolving hostnames. You will have to implement
|
|
things in this module. This resolver will be both used by the DNS client and the
|
|
DNS server, but with a different list of servers.
|
|
"""
|
|
|
|
import socket
|
|
|
|
from dns.classes import Class
|
|
from dns.types import Type
|
|
|
|
import dns.cache
|
|
import dns.message
|
|
from dns.rcodes import RCode
|
|
|
|
from dns.zone import Zone
|
|
|
|
from collections import defaultdict
|
|
import itertools
|
|
|
|
|
|
def dictappend(seq):
|
|
s = defaultdict(list)
|
|
# update doesn't work
|
|
for (k,v) in seq:
|
|
s[k].append(v)
|
|
return s
|
|
|
|
class Resolver(object):
|
|
""" DNS resolver """
|
|
|
|
def __init__(self, caching, ttl):
|
|
""" Initialize the resolver
|
|
|
|
Args:
|
|
caching (bool): caching is enabled if True
|
|
ttl (int): ttl of cache entries (if > 0)
|
|
"""
|
|
self.caching = caching
|
|
self.ttl = ttl
|
|
self.SBELT = Zone()
|
|
self.SBELT.read_master_file("named.root")
|
|
if caching:
|
|
self.cache = dns.cache.RecordCache(self.ttl)
|
|
else:
|
|
self.cache = dns.cache.DummyCache(self.ttl)
|
|
|
|
@classmethod
|
|
def query(cls, query, dest):
|
|
timeout = 2
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(timeout)
|
|
|
|
sock.sendto(query.to_bytes(), dest)
|
|
|
|
# Receive response
|
|
data = sock.recv(65530) # max MTU is not this value
|
|
response = dns.message.Message.from_bytes(data)
|
|
sock.close()
|
|
return response
|
|
|
|
@classmethod
|
|
def makepacket(cls, flags, qs=[], ans=[], ns=[], ar=[], ident=9001):
|
|
header = dns.message.Header(ident, 0, len(qs), len(ans), len(ns), len(ar))
|
|
for (k,v) in flags.items():
|
|
setattr(header, k, v)
|
|
return dns.message.Message(header, qs, ans, ns, ar)
|
|
|
|
def gethostbyname(self, hostname, additionals=[]):
|
|
""" Translate a host name to IPv4 address.
|
|
|
|
Currently this method contains an example. You will have to replace
|
|
this example with example with the algorithm described in section
|
|
5.3.3 in RFC 1034.
|
|
|
|
Args:
|
|
hostname (str): the hostname to resolve
|
|
|
|
Returns:
|
|
(str, [str], [str]): (hostname, aliaslist, ipaddrlist)
|
|
"""
|
|
SNAME = hostname
|
|
|
|
aliases = [] + additionals
|
|
|
|
# 1. See if the answer is in local information, and if so return
|
|
# it to the client.
|
|
|
|
while True:
|
|
# follow cnames
|
|
lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN)
|
|
if lookedupcnames:
|
|
#print "found CNAMES", SNAME, lookedupcnames
|
|
SNAME = lookedupcnames[0].rdata.data
|
|
aliases.append(SNAME)
|
|
continue
|
|
break
|
|
lookedup = self.cache.lookup(SNAME, Type.A, Class.IN)
|
|
if lookedup:
|
|
#print "got result from cache", SNAME
|
|
return SNAME, aliases, [d.rdata.data for d in lookedup]
|
|
|
|
# 2. Find the best servers to ask.
|
|
|
|
hs = hostname
|
|
SLIST = []
|
|
while hs: # look up NS servers in cache with decreasing fqdn
|
|
lookedup = self.cache.lookup(hs, Type.NS, Class.IN)
|
|
for n in lookedup:
|
|
SLIST.append((False, n.rdata.data))
|
|
if lookedup:
|
|
#print "got NS from cache", hs
|
|
break
|
|
hs = '.'.join(hs.split('.')[1:])
|
|
else:
|
|
nameservers = [ns.rdata.data for ns in self.SBELT.records['.']]
|
|
SLIST = []
|
|
for ns in nameservers:
|
|
try:
|
|
SLIST.append((True, self.SBELT.records[ns][0].rdata.data))
|
|
except KeyError:
|
|
SLIST.append((False, ns))
|
|
|
|
while True:
|
|
SLIST.sort(key=lambda a: a[0], reverse=True)
|
|
if len(SLIST) == 0:
|
|
return [[], [], []]
|
|
|
|
# state 1
|
|
(isLookedUp, ip) = SLIST.pop(0)
|
|
|
|
if not isLookedUp:
|
|
# XXX: this can get into an infinite loop
|
|
#print "looking up name server"
|
|
ns_hostname, ns_aliases, ns_ips = self.gethostbyname(ip)
|
|
if len(ns_ips):
|
|
ip = ns_ips[0]
|
|
SLIST = ns_ips[1:] + SLIST
|
|
else:
|
|
continue # didn't find the server
|
|
|
|
#print "asking " + ip
|
|
# state 2
|
|
# 3. Send them queries until one returns a response.
|
|
#print "querying", ip, "with", hostname
|
|
question = dns.message.Question(hostname, Type.A, Class.IN)
|
|
query = Resolver.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question])
|
|
# state 3
|
|
try:
|
|
response = Resolver.query(query, (ip, 53))
|
|
# temp
|
|
if response.header.rcode == RCode.ServFail:
|
|
raise Exception("server failure")
|
|
except Exception as e:
|
|
print e
|
|
# d. if the response shows a servers failure or other
|
|
# bizarre contents, delete the server from the SLIST and
|
|
# go back to step 3.
|
|
continue # edge 3 -> 1
|
|
|
|
# 4. Analyze the response, either:
|
|
|
|
# a. if the response answers the question or contains a name
|
|
# error, cache the data as well as returning it back to
|
|
# the client.
|
|
if response.header.rcode == dns.rcodes.RCode.NXDomain:
|
|
print "NXDOMAIN"
|
|
return ([],[],[]) # edge 3 -> 6
|
|
|
|
for ans in (response.answers + response.additionals + response.authorities):
|
|
self.cache.add_record(ans)
|
|
|
|
if response.header.aa:
|
|
# state 5
|
|
answers = dictappend([(ans.name, ans.rdata.data) for ans in response.answers if ans.type_ == Type.A])
|
|
cnames = dictappend([(ans.name, ans.rdata.data) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME])
|
|
while True:
|
|
if SNAME in cnames:
|
|
# state 7
|
|
# c. if the response shows a CNAME and that is not the
|
|
# answer itself, cache the CNAME, change the SNAME to the
|
|
# canonical name in the CNAME RR and go to step 1.
|
|
if cnames[SNAME][0] in aliases:
|
|
#print "not following cname loop"
|
|
return SNAME, aliases, []
|
|
aliases.append(cnames[SNAME][0])
|
|
SNAME = cnames[SNAME][0]
|
|
#print "following CNAME", SNAME
|
|
if SNAME in answers:
|
|
return SNAME, aliases, answers[SNAME]
|
|
break # 5 -> 5'
|
|
# recursing
|
|
# 5 -> 0
|
|
return self.gethostbyname(SNAME, aliases)
|
|
|
|
|
|
# b. if the response contains a better delegation to other
|
|
# servers, cache the delegation information, and go to
|
|
# step 2.
|
|
if response.authorities:
|
|
# state 4
|
|
cands = [auth.rdata.data for auth in response.authorities]
|
|
d = {}
|
|
for add in response.additionals:
|
|
if add.type_ == Type.A:
|
|
d[add.name] = add.rdata.data
|
|
SLIST = [(True, d[x]) if x in d else (False, x) for x in cands]
|
|
#print "got authorities", SLIST
|
|
continue # 4 -> 1
|
|
# NOT REACHED
|
|
#print "NOT REACHED"
|