From dc709c767b73a6582416cc98e880ffdc96003dee Mon Sep 17 00:00:00 2001 From: Shulhan Date: Thu, 4 Jul 2019 22:44:06 +0700 Subject: dns: use channel to stop all forwarders Previously, we use a boolean condition to stop all forwarders by setting it to false. But, this method does not work because using select statement with single case, will block the process. This change, allocated a new boolean channel for each forwarders, that is stored on server fields and when we need to stop or restart all forwarders, we send the boolean value "true" to each channel. --- lib/dns/server.go | 126 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 50 deletions(-) diff --git a/lib/dns/server.go b/lib/dns/server.go index 7788eaf2..f77039e4 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -22,6 +22,11 @@ import ( "github.com/shuLhan/share/lib/debug" ) +const ( + asFallback = "fallback" + asPrimary = "primary" +) + // // Server defines DNS server. // @@ -72,19 +77,15 @@ type Server struct { tcp *net.TCPListener doh *http.Server - requestq chan *request - forwardq chan *request - fallbackq chan *request + requestq chan *request + primaryq chan *request + fallbackq chan *request + forwardStoppers []chan bool // fwGroup maintain reference counting for all forwarders. fwGroup *sync.WaitGroup hasForwarders bool - hasFallbacks bool - - // isForwarding define a state that allow forwarding to run or to - // stop. - isForwarding bool } // @@ -99,7 +100,7 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) { srv = &Server{ opts: opts, requestq: make(chan *request, 512), - forwardq: make(chan *request, 512), + primaryq: make(chan *request, 512), fallbackq: make(chan *request, 512), fwGroup: &sync.WaitGroup{}, } @@ -291,8 +292,11 @@ func (srv *Server) populateCaches(msgs []*Message) { // Empty nameservers means server will run without forwarding request. // func (srv *Server) RestartForwarders(nameServers, fallbackNS []string) { + fmt.Printf("dns: RestartForwarders: %s %s\n", nameServers, fallbackNS) + srv.opts.NameServers = nameServers srv.opts.FallbackNS = fallbackNS + srv.opts.parseNameServers() srv.stopForwarders() @@ -564,7 +568,7 @@ func (srv *Server) processRequest() { if ans == nil { if req.message.Header.IsRD && srv.hasForwarders { - srv.forwardq <- req + srv.primaryq <- req continue } @@ -573,7 +577,7 @@ func (srv *Server) processRequest() { if an == nil { if req.message.Header.IsRD && srv.hasForwarders { - srv.forwardq <- req + srv.primaryq <- req continue } @@ -594,7 +598,7 @@ func (srv *Server) processRequest() { req.message.Header.ID, req.message.Question) } - srv.forwardq <- req + srv.primaryq <- req continue } @@ -656,21 +660,22 @@ func (srv *Server) processResponse(req *request, res *Message, fallbackq chan *r } func (srv *Server) runForwarders() { - srv.isForwarding = true + srv.hasForwarders = false + srv.forwardStoppers = nil nforwarders := 0 for x := 0; x < len(srv.opts.primaryUDP); x++ { - go srv.runUDPForwarder(srv.opts.primaryUDP[x].String(), srv.forwardq, srv.fallbackq) + go srv.runUDPForwarder(srv.opts.primaryUDP[x].String(), srv.primaryq, srv.fallbackq) nforwarders++ } for x := 0; x < len(srv.opts.primaryTCP); x++ { - go srv.runTCPForwarder(srv.opts.primaryTCP[x].String(), srv.forwardq, srv.fallbackq) + go srv.runTCPForwarder(srv.opts.primaryTCP[x].String(), srv.primaryq, srv.fallbackq) nforwarders++ } for x := 0; x < len(srv.opts.primaryDoh); x++ { - go srv.runDohForwarder(srv.opts.primaryDoh[x], srv.forwardq, srv.fallbackq) + go srv.runDohForwarder(srv.opts.primaryDoh[x], srv.primaryq, srv.fallbackq) nforwarders++ } @@ -693,30 +698,32 @@ func (srv *Server) runForwarders() { go srv.runDohForwarder(srv.opts.fallbackDoh[x], srv.fallbackq, nil) nforwarders++ } - - if nforwarders > 0 { - srv.hasFallbacks = true - } } func (srv *Server) runDohForwarder(nameserver string, primaryq, fallbackq chan *request) { - var isSuccess bool + var asWhat = asPrimary + if fallbackq == nil { + asWhat = asFallback + } + + stop := make(chan bool) + srv.forwardStoppers = append(srv.forwardStoppers, stop) srv.fwGroup.Add(1) - log.Printf("dns: starting DoH forwarder at %s", nameserver) - for srv.isForwarding { + fmt.Printf("dns: starting %s DoH forwarder at %s\n", asWhat, nameserver) + + for { forwarder, err := NewDoHClient(nameserver, false) if err != nil { log.Fatal("dns: failed to create DoH forwarder: " + err.Error()) } - isSuccess = true - for srv.isForwarding && isSuccess { //nolint:gosimple + for err == nil { select { case req, ok := <-primaryq: if !ok { - goto out + break } if debug.Value >= 1 { fmt.Printf("dns: ^ DoH %s %d:%s\n", @@ -730,33 +737,41 @@ func (srv *Server) runDohForwarder(nameserver string, primaryq, fallbackq chan * if fallbackq != nil { fallbackq <- req } - isSuccess = false } else { srv.processResponse(req, res, fallbackq) } + case <-stop: + goto out } } forwarder.Close() - if srv.isForwarding { - log.Println("dns: restarting DoH forwarder for " + nameserver) - } + log.Println("dns: restarting DoH forwarder for " + nameserver) } out: srv.fwGroup.Done() - log.Printf("dns: DoH forwarder for %s has been stopped", nameserver) + fmt.Printf("dns: DoH %s forwarder %s has been stopped\n", asWhat, nameserver) } func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan *request) { + var asWhat = asPrimary + + if fallbackq == nil { + asWhat = asFallback + } + + stop := make(chan bool) + srv.forwardStoppers = append(srv.forwardStoppers, stop) srv.fwGroup.Add(1) - log.Printf("dns: starting TCP forwarder at %s", remoteAddr) - for srv.isForwarding { //nolint:gosimple + fmt.Printf("dns: starting %s TCP forwarder at %s\n", asWhat, remoteAddr) + + for { select { case req, ok := <-primaryq: if !ok { - goto out + break } if debug.Value >= 1 { fmt.Printf("dns: ^ TCP %s %d:%s\n", @@ -767,6 +782,7 @@ func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan * cl, err := NewTCPClient(remoteAddr) if err != nil { log.Println("dns: failed to create TCP client: " + err.Error()) + err = nil continue } @@ -781,11 +797,13 @@ func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan * } srv.processResponse(req, res, fallbackq) + case <-stop: + goto out } } out: srv.fwGroup.Done() - log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr) + fmt.Printf("dns: TCP %s forwarder %s has been stopped\n", asWhat, remoteAddr) } // @@ -793,26 +811,32 @@ out: // and forward it to parent server at "remoteAddr". // func (srv *Server) runUDPForwarder(remoteAddr string, primaryq, fallbackq chan *request) { - var isSuccess bool + var asWhat = asPrimary + + if fallbackq == nil { + asWhat = asFallback + } + stop := make(chan bool) + srv.forwardStoppers = append(srv.forwardStoppers, stop) srv.fwGroup.Add(1) - log.Printf("dns: starting UDP forwarder at %s", remoteAddr) - // The first loop handle broken connection of UDP client. - for srv.isForwarding { + fmt.Printf("dns: starting %s UDP forwarder at %s\n", asWhat, remoteAddr) + + // The first loop handle broken connection. + for { forwarder, err := NewUDPClient(remoteAddr) if err != nil { log.Fatal("dns: failed to create UDP forwarder: " + err.Error()) } - isSuccess = true - // The second loop consume the forward queue. - for srv.isForwarding && isSuccess { //nolint:gosimple + for err == nil { select { case req, ok := <-primaryq: if !ok { - goto out + fmt.Printf("dns: UDP break %s\n", remoteAddr) + break } if debug.Value >= 1 { fmt.Printf("dns: ^ UDP %s %d:%s\n", @@ -826,28 +850,30 @@ func (srv *Server) runUDPForwarder(remoteAddr string, primaryq, fallbackq chan * if fallbackq != nil { fallbackq <- req } - isSuccess = false } else { srv.processResponse(req, res, fallbackq) } + case <-stop: + goto out } } forwarder.Close() - - if srv.isForwarding { - log.Println("dns: restarting UDP forwarder for " + remoteAddr) - } + log.Println("dns: restarting UDP forwarder for " + remoteAddr) } out: srv.fwGroup.Done() - log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr) + fmt.Printf("dns: UDP %s forwarder %s has been stopped\n", asWhat, remoteAddr) } // // stopForwarders stop all forwarder connections. // func (srv *Server) stopForwarders() { - srv.isForwarding = false + srv.hasForwarders = false + for _, forwardStop := range srv.forwardStoppers { + forwardStop <- true + } srv.fwGroup.Wait() + fmt.Println("dns: all forwarders has been stopped") } -- cgit v1.3