diff options
| author | Shulhan <ms@kilabit.info> | 2019-04-10 17:17:28 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-04-12 19:14:02 +0700 |
| commit | 47a096dc764c4a83fbdb96b04237d5851e86fab2 (patch) | |
| tree | a821a3f0587fd4fd855bac8ab4db753777ca0480 | |
| parent | f43e8fd05681ee455c93b464ed5586a83de71fc3 (diff) | |
| download | pakakeh.go-47a096dc764c4a83fbdb96b04237d5851e86fab2.tar.xz | |
dns/server: implement recursion, forwarding request to parent name servers
The forwarding routines will be running only if there is at least one
valid NameServers on ServerOptions.
The request will be forwarded only if IsRD (is recursion desired) flag is
set.
| -rw-r--r-- | lib/dns/dns_test.go | 3 | ||||
| -rw-r--r-- | lib/dns/server.go | 192 | ||||
| -rw-r--r-- | lib/dns/udpclient_test.go | 28 |
3 files changed, 198 insertions, 25 deletions
diff --git a/lib/dns/dns_test.go b/lib/dns/dns_test.go index e54b7e66..e6a88ff7 100644 --- a/lib/dns/dns_test.go +++ b/lib/dns/dns_test.go @@ -39,6 +39,9 @@ func TestMain(m *testing.M) { DoHPort: 8443, DoHCertificate: &cert, DoHAllowInsecure: true, + NameServers: []string{ + "https://cloudflare-dns.com/dns-query", + }, } _testServer, err = NewServer(serverOptions) diff --git a/lib/dns/server.go b/lib/dns/server.go index eba74276..d41fd6fc 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -5,6 +5,7 @@ package dns import ( + "bytes" "crypto/tls" "encoding/base64" "fmt" @@ -49,6 +50,9 @@ type Server struct { doh *http.Server requestq chan *Request + forwardq chan *Request + + hasForwarders bool } // @@ -63,6 +67,7 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) { srv = &Server{ opts: opts, requestq: make(chan *Request, 512), + forwardq: make(chan *Request, 512), } udpAddr := opts.getUDPAddress() @@ -86,6 +91,30 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) { } // +// isResponseValid check if request name, type, and class match with response. +// It will return true if both matched, otherwise it will return false. +// +func isResponseValid(req *Request, res *Message) bool { + if !bytes.Equal(req.Message.Question.Name, res.Question.Name) { + log.Printf("dns: unmatched response name, got %s want %s\n", + req.Message.Question.Name, res.Question.Name) + return false + } + if req.Message.Question.Type != res.Question.Type { + log.Printf("dns: unmatched response type, got %s want %s\n", + req.Message.Question, res.Question) + return false + } + if req.Message.Question.Class != res.Question.Class { + log.Printf("dns: unmatched response class, got %s want %s\n", + req.Message.Question, res.Question) + return false + } + + return true +} + +// // LoadHostsDir populate caches with DNS record from hosts formatted files in // directory "dir". // @@ -226,6 +255,8 @@ func (srv *Server) populateCaches(msgs []*Message) { // Start the server, listening and serve query from clients. // func (srv *Server) Start() { + srv.runForwarders() + go srv.processRequest() if srv.opts.DoHCertificate != nil { @@ -466,7 +497,8 @@ func (srv *Server) serveTCPClient(cl *TCPClient) { // func (srv *Server) processRequest() { var ( - resp []byte + res *Message + isLocal bool ) for req := range srv.requestq { @@ -475,34 +507,160 @@ func (srv *Server) processRequest() { req.Message.Question.Class) if ans == nil { + if req.Message.Header.IsRD && srv.hasForwarders { + srv.forwardq <- req + continue + } + req.Message.SetResponseCode(RCodeErrName) } + + isLocal = false if an == nil { + if req.Message.Header.IsRD && srv.hasForwarders { + srv.forwardq <- req + continue + } + req.Message.SetQuery(false) req.Message.SetAuthorativeAnswer(true) - resp = req.Message.Packet + res = req.Message + isLocal = true } else { + if an.msg.IsExpired() && srv.hasForwarders { + srv.forwardq <- req + continue + } + an.msg.SetID(req.Message.Header.ID) - resp = an.get() + res = an.msg + isLocal = (an.receivedAt == 0) } - switch req.Kind { - case ConnTypeUDP, ConnTypeTCP: - if req.Sender != nil { - _, err := req.Sender.Send(resp, req.UDPAddr) - if err != nil { - log.Println("dns: processRequest: Sender.Send:" + err.Error()) - } + srv.processResponse(req, res, isLocal) + } +} + +func (srv *Server) processResponse(req *Request, res *Message, isLocal bool) { + if !isLocal { + if !isResponseValid(req, res) { + return + } + } + + switch req.Kind { + case ConnTypeUDP: + if req.Sender != nil { + _, err := req.Sender.Send(res.Packet, req.UDPAddr) + if err != nil { + log.Println("dns: failed to send UDP reply:", err) + return } + } + + case ConnTypeTCP: + if req.Sender != nil { + _, err := req.Sender.Send(res.Packet, nil) + if err != nil { + log.Println("dns: failed to send TCP reply:", err) + return + } + } - case ConnTypeDoH: - if req.ResponseWriter != nil { - _, err := req.ResponseWriter.Write(resp) - if err != nil { - log.Println("dns: processRequest: ResponseWriter.Write:", err) - } - req.ChanResponded <- true + case ConnTypeDoH: + if req.ResponseWriter != nil { + _, err := req.ResponseWriter.Write(res.Packet) + req.ChanResponded <- true + if err != nil { + log.Println("dns: failed to send DoH reply:", err) + return } } } + + if !isLocal { + if res.Header.RCode != 0 { + log.Printf("dns: response error %s, code: %s\n", + res.Question, rcodeNames[res.Header.RCode]) + return + } + + an := newAnswer(res, false) + srv.caches.upsert(an) + } +} + +func (srv *Server) runForwarders() { + nforwarders := 0 + for x := 0; x < len(srv.opts.udpServers); x++ { + go srv.runUDPForwarder(srv.opts.udpServers[x]) + nforwarders++ + } + + for x := 0; x < len(srv.opts.tcpServers); x++ { + go srv.runTCPForwarder(srv.opts.tcpServers[x]) + nforwarders++ + } + + for x := 0; x < len(srv.opts.dohServers); x++ { + go srv.runDohForwarder(srv.opts.dohServers[x]) + nforwarders++ + } + + if nforwarders > 0 { + srv.hasForwarders = true + } +} + +func (srv *Server) runDohForwarder(nameserver string) { + forwarder, err := NewDoHClient(nameserver, false) + if err != nil { + log.Fatal("dns: failed to create DoH forwarder: " + err.Error()) + } + + for req := range srv.forwardq { + res, err := forwarder.Query(req.Message, nil) + if err != nil { + log.Println("dns: failed to query DoH: " + err.Error()) + continue + } + + srv.processResponse(req, res, false) + } +} + +func (srv *Server) runTCPForwarder(remoteAddr *net.TCPAddr) { + for req := range srv.forwardq { + cl, err := NewTCPClient(remoteAddr.String()) + if err != nil { + log.Println("dns: failed to create TCP client: " + err.Error()) + continue + } + + res, err := cl.Query(req.Message, nil) + cl.Close() + if err != nil { + log.Println("dns: failed to query TCP: " + err.Error()) + continue + } + + srv.processResponse(req, res, false) + } +} + +func (srv *Server) runUDPForwarder(remoteAddr *net.UDPAddr) { + forwarder, err := NewUDPClient(remoteAddr.String()) + if err != nil { + log.Fatal("dns: failed to create UDP forwarder: " + err.Error()) + } + + for req := range srv.forwardq { + res, err := forwarder.Query(req.Message, remoteAddr) + if err != nil { + log.Println("dns: failed to query UDP: " + err.Error()) + continue + } + + srv.processResponse(req, res, false) + } } diff --git a/lib/dns/udpclient_test.go b/lib/dns/udpclient_test.go index d029ed4e..61aa0a82 100644 --- a/lib/dns/udpclient_test.go +++ b/lib/dns/udpclient_test.go @@ -137,6 +137,18 @@ func TestUDPClientLookup(t *testing.T) { Authority: []*ResourceRecord{}, Additional: []*ResourceRecord{}, }, + }, { + desc: "IsRD:true QType:AAAA QClass:IN QName:kilabit.info", + allowRecursion: true, + qtype: QueryTypeAAAA, + qclass: QueryClassIN, + qname: []byte("kilabit.info"), + }, { + desc: "IsRD:true QType:A QClass:IN QName:random", + allowRecursion: true, + qtype: QueryTypeA, + qclass: QueryClassIN, + qname: []byte("random"), }} for _, c := range cases { @@ -147,17 +159,17 @@ func TestUDPClientLookup(t *testing.T) { t.Fatal(err) } - c.exp.Header.ID = getID() + if !c.allowRecursion { + c.exp.Header.ID = getID() - _, err = c.exp.Pack() - if err != nil { - t.Fatal(err) - } + _, err = c.exp.Pack() + if err != nil { + t.Fatal(err) + } - if c.allowRecursion { - t.Logf("got recursive answer: %+v\n", got) - } else { test.Assert(t, "Packet", c.exp.Packet, got.Packet, true) + } else { + t.Logf("Got recursive answer: %+v\n", got) } } } |
