aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-04-10 17:17:28 +0700
committerShulhan <ms@kilabit.info>2019-04-12 19:14:02 +0700
commit47a096dc764c4a83fbdb96b04237d5851e86fab2 (patch)
treea821a3f0587fd4fd855bac8ab4db753777ca0480
parentf43e8fd05681ee455c93b464ed5586a83de71fc3 (diff)
downloadpakakeh.go-47a096dc764c4a83fbdb96b04237d5851e86fab2.tar.xz
dns/server: implement recursion, forwarding request to parent name servers
The forwarding routines will be running only if there is at least one valid NameServers on ServerOptions. The request will be forwarded only if IsRD (is recursion desired) flag is set.
-rw-r--r--lib/dns/dns_test.go3
-rw-r--r--lib/dns/server.go192
-rw-r--r--lib/dns/udpclient_test.go28
3 files changed, 198 insertions, 25 deletions
diff --git a/lib/dns/dns_test.go b/lib/dns/dns_test.go
index e54b7e66..e6a88ff7 100644
--- a/lib/dns/dns_test.go
+++ b/lib/dns/dns_test.go
@@ -39,6 +39,9 @@ func TestMain(m *testing.M) {
DoHPort: 8443,
DoHCertificate: &cert,
DoHAllowInsecure: true,
+ NameServers: []string{
+ "https://cloudflare-dns.com/dns-query",
+ },
}
_testServer, err = NewServer(serverOptions)
diff --git a/lib/dns/server.go b/lib/dns/server.go
index eba74276..d41fd6fc 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -5,6 +5,7 @@
package dns
import (
+ "bytes"
"crypto/tls"
"encoding/base64"
"fmt"
@@ -49,6 +50,9 @@ type Server struct {
doh *http.Server
requestq chan *Request
+ forwardq chan *Request
+
+ hasForwarders bool
}
//
@@ -63,6 +67,7 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) {
srv = &Server{
opts: opts,
requestq: make(chan *Request, 512),
+ forwardq: make(chan *Request, 512),
}
udpAddr := opts.getUDPAddress()
@@ -86,6 +91,30 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) {
}
//
+// isResponseValid check if request name, type, and class match with response.
+// It will return true if both matched, otherwise it will return false.
+//
+func isResponseValid(req *Request, res *Message) bool {
+ if !bytes.Equal(req.Message.Question.Name, res.Question.Name) {
+ log.Printf("dns: unmatched response name, got %s want %s\n",
+ req.Message.Question.Name, res.Question.Name)
+ return false
+ }
+ if req.Message.Question.Type != res.Question.Type {
+ log.Printf("dns: unmatched response type, got %s want %s\n",
+ req.Message.Question, res.Question)
+ return false
+ }
+ if req.Message.Question.Class != res.Question.Class {
+ log.Printf("dns: unmatched response class, got %s want %s\n",
+ req.Message.Question, res.Question)
+ return false
+ }
+
+ return true
+}
+
+//
// LoadHostsDir populate caches with DNS record from hosts formatted files in
// directory "dir".
//
@@ -226,6 +255,8 @@ func (srv *Server) populateCaches(msgs []*Message) {
// Start the server, listening and serve query from clients.
//
func (srv *Server) Start() {
+ srv.runForwarders()
+
go srv.processRequest()
if srv.opts.DoHCertificate != nil {
@@ -466,7 +497,8 @@ func (srv *Server) serveTCPClient(cl *TCPClient) {
//
func (srv *Server) processRequest() {
var (
- resp []byte
+ res *Message
+ isLocal bool
)
for req := range srv.requestq {
@@ -475,34 +507,160 @@ func (srv *Server) processRequest() {
req.Message.Question.Class)
if ans == nil {
+ if req.Message.Header.IsRD && srv.hasForwarders {
+ srv.forwardq <- req
+ continue
+ }
+
req.Message.SetResponseCode(RCodeErrName)
}
+
+ isLocal = false
if an == nil {
+ if req.Message.Header.IsRD && srv.hasForwarders {
+ srv.forwardq <- req
+ continue
+ }
+
req.Message.SetQuery(false)
req.Message.SetAuthorativeAnswer(true)
- resp = req.Message.Packet
+ res = req.Message
+ isLocal = true
} else {
+ if an.msg.IsExpired() && srv.hasForwarders {
+ srv.forwardq <- req
+ continue
+ }
+
an.msg.SetID(req.Message.Header.ID)
- resp = an.get()
+ res = an.msg
+ isLocal = (an.receivedAt == 0)
}
- switch req.Kind {
- case ConnTypeUDP, ConnTypeTCP:
- if req.Sender != nil {
- _, err := req.Sender.Send(resp, req.UDPAddr)
- if err != nil {
- log.Println("dns: processRequest: Sender.Send:" + err.Error())
- }
+ srv.processResponse(req, res, isLocal)
+ }
+}
+
+func (srv *Server) processResponse(req *Request, res *Message, isLocal bool) {
+ if !isLocal {
+ if !isResponseValid(req, res) {
+ return
+ }
+ }
+
+ switch req.Kind {
+ case ConnTypeUDP:
+ if req.Sender != nil {
+ _, err := req.Sender.Send(res.Packet, req.UDPAddr)
+ if err != nil {
+ log.Println("dns: failed to send UDP reply:", err)
+ return
}
+ }
+
+ case ConnTypeTCP:
+ if req.Sender != nil {
+ _, err := req.Sender.Send(res.Packet, nil)
+ if err != nil {
+ log.Println("dns: failed to send TCP reply:", err)
+ return
+ }
+ }
- case ConnTypeDoH:
- if req.ResponseWriter != nil {
- _, err := req.ResponseWriter.Write(resp)
- if err != nil {
- log.Println("dns: processRequest: ResponseWriter.Write:", err)
- }
- req.ChanResponded <- true
+ case ConnTypeDoH:
+ if req.ResponseWriter != nil {
+ _, err := req.ResponseWriter.Write(res.Packet)
+ req.ChanResponded <- true
+ if err != nil {
+ log.Println("dns: failed to send DoH reply:", err)
+ return
}
}
}
+
+ if !isLocal {
+ if res.Header.RCode != 0 {
+ log.Printf("dns: response error %s, code: %s\n",
+ res.Question, rcodeNames[res.Header.RCode])
+ return
+ }
+
+ an := newAnswer(res, false)
+ srv.caches.upsert(an)
+ }
+}
+
+func (srv *Server) runForwarders() {
+ nforwarders := 0
+ for x := 0; x < len(srv.opts.udpServers); x++ {
+ go srv.runUDPForwarder(srv.opts.udpServers[x])
+ nforwarders++
+ }
+
+ for x := 0; x < len(srv.opts.tcpServers); x++ {
+ go srv.runTCPForwarder(srv.opts.tcpServers[x])
+ nforwarders++
+ }
+
+ for x := 0; x < len(srv.opts.dohServers); x++ {
+ go srv.runDohForwarder(srv.opts.dohServers[x])
+ nforwarders++
+ }
+
+ if nforwarders > 0 {
+ srv.hasForwarders = true
+ }
+}
+
+func (srv *Server) runDohForwarder(nameserver string) {
+ forwarder, err := NewDoHClient(nameserver, false)
+ if err != nil {
+ log.Fatal("dns: failed to create DoH forwarder: " + err.Error())
+ }
+
+ for req := range srv.forwardq {
+ res, err := forwarder.Query(req.Message, nil)
+ if err != nil {
+ log.Println("dns: failed to query DoH: " + err.Error())
+ continue
+ }
+
+ srv.processResponse(req, res, false)
+ }
+}
+
+func (srv *Server) runTCPForwarder(remoteAddr *net.TCPAddr) {
+ for req := range srv.forwardq {
+ cl, err := NewTCPClient(remoteAddr.String())
+ if err != nil {
+ log.Println("dns: failed to create TCP client: " + err.Error())
+ continue
+ }
+
+ res, err := cl.Query(req.Message, nil)
+ cl.Close()
+ if err != nil {
+ log.Println("dns: failed to query TCP: " + err.Error())
+ continue
+ }
+
+ srv.processResponse(req, res, false)
+ }
+}
+
+func (srv *Server) runUDPForwarder(remoteAddr *net.UDPAddr) {
+ forwarder, err := NewUDPClient(remoteAddr.String())
+ if err != nil {
+ log.Fatal("dns: failed to create UDP forwarder: " + err.Error())
+ }
+
+ for req := range srv.forwardq {
+ res, err := forwarder.Query(req.Message, remoteAddr)
+ if err != nil {
+ log.Println("dns: failed to query UDP: " + err.Error())
+ continue
+ }
+
+ srv.processResponse(req, res, false)
+ }
}
diff --git a/lib/dns/udpclient_test.go b/lib/dns/udpclient_test.go
index d029ed4e..61aa0a82 100644
--- a/lib/dns/udpclient_test.go
+++ b/lib/dns/udpclient_test.go
@@ -137,6 +137,18 @@ func TestUDPClientLookup(t *testing.T) {
Authority: []*ResourceRecord{},
Additional: []*ResourceRecord{},
},
+ }, {
+ desc: "IsRD:true QType:AAAA QClass:IN QName:kilabit.info",
+ allowRecursion: true,
+ qtype: QueryTypeAAAA,
+ qclass: QueryClassIN,
+ qname: []byte("kilabit.info"),
+ }, {
+ desc: "IsRD:true QType:A QClass:IN QName:random",
+ allowRecursion: true,
+ qtype: QueryTypeA,
+ qclass: QueryClassIN,
+ qname: []byte("random"),
}}
for _, c := range cases {
@@ -147,17 +159,17 @@ func TestUDPClientLookup(t *testing.T) {
t.Fatal(err)
}
- c.exp.Header.ID = getID()
+ if !c.allowRecursion {
+ c.exp.Header.ID = getID()
- _, err = c.exp.Pack()
- if err != nil {
- t.Fatal(err)
- }
+ _, err = c.exp.Pack()
+ if err != nil {
+ t.Fatal(err)
+ }
- if c.allowRecursion {
- t.Logf("got recursive answer: %+v\n", got)
- } else {
test.Assert(t, "Packet", c.exp.Packet, got.Packet, true)
+ } else {
+ t.Logf("Got recursive answer: %+v\n", got)
}
}
}