aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-07-04 22:44:06 +0700
committerShulhan <ms@kilabit.info>2019-07-09 22:15:44 +0700
commitdc709c767b73a6582416cc98e880ffdc96003dee (patch)
tree27bb367ad92d8179a389fbc1c8b85baaf69895dd
parent9283d0f1385a63e1955150a4d2a36d5207a1f66b (diff)
downloadpakakeh.go-dc709c767b73a6582416cc98e880ffdc96003dee.tar.xz
dns: use channel to stop all forwarders
Previously, we use a boolean condition to stop all forwarders by setting it to false. But, this method does not work because using select statement with single case, will block the process. This change, allocated a new boolean channel for each forwarders, that is stored on server fields and when we need to stop or restart all forwarders, we send the boolean value "true" to each channel.
-rw-r--r--lib/dns/server.go126
1 files changed, 76 insertions, 50 deletions
diff --git a/lib/dns/server.go b/lib/dns/server.go
index 7788eaf2..f77039e4 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -22,6 +22,11 @@ import (
"github.com/shuLhan/share/lib/debug"
)
+const (
+ asFallback = "fallback"
+ asPrimary = "primary"
+)
+
//
// Server defines DNS server.
//
@@ -72,19 +77,15 @@ type Server struct {
tcp *net.TCPListener
doh *http.Server
- requestq chan *request
- forwardq chan *request
- fallbackq chan *request
+ requestq chan *request
+ primaryq chan *request
+ fallbackq chan *request
+ forwardStoppers []chan bool
// fwGroup maintain reference counting for all forwarders.
fwGroup *sync.WaitGroup
hasForwarders bool
- hasFallbacks bool
-
- // isForwarding define a state that allow forwarding to run or to
- // stop.
- isForwarding bool
}
//
@@ -99,7 +100,7 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) {
srv = &Server{
opts: opts,
requestq: make(chan *request, 512),
- forwardq: make(chan *request, 512),
+ primaryq: make(chan *request, 512),
fallbackq: make(chan *request, 512),
fwGroup: &sync.WaitGroup{},
}
@@ -291,8 +292,11 @@ func (srv *Server) populateCaches(msgs []*Message) {
// Empty nameservers means server will run without forwarding request.
//
func (srv *Server) RestartForwarders(nameServers, fallbackNS []string) {
+ fmt.Printf("dns: RestartForwarders: %s %s\n", nameServers, fallbackNS)
+
srv.opts.NameServers = nameServers
srv.opts.FallbackNS = fallbackNS
+
srv.opts.parseNameServers()
srv.stopForwarders()
@@ -564,7 +568,7 @@ func (srv *Server) processRequest() {
if ans == nil {
if req.message.Header.IsRD && srv.hasForwarders {
- srv.forwardq <- req
+ srv.primaryq <- req
continue
}
@@ -573,7 +577,7 @@ func (srv *Server) processRequest() {
if an == nil {
if req.message.Header.IsRD && srv.hasForwarders {
- srv.forwardq <- req
+ srv.primaryq <- req
continue
}
@@ -594,7 +598,7 @@ func (srv *Server) processRequest() {
req.message.Header.ID,
req.message.Question)
}
- srv.forwardq <- req
+ srv.primaryq <- req
continue
}
@@ -656,21 +660,22 @@ func (srv *Server) processResponse(req *request, res *Message, fallbackq chan *r
}
func (srv *Server) runForwarders() {
- srv.isForwarding = true
+ srv.hasForwarders = false
+ srv.forwardStoppers = nil
nforwarders := 0
for x := 0; x < len(srv.opts.primaryUDP); x++ {
- go srv.runUDPForwarder(srv.opts.primaryUDP[x].String(), srv.forwardq, srv.fallbackq)
+ go srv.runUDPForwarder(srv.opts.primaryUDP[x].String(), srv.primaryq, srv.fallbackq)
nforwarders++
}
for x := 0; x < len(srv.opts.primaryTCP); x++ {
- go srv.runTCPForwarder(srv.opts.primaryTCP[x].String(), srv.forwardq, srv.fallbackq)
+ go srv.runTCPForwarder(srv.opts.primaryTCP[x].String(), srv.primaryq, srv.fallbackq)
nforwarders++
}
for x := 0; x < len(srv.opts.primaryDoh); x++ {
- go srv.runDohForwarder(srv.opts.primaryDoh[x], srv.forwardq, srv.fallbackq)
+ go srv.runDohForwarder(srv.opts.primaryDoh[x], srv.primaryq, srv.fallbackq)
nforwarders++
}
@@ -693,30 +698,32 @@ func (srv *Server) runForwarders() {
go srv.runDohForwarder(srv.opts.fallbackDoh[x], srv.fallbackq, nil)
nforwarders++
}
-
- if nforwarders > 0 {
- srv.hasFallbacks = true
- }
}
func (srv *Server) runDohForwarder(nameserver string, primaryq, fallbackq chan *request) {
- var isSuccess bool
+ var asWhat = asPrimary
+ if fallbackq == nil {
+ asWhat = asFallback
+ }
+
+ stop := make(chan bool)
+ srv.forwardStoppers = append(srv.forwardStoppers, stop)
srv.fwGroup.Add(1)
- log.Printf("dns: starting DoH forwarder at %s", nameserver)
- for srv.isForwarding {
+ fmt.Printf("dns: starting %s DoH forwarder at %s\n", asWhat, nameserver)
+
+ for {
forwarder, err := NewDoHClient(nameserver, false)
if err != nil {
log.Fatal("dns: failed to create DoH forwarder: " + err.Error())
}
- isSuccess = true
- for srv.isForwarding && isSuccess { //nolint:gosimple
+ for err == nil {
select {
case req, ok := <-primaryq:
if !ok {
- goto out
+ break
}
if debug.Value >= 1 {
fmt.Printf("dns: ^ DoH %s %d:%s\n",
@@ -730,33 +737,41 @@ func (srv *Server) runDohForwarder(nameserver string, primaryq, fallbackq chan *
if fallbackq != nil {
fallbackq <- req
}
- isSuccess = false
} else {
srv.processResponse(req, res, fallbackq)
}
+ case <-stop:
+ goto out
}
}
forwarder.Close()
- if srv.isForwarding {
- log.Println("dns: restarting DoH forwarder for " + nameserver)
- }
+ log.Println("dns: restarting DoH forwarder for " + nameserver)
}
out:
srv.fwGroup.Done()
- log.Printf("dns: DoH forwarder for %s has been stopped", nameserver)
+ fmt.Printf("dns: DoH %s forwarder %s has been stopped\n", asWhat, nameserver)
}
func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan *request) {
+ var asWhat = asPrimary
+
+ if fallbackq == nil {
+ asWhat = asFallback
+ }
+
+ stop := make(chan bool)
+ srv.forwardStoppers = append(srv.forwardStoppers, stop)
srv.fwGroup.Add(1)
- log.Printf("dns: starting TCP forwarder at %s", remoteAddr)
- for srv.isForwarding { //nolint:gosimple
+ fmt.Printf("dns: starting %s TCP forwarder at %s\n", asWhat, remoteAddr)
+
+ for {
select {
case req, ok := <-primaryq:
if !ok {
- goto out
+ break
}
if debug.Value >= 1 {
fmt.Printf("dns: ^ TCP %s %d:%s\n",
@@ -767,6 +782,7 @@ func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan *
cl, err := NewTCPClient(remoteAddr)
if err != nil {
log.Println("dns: failed to create TCP client: " + err.Error())
+ err = nil
continue
}
@@ -781,11 +797,13 @@ func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan *
}
srv.processResponse(req, res, fallbackq)
+ case <-stop:
+ goto out
}
}
out:
srv.fwGroup.Done()
- log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr)
+ fmt.Printf("dns: TCP %s forwarder %s has been stopped\n", asWhat, remoteAddr)
}
//
@@ -793,26 +811,32 @@ out:
// and forward it to parent server at "remoteAddr".
//
func (srv *Server) runUDPForwarder(remoteAddr string, primaryq, fallbackq chan *request) {
- var isSuccess bool
+ var asWhat = asPrimary
+
+ if fallbackq == nil {
+ asWhat = asFallback
+ }
+ stop := make(chan bool)
+ srv.forwardStoppers = append(srv.forwardStoppers, stop)
srv.fwGroup.Add(1)
- log.Printf("dns: starting UDP forwarder at %s", remoteAddr)
- // The first loop handle broken connection of UDP client.
- for srv.isForwarding {
+ fmt.Printf("dns: starting %s UDP forwarder at %s\n", asWhat, remoteAddr)
+
+ // The first loop handle broken connection.
+ for {
forwarder, err := NewUDPClient(remoteAddr)
if err != nil {
log.Fatal("dns: failed to create UDP forwarder: " + err.Error())
}
- isSuccess = true
-
// The second loop consume the forward queue.
- for srv.isForwarding && isSuccess { //nolint:gosimple
+ for err == nil {
select {
case req, ok := <-primaryq:
if !ok {
- goto out
+ fmt.Printf("dns: UDP break %s\n", remoteAddr)
+ break
}
if debug.Value >= 1 {
fmt.Printf("dns: ^ UDP %s %d:%s\n",
@@ -826,28 +850,30 @@ func (srv *Server) runUDPForwarder(remoteAddr string, primaryq, fallbackq chan *
if fallbackq != nil {
fallbackq <- req
}
- isSuccess = false
} else {
srv.processResponse(req, res, fallbackq)
}
+ case <-stop:
+ goto out
}
}
forwarder.Close()
-
- if srv.isForwarding {
- log.Println("dns: restarting UDP forwarder for " + remoteAddr)
- }
+ log.Println("dns: restarting UDP forwarder for " + remoteAddr)
}
out:
srv.fwGroup.Done()
- log.Printf("dns: TCP forwarder for %s has been stopped", remoteAddr)
+ fmt.Printf("dns: UDP %s forwarder %s has been stopped\n", asWhat, remoteAddr)
}
//
// stopForwarders stop all forwarder connections.
//
func (srv *Server) stopForwarders() {
- srv.isForwarding = false
+ srv.hasForwarders = false
+ for _, forwardStop := range srv.forwardStoppers {
+ forwardStop <- true
+ }
srv.fwGroup.Wait()
+ fmt.Println("dns: all forwarders has been stopped")
}