aboutsummaryrefslogtreecommitdiff
path: root/lib/mining/knn
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2018-09-17 05:04:26 +0700
committerShulhan <ms@kilabit.info>2018-09-18 01:50:21 +0700
commit1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e (patch)
tree5fa83fc0faa31e09cae82ac4d467cf8ba5f87fc2 /lib/mining/knn
parent446fef94cd712861221c0098dcdd9ae52aaed0eb (diff)
downloadpakakeh.go-1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e.tar.xz
Merge package "github.com/shuLhan/go-mining"
Diffstat (limited to 'lib/mining/knn')
-rw-r--r--lib/mining/knn/knn.go109
-rw-r--r--lib/mining/knn/knn_test.go62
-rw-r--r--lib/mining/knn/neighbor.go150
-rw-r--r--lib/mining/knn/neighbor_test.go84
4 files changed, 405 insertions, 0 deletions
diff --git a/lib/mining/knn/knn.go b/lib/mining/knn/knn.go
new file mode 100644
index 00000000..0fa38f9c
--- /dev/null
+++ b/lib/mining/knn/knn.go
@@ -0,0 +1,109 @@
+// Copyright 2015-2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//
+// Package knn implement the K Nearest Neighbor using Euclidian to compute the
+// distance between samples.
+//
+package knn
+
+import (
+ "fmt"
+ "math"
+ "sort"
+
+ "github.com/shuLhan/share/lib/debug"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+const (
+ // TEuclidianDistance used in Runtime.DistanceMethod.
+ TEuclidianDistance = 0
+)
+
+//
+// Runtime parameters for KNN processing.
+//
+type Runtime struct {
+ // DistanceMethod define how the distance between sample will be
+ // measured.
+ DistanceMethod int
+ // ClassIndex define index of class in dataset.
+ ClassIndex int `json:"ClassIndex"`
+ // K define number of nearest neighbors that will be searched.
+ K int `json:"K"`
+
+ // AllNeighbors contain all neighbours
+ AllNeighbors Neighbors
+}
+
+//
+// ComputeEuclidianDistance compute the distance of instance with each sample in
+// dataset `samples` and return it.
+//
+func (in *Runtime) ComputeEuclidianDistance(samples *tabula.Rows,
+ instance *tabula.Row,
+) {
+ for x := range *samples {
+ row := (*samples)[x]
+
+ // compute euclidian distance
+ d := 0.0
+ for y, rec := range *row {
+ if y == in.ClassIndex {
+ // skip class attribute
+ continue
+ }
+
+ ir := (*instance)[y]
+ diff := 0.0
+
+ diff = ir.Float() - rec.Float()
+
+ d += math.Abs(diff)
+ }
+
+ // only add sample distance which is not zero (its probably
+ // we calculating with the instance itself)
+ if d != 0 {
+ in.AllNeighbors.Add(row, math.Sqrt(d))
+ }
+ }
+
+ sort.Sort(&in.AllNeighbors)
+}
+
+//
+// FindNeighbors Given sample set and an instance, return the nearest neighbors as
+// a slice of neighbors.
+//
+func (in *Runtime) FindNeighbors(samples *tabula.Rows, instance *tabula.Row) (
+ kneighbors Neighbors,
+) {
+ // Reset current input neighbours
+ in.AllNeighbors = Neighbors{}
+
+ switch in.DistanceMethod {
+ case TEuclidianDistance:
+ in.ComputeEuclidianDistance(samples, instance)
+ }
+
+ // Make sure number of neighbors is greater than request.
+ minK := in.AllNeighbors.Len()
+ if minK > in.K {
+ minK = in.K
+ }
+
+ if debug.Value >= 2 {
+ fmt.Println("[knn] all neighbors:", in.AllNeighbors.Len())
+ }
+
+ kneighbors = in.AllNeighbors.SelectRange(0, minK)
+
+ if debug.Value >= 2 {
+ fmt.Println("[knn] k neighbors:", kneighbors.Len())
+ }
+
+ return
+}
diff --git a/lib/mining/knn/knn_test.go b/lib/mining/knn/knn_test.go
new file mode 100644
index 00000000..86ef75c0
--- /dev/null
+++ b/lib/mining/knn/knn_test.go
@@ -0,0 +1,62 @@
+// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package knn
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/tabula"
+ "github.com/shuLhan/share/lib/test"
+)
+
+func TestComputeEuclidianDistance(t *testing.T) {
+ var exp = []string{
+ `[0.302891 0.608544 0.47413 1.42718 -0.811085 1]`,
+ `[0.243474 0.505146 0.472892 1.34802 -0.844252 1]` +
+ `[0.202343 0.485983 0.527533 1.47307 -0.809672 1]` +
+ `[0.215496 0.523418 0.51719 1.43548 -0.933981 1]` +
+ `[0.214331 0.546086 0.414773 1.38542 -0.702336 1]` +
+ `[0.301676 0.554505 0.594757 1.21258 -0.873084 1]`,
+ }
+ var expDistances = "[0.5257185558832786" +
+ " 0.5690474496911485" +
+ " 0.5888777462258191" +
+ " 0.6007362149895741" +
+ " 0.672666336306493]"
+
+ // Reading data
+ dataset := tabula.Dataset{}
+ _, e := dsv.SimpleRead("../testdata/phoneme/phoneme.dsv", &dataset)
+ if nil != e {
+ return
+ }
+
+ // Processing
+ knnIn := Runtime{
+ DistanceMethod: TEuclidianDistance,
+ ClassIndex: 5,
+ K: 5,
+ }
+
+ classes := dataset.GetRows().GroupByValue(knnIn.ClassIndex)
+
+ _, minoritySet := classes.GetMinority()
+
+ kneighbors := knnIn.FindNeighbors(&minoritySet, minoritySet[0])
+
+ var got string
+ rows := kneighbors.Rows()
+ for _, row := range *rows {
+ got += fmt.Sprint(*row)
+ }
+
+ test.Assert(t, "", exp[1], got, true)
+
+ distances := kneighbors.Distances()
+ got = fmt.Sprint(*distances)
+ test.Assert(t, "", expDistances, got, true)
+}
diff --git a/lib/mining/knn/neighbor.go b/lib/mining/knn/neighbor.go
new file mode 100644
index 00000000..e472d09f
--- /dev/null
+++ b/lib/mining/knn/neighbor.go
@@ -0,0 +1,150 @@
+// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package knn
+
+import (
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+//
+// Neighbors is a mapping between sample and their distance.
+// This type implement the sort interface.
+//
+type Neighbors struct {
+ // rows contain pointer to rows.
+ rows []*tabula.Row
+ // Distance value
+ distances []float64
+}
+
+//
+// Rows return all rows.
+//
+func (neighbors *Neighbors) Rows() *[]*tabula.Row {
+ return &neighbors.rows
+}
+
+//
+// Row return pointer to row at index `idx`.
+//
+func (neighbors *Neighbors) Row(idx int) *tabula.Row {
+ return neighbors.rows[idx]
+}
+
+//
+// Distances return slice of distance of each neighbours.
+//
+func (neighbors *Neighbors) Distances() *[]float64 {
+ return &neighbors.distances
+}
+
+//
+// Distance return distance value at index `idx`.
+//
+func (neighbors *Neighbors) Distance(idx int) float64 {
+ return neighbors.distances[idx]
+}
+
+//
+// Len return the number of neighbors.
+// This is for sort interface.
+//
+func (neighbors *Neighbors) Len() int {
+ return len(neighbors.distances)
+}
+
+//
+// Less return true if i < j.
+// This is for sort interface.
+//
+func (neighbors *Neighbors) Less(i, j int) bool {
+ if neighbors.distances[i] < neighbors.distances[j] {
+ return true
+ }
+ return false
+}
+
+//
+// Swap content of object in index i with index j.
+// This is for sort interface.
+//
+func (neighbors *Neighbors) Swap(i, j int) {
+ row := neighbors.rows[i]
+ distance := neighbors.distances[i]
+
+ neighbors.rows[i] = neighbors.rows[j]
+ neighbors.distances[i] = neighbors.distances[j]
+
+ neighbors.rows[j] = row
+ neighbors.distances[j] = distance
+}
+
+//
+// Add new neighbor.
+//
+func (neighbors *Neighbors) Add(row *tabula.Row, distance float64) {
+ neighbors.rows = append(neighbors.rows, row)
+ neighbors.distances = append(neighbors.distances, distance)
+}
+
+//
+// SelectRange select all neighbors from index `start` to `end`.
+// Return an empty set if start or end is out of range.
+//
+func (neighbors *Neighbors) SelectRange(start, end int) (newn Neighbors) {
+ if start < 0 {
+ return
+ }
+
+ if end > neighbors.Len() {
+ return
+ }
+
+ for x := start; x < end; x++ {
+ row := neighbors.rows[x]
+ newn.Add(row, neighbors.distances[x])
+ }
+ return
+}
+
+//
+// SelectWhere return all neighbors where row value at index `idx` is equal
+// to string `val`.
+//
+func (neighbors *Neighbors) SelectWhere(idx int, val string) (newn Neighbors) {
+ for x, row := range neighbors.rows {
+ colval := (*row)[idx].String()
+
+ if colval == val {
+ newn.Add(row, neighbors.Distance(x))
+ }
+ }
+ return
+}
+
+//
+// Contain return true if `row` is in neighbors and their index, otherwise
+// return false and -1.
+//
+func (neighbors *Neighbors) Contain(row *tabula.Row) (bool, int) {
+ for x, xrow := range neighbors.rows {
+ if xrow.IsEqual(row) {
+ return true, x
+ }
+ }
+ return false, -1
+}
+
+//
+// Replace neighbor at index `idx` with new row and distance value.
+//
+func (neighbors *Neighbors) Replace(idx int, row *tabula.Row, distance float64) {
+ if idx > len(neighbors.rows) {
+ return
+ }
+
+ neighbors.rows[idx] = row
+ neighbors.distances[idx] = distance
+}
diff --git a/lib/mining/knn/neighbor_test.go b/lib/mining/knn/neighbor_test.go
new file mode 100644
index 00000000..047d9a32
--- /dev/null
+++ b/lib/mining/knn/neighbor_test.go
@@ -0,0 +1,84 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package knn
+
+import (
+ "math/rand"
+ "sort"
+ "testing"
+ "time"
+
+ "github.com/shuLhan/share/lib/tabula"
+ "github.com/shuLhan/share/lib/test"
+)
+
+var dataFloat64 = [][]float64{
+ {0.243474, 0.505146, 0.472892, 1.34802, -0.844252, 1},
+ {0.202343, 0.485983, 0.527533, 1.47307, -0.809672, 1},
+ {0.215496, 0.523418, 0.517190, 1.43548, -0.933981, 1},
+ {0.214331, 0.546086, 0.414773, 1.38542, -0.702336, 1},
+ {0.301676, 0.554505, 0.594757, 1.21258, -0.873084, 1},
+}
+
+var distances = []int{4, 3, 2, 1, 0}
+
+func createNeigbours() (neighbors Neighbors) {
+ for x, d := range dataFloat64 {
+ row := tabula.Row{}
+
+ for _, v := range d {
+ rec := tabula.NewRecordReal(v)
+ row.PushBack(rec)
+ }
+
+ neighbors.Add(&row, float64(distances[x]))
+ }
+ return
+}
+
+func createNeigboursByIdx(indices []int) (neighbors Neighbors) {
+ for x, idx := range indices {
+ row := tabula.Row{}
+
+ for _, v := range dataFloat64[idx] {
+ rec := tabula.NewRecordReal(v)
+ row.PushBack(rec)
+ }
+
+ neighbors.Add(&row, float64(distances[x]))
+ }
+ return
+}
+
+func TestContain(t *testing.T) {
+ rand.Seed(time.Now().UnixNano())
+
+ neighbors := createNeigbours()
+
+ // pick random sample from neighbors
+ pickIdx := rand.Intn(neighbors.Len())
+ randSample := neighbors.Row(pickIdx).Clone()
+
+ isin, idx := neighbors.Contain(randSample)
+
+ test.Assert(t, "", true, isin, true)
+ test.Assert(t, "", pickIdx, idx, true)
+
+ // change one of record value to check for false.
+ (*randSample)[0].SetFloat(0)
+
+ isin, _ = neighbors.Contain(randSample)
+
+ test.Assert(t, "", false, isin, true)
+}
+
+func TestSort(t *testing.T) {
+ neighbors := createNeigbours()
+ exp := createNeigboursByIdx(distances)
+
+ sort.Sort(&neighbors)
+
+ test.Assert(t, "", exp.Rows(), neighbors.Rows(), true)
+}