aboutsummaryrefslogtreecommitdiff
path: root/cmd/resolver/resolver.go
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2022-04-14 01:33:40 +0700
committerShulhan <ms@kilabit.info>2022-04-14 01:33:40 +0700
commit6ead9209bbc431b33ccbdebfffe98ffd23e4f9a7 (patch)
treefbf3cebda822e168bab1f6ee8623295be43da532 /cmd/resolver/resolver.go
parentbb7ec2c3667a45f9a534b3f568cb5148f828a1a9 (diff)
downloadrescached-6ead9209bbc431b33ccbdebfffe98ffd23e4f9a7.tar.xz
cmd/resolver: refactor the resolver as client of DNS and rescached
Previously, the resolver command only for querying DNS server. In this changes and in the future, the resolver command will be client for DNS and rescached server.
Diffstat (limited to 'cmd/resolver/resolver.go')
-rw-r--r--cmd/resolver/resolver.go243
1 files changed, 243 insertions, 0 deletions
diff --git a/cmd/resolver/resolver.go b/cmd/resolver/resolver.go
new file mode 100644
index 0000000..19cebaf
--- /dev/null
+++ b/cmd/resolver/resolver.go
@@ -0,0 +1,243 @@
+// SPDX-FileCopyrightText: 2018 M. Shulhan <ms@kilabit.info>
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+package main
+
+import (
+ "fmt"
+ "log"
+ "math/rand"
+ "strings"
+ "time"
+
+ "github.com/shuLhan/share/lib/dns"
+ libnet "github.com/shuLhan/share/lib/net"
+)
+
+const (
+ defAttempts = 1
+ defQueryType = "A"
+ defQueryClass = "IN"
+ defResolvConf = "/etc/resolv.conf"
+ defTimeout = 5 * time.Second
+)
+
+type resolver struct {
+ conf *libnet.ResolvConf
+ dnsc dns.Client
+
+ cmd string
+ qname string
+ sqtype string
+ sqclass string
+
+ nameserver string
+ qtype dns.RecordType
+ qclass dns.RecordClass
+
+ insecure bool
+}
+
+func (rsol *resolver) doCmdQuery(args []string) {
+ var (
+ maxAttempts = defAttempts
+ timeout = defTimeout
+
+ res *dns.Message
+ qname string
+ queries []string
+ nAttempts int
+ err error
+ ok bool
+ )
+
+ rsol.qname = args[0]
+
+ switch len(args) {
+ case 1:
+ rsol.sqtype = defQueryType
+ rsol.sqclass = defQueryClass
+
+ case 2:
+ rsol.sqtype = args[1]
+ rsol.sqclass = defQueryClass
+
+ case 3:
+ rsol.sqtype = args[1]
+ rsol.sqclass = args[2]
+ }
+
+ rsol.sqtype = strings.ToUpper(rsol.sqtype)
+ rsol.qtype, ok = dns.RecordTypes[rsol.sqtype]
+ if !ok {
+ log.Fatalf("resolver: invalid query type: %q", rsol.sqtype)
+ }
+
+ rsol.sqclass = strings.ToUpper(rsol.sqclass)
+ rsol.qclass, ok = dns.RecordClasses[rsol.sqclass]
+ if !ok {
+ log.Fatalf("resolver: invalid query class: %q", rsol.sqclass)
+ }
+
+ fmt.Printf("= options: %+v\n", rsol)
+
+ if len(rsol.nameserver) == 0 {
+ // Use the nameserver and configuration from resolv.conf.
+ err = rsol.initSystemResolver()
+ if err != nil {
+ log.Fatalf("resolver: %s", err)
+ }
+
+ fmt.Printf("= resolv.conf: %+v\n", rsol.conf)
+
+ queries = populateQueries(rsol.conf, rsol.qname)
+ timeout = time.Duration(rsol.conf.Timeout) * time.Second
+ maxAttempts = rsol.conf.Attempts
+ } else {
+ rsol.dnsc, err = dns.NewClient(rsol.nameserver, rsol.insecure)
+ if err != nil {
+ log.Fatalf("resolver: %s", err)
+ }
+
+ queries = append(queries, rsol.qname)
+ }
+
+ for _, qname = range queries {
+ for nAttempts = 0; nAttempts < maxAttempts; nAttempts++ {
+ fmt.Printf("< Query %s at %s\n", qname, rsol.dnsc.RemoteAddr())
+
+ res, err = rsol.query(timeout, qname)
+ if err != nil {
+ log.Printf("resolver: %s", err)
+ continue
+ }
+
+ printQueryResponse(rsol.dnsc.RemoteAddr(), res)
+ return
+ }
+ }
+}
+
+//
+// initSystemResolver read the system resolv.conf to create fallback DNS
+// resolver.
+//
+func (rsol *resolver) initSystemResolver() (err error) {
+ var (
+ logp = "initSystemResolver"
+
+ ns string
+ )
+
+ rsol.conf, err = libnet.NewResolvConf(defResolvConf)
+ if err != nil {
+ return fmt.Errorf("%s: %w", logp, err)
+ }
+
+ if len(rsol.conf.NameServers) == 0 {
+ ns = "127.0.0.1:53"
+ } else {
+ ns = rsol.conf.NameServers[0]
+ }
+
+ rsol.dnsc, err = dns.NewUDPClient(ns)
+ if err != nil {
+ return fmt.Errorf("%s: %w", logp, err)
+ }
+ return nil
+}
+
+func (rsol *resolver) query(timeout time.Duration, qname string) (res *dns.Message, err error) {
+ var (
+ logp = "query"
+ req = dns.NewMessage()
+ )
+
+ rand.Seed(time.Now().Unix())
+
+ rsol.dnsc.SetTimeout(timeout)
+
+ req.Header.ID = uint16(rand.Intn(65535))
+ req.Question.Name = qname
+ req.Question.Type = rsol.qtype
+ req.Question.Class = rsol.qclass
+ _, err = req.Pack()
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s: %w", logp, qname, err)
+ }
+
+ res, err = rsol.dnsc.Query(req)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s: %w", logp, qname, err)
+ }
+
+ return res, nil
+}
+
+func populateQueries(cr *libnet.ResolvConf, qname string) (queries []string) {
+ ndots := 0
+
+ for _, c := range qname {
+ if c == '.' {
+ ndots++
+ continue
+ }
+ }
+
+ if ndots >= cr.NDots {
+ queries = append(queries, qname)
+ } else {
+ if len(cr.Domain) > 0 {
+ queries = append(queries, qname+"."+cr.Domain)
+ }
+ for _, s := range cr.Search {
+ queries = append(queries, qname+"."+s)
+ }
+ }
+
+ return
+}
+
+func printQueryResponse(nameserver string, msg *dns.Message) {
+ var b strings.Builder
+
+ fmt.Fprintf(&b, "> From: %s", nameserver)
+ fmt.Fprintf(&b, "\n> Header: %+v", msg.Header)
+ fmt.Fprintf(&b, "\n> Question: %s", msg.Question.String())
+
+ b.WriteString("\n> Status: ")
+ switch msg.Header.RCode {
+ case dns.RCodeOK:
+ b.WriteString("OK")
+ case dns.RCodeErrFormat:
+ b.WriteString("Invalid request format")
+ case dns.RCodeErrServer:
+ b.WriteString("Server internal failure")
+ case dns.RCodeErrName:
+ fmt.Fprintf(&b, "Domain name with type %s and class %s did not exist",
+ dns.RecordTypeNames[msg.Question.Type],
+ dns.RecordClassName[msg.Question.Class])
+ case dns.RCodeNotImplemented:
+ b.WriteString(" Unknown query")
+ case dns.RCodeRefused:
+ b.WriteString(" Server refused the request")
+ }
+
+ for x, rr := range msg.Answer {
+ fmt.Fprintf(&b, "\n> Answer #%d:", x+1)
+ fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String())
+ fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value)
+ }
+ for x, rr := range msg.Authority {
+ fmt.Fprintf(&b, "\n> Authority #%d:", x+1)
+ fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String())
+ fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value)
+ }
+ for x, rr := range msg.Additional {
+ fmt.Fprintf(&b, "\n> Additional #%d:", x+1)
+ fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String())
+ fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value)
+ }
+
+ fmt.Println(b.String())
+}