diff options
| author | Shulhan <ms@kilabit.info> | 2021-11-11 01:46:35 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2021-11-11 01:51:46 +0700 |
| commit | 5e07be2c3ad34ce8b025e04958aff28f621b9b23 (patch) | |
| tree | e2707cee9b4630a8519b44175053a6ea246413fd | |
| parent | 7e8b2e42f1f558dfb7640a2565e3d2d4e3f7e59a (diff) | |
| download | pakakeh.go-5e07be2c3ad34ce8b025e04958aff28f621b9b23.tar.xz | |
lib/dns: use different packet between UDP and TCP messages
Previously, all packet size for reading and sending the message is
fixed to 4096, even on UDP.
This changes set the UDP packet size maximum to 512 bytes and others to
4096 bytes.
While at it, minimize copying packet if its not reusable inside a method.
| -rw-r--r-- | lib/dns/dns.go | 3 | ||||
| -rw-r--r-- | lib/dns/dotclient.go | 4 | ||||
| -rw-r--r-- | lib/dns/message.go | 2 | ||||
| -rw-r--r-- | lib/dns/message_test.go | 10 | ||||
| -rw-r--r-- | lib/dns/server.go | 28 | ||||
| -rw-r--r-- | lib/dns/tcpclient.go | 31 | ||||
| -rw-r--r-- | lib/dns/udpclient.go | 50 |
7 files changed, 61 insertions, 67 deletions
diff --git a/lib/dns/dns.go b/lib/dns/dns.go index 47420807..11eb479b 100644 --- a/lib/dns/dns.go +++ b/lib/dns/dns.go @@ -41,7 +41,8 @@ const ( maskOPTDO uint32 = 0x00008000 maxLabelSize = 63 - maxUDPPacketSize = 4096 + maxUdpPacketSize = 512 + maxTcpPacketSize = 4096 rdataIPv4Size = 4 rdataIPv6Size = 16 // sectionHeaderSize define the size of section header in DNS message. diff --git a/lib/dns/dotclient.go b/lib/dns/dotclient.go index 1399973e..3d43d277 100644 --- a/lib/dns/dotclient.go +++ b/lib/dns/dotclient.go @@ -139,7 +139,7 @@ func (cl *DoTClient) recv(msg *Message) (n int, err error) { return } - packet := make([]byte, maxUDPPacketSize) + packet := make([]byte, maxTcpPacketSize) n, err = cl.conn.Read(packet) if err != nil { @@ -149,7 +149,7 @@ func (cl *DoTClient) recv(msg *Message) (n int, err error) { return } - msg.packet = libbytes.Copy(packet[2:n]) + msg.packet = packet[2:n] if debug.Value >= 3 { libbytes.PrintHex(">>> DoTClient: recv: ", msg.packet, 8) diff --git a/lib/dns/message.go b/lib/dns/message.go index f68122ad..a7c32323 100644 --- a/lib/dns/message.go +++ b/lib/dns/message.go @@ -619,7 +619,7 @@ func (msg *Message) Reset() { msg.Question.Reset() msg.ResetRR() - msg.packet = append(msg.packet[:0], make([]byte, maxUDPPacketSize)...) + msg.packet = nil msg.dname = "" msg.off = 0 diff --git a/lib/dns/message_test.go b/lib/dns/message_test.go index 7082cf94..67402b4c 100644 --- a/lib/dns/message_test.go +++ b/lib/dns/message_test.go @@ -844,7 +844,7 @@ func TestMessageSetAuthoritativeAnswer(t *testing.T) { IsRD: true, }, Question: SectionQuestion{}, - packet: make([]byte, maxUDPPacketSize), + packet: make([]byte, maxUdpPacketSize), dnameOff: make(map[string]uint16), } @@ -861,7 +861,7 @@ func TestMessageSetAuthoritativeAnswer(t *testing.T) { IsRA: true, }, Question: SectionQuestion{}, - packet: make([]byte, maxUDPPacketSize), + packet: make([]byte, maxUdpPacketSize), dnameOff: make(map[string]uint16), } @@ -913,7 +913,7 @@ func TestMessageSetQuery(t *testing.T) { IsRD: true, }, Question: SectionQuestion{}, - packet: make([]byte, maxUDPPacketSize), + packet: make([]byte, maxUdpPacketSize), dnameOff: make(map[string]uint16), } @@ -956,7 +956,7 @@ func TestMessageSetRecursionDesired(t *testing.T) { IsRD: true, }, Question: SectionQuestion{}, - packet: make([]byte, maxUDPPacketSize), + packet: make([]byte, maxUdpPacketSize), dnameOff: make(map[string]uint16), } @@ -999,7 +999,7 @@ func TestMessageSetResponseCode(t *testing.T) { IsRD: true, }, Question: SectionQuestion{}, - packet: make([]byte, maxUDPPacketSize), + packet: make([]byte, maxUdpPacketSize), dnameOff: make(map[string]uint16), } diff --git a/lib/dns/server.go b/lib/dns/server.go index 7c289cdb..167feb3d 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -462,13 +462,17 @@ func (srv *Server) serveTCP() { // serveUDP serve DNS request from UDP connection. // func (srv *Server) serveUDP() { - log.Println("dns.Server: listening for DNS over UDP at", srv.udp.LocalAddr()) + var ( + n int + packet = make([]byte, maxUdpPacketSize) + raddr *net.UDPAddr + req *request + err error + ) - packet := make([]byte, maxUDPPacketSize) + log.Println("dns.Server: listening for DNS over UDP at", srv.udp.LocalAddr()) for { - req := newRequest() - - n, raddr, err := srv.udp.ReadFromUDP(packet) + n, raddr, err = srv.udp.ReadFromUDP(packet) if err != nil { if n == 0 || errors.Is(err, io.EOF) { err = nil @@ -479,6 +483,7 @@ func (srv *Server) serveUDP() { return } + req = newRequest() req.message.packet = libbytes.Copy(packet[:n]) req.kind = connTypeUDP @@ -606,18 +611,19 @@ func (srv *Server) incForwarder() { } func (srv *Server) serveTCPClient(cl *TCPClient, kind connType) { + var ( + req *request + err error + ) for { - req := newRequest() + req = newRequest() - n, err := cl.recv(req.message) + req.message, err = cl.recv() if err != nil { log.Printf("serveTCPClient: %s: %s", connTypeNames[kind], err.Error()) break } - if n == 0 || len(req.message.packet) == 0 { - break - } req.kind = kind req.writer = cl @@ -632,7 +638,7 @@ func (srv *Server) serveTCPClient(cl *TCPClient, kind connType) { srv.requestq <- req } - err := cl.conn.Close() + err = cl.conn.Close() if err != nil { log.Printf("serveTCPClient: conn.Close: %s: %s", connTypeNames[kind], err.Error()) diff --git a/lib/dns/tcpclient.go b/lib/dns/tcpclient.go index e88ad7d7..42f99fd9 100644 --- a/lib/dns/tcpclient.go +++ b/lib/dns/tcpclient.go @@ -6,6 +6,7 @@ package dns import ( "fmt" + "io" "net" "time" @@ -123,15 +124,13 @@ func (cl *TCPClient) Lookup( // Query send DNS query to name server. // The addr parameter is unused. // -func (cl *TCPClient) Query(msg *Message) (*Message, error) { - _, err := cl.Write(msg.packet) +func (cl *TCPClient) Query(msg *Message) (res *Message, err error) { + _, err = cl.Write(msg.packet) if err != nil { return nil, err } - res := NewMessage() - - _, err = cl.recv(res) + res, err = cl.recv() if err != nil { return nil, err } @@ -192,31 +191,33 @@ func (cl *TCPClient) Write(msg []byte) (n int, err error) { } // -// recv will read DNS message from active connection in client into `msg`. +// recv receive DNS message. // -func (cl *TCPClient) recv(msg *Message) (n int, err error) { +func (cl *TCPClient) recv() (res *Message, err error) { if cl.readTimeout > 0 { err = cl.conn.SetReadDeadline(time.Now().Add(cl.readTimeout)) if err != nil { - return + return nil, err } } - packet := make([]byte, maxUDPPacketSize) + packet := make([]byte, maxTcpPacketSize) - n, err = cl.conn.Read(packet) + n, err := cl.conn.Read(packet) if err != nil { - return + return nil, err } if n == 0 { - return + return nil, io.EOF } - msg.packet = libbytes.Copy(packet[2:n]) + res = &Message{ + packet: packet[2:n], + } if debug.Value >= 3 { - libbytes.PrintHex(">>> TCPClient: recv: ", msg.packet, 8) + libbytes.PrintHex(">>> TCPClient.recv: ", res.packet, 8) } - return + return res, nil } diff --git a/lib/dns/udpclient.go b/lib/dns/udpclient.go index 62322cad..c2bc6014 100644 --- a/lib/dns/udpclient.go +++ b/lib/dns/udpclient.go @@ -115,56 +115,42 @@ func (cl *UDPClient) Lookup( // // Query send DNS query to name server "ns" and return the unpacked response. // -func (cl *UDPClient) Query(msg *Message) (*Message, error) { +func (cl *UDPClient) Query(req *Message) (res *Message, err error) { + logp := "Query" cl.Lock() + defer cl.Unlock() - _, err := cl.Write(msg.packet) + _, err = cl.Write(req.packet) if err != nil { - cl.Unlock() - return nil, err + return nil, fmt.Errorf("%s: %w", logp, err) } - res := NewMessage() - - _, err = cl.recv(res) + err = cl.conn.SetReadDeadline(time.Now().Add(cl.timeout)) if err != nil { - cl.Unlock() - return nil, err + return nil, fmt.Errorf("%s: %w", logp, err) } - cl.Unlock() + packet := make([]byte, maxUdpPacketSize) - err = res.Unpack() + n, _, err := cl.conn.ReadFromUDP(packet) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: %w", logp, err) } - return res, nil -} - -// -// recv will read DNS message from active connection in client into `msg`. -// -func (cl *UDPClient) recv(msg *Message) (n int, err error) { - err = cl.conn.SetReadDeadline(time.Now().Add(cl.timeout)) - if err != nil { - return + res = &Message{ + packet: packet[:n], } - packet := make([]byte, maxUDPPacketSize) - - n, _, err = cl.conn.ReadFromUDP(packet) - if err != nil { - return + if debug.Value >= 3 { + libbytes.PrintHex(">>> UDPClient.recv:", res.packet, 8) } - msg.packet = libbytes.Copy(packet[:n]) - - if debug.Value >= 3 { - libbytes.PrintHex(">>> UDPClient: recv:", msg.packet, 8) + err = res.Unpack() + if err != nil { + return nil, fmt.Errorf("%s: %w", logp, err) } - return + return res, nil } // |
