aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-06-18 14:40:53 +0700
committerShulhan <ms@kilabit.info>2019-06-18 14:44:44 +0700
commit2516d2d69c5535a50fc72403066b95a86bf225a4 (patch)
treed142bc056f9deceb1ffd563924dddf59dc82c259
parent9b95ec0755219146afbe7a854aa0d2cc65e15f97 (diff)
downloadpakakeh.go-2516d2d69c5535a50fc72403066b95a86bf225a4.tar.xz
dns: add fallback nameservers
The fallback nameservers is another list of parent name servers that will be queried if the primary NameServers return an error.
-rw-r--r--CHANGELOG.adoc8
-rw-r--r--lib/dns/server.go112
-rw-r--r--lib/dns/serveroptions.go58
-rw-r--r--lib/dns/serveroptions_test.go8
4 files changed, 124 insertions, 62 deletions
diff --git a/CHANGELOG.adoc b/CHANGELOG.adoc
index 8bff711b..a47c3e4b 100644
--- a/CHANGELOG.adoc
+++ b/CHANGELOG.adoc
@@ -13,7 +13,13 @@ first week of next month.
=== New Features
-* ascii: new library for working with ASCII characters.
+* ascii: new library for working with ASCII characters
+
+=== Enhancements
+
+* dns: add method to restart forwarders
+* dns: add fallback nameservers
+* ini: create new section or variable if not exist on Set
== share v0.7.0 (2019-06-14)
diff --git a/lib/dns/server.go b/lib/dns/server.go
index 67b62350..df6a728d 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -60,7 +60,8 @@ import (
// + : new answer is added to caches
// # : the expired answer is renewed and updated on caches
//
-// Following the prefix is connection type, message ID, and question.
+// Following the prefix is connection type, parent name server address,
+// message ID, and question.
//
type Server struct {
opts *ServerOptions
@@ -71,13 +72,15 @@ type Server struct {
tcp *net.TCPListener
doh *http.Server
- requestq chan *request
- forwardq chan *request
+ requestq chan *request
+ forwardq chan *request
+ fallbackq chan *request
// fwGroup maintain reference counting for all forwarders.
fwGroup *sync.WaitGroup
hasForwarders bool
+ hasFallbacks bool
// isForwarding define a state that allow forwarding to run or to
// stop.
@@ -94,10 +97,11 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) {
}
srv = &Server{
- opts: opts,
- requestq: make(chan *request, 512),
- forwardq: make(chan *request, 512),
- fwGroup: &sync.WaitGroup{},
+ opts: opts,
+ requestq: make(chan *request, 512),
+ forwardq: make(chan *request, 512),
+ fallbackq: make(chan *request, 512),
+ fwGroup: &sync.WaitGroup{},
}
udpAddr := opts.getUDPAddress()
@@ -286,8 +290,9 @@ func (srv *Server) populateCaches(msgs []*Message) {
// and protocol.
// Empty nameservers means server will run without forwarding request.
//
-func (srv *Server) RestartForwarders(nameServers []string) {
+func (srv *Server) RestartForwarders(nameServers, fallbackNS []string) {
srv.opts.NameServers = nameServers
+ srv.opts.FallbackNS = fallbackNS
srv.opts.parseNameServers()
srv.stopForwarders()
@@ -604,16 +609,17 @@ func (srv *Server) processRequest() {
}
}
- srv.processResponse(req, res, true)
+ _, err := req.writer.Write(res.Packet)
+ if err != nil {
+ log.Println("dns: processRequest: ", err.Error())
+ }
}
}
-func (srv *Server) processResponse(req *request, res *Message, isLocal bool) {
- if !isLocal {
- if !isResponseValid(req, res) {
- srv.requestq <- req
- return
- }
+func (srv *Server) processResponse(req *request, res *Message, fallbackq chan *request) {
+ if !isResponseValid(req, res) {
+ srv.requestq <- req
+ return
}
_, err := req.writer.Write(res.Packet)
@@ -622,14 +628,13 @@ func (srv *Server) processResponse(req *request, res *Message, isLocal bool) {
return
}
- if isLocal {
- return
- }
-
if res.Header.RCode != 0 {
log.Printf("dns: ! %s %s %d:%s\n",
connTypeNames[req.kind], rcodeNames[res.Header.RCode],
res.Header.ID, res.Question)
+ if fallbackq != nil {
+ fallbackq <- req
+ }
return
}
@@ -647,33 +652,54 @@ func (srv *Server) processResponse(req *request, res *Message, isLocal bool) {
res.Header.ID, res.Question)
}
}
+
}
func (srv *Server) runForwarders() {
srv.isForwarding = true
nforwarders := 0
- for x := 0; x < len(srv.opts.udpServers); x++ {
- go srv.runUDPForwarder(srv.opts.udpServers[x].String())
+ for x := 0; x < len(srv.opts.primaryUDP); x++ {
+ go srv.runUDPForwarder(srv.opts.primaryUDP[x].String(), srv.forwardq, srv.fallbackq)
nforwarders++
}
- for x := 0; x < len(srv.opts.tcpServers); x++ {
- go srv.runTCPForwarder(srv.opts.tcpServers[x].String())
+ for x := 0; x < len(srv.opts.primaryTCP); x++ {
+ go srv.runTCPForwarder(srv.opts.primaryTCP[x].String(), srv.forwardq, srv.fallbackq)
nforwarders++
}
- for x := 0; x < len(srv.opts.dohServers); x++ {
- go srv.runDohForwarder(srv.opts.dohServers[x])
+ for x := 0; x < len(srv.opts.primaryDoh); x++ {
+ go srv.runDohForwarder(srv.opts.primaryDoh[x], srv.forwardq, srv.fallbackq)
nforwarders++
}
if nforwarders > 0 {
srv.hasForwarders = true
}
+
+ nforwarders = 0
+ for x := 0; x < len(srv.opts.fallbackUDP); x++ {
+ go srv.runUDPForwarder(srv.opts.fallbackUDP[x].String(), srv.fallbackq, nil)
+ nforwarders++
+ }
+
+ for x := 0; x < len(srv.opts.fallbackTCP); x++ {
+ go srv.runTCPForwarder(srv.opts.fallbackTCP[x].String(), srv.fallbackq, nil)
+ nforwarders++
+ }
+
+ for x := 0; x < len(srv.opts.fallbackDoh); x++ {
+ go srv.runDohForwarder(srv.opts.fallbackDoh[x], srv.fallbackq, nil)
+ nforwarders++
+ }
+
+ if nforwarders > 0 {
+ srv.hasFallbacks = true
+ }
}
-func (srv *Server) runDohForwarder(nameserver string) {
+func (srv *Server) runDohForwarder(nameserver string, primaryq, fallbackq chan *request) {
srv.fwGroup.Add(1)
log.Printf("dns: starting DoH forwarder at %s", nameserver)
@@ -685,22 +711,26 @@ func (srv *Server) runDohForwarder(nameserver string) {
for srv.isForwarding { //nolint:gosimple
select {
- case req, ok := <-srv.forwardq:
+ case req, ok := <-primaryq:
if !ok {
goto out
}
if debug.Value >= 1 {
- fmt.Printf("dns: ^ DoH %d:%s\n",
+ fmt.Printf("dns: ^ DoH %s %d:%s\n",
+ nameserver,
req.message.Header.ID, req.message.Question)
}
res, err := forwarder.Query(req.message)
if err != nil {
log.Println("dns: failed to query DoH: " + err.Error())
+ if fallbackq != nil {
+ fallbackq <- req
+ }
continue
}
- srv.processResponse(req, res, false)
+ srv.processResponse(req, res, fallbackq)
}
}
}
@@ -709,18 +739,19 @@ out:
log.Printf("dns: DoH forwarder for %s has been stopped", nameserver)
}
-func (srv *Server) runTCPForwarder(remoteAddr string) {
+func (srv *Server) runTCPForwarder(remoteAddr string, primaryq, fallbackq chan *request) {
srv.fwGroup.Add(1)
log.Printf("dns: starting TCP forwarder at %s", remoteAddr)
for srv.isForwarding { //nolint:gosimple
select {
- case req, ok := <-srv.forwardq:
+ case req, ok := <-primaryq:
if !ok {
goto out
}
if debug.Value >= 1 {
- fmt.Printf("dns: ^ TCP %d:%s\n",
+ fmt.Printf("dns: ^ TCP %s %d:%s\n",
+ remoteAddr,
req.message.Header.ID, req.message.Question)
}
@@ -734,10 +765,13 @@ func (srv *Server) runTCPForwarder(remoteAddr string) {
cl.Close()
if err != nil {
log.Println("dns: failed to query TCP: " + err.Error())
+ if fallbackq != nil {
+ fallbackq <- req
+ }
continue
}
- srv.processResponse(req, res, false)
+ srv.processResponse(req, res, fallbackq)
}
}
out:
@@ -749,7 +783,7 @@ out:
// runUDPForwarder create a UDP client that consume request from forward queue
// and forward it to parent server at "remoteAddr".
//
-func (srv *Server) runUDPForwarder(remoteAddr string) {
+func (srv *Server) runUDPForwarder(remoteAddr string, primaryq, fallbackq chan *request) {
srv.fwGroup.Add(1)
log.Printf("dns: starting UDP forwarder at %s", remoteAddr)
@@ -763,22 +797,26 @@ func (srv *Server) runUDPForwarder(remoteAddr string) {
// The second loop consume the forward queue.
for srv.isForwarding { //nolint:gosimple
select {
- case req, ok := <-srv.forwardq:
+ case req, ok := <-primaryq:
if !ok {
goto out
}
if debug.Value >= 1 {
- fmt.Printf("dns: ^ UDP %d:%s\n",
+ fmt.Printf("dns: ^ UDP %s %d:%s\n",
+ remoteAddr,
req.message.Header.ID, req.message.Question)
}
res, err := forwarder.Query(req.message)
if err != nil {
log.Println("dns: failed to query UDP: " + err.Error())
+ if fallbackq != nil {
+ fallbackq <- req
+ }
goto brokenClient
}
- srv.processResponse(req, res, false)
+ srv.processResponse(req, res, fallbackq)
}
}
brokenClient:
diff --git a/lib/dns/serveroptions.go b/lib/dns/serveroptions.go
index f52e80dc..7eae5242 100644
--- a/lib/dns/serveroptions.go
+++ b/lib/dns/serveroptions.go
@@ -35,11 +35,12 @@ type ServerOptions struct {
// DoHPort port for listening DNS over HTTP, default to 443.
DoHPort uint16
+ //
// NameServers contains list of parent name servers.
//
// Answer that does not exist on local will be forwarded to parent
- // name servers. If this is empty, any query that does not have an
- // answer in local caches, will be returned with response code
+ // name servers. If this field is empty, any query that does not have
+ // an answer in local caches, will be returned with response code
// RCodeErrName (3).
//
// The name server use the URI format,
@@ -60,6 +61,14 @@ type ServerOptions struct {
//
NameServers []string
+ //
+ // FallbackNS contains list of parent name servers that will be
+ // queried if the primary NameServers return an error.
+ //
+ // This field use the same format as NameServers.
+ //
+ FallbackNS []string
+
// DoHCertificate contains certificate for serving DNS over HTTPS.
// This field is optional, if its empty, server will not listening on
// HTTPS port.
@@ -88,17 +97,21 @@ type ServerOptions struct {
ip net.IP
- // udpServers contains list of parent name server addresses using UDP
+ // primaryUDP contains list of parent name server addresses using UDP
// protocol.
- udpServers []*net.UDPAddr
+ primaryUDP []*net.UDPAddr
- // tcpServers contains list of parent name server addresses using TCP
+ // primaryTCP contains list of parent name server addresses using TCP
// protocol.
- tcpServers []*net.TCPAddr
+ primaryTCP []*net.TCPAddr
- // dohServers contains list of parent name server addresses using DoH
+ // primaryDoh contains list of parent name server addresses using DoH
// protocol.
- dohServers []string
+ primaryDoh []string
+
+ fallbackUDP []*net.UDPAddr
+ fallbackTCP []*net.TCPAddr
+ fallbackDoh []string
}
//
@@ -136,7 +149,7 @@ func (opts *ServerOptions) init() (err error) {
opts.parseNameServers()
- if len(opts.udpServers) == 0 && len(opts.tcpServers) == 0 && len(opts.dohServers) == 0 {
+ if len(opts.primaryUDP) == 0 && len(opts.primaryTCP) == 0 && len(opts.primaryDoh) == 0 {
return fmt.Errorf("dns: no valid name servers")
}
@@ -166,16 +179,14 @@ func (opts *ServerOptions) getDoHAddress() *net.TCPAddr {
//
// parseNameServers parse each name server in NameServers list based on scheme
-// and store the result either in udpServers, tcpServers, or dohServers.
+// and store the result either in udpAddrs, tcpAddrs, or dohAddrs.
//
// If the name server format contains no scheme, it will be assumed as "udp".
//
-func (opts *ServerOptions) parseNameServers() {
- opts.udpServers = nil
- opts.tcpServers = nil
- opts.dohServers = nil
-
- for _, ns := range opts.NameServers {
+func parseNameServers(nameServers []string) (
+ udpAddrs []*net.UDPAddr, tcpAddrs []*net.TCPAddr, dohAddrs []string,
+) {
+ for _, ns := range nameServers {
dnsURL, err := url.Parse(ns)
if err != nil {
log.Printf("dns: invalid name server URI %q", ns)
@@ -190,7 +201,7 @@ func (opts *ServerOptions) parseNameServers() {
continue
}
- opts.udpServers = append(opts.udpServers, udpAddr)
+ udpAddrs = append(udpAddrs, udpAddr)
case "tcp":
tcpAddr, err := libnet.ParseTCPAddr(dnsURL.Host, DefaultPort)
@@ -199,10 +210,10 @@ func (opts *ServerOptions) parseNameServers() {
continue
}
- opts.tcpServers = append(opts.tcpServers, tcpAddr)
+ tcpAddrs = append(tcpAddrs, tcpAddr)
case "https":
- opts.dohServers = append(opts.dohServers, ns)
+ dohAddrs = append(dohAddrs, ns)
default:
if len(dnsURL.Host) > 0 {
@@ -215,7 +226,14 @@ func (opts *ServerOptions) parseNameServers() {
continue
}
- opts.udpServers = append(opts.udpServers, udpAddr)
+ udpAddrs = append(udpAddrs, udpAddr)
}
}
+
+ return udpAddrs, tcpAddrs, dohAddrs
+}
+
+func (opts *ServerOptions) parseNameServers() {
+ opts.primaryUDP, opts.primaryTCP, opts.primaryDoh = parseNameServers(opts.NameServers)
+ opts.fallbackUDP, opts.fallbackTCP, opts.fallbackDoh = parseNameServers(opts.FallbackNS)
}
diff --git a/lib/dns/serveroptions_test.go b/lib/dns/serveroptions_test.go
index 63e9fee3..bca99c28 100644
--- a/lib/dns/serveroptions_test.go
+++ b/lib/dns/serveroptions_test.go
@@ -64,7 +64,7 @@ func TestServerOptionsInit(t *testing.T) {
PruneDelay: time.Hour,
PruneThreshold: -1 * time.Hour,
ip: ip,
- udpServers: []*net.UDPAddr{{
+ primaryUDP: []*net.UDPAddr{{
IP: net.ParseIP("127.0.0.1"),
Port: 53,
}},
@@ -147,8 +147,8 @@ func TestServerOptionsParseNameServers(t *testing.T) {
so.parseNameServers()
- test.Assert(t, "udpServers", c.expUDPServers, so.udpServers, true)
- test.Assert(t, "tcpServers", c.expTCPServers, so.tcpServers, true)
- test.Assert(t, "dohServers", c.expDoHServers, so.dohServers, true)
+ test.Assert(t, "primaryUDP", c.expUDPServers, so.primaryUDP, true)
+ test.Assert(t, "primaryTCP", c.expTCPServers, so.primaryTCP, true)
+ test.Assert(t, "primaryDoh", c.expDoHServers, so.primaryDoh, true)
}
}