aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-04-11 23:14:25 +0700
committerShulhan <ms@kilabit.info>2019-04-12 19:14:02 +0700
commit82b55d27c2a197cee6355b34cebc0a57973fc09d (patch)
treeccaf6166d67f6e9595b4a6217d53e6cd21d90e19
parent152a82e1bba4628ee5a63c1d703d88b44aaccbee (diff)
downloadpakakeh.go-82b55d27c2a197cee6355b34cebc0a57973fc09d.tar.xz
dns: simplify server request using io.Writer
Previously, we need the UDP address of client to write back response to client on UDP connection and http ResponseWriter to write back response to client of DoH connection. This changes simplify the sender using io.Writer. On UDP connection, writer is an instance of UDPClient with connection reference to UDP server and with peer address. On TCP connection, writer is a TCP connection from accept. On Doh connection, writer is http ResponseWriter.
-rw-r--r--lib/dns/client.go1
-rw-r--r--lib/dns/dohclient.go86
-rw-r--r--lib/dns/request.go24
-rw-r--r--lib/dns/sender.go16
-rw-r--r--lib/dns/server.go58
-rw-r--r--lib/dns/tcpclient.go10
-rw-r--r--lib/dns/udpclient.go20
7 files changed, 68 insertions, 147 deletions
diff --git a/lib/dns/client.go b/lib/dns/client.go
index 08edf104..ede7dafd 100644
--- a/lib/dns/client.go
+++ b/lib/dns/client.go
@@ -18,5 +18,4 @@ type Client interface {
Query(req *Message, ns net.Addr) (*Message, error)
SetTimeout(t time.Duration)
SetRemoteAddr(addr string) error
- Sender
}
diff --git a/lib/dns/dohclient.go b/lib/dns/dohclient.go
index 6fb52047..026c193d 100644
--- a/lib/dns/dohclient.go
+++ b/lib/dns/dohclient.go
@@ -25,7 +25,12 @@ type DoHClient struct {
req *http.Request
query url.Values
conn *http.Client
- chRes chan *http.Response
+
+ // w hold the ResponseWriter on receiver side.
+ w http.ResponseWriter
+ // responded is a channel to signal the underlying receiver that the
+ // response has ready to be send to client.
+ responded chan bool
}
//
@@ -62,7 +67,6 @@ func NewDoHClient(nameserver string, allowInsecure bool) (*DoHClient, error) {
Transport: tr,
Timeout: clientTimeout,
},
- chRes: make(chan *http.Response, 1),
}
cl.req = &http.Request{
@@ -204,35 +208,6 @@ func (cl *DoHClient) Query(msg *Message, ns net.Addr) (*Message, error) {
}
//
-// recv read response from channel.
-//
-func (cl *DoHClient) recv(msg *Message) (int, error) {
- httpRes := <-cl.chRes
-
- body, err := ioutil.ReadAll(httpRes.Body)
- httpRes.Body.Close()
- if err != nil {
- return 0, err
- }
-
- if httpRes.StatusCode != 200 {
- err = fmt.Errorf("%s", string(body))
- return 0, err
- }
-
- msg.Packet = append(msg.Packet[:0], body...)
-
- if len(msg.Packet) > 20 {
- err = msg.Unpack()
- if err != nil {
- return 0, err
- }
- }
-
- return len(msg.Packet), nil
-}
-
-//
// RemoteAddr return client remote nameserver address.
//
func (cl *DoHClient) RemoteAddr() string {
@@ -240,28 +215,6 @@ func (cl *DoHClient) RemoteAddr() string {
}
//
-// Send DNS message to name server using Get method. Since HTTP client is
-// synchronous, the response is forwarded to channel to be consumed by recv().
-//
-func (cl *DoHClient) Send(msg []byte, ns net.Addr) (int, error) {
- packet := base64.RawURLEncoding.EncodeToString(msg)
-
- cl.query.Set("dns", packet)
- cl.req.Method = http.MethodGet
- cl.req.Body = nil
- cl.req.URL.RawQuery = cl.query.Encode()
-
- httpRes, err := cl.conn.Do(cl.req)
- if err != nil {
- return 0, err
- }
-
- cl.chRes <- httpRes
-
- return len(msg), nil
-}
-
-//
// SetRemoteAddr set the remote address for sending the packet.
//
func (cl *DoHClient) SetRemoteAddr(addr string) (err error) {
@@ -281,3 +234,30 @@ func (cl *DoHClient) SetRemoteAddr(addr string) (err error) {
func (cl *DoHClient) SetTimeout(t time.Duration) {
cl.conn.Timeout = t
}
+
+//
+// Write the raw DNS response message to active connection.
+// This method is only used by server to write the response of query to
+// client.
+//
+func (cl *DoHClient) Write(packet []byte) (n int, err error) {
+ n, err = cl.w.Write(packet)
+ if err != nil {
+ cl.responded <- false
+ return
+ }
+ cl.responded <- true
+ return
+}
+
+//
+// waitResponse wait for http.ResponseWriter being called by server.
+// This method is to prevent the function that process the HTTP request
+// terminated and write empty response.
+//
+func (cl *DoHClient) waitResponse() {
+ success, ok := <-cl.responded
+ if !success || !ok {
+ cl.w.WriteHeader(http.StatusGatewayTimeout)
+ }
+}
diff --git a/lib/dns/request.go b/lib/dns/request.go
index 4cc210b7..7dc266d9 100644
--- a/lib/dns/request.go
+++ b/lib/dns/request.go
@@ -5,8 +5,7 @@
package dns
import (
- "net"
- "net/http"
+ "io"
)
//
@@ -25,20 +24,13 @@ type request struct {
// Message define the DNS query.
message *Message
- // UDPAddr is address of client if connection is from UDP.
- udpAddr *net.UDPAddr
-
- // Sender is server connection that receive the query and responsible
- // to answer back to client.
- sender Sender
-
- // ResponseWriter is HTTP response writer, where answer for DoH
- // client query will be written.
- responseWriter http.ResponseWriter
-
- // ChanResponded is a channel that notify the DoH handler when answer
- // has been written to ResponseWriter.
- chanResponded chan bool
+ // writer represent client connection on server that receive the query
+ // and responsible to write the answer back.
+ // On UDP connection, writer is an instance of UDPClient with
+ // connection reference to UDP server and with peer address.
+ // On TCP connection, writer is a TCP connection from accept.
+ // On Doh connection, writer is http ResponseWriter.
+ writer io.Writer
}
//
diff --git a/lib/dns/sender.go b/lib/dns/sender.go
deleted file mode 100644
index 2c73851a..00000000
--- a/lib/dns/sender.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2018, Shulhan <ms@kilabit.info>. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package dns
-
-import (
- "net"
-)
-
-//
-// Sender is interface that for implementing sending raw DNS packet.
-//
-type Sender interface {
- Send(packet []byte, addr net.Addr) (n int, err error)
-}
diff --git a/lib/dns/server.go b/lib/dns/server.go
index 0b2bf06d..ca5c6ef9 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -352,11 +352,6 @@ func (srv *Server) serveTCP() {
// serveUDP serve DNS request from UDP connection.
//
func (srv *Server) serveUDP() {
- sender := &UDPClient{
- Timeout: clientTimeout,
- Conn: srv.udp,
- }
-
for {
req := newRequest()
@@ -370,11 +365,14 @@ func (srv *Server) serveUDP() {
}
req.kind = connTypeUDP
- req.udpAddr = raddr
req.message.Packet = req.message.Packet[:n]
req.message.UnpackHeaderQuestion()
- req.sender = sender
+ req.writer = &UDPClient{
+ Timeout: clientTimeout,
+ Conn: srv.udp,
+ Addr: raddr,
+ }
srv.requestq <- req
}
@@ -440,18 +438,18 @@ func (srv *Server) handleDoHRequest(raw []byte, w http.ResponseWriter) {
req := newRequest()
req.kind = connTypeDoH
- req.responseWriter = w
- req.chanResponded = make(chan bool, 1)
+ cl := &DoHClient{
+ w: w,
+ responded: make(chan bool, 1),
+ }
+ req.writer = cl
req.message.Packet = append(req.message.Packet[:0], raw...)
req.message.UnpackHeaderQuestion()
srv.requestq <- req
- _, ok := <-req.chanResponded
- if !ok {
- w.WriteHeader(http.StatusGatewayTimeout)
- }
+ cl.waitResponse()
}
func (srv *Server) serveTCPClient(cl *TCPClient) {
@@ -481,7 +479,7 @@ func (srv *Server) serveTCPClient(cl *TCPClient) {
req.kind = connTypeTCP
req.message.UnpackHeaderQuestion()
- req.sender = cl
+ req.writer = cl
srv.requestq <- req
}
@@ -548,34 +546,10 @@ func (srv *Server) processResponse(req *request, res *Message, isLocal bool) {
}
}
- switch req.kind {
- case connTypeUDP:
- if req.sender != nil {
- _, err := req.sender.Send(res.Packet, req.udpAddr)
- if err != nil {
- log.Println("dns: failed to send UDP reply:", err)
- return
- }
- }
-
- case connTypeTCP:
- if req.sender != nil {
- _, err := req.sender.Send(res.Packet, nil)
- if err != nil {
- log.Println("dns: failed to send TCP reply:", err)
- return
- }
- }
-
- case connTypeDoH:
- if req.responseWriter != nil {
- _, err := req.responseWriter.Write(res.Packet)
- req.chanResponded <- true
- if err != nil {
- log.Println("dns: failed to send DoH reply:", err)
- return
- }
- }
+ _, err := req.writer.Write(res.Packet)
+ if err != nil {
+ log.Println("dns: processResponse: ", err.Error())
+ return
}
if !isLocal {
diff --git a/lib/dns/tcpclient.go b/lib/dns/tcpclient.go
index 9c5e6ba3..a5f1dab9 100644
--- a/lib/dns/tcpclient.go
+++ b/lib/dns/tcpclient.go
@@ -118,7 +118,7 @@ func (cl *TCPClient) Lookup(allowRecursion bool, qtype, qclass uint16, qname []b
// The addr parameter is unused.
//
func (cl *TCPClient) Query(msg *Message, ns net.Addr) (*Message, error) {
- _, err := cl.Send(msg.Packet, ns)
+ _, err := cl.Write(msg.Packet)
if err != nil {
return nil, err
}
@@ -171,11 +171,11 @@ func (cl *TCPClient) recv(msg *Message) (n int, err error) {
}
//
-// Send DNS message to name server using active connection in client.
+// Write raw DNS response message on active connection.
+// This method is only used by server to write the response of query to
+// client.
//
-// The addr parameter is unused.
-//
-func (cl *TCPClient) Send(msg []byte, addr net.Addr) (n int, err error) {
+func (cl *TCPClient) Write(msg []byte) (n int, err error) {
err = cl.conn.SetWriteDeadline(time.Now().Add(cl.Timeout))
if err != nil {
return
diff --git a/lib/dns/udpclient.go b/lib/dns/udpclient.go
index 614e6a69..cd033614 100644
--- a/lib/dns/udpclient.go
+++ b/lib/dns/udpclient.go
@@ -119,7 +119,7 @@ func (cl *UDPClient) Query(msg *Message, ns net.Addr) (*Message, error) {
cl.Lock()
- _, err := cl.Send(msg.Packet, ns)
+ _, err := cl.Write(msg.Packet)
if err != nil {
cl.Unlock()
return nil, err
@@ -167,25 +167,17 @@ func (cl *UDPClient) recv(msg *Message) (n int, err error) {
}
//
-// Send DNS message to name server using active connection in client.
+// Write raw DNS response message on active connection.
+// This method is only used by server to write the response of query to
+// client.
//
-// The "ns" parameter must not be nil.
-//
-func (cl *UDPClient) Send(msg []byte, ns net.Addr) (n int, err error) {
- if ns == nil {
- ns = cl.Addr
- }
-
- raddr := ns.(*net.UDPAddr)
-
+func (cl *UDPClient) Write(msg []byte) (n int, err error) {
err = cl.Conn.SetWriteDeadline(time.Now().Add(cl.Timeout))
if err != nil {
return
}
- n, err = cl.Conn.WriteToUDP(msg, raddr)
-
- return
+ return cl.Conn.WriteToUDP(msg, cl.Addr)
}
//