diff options
| author | Shulhan <ms@kilabit.info> | 2024-03-27 02:24:41 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2024-03-27 02:24:41 +0700 |
| commit | 14325589db35cf36ed1aa71ff4f2c5ad0bb6886b (patch) | |
| tree | cd6cf46ca219c3084b354b4732e4aef44d1c41eb | |
| parent | 71eaafc5119b178be61abf6ae7b8a2fbcdfacc44 (diff) | |
| download | pakakeh.go-14325589db35cf36ed1aa71ff4f2c5ad0bb6886b.tar.xz | |
lib/dns: refactor [Message.Unpack] to [UnpackMessage]
The previous API for Message is a little bit weird.
Its provides creating Message manually, but expose the method
[UnpackHeaderQuestion], meanwhile the field packet itself is unexported.
In order to make it more clear we refactor [Message.Unpack] to
function [UnpackMessage] that accept raw DNS packet.
| -rw-r--r-- | lib/dns/answer_test.go | 8 | ||||
| -rw-r--r-- | lib/dns/caches.go | 4 | ||||
| -rw-r--r-- | lib/dns/caches_test.go | 4 | ||||
| -rw-r--r-- | lib/dns/doh_client.go | 23 | ||||
| -rw-r--r-- | lib/dns/dot_client.go | 24 | ||||
| -rw-r--r-- | lib/dns/message.go | 102 | ||||
| -rw-r--r-- | lib/dns/message_test.go | 21 | ||||
| -rw-r--r-- | lib/dns/rdata_svcb.go | 7 | ||||
| -rw-r--r-- | lib/dns/server.go | 7 | ||||
| -rw-r--r-- | lib/dns/tcp_client.go | 20 | ||||
| -rw-r--r-- | lib/dns/udp_client.go | 7 |
11 files changed, 114 insertions, 113 deletions
diff --git a/lib/dns/answer_test.go b/lib/dns/answer_test.go index 5b5df89d..d12baeaa 100644 --- a/lib/dns/answer_test.go +++ b/lib/dns/answer_test.go @@ -200,12 +200,8 @@ func TestAnswerGet(t *testing.T) { test.Assert(t, "ReceivedAt", an.ReceivedAt >= at-5, true) test.Assert(t, "AccessedAt", an.AccessedAt >= at, true) - got = &Message{ - Header: MessageHeader{}, - Question: MessageQuestion{}, - packet: gotPacket, - } - err = got.Unpack() + + got, err = UnpackMessage(gotPacket) if err != nil { t.Fatal(err) } diff --git a/lib/dns/caches.go b/lib/dns/caches.go index 99e60c50..d0c674c7 100644 --- a/lib/dns/caches.go +++ b/lib/dns/caches.go @@ -516,9 +516,7 @@ func (c *Caches) read(r io.Reader) (answers []*Answer, err error) { return nil, fmt.Errorf("%s: %w", logp, err) } - msg = NewMessage() - msg.packet = item.Packet - err = msg.Unpack() + msg, err = UnpackMessage(item.Packet) if err != nil { return nil, fmt.Errorf("%s: %w", logp, err) } diff --git a/lib/dns/caches_test.go b/lib/dns/caches_test.go index bc18969b..4f15d3d0 100644 --- a/lib/dns/caches_test.go +++ b/lib/dns/caches_test.go @@ -202,9 +202,7 @@ func TestCaches_ExternalSave(t *testing.T) { test.Assert(t, "Caches.ExternalSave", 1, n) - msg = NewMessage() - msg.packet = answer.msg.packet - err = msg.Unpack() + msg, err = UnpackMessage(answer.msg.packet) if err != nil { t.Fatal(err) } diff --git a/lib/dns/doh_client.go b/lib/dns/doh_client.go index 4a336a6b..9e6db4b9 100644 --- a/lib/dns/doh_client.go +++ b/lib/dns/doh_client.go @@ -134,6 +134,8 @@ func (cl *DoHClient) Lookup(q MessageQuestion, allowRecursion bool) (res *Messag // as unpacked message. func (cl *DoHClient) Post(msg *Message) (res *Message, err error) { var ( + logp = `Post` + httpRes *http.Response ) @@ -148,17 +150,20 @@ func (cl *DoHClient) Post(msg *Message) (res *Message, err error) { } cl.req.Body.Close() - res = NewMessage() + var packet []byte - res.packet, err = io.ReadAll(httpRes.Body) + packet, err = io.ReadAll(httpRes.Body) httpRes.Body.Close() if err != nil { return nil, err } - err = res.Unpack() + res, err = UnpackMessage(packet) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } - return res, err + return res, nil } // Get send query to name server using HTTP GET and return the response as @@ -181,19 +186,19 @@ func (cl *DoHClient) Get(msg *Message) (res *Message, err error) { return nil, fmt.Errorf(`%s: %w`, logp, err) } - res = NewMessage() + var packet []byte - res.packet, err = io.ReadAll(httpRes.Body) + packet, err = io.ReadAll(httpRes.Body) httpRes.Body.Close() if err != nil { return nil, fmt.Errorf(`%s: %w`, logp, err) } if httpRes.StatusCode != 200 { - return nil, fmt.Errorf(`%s: %s`, logp, string(res.packet)) + return nil, fmt.Errorf(`%s: %s`, logp, string(packet)) } - if len(res.packet) > 20 { - err = res.Unpack() + if len(packet) > 20 { + res, err = UnpackMessage(packet) if err != nil { return nil, fmt.Errorf(`%s: %w`, logp, err) } diff --git a/lib/dns/dot_client.go b/lib/dns/dot_client.go index 42c6a7e2..7ee2bb29 100644 --- a/lib/dns/dot_client.go +++ b/lib/dns/dot_client.go @@ -106,14 +106,14 @@ func (cl *DoTClient) Query(msg *Message) (res *Message, err error) { return nil, fmt.Errorf(`%s: %w`, logp, err) } - res = NewMessage() + var packet []byte - _, err = cl.recv(res) + packet, err = cl.recv() if err != nil { return nil, fmt.Errorf(`%s: %w`, logp, err) } - err = res.Unpack() + res, err = UnpackMessage(packet) if err != nil { return nil, fmt.Errorf(`%s: %w`, logp, err) } @@ -126,28 +126,30 @@ func (cl *DoTClient) RemoteAddr() string { return cl.conn.RemoteAddr().String() } -// recv will read DNS message from active connection in client into `msg`. -func (cl *DoTClient) recv(msg *Message) (n int, err error) { +// recv will read DNS message from active connection. +func (cl *DoTClient) recv() (packet []byte, err error) { var logp = `recv` err = cl.conn.SetReadDeadline(time.Now().Add(cl.timeout)) if err != nil { - return 0, fmt.Errorf(`%s: %w`, logp, err) + return nil, fmt.Errorf(`%s: %w`, logp, err) } - var packet = make([]byte, maxTCPPacketSize) + var n int + + packet = make([]byte, maxTCPPacketSize) n, err = cl.conn.Read(packet) if err != nil { - return 0, fmt.Errorf(`%s: %w`, logp, err) + return nil, fmt.Errorf(`%s: %w`, logp, err) } if n == 0 { - return n, nil + return packet, nil } - msg.packet = packet[2:n] + packet = packet[2:n] - return n, nil + return packet, nil } // Write raw DNS message on active connection. diff --git a/lib/dns/message.go b/lib/dns/message.go index d4861496..7a43a7c8 100644 --- a/lib/dns/message.go +++ b/lib/dns/message.go @@ -173,6 +173,60 @@ func NewMessageFromRR(rr *ResourceRecord) (msg *Message, err error) { return msg, nil } +// Unpack the raw DNS packet into a Message. +func UnpackMessage(packet []byte) (msg *Message, err error) { + var logp = `UnpackMessage` + + msg = &Message{ + packet: packet, + } + + err = msg.UnpackHeaderQuestion() + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + + var ( + startIdx = uint(sectionHeaderSize + msg.Question.size()) + + rr ResourceRecord + x uint16 + ) + for ; x < msg.Header.ANCount; x++ { + rr = ResourceRecord{} + + startIdx, err = rr.unpack(msg.packet, startIdx) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + + msg.Answer = append(msg.Answer, rr) + } + + for x = 0; x < msg.Header.NSCount; x++ { + rr = ResourceRecord{} + + startIdx, err = rr.unpack(msg.packet, startIdx) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + msg.Authority = append(msg.Authority, rr) + } + + for x = 0; x < msg.Header.ARCount; x++ { + rr = ResourceRecord{} + + startIdx, err = rr.unpack(msg.packet, startIdx) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + + msg.Additional = append(msg.Additional, rr) + } + + return msg, nil +} + // AddAnswer to the Answer field and re-pack it again. func (msg *Message) AddAnswer(rr *ResourceRecord) (err error) { switch rr.Type { @@ -1036,54 +1090,6 @@ func (msg *Message) String() string { return b.String() } -// Unpack the packet to fill the message fields. -func (msg *Message) Unpack() (err error) { - err = msg.UnpackHeaderQuestion() - if err != nil { - return fmt.Errorf(`%w: %w`, errUnpack, err) - } - - var ( - startIdx = uint(sectionHeaderSize + msg.Question.size()) - rr ResourceRecord - ) - - var x uint16 - for ; x < msg.Header.ANCount; x++ { - rr = ResourceRecord{} - - startIdx, err = rr.unpack(msg.packet, startIdx) - if err != nil { - return fmt.Errorf(`%w: %w`, errUnpack, err) - } - - msg.Answer = append(msg.Answer, rr) - } - - for x = 0; x < msg.Header.NSCount; x++ { - rr = ResourceRecord{} - - startIdx, err = rr.unpack(msg.packet, startIdx) - if err != nil { - return fmt.Errorf(`%w: %w`, errUnpack, err) - } - msg.Authority = append(msg.Authority, rr) - } - - for x = 0; x < msg.Header.ARCount; x++ { - rr = ResourceRecord{} - - startIdx, err = rr.unpack(msg.packet, startIdx) - if err != nil { - return fmt.Errorf(`%w: %w`, errUnpack, err) - } - - msg.Additional = append(msg.Additional, rr) - } - - return nil -} - // UnpackHeaderQuestion extract only DNS header and question from message // packet. This method assume that message.packet already set to DNS raw // message. diff --git a/lib/dns/message_test.go b/lib/dns/message_test.go index d97a82d9..5c6fc451 100644 --- a/lib/dns/message_test.go +++ b/lib/dns/message_test.go @@ -1210,7 +1210,7 @@ func TestMessageSetResponseCode(t *testing.T) { } } -func TestMessageUnpack(t *testing.T) { +func TestUnpackMessage(t *testing.T) { type testCase struct { exp *Message desc string @@ -2037,10 +2037,7 @@ func TestMessageUnpack(t *testing.T) { for _, c = range cases { t.Log(c.desc) - msg.Reset() - msg.packet = c.packet - - err = msg.Unpack() + msg, err = UnpackMessage(c.packet) if err != nil { t.Fatal(err) } @@ -2106,26 +2103,26 @@ func TestUnpackMessage_SVCB(t *testing.T) { } var ( - name string - msgjson []byte + name string + stream []byte + msg *Message ) for _, name = range listCase { - var msg Message - msg.packet, err = libbytes.ParseHexDump(tdata.Input[name], true) + stream, err = libbytes.ParseHexDump(tdata.Input[name], true) if err != nil { t.Fatal(logp, err) } - err = msg.Unpack() + msg, err = UnpackMessage(stream) if err != nil { t.Fatal(logp, err) } - msgjson, err = json.MarshalIndent(&msg, ``, ` `) + stream, err = json.MarshalIndent(&msg, ``, ` `) if err != nil { t.Fatal(logp, err) } - test.Assert(t, name, string(tdata.Output[name]), string(msgjson)) + test.Assert(t, name, string(tdata.Output[name]), string(stream)) } } diff --git a/lib/dns/rdata_svcb.go b/lib/dns/rdata_svcb.go index e3ac6d6b..e97eb078 100644 --- a/lib/dns/rdata_svcb.go +++ b/lib/dns/rdata_svcb.go @@ -564,7 +564,10 @@ func (svcb *RDataSVCB) unpack(packet []byte) (err error) { } packet = packet[x:] - svcb.unpackParams(packet) + err = svcb.unpackParams(packet) + if err != nil { + return err + } return nil } @@ -663,7 +666,7 @@ func (svcb *RDataSVCB) unpackParamALPN(packet []byte) ([]byte, error) { var n = int(packet[0]) packet = packet[1:] - total -= 1 + total-- if len(packet) < int(total) { return packet, fmt.Errorf(`%s: mismatch value length, want %d got %d`, logp, n, len(packet)) diff --git a/lib/dns/server.go b/lib/dns/server.go index e9e1b49a..43795c2e 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -513,13 +513,14 @@ func (srv *Server) incForwarder() { func (srv *Server) serveTCPClient(cl *TCPClient, kind connType) { var ( logp = `serveTCPClient` - req *request - err error + + req *request + err error ) for { req = newRequest() - req.message, err = cl.recv() + req.message.packet, err = cl.recv() if err != nil { if !errors.Is(err, io.EOF) { log.Printf(`%s %s: %s`, logp, connTypeNames[kind], err) diff --git a/lib/dns/tcp_client.go b/lib/dns/tcp_client.go index d047e864..d7fad15b 100644 --- a/lib/dns/tcp_client.go +++ b/lib/dns/tcp_client.go @@ -126,12 +126,14 @@ func (cl *TCPClient) Query(msg *Message) (res *Message, err error) { return nil, fmt.Errorf(`%s: %w`, logp, err) } - res, err = cl.recv() + var packet []byte + + packet, err = cl.recv() if err != nil { return nil, fmt.Errorf(`%s: %w`, logp, err) } - err = res.Unpack() + res, err = UnpackMessage(packet) if err != nil { return nil, fmt.Errorf(`%s: %w`, logp, err) } @@ -186,7 +188,7 @@ func (cl *TCPClient) Write(msg []byte) (n int, err error) { } // recv receive DNS message. -func (cl *TCPClient) recv() (res *Message, err error) { +func (cl *TCPClient) recv() (packet []byte, err error) { var logp = `recv` if cl.readTimeout > 0 { @@ -196,11 +198,9 @@ func (cl *TCPClient) recv() (res *Message, err error) { } } - var ( - packet = make([]byte, maxTCPPacketSize) + var n int - n int - ) + packet = make([]byte, maxTCPPacketSize) n, err = cl.conn.Read(packet) if err != nil { @@ -210,9 +210,7 @@ func (cl *TCPClient) recv() (res *Message, err error) { return nil, io.EOF } - res = &Message{ - packet: packet[2:n], - } + packet = packet[2:n] - return res, nil + return packet, nil } diff --git a/lib/dns/udp_client.go b/lib/dns/udp_client.go index 9f3eeb11..0769b1e8 100644 --- a/lib/dns/udp_client.go +++ b/lib/dns/udp_client.go @@ -138,11 +138,8 @@ func (cl *UDPClient) Query(req *Message) (res *Message, err error) { return nil, fmt.Errorf("%s: %w", logp, err) } - res = &Message{ - packet: packet[:n], - } - - err = res.Unpack() + packet = packet[:n] + res, err = UnpackMessage(packet) if err != nil { return nil, fmt.Errorf("%s: %w", logp, err) } |
