diff options
| author | Shulhan <ms@kilabit.info> | 2019-04-10 17:32:19 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-04-12 19:14:02 +0700 |
| commit | 9d2d31c8b5502356b526ce6dceaf056ba5fd5a7f (patch) | |
| tree | 4c4bc0506b85ef7d25042af162a71ab3f0f3dbfd | |
| parent | 47a096dc764c4a83fbdb96b04237d5851e86fab2 (diff) | |
| download | pakakeh.go-9d2d31c8b5502356b526ce6dceaf056ba5fd5a7f.tar.xz | |
dns: unexport request
Since the caches and forwarding now is handled internally, there is no
need for exporting the request anymore.
| -rw-r--r-- | lib/dns/request.go | 34 | ||||
| -rw-r--r-- | lib/dns/server.go | 106 |
2 files changed, 65 insertions, 75 deletions
diff --git a/lib/dns/request.go b/lib/dns/request.go index 00a8f909..a5e6d4c9 100644 --- a/lib/dns/request.go +++ b/lib/dns/request.go @@ -10,52 +10,42 @@ import ( ) // -// Request contains UDP address and DNS query message from client. +// request contains UDP address and DNS query message from client. // // If Kind is UDP, Sender and UDPAddr must be non nil. // If Kind is TCP, Sender must be non nil. // If Kind is DoH, both Sender and UDPAddr must be nil and ResponseWriter and // ChanResponded must be non nil and initialized. // -type Request struct { +type request struct { // Kind define the connection type that this request is belong to, // e.g. UDP, TCP, or DoH. - Kind ConnType + kind ConnType // Message define the DNS query. - Message *Message + message *Message // UDPAddr is address of client if connection is from UDP. - UDPAddr *net.UDPAddr + udpAddr *net.UDPAddr // Sender is server connection that receive the query and responsible // to answer back to client. - Sender Sender + sender Sender // ResponseWriter is HTTP response writer, where answer for DoH // client query will be written. - ResponseWriter http.ResponseWriter + responseWriter http.ResponseWriter // ChanResponded is a channel that notify the DoH handler when answer // has been written to ResponseWriter. - ChanResponded chan bool + chanResponded chan bool } // -// NewRequest create and initialize request. +// newRequest create and initialize request. // -func NewRequest() *Request { - return &Request{ - Message: NewMessage(), +func newRequest() *request { + return &request{ + message: NewMessage(), } } - -// -// Reset message and UDP address in request. -// -func (req *Request) Reset() { - req.Message.Reset() - req.UDPAddr = nil - req.Sender = nil - req.ResponseWriter = nil -} diff --git a/lib/dns/server.go b/lib/dns/server.go index d41fd6fc..c727cb89 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -49,8 +49,8 @@ type Server struct { tcp *net.TCPListener doh *http.Server - requestq chan *Request - forwardq chan *Request + requestq chan *request + forwardq chan *request hasForwarders bool } @@ -66,8 +66,8 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) { srv = &Server{ opts: opts, - requestq: make(chan *Request, 512), - forwardq: make(chan *Request, 512), + requestq: make(chan *request, 512), + forwardq: make(chan *request, 512), } udpAddr := opts.getUDPAddress() @@ -94,20 +94,20 @@ func NewServer(opts *ServerOptions) (srv *Server, err error) { // isResponseValid check if request name, type, and class match with response. // It will return true if both matched, otherwise it will return false. // -func isResponseValid(req *Request, res *Message) bool { - if !bytes.Equal(req.Message.Question.Name, res.Question.Name) { +func isResponseValid(req *request, res *Message) bool { + if !bytes.Equal(req.message.Question.Name, res.Question.Name) { log.Printf("dns: unmatched response name, got %s want %s\n", - req.Message.Question.Name, res.Question.Name) + req.message.Question.Name, res.Question.Name) return false } - if req.Message.Question.Type != res.Question.Type { + if req.message.Question.Type != res.Question.Type { log.Printf("dns: unmatched response type, got %s want %s\n", - req.Message.Question, res.Question) + req.message.Question, res.Question) return false } - if req.Message.Question.Class != res.Question.Class { + if req.message.Question.Class != res.Question.Class { log.Printf("dns: unmatched response class, got %s want %s\n", - req.Message.Question, res.Question) + req.message.Question, res.Question) return false } @@ -358,9 +358,9 @@ func (srv *Server) serveUDP() { } for { - req := NewRequest() + req := newRequest() - n, raddr, err := srv.udp.ReadFromUDP(req.Message.Packet) + n, raddr, err := srv.udp.ReadFromUDP(req.message.Packet) if err != nil { if err != io.EOF { err = fmt.Errorf("dns: error when reading from UDP: " + err.Error()) @@ -369,12 +369,12 @@ func (srv *Server) serveUDP() { return } - req.Kind = ConnTypeUDP - req.UDPAddr = raddr - req.Message.Packet = req.Message.Packet[:n] + req.kind = ConnTypeUDP + req.udpAddr = raddr + req.message.Packet = req.message.Packet[:n] - req.Message.UnpackHeaderQuestion() - req.Sender = sender + req.message.UnpackHeaderQuestion() + req.sender = sender srv.requestq <- req } @@ -437,18 +437,18 @@ func (srv *Server) handleDoHPost(w http.ResponseWriter, r *http.Request) { } func (srv *Server) handleDoHRequest(raw []byte, w http.ResponseWriter) { - req := NewRequest() + req := newRequest() - req.Kind = ConnTypeDoH - req.ResponseWriter = w - req.ChanResponded = make(chan bool, 1) + req.kind = ConnTypeDoH + req.responseWriter = w + req.chanResponded = make(chan bool, 1) - req.Message.Packet = append(req.Message.Packet[:0], raw...) - req.Message.UnpackHeaderQuestion() + req.message.Packet = append(req.message.Packet[:0], raw...) + req.message.UnpackHeaderQuestion() srv.requestq <- req - _, ok := <-req.ChanResponded + _, ok := <-req.chanResponded if !ok { w.WriteHeader(http.StatusGatewayTimeout) } @@ -460,9 +460,9 @@ func (srv *Server) serveTCPClient(cl *TCPClient) { err error ) for { - req := NewRequest() + req := newRequest() for { - n, err = cl.Recv(req.Message) + n, err = cl.Recv(req.message) if err == nil { break } @@ -471,7 +471,7 @@ func (srv *Server) serveTCPClient(cl *TCPClient) { } if n != 0 { log.Println("serveTCPClient:", err) - req.Message.Reset() + req.message.Reset() } continue } @@ -479,9 +479,9 @@ func (srv *Server) serveTCPClient(cl *TCPClient) { break } - req.Kind = ConnTypeTCP - req.Message.UnpackHeaderQuestion() - req.Sender = cl + req.kind = ConnTypeTCP + req.message.UnpackHeaderQuestion() + req.sender = cl srv.requestq <- req } @@ -502,29 +502,29 @@ func (srv *Server) processRequest() { ) for req := range srv.requestq { - ans, an := srv.caches.get(string(req.Message.Question.Name), - req.Message.Question.Type, - req.Message.Question.Class) + ans, an := srv.caches.get(string(req.message.Question.Name), + req.message.Question.Type, + req.message.Question.Class) if ans == nil { - if req.Message.Header.IsRD && srv.hasForwarders { + if req.message.Header.IsRD && srv.hasForwarders { srv.forwardq <- req continue } - req.Message.SetResponseCode(RCodeErrName) + req.message.SetResponseCode(RCodeErrName) } isLocal = false if an == nil { - if req.Message.Header.IsRD && srv.hasForwarders { + if req.message.Header.IsRD && srv.hasForwarders { srv.forwardq <- req continue } - req.Message.SetQuery(false) - req.Message.SetAuthorativeAnswer(true) - res = req.Message + req.message.SetQuery(false) + req.message.SetAuthorativeAnswer(true) + res = req.message isLocal = true } else { if an.msg.IsExpired() && srv.hasForwarders { @@ -532,7 +532,7 @@ func (srv *Server) processRequest() { continue } - an.msg.SetID(req.Message.Header.ID) + an.msg.SetID(req.message.Header.ID) res = an.msg isLocal = (an.receivedAt == 0) } @@ -541,17 +541,17 @@ func (srv *Server) processRequest() { } } -func (srv *Server) processResponse(req *Request, res *Message, isLocal bool) { +func (srv *Server) processResponse(req *request, res *Message, isLocal bool) { if !isLocal { if !isResponseValid(req, res) { return } } - switch req.Kind { + switch req.kind { case ConnTypeUDP: - if req.Sender != nil { - _, err := req.Sender.Send(res.Packet, req.UDPAddr) + if req.sender != nil { + _, err := req.sender.Send(res.Packet, req.udpAddr) if err != nil { log.Println("dns: failed to send UDP reply:", err) return @@ -559,8 +559,8 @@ func (srv *Server) processResponse(req *Request, res *Message, isLocal bool) { } case ConnTypeTCP: - if req.Sender != nil { - _, err := req.Sender.Send(res.Packet, nil) + if req.sender != nil { + _, err := req.sender.Send(res.Packet, nil) if err != nil { log.Println("dns: failed to send TCP reply:", err) return @@ -568,9 +568,9 @@ func (srv *Server) processResponse(req *Request, res *Message, isLocal bool) { } case ConnTypeDoH: - if req.ResponseWriter != nil { - _, err := req.ResponseWriter.Write(res.Packet) - req.ChanResponded <- true + if req.responseWriter != nil { + _, err := req.responseWriter.Write(res.Packet) + req.chanResponded <- true if err != nil { log.Println("dns: failed to send DoH reply:", err) return @@ -619,7 +619,7 @@ func (srv *Server) runDohForwarder(nameserver string) { } for req := range srv.forwardq { - res, err := forwarder.Query(req.Message, nil) + res, err := forwarder.Query(req.message, nil) if err != nil { log.Println("dns: failed to query DoH: " + err.Error()) continue @@ -637,7 +637,7 @@ func (srv *Server) runTCPForwarder(remoteAddr *net.TCPAddr) { continue } - res, err := cl.Query(req.Message, nil) + res, err := cl.Query(req.message, nil) cl.Close() if err != nil { log.Println("dns: failed to query TCP: " + err.Error()) @@ -655,7 +655,7 @@ func (srv *Server) runUDPForwarder(remoteAddr *net.UDPAddr) { } for req := range srv.forwardq { - res, err := forwarder.Query(req.Message, remoteAddr) + res, err := forwarder.Query(req.message, remoteAddr) if err != nil { log.Println("dns: failed to query UDP: " + err.Error()) continue |
