diff options
| author | Shulhan <ms@kilabit.info> | 2019-06-16 14:36:27 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-06-16 14:36:27 +0700 |
| commit | 45432e60ecbc4c2bab43a9912ee7de108befd13b (patch) | |
| tree | 79ae1f3a2eafd5a7702f219bdd9ba5987cb89029 | |
| parent | 8d5ffba2d9483f049385d37bc468f2a114548eb0 (diff) | |
| download | pakakeh.go-45432e60ecbc4c2bab43a9912ee7de108befd13b.tar.xz | |
dns: add method to restart forwarders
The RestartForwarders method allow server to change the parent nameserver
address to new one. An example of use case is when system change the
network through WiFi, which cause the nameserver address also change
in the resolv.conf file.
| -rw-r--r-- | lib/dns/server.go | 172 | ||||
| -rw-r--r-- | lib/dns/serveroptions.go | 16 | ||||
| -rw-r--r-- | lib/dns/serveroptions_test.go | 6 |
3 files changed, 144 insertions, 50 deletions
diff --git a/lib/dns/server.go b/lib/dns/server.go index 248ee44c..67b62350 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -17,6 +17,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/shuLhan/share/lib/debug" ) @@ -73,7 +74,14 @@ type Server struct { requestq chan *request forwardq chan *request + // fwGroup maintain reference counting for all forwarders. + fwGroup *sync.WaitGroup + hasForwarders bool + + // isForwarding define a state that allow forwarding to run or to + // stop. + isForwarding bool } // @@ -89,6 +97,7 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) { opts: opts, requestq: make(chan *request, 512), forwardq: make(chan *request, 512), + fwGroup: &sync.WaitGroup{}, } udpAddr := opts.getUDPAddress() @@ -273,6 +282,23 @@ func (srv *Server) populateCaches(msgs []*Message) { } // +// RestartForwarders stop and start new forwarders with new nameserver address +// and protocol. +// Empty nameservers means server will run without forwarding request. +// +func (srv *Server) RestartForwarders(nameServers []string) { + srv.opts.NameServers = nameServers + srv.opts.parseNameServers() + + srv.stopForwarders() + srv.runForwarders() + + if !srv.hasForwarders { + log.Println("dns: no valid forward nameservers") + } +} + +// // Start the server, listening and serve query from clients. // func (srv *Server) Start() { @@ -624,6 +650,8 @@ func (srv *Server) processResponse(req *request, res *Message, isLocal bool) { } func (srv *Server) runForwarders() { + srv.isForwarding = true + nforwarders := 0 for x := 0; x < len(srv.opts.udpServers); x++ { go srv.runUDPForwarder(srv.opts.udpServers[x].String()) @@ -646,75 +674,129 @@ func (srv *Server) runForwarders() { } func (srv *Server) runDohForwarder(nameserver string) { - forwarder, err := NewDoHClient(nameserver, false) - if err != nil { - log.Fatal("dns: failed to create DoH forwarder: " + err.Error()) - } + srv.fwGroup.Add(1) + log.Printf("dns: starting DoH forwarder at %s", nameserver) - for req := range srv.forwardq { - if debug.Value >= 1 { - fmt.Printf("dns: ^ DoH %d:%s\n", - req.message.Header.ID, req.message.Question) - } - - res, err := forwarder.Query(req.message) + for srv.isForwarding { + forwarder, err := NewDoHClient(nameserver, false) if err != nil { - log.Println("dns: failed to query DoH: " + err.Error()) - continue + log.Fatal("dns: failed to create DoH forwarder: " + err.Error()) } - srv.processResponse(req, res, false) + for srv.isForwarding { //nolint:gosimple + select { + case req, ok := <-srv.forwardq: + if !ok { + goto out + } + if debug.Value >= 1 { + fmt.Printf("dns: ^ DoH %d:%s\n", + req.message.Header.ID, req.message.Question) + } + + res, err := forwarder.Query(req.message) + if err != nil { + log.Println("dns: failed to query DoH: " + err.Error()) + continue + } + + srv.processResponse(req, res, false) + } + } } +out: + srv.fwGroup.Done() + log.Printf("dns: DoH forwarder for %s has been stopped", nameserver) } func (srv *Server) runTCPForwarder(remoteAddr string) { - for req := range srv.forwardq { - if debug.Value >= 1 { - fmt.Printf("dns: ^ TCP %d:%s\n", - req.message.Header.ID, req.message.Question) - } + srv.fwGroup.Add(1) + log.Printf("dns: starting TCP forwarder at %s", remoteAddr) - cl, err := NewTCPClient(remoteAddr) - if err != nil { - log.Println("dns: failed to create TCP client: " + err.Error()) - continue - } + for srv.isForwarding { //nolint:gosimple + select { + case req, ok := <-srv.forwardq: + if !ok { + goto out + } + if debug.Value >= 1 { + fmt.Printf("dns: ^ TCP %d:%s\n", + req.message.Header.ID, req.message.Question) + } - res, err := cl.Query(req.message) - cl.Close() - if err != nil { - log.Println("dns: failed to query TCP: " + err.Error()) - continue - } + cl, err := NewTCPClient(remoteAddr) + if err != nil { + log.Println("dns: failed to create TCP client: " + err.Error()) + continue + } - srv.processResponse(req, res, false) + res, err := cl.Query(req.message) + cl.Close() + if err != nil { + log.Println("dns: failed to query TCP: " + err.Error()) + continue + } + + srv.processResponse(req, res, false) + } } +out: + srv.fwGroup.Done() + log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr) } +// +// runUDPForwarder create a UDP client that consume request from forward queue +// and forward it to parent server at "remoteAddr". +// func (srv *Server) runUDPForwarder(remoteAddr string) { - for { + srv.fwGroup.Add(1) + log.Printf("dns: starting UDP forwarder at %s", remoteAddr) + + // The first loop handle broken connection of UDP client. + for srv.isForwarding { forwarder, err := NewUDPClient(remoteAddr) if err != nil { log.Fatal("dns: failed to create UDP forwarder: " + err.Error()) } - for req := range srv.forwardq { - if debug.Value >= 1 { - fmt.Printf("dns: ^ UDP %d:%s\n", - req.message.Header.ID, req.message.Question) - } + // The second loop consume the forward queue. + for srv.isForwarding { //nolint:gosimple + select { + case req, ok := <-srv.forwardq: + if !ok { + goto out + } + if debug.Value >= 1 { + fmt.Printf("dns: ^ UDP %d:%s\n", + req.message.Header.ID, req.message.Question) + } - res, err := forwarder.Query(req.message) - if err != nil { - log.Println("dns: failed to query UDP: " + err.Error()) - break - } + res, err := forwarder.Query(req.message) + if err != nil { + log.Println("dns: failed to query UDP: " + err.Error()) + goto brokenClient + } - srv.processResponse(req, res, false) + srv.processResponse(req, res, false) + } } - + brokenClient: forwarder.Close() - log.Println("dns: restarting UDP forwarder for " + remoteAddr) + if srv.isForwarding { + log.Println("dns: restarting UDP forwarder for " + remoteAddr) + } } +out: + srv.fwGroup.Done() + log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr) +} + +// +// stopForwarders stop all forwarder connections. +// +func (srv *Server) stopForwarders() { + srv.isForwarding = false + srv.fwGroup.Wait() } diff --git a/lib/dns/serveroptions.go b/lib/dns/serveroptions.go index 9440b57b..f52e80dc 100644 --- a/lib/dns/serveroptions.go +++ b/lib/dns/serveroptions.go @@ -167,10 +167,14 @@ func (opts *ServerOptions) getDoHAddress() *net.TCPAddr { // // parseNameServers parse each name server in NameServers list based on scheme // and store the result either in udpServers, tcpServers, or dohServers. -// If the name server format is invalid, for example no scheme, it will be -// skipped. +// +// If the name server format contains no scheme, it will be assumed as "udp". // func (opts *ServerOptions) parseNameServers() { + opts.udpServers = nil + opts.tcpServers = nil + opts.dohServers = nil + for _, ns := range opts.NameServers { dnsURL, err := url.Parse(ns) if err != nil { @@ -201,9 +205,13 @@ func (opts *ServerOptions) parseNameServers() { opts.dohServers = append(opts.dohServers, ns) default: - udpAddr, err := libnet.ParseUDPAddr(dnsURL.Host, DefaultPort) + if len(dnsURL.Host) > 0 { + ns = dnsURL.Host + } + + udpAddr, err := libnet.ParseUDPAddr(ns, DefaultPort) if err != nil { - log.Printf("dns: invalid UDP IP address %q", dnsURL.Host) + log.Printf("dns: invalid UDP IP address %q", ns) continue } diff --git a/lib/dns/serveroptions_test.go b/lib/dns/serveroptions_test.go index 63424580..63e9fee3 100644 --- a/lib/dns/serveroptions_test.go +++ b/lib/dns/serveroptions_test.go @@ -112,10 +112,14 @@ func TestServerOptionsParseNameServers(t *testing.T) { "tcp://localhost:53", }, }, { - desc: "With invalid scheme", + desc: "With no scheme", nameServers: []string{ "127.0.0.1", }, + expUDPServers: []*net.UDPAddr{{ + IP: ip, + Port: 53, + }}, }, { desc: "With valid name servers", nameServers: []string{ |
