aboutsummaryrefslogtreecommitdiff
path: root/lib/dns
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-04-08 12:09:22 +0700
committerShulhan <ms@kilabit.info>2019-04-12 19:13:07 +0700
commitfaeaebe07acecd3fa18cc1fdb8e868108d318658 (patch)
treee09e300bc653588be538b5f13b830989a649cbc4 /lib/dns
parentce558fb263d4bdef58ebfd801331c6b074349f33 (diff)
downloadpakakeh.go-faeaebe07acecd3fa18cc1fdb8e868108d318658.tar.xz
dns: add caches for server
There are two type of answer: local and non-local. Local answer is a DNS record that is loaded from hosts file or master zone file. Non-local answer is a DNS record that is received from parent name servers. Server caches the DNS answers in two storages: map and list. The map caches store local and non local answers, using domain name as a key and list of answers as value, domain-name -> [{A,IN,...},{AAAA,IN,...}] The list caches store non-local answers, ordered by last accessed time, it is used to prune least frequently accessed answers. Local caches will never get pruned.
Diffstat (limited to 'lib/dns')
-rw-r--r--lib/dns/answer.go104
-rw-r--r--lib/dns/answer_test.go279
-rw-r--r--lib/dns/answers.go76
-rw-r--r--lib/dns/answers_test.go237
-rw-r--r--lib/dns/caches.go172
-rw-r--r--lib/dns/caches_test.go283
-rw-r--r--lib/dns/dns_test.go11
-rw-r--r--lib/dns/example_server_test.go12
-rw-r--r--lib/dns/server.go219
-rw-r--r--lib/dns/serveroptions.go22
10 files changed, 1379 insertions, 36 deletions
diff --git a/lib/dns/answer.go b/lib/dns/answer.go
new file mode 100644
index 00000000..80d1dbb8
--- /dev/null
+++ b/lib/dns/answer.go
@@ -0,0 +1,104 @@
+// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "container/list"
+ "time"
+)
+
+//
+// answer maintain the record of DNS response for cache.
+//
+type answer struct {
+ // receivedAt contains time when message is received. If answer is
+ // from local cache (host or zone file), its value is 0.
+ receivedAt int64
+
+ // accessedAt contains time when message last accessed. This field
+ // is used to prune old answer from caches.
+ accessedAt int64
+
+ // qname contains DNS question name, a copy of msg.Question.Name.
+ qname string
+ // qtype contains DNS question type, a copy of msg.Question.Type.
+ qtype uint16
+ // qclass contains DNS question class, a copy of msg.Question.Class.
+ qclass uint16
+
+ // msg contains the unpacked DNS message.
+ msg *Message
+
+ // el contains pointer to the cache in LRU.
+ el *list.Element
+}
+
+//
+// newAnswer create new answer from Message.
+// If is not local (isLocal=false), the received and accessed time will be set
+// to current timestamp.
+//
+func newAnswer(msg *Message, isLocal bool) (an *answer) {
+ if msg == nil || msg.Question == nil || len(msg.Answer) == 0 {
+ return nil
+ }
+
+ an = &answer{
+ qname: string(msg.Question.Name),
+ qtype: msg.Question.Type,
+ qclass: msg.Question.Class,
+ msg: msg,
+ }
+ if isLocal {
+ return
+ }
+ at := time.Now().Unix()
+ an.receivedAt = at
+ an.accessedAt = at
+ return
+}
+
+//
+// clear the answer fields.
+//
+func (an *answer) clear() {
+ an.msg = nil
+ an.el = nil
+}
+
+//
+// get the raw packet in the message.
+// Before the raw packet is returned, the answer accessed time will be updated
+// to current time and each resource record's TTL in message is subtracted
+// based on received time.
+//
+func (an *answer) get() (packet []byte) {
+ if an.receivedAt > 0 {
+ an.accessedAt = time.Now().Unix()
+ ttl := uint32(an.accessedAt - an.receivedAt)
+ an.msg.SubTTL(ttl)
+ }
+
+ packet = make([]byte, len(an.msg.Packet))
+ copy(packet, an.msg.Packet)
+ return
+}
+
+//
+// update the answer with new message.
+//
+func (an *answer) update(nu *answer) {
+ if nu == nil || nu.msg == nil {
+ return
+ }
+
+ if an.receivedAt > 0 {
+ an.receivedAt = nu.receivedAt
+ an.accessedAt = nu.accessedAt
+ }
+
+ an.msg = nu.msg
+ nu.msg = nil
+}
diff --git a/lib/dns/answer_test.go b/lib/dns/answer_test.go
new file mode 100644
index 00000000..03e64a9f
--- /dev/null
+++ b/lib/dns/answer_test.go
@@ -0,0 +1,279 @@
+// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "container/list"
+ "testing"
+ "time"
+
+ "github.com/shuLhan/share/lib/test"
+)
+
+func TestNewAnswer(t *testing.T) {
+ at := time.Now().Unix()
+
+ msg1 := &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ Question: &SectionQuestion{
+ Name: []byte("test"),
+ Type: 1,
+ Class: 1,
+ },
+ Answer: []*ResourceRecord{{
+ Name: []byte("test"),
+ Type: QueryTypeA,
+ Class: QueryClassIN,
+ TTL: 3600,
+ rdlen: 4,
+ Text: &RDataText{
+ Value: []byte("127.0.0.1"),
+ },
+ }},
+ }
+
+ cases := []struct {
+ desc string
+ msg *Message
+ exp *answer
+ expMsg *Message
+ expQName string
+ expQType uint16
+ expQClass uint16
+ isLocal bool
+ }{{
+ desc: "With nil msg",
+ }, {
+ desc: "With local message",
+ msg: msg1,
+ isLocal: true,
+ exp: &answer{
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ msg: msg1,
+ },
+ expQName: "test",
+ expQType: 1,
+ expQClass: 1,
+ expMsg: msg1,
+ }, {
+ desc: "With non local message",
+ msg: msg1,
+ exp: &answer{
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ msg: msg1,
+ },
+ expQName: "test",
+ expQType: 1,
+ expQClass: 1,
+ expMsg: msg1,
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ got := newAnswer(c.msg, c.isLocal)
+
+ if got == nil {
+ test.Assert(t, "newAnswer", got, c.exp, true)
+ continue
+ }
+
+ if c.isLocal {
+ test.Assert(t, "newAnswer.receivedAt", int64(0), got.receivedAt, true)
+ test.Assert(t, "newAnswer.accessedAt", int64(0), got.accessedAt, true)
+ } else {
+ test.Assert(t, "newAnswer.receivedAt", true, got.receivedAt >= at, true)
+ test.Assert(t, "newAnswer.accessedAt", true, got.accessedAt >= at, true)
+ }
+
+ test.Assert(t, "newAnswer.qname", c.expQName, got.qname, true)
+ test.Assert(t, "newAnswer.qtype", c.expQType, got.qtype, true)
+ test.Assert(t, "newAnswer.qclass", c.expQClass, got.qclass, true)
+ test.Assert(t, "newAnswer.msg", c.expMsg, got.msg, true)
+ }
+}
+
+func TestAnswerClear(t *testing.T) {
+ msg := NewMessage()
+ el := &list.Element{
+ Value: 1,
+ }
+
+ an := &answer{
+ msg: msg,
+ el: el,
+ }
+
+ an.clear()
+
+ var expMsg *Message
+ var expEl *list.Element
+
+ test.Assert(t, "answer.msg", expMsg, an.msg, true)
+ test.Assert(t, "answer.el", expEl, an.el, true)
+}
+
+func TestAnswerGet(t *testing.T) {
+ // kilabit.info A
+ res := &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ QDCount: 1,
+ ANCount: 1,
+ },
+ Question: &SectionQuestion{
+ Name: []byte("kilabit.info"),
+ Type: QueryTypeA,
+ Class: QueryClassIN,
+ },
+ Answer: []*ResourceRecord{{
+ Name: []byte("kilabit.info"),
+ Type: QueryTypeA,
+ Class: QueryClassIN,
+ TTL: 3600,
+ rdlen: 4,
+ Text: &RDataText{
+ Value: []byte("127.0.0.1"),
+ },
+ }},
+ Authority: []*ResourceRecord{},
+ Additional: []*ResourceRecord{},
+ }
+
+ _, err := res.Pack()
+ if err != nil {
+ t.Fatal("Pack: ", err)
+ }
+
+ at := time.Now().Unix()
+
+ cases := []struct {
+ desc string
+ msg *Message
+ isLocal bool
+ }{{
+ desc: "With local answer",
+ msg: res,
+ isLocal: true,
+ }, {
+ desc: "With non local answer",
+ msg: res,
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ an := newAnswer(c.msg, c.isLocal)
+
+ if !c.isLocal {
+ an.receivedAt -= 5
+ }
+
+ gotPacket := an.get()
+
+ if c.isLocal {
+ test.Assert(t, "receivedAt", int64(0), an.receivedAt, true)
+ test.Assert(t, "accessedAt", int64(0), an.accessedAt, true)
+ test.Assert(t, "packet", c.msg.Packet, gotPacket, true)
+ continue
+ }
+
+ test.Assert(t, "receivedAt", an.receivedAt >= at-5, true, true)
+ test.Assert(t, "accessedAt", an.accessedAt >= at, true, true)
+ got := &Message{
+ Header: &SectionHeader{},
+ Question: &SectionQuestion{},
+ Packet: gotPacket,
+ }
+ err := got.Unpack()
+ if err != nil {
+ t.Fatal(err)
+ }
+ test.Assert(t, "Message.Header", c.msg.Header, got.Header, true)
+ test.Assert(t, "Message.Question", c.msg.Question, got.Question, true)
+ test.Assert(t, "Answer.TTL", c.msg.Answer[0].TTL, got.Answer[0].TTL, true)
+ }
+}
+
+func TestAnswerUpdate(t *testing.T) {
+ at := time.Now().Unix() - 5
+ msg1 := &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ }
+ msg2 := &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ }
+
+ cases := []struct {
+ desc string
+ an *answer
+ nu *answer
+ expReceivedAt int64
+ expAccessedAt int64
+ expMsg *Message
+ }{{
+ desc: "With nil parameter",
+ an: &answer{
+ receivedAt: 1,
+ accessedAt: 1,
+ msg: msg1,
+ },
+ expReceivedAt: 1,
+ expAccessedAt: 1,
+ expMsg: msg1,
+ }, {
+ desc: "With local answer",
+ an: &answer{
+ receivedAt: 0,
+ accessedAt: 0,
+ msg: msg1,
+ },
+ nu: &answer{
+ receivedAt: at,
+ accessedAt: at,
+ msg: msg2,
+ },
+ expReceivedAt: 0,
+ expAccessedAt: 0,
+ expMsg: nil,
+ }, {
+ desc: "With non local answer",
+ an: &answer{
+ receivedAt: 1,
+ accessedAt: 1,
+ msg: msg1,
+ },
+ nu: &answer{
+ receivedAt: at,
+ accessedAt: at,
+ msg: msg2,
+ },
+ expReceivedAt: at,
+ expAccessedAt: at,
+ expMsg: nil,
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ c.an.update(c.nu)
+
+ test.Assert(t, "receivedAt", c.expReceivedAt, c.an.receivedAt, true)
+ test.Assert(t, "accessedAt", c.expAccessedAt, c.an.accessedAt, true)
+ if c.nu != nil {
+ test.Assert(t, "c.nu.msg", c.expMsg, c.nu.msg, true)
+ }
+ }
+}
diff --git a/lib/dns/answers.go b/lib/dns/answers.go
new file mode 100644
index 00000000..01813ec4
--- /dev/null
+++ b/lib/dns/answers.go
@@ -0,0 +1,76 @@
+// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+//
+// answers maintain list of answer.
+//
+type answers struct {
+ v []*answer
+}
+
+//
+// newAnswers create and initialize list of answer with one element.
+//
+func newAnswers(an *answer) (ans *answers) {
+ ans = &answers{
+ v: make([]*answer, 0, 1),
+ }
+ if an != nil && an.msg != nil {
+ ans.v = append(ans.v, an)
+ }
+ return
+}
+
+//
+// get an answer with specific query type and class in slice.
+// If found, it will return its element and index in slice; otherwise it will
+// return nil on answer.
+//
+func (ans *answers) get(qtype, qclass uint16) (an *answer, x int) {
+ for x = 0; x < len(ans.v); x++ {
+ if ans.v[x].qtype != qtype {
+ continue
+ }
+ if ans.v[x].qclass != qclass {
+ continue
+ }
+
+ an = ans.v[x]
+ return
+ }
+ return
+}
+
+//
+// remove the answer from list.
+//
+func (ans *answers) remove(qtype, qclass uint16) {
+ an, x := ans.get(qtype, qclass)
+ if an != nil {
+ ans.v[x] = ans.v[len(ans.v)-1]
+ ans.v[len(ans.v)-1] = nil
+ ans.v = ans.v[:len(ans.v)-1]
+ }
+}
+
+//
+// upsert update or insert new answer to list.
+// If answer is updated, it will return the old answer; otherwise new insert
+// is inserted to list and it will return nil instead.
+//
+func (ans *answers) upsert(nu *answer) (an *answer) {
+ if nu == nil || nu.msg == nil {
+ return
+ }
+ var x int
+ an, x = ans.get(nu.qtype, nu.qclass)
+ if an != nil {
+ ans.v[x].update(nu)
+ } else {
+ ans.v = append(ans.v, nu)
+ }
+ return
+}
diff --git a/lib/dns/answers_test.go b/lib/dns/answers_test.go
new file mode 100644
index 00000000..4990763c
--- /dev/null
+++ b/lib/dns/answers_test.go
@@ -0,0 +1,237 @@
+// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "testing"
+
+ "github.com/shuLhan/share/lib/test"
+)
+
+func TestNewAnswers(t *testing.T) {
+ cases := []struct {
+ desc string
+ an *answer
+ expLen int
+ expV []*answer
+ }{{
+ desc: "With nil parameter",
+ expV: make([]*answer, 0, 1),
+ }, {
+ desc: "With nil message",
+ an: &answer{},
+ expLen: 0,
+ expV: []*answer{},
+ }, {
+ desc: "With valid answer",
+ an: &answer{
+ msg: &Message{},
+ },
+ expLen: 1,
+ expV: []*answer{{
+ msg: &Message{},
+ }},
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ got := newAnswers(c.an)
+
+ test.Assert(t, "len(answers.v)", len(got.v), c.expLen, true)
+ test.Assert(t, "answers.v", got.v, c.expV, true)
+ }
+}
+
+func TestAnswersGet(t *testing.T) {
+ msg := &Message{
+ Question: &SectionQuestion{
+ Name: []byte("test"),
+ Type: 1,
+ Class: 1,
+ },
+ Answer: []*ResourceRecord{{
+ Name: []byte("test"),
+ Type: QueryTypeA,
+ Class: QueryClassIN,
+ }},
+ }
+ an := newAnswer(msg, true)
+ ans := newAnswers(an)
+
+ cases := []struct {
+ desc string
+ qtype uint16
+ qclass uint16
+ exp *answer
+ expIndex int
+ }{{
+ desc: "With query type and class not found",
+ expIndex: 1,
+ }, {
+ desc: "With query type not found",
+ qclass: 1,
+ expIndex: 1,
+ }, {
+ desc: "With query class not found",
+ qtype: 1,
+ expIndex: 1,
+ }, {
+ desc: "With valid query type and class",
+ qtype: 1,
+ qclass: 1,
+ exp: an,
+ expIndex: 0,
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ got, x := ans.get(c.qtype, c.qclass)
+
+ test.Assert(t, "answers.get", c.exp, got, true)
+ test.Assert(t, "answers.get index", c.expIndex, x, true)
+ }
+}
+
+func TestAnswersRemove(t *testing.T) {
+ msg := &Message{
+ Question: &SectionQuestion{
+ Name: []byte("test"),
+ Type: 1,
+ Class: 1,
+ },
+ Answer: []*ResourceRecord{{
+ Name: []byte("test"),
+ Type: QueryTypeA,
+ Class: QueryClassIN,
+ }},
+ }
+
+ an := newAnswer(msg, true)
+ ans := newAnswers(an)
+
+ cases := []struct {
+ desc string
+ qtype, qclass uint16
+ exp *answers
+ expLen int
+ }{{
+ desc: "With query type and class not found",
+ exp: ans,
+ expLen: 1,
+ }, {
+ desc: "With query type not found",
+ qclass: 1,
+ exp: ans,
+ expLen: 1,
+ }, {
+ desc: "With query class not found",
+ qtype: 1,
+ exp: ans,
+ expLen: 1,
+ }, {
+ desc: "With valid query type and class",
+ qtype: 1,
+ qclass: 1,
+ exp: &answers{
+ v: make([]*answer, 0, 1),
+ },
+ expLen: 0,
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ ans.remove(c.qtype, c.qclass)
+
+ test.Assert(t, "len(answers.v)", c.expLen, len(ans.v), true)
+ test.Assert(t, "cap(answers.v)", 1, cap(ans.v), true)
+ test.Assert(t, "answers", c.exp, ans, true)
+ }
+}
+
+func TestAnswersUpdate(t *testing.T) {
+ an1 := &answer{
+ qtype: 1,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ },
+ }
+ an2 := &answer{
+ qtype: 2,
+ qclass: 1,
+ msg: &Message{},
+ }
+ an3 := &answer{
+ qtype: 1,
+ qclass: 2,
+ msg: &Message{},
+ }
+ an4 := &answer{
+ qtype: 1,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 2,
+ },
+ },
+ }
+
+ ans := newAnswers(an1)
+
+ cases := []struct {
+ desc string
+ nu *answer
+ exp *answers
+ }{{
+ desc: "With nil parameter",
+ exp: ans,
+ }, {
+ desc: "With query type not found",
+ nu: an2,
+ exp: &answers{
+ v: []*answer{
+ an1,
+ an2,
+ },
+ },
+ }, {
+ desc: "With query class not found",
+ nu: an3,
+ exp: &answers{
+ v: []*answer{
+ an1,
+ an2,
+ an3,
+ },
+ },
+ }, {
+ desc: "With query found",
+ nu: an4,
+ exp: &answers{
+ v: []*answer{
+ {
+ qtype: 1,
+ qclass: 1,
+ msg: an4.msg,
+ },
+ an2,
+ an3,
+ },
+ },
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ ans.upsert(c.nu)
+
+ test.Assert(t, "answers.upsert", c.exp, ans, true)
+ }
+}
diff --git a/lib/dns/caches.go b/lib/dns/caches.go
new file mode 100644
index 00000000..1f5b9a9c
--- /dev/null
+++ b/lib/dns/caches.go
@@ -0,0 +1,172 @@
+// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "container/list"
+ "sync"
+ "time"
+)
+
+//
+// caches of DNS answers.
+//
+type caches struct {
+ sync.Mutex
+
+ // v contains mapping of DNS question name (a domain name) with their
+ // list of answer.
+ v map[string]*answers
+
+ // lru represent list of non local answers, ordered based on answer
+ // access time in ascending order.
+ lru *list.List
+
+ // pruneDelay define a delay where caches will be pruned.
+ // Default to 1 hour.
+ pruneDelay time.Duration
+
+ // pruneThreshold define negative duration where answers will be
+ // pruned from caches.
+ // Default to -1 hour.
+ pruneThreshold time.Duration
+}
+
+//
+// newCaches create new in memory caches with specific prune delay and
+// threshold.
+// The prune delay MUST be greater than 1 minute or it will set to 1 hour.
+// The prune threshold MUST be greater than -1 minute or it will be set to 1
+// hour.
+//
+func newCaches(pruneDelay, pruneThreshold time.Duration) (ca *caches) {
+ if pruneDelay.Minutes() < 1 {
+ pruneDelay = time.Hour
+ }
+ if pruneThreshold.Minutes() > -1 {
+ pruneThreshold = -1 * time.Hour
+ }
+
+ ca = &caches{
+ v: make(map[string]*answers),
+ lru: list.New(),
+ pruneDelay: pruneDelay,
+ pruneThreshold: pruneThreshold,
+ }
+
+ return
+}
+
+//
+// get an answer from cache based on domain-name, query type, and query class.
+// If answer exist on cache, their accessed time will be updated to current
+// time.
+//
+func (c *caches) get(qname string, qtype, qclass uint16) (an *answer) {
+ c.Lock()
+
+ answers, found := c.v[qname]
+ if found {
+ an, _ = answers.get(qtype, qclass)
+ if an != nil {
+ // Move the answer to the back of LRU if its not
+ // local.
+ if an.receivedAt > 0 {
+ c.lru.MoveToBack(an.el)
+ }
+ }
+ }
+
+ c.Unlock()
+ return
+}
+
+//
+// list return all answers in LRU.
+//
+func (c *caches) list() (list []*answer) {
+ c.Lock()
+ for e := c.lru.Front(); e != nil; e = e.Next() {
+ list = append(list, e.Value.(*answer))
+ }
+ c.Unlock()
+ return
+}
+
+//
+// prune will remove old answers on caches based on accessed time.
+//
+func (c *caches) prune() {
+ c.Lock()
+
+ exp := time.Now().Add(c.pruneThreshold).Unix()
+
+ e := c.lru.Front()
+ for e != nil {
+ an := e.Value.(*answer)
+ if an.accessedAt > exp {
+ break
+ }
+
+ next := e.Next()
+ _ = c.lru.Remove(e)
+ c.remove(an)
+
+ e = next
+ }
+
+ c.Unlock()
+}
+
+//
+// remove an answer from caches.
+//
+func (c *caches) remove(an *answer) {
+ answers, found := c.v[an.qname]
+ if found {
+ answers.remove(an.qtype, an.qclass)
+ }
+ an.clear()
+}
+
+//
+// upsert update or insert answer to caches. If the answer is inserted to
+// caches it will return true, otherwise when its updated it will return
+// false.
+//
+func (c *caches) upsert(nu *answer) (inserted bool) {
+ if nu == nil || nu.msg == nil {
+ return
+ }
+
+ c.Lock()
+
+ answers, found := c.v[nu.qname]
+ if !found {
+ inserted = true
+ c.v[nu.qname] = newAnswers(nu)
+ if nu.receivedAt > 0 {
+ nu.el = c.lru.PushBack(nu)
+ }
+ } else {
+ an := answers.upsert(nu)
+ if an == nil {
+ inserted = true
+ if nu.receivedAt > 0 {
+ // Push the new answer to LRU if new answer is
+ // not local and its inserted to list.
+ nu.el = c.lru.PushBack(nu)
+ }
+ } else {
+ if nu.receivedAt > 0 {
+ c.lru.MoveToBack(an.el)
+ }
+ }
+ }
+
+ c.Unlock()
+
+ return inserted
+}
diff --git a/lib/dns/caches_test.go b/lib/dns/caches_test.go
new file mode 100644
index 00000000..be038c72
--- /dev/null
+++ b/lib/dns/caches_test.go
@@ -0,0 +1,283 @@
+// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "testing"
+ "time"
+
+ "github.com/shuLhan/share/lib/test"
+)
+
+func TestNewCaches(t *testing.T) {
+ cases := []struct {
+ desc string
+ pruneDelay time.Duration
+ pruneThreshold time.Duration
+ expDelay time.Duration
+ expThreshold time.Duration
+ }{{
+ desc: "With invalid delay and threshold",
+ expDelay: time.Hour,
+ expThreshold: -time.Hour,
+ }, {
+ desc: "With 2m delay and threshold",
+ pruneDelay: 2 * time.Minute,
+ pruneThreshold: -2 * time.Minute,
+ expDelay: 2 * time.Minute,
+ expThreshold: -2 * time.Minute,
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ got := newCaches(c.pruneDelay, c.pruneThreshold)
+
+ test.Assert(t, "caches.pruneDelay", c.expDelay,
+ got.pruneDelay, true)
+ test.Assert(t, "caches.pruneThreshold", c.expThreshold,
+ got.pruneThreshold, true)
+ }
+}
+
+func TestCachesGet(t *testing.T) {
+ an1 := &answer{
+ receivedAt: 1,
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ },
+ }
+ an2 := &answer{
+ receivedAt: 2,
+ qname: "test",
+ qtype: 2,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 2,
+ },
+ },
+ }
+ an3 := &answer{
+ receivedAt: 3,
+ qname: "test",
+ qtype: 3,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 3,
+ },
+ },
+ }
+
+ ca := newCaches(0, 0)
+
+ ca.upsert(an1)
+ ca.upsert(an2)
+ ca.upsert(an3)
+
+ cases := []struct {
+ desc string
+ qname string
+ qtype uint16
+ qclass uint16
+ exp *answer
+ expList []*answer
+ }{{
+ desc: "With query not found",
+ expList: []*answer{
+ an1, an2, an3,
+ },
+ }, {
+ desc: "With query found",
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ exp: an1,
+ expList: []*answer{
+ an2, an3, an1,
+ },
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ got := ca.get(c.qname, c.qtype, c.qclass)
+ gotList := ca.list()
+
+ test.Assert(t, "caches.get", c.exp, got, true)
+ test.Assert(t, "caches.list", c.expList, gotList, true)
+ }
+}
+
+func TestCachesPrune(t *testing.T) {
+ at := time.Now().Unix()
+
+ an1 := &answer{
+ receivedAt: 1,
+ accessedAt: 1,
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ },
+ }
+ an2 := &answer{
+ receivedAt: 2,
+ accessedAt: 2,
+ qname: "test",
+ qtype: 2,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 2,
+ },
+ },
+ }
+ an3 := &answer{
+ receivedAt: at,
+ accessedAt: at,
+ qname: "test",
+ qtype: 3,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 3,
+ },
+ },
+ }
+
+ ca := newCaches(0, 0)
+
+ ca.upsert(an1)
+ ca.upsert(an2)
+ ca.upsert(an3)
+
+ t.Logf("%+v\n", ca.list())
+
+ cases := []struct {
+ desc string
+ expList []*answer
+ }{{
+ desc: "With several caches got pruned",
+ expList: []*answer{
+ an3,
+ },
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ ca.prune()
+
+ gotList := ca.list()
+
+ test.Assert(t, "caches.list", c.expList, gotList, true)
+ }
+}
+
+func TestCachesUpsert(t *testing.T) {
+ ca := newCaches(0, 0)
+
+ an1 := &answer{
+ receivedAt: 1,
+ accessedAt: 1,
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 1,
+ },
+ },
+ }
+ an1Update := &answer{
+ receivedAt: 3,
+ accessedAt: 3,
+ qname: "test",
+ qtype: 1,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 3,
+ },
+ },
+ }
+ an2 := &answer{
+ receivedAt: 2,
+ accessedAt: 2,
+ qname: "test",
+ qtype: 2,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 2,
+ },
+ },
+ }
+ an2Update := &answer{
+ receivedAt: 4,
+ accessedAt: 4,
+ qname: "test",
+ qtype: 2,
+ qclass: 1,
+ msg: &Message{
+ Header: &SectionHeader{
+ ID: 4,
+ },
+ },
+ }
+
+ cases := []struct {
+ desc string
+ nu *answer
+ expLen int
+ expList []*answer
+ }{{
+ desc: "With empty answer",
+ }, {
+ desc: "With new answer",
+ nu: an1,
+ expLen: 1,
+ expList: []*answer{an1},
+ }, {
+ desc: "With new answer, different type",
+ nu: an2,
+ expLen: 2,
+ expList: []*answer{an1, an2},
+ }, {
+ desc: "With update on answer",
+ nu: an1Update,
+ expLen: 2,
+ expList: []*answer{an2, an1},
+ }, {
+ desc: "With update on answer (2)",
+ nu: an2Update,
+ expLen: 2,
+ expList: []*answer{an1, an2},
+ }}
+
+ for _, c := range cases {
+ t.Log(c.desc)
+
+ ca.upsert(c.nu)
+
+ gotList := ca.list()
+
+ test.Assert(t, "len(caches.list)", c.expLen, len(gotList), true)
+
+ for x := 0; x < len(gotList); x++ {
+ test.Assert(t, "caches.list", c.expList[x], gotList[x], true)
+ }
+ }
+}
diff --git a/lib/dns/dns_test.go b/lib/dns/dns_test.go
index 7929e168..ae2bc0fc 100644
--- a/lib/dns/dns_test.go
+++ b/lib/dns/dns_test.go
@@ -205,16 +205,14 @@ func (h *serverHandler) ServeDNS(req *Request) {
}
func TestMain(m *testing.M) {
+ var err error
+
log.SetFlags(log.Lmicroseconds)
_testHandler = &serverHandler{}
_testHandler.responses = generateTestResponses()
- _testServer = &Server{
- Handler: _testHandler,
- }
-
serverOptions := &ServerOptions{
IPAddress: "127.0.0.1",
UDPPort: 5300,
@@ -225,11 +223,14 @@ func TestMain(m *testing.M) {
DoHAllowInsecure: true,
}
- err := _testServer.Start(serverOptions)
+ _testServer, err = NewServer(serverOptions, _testHandler)
if err != nil {
log.Fatal(err)
}
+ _testServer.Start()
+
+ // Wait for all listeners running.
time.Sleep(500 * time.Millisecond)
s := m.Run()
diff --git a/lib/dns/example_server_test.go b/lib/dns/example_server_test.go
index 5046c042..e2689f08 100644
--- a/lib/dns/example_server_test.go
+++ b/lib/dns/example_server_test.go
@@ -3,6 +3,7 @@ package dns_test
import (
"fmt"
"log"
+ "time"
"github.com/shuLhan/share/lib/dns"
)
@@ -211,10 +212,6 @@ func ExampleServer() {
handler.generateResponses()
- server := &dns.Server{
- Handler: handler,
- }
-
serverOptions := &dns.ServerOptions{
IPAddress: "127.0.0.1",
TCPPort: 5300,
@@ -225,11 +222,16 @@ func ExampleServer() {
DoHAllowInsecure: true,
}
- err := server.Start(serverOptions)
+ server, err := dns.NewServer(serverOptions, handler)
if err != nil {
log.Fatal(err)
}
+ server.Start()
+
+ // Wait for all listeners running.
+ time.Sleep(500 * time.Millisecond)
+
clientLookup(serverAddress)
server.Stop()
diff --git a/lib/dns/server.go b/lib/dns/server.go
index e20d239b..df15aae9 100644
--- a/lib/dns/server.go
+++ b/lib/dns/server.go
@@ -13,6 +13,8 @@ import (
"log"
"net"
"net/http"
+ "os"
+ "path/filepath"
"strings"
"time"
)
@@ -20,59 +22,226 @@ import (
//
// Server defines DNS server.
//
+// Caches
+//
+// There are two type of answer: local and non-local.
+// Local answer is a DNS record that is loaded from hosts file or master
+// zone file.
+// Non-local answer is a DNS record that is received from parent name
+// servers.
+//
+// Server caches the DNS answers in two storages: map and list.
+// The map caches store local and non local answers, using domain name as a
+// key and list of answers as value,
+//
+// domain-name -> [{A,IN,...},{AAAA,IN,...}]
+//
+// The list caches store non-local answers, ordered by last accessed time,
+// it is used to prune least frequently accessed answers.
+// Local caches will never get pruned.
+//
type Server struct {
- Handler Handler
- udp *net.UDPConn
- tcp *net.TCPListener
- doh *http.Server
+ Handler Handler
+
errListener chan error
+ caches *caches
+
+ udp *net.UDPConn
+ tcp *net.TCPListener
+
+ doh *http.Server
+ dohAddress string
+ dohIdleTimeout time.Duration
+ dohCert *tls.Certificate
+ dohAllowInsecure bool
}
-func (srv *Server) init(opts *ServerOptions) (err error) {
+//
+// NewServer create and initialize server using the options and a .handler.
+//
+func NewServer(opts *ServerOptions, handler Handler) (srv *Server, err error) {
err = opts.init()
if err != nil {
- return
+ return nil, err
+ }
+
+ srv = &Server{
+ Handler: handler,
+ dohAddress: opts.getDoHAddress().String(),
+ dohIdleTimeout: opts.DoHIdleTimeout,
+ dohCert: opts.cert,
+ dohAllowInsecure: opts.DoHAllowInsecure,
}
udpAddr := opts.getUDPAddress()
srv.udp, err = net.ListenUDP("udp", udpAddr)
if err != nil {
- return fmt.Errorf("dns: error listening on UDP '%v': %s",
+ return nil, fmt.Errorf("dns: error listening on UDP '%v': %s",
udpAddr, err.Error())
}
tcpAddr := opts.getTCPAddress()
srv.tcp, err = net.ListenTCP("tcp", tcpAddr)
if err != nil {
- return fmt.Errorf("dns: error listening on TCP '%v': %s",
+ return nil, fmt.Errorf("dns: error listening on TCP '%v': %s",
tcpAddr, err.Error())
}
srv.errListener = make(chan error, 1)
+ srv.caches = newCaches(opts.PruneDelay, opts.PruneThreshold)
+
+ opts.cert = nil
- return nil
+ return srv, nil
}
//
-// Start the server, listening and serve query from clients.
+// LoadHostsDir populate caches with DNS record from hosts formatted files in
+// directory "dir".
+//
+func (srv *Server) LoadHostsDir(dir string) {
+ if len(dir) == 0 {
+ return
+ }
+
+ d, err := os.Open(dir)
+ if err != nil {
+ log.Println("dns: LoadHostsDir: ", err)
+ return
+ }
+
+ fis, err := d.Readdir(0)
+ if err != nil {
+ log.Println("dns: LoadHostsDir: ", err)
+ err = d.Close()
+ if err != nil {
+ log.Println("dns: LoadHostsDir: ", err)
+ }
+ return
+ }
+
+ for x := 0; x < len(fis); x++ {
+ if fis[x].IsDir() {
+ continue
+ }
+
+ hostsFile := filepath.Join(dir, fis[x].Name())
+
+ srv.LoadHostsFile(hostsFile)
+ }
+
+ err = d.Close()
+ if err != nil {
+ log.Println("dns: LoadHostsDir: ", err)
+ }
+}
+
+//
+// LoadHostsFile populate caches with DNS record from hosts formatted file.
+//
+func (srv *Server) LoadHostsFile(path string) {
+ if len(path) == 0 {
+ fmt.Println("dns: loading system hosts file")
+ } else {
+ fmt.Printf("dns: loading hosts file '%s'\n", path)
+ }
+
+ msgs, err := HostsLoad(path)
+ if err != nil {
+ log.Println("dns: LoadHostsFile: " + err.Error())
+ }
+
+ srv.populateCaches(msgs)
+}
+
//
-func (srv *Server) Start(opts *ServerOptions) (err error) {
- err = srv.init(opts)
+// LoadMasterDir populate caches with DNS record from master (zone) formatted
+// files in directory "dir".
+//
+func (srv *Server) LoadMasterDir(dir string) {
+ if len(dir) == 0 {
+ return
+ }
+
+ d, err := os.Open(dir)
if err != nil {
+ log.Println("dns: LoadMasterDir: ", err)
return
}
- if opts.cert != nil {
- dohAddress := opts.getDoHAddress()
- go srv.serveDoH(dohAddress, opts.DoHIdleTimeout, *opts.cert,
- opts.DoHAllowInsecure)
- opts.cert = nil
+ fis, err := d.Readdir(0)
+ if err != nil {
+ log.Println("dns: LoadMasterDir: ", err)
+ err = d.Close()
+ if err != nil {
+ log.Println("dns: LoadMasterDir: ", err)
+ }
+ return
+ }
+
+ for x := 0; x < len(fis); x++ {
+ if fis[x].IsDir() {
+ continue
+ }
+
+ masterFile := filepath.Join(dir, fis[x].Name())
+
+ srv.LoadMasterFile(masterFile)
+ }
+
+ err = d.Close()
+ if err != nil {
+ log.Println("dns: LoadMasterDir: error closing directory:", err)
+ }
+}
+
+//
+// LostMasterFile populate caches with DNS record from master (zone) formatted
+// file.
+//
+func (srv *Server) LoadMasterFile(path string) {
+ fmt.Printf("dns: loading master file '%s'\n", path)
+
+ msgs, err := MasterLoad(path, "", 0)
+ if err != nil {
+ log.Println("dns: LoadMasterFile: " + err.Error())
+ }
+
+ srv.populateCaches(msgs)
+}
+
+//
+// populateCaches add list of message to caches.
+//
+func (srv *Server) populateCaches(msgs []*Message) {
+ var (
+ n int
+ inserted bool
+ isLocal = true
+ )
+
+ for x := 0; x < len(msgs); x++ {
+ an := newAnswer(msgs[x], isLocal)
+ inserted = srv.caches.upsert(an)
+ if inserted {
+ n++
+ }
+ msgs[x] = nil
+ }
+
+ fmt.Printf("dns: %d out of %d records cached\n", n, len(msgs))
+}
+
+//
+// Start the server, listening and serve query from clients.
+//
+func (srv *Server) Start() {
+ if srv.dohCert != nil {
+ go srv.serveDoH()
}
go srv.serveTCP()
go srv.serveUDP()
-
- return nil
}
//
@@ -109,17 +278,15 @@ func (srv *Server) Wait() {
// serveDoH listen for request over HTTPS using certificate and key
// file in parameter. The path to request is static "/dns-query".
//
-func (srv *Server) serveDoH(addr *net.TCPAddr, idleTimeout time.Duration,
- cert tls.Certificate, allowInsecure bool,
-) {
+func (srv *Server) serveDoH() {
srv.doh = &http.Server{
- Addr: addr.String(),
- IdleTimeout: idleTimeout,
+ Addr: srv.dohAddress,
+ IdleTimeout: srv.dohIdleTimeout,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{
- cert,
+ *srv.dohCert,
},
- InsecureSkipVerify: allowInsecure, // nolint: gosec
+ InsecureSkipVerify: srv.dohAllowInsecure, // nolint: gosec
},
}
diff --git a/lib/dns/serveroptions.go b/lib/dns/serveroptions.go
index 3284f682..9ea8ecc1 100644
--- a/lib/dns/serveroptions.go
+++ b/lib/dns/serveroptions.go
@@ -51,6 +51,22 @@ type ServerOptions struct {
// This field is optional.
DoHAllowInsecure bool
+ // PruneDelay define a delay where caches will be pruned.
+ // This field is optional, minimum value is 1 minute, and default
+ // value is 1 hour.
+ // For example, if its set to 1 hour, every 1 hour the caches will be
+ // inspected to remove answers that has not been accessed more than or
+ // equal to PruneThreshold.
+ PruneDelay time.Duration
+
+ // PruneThreshold define negative duration where answers will be
+ // pruned from caches.
+ // This field is optional, minimum value is -1 minute, and default
+ // value is -1 hour,
+ // For example, if its set to -1 minute, any answers that has not been
+ // accessed in the last 1 minute will be removed from cache.
+ PruneThreshold time.Duration
+
ip net.IP
cert *tls.Certificate
}
@@ -88,6 +104,12 @@ func (opts *ServerOptions) init() (err error) {
if opts.DoHIdleTimeout <= 0 {
opts.DoHIdleTimeout = defaultDoHIdleTimeout
}
+ if opts.PruneDelay.Minutes() < 1 {
+ opts.PruneDelay = time.Hour
+ }
+ if opts.PruneThreshold.Minutes() > -1 {
+ opts.PruneThreshold = -1 * time.Hour
+ }
return nil
}