summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2021-11-11 01:46:35 +0700
committerShulhan <ms@kilabit.info>2021-11-11 01:51:46 +0700
commit5e07be2c3ad34ce8b025e04958aff28f621b9b23 (patch)
treee2707cee9b4630a8519b44175053a6ea246413fd
parent7e8b2e42f1f558dfb7640a2565e3d2d4e3f7e59a (diff)
downloadpakakeh.go-5e07be2c3ad34ce8b025e04958aff28f621b9b23.tar.xz
lib/dns: use different packet between UDP and TCP messages
Previously, all packet size for reading and sending the message is fixed to 4096, even on UDP. This changes set the UDP packet size maximum to 512 bytes and others to 4096 bytes. While at it, minimize copying packet if its not reusable inside a method.
-rw-r--r--lib/dns/dns.go3
-rw-r--r--lib/dns/dotclient.go4
-rw-r--r--lib/dns/message.go2
-rw-r--r--lib/dns/message_test.go10
-rw-r--r--lib/dns/server.go28
-rw-r--r--lib/dns/tcpclient.go31
-rw-r--r--lib/dns/udpclient.go50
7 files changed, 61 insertions, 67 deletions
diff --git a/lib/dns/dns.go b/lib/dns/dns.go
index 47420807..11eb479b 100644
--- a/lib/dns/dns.go
+++ b/lib/dns/dns.go
@@ -41,7 +41,8 @@ const (
maskOPTDO uint32 = 0x00008000
maxLabelSize = 63
- maxUDPPacketSize = 4096
+ maxUdpPacketSize = 512
+ maxTcpPacketSize = 4096
rdataIPv4Size = 4
rdataIPv6Size = 16
// sectionHeaderSize define the size of section header in DNS message.
diff --git a/lib/dns/dotclient.go b/lib/dns/dotclient.go
index 1399973e..3d43d277 100644
--- a/lib/dns/dotclient.go
+++ b/lib/dns/dotclient.go
@@ -139,7 +139,7 @@ func (cl *DoTClient) recv(msg *Message) (n int, err error) {
return
}
- packet := make([]byte, maxUDPPacketSize)
+ packet := make([]byte, maxTcpPacketSize)
n, err = cl.conn.Read(packet)
if err != nil {
@@ -149,7 +149,7 @@ func (cl *DoTClient) recv(msg *Message) (n int, err error) {
return
}
- msg.packet = libbytes.Copy(packet[2:n])
+ msg.packet = packet[2:n]
if debug.Value >= 3 {
libbytes.PrintHex(">>> DoTClient: recv: ", msg.packet, 8)
diff --git a/lib/dns/message.go b/lib/dns/message.go
index f68122ad..a7c32323 100644
--- a/lib/dns/message.go
+++ b/lib/dns/message.go
@@ -619,7 +619,7 @@ func (msg *Message) Reset() {
msg.Question.Reset()
msg.ResetRR()
- msg.packet = append(msg.packet[:0], make([]byte, maxUDPPacketSize)...)
+ msg.packet = nil
msg.dname = ""
msg.off = 0
diff --git a/lib/dns/message_test.go b/lib/dns/message_test.go
index 7082cf94..67402b4c 100644
--- a/lib/dns/message_test.go
+++ b/lib/dns/message_test.go
@@ -844,7 +844,7 @@ func TestMessageSetAuthoritativeAnswer(t *testing.T) {
IsRD: true,
},
Question: SectionQuestion{},
- packet: make([]byte, maxUDPPacketSize),
+ packet: make([]byte, maxUdpPacketSize),
dnameOff: make(map[string]uint16),
}
@@ -861,7 +861,7 @@ func TestMessageSetAuthoritativeAnswer(t *testing.T) {
IsRA: true,
},
Question: SectionQuestion{},
- packet: make([]byte, maxUDPPacketSize),
+ packet: make([]byte, maxUdpPacketSize),
dnameOff: make(map[string]uint16),
}
@@ -913,7 +913,7 @@ func TestMessageSetQuery(t *testing.T) {
IsRD: true,
},
Question: SectionQuestion{},
- packet: make([]byte, maxUDPPacketSize),
+ packet: make([]byte, maxUdpPacketSize),
dnameOff: make(map[string]uint16),
}
@@ -956,7 +956,7 @@ func TestMessageSetRecursionDesired(t *testing.T) {
IsRD: true,
},
Question: SectionQuestion{},
- packet: make([]byte, maxUDPPacketSize),
+ packet: make([]byte, maxUdpPacketSize),
dnameOff: make(map[string]uint16),
}
@@ -999,7 +999,7 @@ func TestMessageSetResponseCode(t *testing.T) {
IsRD: true,
},
Question: SectionQuestion{},
- packet: make([]byte, maxUDPPacketSize),
+ packet: make([]byte, maxUdpPacketSize),
dnameOff: make(map[string]uint16),
}
diff --git a/lib/dns/server.go b/lib/dns/server.go
index 7c289cdb..167feb3d 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -462,13 +462,17 @@ func (srv *Server) serveTCP() {
// serveUDP serve DNS request from UDP connection.
//
func (srv *Server) serveUDP() {
- log.Println("dns.Server: listening for DNS over UDP at", srv.udp.LocalAddr())
+ var (
+ n int
+ packet = make([]byte, maxUdpPacketSize)
+ raddr *net.UDPAddr
+ req *request
+ err error
+ )
- packet := make([]byte, maxUDPPacketSize)
+ log.Println("dns.Server: listening for DNS over UDP at", srv.udp.LocalAddr())
for {
- req := newRequest()
-
- n, raddr, err := srv.udp.ReadFromUDP(packet)
+ n, raddr, err = srv.udp.ReadFromUDP(packet)
if err != nil {
if n == 0 || errors.Is(err, io.EOF) {
err = nil
@@ -479,6 +483,7 @@ func (srv *Server) serveUDP() {
return
}
+ req = newRequest()
req.message.packet = libbytes.Copy(packet[:n])
req.kind = connTypeUDP
@@ -606,18 +611,19 @@ func (srv *Server) incForwarder() {
}
func (srv *Server) serveTCPClient(cl *TCPClient, kind connType) {
+ var (
+ req *request
+ err error
+ )
for {
- req := newRequest()
+ req = newRequest()
- n, err := cl.recv(req.message)
+ req.message, err = cl.recv()
if err != nil {
log.Printf("serveTCPClient: %s: %s",
connTypeNames[kind], err.Error())
break
}
- if n == 0 || len(req.message.packet) == 0 {
- break
- }
req.kind = kind
req.writer = cl
@@ -632,7 +638,7 @@ func (srv *Server) serveTCPClient(cl *TCPClient, kind connType) {
srv.requestq <- req
}
- err := cl.conn.Close()
+ err = cl.conn.Close()
if err != nil {
log.Printf("serveTCPClient: conn.Close: %s: %s",
connTypeNames[kind], err.Error())
diff --git a/lib/dns/tcpclient.go b/lib/dns/tcpclient.go
index e88ad7d7..42f99fd9 100644
--- a/lib/dns/tcpclient.go
+++ b/lib/dns/tcpclient.go
@@ -6,6 +6,7 @@ package dns
import (
"fmt"
+ "io"
"net"
"time"
@@ -123,15 +124,13 @@ func (cl *TCPClient) Lookup(
// Query send DNS query to name server.
// The addr parameter is unused.
//
-func (cl *TCPClient) Query(msg *Message) (*Message, error) {
- _, err := cl.Write(msg.packet)
+func (cl *TCPClient) Query(msg *Message) (res *Message, err error) {
+ _, err = cl.Write(msg.packet)
if err != nil {
return nil, err
}
- res := NewMessage()
-
- _, err = cl.recv(res)
+ res, err = cl.recv()
if err != nil {
return nil, err
}
@@ -192,31 +191,33 @@ func (cl *TCPClient) Write(msg []byte) (n int, err error) {
}
//
-// recv will read DNS message from active connection in client into `msg`.
+// recv receive DNS message.
//
-func (cl *TCPClient) recv(msg *Message) (n int, err error) {
+func (cl *TCPClient) recv() (res *Message, err error) {
if cl.readTimeout > 0 {
err = cl.conn.SetReadDeadline(time.Now().Add(cl.readTimeout))
if err != nil {
- return
+ return nil, err
}
}
- packet := make([]byte, maxUDPPacketSize)
+ packet := make([]byte, maxTcpPacketSize)
- n, err = cl.conn.Read(packet)
+ n, err := cl.conn.Read(packet)
if err != nil {
- return
+ return nil, err
}
if n == 0 {
- return
+ return nil, io.EOF
}
- msg.packet = libbytes.Copy(packet[2:n])
+ res = &Message{
+ packet: packet[2:n],
+ }
if debug.Value >= 3 {
- libbytes.PrintHex(">>> TCPClient: recv: ", msg.packet, 8)
+ libbytes.PrintHex(">>> TCPClient.recv: ", res.packet, 8)
}
- return
+ return res, nil
}
diff --git a/lib/dns/udpclient.go b/lib/dns/udpclient.go
index 62322cad..c2bc6014 100644
--- a/lib/dns/udpclient.go
+++ b/lib/dns/udpclient.go
@@ -115,56 +115,42 @@ func (cl *UDPClient) Lookup(
//
// Query send DNS query to name server "ns" and return the unpacked response.
//
-func (cl *UDPClient) Query(msg *Message) (*Message, error) {
+func (cl *UDPClient) Query(req *Message) (res *Message, err error) {
+ logp := "Query"
cl.Lock()
+ defer cl.Unlock()
- _, err := cl.Write(msg.packet)
+ _, err = cl.Write(req.packet)
if err != nil {
- cl.Unlock()
- return nil, err
+ return nil, fmt.Errorf("%s: %w", logp, err)
}
- res := NewMessage()
-
- _, err = cl.recv(res)
+ err = cl.conn.SetReadDeadline(time.Now().Add(cl.timeout))
if err != nil {
- cl.Unlock()
- return nil, err
+ return nil, fmt.Errorf("%s: %w", logp, err)
}
- cl.Unlock()
+ packet := make([]byte, maxUdpPacketSize)
- err = res.Unpack()
+ n, _, err := cl.conn.ReadFromUDP(packet)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("%s: %w", logp, err)
}
- return res, nil
-}
-
-//
-// recv will read DNS message from active connection in client into `msg`.
-//
-func (cl *UDPClient) recv(msg *Message) (n int, err error) {
- err = cl.conn.SetReadDeadline(time.Now().Add(cl.timeout))
- if err != nil {
- return
+ res = &Message{
+ packet: packet[:n],
}
- packet := make([]byte, maxUDPPacketSize)
-
- n, _, err = cl.conn.ReadFromUDP(packet)
- if err != nil {
- return
+ if debug.Value >= 3 {
+ libbytes.PrintHex(">>> UDPClient.recv:", res.packet, 8)
}
- msg.packet = libbytes.Copy(packet[:n])
-
- if debug.Value >= 3 {
- libbytes.PrintHex(">>> UDPClient: recv:", msg.packet, 8)
+ err = res.Unpack()
+ if err != nil {
+ return nil, fmt.Errorf("%s: %w", logp, err)
}
- return
+ return res, nil
}
//