diff options
| author | Shulhan <ms@kilabit.info> | 2021-11-13 14:38:59 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2021-11-13 14:41:10 +0700 |
| commit | 0520e7df91272447fdf0dd1be4b61fe75dea88fa (patch) | |
| tree | e6f66781228d25adfa78bbfda816e8da8efc62b7 | |
| parent | 6b2363eaf194328d7eab9a85f2846b3cc1c617d4 (diff) | |
| download | pakakeh.go-0520e7df91272447fdf0dd1be4b61fe75dea88fa.tar.xz | |
lib/dns: remove internal field off from ResourceRecord
The off field previously used to record the next index after parsing
domain name with offset. Since the field only used once, after calling
unpackDomainName, we return the off value instead as "end".
This changes make the unpackDomainName() become a function.
This changes rename field offTTL to idxTTL to match with their value.
This changes alos pefix all errors with the method or function names.
| -rw-r--r-- | lib/dns/message.go | 16 | ||||
| -rw-r--r-- | lib/dns/resource_record.go | 188 |
2 files changed, 104 insertions, 100 deletions
diff --git a/lib/dns/message.go b/lib/dns/message.go index a7c32323..46130051 100644 --- a/lib/dns/message.go +++ b/lib/dns/message.go @@ -366,7 +366,7 @@ func (msg *Message) packRR(rr *ResourceRecord) { } } - rr.offTTL = uint(msg.off) + rr.idxTTL = uint(msg.off) msg.packet = libbytes.AppendUint32(msg.packet, rr.TTL) msg.off += 4 @@ -809,7 +809,7 @@ func (msg *Message) SubTTL(n uint32) { } else { msg.Answer[x].TTL -= n } - libbytes.WriteUint32(msg.packet, msg.Answer[x].offTTL, + libbytes.WriteUint32(msg.packet, msg.Answer[x].idxTTL, msg.Answer[x].TTL) } for x := 0; x < len(msg.Authority); x++ { @@ -818,7 +818,7 @@ func (msg *Message) SubTTL(n uint32) { } else { msg.Authority[x].TTL -= n } - libbytes.WriteUint32(msg.packet, msg.Authority[x].offTTL, + libbytes.WriteUint32(msg.packet, msg.Authority[x].idxTTL, msg.Authority[x].TTL) } for x := 0; x < len(msg.Additional); x++ { @@ -830,7 +830,7 @@ func (msg *Message) SubTTL(n uint32) { } else { msg.Additional[x].TTL -= n } - libbytes.WriteUint32(msg.packet, msg.Additional[x].offTTL, + libbytes.WriteUint32(msg.packet, msg.Additional[x].idxTTL, msg.Additional[x].TTL) } } @@ -938,16 +938,16 @@ func (msg *Message) Unpack() (err error) { // message. // func (msg *Message) UnpackHeaderQuestion() (err error) { + if len(msg.packet) <= sectionHeaderSize { + return fmt.Errorf("UnpackHeaderQuestion: missing question") + } + msg.Header.unpack(msg.packet) if debug.Value >= 3 { log.Printf("msg.Header: %+v\n", msg.Header) } - if len(msg.packet) <= sectionHeaderSize { - return fmt.Errorf("Message.UnpackHeaderQuestion: missing question") - } - err = msg.Question.unpack(msg.packet[sectionHeaderSize:]) if err != nil { return err diff --git a/lib/dns/resource_record.go b/lib/dns/resource_record.go index d9ebc211..8ee65285 100644 --- a/lib/dns/resource_record.go +++ b/lib/dns/resource_record.go @@ -5,7 +5,6 @@ package dns import ( - "bytes" "errors" "fmt" "log" @@ -40,7 +39,7 @@ type ResourceRecord struct { // cached. TTL uint32 - // Value hold the generic value for all record types. + // Value hold the generic value based on Type. Value interface{} // An unsigned 16 bit integer that specifies the length in octets of @@ -54,20 +53,15 @@ type ResourceRecord struct { // address. rdata []byte - off uint - offTTL uint + idxTTL uint // mark the index position of TTL field inside packet. } // // String return the text representation of ResourceRecord for human. // func (rr *ResourceRecord) String() string { - var buf bytes.Buffer - - fmt.Fprintf(&buf, "{Name:%s Type:%d Class:%d TTL:%d rdlen:%d}", + return fmt.Sprintf("{Name:%s Type:%d Class:%d TTL:%d rdlen:%d}", rr.Name, rr.Type, rr.Class, rr.TTL, rr.rdlen) - - return buf.String() } // @@ -76,8 +70,12 @@ func (rr *ResourceRecord) String() string { // type is not match with its value. // func (rr *ResourceRecord) initAndValidate() (err error) { + var ( + logp = "initAndValidate" + ) + if len(rr.Name) == 0 { - return errors.New("empty RR name") + return fmt.Errorf("%s: empty Name", logp) } if rr.Class == 0 { rr.Class = QueryClassIN @@ -86,20 +84,23 @@ func (rr *ResourceRecord) initAndValidate() (err error) { rr.TTL = defaultTTL } + qtype, ok := QueryTypeNames[rr.Type] + if !ok { + return fmt.Errorf("%s: unknown type %d", logp, rr.Type) + } switch rr.Type { case QueryTypeA: v, ok := rr.Value.(string) if !ok { - return fmt.Errorf("RR A: expecting Value as string got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } - s := string(v) - ip := net.ParseIP(s) + ip := net.ParseIP(v) if ip == nil { - return fmt.Errorf("RR A: invalid or empty IP address %q", s) + return fmt.Errorf("%s: invalid or empty %s: %q", logp, qtype, v) } ipv4 := ip.To4() if ipv4 == nil { - return fmt.Errorf("RR A: invalid or empty IPv4 address %q", s) + return fmt.Errorf("%s: invalid or empty %s: %q", logp, qtype, v) } case QueryTypeNS, QueryTypeCNAME, QueryTypeMB, QueryTypeMG, @@ -107,105 +108,102 @@ func (rr *ResourceRecord) initAndValidate() (err error) { v, ok := rr.Value.(string) if !ok { - return fmt.Errorf("RR %s: expecting Value as string got %T", - QueryTypeNames[rr.Type], rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } if !libnet.IsHostnameValid([]byte(v), true) { - return fmt.Errorf("RR %s: invalid or empty value: %q", - QueryTypeNames[rr.Type], v) + return fmt.Errorf("%s: invalid or empty %s: %q", logp, qtype, v) } case QueryTypeSOA: soa, ok := rr.Value.(*RDataSOA) if !ok { - return fmt.Errorf("RR SOA: expecting RDataSOA got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } if !libnet.IsHostnameValid([]byte(soa.MName), true) { - return fmt.Errorf("RR SOA: invalid or empty MName: %q", soa.MName) + return fmt.Errorf("%s: invalid or empty %s MName: %q", logp, qtype, soa.MName) } if !libnet.IsHostnameValid([]byte(soa.RName), true) { - return fmt.Errorf("RR SOA: invalid or empty RName: %q", soa.RName) + return fmt.Errorf("%s: invalid or empty %s RName: %q", logp, qtype, soa.RName) } case QueryTypeWKS: _, ok := rr.Value.(*RDataWKS) if !ok { - return fmt.Errorf("RR WKS: expecting WKS got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } case QueryTypeHINFO: _, ok := rr.Value.(*RDataHINFO) if !ok { - return fmt.Errorf("RR HINFO: expecting HINFO got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } case QueryTypeMINFO: _, ok := rr.Value.(*RDataMINFO) if !ok { - return fmt.Errorf("RR MINFO: expecting MINFO got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } case QueryTypeMX: mx, ok := rr.Value.(*RDataMX) if !ok { - return fmt.Errorf("RR MX: expecting MX got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } err = mx.initAndValidate() if err != nil { - return err + return fmt.Errorf("%s: %w", logp, err) } case QueryTypeTXT: txt, ok := rr.Value.(string) if !ok { - return fmt.Errorf("RR TXT: expecting string got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } if len(txt) == 0 { - return errors.New("empty RR TXT value") + return fmt.Errorf("%s: empty %s value", logp, qtype) } case QueryTypeSRV: srv, ok := rr.Value.(*RDataSRV) if !ok { - return fmt.Errorf("RR SRV: expecting SRV got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } err = srv.initAndValidate() if err != nil { - return err + return fmt.Errorf("%s: %w", logp, err) } case QueryTypeAAAA: v, ok := rr.Value.(string) if !ok { - return fmt.Errorf("RR AAAA: expecting AAAA got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } - s := string(v) - ip := net.ParseIP(s) + ip := net.ParseIP(v) if ip == nil { - return fmt.Errorf("RR AAAA: invalid or empty IPv6 address: %q", s) + return fmt.Errorf("%s: invalid or empty %s value: %q", logp, qtype, v) } ipv6 := ip.To16() if ipv6 == nil { - return fmt.Errorf("RR AAAA: invalid or empty IPv6 address: %q", s) + return fmt.Errorf("%s: invalid or empty %s value: %q", logp, qtype, v) } case QueryTypeOPT: _, ok := rr.Value.(*RDataOPT) if !ok { - return fmt.Errorf("RR OPT: expecting OPT got %T", rr.Value) + return fmt.Errorf("%s: expecting %s got %T", logp, qtype, rr.Value) } - default: - return fmt.Errorf("unknown RR type %d", rr.Type) } return nil } // -// unpack the DNS resource record from DNS packet start from index `startIdx`. +// unpack the resource record from packet start from index startIdx. // func (rr *ResourceRecord) unpack(packet []byte, startIdx uint) (x uint, err error) { + var end uint + x = startIdx - rr.Name, err = rr.unpackDomainName(packet, x) + rr.Name, end, err = unpackDomainName(packet, x) if err != nil { - return x, err + return x, fmt.Errorf("unpack: %w", err) } - if rr.off > 0 { - x = rr.off + 1 + if end > 0 { + x = end } else { if len(rr.Name) == 0 { x++ @@ -218,7 +216,7 @@ func (rr *ResourceRecord) unpack(packet []byte, startIdx uint) (x uint, err erro x += 2 rr.Class = libbytes.ReadUint16(packet, x) x += 2 - rr.offTTL = x + rr.idxTTL = x rr.TTL = libbytes.ReadUint32(packet, x) x += 4 rr.rdlen = libbytes.ReadUint16(packet, x) @@ -227,41 +225,45 @@ func (rr *ResourceRecord) unpack(packet []byte, startIdx uint) (x uint, err erro rr.rdata = append(rr.rdata, packet[x:x+uint(rr.rdlen)]...) err = rr.unpackRData(packet, x) + if err != nil { + return x, fmt.Errorf("unpack: %w", err) + } x += uint(rr.rdlen) return x, err } -func (rr *ResourceRecord) unpackDomainName(packet []byte, start uint) ( - string, error, -) { - var out strings.Builder - - x := int(start) +func unpackDomainName(packet []byte, start uint) (name string, end uint, err error) { + var ( + x = int(start) + out strings.Builder + count, y byte + ) for x < len(packet) { - count := packet[x] + count = packet[x] if count == 0 { break } if (packet[x] & maskPointer) == maskPointer { offset := uint16(packet[x]&maskOffset)<<8 | uint16(packet[x+1]) - if rr.off == 0 { - rr.off = uint(x + 1) + if end == 0 { + end = uint(x + 2) } + // Jump to index defined by offset. x = int(offset) continue } if count > maxLabelSize { - return "", ErrLabelSizeLimit + return "", end, fmt.Errorf("unpackDomainName: at %d: %w", x, ErrLabelSizeLimit) } if out.Len() > 0 { out.WriteByte('.') } x++ - for y := byte(0); y < count; y++ { + for y = 0; y < count; y++ { if x >= len(packet) { break } @@ -272,7 +274,7 @@ func (rr *ResourceRecord) unpackDomainName(packet []byte, start uint) ( x++ } } - return out.String(), nil + return out.String(), end, nil } func (rr *ResourceRecord) unpackRData(packet []byte, startIdx uint) (err error) { @@ -294,7 +296,7 @@ func (rr *ResourceRecord) unpackRData(packet []byte, startIdx uint) (err error) // class protocols. // case QueryTypeNS: - rr.Value, err = rr.unpackDomainName(packet, startIdx) + rr.Value, _, err = unpackDomainName(packet, startIdx) return err // MD is obsolete. See the definition of MX and [RFC-974] for details of @@ -314,22 +316,22 @@ func (rr *ResourceRecord) unpackRData(packet []byte, startIdx uint) (err error) // cases. See the description of name server logic in [RFC-1034] for // details. case QueryTypeCNAME: - rr.Value, err = rr.unpackDomainName(packet, startIdx) + rr.Value, _, err = unpackDomainName(packet, startIdx) return err case QueryTypeSOA: return rr.unpackSOA(packet, startIdx) case QueryTypeMB: - rr.Value, err = rr.unpackDomainName(packet, startIdx) + rr.Value, _, err = unpackDomainName(packet, startIdx) return err case QueryTypeMG: - rr.Value, err = rr.unpackDomainName(packet, startIdx) + rr.Value, _, err = unpackDomainName(packet, startIdx) return err case QueryTypeMR: - rr.Value, err = rr.unpackDomainName(packet, startIdx) + rr.Value, _, err = unpackDomainName(packet, startIdx) return err // NULL records cause no additional section processing. @@ -347,7 +349,7 @@ func (rr *ResourceRecord) unpackRData(packet []byte, startIdx uint) (err error) return rrWKS.unpack(packet[startIdx:endIdx]) case QueryTypePTR: - rr.Value, err = rr.unpackDomainName(packet, startIdx) + rr.Value, _, err = unpackDomainName(packet, startIdx) return err case QueryTypeHINFO: @@ -409,26 +411,28 @@ func (rr *ResourceRecord) unpackAAAA() error { } func (rr *ResourceRecord) unpackMInfo(packet []byte, startIdx uint) (err error) { - x := startIdx - rr.off = 0 + var ( + logp = "unpackMInfo" + rrMInfo = &RDataMINFO{} + x = startIdx + end uint + ) - rrMInfo := &RDataMINFO{} rr.Value = rrMInfo - rrMInfo.RMailBox, err = rr.unpackDomainName(packet, x) + rrMInfo.RMailBox, end, err = unpackDomainName(packet, x) if err != nil { - return err + return fmt.Errorf("%s: %w", logp, err) } - if rr.off > 0 { - x = rr.off + 1 - rr.off = 0 + if end > 0 { + x = end } else { x += uint(len(rrMInfo.RMailBox) + 2) } - rrMInfo.EmailBox, err = rr.unpackDomainName(packet, x) + rrMInfo.EmailBox, _, err = unpackDomainName(packet, x) if err != nil { - return err + return fmt.Errorf("%s: %w", logp, err) } return nil @@ -440,8 +444,7 @@ func (rr *ResourceRecord) unpackMX(packet []byte, startIdx uint) (err error) { rrMX.Preference = libbytes.ReadInt16(packet, startIdx) - rr.off = 0 - rrMX.Exchange, err = rr.unpackDomainName(packet, startIdx+2) + rrMX.Exchange, _, err = unpackDomainName(packet, startIdx+2) return err } @@ -455,7 +458,7 @@ func (rr *ResourceRecord) unpackSRV(packet []byte, x uint) (err error) { y := 0 for ; y < len(rr.Name); y++ { if rr.Name[y] == '.' { - rrSRV.Service = string(rr.Name[start:y]) + rrSRV.Service = rr.Name[start:y] break } } @@ -478,7 +481,7 @@ func (rr *ResourceRecord) unpackSRV(packet []byte, x uint) (err error) { rrSRV.Port = libbytes.ReadUint16(packet, x) x += 2 - rrSRV.Target, err = rr.unpackDomainName(packet, x) + rrSRV.Target, _, err = unpackDomainName(packet, x) return } @@ -506,37 +509,38 @@ func (rr *ResourceRecord) unpackOPT(packet []byte, x uint) error { x += 2 endIdx := x + uint(rr.rdlen) if int(endIdx) >= len(packet) { - return errors.New("RR OPT length is out of range") + return errors.New("unpackOPT: data length is out of range") } rrOPT.Data = append(rrOPT.Data, packet[x:endIdx]...) return nil } func (rr *ResourceRecord) unpackSOA(packet []byte, startIdx uint) (err error) { - rrSOA := &RDataSOA{} - rr.Value = rrSOA + var ( + logp = "unpackSOA" + rrSOA = &RDataSOA{} + x = startIdx + end uint + ) - x := startIdx - rr.off = 0 + rr.Value = rrSOA - rrSOA.MName, err = rr.unpackDomainName(packet, x) + rrSOA.MName, end, err = unpackDomainName(packet, x) if err != nil { - return err + return fmt.Errorf("%s: %w", logp, err) } - if rr.off > 0 { - x = rr.off + 1 - rr.off = 0 + if end > 0 { + x = end } else { x += uint(len(rrSOA.MName) + 2) } - rrSOA.RName, err = rr.unpackDomainName(packet, x) + rrSOA.RName, end, err = unpackDomainName(packet, x) if err != nil { - return err + return fmt.Errorf("%s: %w", logp, err) } - if rr.off > 0 { - x = rr.off + 1 - rr.off = 0 + if end > 0 { + x = end } else { x += uint(len(rrSOA.RName) + 2) } |
