diff options
| -rw-r--r-- | lib/dns/answer.go | 104 | ||||
| -rw-r--r-- | lib/dns/answer_test.go | 279 | ||||
| -rw-r--r-- | lib/dns/answers.go | 76 | ||||
| -rw-r--r-- | lib/dns/answers_test.go | 237 | ||||
| -rw-r--r-- | lib/dns/caches.go | 172 | ||||
| -rw-r--r-- | lib/dns/caches_test.go | 283 | ||||
| -rw-r--r-- | lib/dns/dns_test.go | 11 | ||||
| -rw-r--r-- | lib/dns/example_server_test.go | 12 | ||||
| -rw-r--r-- | lib/dns/server.go | 219 | ||||
| -rw-r--r-- | lib/dns/serveroptions.go | 22 |
10 files changed, 1379 insertions, 36 deletions
diff --git a/lib/dns/answer.go b/lib/dns/answer.go new file mode 100644 index 00000000..80d1dbb8 --- /dev/null +++ b/lib/dns/answer.go @@ -0,0 +1,104 @@ +// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "container/list" + "time" +) + +// +// answer maintain the record of DNS response for cache. +// +type answer struct { + // receivedAt contains time when message is received. If answer is + // from local cache (host or zone file), its value is 0. + receivedAt int64 + + // accessedAt contains time when message last accessed. This field + // is used to prune old answer from caches. + accessedAt int64 + + // qname contains DNS question name, a copy of msg.Question.Name. + qname string + // qtype contains DNS question type, a copy of msg.Question.Type. + qtype uint16 + // qclass contains DNS question class, a copy of msg.Question.Class. + qclass uint16 + + // msg contains the unpacked DNS message. + msg *Message + + // el contains pointer to the cache in LRU. + el *list.Element +} + +// +// newAnswer create new answer from Message. +// If is not local (isLocal=false), the received and accessed time will be set +// to current timestamp. +// +func newAnswer(msg *Message, isLocal bool) (an *answer) { + if msg == nil || msg.Question == nil || len(msg.Answer) == 0 { + return nil + } + + an = &answer{ + qname: string(msg.Question.Name), + qtype: msg.Question.Type, + qclass: msg.Question.Class, + msg: msg, + } + if isLocal { + return + } + at := time.Now().Unix() + an.receivedAt = at + an.accessedAt = at + return +} + +// +// clear the answer fields. +// +func (an *answer) clear() { + an.msg = nil + an.el = nil +} + +// +// get the raw packet in the message. +// Before the raw packet is returned, the answer accessed time will be updated +// to current time and each resource record's TTL in message is subtracted +// based on received time. +// +func (an *answer) get() (packet []byte) { + if an.receivedAt > 0 { + an.accessedAt = time.Now().Unix() + ttl := uint32(an.accessedAt - an.receivedAt) + an.msg.SubTTL(ttl) + } + + packet = make([]byte, len(an.msg.Packet)) + copy(packet, an.msg.Packet) + return +} + +// +// update the answer with new message. +// +func (an *answer) update(nu *answer) { + if nu == nil || nu.msg == nil { + return + } + + if an.receivedAt > 0 { + an.receivedAt = nu.receivedAt + an.accessedAt = nu.accessedAt + } + + an.msg = nu.msg + nu.msg = nil +} diff --git a/lib/dns/answer_test.go b/lib/dns/answer_test.go new file mode 100644 index 00000000..03e64a9f --- /dev/null +++ b/lib/dns/answer_test.go @@ -0,0 +1,279 @@ +// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "container/list" + "testing" + "time" + + "github.com/shuLhan/share/lib/test" +) + +func TestNewAnswer(t *testing.T) { + at := time.Now().Unix() + + msg1 := &Message{ + Header: &SectionHeader{ + ID: 1, + }, + Question: &SectionQuestion{ + Name: []byte("test"), + Type: 1, + Class: 1, + }, + Answer: []*ResourceRecord{{ + Name: []byte("test"), + Type: QueryTypeA, + Class: QueryClassIN, + TTL: 3600, + rdlen: 4, + Text: &RDataText{ + Value: []byte("127.0.0.1"), + }, + }}, + } + + cases := []struct { + desc string + msg *Message + exp *answer + expMsg *Message + expQName string + expQType uint16 + expQClass uint16 + isLocal bool + }{{ + desc: "With nil msg", + }, { + desc: "With local message", + msg: msg1, + isLocal: true, + exp: &answer{ + qname: "test", + qtype: 1, + qclass: 1, + msg: msg1, + }, + expQName: "test", + expQType: 1, + expQClass: 1, + expMsg: msg1, + }, { + desc: "With non local message", + msg: msg1, + exp: &answer{ + qname: "test", + qtype: 1, + qclass: 1, + msg: msg1, + }, + expQName: "test", + expQType: 1, + expQClass: 1, + expMsg: msg1, + }} + + for _, c := range cases { + t.Log(c.desc) + + got := newAnswer(c.msg, c.isLocal) + + if got == nil { + test.Assert(t, "newAnswer", got, c.exp, true) + continue + } + + if c.isLocal { + test.Assert(t, "newAnswer.receivedAt", int64(0), got.receivedAt, true) + test.Assert(t, "newAnswer.accessedAt", int64(0), got.accessedAt, true) + } else { + test.Assert(t, "newAnswer.receivedAt", true, got.receivedAt >= at, true) + test.Assert(t, "newAnswer.accessedAt", true, got.accessedAt >= at, true) + } + + test.Assert(t, "newAnswer.qname", c.expQName, got.qname, true) + test.Assert(t, "newAnswer.qtype", c.expQType, got.qtype, true) + test.Assert(t, "newAnswer.qclass", c.expQClass, got.qclass, true) + test.Assert(t, "newAnswer.msg", c.expMsg, got.msg, true) + } +} + +func TestAnswerClear(t *testing.T) { + msg := NewMessage() + el := &list.Element{ + Value: 1, + } + + an := &answer{ + msg: msg, + el: el, + } + + an.clear() + + var expMsg *Message + var expEl *list.Element + + test.Assert(t, "answer.msg", expMsg, an.msg, true) + test.Assert(t, "answer.el", expEl, an.el, true) +} + +func TestAnswerGet(t *testing.T) { + // kilabit.info A + res := &Message{ + Header: &SectionHeader{ + ID: 1, + QDCount: 1, + ANCount: 1, + }, + Question: &SectionQuestion{ + Name: []byte("kilabit.info"), + Type: QueryTypeA, + Class: QueryClassIN, + }, + Answer: []*ResourceRecord{{ + Name: []byte("kilabit.info"), + Type: QueryTypeA, + Class: QueryClassIN, + TTL: 3600, + rdlen: 4, + Text: &RDataText{ + Value: []byte("127.0.0.1"), + }, + }}, + Authority: []*ResourceRecord{}, + Additional: []*ResourceRecord{}, + } + + _, err := res.Pack() + if err != nil { + t.Fatal("Pack: ", err) + } + + at := time.Now().Unix() + + cases := []struct { + desc string + msg *Message + isLocal bool + }{{ + desc: "With local answer", + msg: res, + isLocal: true, + }, { + desc: "With non local answer", + msg: res, + }} + + for _, c := range cases { + t.Log(c.desc) + + an := newAnswer(c.msg, c.isLocal) + + if !c.isLocal { + an.receivedAt -= 5 + } + + gotPacket := an.get() + + if c.isLocal { + test.Assert(t, "receivedAt", int64(0), an.receivedAt, true) + test.Assert(t, "accessedAt", int64(0), an.accessedAt, true) + test.Assert(t, "packet", c.msg.Packet, gotPacket, true) + continue + } + + test.Assert(t, "receivedAt", an.receivedAt >= at-5, true, true) + test.Assert(t, "accessedAt", an.accessedAt >= at, true, true) + got := &Message{ + Header: &SectionHeader{}, + Question: &SectionQuestion{}, + Packet: gotPacket, + } + err := got.Unpack() + if err != nil { + t.Fatal(err) + } + test.Assert(t, "Message.Header", c.msg.Header, got.Header, true) + test.Assert(t, "Message.Question", c.msg.Question, got.Question, true) + test.Assert(t, "Answer.TTL", c.msg.Answer[0].TTL, got.Answer[0].TTL, true) + } +} + +func TestAnswerUpdate(t *testing.T) { + at := time.Now().Unix() - 5 + msg1 := &Message{ + Header: &SectionHeader{ + ID: 1, + }, + } + msg2 := &Message{ + Header: &SectionHeader{ + ID: 1, + }, + } + + cases := []struct { + desc string + an *answer + nu *answer + expReceivedAt int64 + expAccessedAt int64 + expMsg *Message + }{{ + desc: "With nil parameter", + an: &answer{ + receivedAt: 1, + accessedAt: 1, + msg: msg1, + }, + expReceivedAt: 1, + expAccessedAt: 1, + expMsg: msg1, + }, { + desc: "With local answer", + an: &answer{ + receivedAt: 0, + accessedAt: 0, + msg: msg1, + }, + nu: &answer{ + receivedAt: at, + accessedAt: at, + msg: msg2, + }, + expReceivedAt: 0, + expAccessedAt: 0, + expMsg: nil, + }, { + desc: "With non local answer", + an: &answer{ + receivedAt: 1, + accessedAt: 1, + msg: msg1, + }, + nu: &answer{ + receivedAt: at, + accessedAt: at, + msg: msg2, + }, + expReceivedAt: at, + expAccessedAt: at, + expMsg: nil, + }} + + for _, c := range cases { + t.Log(c.desc) + + c.an.update(c.nu) + + test.Assert(t, "receivedAt", c.expReceivedAt, c.an.receivedAt, true) + test.Assert(t, "accessedAt", c.expAccessedAt, c.an.accessedAt, true) + if c.nu != nil { + test.Assert(t, "c.nu.msg", c.expMsg, c.nu.msg, true) + } + } +} diff --git a/lib/dns/answers.go b/lib/dns/answers.go new file mode 100644 index 00000000..01813ec4 --- /dev/null +++ b/lib/dns/answers.go @@ -0,0 +1,76 @@ +// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +// +// answers maintain list of answer. +// +type answers struct { + v []*answer +} + +// +// newAnswers create and initialize list of answer with one element. +// +func newAnswers(an *answer) (ans *answers) { + ans = &answers{ + v: make([]*answer, 0, 1), + } + if an != nil && an.msg != nil { + ans.v = append(ans.v, an) + } + return +} + +// +// get an answer with specific query type and class in slice. +// If found, it will return its element and index in slice; otherwise it will +// return nil on answer. +// +func (ans *answers) get(qtype, qclass uint16) (an *answer, x int) { + for x = 0; x < len(ans.v); x++ { + if ans.v[x].qtype != qtype { + continue + } + if ans.v[x].qclass != qclass { + continue + } + + an = ans.v[x] + return + } + return +} + +// +// remove the answer from list. +// +func (ans *answers) remove(qtype, qclass uint16) { + an, x := ans.get(qtype, qclass) + if an != nil { + ans.v[x] = ans.v[len(ans.v)-1] + ans.v[len(ans.v)-1] = nil + ans.v = ans.v[:len(ans.v)-1] + } +} + +// +// upsert update or insert new answer to list. +// If answer is updated, it will return the old answer; otherwise new insert +// is inserted to list and it will return nil instead. +// +func (ans *answers) upsert(nu *answer) (an *answer) { + if nu == nil || nu.msg == nil { + return + } + var x int + an, x = ans.get(nu.qtype, nu.qclass) + if an != nil { + ans.v[x].update(nu) + } else { + ans.v = append(ans.v, nu) + } + return +} diff --git a/lib/dns/answers_test.go b/lib/dns/answers_test.go new file mode 100644 index 00000000..4990763c --- /dev/null +++ b/lib/dns/answers_test.go @@ -0,0 +1,237 @@ +// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "testing" + + "github.com/shuLhan/share/lib/test" +) + +func TestNewAnswers(t *testing.T) { + cases := []struct { + desc string + an *answer + expLen int + expV []*answer + }{{ + desc: "With nil parameter", + expV: make([]*answer, 0, 1), + }, { + desc: "With nil message", + an: &answer{}, + expLen: 0, + expV: []*answer{}, + }, { + desc: "With valid answer", + an: &answer{ + msg: &Message{}, + }, + expLen: 1, + expV: []*answer{{ + msg: &Message{}, + }}, + }} + + for _, c := range cases { + t.Log(c.desc) + + got := newAnswers(c.an) + + test.Assert(t, "len(answers.v)", len(got.v), c.expLen, true) + test.Assert(t, "answers.v", got.v, c.expV, true) + } +} + +func TestAnswersGet(t *testing.T) { + msg := &Message{ + Question: &SectionQuestion{ + Name: []byte("test"), + Type: 1, + Class: 1, + }, + Answer: []*ResourceRecord{{ + Name: []byte("test"), + Type: QueryTypeA, + Class: QueryClassIN, + }}, + } + an := newAnswer(msg, true) + ans := newAnswers(an) + + cases := []struct { + desc string + qtype uint16 + qclass uint16 + exp *answer + expIndex int + }{{ + desc: "With query type and class not found", + expIndex: 1, + }, { + desc: "With query type not found", + qclass: 1, + expIndex: 1, + }, { + desc: "With query class not found", + qtype: 1, + expIndex: 1, + }, { + desc: "With valid query type and class", + qtype: 1, + qclass: 1, + exp: an, + expIndex: 0, + }} + + for _, c := range cases { + t.Log(c.desc) + + got, x := ans.get(c.qtype, c.qclass) + + test.Assert(t, "answers.get", c.exp, got, true) + test.Assert(t, "answers.get index", c.expIndex, x, true) + } +} + +func TestAnswersRemove(t *testing.T) { + msg := &Message{ + Question: &SectionQuestion{ + Name: []byte("test"), + Type: 1, + Class: 1, + }, + Answer: []*ResourceRecord{{ + Name: []byte("test"), + Type: QueryTypeA, + Class: QueryClassIN, + }}, + } + + an := newAnswer(msg, true) + ans := newAnswers(an) + + cases := []struct { + desc string + qtype, qclass uint16 + exp *answers + expLen int + }{{ + desc: "With query type and class not found", + exp: ans, + expLen: 1, + }, { + desc: "With query type not found", + qclass: 1, + exp: ans, + expLen: 1, + }, { + desc: "With query class not found", + qtype: 1, + exp: ans, + expLen: 1, + }, { + desc: "With valid query type and class", + qtype: 1, + qclass: 1, + exp: &answers{ + v: make([]*answer, 0, 1), + }, + expLen: 0, + }} + + for _, c := range cases { + t.Log(c.desc) + + ans.remove(c.qtype, c.qclass) + + test.Assert(t, "len(answers.v)", c.expLen, len(ans.v), true) + test.Assert(t, "cap(answers.v)", 1, cap(ans.v), true) + test.Assert(t, "answers", c.exp, ans, true) + } +} + +func TestAnswersUpdate(t *testing.T) { + an1 := &answer{ + qtype: 1, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 1, + }, + }, + } + an2 := &answer{ + qtype: 2, + qclass: 1, + msg: &Message{}, + } + an3 := &answer{ + qtype: 1, + qclass: 2, + msg: &Message{}, + } + an4 := &answer{ + qtype: 1, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 2, + }, + }, + } + + ans := newAnswers(an1) + + cases := []struct { + desc string + nu *answer + exp *answers + }{{ + desc: "With nil parameter", + exp: ans, + }, { + desc: "With query type not found", + nu: an2, + exp: &answers{ + v: []*answer{ + an1, + an2, + }, + }, + }, { + desc: "With query class not found", + nu: an3, + exp: &answers{ + v: []*answer{ + an1, + an2, + an3, + }, + }, + }, { + desc: "With query found", + nu: an4, + exp: &answers{ + v: []*answer{ + { + qtype: 1, + qclass: 1, + msg: an4.msg, + }, + an2, + an3, + }, + }, + }} + + for _, c := range cases { + t.Log(c.desc) + + ans.upsert(c.nu) + + test.Assert(t, "answers.upsert", c.exp, ans, true) + } +} diff --git a/lib/dns/caches.go b/lib/dns/caches.go new file mode 100644 index 00000000..1f5b9a9c --- /dev/null +++ b/lib/dns/caches.go @@ -0,0 +1,172 @@ +// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "container/list" + "sync" + "time" +) + +// +// caches of DNS answers. +// +type caches struct { + sync.Mutex + + // v contains mapping of DNS question name (a domain name) with their + // list of answer. + v map[string]*answers + + // lru represent list of non local answers, ordered based on answer + // access time in ascending order. + lru *list.List + + // pruneDelay define a delay where caches will be pruned. + // Default to 1 hour. + pruneDelay time.Duration + + // pruneThreshold define negative duration where answers will be + // pruned from caches. + // Default to -1 hour. + pruneThreshold time.Duration +} + +// +// newCaches create new in memory caches with specific prune delay and +// threshold. +// The prune delay MUST be greater than 1 minute or it will set to 1 hour. +// The prune threshold MUST be greater than -1 minute or it will be set to 1 +// hour. +// +func newCaches(pruneDelay, pruneThreshold time.Duration) (ca *caches) { + if pruneDelay.Minutes() < 1 { + pruneDelay = time.Hour + } + if pruneThreshold.Minutes() > -1 { + pruneThreshold = -1 * time.Hour + } + + ca = &caches{ + v: make(map[string]*answers), + lru: list.New(), + pruneDelay: pruneDelay, + pruneThreshold: pruneThreshold, + } + + return +} + +// +// get an answer from cache based on domain-name, query type, and query class. +// If answer exist on cache, their accessed time will be updated to current +// time. +// +func (c *caches) get(qname string, qtype, qclass uint16) (an *answer) { + c.Lock() + + answers, found := c.v[qname] + if found { + an, _ = answers.get(qtype, qclass) + if an != nil { + // Move the answer to the back of LRU if its not + // local. + if an.receivedAt > 0 { + c.lru.MoveToBack(an.el) + } + } + } + + c.Unlock() + return +} + +// +// list return all answers in LRU. +// +func (c *caches) list() (list []*answer) { + c.Lock() + for e := c.lru.Front(); e != nil; e = e.Next() { + list = append(list, e.Value.(*answer)) + } + c.Unlock() + return +} + +// +// prune will remove old answers on caches based on accessed time. +// +func (c *caches) prune() { + c.Lock() + + exp := time.Now().Add(c.pruneThreshold).Unix() + + e := c.lru.Front() + for e != nil { + an := e.Value.(*answer) + if an.accessedAt > exp { + break + } + + next := e.Next() + _ = c.lru.Remove(e) + c.remove(an) + + e = next + } + + c.Unlock() +} + +// +// remove an answer from caches. +// +func (c *caches) remove(an *answer) { + answers, found := c.v[an.qname] + if found { + answers.remove(an.qtype, an.qclass) + } + an.clear() +} + +// +// upsert update or insert answer to caches. If the answer is inserted to +// caches it will return true, otherwise when its updated it will return +// false. +// +func (c *caches) upsert(nu *answer) (inserted bool) { + if nu == nil || nu.msg == nil { + return + } + + c.Lock() + + answers, found := c.v[nu.qname] + if !found { + inserted = true + c.v[nu.qname] = newAnswers(nu) + if nu.receivedAt > 0 { + nu.el = c.lru.PushBack(nu) + } + } else { + an := answers.upsert(nu) + if an == nil { + inserted = true + if nu.receivedAt > 0 { + // Push the new answer to LRU if new answer is + // not local and its inserted to list. + nu.el = c.lru.PushBack(nu) + } + } else { + if nu.receivedAt > 0 { + c.lru.MoveToBack(an.el) + } + } + } + + c.Unlock() + + return inserted +} diff --git a/lib/dns/caches_test.go b/lib/dns/caches_test.go new file mode 100644 index 00000000..be038c72 --- /dev/null +++ b/lib/dns/caches_test.go @@ -0,0 +1,283 @@ +// Copyright 2019, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "testing" + "time" + + "github.com/shuLhan/share/lib/test" +) + +func TestNewCaches(t *testing.T) { + cases := []struct { + desc string + pruneDelay time.Duration + pruneThreshold time.Duration + expDelay time.Duration + expThreshold time.Duration + }{{ + desc: "With invalid delay and threshold", + expDelay: time.Hour, + expThreshold: -time.Hour, + }, { + desc: "With 2m delay and threshold", + pruneDelay: 2 * time.Minute, + pruneThreshold: -2 * time.Minute, + expDelay: 2 * time.Minute, + expThreshold: -2 * time.Minute, + }} + + for _, c := range cases { + t.Log(c.desc) + + got := newCaches(c.pruneDelay, c.pruneThreshold) + + test.Assert(t, "caches.pruneDelay", c.expDelay, + got.pruneDelay, true) + test.Assert(t, "caches.pruneThreshold", c.expThreshold, + got.pruneThreshold, true) + } +} + +func TestCachesGet(t *testing.T) { + an1 := &answer{ + receivedAt: 1, + qname: "test", + qtype: 1, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 1, + }, + }, + } + an2 := &answer{ + receivedAt: 2, + qname: "test", + qtype: 2, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 2, + }, + }, + } + an3 := &answer{ + receivedAt: 3, + qname: "test", + qtype: 3, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 3, + }, + }, + } + + ca := newCaches(0, 0) + + ca.upsert(an1) + ca.upsert(an2) + ca.upsert(an3) + + cases := []struct { + desc string + qname string + qtype uint16 + qclass uint16 + exp *answer + expList []*answer + }{{ + desc: "With query not found", + expList: []*answer{ + an1, an2, an3, + }, + }, { + desc: "With query found", + qname: "test", + qtype: 1, + qclass: 1, + exp: an1, + expList: []*answer{ + an2, an3, an1, + }, + }} + + for _, c := range cases { + t.Log(c.desc) + + got := ca.get(c.qname, c.qtype, c.qclass) + gotList := ca.list() + + test.Assert(t, "caches.get", c.exp, got, true) + test.Assert(t, "caches.list", c.expList, gotList, true) + } +} + +func TestCachesPrune(t *testing.T) { + at := time.Now().Unix() + + an1 := &answer{ + receivedAt: 1, + accessedAt: 1, + qname: "test", + qtype: 1, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 1, + }, + }, + } + an2 := &answer{ + receivedAt: 2, + accessedAt: 2, + qname: "test", + qtype: 2, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 2, + }, + }, + } + an3 := &answer{ + receivedAt: at, + accessedAt: at, + qname: "test", + qtype: 3, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 3, + }, + }, + } + + ca := newCaches(0, 0) + + ca.upsert(an1) + ca.upsert(an2) + ca.upsert(an3) + + t.Logf("%+v\n", ca.list()) + + cases := []struct { + desc string + expList []*answer + }{{ + desc: "With several caches got pruned", + expList: []*answer{ + an3, + }, + }} + + for _, c := range cases { + t.Log(c.desc) + + ca.prune() + + gotList := ca.list() + + test.Assert(t, "caches.list", c.expList, gotList, true) + } +} + +func TestCachesUpsert(t *testing.T) { + ca := newCaches(0, 0) + + an1 := &answer{ + receivedAt: 1, + accessedAt: 1, + qname: "test", + qtype: 1, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 1, + }, + }, + } + an1Update := &answer{ + receivedAt: 3, + accessedAt: 3, + qname: "test", + qtype: 1, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 3, + }, + }, + } + an2 := &answer{ + receivedAt: 2, + accessedAt: 2, + qname: "test", + qtype: 2, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 2, + }, + }, + } + an2Update := &answer{ + receivedAt: 4, + accessedAt: 4, + qname: "test", + qtype: 2, + qclass: 1, + msg: &Message{ + Header: &SectionHeader{ + ID: 4, + }, + }, + } + + cases := []struct { + desc string + nu *answer + expLen int + expList []*answer + }{{ + desc: "With empty answer", + }, { + desc: "With new answer", + nu: an1, + expLen: 1, + expList: []*answer{an1}, + }, { + desc: "With new answer, different type", + nu: an2, + expLen: 2, + expList: []*answer{an1, an2}, + }, { + desc: "With update on answer", + nu: an1Update, + expLen: 2, + expList: []*answer{an2, an1}, + }, { + desc: "With update on answer (2)", + nu: an2Update, + expLen: 2, + expList: []*answer{an1, an2}, + }} + + for _, c := range cases { + t.Log(c.desc) + + ca.upsert(c.nu) + + gotList := ca.list() + + test.Assert(t, "len(caches.list)", c.expLen, len(gotList), true) + + for x := 0; x < len(gotList); x++ { + test.Assert(t, "caches.list", c.expList[x], gotList[x], true) + } + } +} diff --git a/lib/dns/dns_test.go b/lib/dns/dns_test.go index 7929e168..ae2bc0fc 100644 --- a/lib/dns/dns_test.go +++ b/lib/dns/dns_test.go @@ -205,16 +205,14 @@ func (h *serverHandler) ServeDNS(req *Request) { } func TestMain(m *testing.M) { + var err error + log.SetFlags(log.Lmicroseconds) _testHandler = &serverHandler{} _testHandler.responses = generateTestResponses() - _testServer = &Server{ - Handler: _testHandler, - } - serverOptions := &ServerOptions{ IPAddress: "127.0.0.1", UDPPort: 5300, @@ -225,11 +223,14 @@ func TestMain(m *testing.M) { DoHAllowInsecure: true, } - err := _testServer.Start(serverOptions) + _testServer, err = NewServer(serverOptions, _testHandler) if err != nil { log.Fatal(err) } + _testServer.Start() + + // Wait for all listeners running. time.Sleep(500 * time.Millisecond) s := m.Run() diff --git a/lib/dns/example_server_test.go b/lib/dns/example_server_test.go index 5046c042..e2689f08 100644 --- a/lib/dns/example_server_test.go +++ b/lib/dns/example_server_test.go @@ -3,6 +3,7 @@ package dns_test import ( "fmt" "log" + "time" "github.com/shuLhan/share/lib/dns" ) @@ -211,10 +212,6 @@ func ExampleServer() { handler.generateResponses() - server := &dns.Server{ - Handler: handler, - } - serverOptions := &dns.ServerOptions{ IPAddress: "127.0.0.1", TCPPort: 5300, @@ -225,11 +222,16 @@ func ExampleServer() { DoHAllowInsecure: true, } - err := server.Start(serverOptions) + server, err := dns.NewServer(serverOptions, handler) if err != nil { log.Fatal(err) } + server.Start() + + // Wait for all listeners running. + time.Sleep(500 * time.Millisecond) + clientLookup(serverAddress) server.Stop() diff --git a/lib/dns/server.go b/lib/dns/server.go index e20d239b..df15aae9 100644 --- a/lib/dns/server.go +++ b/lib/dns/server.go @@ -13,6 +13,8 @@ import ( "log" "net" "net/http" + "os" + "path/filepath" "strings" "time" ) @@ -20,59 +22,226 @@ import ( // // Server defines DNS server. // +// Caches +// +// There are two type of answer: local and non-local. +// Local answer is a DNS record that is loaded from hosts file or master +// zone file. +// Non-local answer is a DNS record that is received from parent name +// servers. +// +// Server caches the DNS answers in two storages: map and list. +// The map caches store local and non local answers, using domain name as a +// key and list of answers as value, +// +// domain-name -> [{A,IN,...},{AAAA,IN,...}] +// +// The list caches store non-local answers, ordered by last accessed time, +// it is used to prune least frequently accessed answers. +// Local caches will never get pruned. +// type Server struct { - Handler Handler - udp *net.UDPConn - tcp *net.TCPListener - doh *http.Server + Handler Handler + errListener chan error + caches *caches + + udp *net.UDPConn + tcp *net.TCPListener + + doh *http.Server + dohAddress string + dohIdleTimeout time.Duration + dohCert *tls.Certificate + dohAllowInsecure bool } -func (srv *Server) init(opts *ServerOptions) (err error) { +// +// NewServer create and initialize server using the options and a .handler. +// +func NewServer(opts *ServerOptions, handler Handler) (srv *Server, err error) { err = opts.init() if err != nil { - return + return nil, err + } + + srv = &Server{ + Handler: handler, + dohAddress: opts.getDoHAddress().String(), + dohIdleTimeout: opts.DoHIdleTimeout, + dohCert: opts.cert, + dohAllowInsecure: opts.DoHAllowInsecure, } udpAddr := opts.getUDPAddress() srv.udp, err = net.ListenUDP("udp", udpAddr) if err != nil { - return fmt.Errorf("dns: error listening on UDP '%v': %s", + return nil, fmt.Errorf("dns: error listening on UDP '%v': %s", udpAddr, err.Error()) } tcpAddr := opts.getTCPAddress() srv.tcp, err = net.ListenTCP("tcp", tcpAddr) if err != nil { - return fmt.Errorf("dns: error listening on TCP '%v': %s", + return nil, fmt.Errorf("dns: error listening on TCP '%v': %s", tcpAddr, err.Error()) } srv.errListener = make(chan error, 1) + srv.caches = newCaches(opts.PruneDelay, opts.PruneThreshold) + + opts.cert = nil - return nil + return srv, nil } // -// Start the server, listening and serve query from clients. +// LoadHostsDir populate caches with DNS record from hosts formatted files in +// directory "dir". +// +func (srv *Server) LoadHostsDir(dir string) { + if len(dir) == 0 { + return + } + + d, err := os.Open(dir) + if err != nil { + log.Println("dns: LoadHostsDir: ", err) + return + } + + fis, err := d.Readdir(0) + if err != nil { + log.Println("dns: LoadHostsDir: ", err) + err = d.Close() + if err != nil { + log.Println("dns: LoadHostsDir: ", err) + } + return + } + + for x := 0; x < len(fis); x++ { + if fis[x].IsDir() { + continue + } + + hostsFile := filepath.Join(dir, fis[x].Name()) + + srv.LoadHostsFile(hostsFile) + } + + err = d.Close() + if err != nil { + log.Println("dns: LoadHostsDir: ", err) + } +} + +// +// LoadHostsFile populate caches with DNS record from hosts formatted file. +// +func (srv *Server) LoadHostsFile(path string) { + if len(path) == 0 { + fmt.Println("dns: loading system hosts file") + } else { + fmt.Printf("dns: loading hosts file '%s'\n", path) + } + + msgs, err := HostsLoad(path) + if err != nil { + log.Println("dns: LoadHostsFile: " + err.Error()) + } + + srv.populateCaches(msgs) +} + // -func (srv *Server) Start(opts *ServerOptions) (err error) { - err = srv.init(opts) +// LoadMasterDir populate caches with DNS record from master (zone) formatted +// files in directory "dir". +// +func (srv *Server) LoadMasterDir(dir string) { + if len(dir) == 0 { + return + } + + d, err := os.Open(dir) if err != nil { + log.Println("dns: LoadMasterDir: ", err) return } - if opts.cert != nil { - dohAddress := opts.getDoHAddress() - go srv.serveDoH(dohAddress, opts.DoHIdleTimeout, *opts.cert, - opts.DoHAllowInsecure) - opts.cert = nil + fis, err := d.Readdir(0) + if err != nil { + log.Println("dns: LoadMasterDir: ", err) + err = d.Close() + if err != nil { + log.Println("dns: LoadMasterDir: ", err) + } + return + } + + for x := 0; x < len(fis); x++ { + if fis[x].IsDir() { + continue + } + + masterFile := filepath.Join(dir, fis[x].Name()) + + srv.LoadMasterFile(masterFile) + } + + err = d.Close() + if err != nil { + log.Println("dns: LoadMasterDir: error closing directory:", err) + } +} + +// +// LostMasterFile populate caches with DNS record from master (zone) formatted +// file. +// +func (srv *Server) LoadMasterFile(path string) { + fmt.Printf("dns: loading master file '%s'\n", path) + + msgs, err := MasterLoad(path, "", 0) + if err != nil { + log.Println("dns: LoadMasterFile: " + err.Error()) + } + + srv.populateCaches(msgs) +} + +// +// populateCaches add list of message to caches. +// +func (srv *Server) populateCaches(msgs []*Message) { + var ( + n int + inserted bool + isLocal = true + ) + + for x := 0; x < len(msgs); x++ { + an := newAnswer(msgs[x], isLocal) + inserted = srv.caches.upsert(an) + if inserted { + n++ + } + msgs[x] = nil + } + + fmt.Printf("dns: %d out of %d records cached\n", n, len(msgs)) +} + +// +// Start the server, listening and serve query from clients. +// +func (srv *Server) Start() { + if srv.dohCert != nil { + go srv.serveDoH() } go srv.serveTCP() go srv.serveUDP() - - return nil } // @@ -109,17 +278,15 @@ func (srv *Server) Wait() { // serveDoH listen for request over HTTPS using certificate and key // file in parameter. The path to request is static "/dns-query". // -func (srv *Server) serveDoH(addr *net.TCPAddr, idleTimeout time.Duration, - cert tls.Certificate, allowInsecure bool, -) { +func (srv *Server) serveDoH() { srv.doh = &http.Server{ - Addr: addr.String(), - IdleTimeout: idleTimeout, + Addr: srv.dohAddress, + IdleTimeout: srv.dohIdleTimeout, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{ - cert, + *srv.dohCert, }, - InsecureSkipVerify: allowInsecure, // nolint: gosec + InsecureSkipVerify: srv.dohAllowInsecure, // nolint: gosec }, } diff --git a/lib/dns/serveroptions.go b/lib/dns/serveroptions.go index 3284f682..9ea8ecc1 100644 --- a/lib/dns/serveroptions.go +++ b/lib/dns/serveroptions.go @@ -51,6 +51,22 @@ type ServerOptions struct { // This field is optional. DoHAllowInsecure bool + // PruneDelay define a delay where caches will be pruned. + // This field is optional, minimum value is 1 minute, and default + // value is 1 hour. + // For example, if its set to 1 hour, every 1 hour the caches will be + // inspected to remove answers that has not been accessed more than or + // equal to PruneThreshold. + PruneDelay time.Duration + + // PruneThreshold define negative duration where answers will be + // pruned from caches. + // This field is optional, minimum value is -1 minute, and default + // value is -1 hour, + // For example, if its set to -1 minute, any answers that has not been + // accessed in the last 1 minute will be removed from cache. + PruneThreshold time.Duration + ip net.IP cert *tls.Certificate } @@ -88,6 +104,12 @@ func (opts *ServerOptions) init() (err error) { if opts.DoHIdleTimeout <= 0 { opts.DoHIdleTimeout = defaultDoHIdleTimeout } + if opts.PruneDelay.Minutes() < 1 { + opts.PruneDelay = time.Hour + } + if opts.PruneThreshold.Minutes() > -1 { + opts.PruneThreshold = -1 * time.Hour + } return nil } |
