summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2024-03-27 02:24:41 +0700
committerShulhan <ms@kilabit.info>2024-03-27 02:24:41 +0700
commit14325589db35cf36ed1aa71ff4f2c5ad0bb6886b (patch)
treecd6cf46ca219c3084b354b4732e4aef44d1c41eb
parent71eaafc5119b178be61abf6ae7b8a2fbcdfacc44 (diff)
downloadpakakeh.go-14325589db35cf36ed1aa71ff4f2c5ad0bb6886b.tar.xz
lib/dns: refactor [Message.Unpack] to [UnpackMessage]
The previous API for Message is a little bit weird. Its provides creating Message manually, but expose the method [UnpackHeaderQuestion], meanwhile the field packet itself is unexported. In order to make it more clear we refactor [Message.Unpack] to function [UnpackMessage] that accept raw DNS packet.
-rw-r--r--lib/dns/answer_test.go8
-rw-r--r--lib/dns/caches.go4
-rw-r--r--lib/dns/caches_test.go4
-rw-r--r--lib/dns/doh_client.go23
-rw-r--r--lib/dns/dot_client.go24
-rw-r--r--lib/dns/message.go102
-rw-r--r--lib/dns/message_test.go21
-rw-r--r--lib/dns/rdata_svcb.go7
-rw-r--r--lib/dns/server.go7
-rw-r--r--lib/dns/tcp_client.go20
-rw-r--r--lib/dns/udp_client.go7
11 files changed, 114 insertions, 113 deletions
diff --git a/lib/dns/answer_test.go b/lib/dns/answer_test.go
index 5b5df89d..d12baeaa 100644
--- a/lib/dns/answer_test.go
+++ b/lib/dns/answer_test.go
@@ -200,12 +200,8 @@ func TestAnswerGet(t *testing.T) {
test.Assert(t, "ReceivedAt", an.ReceivedAt >= at-5, true)
test.Assert(t, "AccessedAt", an.AccessedAt >= at, true)
- got = &Message{
- Header: MessageHeader{},
- Question: MessageQuestion{},
- packet: gotPacket,
- }
- err = got.Unpack()
+
+ got, err = UnpackMessage(gotPacket)
if err != nil {
t.Fatal(err)
}
diff --git a/lib/dns/caches.go b/lib/dns/caches.go
index 99e60c50..d0c674c7 100644
--- a/lib/dns/caches.go
+++ b/lib/dns/caches.go
@@ -516,9 +516,7 @@ func (c *Caches) read(r io.Reader) (answers []*Answer, err error) {
return nil, fmt.Errorf("%s: %w", logp, err)
}
- msg = NewMessage()
- msg.packet = item.Packet
- err = msg.Unpack()
+ msg, err = UnpackMessage(item.Packet)
if err != nil {
return nil, fmt.Errorf("%s: %w", logp, err)
}
diff --git a/lib/dns/caches_test.go b/lib/dns/caches_test.go
index bc18969b..4f15d3d0 100644
--- a/lib/dns/caches_test.go
+++ b/lib/dns/caches_test.go
@@ -202,9 +202,7 @@ func TestCaches_ExternalSave(t *testing.T) {
test.Assert(t, "Caches.ExternalSave", 1, n)
- msg = NewMessage()
- msg.packet = answer.msg.packet
- err = msg.Unpack()
+ msg, err = UnpackMessage(answer.msg.packet)
if err != nil {
t.Fatal(err)
}
diff --git a/lib/dns/doh_client.go b/lib/dns/doh_client.go
index 4a336a6b..9e6db4b9 100644
--- a/lib/dns/doh_client.go
+++ b/lib/dns/doh_client.go
@@ -134,6 +134,8 @@ func (cl *DoHClient) Lookup(q MessageQuestion, allowRecursion bool) (res *Messag
// as unpacked message.
func (cl *DoHClient) Post(msg *Message) (res *Message, err error) {
var (
+ logp = `Post`
+
httpRes *http.Response
)
@@ -148,17 +150,20 @@ func (cl *DoHClient) Post(msg *Message) (res *Message, err error) {
}
cl.req.Body.Close()
- res = NewMessage()
+ var packet []byte
- res.packet, err = io.ReadAll(httpRes.Body)
+ packet, err = io.ReadAll(httpRes.Body)
httpRes.Body.Close()
if err != nil {
return nil, err
}
- err = res.Unpack()
+ res, err = UnpackMessage(packet)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
- return res, err
+ return res, nil
}
// Get send query to name server using HTTP GET and return the response as
@@ -181,19 +186,19 @@ func (cl *DoHClient) Get(msg *Message) (res *Message, err error) {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
- res = NewMessage()
+ var packet []byte
- res.packet, err = io.ReadAll(httpRes.Body)
+ packet, err = io.ReadAll(httpRes.Body)
httpRes.Body.Close()
if err != nil {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
if httpRes.StatusCode != 200 {
- return nil, fmt.Errorf(`%s: %s`, logp, string(res.packet))
+ return nil, fmt.Errorf(`%s: %s`, logp, string(packet))
}
- if len(res.packet) > 20 {
- err = res.Unpack()
+ if len(packet) > 20 {
+ res, err = UnpackMessage(packet)
if err != nil {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
diff --git a/lib/dns/dot_client.go b/lib/dns/dot_client.go
index 42c6a7e2..7ee2bb29 100644
--- a/lib/dns/dot_client.go
+++ b/lib/dns/dot_client.go
@@ -106,14 +106,14 @@ func (cl *DoTClient) Query(msg *Message) (res *Message, err error) {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
- res = NewMessage()
+ var packet []byte
- _, err = cl.recv(res)
+ packet, err = cl.recv()
if err != nil {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
- err = res.Unpack()
+ res, err = UnpackMessage(packet)
if err != nil {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
@@ -126,28 +126,30 @@ func (cl *DoTClient) RemoteAddr() string {
return cl.conn.RemoteAddr().String()
}
-// recv will read DNS message from active connection in client into `msg`.
-func (cl *DoTClient) recv(msg *Message) (n int, err error) {
+// recv will read DNS message from active connection.
+func (cl *DoTClient) recv() (packet []byte, err error) {
var logp = `recv`
err = cl.conn.SetReadDeadline(time.Now().Add(cl.timeout))
if err != nil {
- return 0, fmt.Errorf(`%s: %w`, logp, err)
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
}
- var packet = make([]byte, maxTCPPacketSize)
+ var n int
+
+ packet = make([]byte, maxTCPPacketSize)
n, err = cl.conn.Read(packet)
if err != nil {
- return 0, fmt.Errorf(`%s: %w`, logp, err)
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
}
if n == 0 {
- return n, nil
+ return packet, nil
}
- msg.packet = packet[2:n]
+ packet = packet[2:n]
- return n, nil
+ return packet, nil
}
// Write raw DNS message on active connection.
diff --git a/lib/dns/message.go b/lib/dns/message.go
index d4861496..7a43a7c8 100644
--- a/lib/dns/message.go
+++ b/lib/dns/message.go
@@ -173,6 +173,60 @@ func NewMessageFromRR(rr *ResourceRecord) (msg *Message, err error) {
return msg, nil
}
+// Unpack the raw DNS packet into a Message.
+func UnpackMessage(packet []byte) (msg *Message, err error) {
+ var logp = `UnpackMessage`
+
+ msg = &Message{
+ packet: packet,
+ }
+
+ err = msg.UnpackHeaderQuestion()
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+
+ var (
+ startIdx = uint(sectionHeaderSize + msg.Question.size())
+
+ rr ResourceRecord
+ x uint16
+ )
+ for ; x < msg.Header.ANCount; x++ {
+ rr = ResourceRecord{}
+
+ startIdx, err = rr.unpack(msg.packet, startIdx)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+
+ msg.Answer = append(msg.Answer, rr)
+ }
+
+ for x = 0; x < msg.Header.NSCount; x++ {
+ rr = ResourceRecord{}
+
+ startIdx, err = rr.unpack(msg.packet, startIdx)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+ msg.Authority = append(msg.Authority, rr)
+ }
+
+ for x = 0; x < msg.Header.ARCount; x++ {
+ rr = ResourceRecord{}
+
+ startIdx, err = rr.unpack(msg.packet, startIdx)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+
+ msg.Additional = append(msg.Additional, rr)
+ }
+
+ return msg, nil
+}
+
// AddAnswer to the Answer field and re-pack it again.
func (msg *Message) AddAnswer(rr *ResourceRecord) (err error) {
switch rr.Type {
@@ -1036,54 +1090,6 @@ func (msg *Message) String() string {
return b.String()
}
-// Unpack the packet to fill the message fields.
-func (msg *Message) Unpack() (err error) {
- err = msg.UnpackHeaderQuestion()
- if err != nil {
- return fmt.Errorf(`%w: %w`, errUnpack, err)
- }
-
- var (
- startIdx = uint(sectionHeaderSize + msg.Question.size())
- rr ResourceRecord
- )
-
- var x uint16
- for ; x < msg.Header.ANCount; x++ {
- rr = ResourceRecord{}
-
- startIdx, err = rr.unpack(msg.packet, startIdx)
- if err != nil {
- return fmt.Errorf(`%w: %w`, errUnpack, err)
- }
-
- msg.Answer = append(msg.Answer, rr)
- }
-
- for x = 0; x < msg.Header.NSCount; x++ {
- rr = ResourceRecord{}
-
- startIdx, err = rr.unpack(msg.packet, startIdx)
- if err != nil {
- return fmt.Errorf(`%w: %w`, errUnpack, err)
- }
- msg.Authority = append(msg.Authority, rr)
- }
-
- for x = 0; x < msg.Header.ARCount; x++ {
- rr = ResourceRecord{}
-
- startIdx, err = rr.unpack(msg.packet, startIdx)
- if err != nil {
- return fmt.Errorf(`%w: %w`, errUnpack, err)
- }
-
- msg.Additional = append(msg.Additional, rr)
- }
-
- return nil
-}
-
// UnpackHeaderQuestion extract only DNS header and question from message
// packet. This method assume that message.packet already set to DNS raw
// message.
diff --git a/lib/dns/message_test.go b/lib/dns/message_test.go
index d97a82d9..5c6fc451 100644
--- a/lib/dns/message_test.go
+++ b/lib/dns/message_test.go
@@ -1210,7 +1210,7 @@ func TestMessageSetResponseCode(t *testing.T) {
}
}
-func TestMessageUnpack(t *testing.T) {
+func TestUnpackMessage(t *testing.T) {
type testCase struct {
exp *Message
desc string
@@ -2037,10 +2037,7 @@ func TestMessageUnpack(t *testing.T) {
for _, c = range cases {
t.Log(c.desc)
- msg.Reset()
- msg.packet = c.packet
-
- err = msg.Unpack()
+ msg, err = UnpackMessage(c.packet)
if err != nil {
t.Fatal(err)
}
@@ -2106,26 +2103,26 @@ func TestUnpackMessage_SVCB(t *testing.T) {
}
var (
- name string
- msgjson []byte
+ name string
+ stream []byte
+ msg *Message
)
for _, name = range listCase {
- var msg Message
- msg.packet, err = libbytes.ParseHexDump(tdata.Input[name], true)
+ stream, err = libbytes.ParseHexDump(tdata.Input[name], true)
if err != nil {
t.Fatal(logp, err)
}
- err = msg.Unpack()
+ msg, err = UnpackMessage(stream)
if err != nil {
t.Fatal(logp, err)
}
- msgjson, err = json.MarshalIndent(&msg, ``, ` `)
+ stream, err = json.MarshalIndent(&msg, ``, ` `)
if err != nil {
t.Fatal(logp, err)
}
- test.Assert(t, name, string(tdata.Output[name]), string(msgjson))
+ test.Assert(t, name, string(tdata.Output[name]), string(stream))
}
}
diff --git a/lib/dns/rdata_svcb.go b/lib/dns/rdata_svcb.go
index e3ac6d6b..e97eb078 100644
--- a/lib/dns/rdata_svcb.go
+++ b/lib/dns/rdata_svcb.go
@@ -564,7 +564,10 @@ func (svcb *RDataSVCB) unpack(packet []byte) (err error) {
}
packet = packet[x:]
- svcb.unpackParams(packet)
+ err = svcb.unpackParams(packet)
+ if err != nil {
+ return err
+ }
return nil
}
@@ -663,7 +666,7 @@ func (svcb *RDataSVCB) unpackParamALPN(packet []byte) ([]byte, error) {
var n = int(packet[0])
packet = packet[1:]
- total -= 1
+ total--
if len(packet) < int(total) {
return packet, fmt.Errorf(`%s: mismatch value length, want %d got %d`, logp, n, len(packet))
diff --git a/lib/dns/server.go b/lib/dns/server.go
index e9e1b49a..43795c2e 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -513,13 +513,14 @@ func (srv *Server) incForwarder() {
func (srv *Server) serveTCPClient(cl *TCPClient, kind connType) {
var (
logp = `serveTCPClient`
- req *request
- err error
+
+ req *request
+ err error
)
for {
req = newRequest()
- req.message, err = cl.recv()
+ req.message.packet, err = cl.recv()
if err != nil {
if !errors.Is(err, io.EOF) {
log.Printf(`%s %s: %s`, logp, connTypeNames[kind], err)
diff --git a/lib/dns/tcp_client.go b/lib/dns/tcp_client.go
index d047e864..d7fad15b 100644
--- a/lib/dns/tcp_client.go
+++ b/lib/dns/tcp_client.go
@@ -126,12 +126,14 @@ func (cl *TCPClient) Query(msg *Message) (res *Message, err error) {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
- res, err = cl.recv()
+ var packet []byte
+
+ packet, err = cl.recv()
if err != nil {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
- err = res.Unpack()
+ res, err = UnpackMessage(packet)
if err != nil {
return nil, fmt.Errorf(`%s: %w`, logp, err)
}
@@ -186,7 +188,7 @@ func (cl *TCPClient) Write(msg []byte) (n int, err error) {
}
// recv receive DNS message.
-func (cl *TCPClient) recv() (res *Message, err error) {
+func (cl *TCPClient) recv() (packet []byte, err error) {
var logp = `recv`
if cl.readTimeout > 0 {
@@ -196,11 +198,9 @@ func (cl *TCPClient) recv() (res *Message, err error) {
}
}
- var (
- packet = make([]byte, maxTCPPacketSize)
+ var n int
- n int
- )
+ packet = make([]byte, maxTCPPacketSize)
n, err = cl.conn.Read(packet)
if err != nil {
@@ -210,9 +210,7 @@ func (cl *TCPClient) recv() (res *Message, err error) {
return nil, io.EOF
}
- res = &Message{
- packet: packet[2:n],
- }
+ packet = packet[2:n]
- return res, nil
+ return packet, nil
}
diff --git a/lib/dns/udp_client.go b/lib/dns/udp_client.go
index 9f3eeb11..0769b1e8 100644
--- a/lib/dns/udp_client.go
+++ b/lib/dns/udp_client.go
@@ -138,11 +138,8 @@ func (cl *UDPClient) Query(req *Message) (res *Message, err error) {
return nil, fmt.Errorf("%s: %w", logp, err)
}
- res = &Message{
- packet: packet[:n],
- }
-
- err = res.Unpack()
+ packet = packet[:n]
+ res, err = UnpackMessage(packet)
if err != nil {
return nil, fmt.Errorf("%s: %w", logp, err)
}