aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-06-16 14:36:27 +0700
committerShulhan <ms@kilabit.info>2019-06-16 14:36:27 +0700
commit45432e60ecbc4c2bab43a9912ee7de108befd13b (patch)
tree79ae1f3a2eafd5a7702f219bdd9ba5987cb89029
parent8d5ffba2d9483f049385d37bc468f2a114548eb0 (diff)
downloadpakakeh.go-45432e60ecbc4c2bab43a9912ee7de108befd13b.tar.xz
dns: add method to restart forwarders
The RestartForwarders method allow server to change the parent nameserver address to new one. An example of use case is when system change the network through WiFi, which cause the nameserver address also change in the resolv.conf file.
-rw-r--r--lib/dns/server.go172
-rw-r--r--lib/dns/serveroptions.go16
-rw-r--r--lib/dns/serveroptions_test.go6
3 files changed, 144 insertions, 50 deletions
diff --git a/lib/dns/server.go b/lib/dns/server.go
index 248ee44c..67b62350 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -17,6 +17,7 @@ import (
"os"
"path/filepath"
"strings"
+ "sync"
"github.com/shuLhan/share/lib/debug"
)
@@ -73,7 +74,14 @@ type Server struct {
requestq chan *request
forwardq chan *request
+ // fwGroup maintain reference counting for all forwarders.
+ fwGroup *sync.WaitGroup
+
hasForwarders bool
+
+ // isForwarding define a state that allow forwarding to run or to
+ // stop.
+ isForwarding bool
}
//
@@ -89,6 +97,7 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) {
opts: opts,
requestq: make(chan *request, 512),
forwardq: make(chan *request, 512),
+ fwGroup: &sync.WaitGroup{},
}
udpAddr := opts.getUDPAddress()
@@ -273,6 +282,23 @@ func (srv *Server) populateCaches(msgs []*Message) {
}
//
+// RestartForwarders stop and start new forwarders with new nameserver address
+// and protocol.
+// Empty nameservers means server will run without forwarding request.
+//
+func (srv *Server) RestartForwarders(nameServers []string) {
+ srv.opts.NameServers = nameServers
+ srv.opts.parseNameServers()
+
+ srv.stopForwarders()
+ srv.runForwarders()
+
+ if !srv.hasForwarders {
+ log.Println("dns: no valid forward nameservers")
+ }
+}
+
+//
// Start the server, listening and serve query from clients.
//
func (srv *Server) Start() {
@@ -624,6 +650,8 @@ func (srv *Server) processResponse(req *request, res *Message, isLocal bool) {
}
func (srv *Server) runForwarders() {
+ srv.isForwarding = true
+
nforwarders := 0
for x := 0; x < len(srv.opts.udpServers); x++ {
go srv.runUDPForwarder(srv.opts.udpServers[x].String())
@@ -646,75 +674,129 @@ func (srv *Server) runForwarders() {
}
func (srv *Server) runDohForwarder(nameserver string) {
- forwarder, err := NewDoHClient(nameserver, false)
- if err != nil {
- log.Fatal("dns: failed to create DoH forwarder: " + err.Error())
- }
+ srv.fwGroup.Add(1)
+ log.Printf("dns: starting DoH forwarder at %s", nameserver)
- for req := range srv.forwardq {
- if debug.Value >= 1 {
- fmt.Printf("dns: ^ DoH %d:%s\n",
- req.message.Header.ID, req.message.Question)
- }
-
- res, err := forwarder.Query(req.message)
+ for srv.isForwarding {
+ forwarder, err := NewDoHClient(nameserver, false)
if err != nil {
- log.Println("dns: failed to query DoH: " + err.Error())
- continue
+ log.Fatal("dns: failed to create DoH forwarder: " + err.Error())
}
- srv.processResponse(req, res, false)
+ for srv.isForwarding { //nolint:gosimple
+ select {
+ case req, ok := <-srv.forwardq:
+ if !ok {
+ goto out
+ }
+ if debug.Value >= 1 {
+ fmt.Printf("dns: ^ DoH %d:%s\n",
+ req.message.Header.ID, req.message.Question)
+ }
+
+ res, err := forwarder.Query(req.message)
+ if err != nil {
+ log.Println("dns: failed to query DoH: " + err.Error())
+ continue
+ }
+
+ srv.processResponse(req, res, false)
+ }
+ }
}
+out:
+ srv.fwGroup.Done()
+ log.Printf("dns: DoH forwarder for %s has been stopped", nameserver)
}
func (srv *Server) runTCPForwarder(remoteAddr string) {
- for req := range srv.forwardq {
- if debug.Value >= 1 {
- fmt.Printf("dns: ^ TCP %d:%s\n",
- req.message.Header.ID, req.message.Question)
- }
+ srv.fwGroup.Add(1)
+ log.Printf("dns: starting TCP forwarder at %s", remoteAddr)
- cl, err := NewTCPClient(remoteAddr)
- if err != nil {
- log.Println("dns: failed to create TCP client: " + err.Error())
- continue
- }
+ for srv.isForwarding { //nolint:gosimple
+ select {
+ case req, ok := <-srv.forwardq:
+ if !ok {
+ goto out
+ }
+ if debug.Value >= 1 {
+ fmt.Printf("dns: ^ TCP %d:%s\n",
+ req.message.Header.ID, req.message.Question)
+ }
- res, err := cl.Query(req.message)
- cl.Close()
- if err != nil {
- log.Println("dns: failed to query TCP: " + err.Error())
- continue
- }
+ cl, err := NewTCPClient(remoteAddr)
+ if err != nil {
+ log.Println("dns: failed to create TCP client: " + err.Error())
+ continue
+ }
- srv.processResponse(req, res, false)
+ res, err := cl.Query(req.message)
+ cl.Close()
+ if err != nil {
+ log.Println("dns: failed to query TCP: " + err.Error())
+ continue
+ }
+
+ srv.processResponse(req, res, false)
+ }
}
+out:
+ srv.fwGroup.Done()
+ log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr)
}
+//
+// runUDPForwarder create a UDP client that consume request from forward queue
+// and forward it to parent server at "remoteAddr".
+//
func (srv *Server) runUDPForwarder(remoteAddr string) {
- for {
+ srv.fwGroup.Add(1)
+ log.Printf("dns: starting UDP forwarder at %s", remoteAddr)
+
+ // The first loop handle broken connection of UDP client.
+ for srv.isForwarding {
forwarder, err := NewUDPClient(remoteAddr)
if err != nil {
log.Fatal("dns: failed to create UDP forwarder: " + err.Error())
}
- for req := range srv.forwardq {
- if debug.Value >= 1 {
- fmt.Printf("dns: ^ UDP %d:%s\n",
- req.message.Header.ID, req.message.Question)
- }
+ // The second loop consume the forward queue.
+ for srv.isForwarding { //nolint:gosimple
+ select {
+ case req, ok := <-srv.forwardq:
+ if !ok {
+ goto out
+ }
+ if debug.Value >= 1 {
+ fmt.Printf("dns: ^ UDP %d:%s\n",
+ req.message.Header.ID, req.message.Question)
+ }
- res, err := forwarder.Query(req.message)
- if err != nil {
- log.Println("dns: failed to query UDP: " + err.Error())
- break
- }
+ res, err := forwarder.Query(req.message)
+ if err != nil {
+ log.Println("dns: failed to query UDP: " + err.Error())
+ goto brokenClient
+ }
- srv.processResponse(req, res, false)
+ srv.processResponse(req, res, false)
+ }
}
-
+ brokenClient:
forwarder.Close()
- log.Println("dns: restarting UDP forwarder for " + remoteAddr)
+ if srv.isForwarding {
+ log.Println("dns: restarting UDP forwarder for " + remoteAddr)
+ }
}
+out:
+ srv.fwGroup.Done()
+ log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr)
+}
+
+//
+// stopForwarders stop all forwarder connections.
+//
+func (srv *Server) stopForwarders() {
+ srv.isForwarding = false
+ srv.fwGroup.Wait()
}
diff --git a/lib/dns/serveroptions.go b/lib/dns/serveroptions.go
index 9440b57b..f52e80dc 100644
--- a/lib/dns/serveroptions.go
+++ b/lib/dns/serveroptions.go
@@ -167,10 +167,14 @@ func (opts *ServerOptions) getDoHAddress() *net.TCPAddr {
//
// parseNameServers parse each name server in NameServers list based on scheme
// and store the result either in udpServers, tcpServers, or dohServers.
-// If the name server format is invalid, for example no scheme, it will be
-// skipped.
+//
+// If the name server format contains no scheme, it will be assumed as "udp".
//
func (opts *ServerOptions) parseNameServers() {
+ opts.udpServers = nil
+ opts.tcpServers = nil
+ opts.dohServers = nil
+
for _, ns := range opts.NameServers {
dnsURL, err := url.Parse(ns)
if err != nil {
@@ -201,9 +205,13 @@ func (opts *ServerOptions) parseNameServers() {
opts.dohServers = append(opts.dohServers, ns)
default:
- udpAddr, err := libnet.ParseUDPAddr(dnsURL.Host, DefaultPort)
+ if len(dnsURL.Host) > 0 {
+ ns = dnsURL.Host
+ }
+
+ udpAddr, err := libnet.ParseUDPAddr(ns, DefaultPort)
if err != nil {
- log.Printf("dns: invalid UDP IP address %q", dnsURL.Host)
+ log.Printf("dns: invalid UDP IP address %q", ns)
continue
}
diff --git a/lib/dns/serveroptions_test.go b/lib/dns/serveroptions_test.go
index 63424580..63e9fee3 100644
--- a/lib/dns/serveroptions_test.go
+++ b/lib/dns/serveroptions_test.go
@@ -112,10 +112,14 @@ func TestServerOptionsParseNameServers(t *testing.T) {
"tcp://localhost:53",
},
}, {
- desc: "With invalid scheme",
+ desc: "With no scheme",
nameServers: []string{
"127.0.0.1",
},
+ expUDPServers: []*net.UDPAddr{{
+ IP: ip,
+ Port: 53,
+ }},
}, {
desc: "With valid name servers",
nameServers: []string{