diff options
| author | Shulhan <ms@kilabit.info> | 2018-09-17 05:04:26 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2018-09-18 01:50:21 +0700 |
| commit | 1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e (patch) | |
| tree | 5fa83fc0faa31e09cae82ac4d467cf8ba5f87fc2 /lib/mining/classifier | |
| parent | 446fef94cd712861221c0098dcdd9ae52aaed0eb (diff) | |
| download | pakakeh.go-1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e.tar.xz | |
Merge package "github.com/shuLhan/go-mining"
Diffstat (limited to 'lib/mining/classifier')
| -rw-r--r-- | lib/mining/classifier/cart/cart.go | 480 | ||||
| -rw-r--r-- | lib/mining/classifier/cart/cart_test.go | 62 | ||||
| -rw-r--r-- | lib/mining/classifier/cart/node.go | 44 | ||||
| -rw-r--r-- | lib/mining/classifier/cm.go | 442 | ||||
| -rw-r--r-- | lib/mining/classifier/cm_test.go | 69 | ||||
| -rw-r--r-- | lib/mining/classifier/crf/crf.go | 470 | ||||
| -rw-r--r-- | lib/mining/classifier/crf/crf_test.go | 92 | ||||
| -rw-r--r-- | lib/mining/classifier/rf/rf.go | 363 | ||||
| -rw-r--r-- | lib/mining/classifier/rf/rf_bench_test.go | 22 | ||||
| -rw-r--r-- | lib/mining/classifier/rf/rf_test.go | 190 | ||||
| -rw-r--r-- | lib/mining/classifier/runtime.go | 443 | ||||
| -rw-r--r-- | lib/mining/classifier/stat.go | 176 | ||||
| -rw-r--r-- | lib/mining/classifier/stats.go | 143 | ||||
| -rw-r--r-- | lib/mining/classifier/stats_interface.go | 68 |
14 files changed, 3064 insertions, 0 deletions
diff --git a/lib/mining/classifier/cart/cart.go b/lib/mining/classifier/cart/cart.go new file mode 100644 index 00000000..449781d9 --- /dev/null +++ b/lib/mining/classifier/cart/cart.go @@ -0,0 +1,480 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +// Package cart implement the Classification and Regression Tree by Breiman, et al. +// CART is binary decision tree. +// +// Breiman, Leo, et al. Classification and regression trees. CRC press, +// 1984. +// +// The implementation is based on Data Mining book, +// +// Han, Jiawei, Micheline Kamber, and Jian Pei. Data mining: concepts and +// techniques: concepts and techniques. Elsevier, 2011. +// +package cart + +import ( + "fmt" + + "github.com/shuLhan/share/lib/debug" + "github.com/shuLhan/share/lib/mining/gain/gini" + "github.com/shuLhan/share/lib/mining/tree/binary" + "github.com/shuLhan/share/lib/numbers" + libstrings "github.com/shuLhan/share/lib/strings" + "github.com/shuLhan/share/lib/tabula" +) + +const ( + // SplitMethodGini if defined in Runtime, the dataset will be splitted + // using Gini gain for each possible value or partition. + // + // This option is used in Runtime.SplitMethod. + SplitMethodGini = "gini" +) + +const ( + // ColFlagParent denote that the column is parent/split node. + ColFlagParent = 1 + // ColFlagSkip denote that the column would be skipped. + ColFlagSkip = 2 +) + +// +// Runtime data for building CART. +// +type Runtime struct { + // SplitMethod define the criteria to used for splitting. + SplitMethod string `json:"SplitMethod"` + // NRandomFeature if less or equal to zero compute gain on all feature, + // otherwise select n random feature and compute gain only on selected + // features. + NRandomFeature int `json:"NRandomFeature"` + // OOBErrVal is the last out-of-bag error value in the tree. + OOBErrVal float64 + // Tree in classification. + Tree binary.Tree +} + +// +// New create new Runtime object. +// +func New(D tabula.ClasetInterface, splitMethod string, nRandomFeature int) ( + *Runtime, error, +) { + runtime := &Runtime{ + SplitMethod: splitMethod, + NRandomFeature: nRandomFeature, + Tree: binary.Tree{}, + } + + e := runtime.Build(D) + if e != nil { + return nil, e + } + + return runtime, nil +} + +// +// Build will create a tree using CART algorithm. +// +func (runtime *Runtime) Build(D tabula.ClasetInterface) (e error) { + // Re-check input configuration. + switch runtime.SplitMethod { + case SplitMethodGini: + // Do nothing. + default: + // Set default split method to Gini index. + runtime.SplitMethod = SplitMethodGini + } + + runtime.Tree.Root, e = runtime.splitTreeByGain(D) + + return +} + +// +// splitTreeByGain calculate the gain in all dataset, and split into two node: +// left and right. +// +// Return node with the split information. +// +func (runtime *Runtime) splitTreeByGain(D tabula.ClasetInterface) ( + node *binary.BTNode, + e error, +) { + node = &binary.BTNode{} + + D.RecountMajorMinor() + + // if dataset is empty return node labeled with majority classes in + // dataset. + nrow := D.GetNRow() + + if nrow <= 0 { + if debug.Value >= 2 { + fmt.Printf("[cart] empty dataset (%s) : %v\n", + D.MajorityClass(), D) + } + + node.Value = NodeValue{ + IsLeaf: true, + Class: D.MajorityClass(), + Size: 0, + } + return node, nil + } + + // if all dataset is in the same class, return node as leaf with class + // is set to that class. + single, name := D.IsInSingleClass() + if single { + if debug.Value >= 2 { + fmt.Printf("[cart] in single class (%s): %v\n", name, + D.GetColumns()) + } + + node.Value = NodeValue{ + IsLeaf: true, + Class: name, + Size: nrow, + } + return node, nil + } + + if debug.Value >= 2 { + fmt.Println("[cart] D:", D) + } + + // calculate the Gini gain for each attribute. + gains := runtime.computeGain(D) + + // get attribute with maximum Gini gain. + MaxGainIdx := gini.FindMaxGain(&gains) + MaxGain := gains[MaxGainIdx] + + // if maxgain value is 0, use majority class as node and terminate + // the process + if MaxGain.GetMaxGainValue() == 0 { + if debug.Value >= 2 { + fmt.Println("[cart] max gain 0 with target", + D.GetClassAsStrings(), + " and majority class is ", D.MajorityClass()) + } + + node.Value = NodeValue{ + IsLeaf: true, + Class: D.MajorityClass(), + Size: 0, + } + return node, nil + } + + // using the sorted index in MaxGain, sort all field in dataset + tabula.SortColumnsByIndex(D, MaxGain.SortedIndex) + + if debug.Value >= 2 { + fmt.Println("[cart] maxgain:", MaxGain) + } + + // Now that we have attribute with max gain in MaxGainIdx, and their + // gain dan partition value in Gains[MaxGainIdx] and + // GetMaxPartValue(), we split the dataset based on type of max-gain + // attribute. + // If its continuous, split the attribute using numeric value. + // If its discrete, split the attribute using subset (partition) of + // nominal values. + var splitV interface{} + + if MaxGain.IsContinu { + splitV = MaxGain.GetMaxPartGainValue() + } else { + attrPartV := MaxGain.GetMaxPartGainValue() + attrSubV := attrPartV.(libstrings.Row) + splitV = attrSubV[0] + } + + if debug.Value >= 2 { + fmt.Println("[cart] maxgainindex:", MaxGainIdx) + fmt.Println("[cart] split v:", splitV) + } + + node.Value = NodeValue{ + SplitAttrName: D.GetColumn(MaxGainIdx).GetName(), + IsLeaf: false, + IsContinu: MaxGain.IsContinu, + Size: nrow, + SplitAttrIdx: MaxGainIdx, + SplitV: splitV, + } + + dsL, dsR, e := tabula.SplitRowsByValue(D, MaxGainIdx, splitV) + + if e != nil { + return node, e + } + + splitL := dsL.(tabula.ClasetInterface) + splitR := dsR.(tabula.ClasetInterface) + + // Set the flag to parent in attribute referenced by + // MaxGainIdx, so it will not computed again in the next round. + cols := splitL.GetColumns() + for x := range *cols { + if x == MaxGainIdx { + (*cols)[x].Flag = ColFlagParent + } else { + (*cols)[x].Flag = 0 + } + } + + cols = splitR.GetColumns() + for x := range *cols { + if x == MaxGainIdx { + (*cols)[x].Flag = ColFlagParent + } else { + (*cols)[x].Flag = 0 + } + } + + nodeLeft, e := runtime.splitTreeByGain(splitL) + if e != nil { + return node, e + } + + nodeRight, e := runtime.splitTreeByGain(splitR) + if e != nil { + return node, e + } + + node.SetLeft(nodeLeft) + node.SetRight(nodeRight) + + return node, nil +} + +// SelectRandomFeature if NRandomFeature is greater than zero, select and +// compute gain in n random features instead of in all features +func (runtime *Runtime) SelectRandomFeature(D tabula.ClasetInterface) { + if runtime.NRandomFeature <= 0 { + // all features selected + return + } + + ncols := D.GetNColumn() + + // count all features minus class + nfeature := ncols - 1 + if runtime.NRandomFeature >= nfeature { + // Do nothing if number of random feature equal or greater than + // number of feature in dataset. + return + } + + // exclude class index and parent node index + excludeIdx := []int{D.GetClassIndex()} + cols := D.GetColumns() + for x, col := range *cols { + if (col.Flag & ColFlagParent) == ColFlagParent { + excludeIdx = append(excludeIdx, x) + } else { + (*cols)[x].Flag |= ColFlagSkip + } + } + + // Select random features excluding feature in `excludeIdx`. + var pickedIdx []int + for x := 0; x < runtime.NRandomFeature; x++ { + idx := numbers.IntPickRandPositive(ncols, false, pickedIdx, + excludeIdx) + pickedIdx = append(pickedIdx, idx) + + // Remove skip flag on selected column + col := D.GetColumn(idx) + col.Flag = col.Flag &^ ColFlagSkip + } + + if debug.Value >= 1 { + fmt.Println("[cart] selected random features:", pickedIdx) + fmt.Println("[cart] selected columns :", D.GetColumns()) + } +} + +// +// computeGain calculate the gini index for each value in each attribute. +// +func (runtime *Runtime) computeGain(D tabula.ClasetInterface) ( + gains []gini.Gini, +) { + switch runtime.SplitMethod { + case SplitMethodGini: + // create gains value for all attribute minus target class. + gains = make([]gini.Gini, D.GetNColumn()) + } + + runtime.SelectRandomFeature(D) + + classVS := D.GetClassValueSpace() + classIdx := D.GetClassIndex() + classType := D.GetClassType() + + for x, col := range *D.GetColumns() { + // skip class attribute. + if x == classIdx { + continue + } + + // skip column flagged with parent + if (col.Flag & ColFlagParent) == ColFlagParent { + gains[x].Skip = true + continue + } + + // ignore column flagged with skip + if (col.Flag & ColFlagSkip) == ColFlagSkip { + gains[x].Skip = true + continue + } + + // compute gain. + if col.GetType() == tabula.TReal { + attr := col.ToFloatSlice() + + if classType == tabula.TString { + target := D.GetClassAsStrings() + gains[x].ComputeContinu(&attr, &target, + &classVS) + } else { + targetReal := D.GetClassAsReals() + classVSReal := libstrings.ToFloat64(classVS) + + gains[x].ComputeContinuFloat(&attr, + &targetReal, &classVSReal) + } + } else { + attr := col.ToStringSlice() + attrV := col.ValueSpace + + if debug.Value >= 2 { + fmt.Println("[cart] attr :", attr) + fmt.Println("[cart] attrV:", attrV) + } + + target := D.GetClassAsStrings() + gains[x].ComputeDiscrete(&attr, &attrV, &target, + &classVS) + } + + if debug.Value >= 2 { + fmt.Println("[cart] gain :", gains[x]) + } + } + return +} + +// +// Classify return the prediction of one sample. +// +func (runtime *Runtime) Classify(data *tabula.Row) (class string) { + node := runtime.Tree.Root + nodev := node.Value.(NodeValue) + + for !nodev.IsLeaf { + if nodev.IsContinu { + splitV := nodev.SplitV.(float64) + attrV := (*data)[nodev.SplitAttrIdx].Float() + + if attrV < splitV { + node = node.Left + } else { + node = node.Right + } + } else { + splitV := nodev.SplitV.([]string) + attrV := (*data)[nodev.SplitAttrIdx].String() + + if libstrings.IsContain(splitV, attrV) { + node = node.Left + } else { + node = node.Right + } + } + nodev = node.Value.(NodeValue) + } + + return nodev.Class +} + +// +// ClassifySet set the class attribute based on tree classification. +// +func (runtime *Runtime) ClassifySet(data tabula.ClasetInterface) (e error) { + nrow := data.GetNRow() + targetAttr := data.GetClassColumn() + + for i := 0; i < nrow; i++ { + class := runtime.Classify(data.GetRow(i)) + + _ = (*targetAttr).Records[i].SetValue(class, tabula.TString) + } + + return +} + +// +// CountOOBError process out-of-bag data on tree and return error value. +// +func (runtime *Runtime) CountOOBError(oob tabula.Claset) ( + errval float64, + e error, +) { + // save the original target to be compared later. + origTarget := oob.GetClassAsStrings() + + if debug.Value >= 2 { + fmt.Println("[cart] OOB:", oob.Columns) + fmt.Println("[cart] TREE:", &runtime.Tree) + } + + // reset the target. + oobtarget := oob.GetClassColumn() + oobtarget.ClearValues() + + e = runtime.ClassifySet(&oob) + + if e != nil { + // set original target values back. + oobtarget.SetValues(origTarget) + return + } + + target := oobtarget.ToStringSlice() + + if debug.Value >= 2 { + fmt.Println("[cart] original target:", origTarget) + fmt.Println("[cart] classify target:", target) + } + + // count how many target value is miss-classified. + runtime.OOBErrVal, _, _ = libstrings.CountMissRate(origTarget, target) + + // set original target values back. + oobtarget.SetValues(origTarget) + + return runtime.OOBErrVal, nil +} + +// +// String yes, it will print it JSON like format. +// +func (runtime *Runtime) String() (s string) { + s = fmt.Sprintf("NRandomFeature: %d\n"+ + " SplitMethod : %s\n"+ + " Tree :\n%v", runtime.NRandomFeature, + runtime.SplitMethod, + runtime.Tree.String()) + return s +} diff --git a/lib/mining/classifier/cart/cart_test.go b/lib/mining/classifier/cart/cart_test.go new file mode 100644 index 00000000..14e89b12 --- /dev/null +++ b/lib/mining/classifier/cart/cart_test.go @@ -0,0 +1,62 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cart + +import ( + "fmt" + "testing" + + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/tabula" + "github.com/shuLhan/share/lib/test" +) + +const ( + NRows = 150 +) + +func TestCART(t *testing.T) { + fds := "../../testdata/iris/iris.dsv" + + ds := tabula.Claset{} + + _, e := dsv.SimpleRead(fds, &ds) + if nil != e { + t.Fatal(e) + } + + fmt.Println("[cart_test] class index:", ds.GetClassIndex()) + + // copy target to be compared later. + targetv := ds.GetClassAsStrings() + + test.Assert(t, "", NRows, ds.GetNRow(), true) + + // Build CART tree. + CART, e := New(&ds, SplitMethodGini, 0) + if e != nil { + t.Fatal(e) + } + + fmt.Println("[cart_test] CART Tree:\n", CART) + + // Create test set + testset := tabula.Claset{} + _, e = dsv.SimpleRead(fds, &testset) + + if nil != e { + t.Fatal(e) + } + + testset.GetClassColumn().ClearValues() + + // Classifiy test set + e = CART.ClassifySet(&testset) + if nil != e { + t.Fatal(e) + } + + test.Assert(t, "", targetv, testset.GetClassAsStrings(), true) +} diff --git a/lib/mining/classifier/cart/node.go b/lib/mining/classifier/cart/node.go new file mode 100644 index 00000000..b64dd13c --- /dev/null +++ b/lib/mining/classifier/cart/node.go @@ -0,0 +1,44 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cart + +import ( + "fmt" + "reflect" +) + +// +// NodeValue of tree in CART. +// +type NodeValue struct { + // Class of leaf node. + Class string + // SplitAttrName define the name of attribute which cause the split. + SplitAttrName string + // IsLeaf define whether node is a leaf or not. + IsLeaf bool + // IsContinu define whether the node split is continuous or discrete. + IsContinu bool + // Size define number of sample that this node hold before splitting. + Size int + // SplitAttrIdx define the attribute which cause the split. + SplitAttrIdx int + // SplitV define the split value. + SplitV interface{} +} + +// +// String will return the value of node for printable. +// +func (nodev *NodeValue) String() (s string) { + if nodev.IsLeaf { + s = fmt.Sprintf("Class: %s", nodev.Class) + } else { + s = fmt.Sprintf("(SplitValue: %v)", + reflect.ValueOf(nodev.SplitV)) + } + + return s +} diff --git a/lib/mining/classifier/cm.go b/lib/mining/classifier/cm.go new file mode 100644 index 00000000..0dc2ee05 --- /dev/null +++ b/lib/mining/classifier/cm.go @@ -0,0 +1,442 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +import ( + "fmt" + "strconv" + + "github.com/shuLhan/share/lib/tabula" +) + +// +// CM represent the matrix of classification. +// +type CM struct { + tabula.Dataset + // rowNames contain name in each row. + rowNames []string + // nSamples contain number of class. + nSamples int64 + // nTrue contain number of true positive and negative. + nTrue int64 + // nFalse contain number of false positive and negative. + nFalse int64 + + // tpIds contain index of true-positive samples. + tpIds []int + // fpIds contain index of false-positive samples. + fpIds []int + // tnIds contain index of true-negative samples. + tnIds []int + // fnIds contain index of false-negative samples. + fnIds []int +} + +// +// initByNumeric will initialize confusion matrix using numeric value space. +// +func (cm *CM) initByNumeric(vs []int64) { + var colTypes []int + var colNames []string + + for _, v := range vs { + vstr := strconv.FormatInt(v, 10) + colTypes = append(colTypes, tabula.TInteger) + colNames = append(colNames, vstr) + cm.rowNames = append(cm.rowNames, vstr) + } + + // class error column + colTypes = append(colTypes, tabula.TReal) + colNames = append(colNames, "class_error") + + cm.Dataset.Init(tabula.DatasetModeMatrix, colTypes, colNames) +} + +// +// ComputeStrings will calculate confusion matrix using targets and predictions +// class values. +// +func (cm *CM) ComputeStrings(valueSpace, targets, predictions []string) { + cm.init(valueSpace) + + for x, target := range valueSpace { + col := cm.GetColumn(x) + + for _, predict := range valueSpace { + cnt := cm.countTargetPrediction(target, predict, + targets, predictions) + + col.PushBack(tabula.NewRecordInt(cnt)) + } + + cm.PushColumnToRows(*col) + } + + cm.computeClassError() +} + +// +// ComputeNumeric will calculate confusion matrix using targets and predictions +// values. +// +func (cm *CM) ComputeNumeric(vs, actuals, predictions []int64) { + cm.initByNumeric(vs) + + for x, act := range vs { + col := cm.GetColumn(x) + + for _, pred := range vs { + cnt := cm.countNumeric(act, pred, actuals, predictions) + + rec := tabula.NewRecordInt(cnt) + col.PushBack(rec) + } + + cm.PushColumnToRows(*col) + } + + cm.computeClassError() +} + +// +// create will initialize confusion matrix using value space. +// +func (cm *CM) init(valueSpace []string) { + var colTypes []int + var colNames []string + + for _, v := range valueSpace { + colTypes = append(colTypes, tabula.TInteger) + colNames = append(colNames, v) + cm.rowNames = append(cm.rowNames, v) + } + + // class error column + colTypes = append(colTypes, tabula.TReal) + colNames = append(colNames, "class_error") + + cm.Dataset.Init(tabula.DatasetModeMatrix, colTypes, colNames) +} + +// +// countTargetPrediction will count and return number of true positive or false +// positive in predictions using targets values. +// +func (cm *CM) countTargetPrediction(target, predict string, + targets, predictions []string, +) ( + cnt int64, +) { + predictslen := len(predictions) + + for x, v := range targets { + // In case out of range, where predictions length less than + // targets length. + if x >= predictslen { + break + } + if v != target { + continue + } + if predict == predictions[x] { + cnt++ + } + } + return +} + +// +// countNumeric will count and return number of `pred` in predictions where +// actuals value is `act`. +// +func (cm *CM) countNumeric(act, pred int64, actuals, predictions []int64) ( + cnt int64, +) { + // Find minimum length to mitigate out-of-range loop. + minlen := len(actuals) + if len(predictions) < minlen { + minlen = len(predictions) + } + + for x := 0; x < minlen; x++ { + if actuals[x] != act { + continue + } + if predictions[x] != pred { + continue + } + cnt++ + } + return cnt +} + +// +// computeClassError will compute the classification error in matrix. +// +func (cm *CM) computeClassError() { + var tp, fp int64 + + cm.nSamples = 0 + cm.nFalse = 0 + + classcol := cm.GetNColumn() - 1 + col := cm.GetColumnClassError() + rows := cm.GetDataAsRows() + for x, row := range *rows { + for y, cell := range *row { + if y == classcol { + break + } + if x == y { + tp = cell.Integer() + } else { + fp += cell.Integer() + } + } + + nSamplePerRow := tp + fp + errv := float64(fp) / float64(nSamplePerRow) + col.PushBack(tabula.NewRecordReal(errv)) + + cm.nSamples += nSamplePerRow + cm.nTrue += tp + cm.nFalse += fp + } + + cm.PushColumnToRows(*col) +} + +// +// GroupIndexPredictions given index of samples, group the samples by their +// class of prediction. For example, +// +// sampleIds: [0, 1, 2, 3, 4, 5] +// actuals: [1, 1, 0, 0, 1, 0] +// predictions: [1, 0, 1, 0, 1, 1] +// +// This function will group the index by true-positive, false-positive, +// true-negative, and false-negative, which result in, +// +// true-positive indices: [0, 4] +// false-positive indices: [2, 5] +// true-negative indices: [3] +// false-negative indices: [1] +// +// This function assume that positive value as "1" and negative value as "0". +// +func (cm *CM) GroupIndexPredictions(sampleIds []int, + actuals, predictions []int64, +) { + // Reset indices. + cm.tpIds = nil + cm.fpIds = nil + cm.tnIds = nil + cm.fnIds = nil + + // Make sure we are not out-of-range when looping, always pick the + // minimum length between the three parameters. + min := len(sampleIds) + if len(actuals) < min { + min = len(actuals) + } + if len(predictions) < min { + min = len(predictions) + } + + for x := 0; x < min; x++ { + if actuals[x] == 1 { + if predictions[x] == 1 { + cm.tpIds = append(cm.tpIds, sampleIds[x]) + } else { + cm.fnIds = append(cm.fnIds, sampleIds[x]) + } + } else { + if predictions[x] == 1 { + cm.fpIds = append(cm.fpIds, sampleIds[x]) + } else { + cm.tnIds = append(cm.tnIds, sampleIds[x]) + } + } + } +} + +// +// GroupIndexPredictionsStrings is an alternative to GroupIndexPredictions +// which work with string class. +// +func (cm *CM) GroupIndexPredictionsStrings(sampleIds []int, + actuals, predictions []string, +) { + if len(sampleIds) <= 0 { + return + } + + // Reset indices. + cm.tpIds = nil + cm.fpIds = nil + cm.tnIds = nil + cm.fnIds = nil + + // Make sure we are not out-of-range when looping, always pick the + // minimum length between the three parameters. + min := len(sampleIds) + if len(actuals) < min { + min = len(actuals) + } + if len(predictions) < min { + min = len(predictions) + } + + for x := 0; x < min; x++ { + if actuals[x] == "1" { + if predictions[x] == "1" { + cm.tpIds = append(cm.tpIds, sampleIds[x]) + } else { + cm.fnIds = append(cm.fnIds, sampleIds[x]) + } + } else { + if predictions[x] == "1" { + cm.fpIds = append(cm.fpIds, sampleIds[x]) + } else { + cm.tnIds = append(cm.tnIds, sampleIds[x]) + } + } + } +} + +// +// GetColumnClassError return the last column which is the column that contain +// the error of classification. +// +func (cm *CM) GetColumnClassError() *tabula.Column { + return cm.GetColumn(cm.GetNColumn() - 1) +} + +// +// GetTrueRate return true-positive rate in term of +// +// true-positive / (true-positive + false-positive) +// +func (cm *CM) GetTrueRate() float64 { + return float64(cm.nTrue) / float64(cm.nTrue+cm.nFalse) +} + +// +// GetFalseRate return false-positive rate in term of, +// +// false-positive / (false-positive + true negative) +// +func (cm *CM) GetFalseRate() float64 { + return float64(cm.nFalse) / float64(cm.nTrue+cm.nFalse) +} + +// +// TP return number of true-positive in confusion matrix. +// +func (cm *CM) TP() int { + row := cm.GetRow(0) + if row == nil { + return 0 + } + + v, _ := row.GetIntAt(0) + return int(v) +} + +// +// FP return number of false-positive in confusion matrix. +// +func (cm *CM) FP() int { + row := cm.GetRow(0) + if row == nil { + return 0 + } + + v, _ := row.GetIntAt(1) + return int(v) +} + +// +// FN return number of false-negative. +// +func (cm *CM) FN() int { + row := cm.GetRow(1) + if row == nil { + return 0 + } + v, _ := row.GetIntAt(0) + return int(v) +} + +// +// TN return number of true-negative. +// +func (cm *CM) TN() int { + row := cm.GetRow(1) + if row == nil { + return 0 + } + v, _ := row.GetIntAt(1) + return int(v) +} + +// +// TPIndices return indices of all true-positive samples. +// +func (cm *CM) TPIndices() []int { + return cm.tpIds +} + +// +// FNIndices return indices of all false-negative samples. +// +func (cm *CM) FNIndices() []int { + return cm.fnIds +} + +// +// FPIndices return indices of all false-positive samples. +// +func (cm *CM) FPIndices() []int { + return cm.fpIds +} + +// +// TNIndices return indices of all true-negative samples. +// +func (cm *CM) TNIndices() []int { + return cm.tnIds +} + +// +// String will return the output of confusion matrix in table like format. +// +func (cm *CM) String() (s string) { + s += "Confusion Matrix:\n" + + // Row header: column names. + s += "\t" + for _, col := range cm.GetColumnsName() { + s += col + "\t" + } + s += "\n" + + rows := cm.GetDataAsRows() + for x, row := range *rows { + s += cm.rowNames[x] + "\t" + + for _, v := range *row { + s += v.String() + "\t" + } + s += "\n" + } + + s += fmt.Sprintf("TP-FP indices %d %d\n", len(cm.tpIds), len(cm.fpIds)) + s += fmt.Sprintf("FN-TN indices %d %d\n", len(cm.fnIds), len(cm.tnIds)) + + return +} diff --git a/lib/mining/classifier/cm_test.go b/lib/mining/classifier/cm_test.go new file mode 100644 index 00000000..6fd47d93 --- /dev/null +++ b/lib/mining/classifier/cm_test.go @@ -0,0 +1,69 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +import ( + "fmt" + "testing" + + "github.com/shuLhan/share/lib/test" +) + +func TestComputeNumeric(t *testing.T) { + actuals := []int64{1, 1, 1, 0, 0, 0, 0} + predics := []int64{1, 1, 0, 0, 0, 0, 1} + vs := []int64{1, 0} + exp := []int{2, 1, 3, 1} + + cm := &CM{} + + cm.ComputeNumeric(vs, actuals, predics) + + test.Assert(t, "", exp[0], cm.TP(), true) + test.Assert(t, "", exp[1], cm.FN(), true) + test.Assert(t, "", exp[2], cm.TN(), true) + test.Assert(t, "", exp[3], cm.FP(), true) + + fmt.Println(cm) +} + +func TestComputeStrings(t *testing.T) { + actuals := []string{"1", "1", "1", "0", "0", "0", "0"} + predics := []string{"1", "1", "0", "0", "0", "0", "1"} + vs := []string{"1", "0"} + exp := []int{2, 1, 3, 1} + + cm := &CM{} + + cm.ComputeStrings(vs, actuals, predics) + + test.Assert(t, "", exp[0], cm.TP(), true) + test.Assert(t, "", exp[1], cm.FN(), true) + test.Assert(t, "", exp[2], cm.TN(), true) + test.Assert(t, "", exp[3], cm.FP(), true) + + fmt.Println(cm) +} + +func TestGroupIndexPredictions(t *testing.T) { + testIds := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + actuals := []int64{1, 1, 1, 1, 0, 0, 0, 0, 0, 0} + predics := []int64{1, 1, 0, 1, 0, 0, 0, 0, 1, 0} + exp := [][]int{ + {0, 1, 3}, // tp + {2}, // fn + {8}, // fp + {4, 5, 6, 7, 9}, // tn + } + + cm := &CM{} + + cm.GroupIndexPredictions(testIds, actuals, predics) + + test.Assert(t, "", exp[0], cm.TPIndices(), true) + test.Assert(t, "", exp[1], cm.FNIndices(), true) + test.Assert(t, "", exp[2], cm.FPIndices(), true) + test.Assert(t, "", exp[3], cm.TNIndices(), true) +} diff --git a/lib/mining/classifier/crf/crf.go b/lib/mining/classifier/crf/crf.go new file mode 100644 index 00000000..3df7dd83 --- /dev/null +++ b/lib/mining/classifier/crf/crf.go @@ -0,0 +1,470 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +// Package crf implement the cascaded random forest algorithm, proposed by +// Baumann et.al in their paper: +// +// Baumann, Florian, et al. "Cascaded Random Forest for Fast Object +// Detection." Image Analysis. Springer Berlin Heidelberg, 2013. 131-142. +// +// +package crf + +import ( + "errors" + "fmt" + "math" + "sort" + + "github.com/shuLhan/share/lib/debug" + "github.com/shuLhan/share/lib/mining/classifier" + "github.com/shuLhan/share/lib/mining/classifier/rf" + "github.com/shuLhan/share/lib/numbers" + libstrings "github.com/shuLhan/share/lib/strings" + "github.com/shuLhan/share/lib/tabula" +) + +const ( + tag = "[crf]" + + // DefStage default number of stage + DefStage = 200 + // DefTPRate default threshold for true-positive rate. + DefTPRate = 0.9 + // DefTNRate default threshold for true-negative rate. + DefTNRate = 0.7 + + // DefNumTree default number of tree. + DefNumTree = 1 + // DefPercentBoot default percentage of sample that will be used for + // bootstraping a tree. + DefPercentBoot = 66 + // DefPerfFile default performance file output. + DefPerfFile = "crf.perf" + // DefStatFile default statistic file output. + DefStatFile = "crf.stat" +) + +var ( + // ErrNoInput will tell you when no input is given. + ErrNoInput = errors.New("rf: input samples is empty") +) + +// +// Runtime define the cascaded random forest runtime input and output. +// +type Runtime struct { + // Runtime embed common fields for classifier. + classifier.Runtime + + // NStage number of stage. + NStage int `json:"NStage"` + // TPRate threshold for true positive rate per stage. + TPRate float64 `json:"TPRate"` + // TNRate threshold for true negative rate per stage. + TNRate float64 `json:"TNRate"` + + // NTree number of tree in each stage. + NTree int `json:"NTree"` + // NRandomFeature number of features used to split the dataset. + NRandomFeature int `json:"NRandomFeature"` + // PercentBoot percentage of bootstrap. + PercentBoot int `json:"PercentBoot"` + + // forests contain forest for each stage. + forests []*rf.Runtime + // weights contain weight for each stage. + weights []float64 + // tnset contain sample of all true-negative in each iteration. + tnset *tabula.Claset +} + +// +// New create and return new input for cascaded random-forest. +// +func New(nstage, ntree, percentboot, nfeature int, + tprate, tnrate float64, + samples tabula.ClasetInterface, +) ( + crf *Runtime, +) { + crf = &Runtime{ + NStage: nstage, + NTree: ntree, + PercentBoot: percentboot, + NRandomFeature: nfeature, + TPRate: tprate, + TNRate: tnrate, + } + + return crf +} + +// +// AddForest will append new forest. +// +func (crf *Runtime) AddForest(forest *rf.Runtime) { + crf.forests = append(crf.forests, forest) +} + +// +// Initialize will check crf inputs and set it to default values if its +// invalid. +// +func (crf *Runtime) Initialize(samples tabula.ClasetInterface) error { + if crf.NStage <= 0 { + crf.NStage = DefStage + } + if crf.TPRate <= 0 || crf.TPRate >= 1 { + crf.TPRate = DefTPRate + } + if crf.TNRate <= 0 || crf.TNRate >= 1 { + crf.TNRate = DefTNRate + } + if crf.NTree <= 0 { + crf.NTree = DefNumTree + } + if crf.PercentBoot <= 0 { + crf.PercentBoot = DefPercentBoot + } + if crf.NRandomFeature <= 0 { + // Set default value to square-root of features. + ncol := samples.GetNColumn() - 1 + crf.NRandomFeature = int(math.Sqrt(float64(ncol))) + } + if crf.PerfFile == "" { + crf.PerfFile = DefPerfFile + } + if crf.StatFile == "" { + crf.StatFile = DefStatFile + } + crf.tnset = samples.Clone().(*tabula.Claset) + + return crf.Runtime.Initialize() +} + +// +// Build given a sample dataset, build the stage with randomforest. +// +func (crf *Runtime) Build(samples tabula.ClasetInterface) (e error) { + if samples == nil { + return ErrNoInput + } + + e = crf.Initialize(samples) + if e != nil { + return + } + + fmt.Println(tag, "Training samples:", samples) + fmt.Println(tag, "Sample (one row):", samples.GetRow(0)) + fmt.Println(tag, "Config:", crf) + + for x := 0; x < crf.NStage; x++ { + if debug.Value >= 1 { + fmt.Println(tag, "Stage #", x) + } + + forest, e := crf.createForest(samples) + if e != nil { + return e + } + + e = crf.finalizeStage(forest) + if e != nil { + return e + } + } + + return crf.Finalize() +} + +// +// createForest will create and return a forest and run the training `samples` +// on it. +// +// Algorithm, +// (1) Initialize forest. +// (2) For 0 to maximum number of tree in forest, +// (2.1) grow one tree until success. +// (2.2) If tree tp-rate and tn-rate greater than threshold, stop growing. +// (3) Calculate weight. +// (4) TODO: Move true-negative from samples. The collection of true-negative +// will be used again to test the model and after test and the sample with FP +// will be moved to training samples again. +// (5) Refill samples with false-positive. +// +func (crf *Runtime) createForest(samples tabula.ClasetInterface) ( + forest *rf.Runtime, e error, +) { + var cm *classifier.CM + var stat *classifier.Stat + + fmt.Println(tag, "Forest samples:", samples) + + // (1) + forest = &rf.Runtime{ + Runtime: classifier.Runtime{ + RunOOB: true, + }, + NTree: crf.NTree, + NRandomFeature: crf.NRandomFeature, + } + + e = forest.Initialize(samples) + if e != nil { + return nil, e + } + + // (2) + for t := 0; t < crf.NTree; t++ { + if debug.Value >= 2 { + fmt.Println(tag, "Tree #", t) + } + + // (2.1) + for { + cm, stat, e = forest.GrowTree(samples) + if e == nil { + break + } + } + + // (2.2) + if stat.TPRate > crf.TPRate && + stat.TNRate > crf.TNRate { + break + } + } + + e = forest.Finalize() + if e != nil { + return nil, e + } + + // (3) + crf.computeWeight(stat) + + if debug.Value >= 1 { + fmt.Println(tag, "Weight:", stat.FMeasure) + } + + // (4) + crf.deleteTrueNegative(samples, cm) + + // (5) + crf.runTPSet(samples) + + samples.RecountMajorMinor() + + return forest, nil +} + +// +// finalizeStage save forest and write the forest statistic to file. +// +func (crf *Runtime) finalizeStage(forest *rf.Runtime) (e error) { + stat := forest.StatTotal() + stat.ID = int64(len(crf.forests)) + + e = crf.WriteOOBStat(stat) + if e != nil { + return e + } + + crf.AddStat(stat) + crf.ComputeStatTotal(stat) + + if debug.Value >= 1 { + crf.PrintStatTotal(nil) + } + + // (7) + crf.AddForest(forest) + + return nil +} + +// +// computeWeight will compute the weight of stage based on F-measure of the +// last tree in forest. +// +func (crf *Runtime) computeWeight(stat *classifier.Stat) { + crf.weights = append(crf.weights, math.Exp(stat.FMeasure)) +} + +// +// deleteTrueNegative will delete all samples data where their row index is in +// true-negative values in confusion matrix and move it to TN-set. +// +// (1) Move true negative to tnset on the first iteration, on the next +// iteration it will be full deleted. +// (2) Delete TN from sample set one-by-one with offset, to make sure we +// are not deleting with wrong index. + +func (crf *Runtime) deleteTrueNegative(samples tabula.ClasetInterface, + cm *classifier.CM, +) { + var row *tabula.Row + + tnids := cm.TNIndices() + sort.Ints(tnids) + + // (1) + if len(crf.weights) <= 1 { + for _, i := range tnids { + crf.tnset.PushRow(samples.GetRow(i)) + } + } + + // (2) + c := 0 + for x, i := range tnids { + row = samples.DeleteRow(i - x) + if row != nil { + c++ + } + } + + if debug.Value >= 1 { + fmt.Println(tag, "# TN", len(tnids), "# deleted", c) + } +} + +// +// refillWithFP will copy the false-positive data in training set `tnset` +// and append it to `samples`. +// +func (crf *Runtime) refillWithFP(samples, tnset tabula.ClasetInterface, + cm *classifier.CM, +) { + // Get and sort FP. + fpids := cm.FPIndices() + sort.Ints(fpids) + + // Move FP samples from TN-set to training set samples. + for _, i := range fpids { + samples.PushRow(tnset.GetRow(i)) + } + + // Delete FP from training set. + var row *tabula.Row + c := 0 + for x, i := range fpids { + row = tnset.DeleteRow(i - x) + if row != nil { + c++ + } + } + + if debug.Value >= 1 { + fmt.Println(tag, "# FP", len(fpids), "# refilled", c) + } +} + +// +// runTPSet will run true-positive set into trained stage, to get the +// false-positive. The FP samples will be added to training set. +// +func (crf *Runtime) runTPSet(samples tabula.ClasetInterface) { + // Skip the first stage, because we just got tnset from them. + if len(crf.weights) <= 1 { + return + } + + tnIds := numbers.IntCreateSeq(0, crf.tnset.Len()-1) + _, cm, _ := crf.ClassifySetByWeight(crf.tnset, tnIds) + + crf.refillWithFP(samples, crf.tnset, cm) +} + +// +// ClassifySetByWeight will classify each instance in samples by weight +// with respect to its single performance. +// +// Algorithm, +// (1) For each instance in samples, +// (1.1) for each stage, +// (1.1.1) collect votes for instance in current stage. +// (1.1.2) Compute probabilities of each classes in votes. +// +// prob_class = count_of_class / total_votes +// +// (1.1.3) Compute total of probabilites times of stage weight. +// +// stage_prob = prob_class * stage_weight +// +// (1.2) Divide each class stage probabilites with +// +// stage_prob = stage_prob / +// (sum_of_all_weights * number_of_tree_in_forest) +// +// (1.3) Select class label with highest probabilites. +// (1.4) Save stage probabilities for positive class. +// (2) Compute confusion matrix. +// +func (crf *Runtime) ClassifySetByWeight(samples tabula.ClasetInterface, + sampleIds []int, +) ( + predicts []string, cm *classifier.CM, probs []float64, +) { + stat := classifier.Stat{} + stat.Start() + + vs := samples.GetClassValueSpace() + stageProbs := make([]float64, len(vs)) + stageSumProbs := make([]float64, len(vs)) + sumWeights := numbers.Floats64Sum(crf.weights) + + // (1) + rows := samples.GetDataAsRows() + for _, row := range *rows { + for y := range stageSumProbs { + stageSumProbs[y] = 0 + } + + // (1.1) + for y, forest := range crf.forests { + // (1.1.1) + votes := forest.Votes(row, -1) + + // (1.1.2) + probs := libstrings.FrequencyOfTokens(votes, vs, false) + + // (1.1.3) + for z := range probs { + stageSumProbs[z] += probs[z] + stageProbs[z] += probs[z] * crf.weights[y] + } + } + + // (1.2) + stageWeight := sumWeights * float64(crf.NTree) + + for x := range stageProbs { + stageProbs[x] = stageProbs[x] / stageWeight + } + + // (1.3) + _, maxi, ok := numbers.Floats64FindMax(stageProbs) + if ok { + predicts = append(predicts, vs[maxi]) + } + + probs = append(probs, stageSumProbs[0]/ + float64(len(crf.forests))) + } + + // (2) + actuals := samples.GetClassAsStrings() + cm = crf.ComputeCM(sampleIds, vs, actuals, predicts) + + crf.ComputeStatFromCM(&stat, cm) + stat.End() + + _ = stat.Write(crf.StatFile) + + return predicts, cm, probs +} diff --git a/lib/mining/classifier/crf/crf_test.go b/lib/mining/classifier/crf/crf_test.go new file mode 100644 index 00000000..4e530591 --- /dev/null +++ b/lib/mining/classifier/crf/crf_test.go @@ -0,0 +1,92 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package crf + +import ( + "fmt" + "os" + "testing" + + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/mining/classifier" + "github.com/shuLhan/share/lib/tabula" +) + +var ( + SampleFile string + PerfFile string + StatFile string + NStage = 200 + NTree = 1 +) + +func runCRF(t *testing.T) { + // read trainingset. + samples := tabula.Claset{} + _, e := dsv.SimpleRead(SampleFile, &samples) + if e != nil { + t.Fatal(e) + } + + nbag := (samples.Len() * 63) / 100 + train, test, _, testIds := tabula.RandomPickRows(&samples, nbag, false) + + trainset := train.(tabula.ClasetInterface) + testset := test.(tabula.ClasetInterface) + + crfRuntime := Runtime{ + Runtime: classifier.Runtime{ + StatFile: StatFile, + PerfFile: PerfFile, + }, + NStage: NStage, + NTree: NTree, + } + + e = crfRuntime.Build(trainset) + if e != nil { + t.Fatal(e) + } + + testset.RecountMajorMinor() + fmt.Println("Testset:", testset) + + predicts, cm, probs := crfRuntime.ClassifySetByWeight(testset, testIds) + + fmt.Println("Confusion matrix:", cm) + + crfRuntime.Performance(testset, predicts, probs) + e = crfRuntime.WritePerformance() + if e != nil { + t.Fatal(e) + } +} + +func TestPhoneme200_1(t *testing.T) { + SampleFile = "../../testdata/phoneme/phoneme.dsv" + PerfFile = "phoneme_200_1.perf" + StatFile = "phoneme_200_1.stat" + + runCRF(t) +} + +func TestPhoneme200_10(t *testing.T) { + SampleFile = "../../testdata/phoneme/phoneme.dsv" + PerfFile = "phoneme_200_10.perf" + StatFile = "phoneme_200_10.stat" + NTree = 10 + + runCRF(t) +} + +func TestMain(m *testing.M) { + envTestCRF := os.Getenv("TEST_CRF") + + if len(envTestCRF) == 0 { + os.Exit(0) + } + + os.Exit(m.Run()) +} diff --git a/lib/mining/classifier/rf/rf.go b/lib/mining/classifier/rf/rf.go new file mode 100644 index 00000000..86595efd --- /dev/null +++ b/lib/mining/classifier/rf/rf.go @@ -0,0 +1,363 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +// Package rf implement ensemble of classifiers using random forest +// algorithm by Breiman and Cutler. +// +// Breiman, Leo. "Random forests." Machine learning 45.1 (2001): 5-32. +// +// The implementation is based on various sources and using author experience. +// +package rf + +import ( + "errors" + "fmt" + "math" + + "github.com/shuLhan/share/lib/debug" + "github.com/shuLhan/share/lib/mining/classifier" + "github.com/shuLhan/share/lib/mining/classifier/cart" + "github.com/shuLhan/share/lib/numbers" + libstrings "github.com/shuLhan/share/lib/strings" + "github.com/shuLhan/share/lib/tabula" +) + +const ( + tag = "[rf]" + + // DefNumTree default number of tree. + DefNumTree = 100 + + // DefPercentBoot default percentage of sample that will be used for + // bootstraping a tree. + DefPercentBoot = 66 + + // DefOOBStatsFile default statistic file output. + DefOOBStatsFile = "rf.oob.stat" + + // DefPerfFile default performance file output. + DefPerfFile = "rf.perf" + + // DefStatFile default statistic file. + DefStatFile = "rf.stat" +) + +var ( + // ErrNoInput will tell you when no input is given. + ErrNoInput = errors.New("rf: input samples is empty") +) + +// +// Runtime contains input and output configuration when generating random forest. +// +type Runtime struct { + // Runtime embed common fields for classifier. + classifier.Runtime + + // NTree number of tree in forest. + NTree int `json:"NTree"` + // NRandomFeature number of feature randomly selected for each tree. + NRandomFeature int `json:"NRandomFeature"` + // PercentBoot percentage of sample for bootstraping. + PercentBoot int `json:"PercentBoot"` + + // nSubsample number of samples used for bootstraping. + nSubsample int + // trees contain all tree in the forest. + trees []cart.Runtime + // bagIndices contain list of index of selected samples at bootstraping + // for book-keeping. + bagIndices [][]int +} + +// +// Trees return all tree in forest. +// +func (forest *Runtime) Trees() []cart.Runtime { + return forest.trees +} + +// +// AddCartTree add tree to forest +// +func (forest *Runtime) AddCartTree(tree cart.Runtime) { + forest.trees = append(forest.trees, tree) +} + +// +// AddBagIndex add bagging index for book keeping. +// +func (forest *Runtime) AddBagIndex(bagIndex []int) { + forest.bagIndices = append(forest.bagIndices, bagIndex) +} + +// +// Initialize will check forest inputs and set it to default values if invalid. +// +// It will also calculate number of random samples for each tree using, +// +// number-of-sample * percentage-of-bootstrap +// +// +func (forest *Runtime) Initialize(samples tabula.ClasetInterface) error { + if forest.NTree <= 0 { + forest.NTree = DefNumTree + } + if forest.PercentBoot <= 0 { + forest.PercentBoot = DefPercentBoot + } + if forest.NRandomFeature <= 0 { + // Set default value to square-root of features. + ncol := samples.GetNColumn() - 1 + forest.NRandomFeature = int(math.Sqrt(float64(ncol))) + } + if forest.OOBStatsFile == "" { + forest.OOBStatsFile = DefOOBStatsFile + } + if forest.PerfFile == "" { + forest.PerfFile = DefPerfFile + } + if forest.StatFile == "" { + forest.StatFile = DefStatFile + } + + forest.nSubsample = int(float32(samples.GetNRow()) * + (float32(forest.PercentBoot) / 100.0)) + + return forest.Runtime.Initialize() +} + +// +// Build the forest using samples dataset. +// +// Algorithm, +// +// (0) Recheck input value: number of tree, percentage bootstrap, etc; and +// Open statistic file output. +// (1) For 0 to NTree, +// (1.1) Create new tree, repeat until all trees has been build. +// (2) Compute and write total statistic. +// +func (forest *Runtime) Build(samples tabula.ClasetInterface) (e error) { + // check input samples + if samples == nil { + return ErrNoInput + } + + // (0) + e = forest.Initialize(samples) + if e != nil { + return + } + + fmt.Println(tag, "Training set :", samples) + fmt.Println(tag, "Sample (one row):", samples.GetRow(0)) + fmt.Println(tag, "Forest config :", forest) + + // (1) + for t := 0; t < forest.NTree; t++ { + if debug.Value >= 1 { + fmt.Println(tag, "tree #", t) + } + + // (1.1) + for { + _, _, e = forest.GrowTree(samples) + if e == nil { + break + } + + fmt.Println(tag, "error:", e) + } + } + + // (2) + return forest.Finalize() +} + +// +// GrowTree build a new tree in forest, return OOB error value or error if tree +// can not grow. +// +// Algorithm, +// +// (1) Select random samples with replacement, also with OOB. +// (2) Build tree using CART, without pruning. +// (3) Add tree to forest. +// (4) Save index of random samples for calculating error rate later. +// (5) Run OOB on forest. +// (6) Calculate OOB error rate and statistic values. +// +func (forest *Runtime) GrowTree(samples tabula.ClasetInterface) ( + cm *classifier.CM, stat *classifier.Stat, e error, +) { + stat = &classifier.Stat{} + stat.ID = int64(len(forest.trees)) + stat.Start() + + // (1) + bag, oob, bagIdx, oobIdx := tabula.RandomPickRows( + samples.(tabula.DatasetInterface), + forest.nSubsample, true) + + bagset := bag.(tabula.ClasetInterface) + + if debug.Value >= 2 { + bagset.RecountMajorMinor() + fmt.Println(tag, "Bagging:", bagset) + } + + // (2) + cart, e := cart.New(bagset, cart.SplitMethodGini, + forest.NRandomFeature) + if e != nil { + return nil, nil, e + } + + // (3) + forest.AddCartTree(*cart) + + // (4) + forest.AddBagIndex(bagIdx) + + // (5) + if forest.RunOOB { + oobset := oob.(tabula.ClasetInterface) + _, cm, _ = forest.ClassifySet(oobset, oobIdx) + + forest.AddOOBCM(cm) + } + + stat.End() + + if debug.Value >= 3 && forest.RunOOB { + fmt.Println(tag, "Elapsed time (s):", stat.ElapsedTime) + } + + forest.AddStat(stat) + + // (6) + if forest.RunOOB { + forest.ComputeStatFromCM(stat, cm) + + if debug.Value >= 2 { + fmt.Println(tag, "OOB stat:", stat) + } + } + + forest.ComputeStatTotal(stat) + e = forest.WriteOOBStat(stat) + + return cm, stat, e +} + +// +// ClassifySet given a samples predict their class by running each sample in +// forest, adn return their class prediction with confusion matrix. +// `samples` is the sample that will be predicted, `sampleIds` is the index of +// samples. +// If `sampleIds` is not nil, then sample index will be checked in each tree, +// if the sample is used for training, their vote is not counted. +// +// Algorithm, +// +// (0) Get value space (possible class values in dataset) +// (1) For each row in test-set, +// (1.1) collect votes in all trees, +// (1.2) select majority class vote, and +// (1.3) compute and save the actual class probabilities. +// (2) Compute confusion matrix from predictions. +// (3) Compute stat from confusion matrix. +// (4) Write the stat to file only if sampleIds is empty, which mean its run +// not from OOB set. +// +func (forest *Runtime) ClassifySet(samples tabula.ClasetInterface, + sampleIds []int, +) ( + predicts []string, cm *classifier.CM, probs []float64, +) { + stat := classifier.Stat{} + stat.Start() + + if len(sampleIds) <= 0 { + fmt.Println(tag, "Classify set:", samples) + fmt.Println(tag, "Classify set sample (one row):", + samples.GetRow(0)) + } + + // (0) + vs := samples.GetClassValueSpace() + actuals := samples.GetClassAsStrings() + sampleIdx := -1 + + // (1) + rows := samples.GetRows() + for x, row := range *rows { + // (1.1) + if len(sampleIds) > 0 { + sampleIdx = sampleIds[x] + } + votes := forest.Votes(row, sampleIdx) + + // (1.2) + classProbs := libstrings.FrequencyOfTokens(votes, vs, false) + + _, idx, ok := numbers.Floats64FindMax(classProbs) + + if ok { + predicts = append(predicts, vs[idx]) + } + + // (1.3) + probs = append(probs, classProbs[0]) + } + + // (2) + cm = forest.ComputeCM(sampleIds, vs, actuals, predicts) + + // (3) + forest.ComputeStatFromCM(&stat, cm) + stat.End() + + if len(sampleIds) <= 0 { + fmt.Println(tag, "CM:", cm) + fmt.Println(tag, "Classifying stat:", stat) + _ = stat.Write(forest.StatFile) + } + + return predicts, cm, probs +} + +// +// Votes will return votes, or classes, in each tree based on sample. +// If checkIdx is true then the `sampleIdx` will be checked in if it has been used +// when training the tree, if its exist then the sample will be skipped. +// +// (1) If row is used to build the tree then skip it, +// (2) classify row in tree, +// (3) save tree class value. +// +func (forest *Runtime) Votes(sample *tabula.Row, sampleIdx int) ( + votes []string, +) { + for x, tree := range forest.trees { + // (1) + if sampleIdx >= 0 { + exist := numbers.IntsIsExist(forest.bagIndices[x], + sampleIdx) + if exist { + continue + } + } + + // (2) + class := tree.Classify(sample) + + // (3) + votes = append(votes, class) + } + return votes +} diff --git a/lib/mining/classifier/rf/rf_bench_test.go b/lib/mining/classifier/rf/rf_bench_test.go new file mode 100644 index 00000000..eddd721c --- /dev/null +++ b/lib/mining/classifier/rf/rf_bench_test.go @@ -0,0 +1,22 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rf + +import ( + "testing" +) + +func BenchmarkPhoneme(b *testing.B) { + SampleDsvFile = "../../testdata/phoneme/phoneme.dsv" + OOBStatsFile = "phoneme.oob" + PerfFile = "phoneme.perf" + + MinFeature = 3 + MaxFeature = 4 + + for x := 0; x < b.N; x++ { + runRandomForest() + } +} diff --git a/lib/mining/classifier/rf/rf_test.go b/lib/mining/classifier/rf/rf_test.go new file mode 100644 index 00000000..ba96125b --- /dev/null +++ b/lib/mining/classifier/rf/rf_test.go @@ -0,0 +1,190 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rf + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/mining/classifier" + "github.com/shuLhan/share/lib/tabula" +) + +// Global options to run for each test. +var ( + // SampleDsvFile is the file that contain samples config. + SampleDsvFile string + // DoTest if its true then the dataset will splited into training and + // test set with random selection without replacement. + DoTest = false + // NTree number of tree to generate. + NTree = 100 + // NBootstrap percentage of sample used as subsample. + NBootstrap = 66 + // MinFeature number of feature to begin with. + MinFeature = 1 + // MaxFeature maximum number of feature to test + MaxFeature = -1 + // RunOOB if its true the the OOB samples will be used to test the + // model in each iteration. + RunOOB = true + // OOBStatsFile is the file where OOB statistic will be saved. + OOBStatsFile string + // PerfFile is the file where performance statistic will be saved. + PerfFile string + // StatFile is the file where classifying statistic will be saved. + StatFile string +) + +func getSamples() (train, test tabula.ClasetInterface) { + samples := tabula.Claset{} + _, e := dsv.SimpleRead(SampleDsvFile, &samples) + if nil != e { + log.Fatal(e) + } + + if !DoTest { + return &samples, nil + } + + ntrain := int(float32(samples.Len()) * (float32(NBootstrap) / 100.0)) + + bag, oob, _, _ := tabula.RandomPickRows(&samples, ntrain, false) + + train = bag.(tabula.ClasetInterface) + test = oob.(tabula.ClasetInterface) + + train.SetClassIndex(samples.GetClassIndex()) + test.SetClassIndex(samples.GetClassIndex()) + + return train, test +} + +func runRandomForest() { + oobStatsFile := OOBStatsFile + perfFile := PerfFile + statFile := StatFile + + trainset, testset := getSamples() + + if MaxFeature < 0 { + MaxFeature = trainset.GetNColumn() + } + + for nfeature := MinFeature; nfeature < MaxFeature; nfeature++ { + // Add prefix to OOB stats file. + oobStatsFile = fmt.Sprintf("N%d.%s", nfeature, OOBStatsFile) + + // Add prefix to performance file. + perfFile = fmt.Sprintf("N%d.%s", nfeature, PerfFile) + + // Add prefix to stat file. + statFile = fmt.Sprintf("N%d.%s", nfeature, StatFile) + + // Create and build random forest. + forest := Runtime{ + Runtime: classifier.Runtime{ + RunOOB: RunOOB, + OOBStatsFile: oobStatsFile, + PerfFile: perfFile, + StatFile: statFile, + }, + NTree: NTree, + NRandomFeature: nfeature, + PercentBoot: NBootstrap, + } + + e := forest.Build(trainset) + if e != nil { + log.Fatal(e) + } + + if DoTest { + predicts, _, probs := forest.ClassifySet(testset, nil) + + forest.Performance(testset, predicts, probs) + e = forest.WritePerformance() + if e != nil { + log.Fatal(e) + } + } + } +} + +func TestEnsemblingGlass(t *testing.T) { + SampleDsvFile = "../../testdata/forensic_glass/fgl.dsv" + RunOOB = false + OOBStatsFile = "glass.oob" + StatFile = "glass.stat" + PerfFile = "glass.perf" + DoTest = true + + runRandomForest() +} + +func TestEnsemblingIris(t *testing.T) { + SampleDsvFile = "../../testdata/iris/iris.dsv" + OOBStatsFile = "iris.oob" + + runRandomForest() +} + +func TestEnsemblingPhoneme(t *testing.T) { + SampleDsvFile = "../../testdata/phoneme/phoneme.dsv" + OOBStatsFile = "phoneme.oob.stat" + StatFile = "phoneme.stat" + PerfFile = "phoneme.perf" + + NTree = 200 + MinFeature = 3 + MaxFeature = 4 + RunOOB = false + DoTest = true + + runRandomForest() +} + +func TestEnsemblingSmotePhoneme(t *testing.T) { + SampleDsvFile = "../../resampling/smote/phoneme_smote.dsv" + OOBStatsFile = "phonemesmote.oob" + + MinFeature = 3 + MaxFeature = 4 + + runRandomForest() +} + +func TestEnsemblingLnsmotePhoneme(t *testing.T) { + SampleDsvFile = "../../resampling/lnsmote/phoneme_lnsmote.dsv" + OOBStatsFile = "phonemelnsmote.oob" + + MinFeature = 3 + MaxFeature = 4 + + runRandomForest() +} + +func TestWvc2010Lnsmote(t *testing.T) { + SampleDsvFile = "../../testdata/wvc2010lnsmote/wvc2010_features.lnsmote.dsv" + OOBStatsFile = "wvc2010lnsmote.oob" + + NTree = 1 + MinFeature = 5 + MaxFeature = 6 + + runRandomForest() +} + +func TestMain(m *testing.M) { + envTestRF := os.Getenv("TEST_RF") + if len(envTestRF) == 0 { + os.Exit(0) + } + + os.Exit(m.Run()) +} diff --git a/lib/mining/classifier/runtime.go b/lib/mining/classifier/runtime.go new file mode 100644 index 00000000..0f7a7755 --- /dev/null +++ b/lib/mining/classifier/runtime.go @@ -0,0 +1,443 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +import ( + "fmt" + "math" + + "github.com/shuLhan/share/lib/debug" + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/numbers" + libstrings "github.com/shuLhan/share/lib/strings" + "github.com/shuLhan/share/lib/tabula" +) + +const ( + tag = "[classifier.runtime]" +) + +// +// Runtime define a generic type which provide common fields that can be +// embedded by the real classifier (e.g. RandomForest). +// +type Runtime struct { + // RunOOB if its true the OOB will be computed, default is false. + RunOOB bool `json:"RunOOB"` + + // OOBStatsFile is the file where OOB statistic will be written. + OOBStatsFile string `json:"OOBStatsFile"` + + // PerfFile is the file where statistic of performance will be written. + PerfFile string `json:"PerfFile"` + + // StatFile is the file where statistic of classifying samples will be + // written. + StatFile string `json:"StatFile"` + + // oobCms contain confusion matrix value for each OOB in iteration. + oobCms []CM + + // oobStats contain statistic of classifier for each OOB in iteration. + oobStats Stats + + // oobStatTotal contain total OOB statistic values. + oobStatTotal Stat + + // oobWriter contain file writer for statistic. + oobWriter *dsv.Writer + + // perfs contain performance statistic per sample, after classifying + // sample on classifier. + perfs Stats +} + +// +// Initialize will start the runtime for processing by saving start time and +// opening stats file. +// +func (rt *Runtime) Initialize() error { + rt.oobStatTotal.Start() + + return rt.OpenOOBStatsFile() +} + +// +// Finalize finish the runtime, compute total statistic, write it to file, and +// close the file. +// +func (rt *Runtime) Finalize() (e error) { + st := &rt.oobStatTotal + + st.End() + st.ID = int64(len(rt.oobStats)) + + e = rt.WriteOOBStat(st) + if e != nil { + return e + } + + return rt.CloseOOBStatsFile() +} + +// +// OOBStats return all statistic objects. +// +func (rt *Runtime) OOBStats() *Stats { + return &rt.oobStats +} + +// +// StatTotal return total statistic. +// +func (rt *Runtime) StatTotal() *Stat { + return &rt.oobStatTotal +} + +// +// AddOOBCM will append new confusion matrix. +// +func (rt *Runtime) AddOOBCM(cm *CM) { + rt.oobCms = append(rt.oobCms, *cm) +} + +// +// AddStat will append new classifier statistic data. +// +func (rt *Runtime) AddStat(stat *Stat) { + rt.oobStats = append(rt.oobStats, stat) +} + +// +// ComputeCM will compute confusion matrix of sample using value space, actual +// and prediction values. +// +func (rt *Runtime) ComputeCM(sampleIds []int, + vs, actuals, predicts []string, +) ( + cm *CM, +) { + cm = &CM{} + + cm.ComputeStrings(vs, actuals, predicts) + cm.GroupIndexPredictionsStrings(sampleIds, actuals, predicts) + + if debug.Value >= 2 { + fmt.Println(tag, cm) + } + + return cm +} + +// +// ComputeStatFromCM will compute statistic using confusion matrix. +// +func (rt *Runtime) ComputeStatFromCM(stat *Stat, cm *CM) { + + stat.OobError = cm.GetFalseRate() + + stat.OobErrorMean = rt.oobStatTotal.OobError / + float64(len(rt.oobStats)+1) + + stat.TP = int64(cm.TP()) + stat.FP = int64(cm.FP()) + stat.TN = int64(cm.TN()) + stat.FN = int64(cm.FN()) + + t := float64(stat.TP + stat.FN) + if t == 0 { + stat.TPRate = 0 + } else { + stat.TPRate = float64(stat.TP) / t + } + + t = float64(stat.FP + stat.TN) + if t == 0 { + stat.FPRate = 0 + } else { + stat.FPRate = float64(stat.FP) / t + } + + t = float64(stat.FP + stat.TN) + if t == 0 { + stat.TNRate = 0 + } else { + stat.TNRate = float64(stat.TN) / t + } + + t = float64(stat.TP + stat.FP) + if t == 0 { + stat.Precision = 0 + } else { + stat.Precision = float64(stat.TP) / t + } + + t = (1 / stat.Precision) + (1 / stat.TPRate) + if t == 0 { + stat.FMeasure = 0 + } else { + stat.FMeasure = 2 / t + } + + t = float64(stat.TP + stat.TN + stat.FP + stat.FN) + if t == 0 { + stat.Accuracy = 0 + } else { + stat.Accuracy = float64(stat.TP+stat.TN) / t + } + + if debug.Value >= 1 { + rt.PrintOobStat(stat, cm) + rt.PrintStat(stat) + } +} + +// +// ComputeStatTotal compute total statistic. +// +func (rt *Runtime) ComputeStatTotal(stat *Stat) { + if stat == nil { + return + } + + nstat := len(rt.oobStats) + if nstat == 0 { + return + } + + t := &rt.oobStatTotal + + t.OobError += stat.OobError + t.OobErrorMean = t.OobError / float64(nstat) + t.TP += stat.TP + t.FP += stat.FP + t.TN += stat.TN + t.FN += stat.FN + + total := float64(t.TP + t.FN) + if total == 0 { + t.TPRate = 0 + } else { + t.TPRate = float64(t.TP) / total + } + + total = float64(t.FP + t.TN) + if total == 0 { + t.FPRate = 0 + } else { + t.FPRate = float64(t.FP) / total + } + + total = float64(t.FP + t.TN) + if total == 0 { + t.TNRate = 0 + } else { + t.TNRate = float64(t.TN) / total + } + + total = float64(t.TP + t.FP) + if total == 0 { + t.Precision = 0 + } else { + t.Precision = float64(t.TP) / total + } + + total = (1 / t.Precision) + (1 / t.TPRate) + if total == 0 { + t.FMeasure = 0 + } else { + t.FMeasure = 2 / total + } + + total = float64(t.TP + t.TN + t.FP + t.FN) + if total == 0 { + t.Accuracy = 0 + } else { + t.Accuracy = float64(t.TP+t.TN) / total + } +} + +// +// OpenOOBStatsFile will open statistic file for output. +// +func (rt *Runtime) OpenOOBStatsFile() error { + if rt.oobWriter != nil { + _ = rt.CloseOOBStatsFile() + } + rt.oobWriter = &dsv.Writer{} + return rt.oobWriter.OpenOutput(rt.OOBStatsFile) +} + +// +// WriteOOBStat will write statistic of process to file. +// +func (rt *Runtime) WriteOOBStat(stat *Stat) error { + if rt.oobWriter == nil { + return nil + } + if stat == nil { + return nil + } + return rt.oobWriter.WriteRawRow(stat.ToRow(), nil, nil) +} + +// +// CloseOOBStatsFile will close statistics file for writing. +// +func (rt *Runtime) CloseOOBStatsFile() (e error) { + if rt.oobWriter == nil { + return + } + + e = rt.oobWriter.Close() + rt.oobWriter = nil + + return +} + +// +// PrintOobStat will print the out-of-bag statistic to standard output. +// +func (rt *Runtime) PrintOobStat(stat *Stat, cm *CM) { + fmt.Printf("%s OOB error rate: %.4f,"+ + " total: %.4f, mean %.4f, true rate: %.4f\n", tag, + stat.OobError, rt.oobStatTotal.OobError, + stat.OobErrorMean, cm.GetTrueRate()) +} + +// +// PrintStat will print statistic value to standard output. +// +func (rt *Runtime) PrintStat(stat *Stat) { + if stat == nil { + statslen := len(rt.oobStats) + if statslen <= 0 { + return + } + stat = rt.oobStats[statslen-1] + } + + fmt.Printf("%s TPRate: %.4f, FPRate: %.4f,"+ + " TNRate: %.4f, precision: %.4f, f-measure: %.4f,"+ + " accuracy: %.4f\n", tag, stat.TPRate, stat.FPRate, + stat.TNRate, stat.Precision, stat.FMeasure, stat.Accuracy) +} + +// +// PrintStatTotal will print total statistic to standard output. +// +func (rt *Runtime) PrintStatTotal(st *Stat) { + if st == nil { + st = &rt.oobStatTotal + } + rt.PrintStat(st) +} + +// +// Performance given an actuals class label and their probabilities, compute +// the performance statistic of classifier. +// +// Algorithm, +// (1) Sort the probabilities in descending order. +// (2) Sort the actuals and predicts using sorted index from probs +// (3) Compute tpr, fpr, precision +// (4) Write performance to file. +// +func (rt *Runtime) Performance(samples tabula.ClasetInterface, + predicts []string, probs []float64, +) ( + perfs Stats, +) { + // (1) + actuals := samples.GetClassAsStrings() + sortedIds := numbers.IntCreateSeq(0, len(probs)-1) + numbers.Floats64InplaceMergesort(probs, sortedIds, 0, len(probs), + false) + + // (2) + libstrings.SortByIndex(&actuals, sortedIds) + libstrings.SortByIndex(&predicts, sortedIds) + + // (3) + rt.computePerfByProbs(samples, actuals, probs) + + return rt.perfs +} + +func trapezoidArea(fp, fpprev, tp, tpprev int64) float64 { + base := math.Abs(float64(fp - fpprev)) + heightAvg := float64(tp+tpprev) / float64(2.0) + return base * heightAvg +} + +// +// computePerfByProbs will compute classifier performance using probabilities +// or score `probs`. +// +// This currently only work for two class problem. +// +func (rt *Runtime) computePerfByProbs(samples tabula.ClasetInterface, + actuals []string, probs []float64, +) { + vs := samples.GetClassValueSpace() + nactuals := numbers.IntsTo64(samples.Counts()) + nclass := libstrings.CountTokens(actuals, vs, false) + + pprev := math.Inf(-1) + tp := int64(0) + fp := int64(0) + tpprev := int64(0) + fpprev := int64(0) + + auc := float64(0) + + for x, p := range probs { + if p != pprev { + stat := Stat{} + stat.SetTPRate(tp, nactuals[0]) + stat.SetFPRate(fp, nactuals[1]) + stat.SetPrecisionFromRate(nactuals[0], nactuals[1]) + + auc = auc + trapezoidArea(fp, fpprev, tp, tpprev) + stat.SetAUC(auc) + + rt.perfs = append(rt.perfs, &stat) + + pprev = p + tpprev = tp + fpprev = fp + } + + if actuals[x] == vs[0] { + tp++ + } else { + fp++ + } + } + + stat := Stat{} + stat.SetTPRate(tp, nactuals[0]) + stat.SetFPRate(fp, nactuals[1]) + stat.SetPrecisionFromRate(nactuals[0], nactuals[1]) + + auc = auc + trapezoidArea(fp, fpprev, tp, tpprev) + auc = auc / float64(nclass[0]*nclass[1]) + stat.SetAUC(auc) + + rt.perfs = append(rt.perfs, &stat) + + if len(rt.perfs) >= 2 { + // Replace the first stat with second stat, because of NaN + // value on the first precision. + rt.perfs[0] = rt.perfs[1] + } +} + +// +// WritePerformance will write performance data to file. +// +func (rt *Runtime) WritePerformance() error { + return rt.perfs.Write(rt.PerfFile) +} diff --git a/lib/mining/classifier/stat.go b/lib/mining/classifier/stat.go new file mode 100644 index 00000000..790aec10 --- /dev/null +++ b/lib/mining/classifier/stat.go @@ -0,0 +1,176 @@ +// Copyright 2015-2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +import ( + "time" + + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/tabula" +) + +// +// Stat hold statistic value of classifier, including TP rate, FP rate, precision, +// and recall. +// +type Stat struct { + // ID unique id for this statistic (e.g. number of tree). + ID int64 + // StartTime contain the start time of classifier in unix timestamp. + StartTime int64 + // EndTime contain the end time of classifier in unix timestamp. + EndTime int64 + // ElapsedTime contain actual time, in seconds, between end and start + // time. + ElapsedTime int64 + // TP contain true-positive value. + TP int64 + // FP contain false-positive value. + FP int64 + // TN contain true-negative value. + TN int64 + // FN contain false-negative value. + FN int64 + // OobError contain out-of-bag error. + OobError float64 + // OobErrorMean contain mean of out-of-bag error. + OobErrorMean float64 + // TPRate contain true-positive rate (recall): tp/(tp+fn) + TPRate float64 + // FPRate contain false-positive rate: fp/(fp+tn) + FPRate float64 + // TNRate contain true-negative rate: tn/(tn+fp) + TNRate float64 + // Precision contain: tp/(tp+fp) + Precision float64 + // FMeasure contain value of F-measure or the harmonic mean of + // precision and recall. + FMeasure float64 + // Accuracy contain the degree of closeness of measurements of a + // quantity to that quantity's true value. + Accuracy float64 + // AUC contain the area under curve. + AUC float64 +} + +// SetAUC will set the AUC value. +func (stat *Stat) SetAUC(v float64) { + stat.AUC = v +} + +// +// SetTPRate will set TP and TPRate using number of positive `p`. +// +func (stat *Stat) SetTPRate(tp, p int64) { + stat.TP = tp + stat.TPRate = float64(tp) / float64(p) +} + +// +// SetFPRate will set FP and FPRate using number of negative `n`. +// +func (stat *Stat) SetFPRate(fp, n int64) { + stat.FP = fp + stat.FPRate = float64(fp) / float64(n) +} + +// +// SetPrecisionFromRate will set Precision value using tprate and fprate. +// `p` and `n` is the number of positive and negative class in samples. +// +func (stat *Stat) SetPrecisionFromRate(p, n int64) { + stat.Precision = (stat.TPRate * float64(p)) / + ((stat.TPRate * float64(p)) + (stat.FPRate * float64(n))) +} + +// +// Recall return value of recall. +// +func (stat *Stat) Recall() float64 { + return stat.TPRate +} + +// +// Sum will add statistic from other stat object to current stat, not including +// the start and end time. +// +func (stat *Stat) Sum(other *Stat) { + stat.OobError += other.OobError + stat.OobErrorMean += other.OobErrorMean + stat.TP += other.TP + stat.FP += other.FP + stat.TN += other.TN + stat.FN += other.FN + stat.TPRate += other.TPRate + stat.FPRate += other.FPRate + stat.TNRate += other.TNRate + stat.Precision += other.Precision + stat.FMeasure += other.FMeasure + stat.Accuracy += other.Accuracy +} + +// +// ToRow will convert the stat to tabula.row in the order of Stat field. +// +func (stat *Stat) ToRow() (row *tabula.Row) { + row = &tabula.Row{} + + row.PushBack(tabula.NewRecordInt(stat.ID)) + row.PushBack(tabula.NewRecordInt(stat.StartTime)) + row.PushBack(tabula.NewRecordInt(stat.EndTime)) + row.PushBack(tabula.NewRecordInt(stat.ElapsedTime)) + row.PushBack(tabula.NewRecordReal(stat.OobError)) + row.PushBack(tabula.NewRecordReal(stat.OobErrorMean)) + row.PushBack(tabula.NewRecordInt(stat.TP)) + row.PushBack(tabula.NewRecordInt(stat.FP)) + row.PushBack(tabula.NewRecordInt(stat.TN)) + row.PushBack(tabula.NewRecordInt(stat.FN)) + row.PushBack(tabula.NewRecordReal(stat.TPRate)) + row.PushBack(tabula.NewRecordReal(stat.FPRate)) + row.PushBack(tabula.NewRecordReal(stat.TNRate)) + row.PushBack(tabula.NewRecordReal(stat.Precision)) + row.PushBack(tabula.NewRecordReal(stat.FMeasure)) + row.PushBack(tabula.NewRecordReal(stat.Accuracy)) + row.PushBack(tabula.NewRecordReal(stat.AUC)) + + return +} + +// +// Start will start the timer. +// +func (stat *Stat) Start() { + stat.StartTime = time.Now().Unix() +} + +// +// End will stop the timer and compute the elapsed time. +// +func (stat *Stat) End() { + stat.EndTime = time.Now().Unix() + stat.ElapsedTime = stat.EndTime - stat.StartTime +} + +// +// Write will write the content of stat to `file`. +// +func (stat *Stat) Write(file string) (e error) { + if file == "" { + return + } + + writer := &dsv.Writer{} + e = writer.OpenOutput(file) + if e != nil { + return e + } + + e = writer.WriteRawRow(stat.ToRow(), nil, nil) + if e != nil { + return e + } + + return writer.Close() +} diff --git a/lib/mining/classifier/stats.go b/lib/mining/classifier/stats.go new file mode 100644 index 00000000..b29d3510 --- /dev/null +++ b/lib/mining/classifier/stats.go @@ -0,0 +1,143 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +import ( + "github.com/shuLhan/share/lib/dsv" +) + +// +// Stats define list of statistic values. +// +type Stats []*Stat + +// +// Add will add other stat object to the slice. +// +func (stats *Stats) Add(stat *Stat) { + *stats = append(*stats, stat) +} + +// +// StartTimes return all start times in unix timestamp. +// +func (stats *Stats) StartTimes() (times []int64) { + for _, stat := range *stats { + times = append(times, stat.StartTime) + } + return +} + +// +// EndTimes return all end times in unix timestamp. +// +func (stats *Stats) EndTimes() (times []int64) { + for _, stat := range *stats { + times = append(times, stat.EndTime) + } + return +} + +// +// OobErrorMeans return all out-of-bag error mean values. +// +func (stats *Stats) OobErrorMeans() (oobmeans []float64) { + oobmeans = make([]float64, len(*stats)) + for x, stat := range *stats { + oobmeans[x] = stat.OobErrorMean + } + return +} + +// +// TPRates return all true-positive rate values. +// +func (stats *Stats) TPRates() (tprates []float64) { + for _, stat := range *stats { + tprates = append(tprates, stat.TPRate) + } + return +} + +// +// FPRates return all false-positive rate values. +// +func (stats *Stats) FPRates() (fprates []float64) { + for _, stat := range *stats { + fprates = append(fprates, stat.FPRate) + } + return +} + +// +// TNRates will return all true-negative rate values. +// +func (stats *Stats) TNRates() (tnrates []float64) { + for _, stat := range *stats { + tnrates = append(tnrates, stat.TNRate) + } + return +} + +// +// Precisions return all precision values. +// +func (stats *Stats) Precisions() (precs []float64) { + for _, stat := range *stats { + precs = append(precs, stat.Precision) + } + return +} + +// +// Recalls return all recall values. +// +func (stats *Stats) Recalls() (recalls []float64) { + return stats.TPRates() +} + +// +// FMeasures return all F-measure values. +// +func (stats *Stats) FMeasures() (fmeasures []float64) { + for _, stat := range *stats { + fmeasures = append(fmeasures, stat.FMeasure) + } + return +} + +// +// Accuracies return all accuracy values. +// +func (stats *Stats) Accuracies() (accuracies []float64) { + for _, stat := range *stats { + accuracies = append(accuracies, stat.Accuracy) + } + return +} + +// +// Write will write all statistic data to `file`. +// +func (stats *Stats) Write(file string) (e error) { + if file == "" { + return + } + + writer := &dsv.Writer{} + e = writer.OpenOutput(file) + if e != nil { + return e + } + + for _, st := range *stats { + e = writer.WriteRawRow(st.ToRow(), nil, nil) + if e != nil { + return e + } + } + + return writer.Close() +} diff --git a/lib/mining/classifier/stats_interface.go b/lib/mining/classifier/stats_interface.go new file mode 100644 index 00000000..a9b60b9a --- /dev/null +++ b/lib/mining/classifier/stats_interface.go @@ -0,0 +1,68 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +// +// ComputeFMeasures given array of precisions and recalls, compute F-measure +// of each instance and return it. +// +func ComputeFMeasures(precisions, recalls []float64) (fmeasures []float64) { + // Get the minimum length of precision and recall. + // This is to make sure that we are not looping out of range. + minlen := len(precisions) + recallslen := len(recalls) + if recallslen < minlen { + minlen = recallslen + } + + for x := 0; x < minlen; x++ { + f := 2 / ((1 / precisions[x]) + (1 / recalls[x])) + fmeasures = append(fmeasures, f) + } + return +} + +// +// ComputeAccuracies will compute and return accuracy from array of +// true-positive, false-positive, true-negative, and false-negative; using +// formula, +// +// (tp + tn) / (tp + tn + tn + fn) +// +func ComputeAccuracies(tp, fp, tn, fn []int64) (accuracies []float64) { + // Get minimum length of input, just to make sure we are not looping + // out of range. + minlen := len(tp) + if len(fp) < len(tn) { + minlen = len(fp) + } + if len(fn) < minlen { + minlen = len(fn) + } + + for x := 0; x < minlen; x++ { + acc := float64(tp[x]+tn[x]) / + float64(tp[x]+fp[x]+tn[x]+fn[x]) + accuracies = append(accuracies, acc) + } + return +} + +// +// ComputeElapsedTimes will compute and return elapsed time between `start` +// and `end` timestamps. +// +func ComputeElapsedTimes(start, end []int64) (elaps []int64) { + // Get minimum length. + minlen := len(start) + if len(end) < minlen { + minlen = len(end) + } + + for x := 0; x < minlen; x++ { + elaps = append(elaps, end[x]-start[x]) + } + return +} |
