net-dnsserver/dns/resolver.py

259 lines
9.5 KiB
Python

#!/usr/bin/env python2
""" DNS Resolver
This module contains a class for resolving hostnames. You will have to implement
things in this module. This resolver will be both used by the DNS client and the
DNS server, but with a different list of servers.
"""
import socket
from dns.classes import Class
from dns.types import Type
import dns.cache
import dns.message
from dns.rcodes import RCode
from dns.zone import Zone
from collections import defaultdict
import itertools
def dictappend(seq):
"""
Append the [(k,v)] list to a defaultdict.
"""
s = defaultdict(list)
# .update does not work as expected for this.
for (k,v) in seq:
s[k].append(v)
return s
def ipport(name):
"""
Split a name into an (ip, port) pair, defaulting the port to 53
"""
x = name.split(':')
if len(x) == 1:
return x[0], 53
else:
return x[0], int(x[1])
class Resolver(object):
""" DNS resolver """
def __init__(self, caching, ttl):
""" Initialize the resolver
Args:
caching (bool): caching is enabled if True
ttl (int): ttl of cache entries (if > 0)
"""
self.caching = caching
self.ttl = ttl
self.SBELT = Zone()
self.SBELT.read_master_file("named.root")
if caching:
self.cache = dns.cache.RecordCache(self.ttl)
else:
self.cache = dns.cache.DummyCache(self.ttl)
@classmethod
def query(cls, query, dest):
"""
Send a `query` to the destination address at `dest` and get a response back
"""
timeout = 2
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(timeout)
sock.sendto(query.to_bytes(), dest)
# Receive response
data = sock.recv(65530) # max MTU is not this value
response = dns.message.Message.from_bytes(data)
sock.close()
return response
def query_dns(self, hostname, type_, server, recursive=False):
"""
Query a nameserver directly with the arguments as specified
"""
question = dns.message.Question(hostname, type_, Class.IN)
query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": recursive}, [question])
print "asking", server
response = Resolver.query(query, server)
if response.header.rcode == RCode.ServFail:
raise Exception("server failure")
return response
def best_ns_from_cache(self, hostname):
"""
This function finds the nearest nameserver for this hostname in the cache.
"""
hs = hostname
a_list = []
while hs: # look up NS servers in cache with decreasing fqdn
lookedup = self.cache.lookup(hs, Type.NS, Class.IN)
for n in lookedup:
a_list += self.cache.lookup(n.rdata.data, Type.A, Class.IN)
if lookedup:
return (a_list, lookedup)
hs = '.'.join(hs.split('.')[1:])
return [], None
def lookup(self, type_, hostname, additionals=[]):
"""
Look up all of the relevant resource records, either from cache
or by going over the network.
This is used by the server, and also by gethostbyname.
The algorithm is specified in RFC 1034, but copied into the comments here.
"""
SNAME = hostname
aliases = [] + additionals
# 1. See if the answer is in local information, and if so return
# it to the client.
while type_ != Type.CNAME:
# follow cnames
lookedupcnames = self.cache.lookup(SNAME, Type.CNAME, Class.IN)
if lookedupcnames:
#print "found CNAMES", SNAME, lookedupcnames
aliases.append(lookedupcnames[0])
#aliases.append(SNAME)
SNAME = lookedupcnames[0].rdata.data
continue
break
lookedup = self.cache.lookup(SNAME, type_, Class.IN)
if lookedup:
#print "got result from cache", SNAME
return SNAME, aliases, lookedup
# 2. Find the best servers to ask.
SLIST = []
ns_a_list, lookedup = self.best_ns_from_cache(hostname)
if lookedup:
for n in lookedup:
# TODO: use ns_a_list
SLIST.append((False, n.rdata.data))
else:
nameservers = [ns.rdata.data for ns in self.SBELT.lookup('.')[1]]
SLIST = []
for ns in nameservers:
try:
ips = [x.rdata.data for x in self.SBELT.lookup(ns)[1] if x.type_ == Type.A]
SLIST.append((True, ips[0]))
except KeyError:
SLIST.append((False, ns))
while True:
# do the looked up ones first
SLIST.sort(key=lambda a: a[0], reverse=True)
if len(SLIST) == 0:
return [[], [], []]
# state 1
(isLookedUp, ip) = SLIST.pop(0)
if not isLookedUp:
# XXX: this can get into an infinite loop
#print "looking up name server"
ip,port = ipport(ip)
ns_hostname, ns_aliases, ns_ips = self.gethostbyname(ip)
ns_ips = [i + ":" + str(port) for i in ns_ips]
if len(ns_ips):
ip = ns_ips[0]
SLIST = ns_ips[1:] + SLIST
else:
continue # didn't find the server
#print "asking " + ip
# state 2
# 3. Send them queries until one returns a response.
#print "querying", ip, "with", hostname
question = dns.message.Question(hostname, type_, Class.IN)
query = dns.message.makepacket({"qr": 0, "opcode": 0, "rd": 0}, [question])
# state 3
try:
response = self.query_dns(hostname, type_, ipport(ip), False)
except RuntimeError as e:
print e
# d. if the response shows a servers failure or other
# bizarre contents, delete the server from the SLIST and
# go back to step 3.
continue # edge 3 -> 1
# 4. Analyze the response, either:
# a. if the response answers the question or contains a name
# error, cache the data as well as returning it back to
# the client.
if response.header.rcode == dns.rcodes.RCode.NXDomain:
print "NXDOMAIN", SNAME
return ([],[],[]) # edge 3 -> 6
for ans in (response.answers + response.additionals + response.authorities):
self.cache.add_record(ans)
if response.header.aa:
# state 5
answers = dictappend([(ans.name, ans) for ans in response.answers if ans.type_ == type_])
cnames = dictappend([(ans.name, ans) for ans in (response.answers + response.additionals) if ans.type_ == Type.CNAME])
if len(answers) + len(cnames) == 0:
return SNAME, aliases, []
while True:
if SNAME in cnames and type_ != Type.CNAME:
# state 7
# c. if the response shows a CNAME and that is not the
# answer itself, cache the CNAME, change the SNAME to the
# canonical name in the CNAME RR and go to step 1.
#if cnames[SNAME][0] in aliases: # XXX Fix this
#print "not following cname loop"
# return SNAME, aliases, []
aliases.append(cnames[SNAME][0])
SNAME = cnames[SNAME][0].rdata.data
#print "following CNAME", SNAME
if SNAME in answers:
return SNAME, aliases, answers[SNAME]
break # 5 -> 5'
# recursing
# 5 -> 0
return self.lookup(type_, SNAME, aliases)
# b. if the response contains a better delegation to other
# servers, cache the delegation information, and go to
# step 2.
if response.authorities:
# state 4
cands = [auth.rdata.data for auth in response.authorities]
d = {}
for add in response.additionals:
if add.type_ == Type.A:
d[add.name] = add.rdata.data
SLIST = [(True, d[x]) if x in d else (False, x) for x in cands]
#print "got authorities", SLIST
continue # 4 -> 1
# NOT REACHED
#print "NOT REACHED"
def gethostbyname(self, hostname, additionals=[]):
""" Translate a host name to IPv4 address.
Args:
hostname (str): the hostname to resolve
Returns:
(str, [str], [str]): (hostname, aliaslist, ipaddrlist)
"""
hostname, aliases, responses = self.lookup(Type.A, hostname, additionals)
aliases = [a.name for a in aliases]
responses = [r.rdata.data for r in responses]
return hostname, aliases, responses