diff options
Diffstat (limited to 'cmd/resolver/resolver.go')
| -rw-r--r-- | cmd/resolver/resolver.go | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/cmd/resolver/resolver.go b/cmd/resolver/resolver.go new file mode 100644 index 0000000..19cebaf --- /dev/null +++ b/cmd/resolver/resolver.go @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: 2018 M. Shulhan <ms@kilabit.info> +// SPDX-License-Identifier: GPL-3.0-or-later + +package main + +import ( + "fmt" + "log" + "math/rand" + "strings" + "time" + + "github.com/shuLhan/share/lib/dns" + libnet "github.com/shuLhan/share/lib/net" +) + +const ( + defAttempts = 1 + defQueryType = "A" + defQueryClass = "IN" + defResolvConf = "/etc/resolv.conf" + defTimeout = 5 * time.Second +) + +type resolver struct { + conf *libnet.ResolvConf + dnsc dns.Client + + cmd string + qname string + sqtype string + sqclass string + + nameserver string + qtype dns.RecordType + qclass dns.RecordClass + + insecure bool +} + +func (rsol *resolver) doCmdQuery(args []string) { + var ( + maxAttempts = defAttempts + timeout = defTimeout + + res *dns.Message + qname string + queries []string + nAttempts int + err error + ok bool + ) + + rsol.qname = args[0] + + switch len(args) { + case 1: + rsol.sqtype = defQueryType + rsol.sqclass = defQueryClass + + case 2: + rsol.sqtype = args[1] + rsol.sqclass = defQueryClass + + case 3: + rsol.sqtype = args[1] + rsol.sqclass = args[2] + } + + rsol.sqtype = strings.ToUpper(rsol.sqtype) + rsol.qtype, ok = dns.RecordTypes[rsol.sqtype] + if !ok { + log.Fatalf("resolver: invalid query type: %q", rsol.sqtype) + } + + rsol.sqclass = strings.ToUpper(rsol.sqclass) + rsol.qclass, ok = dns.RecordClasses[rsol.sqclass] + if !ok { + log.Fatalf("resolver: invalid query class: %q", rsol.sqclass) + } + + fmt.Printf("= options: %+v\n", rsol) + + if len(rsol.nameserver) == 0 { + // Use the nameserver and configuration from resolv.conf. + err = rsol.initSystemResolver() + if err != nil { + log.Fatalf("resolver: %s", err) + } + + fmt.Printf("= resolv.conf: %+v\n", rsol.conf) + + queries = populateQueries(rsol.conf, rsol.qname) + timeout = time.Duration(rsol.conf.Timeout) * time.Second + maxAttempts = rsol.conf.Attempts + } else { + rsol.dnsc, err = dns.NewClient(rsol.nameserver, rsol.insecure) + if err != nil { + log.Fatalf("resolver: %s", err) + } + + queries = append(queries, rsol.qname) + } + + for _, qname = range queries { + for nAttempts = 0; nAttempts < maxAttempts; nAttempts++ { + fmt.Printf("< Query %s at %s\n", qname, rsol.dnsc.RemoteAddr()) + + res, err = rsol.query(timeout, qname) + if err != nil { + log.Printf("resolver: %s", err) + continue + } + + printQueryResponse(rsol.dnsc.RemoteAddr(), res) + return + } + } +} + +// +// initSystemResolver read the system resolv.conf to create fallback DNS +// resolver. +// +func (rsol *resolver) initSystemResolver() (err error) { + var ( + logp = "initSystemResolver" + + ns string + ) + + rsol.conf, err = libnet.NewResolvConf(defResolvConf) + if err != nil { + return fmt.Errorf("%s: %w", logp, err) + } + + if len(rsol.conf.NameServers) == 0 { + ns = "127.0.0.1:53" + } else { + ns = rsol.conf.NameServers[0] + } + + rsol.dnsc, err = dns.NewUDPClient(ns) + if err != nil { + return fmt.Errorf("%s: %w", logp, err) + } + return nil +} + +func (rsol *resolver) query(timeout time.Duration, qname string) (res *dns.Message, err error) { + var ( + logp = "query" + req = dns.NewMessage() + ) + + rand.Seed(time.Now().Unix()) + + rsol.dnsc.SetTimeout(timeout) + + req.Header.ID = uint16(rand.Intn(65535)) + req.Question.Name = qname + req.Question.Type = rsol.qtype + req.Question.Class = rsol.qclass + _, err = req.Pack() + if err != nil { + return nil, fmt.Errorf("%s: %s: %w", logp, qname, err) + } + + res, err = rsol.dnsc.Query(req) + if err != nil { + return nil, fmt.Errorf("%s: %s: %w", logp, qname, err) + } + + return res, nil +} + +func populateQueries(cr *libnet.ResolvConf, qname string) (queries []string) { + ndots := 0 + + for _, c := range qname { + if c == '.' { + ndots++ + continue + } + } + + if ndots >= cr.NDots { + queries = append(queries, qname) + } else { + if len(cr.Domain) > 0 { + queries = append(queries, qname+"."+cr.Domain) + } + for _, s := range cr.Search { + queries = append(queries, qname+"."+s) + } + } + + return +} + +func printQueryResponse(nameserver string, msg *dns.Message) { + var b strings.Builder + + fmt.Fprintf(&b, "> From: %s", nameserver) + fmt.Fprintf(&b, "\n> Header: %+v", msg.Header) + fmt.Fprintf(&b, "\n> Question: %s", msg.Question.String()) + + b.WriteString("\n> Status: ") + switch msg.Header.RCode { + case dns.RCodeOK: + b.WriteString("OK") + case dns.RCodeErrFormat: + b.WriteString("Invalid request format") + case dns.RCodeErrServer: + b.WriteString("Server internal failure") + case dns.RCodeErrName: + fmt.Fprintf(&b, "Domain name with type %s and class %s did not exist", + dns.RecordTypeNames[msg.Question.Type], + dns.RecordClassName[msg.Question.Class]) + case dns.RCodeNotImplemented: + b.WriteString(" Unknown query") + case dns.RCodeRefused: + b.WriteString(" Server refused the request") + } + + for x, rr := range msg.Answer { + fmt.Fprintf(&b, "\n> Answer #%d:", x+1) + fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String()) + fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value) + } + for x, rr := range msg.Authority { + fmt.Fprintf(&b, "\n> Authority #%d:", x+1) + fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String()) + fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value) + } + for x, rr := range msg.Additional { + fmt.Fprintf(&b, "\n> Additional #%d:", x+1) + fmt.Fprintf(&b, "\n>> Resource record: %s", rr.String()) + fmt.Fprintf(&b, "\n>> RDATA: %+v", rr.Value) + } + + fmt.Println(b.String()) +} |
