diff options
| author | Shulhan <ms@kilabit.info> | 2022-04-14 01:33:40 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2022-04-14 01:33:40 +0700 |
| commit | 6ead9209bbc431b33ccbdebfffe98ffd23e4f9a7 (patch) | |
| tree | fbf3cebda822e168bab1f6ee8623295be43da532 /cmd/resolver/resolver.go | |
| parent | bb7ec2c3667a45f9a534b3f568cb5148f828a1a9 (diff) | |
| download | rescached-6ead9209bbc431b33ccbdebfffe98ffd23e4f9a7.tar.xz | |
cmd/resolver: refactor the resolver as client of DNS and rescached
Previously, the resolver command only for querying DNS server.
In this changes and in the future, the resolver command will be client
for DNS and rescached server.
Diffstat (limited to 'cmd/resolver/resolver.go')
| -rw-r--r-- | cmd/resolver/resolver.go | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/cmd/resolver/resolver.go b/cmd/resolver/resolver.go new file mode 100644 index 0000000..19cebaf --- /dev/null +++ b/cmd/resolver/resolver.go @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: 2018 M. Shulhan <ms@kilabit.info> +// SPDX-License-Identifier: GPL-3.0-or-later + +package main + +import ( + "fmt" + "log" + "math/rand" + "strings" + "time" + + "github.com/shuLhan/share/lib/dns" + libnet "github.com/shuLhan/share/lib/net" +) + +const ( + defAttempts = 1 + defQueryType = "A" + defQueryClass = "IN" + defResolvConf = "/etc/resolv.conf" + defTimeout = 5 * time.Second +) + +type resolver struct { + conf *libnet.ResolvConf + dnsc dns.Client + + cmd string + qname string + sqtype string + sqclass string + + nameserver string + qtype dns.RecordType + qclass dns.RecordClass + + insecure bool +} + +func (rsol *resolver) doCmdQuery(args []string) { + var ( + maxAttempts = defAttempts + timeout = defTimeout + + res *dns.Message + qname string + queries []string + nAttempts int + err error + ok bool + ) + + rsol.qname = args[0] + + switch len(args) { + case 1: + rsol.sqtype = defQueryType + rsol.sqclass = defQueryClass + + case 2: + rsol.sqtype = args[1] + rsol.sqclass = defQueryClass + + case 3: + rsol.sqtype = args[1] + rsol.sqclass = args[2] + } + + rsol.sqtype = strings.ToUpper(rsol.sqtype) + rsol.qtype, ok = dns.RecordTypes[rsol.sqtype] + if !ok { + log.Fatalf("resolver: invalid query type: %q", rsol.sqtype) + } + + rsol.sqclass = strings.ToUpper(rsol.sqclass) + rsol.qclass, ok = dns.RecordClasses[rsol.sqclass] + if !ok { + log.Fatalf("resolver: invalid query class: %q", rsol.sqclass) + } + + fmt.Printf("= options: %+v\n", rsol) + + if len(rsol.nameserver) == 0 { + // Use the nameserver and configuration from resolv.conf. + err = rsol.initSystemResolver() + if err != nil { + log.Fatalf("resolver: %s", err) + } + + fmt.Printf("= resolv.conf: %+v\n", rsol.conf) + + queries = populateQueries(rsol.conf, rsol.qname) + timeout = time.Duration(rsol.conf.Timeout) * time.Second + maxAttempts = rsol.conf.Attempts + } else { + rsol.dnsc, err = dns.NewClient(rsol.nameserver, rsol.insecure) + if err != nil { + log.Fatalf("resolver: %s", err) + } + + queries = append(queries, rsol.qname) + } + + for _, qname = range queries { + for nAttempts = 0; nAttempts < maxAttempts; nAttempts++ { + fmt.Printf("< Query %s at %s\n", qname, rsol.dnsc.RemoteAddr()) + + res, err = rsol.query(timeout, qname) + if err != nil { + log.Printf("resolver: %s", err) + continue + } + + printQueryResponse(rsol.dnsc.RemoteAddr(), res) + return + } + } +} + +// +// initSystemResolver read the system resolv.conf to create fallback DNS +// resolver. +// +func (rsol *resolver) initSystemResolver() (err error) { + var ( + logp = "initSystemResolver" + + ns string + ) + + rsol.conf, err = libnet.NewResolvConf(defResolvConf) + if err != nil { + return fmt.Errorf("%s: %w", logp, err) + } + + if len(rsol.conf.NameServers) == 0 { + ns = "127.0.0.1:53" + } else { + ns = rsol.conf.NameServers[0] + } + + rsol.dnsc, err = dns.NewUDPClient(ns) + if err != nil { + return fmt.Errorf("%s: %w", logp, err) + } + return nil +} + +func (rsol *resolver) query(timeout time.Duration, qname string) (res *dns.Message, err error) { + var ( + logp = "query" + req = dns.NewMessage() + ) + + rand.Seed(time.Now().Unix()) + + rsol.dnsc.SetTimeout(timeout) + + req.Header.ID = uint16(rand.Intn(65535)) + req.Question.Name = qname + req.Question.Type = rsol.qtype + req.Question.Class = rsol.qclass + _, err = req.Pack() + if err != nil { + return nil, fmt.Errorf("%s: %s: %w", logp, qname, err) + } + + res, err = rsol.dnsc.Query(req) + if err != nil { + return nil, fmt.Errorf("%s: %s: %w", logp, qname, err) + } + + return res, nil +} + +func populateQueries(cr *libnet.ResolvConf, qname string) (queries []string) { + ndots := 0 + + for _, c := range qname { + if c == '.' { + ndots++ + continue + } + } + + if ndots >= cr.NDots { + queries = append(queries, qname) + } else { + if len(cr.Domain) > 0 { + queries = append(queries, qname+"."+cr.Domain) + } + for _, s := range cr.Search { + queries = append(queries, qname+"."+s) + } + } + + return +} + +func printQueryResponse(nameserver string, msg *dns.Message) { + var b strings.Builder + + fmt.Fprintf(&b, "> From: %s", nameserver) + fmt.Fprintf(&b, "\n> Header: %+v", msg.Header) + fmt.Fprintf(&b, "\n> Question: %s", msg.Question.String()) + + b.WriteString("\n> Status: ") + switch msg.Header.RCode { + case dns.RCodeOK: + b.WriteString("OK") + case dns.RCodeErrFormat: + b.WriteString("Invalid request format") + case dns.RCodeErrServer: + b.WriteString("Server internal failure") + case dns.RCodeErrName: + fmt.Fprintf(&b, "Domain name with type %s and class %s did not exist", + dns.RecordTypeNames[msg.Question.Type], + dns.RecordClassName[msg.Question.Class]) + case dns.RCodeNotImplemented: + b.WriteString(" Unknown query") + case dns.RCodeRefused: + b.WriteString(" Server refused the request") + } + + for x, rr := range msg.Answer { + fmt.Fprintf(&b, "\n> Answer #%d:", x+1) + fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String()) + fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value) + } + for x, rr := range msg.Authority { + fmt.Fprintf(&b, "\n> Authority #%d:", x+1) + fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String()) + fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value) + } + for x, rr := range msg.Additional { + fmt.Fprintf(&b, "\n> Additional #%d:", x+1) + fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String()) + fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value) + } + + fmt.Println(b.String()) +} |
