diff options
| author | Shulhan <ms@kilabit.info> | 2019-04-12 23:07:58 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-07-05 00:37:21 +0700 |
| commit | b412fb5853219fee7c0fcad7bab8a72d846d7bc5 (patch) | |
| tree | 59ec94fe15ad69a92920c4378a6ecc5c6109ef48 /rescached.go | |
| parent | 3bb229101b114153606f7ecec8911c811188bc30 (diff) | |
| download | rescached-b412fb5853219fee7c0fcad7bab8a72d846d7bc5.tar.xz | |
all: refactoring with latest update on dns package
All the server core functionalities (caches and forwarding) now
implemented inside "dns.Server". The main function of this package are
for reading options from configuration file (or from command line options)
and watching changes from system resolv.conf.
There are also some major changes on configuration file.
* "server.parent" option now use URI format instead of IP:PORT.
This will allow parent name servers to be UDP, TCP, and/or DoH
simultaneusly.
* "server.doh.parent" and "server.parent.connection" are removed,
redundant with new "server.parent" format.
* "cache.threshold" is renamed to "cache.prune_threshold".
Diffstat (limited to 'rescached.go')
| -rw-r--r-- | rescached.go | 473 |
1 files changed, 34 insertions, 439 deletions
diff --git a/rescached.go b/rescached.go index 31a0db4..b70a047 100644 --- a/rescached.go +++ b/rescached.go @@ -6,280 +6,73 @@ package rescached import ( - "bytes" - "errors" "fmt" "io/ioutil" "log" - "net" "os" - "path/filepath" "strconv" - libbytes "github.com/shuLhan/share/lib/bytes" "github.com/shuLhan/share/lib/debug" "github.com/shuLhan/share/lib/dns" libio "github.com/shuLhan/share/lib/io" - libnet "github.com/shuLhan/share/lib/net" -) - -const ( - _maxQueue = 512 - _maxForwarder = 4 -) - -// List of error messages. -var ( - ErrNetworkType = errors.New("invalid network type") ) // Server implement caching DNS server. type Server struct { - dnsServer *dns.Server - nsParents []*net.UDPAddr - reqQueue chan *dns.Request - fwQueue chan *dns.Request - fwDoHQueue chan *dns.Request - fwStop chan bool - cw *cacheWorker - opts *Options + dns *dns.Server + opts *Options } // // New create and initialize new rescached server. // -func New(opts *Options) *Server { +func New(opts *Options) (srv *Server, err error) { if opts == nil { opts = NewOptions() } opts.init() - srv := &Server{ - dnsServer: new(dns.Server), - reqQueue: make(chan *dns.Request, _maxQueue), - fwQueue: make(chan *dns.Request, _maxQueue), - fwDoHQueue: make(chan *dns.Request, _maxQueue), - fwStop: make(chan bool), - cw: newCacheWorker(opts.CachePruneDelay, opts.CacheThreshold), - opts: opts, - } - - if len(srv.opts.FileResolvConf) == 0 { - srv.nsParents = srv.opts.NSParents - } else { - err := srv.loadResolvConf() - if err != nil { - log.Printf("! loadResolvConf: %s\n", err) - srv.nsParents = srv.opts.NSParents - } - } - - srv.dnsServer.Handler = srv - - return srv -} - -func (srv *Server) CachesStats() string { - return fmt.Sprintf("= rescached: CachesStats{caches:%d cachesList:%d}", - srv.cw.caches.length(), srv.cw.cachesList.length()) -} - -// -// LoadHostsFile parse hosts formatted file and put it into caches. -// -func (srv *Server) LoadHostsFile(path string) { - if len(path) == 0 { - fmt.Println("= Loading system hosts file") - } else { - fmt.Printf("= Loading hosts file '%s'\n", path) - } - - msgs, err := dns.HostsLoad(path) - if err != nil { - return - } - - srv.populateCaches(msgs) -} - -// -// LoadHostsDir load all host formatted files in directory. -// -func (srv *Server) LoadHostsDir(dir string) { - if len(dir) == 0 { - return - } - - d, err := os.Open(dir) - if err != nil { - log.Println("! loadHostsDir: Open:", err) - return - } - - fis, err := d.Readdir(0) - if err != nil { - log.Println("! loadHostsDir: Readdir:", err) - err = d.Close() - if err != nil { - log.Println("! loadHostsDir: Close:", err) - } - return - } - - for x := 0; x < len(fis); x++ { - if fis[x].IsDir() { - continue - } - - hostsFile := filepath.Join(dir, fis[x].Name()) - - srv.LoadHostsFile(hostsFile) - } - - err = d.Close() - if err != nil { - log.Println("! loadHostsDir: Close:", err) - } -} - -// -// LoadMasterDir load all master formatted files in directory. -// -func (srv *Server) LoadMasterDir(dir string) { - if len(dir) == 0 { - return - } - - d, err := os.Open(dir) - if err != nil { - log.Println("! loadMasterDir: ", err) - return - } - - fis, err := d.Readdir(0) - if err != nil { - log.Println("! loadMasterDir: ", err) - err = d.Close() - if err != nil { - log.Println("! loadMasterDir: Close:", err) - } - return - } - - for x := 0; x < len(fis); x++ { - if fis[x].IsDir() { - continue - } - - masterFile := filepath.Join(dir, fis[x].Name()) - - srv.LoadMasterFile(masterFile) - } - - err = d.Close() - if err != nil { - log.Println("! loadHostsDir: Close:", err) - } -} - -// -// LostMasterFile parse master file and put the result into caches. -// -func (srv *Server) LoadMasterFile(path string) { - fmt.Printf("= Loading master file '%s'\n", path) - - msgs, err := dns.MasterLoad(path, "", 0) - if err != nil { - return - } - - srv.populateCaches(msgs) -} - -func (srv *Server) loadResolvConf() error { - rc, err := libnet.NewResolvConf(srv.opts.FileResolvConf) - if err != nil { - return err + if debug.Value >= 1 { + fmt.Printf("= config: %+v\n", opts) } - nsAddrs, err := dns.ParseNameServers(rc.NameServers) + dnsServer, err := dns.NewServer(&opts.ServerOptions) if err != nil { - return err + return nil, err } - if len(nsAddrs) > 0 { - srv.nsParents = nsAddrs - } else { - srv.nsParents = srv.opts.NSParents - } + dnsServer.LoadHostsDir(opts.DirHosts) + dnsServer.LoadMasterDir(opts.DirMaster) + dnsServer.LoadHostsFile("") - return nil -} - -func (srv *Server) populateCaches(msgs []*dns.Message) { - n := 0 - for x := 0; x < len(msgs); x++ { - ok := srv.cw.upsert(msgs[x], true) - if ok { - n++ - } - msgs[x] = nil + srv = &Server{ + dns: dnsServer, + opts: opts, } - fmt.Printf("== %d record cached\n", n) -} - -// -// ServeDNS handle DNS request from server. -// -func (srv *Server) ServeDNS(req *dns.Request) { - srv.reqQueue <- req + return srv, nil } // // Start the server, waiting for DNS query from clients, read it and response // it. // -func (srv *Server) Start() error { - fmt.Printf("= Listening on '%s:%d'\n", srv.opts.ListenAddress, - srv.opts.ListenPort) - - err := srv.runForwarders() - if err != nil { - return err - } - - if len(srv.opts.DoHCert) > 0 && len(srv.opts.DoHCertKey) > 0 { - fmt.Printf("= DoH listening on '%s:%d'\n", - srv.opts.ListenAddress, srv.opts.DoHPort) +func (srv *Server) Start() (err error) { + fmt.Printf("= Listening on '%s:%d'\n", srv.opts.IPAddress, + srv.opts.Port) - err = srv.runDoHForwarders() + if len(srv.opts.FileResolvConf) > 0 { + _, err = libio.NewWatcher(srv.opts.FileResolvConf, 0, srv.watchResolvConf) if err != nil { - return err + log.Fatal("rescached: Start:", err) } } - if len(srv.opts.FileResolvConf) > 0 { - go srv.watchResolvConf() - } - - go srv.cw.start() - go srv.processRequestQueue() - - serverOptions := &dns.ServerOptions{ - IPAddress: srv.opts.ListenAddress, - UDPPort: srv.opts.ListenPort, - TCPPort: srv.opts.ListenPort, - DoHPort: srv.opts.DoHPort, - DoHCert: srv.opts.DoHCert, - DoHCertKey: srv.opts.DoHCertKey, - DoHAllowInsecure: srv.opts.DoHAllowInsecure, - } + srv.dns.Start() + srv.dns.Wait() - err = srv.dnsServer.ListenAndServe(serverOptions) - - return err + return nil } // @@ -318,219 +111,21 @@ func (srv *Server) WritePID() error { return err } -func (srv *Server) runForwarders() (err error) { - max := _maxForwarder - - fmt.Printf("= Name servers: %v\n", srv.nsParents) - - if len(srv.nsParents) > max { - max = len(srv.nsParents) - } - - for x := 0; x < max; x++ { - var ( - cl dns.Client - raddr *net.UDPAddr - ) - - nsIdx := x % len(srv.nsParents) - raddr = srv.nsParents[nsIdx] - - if srv.opts.ConnType == dns.ConnTypeUDP { - cl, err = dns.NewUDPClient(raddr.String()) - if err != nil { - log.Fatal("runForwarders: NewUDPClient:", err) - return - } - } - - go srv.processForwardQueue(cl, raddr) - } - return -} - -func (srv *Server) runDoHForwarders() error { - fmt.Printf("= DoH name servers: %v\n", srv.opts.DoHParents) - - for x := 0; x < len(srv.opts.DoHParents); x++ { - cl, err := dns.NewDoHClient(srv.opts.DoHParents[x], srv.opts.DoHAllowInsecure) - if err != nil { - log.Fatal("runDoHForwarders: NewDoHClient:", err) - return err - } - - go srv.processDoHForwardQueue(cl) - } - - return nil -} - -func (srv *Server) stopForwarders() { - srv.fwStop <- true -} - -// -// processRequest process request from any connection, forward it to parent -// name server if no response from cache or if cache is expired; or send the -// cached response back to request. -// -func (srv *Server) processRequest(req *dns.Request) { - if req == nil { +func (srv *Server) watchResolvConf(ns *libio.NodeState) { + switch ns.State { + case libio.FileStateDeleted: + log.Printf("= ResolvConf: file %q deleted\n", srv.opts.FileResolvConf) return - } - if debug.Value >= 1 { - fmt.Printf("< request: Kind:%-4s ID:%-5d %s\n", - dns.ConnTypeNames[req.Kind], - req.Message.Header.ID, req.Message.Question) - } - - // Check if request query name exist in cache. - libbytes.ToLower(&req.Message.Question.Name) - qname := string(req.Message.Question.Name) - _, res := srv.cw.caches.get(qname, req.Message.Question.Type, req.Message.Question.Class) - if res == nil || res.isExpired() { - if req.Kind == dns.ConnTypeDoH { - srv.fwDoHQueue <- req - } else { - srv.fwQueue <- req - } - return - } - - srv.processRequestResponse(req, res.message) - - // Ignore update on local caches - if res.receivedAt == 0 { - if debug.Value >= 1 { - fmt.Printf("= local : ID:%-5d %s\n", - res.message.Header.ID, res.message.Question) - } - } else { - if debug.Value >= 1 { - fmt.Printf("= cache : Total:%-4d ID:%-5d %s\n", - srv.cw.cachesList.length(), - res.message.Header.ID, res.message.Question) - } - - srv.cw.cachesList.fix(res) - } -} - -func (srv *Server) processRequestResponse(req *dns.Request, res *dns.Message) { - res.SetID(req.Message.Header.ID) - - switch req.Kind { - case dns.ConnTypeUDP, dns.ConnTypeTCP: - if req.Sender != nil { - _, err := req.Sender.Send(res, req.UDPAddr) - if err != nil { - log.Println("! processRequest: Sender.Send:", err) - } - } - - case dns.ConnTypeDoH: - if req.ResponseWriter != nil { - _, err := req.ResponseWriter.Write(res.Packet) - if err != nil { - log.Println("! processRequest: ResponseWriter.Write:", err) - } - req.ChanResponded <- true - } - } -} - -func (srv *Server) processRequestQueue() { - for req := range srv.reqQueue { - srv.processRequest(req) - } -} - -func (srv *Server) processForwardQueue(cl dns.Client, raddr net.Addr) { - for { - select { - case req := <-srv.fwQueue: - var ( - err error - res *dns.Message - ) - - switch srv.opts.ConnType { - case dns.ConnTypeUDP: - res, err = cl.Query(req.Message, raddr) - - case dns.ConnTypeTCP: - cl, err = dns.NewTCPClient(raddr.String()) - if err != nil { - continue - } - - res, err = cl.Query(req.Message, nil) - - cl.Close() - } - if err != nil { - continue - } - - srv.processForwardResponse(req, res) - - case <-srv.fwStop: - return - } - } -} - -func (srv *Server) processDoHForwardQueue(cl *dns.DoHClient) { - for req := range srv.fwDoHQueue { - res, err := cl.Query(req.Message, nil) + default: + ok, err := srv.opts.loadResolvConf() if err != nil { - continue + log.Println("rescached: loadResolvConf: " + err.Error()) + break } - - srv.processForwardResponse(req, res) - } -} - -func (srv *Server) processForwardResponse(req *dns.Request, res *dns.Message) { - if bytes.Equal(req.Message.Question.Name, res.Question.Name) { - if req.Message.Question.Type != res.Question.Type { - return - } - } - - srv.processRequestResponse(req, res) - - srv.cw.upsertQueue <- res -} - -func (srv *Server) watchResolvConf() { - watcher, err := libio.NewWatcher(srv.opts.FileResolvConf, 0) - if err != nil { - log.Fatal("! watchResolvConf: ", err) - } - - for fi := range watcher.C { - if fi == nil { - if srv.nsParents[0] == srv.opts.NSParents[0] { - continue - } - - log.Printf("= ResolvConf: file '%s' deleted\n", - srv.opts.FileResolvConf) - - srv.nsParents = srv.opts.NSParents - } else { - err := srv.loadResolvConf() - if err != nil { - log.Printf("! loadResolvConf: %s\n", err) - srv.nsParents = srv.opts.NSParents - } + if !ok { + break } - srv.stopForwarders() - err = srv.runForwarders() - if err != nil { - log.Printf("! watchResolvConf: %s\n", err) - } + srv.dns.RestartForwarders(srv.opts.NameServers, srv.opts.FallbackNS) } } |
