aboutsummaryrefslogtreecommitdiff
path: root/lib/mining/classifier
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2018-09-17 05:04:26 +0700
committerShulhan <ms@kilabit.info>2018-09-18 01:50:21 +0700
commit1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e (patch)
tree5fa83fc0faa31e09cae82ac4d467cf8ba5f87fc2 /lib/mining/classifier
parent446fef94cd712861221c0098dcdd9ae52aaed0eb (diff)
downloadpakakeh.go-1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e.tar.xz
Merge package "github.com/shuLhan/go-mining"
Diffstat (limited to 'lib/mining/classifier')
-rw-r--r--lib/mining/classifier/cart/cart.go480
-rw-r--r--lib/mining/classifier/cart/cart_test.go62
-rw-r--r--lib/mining/classifier/cart/node.go44
-rw-r--r--lib/mining/classifier/cm.go442
-rw-r--r--lib/mining/classifier/cm_test.go69
-rw-r--r--lib/mining/classifier/crf/crf.go470
-rw-r--r--lib/mining/classifier/crf/crf_test.go92
-rw-r--r--lib/mining/classifier/rf/rf.go363
-rw-r--r--lib/mining/classifier/rf/rf_bench_test.go22
-rw-r--r--lib/mining/classifier/rf/rf_test.go190
-rw-r--r--lib/mining/classifier/runtime.go443
-rw-r--r--lib/mining/classifier/stat.go176
-rw-r--r--lib/mining/classifier/stats.go143
-rw-r--r--lib/mining/classifier/stats_interface.go68
14 files changed, 3064 insertions, 0 deletions
diff --git a/lib/mining/classifier/cart/cart.go b/lib/mining/classifier/cart/cart.go
new file mode 100644
index 00000000..449781d9
--- /dev/null
+++ b/lib/mining/classifier/cart/cart.go
@@ -0,0 +1,480 @@
+// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//
+// Package cart implement the Classification and Regression Tree by Breiman, et al.
+// CART is binary decision tree.
+//
+// Breiman, Leo, et al. Classification and regression trees. CRC press,
+// 1984.
+//
+// The implementation is based on Data Mining book,
+//
+// Han, Jiawei, Micheline Kamber, and Jian Pei. Data mining: concepts and
+// techniques: concepts and techniques. Elsevier, 2011.
+//
+package cart
+
+import (
+ "fmt"
+
+ "github.com/shuLhan/share/lib/debug"
+ "github.com/shuLhan/share/lib/mining/gain/gini"
+ "github.com/shuLhan/share/lib/mining/tree/binary"
+ "github.com/shuLhan/share/lib/numbers"
+ libstrings "github.com/shuLhan/share/lib/strings"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+const (
+ // SplitMethodGini if defined in Runtime, the dataset will be splitted
+ // using Gini gain for each possible value or partition.
+ //
+ // This option is used in Runtime.SplitMethod.
+ SplitMethodGini = "gini"
+)
+
+const (
+ // ColFlagParent denote that the column is parent/split node.
+ ColFlagParent = 1
+ // ColFlagSkip denote that the column would be skipped.
+ ColFlagSkip = 2
+)
+
+//
+// Runtime data for building CART.
+//
+type Runtime struct {
+ // SplitMethod define the criteria to used for splitting.
+ SplitMethod string `json:"SplitMethod"`
+ // NRandomFeature if less or equal to zero compute gain on all feature,
+ // otherwise select n random feature and compute gain only on selected
+ // features.
+ NRandomFeature int `json:"NRandomFeature"`
+ // OOBErrVal is the last out-of-bag error value in the tree.
+ OOBErrVal float64
+ // Tree in classification.
+ Tree binary.Tree
+}
+
+//
+// New create new Runtime object.
+//
+func New(D tabula.ClasetInterface, splitMethod string, nRandomFeature int) (
+ *Runtime, error,
+) {
+ runtime := &Runtime{
+ SplitMethod: splitMethod,
+ NRandomFeature: nRandomFeature,
+ Tree: binary.Tree{},
+ }
+
+ e := runtime.Build(D)
+ if e != nil {
+ return nil, e
+ }
+
+ return runtime, nil
+}
+
+//
+// Build will create a tree using CART algorithm.
+//
+func (runtime *Runtime) Build(D tabula.ClasetInterface) (e error) {
+ // Re-check input configuration.
+ switch runtime.SplitMethod {
+ case SplitMethodGini:
+ // Do nothing.
+ default:
+ // Set default split method to Gini index.
+ runtime.SplitMethod = SplitMethodGini
+ }
+
+ runtime.Tree.Root, e = runtime.splitTreeByGain(D)
+
+ return
+}
+
+//
+// splitTreeByGain calculate the gain in all dataset, and split into two node:
+// left and right.
+//
+// Return node with the split information.
+//
+func (runtime *Runtime) splitTreeByGain(D tabula.ClasetInterface) (
+ node *binary.BTNode,
+ e error,
+) {
+ node = &binary.BTNode{}
+
+ D.RecountMajorMinor()
+
+ // if dataset is empty return node labeled with majority classes in
+ // dataset.
+ nrow := D.GetNRow()
+
+ if nrow <= 0 {
+ if debug.Value >= 2 {
+ fmt.Printf("[cart] empty dataset (%s) : %v\n",
+ D.MajorityClass(), D)
+ }
+
+ node.Value = NodeValue{
+ IsLeaf: true,
+ Class: D.MajorityClass(),
+ Size: 0,
+ }
+ return node, nil
+ }
+
+ // if all dataset is in the same class, return node as leaf with class
+ // is set to that class.
+ single, name := D.IsInSingleClass()
+ if single {
+ if debug.Value >= 2 {
+ fmt.Printf("[cart] in single class (%s): %v\n", name,
+ D.GetColumns())
+ }
+
+ node.Value = NodeValue{
+ IsLeaf: true,
+ Class: name,
+ Size: nrow,
+ }
+ return node, nil
+ }
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] D:", D)
+ }
+
+ // calculate the Gini gain for each attribute.
+ gains := runtime.computeGain(D)
+
+ // get attribute with maximum Gini gain.
+ MaxGainIdx := gini.FindMaxGain(&gains)
+ MaxGain := gains[MaxGainIdx]
+
+ // if maxgain value is 0, use majority class as node and terminate
+ // the process
+ if MaxGain.GetMaxGainValue() == 0 {
+ if debug.Value >= 2 {
+ fmt.Println("[cart] max gain 0 with target",
+ D.GetClassAsStrings(),
+ " and majority class is ", D.MajorityClass())
+ }
+
+ node.Value = NodeValue{
+ IsLeaf: true,
+ Class: D.MajorityClass(),
+ Size: 0,
+ }
+ return node, nil
+ }
+
+ // using the sorted index in MaxGain, sort all field in dataset
+ tabula.SortColumnsByIndex(D, MaxGain.SortedIndex)
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] maxgain:", MaxGain)
+ }
+
+ // Now that we have attribute with max gain in MaxGainIdx, and their
+ // gain dan partition value in Gains[MaxGainIdx] and
+ // GetMaxPartValue(), we split the dataset based on type of max-gain
+ // attribute.
+ // If its continuous, split the attribute using numeric value.
+ // If its discrete, split the attribute using subset (partition) of
+ // nominal values.
+ var splitV interface{}
+
+ if MaxGain.IsContinu {
+ splitV = MaxGain.GetMaxPartGainValue()
+ } else {
+ attrPartV := MaxGain.GetMaxPartGainValue()
+ attrSubV := attrPartV.(libstrings.Row)
+ splitV = attrSubV[0]
+ }
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] maxgainindex:", MaxGainIdx)
+ fmt.Println("[cart] split v:", splitV)
+ }
+
+ node.Value = NodeValue{
+ SplitAttrName: D.GetColumn(MaxGainIdx).GetName(),
+ IsLeaf: false,
+ IsContinu: MaxGain.IsContinu,
+ Size: nrow,
+ SplitAttrIdx: MaxGainIdx,
+ SplitV: splitV,
+ }
+
+ dsL, dsR, e := tabula.SplitRowsByValue(D, MaxGainIdx, splitV)
+
+ if e != nil {
+ return node, e
+ }
+
+ splitL := dsL.(tabula.ClasetInterface)
+ splitR := dsR.(tabula.ClasetInterface)
+
+ // Set the flag to parent in attribute referenced by
+ // MaxGainIdx, so it will not computed again in the next round.
+ cols := splitL.GetColumns()
+ for x := range *cols {
+ if x == MaxGainIdx {
+ (*cols)[x].Flag = ColFlagParent
+ } else {
+ (*cols)[x].Flag = 0
+ }
+ }
+
+ cols = splitR.GetColumns()
+ for x := range *cols {
+ if x == MaxGainIdx {
+ (*cols)[x].Flag = ColFlagParent
+ } else {
+ (*cols)[x].Flag = 0
+ }
+ }
+
+ nodeLeft, e := runtime.splitTreeByGain(splitL)
+ if e != nil {
+ return node, e
+ }
+
+ nodeRight, e := runtime.splitTreeByGain(splitR)
+ if e != nil {
+ return node, e
+ }
+
+ node.SetLeft(nodeLeft)
+ node.SetRight(nodeRight)
+
+ return node, nil
+}
+
+// SelectRandomFeature if NRandomFeature is greater than zero, select and
+// compute gain in n random features instead of in all features
+func (runtime *Runtime) SelectRandomFeature(D tabula.ClasetInterface) {
+ if runtime.NRandomFeature <= 0 {
+ // all features selected
+ return
+ }
+
+ ncols := D.GetNColumn()
+
+ // count all features minus class
+ nfeature := ncols - 1
+ if runtime.NRandomFeature >= nfeature {
+ // Do nothing if number of random feature equal or greater than
+ // number of feature in dataset.
+ return
+ }
+
+ // exclude class index and parent node index
+ excludeIdx := []int{D.GetClassIndex()}
+ cols := D.GetColumns()
+ for x, col := range *cols {
+ if (col.Flag & ColFlagParent) == ColFlagParent {
+ excludeIdx = append(excludeIdx, x)
+ } else {
+ (*cols)[x].Flag |= ColFlagSkip
+ }
+ }
+
+ // Select random features excluding feature in `excludeIdx`.
+ var pickedIdx []int
+ for x := 0; x < runtime.NRandomFeature; x++ {
+ idx := numbers.IntPickRandPositive(ncols, false, pickedIdx,
+ excludeIdx)
+ pickedIdx = append(pickedIdx, idx)
+
+ // Remove skip flag on selected column
+ col := D.GetColumn(idx)
+ col.Flag = col.Flag &^ ColFlagSkip
+ }
+
+ if debug.Value >= 1 {
+ fmt.Println("[cart] selected random features:", pickedIdx)
+ fmt.Println("[cart] selected columns :", D.GetColumns())
+ }
+}
+
+//
+// computeGain calculate the gini index for each value in each attribute.
+//
+func (runtime *Runtime) computeGain(D tabula.ClasetInterface) (
+ gains []gini.Gini,
+) {
+ switch runtime.SplitMethod {
+ case SplitMethodGini:
+ // create gains value for all attribute minus target class.
+ gains = make([]gini.Gini, D.GetNColumn())
+ }
+
+ runtime.SelectRandomFeature(D)
+
+ classVS := D.GetClassValueSpace()
+ classIdx := D.GetClassIndex()
+ classType := D.GetClassType()
+
+ for x, col := range *D.GetColumns() {
+ // skip class attribute.
+ if x == classIdx {
+ continue
+ }
+
+ // skip column flagged with parent
+ if (col.Flag & ColFlagParent) == ColFlagParent {
+ gains[x].Skip = true
+ continue
+ }
+
+ // ignore column flagged with skip
+ if (col.Flag & ColFlagSkip) == ColFlagSkip {
+ gains[x].Skip = true
+ continue
+ }
+
+ // compute gain.
+ if col.GetType() == tabula.TReal {
+ attr := col.ToFloatSlice()
+
+ if classType == tabula.TString {
+ target := D.GetClassAsStrings()
+ gains[x].ComputeContinu(&attr, &target,
+ &classVS)
+ } else {
+ targetReal := D.GetClassAsReals()
+ classVSReal := libstrings.ToFloat64(classVS)
+
+ gains[x].ComputeContinuFloat(&attr,
+ &targetReal, &classVSReal)
+ }
+ } else {
+ attr := col.ToStringSlice()
+ attrV := col.ValueSpace
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] attr :", attr)
+ fmt.Println("[cart] attrV:", attrV)
+ }
+
+ target := D.GetClassAsStrings()
+ gains[x].ComputeDiscrete(&attr, &attrV, &target,
+ &classVS)
+ }
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] gain :", gains[x])
+ }
+ }
+ return
+}
+
+//
+// Classify return the prediction of one sample.
+//
+func (runtime *Runtime) Classify(data *tabula.Row) (class string) {
+ node := runtime.Tree.Root
+ nodev := node.Value.(NodeValue)
+
+ for !nodev.IsLeaf {
+ if nodev.IsContinu {
+ splitV := nodev.SplitV.(float64)
+ attrV := (*data)[nodev.SplitAttrIdx].Float()
+
+ if attrV < splitV {
+ node = node.Left
+ } else {
+ node = node.Right
+ }
+ } else {
+ splitV := nodev.SplitV.([]string)
+ attrV := (*data)[nodev.SplitAttrIdx].String()
+
+ if libstrings.IsContain(splitV, attrV) {
+ node = node.Left
+ } else {
+ node = node.Right
+ }
+ }
+ nodev = node.Value.(NodeValue)
+ }
+
+ return nodev.Class
+}
+
+//
+// ClassifySet set the class attribute based on tree classification.
+//
+func (runtime *Runtime) ClassifySet(data tabula.ClasetInterface) (e error) {
+ nrow := data.GetNRow()
+ targetAttr := data.GetClassColumn()
+
+ for i := 0; i < nrow; i++ {
+ class := runtime.Classify(data.GetRow(i))
+
+ _ = (*targetAttr).Records[i].SetValue(class, tabula.TString)
+ }
+
+ return
+}
+
+//
+// CountOOBError process out-of-bag data on tree and return error value.
+//
+func (runtime *Runtime) CountOOBError(oob tabula.Claset) (
+ errval float64,
+ e error,
+) {
+ // save the original target to be compared later.
+ origTarget := oob.GetClassAsStrings()
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] OOB:", oob.Columns)
+ fmt.Println("[cart] TREE:", &runtime.Tree)
+ }
+
+ // reset the target.
+ oobtarget := oob.GetClassColumn()
+ oobtarget.ClearValues()
+
+ e = runtime.ClassifySet(&oob)
+
+ if e != nil {
+ // set original target values back.
+ oobtarget.SetValues(origTarget)
+ return
+ }
+
+ target := oobtarget.ToStringSlice()
+
+ if debug.Value >= 2 {
+ fmt.Println("[cart] original target:", origTarget)
+ fmt.Println("[cart] classify target:", target)
+ }
+
+ // count how many target value is miss-classified.
+ runtime.OOBErrVal, _, _ = libstrings.CountMissRate(origTarget, target)
+
+ // set original target values back.
+ oobtarget.SetValues(origTarget)
+
+ return runtime.OOBErrVal, nil
+}
+
+//
+// String yes, it will print it JSON like format.
+//
+func (runtime *Runtime) String() (s string) {
+ s = fmt.Sprintf("NRandomFeature: %d\n"+
+ " SplitMethod : %s\n"+
+ " Tree :\n%v", runtime.NRandomFeature,
+ runtime.SplitMethod,
+ runtime.Tree.String())
+ return s
+}
diff --git a/lib/mining/classifier/cart/cart_test.go b/lib/mining/classifier/cart/cart_test.go
new file mode 100644
index 00000000..14e89b12
--- /dev/null
+++ b/lib/mining/classifier/cart/cart_test.go
@@ -0,0 +1,62 @@
+// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cart
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/tabula"
+ "github.com/shuLhan/share/lib/test"
+)
+
+const (
+ NRows = 150
+)
+
+func TestCART(t *testing.T) {
+ fds := "../../testdata/iris/iris.dsv"
+
+ ds := tabula.Claset{}
+
+ _, e := dsv.SimpleRead(fds, &ds)
+ if nil != e {
+ t.Fatal(e)
+ }
+
+ fmt.Println("[cart_test] class index:", ds.GetClassIndex())
+
+ // copy target to be compared later.
+ targetv := ds.GetClassAsStrings()
+
+ test.Assert(t, "", NRows, ds.GetNRow(), true)
+
+ // Build CART tree.
+ CART, e := New(&ds, SplitMethodGini, 0)
+ if e != nil {
+ t.Fatal(e)
+ }
+
+ fmt.Println("[cart_test] CART Tree:\n", CART)
+
+ // Create test set
+ testset := tabula.Claset{}
+ _, e = dsv.SimpleRead(fds, &testset)
+
+ if nil != e {
+ t.Fatal(e)
+ }
+
+ testset.GetClassColumn().ClearValues()
+
+ // Classifiy test set
+ e = CART.ClassifySet(&testset)
+ if nil != e {
+ t.Fatal(e)
+ }
+
+ test.Assert(t, "", targetv, testset.GetClassAsStrings(), true)
+}
diff --git a/lib/mining/classifier/cart/node.go b/lib/mining/classifier/cart/node.go
new file mode 100644
index 00000000..b64dd13c
--- /dev/null
+++ b/lib/mining/classifier/cart/node.go
@@ -0,0 +1,44 @@
+// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cart
+
+import (
+ "fmt"
+ "reflect"
+)
+
+//
+// NodeValue of tree in CART.
+//
+type NodeValue struct {
+ // Class of leaf node.
+ Class string
+ // SplitAttrName define the name of attribute which cause the split.
+ SplitAttrName string
+ // IsLeaf define whether node is a leaf or not.
+ IsLeaf bool
+ // IsContinu define whether the node split is continuous or discrete.
+ IsContinu bool
+ // Size define number of sample that this node hold before splitting.
+ Size int
+ // SplitAttrIdx define the attribute which cause the split.
+ SplitAttrIdx int
+ // SplitV define the split value.
+ SplitV interface{}
+}
+
+//
+// String will return the value of node for printable.
+//
+func (nodev *NodeValue) String() (s string) {
+ if nodev.IsLeaf {
+ s = fmt.Sprintf("Class: %s", nodev.Class)
+ } else {
+ s = fmt.Sprintf("(SplitValue: %v)",
+ reflect.ValueOf(nodev.SplitV))
+ }
+
+ return s
+}
diff --git a/lib/mining/classifier/cm.go b/lib/mining/classifier/cm.go
new file mode 100644
index 00000000..0dc2ee05
--- /dev/null
+++ b/lib/mining/classifier/cm.go
@@ -0,0 +1,442 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package classifier
+
+import (
+ "fmt"
+ "strconv"
+
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+//
+// CM represent the matrix of classification.
+//
+type CM struct {
+ tabula.Dataset
+ // rowNames contain name in each row.
+ rowNames []string
+ // nSamples contain number of class.
+ nSamples int64
+ // nTrue contain number of true positive and negative.
+ nTrue int64
+ // nFalse contain number of false positive and negative.
+ nFalse int64
+
+ // tpIds contain index of true-positive samples.
+ tpIds []int
+ // fpIds contain index of false-positive samples.
+ fpIds []int
+ // tnIds contain index of true-negative samples.
+ tnIds []int
+ // fnIds contain index of false-negative samples.
+ fnIds []int
+}
+
+//
+// initByNumeric will initialize confusion matrix using numeric value space.
+//
+func (cm *CM) initByNumeric(vs []int64) {
+ var colTypes []int
+ var colNames []string
+
+ for _, v := range vs {
+ vstr := strconv.FormatInt(v, 10)
+ colTypes = append(colTypes, tabula.TInteger)
+ colNames = append(colNames, vstr)
+ cm.rowNames = append(cm.rowNames, vstr)
+ }
+
+ // class error column
+ colTypes = append(colTypes, tabula.TReal)
+ colNames = append(colNames, "class_error")
+
+ cm.Dataset.Init(tabula.DatasetModeMatrix, colTypes, colNames)
+}
+
+//
+// ComputeStrings will calculate confusion matrix using targets and predictions
+// class values.
+//
+func (cm *CM) ComputeStrings(valueSpace, targets, predictions []string) {
+ cm.init(valueSpace)
+
+ for x, target := range valueSpace {
+ col := cm.GetColumn(x)
+
+ for _, predict := range valueSpace {
+ cnt := cm.countTargetPrediction(target, predict,
+ targets, predictions)
+
+ col.PushBack(tabula.NewRecordInt(cnt))
+ }
+
+ cm.PushColumnToRows(*col)
+ }
+
+ cm.computeClassError()
+}
+
+//
+// ComputeNumeric will calculate confusion matrix using targets and predictions
+// values.
+//
+func (cm *CM) ComputeNumeric(vs, actuals, predictions []int64) {
+ cm.initByNumeric(vs)
+
+ for x, act := range vs {
+ col := cm.GetColumn(x)
+
+ for _, pred := range vs {
+ cnt := cm.countNumeric(act, pred, actuals, predictions)
+
+ rec := tabula.NewRecordInt(cnt)
+ col.PushBack(rec)
+ }
+
+ cm.PushColumnToRows(*col)
+ }
+
+ cm.computeClassError()
+}
+
+//
+// create will initialize confusion matrix using value space.
+//
+func (cm *CM) init(valueSpace []string) {
+ var colTypes []int
+ var colNames []string
+
+ for _, v := range valueSpace {
+ colTypes = append(colTypes, tabula.TInteger)
+ colNames = append(colNames, v)
+ cm.rowNames = append(cm.rowNames, v)
+ }
+
+ // class error column
+ colTypes = append(colTypes, tabula.TReal)
+ colNames = append(colNames, "class_error")
+
+ cm.Dataset.Init(tabula.DatasetModeMatrix, colTypes, colNames)
+}
+
+//
+// countTargetPrediction will count and return number of true positive or false
+// positive in predictions using targets values.
+//
+func (cm *CM) countTargetPrediction(target, predict string,
+ targets, predictions []string,
+) (
+ cnt int64,
+) {
+ predictslen := len(predictions)
+
+ for x, v := range targets {
+ // In case out of range, where predictions length less than
+ // targets length.
+ if x >= predictslen {
+ break
+ }
+ if v != target {
+ continue
+ }
+ if predict == predictions[x] {
+ cnt++
+ }
+ }
+ return
+}
+
+//
+// countNumeric will count and return number of `pred` in predictions where
+// actuals value is `act`.
+//
+func (cm *CM) countNumeric(act, pred int64, actuals, predictions []int64) (
+ cnt int64,
+) {
+ // Find minimum length to mitigate out-of-range loop.
+ minlen := len(actuals)
+ if len(predictions) < minlen {
+ minlen = len(predictions)
+ }
+
+ for x := 0; x < minlen; x++ {
+ if actuals[x] != act {
+ continue
+ }
+ if predictions[x] != pred {
+ continue
+ }
+ cnt++
+ }
+ return cnt
+}
+
+//
+// computeClassError will compute the classification error in matrix.
+//
+func (cm *CM) computeClassError() {
+ var tp, fp int64
+
+ cm.nSamples = 0
+ cm.nFalse = 0
+
+ classcol := cm.GetNColumn() - 1
+ col := cm.GetColumnClassError()
+ rows := cm.GetDataAsRows()
+ for x, row := range *rows {
+ for y, cell := range *row {
+ if y == classcol {
+ break
+ }
+ if x == y {
+ tp = cell.Integer()
+ } else {
+ fp += cell.Integer()
+ }
+ }
+
+ nSamplePerRow := tp + fp
+ errv := float64(fp) / float64(nSamplePerRow)
+ col.PushBack(tabula.NewRecordReal(errv))
+
+ cm.nSamples += nSamplePerRow
+ cm.nTrue += tp
+ cm.nFalse += fp
+ }
+
+ cm.PushColumnToRows(*col)
+}
+
+//
+// GroupIndexPredictions given index of samples, group the samples by their
+// class of prediction. For example,
+//
+// sampleIds: [0, 1, 2, 3, 4, 5]
+// actuals: [1, 1, 0, 0, 1, 0]
+// predictions: [1, 0, 1, 0, 1, 1]
+//
+// This function will group the index by true-positive, false-positive,
+// true-negative, and false-negative, which result in,
+//
+// true-positive indices: [0, 4]
+// false-positive indices: [2, 5]
+// true-negative indices: [3]
+// false-negative indices: [1]
+//
+// This function assume that positive value as "1" and negative value as "0".
+//
+func (cm *CM) GroupIndexPredictions(sampleIds []int,
+ actuals, predictions []int64,
+) {
+ // Reset indices.
+ cm.tpIds = nil
+ cm.fpIds = nil
+ cm.tnIds = nil
+ cm.fnIds = nil
+
+ // Make sure we are not out-of-range when looping, always pick the
+ // minimum length between the three parameters.
+ min := len(sampleIds)
+ if len(actuals) < min {
+ min = len(actuals)
+ }
+ if len(predictions) < min {
+ min = len(predictions)
+ }
+
+ for x := 0; x < min; x++ {
+ if actuals[x] == 1 {
+ if predictions[x] == 1 {
+ cm.tpIds = append(cm.tpIds, sampleIds[x])
+ } else {
+ cm.fnIds = append(cm.fnIds, sampleIds[x])
+ }
+ } else {
+ if predictions[x] == 1 {
+ cm.fpIds = append(cm.fpIds, sampleIds[x])
+ } else {
+ cm.tnIds = append(cm.tnIds, sampleIds[x])
+ }
+ }
+ }
+}
+
+//
+// GroupIndexPredictionsStrings is an alternative to GroupIndexPredictions
+// which work with string class.
+//
+func (cm *CM) GroupIndexPredictionsStrings(sampleIds []int,
+ actuals, predictions []string,
+) {
+ if len(sampleIds) <= 0 {
+ return
+ }
+
+ // Reset indices.
+ cm.tpIds = nil
+ cm.fpIds = nil
+ cm.tnIds = nil
+ cm.fnIds = nil
+
+ // Make sure we are not out-of-range when looping, always pick the
+ // minimum length between the three parameters.
+ min := len(sampleIds)
+ if len(actuals) < min {
+ min = len(actuals)
+ }
+ if len(predictions) < min {
+ min = len(predictions)
+ }
+
+ for x := 0; x < min; x++ {
+ if actuals[x] == "1" {
+ if predictions[x] == "1" {
+ cm.tpIds = append(cm.tpIds, sampleIds[x])
+ } else {
+ cm.fnIds = append(cm.fnIds, sampleIds[x])
+ }
+ } else {
+ if predictions[x] == "1" {
+ cm.fpIds = append(cm.fpIds, sampleIds[x])
+ } else {
+ cm.tnIds = append(cm.tnIds, sampleIds[x])
+ }
+ }
+ }
+}
+
+//
+// GetColumnClassError return the last column which is the column that contain
+// the error of classification.
+//
+func (cm *CM) GetColumnClassError() *tabula.Column {
+ return cm.GetColumn(cm.GetNColumn() - 1)
+}
+
+//
+// GetTrueRate return true-positive rate in term of
+//
+// true-positive / (true-positive + false-positive)
+//
+func (cm *CM) GetTrueRate() float64 {
+ return float64(cm.nTrue) / float64(cm.nTrue+cm.nFalse)
+}
+
+//
+// GetFalseRate return false-positive rate in term of,
+//
+// false-positive / (false-positive + true negative)
+//
+func (cm *CM) GetFalseRate() float64 {
+ return float64(cm.nFalse) / float64(cm.nTrue+cm.nFalse)
+}
+
+//
+// TP return number of true-positive in confusion matrix.
+//
+func (cm *CM) TP() int {
+ row := cm.GetRow(0)
+ if row == nil {
+ return 0
+ }
+
+ v, _ := row.GetIntAt(0)
+ return int(v)
+}
+
+//
+// FP return number of false-positive in confusion matrix.
+//
+func (cm *CM) FP() int {
+ row := cm.GetRow(0)
+ if row == nil {
+ return 0
+ }
+
+ v, _ := row.GetIntAt(1)
+ return int(v)
+}
+
+//
+// FN return number of false-negative.
+//
+func (cm *CM) FN() int {
+ row := cm.GetRow(1)
+ if row == nil {
+ return 0
+ }
+ v, _ := row.GetIntAt(0)
+ return int(v)
+}
+
+//
+// TN return number of true-negative.
+//
+func (cm *CM) TN() int {
+ row := cm.GetRow(1)
+ if row == nil {
+ return 0
+ }
+ v, _ := row.GetIntAt(1)
+ return int(v)
+}
+
+//
+// TPIndices return indices of all true-positive samples.
+//
+func (cm *CM) TPIndices() []int {
+ return cm.tpIds
+}
+
+//
+// FNIndices return indices of all false-negative samples.
+//
+func (cm *CM) FNIndices() []int {
+ return cm.fnIds
+}
+
+//
+// FPIndices return indices of all false-positive samples.
+//
+func (cm *CM) FPIndices() []int {
+ return cm.fpIds
+}
+
+//
+// TNIndices return indices of all true-negative samples.
+//
+func (cm *CM) TNIndices() []int {
+ return cm.tnIds
+}
+
+//
+// String will return the output of confusion matrix in table like format.
+//
+func (cm *CM) String() (s string) {
+ s += "Confusion Matrix:\n"
+
+ // Row header: column names.
+ s += "\t"
+ for _, col := range cm.GetColumnsName() {
+ s += col + "\t"
+ }
+ s += "\n"
+
+ rows := cm.GetDataAsRows()
+ for x, row := range *rows {
+ s += cm.rowNames[x] + "\t"
+
+ for _, v := range *row {
+ s += v.String() + "\t"
+ }
+ s += "\n"
+ }
+
+ s += fmt.Sprintf("TP-FP indices %d %d\n", len(cm.tpIds), len(cm.fpIds))
+ s += fmt.Sprintf("FN-TN indices %d %d\n", len(cm.fnIds), len(cm.tnIds))
+
+ return
+}
diff --git a/lib/mining/classifier/cm_test.go b/lib/mining/classifier/cm_test.go
new file mode 100644
index 00000000..6fd47d93
--- /dev/null
+++ b/lib/mining/classifier/cm_test.go
@@ -0,0 +1,69 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package classifier
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/shuLhan/share/lib/test"
+)
+
+func TestComputeNumeric(t *testing.T) {
+ actuals := []int64{1, 1, 1, 0, 0, 0, 0}
+ predics := []int64{1, 1, 0, 0, 0, 0, 1}
+ vs := []int64{1, 0}
+ exp := []int{2, 1, 3, 1}
+
+ cm := &CM{}
+
+ cm.ComputeNumeric(vs, actuals, predics)
+
+ test.Assert(t, "", exp[0], cm.TP(), true)
+ test.Assert(t, "", exp[1], cm.FN(), true)
+ test.Assert(t, "", exp[2], cm.TN(), true)
+ test.Assert(t, "", exp[3], cm.FP(), true)
+
+ fmt.Println(cm)
+}
+
+func TestComputeStrings(t *testing.T) {
+ actuals := []string{"1", "1", "1", "0", "0", "0", "0"}
+ predics := []string{"1", "1", "0", "0", "0", "0", "1"}
+ vs := []string{"1", "0"}
+ exp := []int{2, 1, 3, 1}
+
+ cm := &CM{}
+
+ cm.ComputeStrings(vs, actuals, predics)
+
+ test.Assert(t, "", exp[0], cm.TP(), true)
+ test.Assert(t, "", exp[1], cm.FN(), true)
+ test.Assert(t, "", exp[2], cm.TN(), true)
+ test.Assert(t, "", exp[3], cm.FP(), true)
+
+ fmt.Println(cm)
+}
+
+func TestGroupIndexPredictions(t *testing.T) {
+ testIds := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+ actuals := []int64{1, 1, 1, 1, 0, 0, 0, 0, 0, 0}
+ predics := []int64{1, 1, 0, 1, 0, 0, 0, 0, 1, 0}
+ exp := [][]int{
+ {0, 1, 3}, // tp
+ {2}, // fn
+ {8}, // fp
+ {4, 5, 6, 7, 9}, // tn
+ }
+
+ cm := &CM{}
+
+ cm.GroupIndexPredictions(testIds, actuals, predics)
+
+ test.Assert(t, "", exp[0], cm.TPIndices(), true)
+ test.Assert(t, "", exp[1], cm.FNIndices(), true)
+ test.Assert(t, "", exp[2], cm.FPIndices(), true)
+ test.Assert(t, "", exp[3], cm.TNIndices(), true)
+}
diff --git a/lib/mining/classifier/crf/crf.go b/lib/mining/classifier/crf/crf.go
new file mode 100644
index 00000000..3df7dd83
--- /dev/null
+++ b/lib/mining/classifier/crf/crf.go
@@ -0,0 +1,470 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//
+// Package crf implement the cascaded random forest algorithm, proposed by
+// Baumann et.al in their paper:
+//
+// Baumann, Florian, et al. "Cascaded Random Forest for Fast Object
+// Detection." Image Analysis. Springer Berlin Heidelberg, 2013. 131-142.
+//
+//
+package crf
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "sort"
+
+ "github.com/shuLhan/share/lib/debug"
+ "github.com/shuLhan/share/lib/mining/classifier"
+ "github.com/shuLhan/share/lib/mining/classifier/rf"
+ "github.com/shuLhan/share/lib/numbers"
+ libstrings "github.com/shuLhan/share/lib/strings"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+const (
+ tag = "[crf]"
+
+ // DefStage default number of stage
+ DefStage = 200
+ // DefTPRate default threshold for true-positive rate.
+ DefTPRate = 0.9
+ // DefTNRate default threshold for true-negative rate.
+ DefTNRate = 0.7
+
+ // DefNumTree default number of tree.
+ DefNumTree = 1
+ // DefPercentBoot default percentage of sample that will be used for
+ // bootstraping a tree.
+ DefPercentBoot = 66
+ // DefPerfFile default performance file output.
+ DefPerfFile = "crf.perf"
+ // DefStatFile default statistic file output.
+ DefStatFile = "crf.stat"
+)
+
+var (
+ // ErrNoInput will tell you when no input is given.
+ ErrNoInput = errors.New("rf: input samples is empty")
+)
+
+//
+// Runtime define the cascaded random forest runtime input and output.
+//
+type Runtime struct {
+ // Runtime embed common fields for classifier.
+ classifier.Runtime
+
+ // NStage number of stage.
+ NStage int `json:"NStage"`
+ // TPRate threshold for true positive rate per stage.
+ TPRate float64 `json:"TPRate"`
+ // TNRate threshold for true negative rate per stage.
+ TNRate float64 `json:"TNRate"`
+
+ // NTree number of tree in each stage.
+ NTree int `json:"NTree"`
+ // NRandomFeature number of features used to split the dataset.
+ NRandomFeature int `json:"NRandomFeature"`
+ // PercentBoot percentage of bootstrap.
+ PercentBoot int `json:"PercentBoot"`
+
+ // forests contain forest for each stage.
+ forests []*rf.Runtime
+ // weights contain weight for each stage.
+ weights []float64
+ // tnset contain sample of all true-negative in each iteration.
+ tnset *tabula.Claset
+}
+
+//
+// New create and return new input for cascaded random-forest.
+//
+func New(nstage, ntree, percentboot, nfeature int,
+ tprate, tnrate float64,
+ samples tabula.ClasetInterface,
+) (
+ crf *Runtime,
+) {
+ crf = &Runtime{
+ NStage: nstage,
+ NTree: ntree,
+ PercentBoot: percentboot,
+ NRandomFeature: nfeature,
+ TPRate: tprate,
+ TNRate: tnrate,
+ }
+
+ return crf
+}
+
+//
+// AddForest will append new forest.
+//
+func (crf *Runtime) AddForest(forest *rf.Runtime) {
+ crf.forests = append(crf.forests, forest)
+}
+
+//
+// Initialize will check crf inputs and set it to default values if its
+// invalid.
+//
+func (crf *Runtime) Initialize(samples tabula.ClasetInterface) error {
+ if crf.NStage <= 0 {
+ crf.NStage = DefStage
+ }
+ if crf.TPRate <= 0 || crf.TPRate >= 1 {
+ crf.TPRate = DefTPRate
+ }
+ if crf.TNRate <= 0 || crf.TNRate >= 1 {
+ crf.TNRate = DefTNRate
+ }
+ if crf.NTree <= 0 {
+ crf.NTree = DefNumTree
+ }
+ if crf.PercentBoot <= 0 {
+ crf.PercentBoot = DefPercentBoot
+ }
+ if crf.NRandomFeature <= 0 {
+ // Set default value to square-root of features.
+ ncol := samples.GetNColumn() - 1
+ crf.NRandomFeature = int(math.Sqrt(float64(ncol)))
+ }
+ if crf.PerfFile == "" {
+ crf.PerfFile = DefPerfFile
+ }
+ if crf.StatFile == "" {
+ crf.StatFile = DefStatFile
+ }
+ crf.tnset = samples.Clone().(*tabula.Claset)
+
+ return crf.Runtime.Initialize()
+}
+
+//
+// Build given a sample dataset, build the stage with randomforest.
+//
+func (crf *Runtime) Build(samples tabula.ClasetInterface) (e error) {
+ if samples == nil {
+ return ErrNoInput
+ }
+
+ e = crf.Initialize(samples)
+ if e != nil {
+ return
+ }
+
+ fmt.Println(tag, "Training samples:", samples)
+ fmt.Println(tag, "Sample (one row):", samples.GetRow(0))
+ fmt.Println(tag, "Config:", crf)
+
+ for x := 0; x < crf.NStage; x++ {
+ if debug.Value >= 1 {
+ fmt.Println(tag, "Stage #", x)
+ }
+
+ forest, e := crf.createForest(samples)
+ if e != nil {
+ return e
+ }
+
+ e = crf.finalizeStage(forest)
+ if e != nil {
+ return e
+ }
+ }
+
+ return crf.Finalize()
+}
+
+//
+// createForest will create and return a forest and run the training `samples`
+// on it.
+//
+// Algorithm,
+// (1) Initialize forest.
+// (2) For 0 to maximum number of tree in forest,
+// (2.1) grow one tree until success.
+// (2.2) If tree tp-rate and tn-rate greater than threshold, stop growing.
+// (3) Calculate weight.
+// (4) TODO: Move true-negative from samples. The collection of true-negative
+// will be used again to test the model and after test and the sample with FP
+// will be moved to training samples again.
+// (5) Refill samples with false-positive.
+//
+func (crf *Runtime) createForest(samples tabula.ClasetInterface) (
+ forest *rf.Runtime, e error,
+) {
+ var cm *classifier.CM
+ var stat *classifier.Stat
+
+ fmt.Println(tag, "Forest samples:", samples)
+
+ // (1)
+ forest = &rf.Runtime{
+ Runtime: classifier.Runtime{
+ RunOOB: true,
+ },
+ NTree: crf.NTree,
+ NRandomFeature: crf.NRandomFeature,
+ }
+
+ e = forest.Initialize(samples)
+ if e != nil {
+ return nil, e
+ }
+
+ // (2)
+ for t := 0; t < crf.NTree; t++ {
+ if debug.Value >= 2 {
+ fmt.Println(tag, "Tree #", t)
+ }
+
+ // (2.1)
+ for {
+ cm, stat, e = forest.GrowTree(samples)
+ if e == nil {
+ break
+ }
+ }
+
+ // (2.2)
+ if stat.TPRate > crf.TPRate &&
+ stat.TNRate > crf.TNRate {
+ break
+ }
+ }
+
+ e = forest.Finalize()
+ if e != nil {
+ return nil, e
+ }
+
+ // (3)
+ crf.computeWeight(stat)
+
+ if debug.Value >= 1 {
+ fmt.Println(tag, "Weight:", stat.FMeasure)
+ }
+
+ // (4)
+ crf.deleteTrueNegative(samples, cm)
+
+ // (5)
+ crf.runTPSet(samples)
+
+ samples.RecountMajorMinor()
+
+ return forest, nil
+}
+
+//
+// finalizeStage save forest and write the forest statistic to file.
+//
+func (crf *Runtime) finalizeStage(forest *rf.Runtime) (e error) {
+ stat := forest.StatTotal()
+ stat.ID = int64(len(crf.forests))
+
+ e = crf.WriteOOBStat(stat)
+ if e != nil {
+ return e
+ }
+
+ crf.AddStat(stat)
+ crf.ComputeStatTotal(stat)
+
+ if debug.Value >= 1 {
+ crf.PrintStatTotal(nil)
+ }
+
+ // (7)
+ crf.AddForest(forest)
+
+ return nil
+}
+
+//
+// computeWeight will compute the weight of stage based on F-measure of the
+// last tree in forest.
+//
+func (crf *Runtime) computeWeight(stat *classifier.Stat) {
+ crf.weights = append(crf.weights, math.Exp(stat.FMeasure))
+}
+
+//
+// deleteTrueNegative will delete all samples data where their row index is in
+// true-negative values in confusion matrix and move it to TN-set.
+//
+// (1) Move true negative to tnset on the first iteration, on the next
+// iteration it will be full deleted.
+// (2) Delete TN from sample set one-by-one with offset, to make sure we
+// are not deleting with wrong index.
+
+func (crf *Runtime) deleteTrueNegative(samples tabula.ClasetInterface,
+ cm *classifier.CM,
+) {
+ var row *tabula.Row
+
+ tnids := cm.TNIndices()
+ sort.Ints(tnids)
+
+ // (1)
+ if len(crf.weights) <= 1 {
+ for _, i := range tnids {
+ crf.tnset.PushRow(samples.GetRow(i))
+ }
+ }
+
+ // (2)
+ c := 0
+ for x, i := range tnids {
+ row = samples.DeleteRow(i - x)
+ if row != nil {
+ c++
+ }
+ }
+
+ if debug.Value >= 1 {
+ fmt.Println(tag, "# TN", len(tnids), "# deleted", c)
+ }
+}
+
+//
+// refillWithFP will copy the false-positive data in training set `tnset`
+// and append it to `samples`.
+//
+func (crf *Runtime) refillWithFP(samples, tnset tabula.ClasetInterface,
+ cm *classifier.CM,
+) {
+ // Get and sort FP.
+ fpids := cm.FPIndices()
+ sort.Ints(fpids)
+
+ // Move FP samples from TN-set to training set samples.
+ for _, i := range fpids {
+ samples.PushRow(tnset.GetRow(i))
+ }
+
+ // Delete FP from training set.
+ var row *tabula.Row
+ c := 0
+ for x, i := range fpids {
+ row = tnset.DeleteRow(i - x)
+ if row != nil {
+ c++
+ }
+ }
+
+ if debug.Value >= 1 {
+ fmt.Println(tag, "# FP", len(fpids), "# refilled", c)
+ }
+}
+
+//
+// runTPSet will run true-positive set into trained stage, to get the
+// false-positive. The FP samples will be added to training set.
+//
+func (crf *Runtime) runTPSet(samples tabula.ClasetInterface) {
+ // Skip the first stage, because we just got tnset from them.
+ if len(crf.weights) <= 1 {
+ return
+ }
+
+ tnIds := numbers.IntCreateSeq(0, crf.tnset.Len()-1)
+ _, cm, _ := crf.ClassifySetByWeight(crf.tnset, tnIds)
+
+ crf.refillWithFP(samples, crf.tnset, cm)
+}
+
+//
+// ClassifySetByWeight will classify each instance in samples by weight
+// with respect to its single performance.
+//
+// Algorithm,
+// (1) For each instance in samples,
+// (1.1) for each stage,
+// (1.1.1) collect votes for instance in current stage.
+// (1.1.2) Compute probabilities of each classes in votes.
+//
+// prob_class = count_of_class / total_votes
+//
+// (1.1.3) Compute total of probabilites times of stage weight.
+//
+// stage_prob = prob_class * stage_weight
+//
+// (1.2) Divide each class stage probabilites with
+//
+// stage_prob = stage_prob /
+// (sum_of_all_weights * number_of_tree_in_forest)
+//
+// (1.3) Select class label with highest probabilites.
+// (1.4) Save stage probabilities for positive class.
+// (2) Compute confusion matrix.
+//
+func (crf *Runtime) ClassifySetByWeight(samples tabula.ClasetInterface,
+ sampleIds []int,
+) (
+ predicts []string, cm *classifier.CM, probs []float64,
+) {
+ stat := classifier.Stat{}
+ stat.Start()
+
+ vs := samples.GetClassValueSpace()
+ stageProbs := make([]float64, len(vs))
+ stageSumProbs := make([]float64, len(vs))
+ sumWeights := numbers.Floats64Sum(crf.weights)
+
+ // (1)
+ rows := samples.GetDataAsRows()
+ for _, row := range *rows {
+ for y := range stageSumProbs {
+ stageSumProbs[y] = 0
+ }
+
+ // (1.1)
+ for y, forest := range crf.forests {
+ // (1.1.1)
+ votes := forest.Votes(row, -1)
+
+ // (1.1.2)
+ probs := libstrings.FrequencyOfTokens(votes, vs, false)
+
+ // (1.1.3)
+ for z := range probs {
+ stageSumProbs[z] += probs[z]
+ stageProbs[z] += probs[z] * crf.weights[y]
+ }
+ }
+
+ // (1.2)
+ stageWeight := sumWeights * float64(crf.NTree)
+
+ for x := range stageProbs {
+ stageProbs[x] = stageProbs[x] / stageWeight
+ }
+
+ // (1.3)
+ _, maxi, ok := numbers.Floats64FindMax(stageProbs)
+ if ok {
+ predicts = append(predicts, vs[maxi])
+ }
+
+ probs = append(probs, stageSumProbs[0]/
+ float64(len(crf.forests)))
+ }
+
+ // (2)
+ actuals := samples.GetClassAsStrings()
+ cm = crf.ComputeCM(sampleIds, vs, actuals, predicts)
+
+ crf.ComputeStatFromCM(&stat, cm)
+ stat.End()
+
+ _ = stat.Write(crf.StatFile)
+
+ return predicts, cm, probs
+}
diff --git a/lib/mining/classifier/crf/crf_test.go b/lib/mining/classifier/crf/crf_test.go
new file mode 100644
index 00000000..4e530591
--- /dev/null
+++ b/lib/mining/classifier/crf/crf_test.go
@@ -0,0 +1,92 @@
+// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package crf
+
+import (
+ "fmt"
+ "os"
+ "testing"
+
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/mining/classifier"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+var (
+ SampleFile string
+ PerfFile string
+ StatFile string
+ NStage = 200
+ NTree = 1
+)
+
+func runCRF(t *testing.T) {
+ // read trainingset.
+ samples := tabula.Claset{}
+ _, e := dsv.SimpleRead(SampleFile, &samples)
+ if e != nil {
+ t.Fatal(e)
+ }
+
+ nbag := (samples.Len() * 63) / 100
+ train, test, _, testIds := tabula.RandomPickRows(&samples, nbag, false)
+
+ trainset := train.(tabula.ClasetInterface)
+ testset := test.(tabula.ClasetInterface)
+
+ crfRuntime := Runtime{
+ Runtime: classifier.Runtime{
+ StatFile: StatFile,
+ PerfFile: PerfFile,
+ },
+ NStage: NStage,
+ NTree: NTree,
+ }
+
+ e = crfRuntime.Build(trainset)
+ if e != nil {
+ t.Fatal(e)
+ }
+
+ testset.RecountMajorMinor()
+ fmt.Println("Testset:", testset)
+
+ predicts, cm, probs := crfRuntime.ClassifySetByWeight(testset, testIds)
+
+ fmt.Println("Confusion matrix:", cm)
+
+ crfRuntime.Performance(testset, predicts, probs)
+ e = crfRuntime.WritePerformance()
+ if e != nil {
+ t.Fatal(e)
+ }
+}
+
+func TestPhoneme200_1(t *testing.T) {
+ SampleFile = "../../testdata/phoneme/phoneme.dsv"
+ PerfFile = "phoneme_200_1.perf"
+ StatFile = "phoneme_200_1.stat"
+
+ runCRF(t)
+}
+
+func TestPhoneme200_10(t *testing.T) {
+ SampleFile = "../../testdata/phoneme/phoneme.dsv"
+ PerfFile = "phoneme_200_10.perf"
+ StatFile = "phoneme_200_10.stat"
+ NTree = 10
+
+ runCRF(t)
+}
+
+func TestMain(m *testing.M) {
+ envTestCRF := os.Getenv("TEST_CRF")
+
+ if len(envTestCRF) == 0 {
+ os.Exit(0)
+ }
+
+ os.Exit(m.Run())
+}
diff --git a/lib/mining/classifier/rf/rf.go b/lib/mining/classifier/rf/rf.go
new file mode 100644
index 00000000..86595efd
--- /dev/null
+++ b/lib/mining/classifier/rf/rf.go
@@ -0,0 +1,363 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//
+// Package rf implement ensemble of classifiers using random forest
+// algorithm by Breiman and Cutler.
+//
+// Breiman, Leo. "Random forests." Machine learning 45.1 (2001): 5-32.
+//
+// The implementation is based on various sources and using author experience.
+//
+package rf
+
+import (
+ "errors"
+ "fmt"
+ "math"
+
+ "github.com/shuLhan/share/lib/debug"
+ "github.com/shuLhan/share/lib/mining/classifier"
+ "github.com/shuLhan/share/lib/mining/classifier/cart"
+ "github.com/shuLhan/share/lib/numbers"
+ libstrings "github.com/shuLhan/share/lib/strings"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+const (
+ tag = "[rf]"
+
+ // DefNumTree default number of tree.
+ DefNumTree = 100
+
+ // DefPercentBoot default percentage of sample that will be used for
+ // bootstraping a tree.
+ DefPercentBoot = 66
+
+ // DefOOBStatsFile default statistic file output.
+ DefOOBStatsFile = "rf.oob.stat"
+
+ // DefPerfFile default performance file output.
+ DefPerfFile = "rf.perf"
+
+ // DefStatFile default statistic file.
+ DefStatFile = "rf.stat"
+)
+
+var (
+ // ErrNoInput will tell you when no input is given.
+ ErrNoInput = errors.New("rf: input samples is empty")
+)
+
+//
+// Runtime contains input and output configuration when generating random forest.
+//
+type Runtime struct {
+ // Runtime embed common fields for classifier.
+ classifier.Runtime
+
+ // NTree number of tree in forest.
+ NTree int `json:"NTree"`
+ // NRandomFeature number of feature randomly selected for each tree.
+ NRandomFeature int `json:"NRandomFeature"`
+ // PercentBoot percentage of sample for bootstraping.
+ PercentBoot int `json:"PercentBoot"`
+
+ // nSubsample number of samples used for bootstraping.
+ nSubsample int
+ // trees contain all tree in the forest.
+ trees []cart.Runtime
+ // bagIndices contain list of index of selected samples at bootstraping
+ // for book-keeping.
+ bagIndices [][]int
+}
+
+//
+// Trees return all tree in forest.
+//
+func (forest *Runtime) Trees() []cart.Runtime {
+ return forest.trees
+}
+
+//
+// AddCartTree add tree to forest
+//
+func (forest *Runtime) AddCartTree(tree cart.Runtime) {
+ forest.trees = append(forest.trees, tree)
+}
+
+//
+// AddBagIndex add bagging index for book keeping.
+//
+func (forest *Runtime) AddBagIndex(bagIndex []int) {
+ forest.bagIndices = append(forest.bagIndices, bagIndex)
+}
+
+//
+// Initialize will check forest inputs and set it to default values if invalid.
+//
+// It will also calculate number of random samples for each tree using,
+//
+// number-of-sample * percentage-of-bootstrap
+//
+//
+func (forest *Runtime) Initialize(samples tabula.ClasetInterface) error {
+ if forest.NTree <= 0 {
+ forest.NTree = DefNumTree
+ }
+ if forest.PercentBoot <= 0 {
+ forest.PercentBoot = DefPercentBoot
+ }
+ if forest.NRandomFeature <= 0 {
+ // Set default value to square-root of features.
+ ncol := samples.GetNColumn() - 1
+ forest.NRandomFeature = int(math.Sqrt(float64(ncol)))
+ }
+ if forest.OOBStatsFile == "" {
+ forest.OOBStatsFile = DefOOBStatsFile
+ }
+ if forest.PerfFile == "" {
+ forest.PerfFile = DefPerfFile
+ }
+ if forest.StatFile == "" {
+ forest.StatFile = DefStatFile
+ }
+
+ forest.nSubsample = int(float32(samples.GetNRow()) *
+ (float32(forest.PercentBoot) / 100.0))
+
+ return forest.Runtime.Initialize()
+}
+
+//
+// Build the forest using samples dataset.
+//
+// Algorithm,
+//
+// (0) Recheck input value: number of tree, percentage bootstrap, etc; and
+// Open statistic file output.
+// (1) For 0 to NTree,
+// (1.1) Create new tree, repeat until all trees has been build.
+// (2) Compute and write total statistic.
+//
+func (forest *Runtime) Build(samples tabula.ClasetInterface) (e error) {
+ // check input samples
+ if samples == nil {
+ return ErrNoInput
+ }
+
+ // (0)
+ e = forest.Initialize(samples)
+ if e != nil {
+ return
+ }
+
+ fmt.Println(tag, "Training set :", samples)
+ fmt.Println(tag, "Sample (one row):", samples.GetRow(0))
+ fmt.Println(tag, "Forest config :", forest)
+
+ // (1)
+ for t := 0; t < forest.NTree; t++ {
+ if debug.Value >= 1 {
+ fmt.Println(tag, "tree #", t)
+ }
+
+ // (1.1)
+ for {
+ _, _, e = forest.GrowTree(samples)
+ if e == nil {
+ break
+ }
+
+ fmt.Println(tag, "error:", e)
+ }
+ }
+
+ // (2)
+ return forest.Finalize()
+}
+
+//
+// GrowTree build a new tree in forest, return OOB error value or error if tree
+// can not grow.
+//
+// Algorithm,
+//
+// (1) Select random samples with replacement, also with OOB.
+// (2) Build tree using CART, without pruning.
+// (3) Add tree to forest.
+// (4) Save index of random samples for calculating error rate later.
+// (5) Run OOB on forest.
+// (6) Calculate OOB error rate and statistic values.
+//
+func (forest *Runtime) GrowTree(samples tabula.ClasetInterface) (
+ cm *classifier.CM, stat *classifier.Stat, e error,
+) {
+ stat = &classifier.Stat{}
+ stat.ID = int64(len(forest.trees))
+ stat.Start()
+
+ // (1)
+ bag, oob, bagIdx, oobIdx := tabula.RandomPickRows(
+ samples.(tabula.DatasetInterface),
+ forest.nSubsample, true)
+
+ bagset := bag.(tabula.ClasetInterface)
+
+ if debug.Value >= 2 {
+ bagset.RecountMajorMinor()
+ fmt.Println(tag, "Bagging:", bagset)
+ }
+
+ // (2)
+ cart, e := cart.New(bagset, cart.SplitMethodGini,
+ forest.NRandomFeature)
+ if e != nil {
+ return nil, nil, e
+ }
+
+ // (3)
+ forest.AddCartTree(*cart)
+
+ // (4)
+ forest.AddBagIndex(bagIdx)
+
+ // (5)
+ if forest.RunOOB {
+ oobset := oob.(tabula.ClasetInterface)
+ _, cm, _ = forest.ClassifySet(oobset, oobIdx)
+
+ forest.AddOOBCM(cm)
+ }
+
+ stat.End()
+
+ if debug.Value >= 3 && forest.RunOOB {
+ fmt.Println(tag, "Elapsed time (s):", stat.ElapsedTime)
+ }
+
+ forest.AddStat(stat)
+
+ // (6)
+ if forest.RunOOB {
+ forest.ComputeStatFromCM(stat, cm)
+
+ if debug.Value >= 2 {
+ fmt.Println(tag, "OOB stat:", stat)
+ }
+ }
+
+ forest.ComputeStatTotal(stat)
+ e = forest.WriteOOBStat(stat)
+
+ return cm, stat, e
+}
+
+//
+// ClassifySet given a samples predict their class by running each sample in
+// forest, adn return their class prediction with confusion matrix.
+// `samples` is the sample that will be predicted, `sampleIds` is the index of
+// samples.
+// If `sampleIds` is not nil, then sample index will be checked in each tree,
+// if the sample is used for training, their vote is not counted.
+//
+// Algorithm,
+//
+// (0) Get value space (possible class values in dataset)
+// (1) For each row in test-set,
+// (1.1) collect votes in all trees,
+// (1.2) select majority class vote, and
+// (1.3) compute and save the actual class probabilities.
+// (2) Compute confusion matrix from predictions.
+// (3) Compute stat from confusion matrix.
+// (4) Write the stat to file only if sampleIds is empty, which mean its run
+// not from OOB set.
+//
+func (forest *Runtime) ClassifySet(samples tabula.ClasetInterface,
+ sampleIds []int,
+) (
+ predicts []string, cm *classifier.CM, probs []float64,
+) {
+ stat := classifier.Stat{}
+ stat.Start()
+
+ if len(sampleIds) <= 0 {
+ fmt.Println(tag, "Classify set:", samples)
+ fmt.Println(tag, "Classify set sample (one row):",
+ samples.GetRow(0))
+ }
+
+ // (0)
+ vs := samples.GetClassValueSpace()
+ actuals := samples.GetClassAsStrings()
+ sampleIdx := -1
+
+ // (1)
+ rows := samples.GetRows()
+ for x, row := range *rows {
+ // (1.1)
+ if len(sampleIds) > 0 {
+ sampleIdx = sampleIds[x]
+ }
+ votes := forest.Votes(row, sampleIdx)
+
+ // (1.2)
+ classProbs := libstrings.FrequencyOfTokens(votes, vs, false)
+
+ _, idx, ok := numbers.Floats64FindMax(classProbs)
+
+ if ok {
+ predicts = append(predicts, vs[idx])
+ }
+
+ // (1.3)
+ probs = append(probs, classProbs[0])
+ }
+
+ // (2)
+ cm = forest.ComputeCM(sampleIds, vs, actuals, predicts)
+
+ // (3)
+ forest.ComputeStatFromCM(&stat, cm)
+ stat.End()
+
+ if len(sampleIds) <= 0 {
+ fmt.Println(tag, "CM:", cm)
+ fmt.Println(tag, "Classifying stat:", stat)
+ _ = stat.Write(forest.StatFile)
+ }
+
+ return predicts, cm, probs
+}
+
+//
+// Votes will return votes, or classes, in each tree based on sample.
+// If checkIdx is true then the `sampleIdx` will be checked in if it has been used
+// when training the tree, if its exist then the sample will be skipped.
+//
+// (1) If row is used to build the tree then skip it,
+// (2) classify row in tree,
+// (3) save tree class value.
+//
+func (forest *Runtime) Votes(sample *tabula.Row, sampleIdx int) (
+ votes []string,
+) {
+ for x, tree := range forest.trees {
+ // (1)
+ if sampleIdx >= 0 {
+ exist := numbers.IntsIsExist(forest.bagIndices[x],
+ sampleIdx)
+ if exist {
+ continue
+ }
+ }
+
+ // (2)
+ class := tree.Classify(sample)
+
+ // (3)
+ votes = append(votes, class)
+ }
+ return votes
+}
diff --git a/lib/mining/classifier/rf/rf_bench_test.go b/lib/mining/classifier/rf/rf_bench_test.go
new file mode 100644
index 00000000..eddd721c
--- /dev/null
+++ b/lib/mining/classifier/rf/rf_bench_test.go
@@ -0,0 +1,22 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rf
+
+import (
+ "testing"
+)
+
+func BenchmarkPhoneme(b *testing.B) {
+ SampleDsvFile = "../../testdata/phoneme/phoneme.dsv"
+ OOBStatsFile = "phoneme.oob"
+ PerfFile = "phoneme.perf"
+
+ MinFeature = 3
+ MaxFeature = 4
+
+ for x := 0; x < b.N; x++ {
+ runRandomForest()
+ }
+}
diff --git a/lib/mining/classifier/rf/rf_test.go b/lib/mining/classifier/rf/rf_test.go
new file mode 100644
index 00000000..ba96125b
--- /dev/null
+++ b/lib/mining/classifier/rf/rf_test.go
@@ -0,0 +1,190 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rf
+
+import (
+ "fmt"
+ "log"
+ "os"
+ "testing"
+
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/mining/classifier"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+// Global options to run for each test.
+var (
+ // SampleDsvFile is the file that contain samples config.
+ SampleDsvFile string
+ // DoTest if its true then the dataset will splited into training and
+ // test set with random selection without replacement.
+ DoTest = false
+ // NTree number of tree to generate.
+ NTree = 100
+ // NBootstrap percentage of sample used as subsample.
+ NBootstrap = 66
+ // MinFeature number of feature to begin with.
+ MinFeature = 1
+ // MaxFeature maximum number of feature to test
+ MaxFeature = -1
+ // RunOOB if its true the the OOB samples will be used to test the
+ // model in each iteration.
+ RunOOB = true
+ // OOBStatsFile is the file where OOB statistic will be saved.
+ OOBStatsFile string
+ // PerfFile is the file where performance statistic will be saved.
+ PerfFile string
+ // StatFile is the file where classifying statistic will be saved.
+ StatFile string
+)
+
+func getSamples() (train, test tabula.ClasetInterface) {
+ samples := tabula.Claset{}
+ _, e := dsv.SimpleRead(SampleDsvFile, &samples)
+ if nil != e {
+ log.Fatal(e)
+ }
+
+ if !DoTest {
+ return &samples, nil
+ }
+
+ ntrain := int(float32(samples.Len()) * (float32(NBootstrap) / 100.0))
+
+ bag, oob, _, _ := tabula.RandomPickRows(&samples, ntrain, false)
+
+ train = bag.(tabula.ClasetInterface)
+ test = oob.(tabula.ClasetInterface)
+
+ train.SetClassIndex(samples.GetClassIndex())
+ test.SetClassIndex(samples.GetClassIndex())
+
+ return train, test
+}
+
+func runRandomForest() {
+ oobStatsFile := OOBStatsFile
+ perfFile := PerfFile
+ statFile := StatFile
+
+ trainset, testset := getSamples()
+
+ if MaxFeature < 0 {
+ MaxFeature = trainset.GetNColumn()
+ }
+
+ for nfeature := MinFeature; nfeature < MaxFeature; nfeature++ {
+ // Add prefix to OOB stats file.
+ oobStatsFile = fmt.Sprintf("N%d.%s", nfeature, OOBStatsFile)
+
+ // Add prefix to performance file.
+ perfFile = fmt.Sprintf("N%d.%s", nfeature, PerfFile)
+
+ // Add prefix to stat file.
+ statFile = fmt.Sprintf("N%d.%s", nfeature, StatFile)
+
+ // Create and build random forest.
+ forest := Runtime{
+ Runtime: classifier.Runtime{
+ RunOOB: RunOOB,
+ OOBStatsFile: oobStatsFile,
+ PerfFile: perfFile,
+ StatFile: statFile,
+ },
+ NTree: NTree,
+ NRandomFeature: nfeature,
+ PercentBoot: NBootstrap,
+ }
+
+ e := forest.Build(trainset)
+ if e != nil {
+ log.Fatal(e)
+ }
+
+ if DoTest {
+ predicts, _, probs := forest.ClassifySet(testset, nil)
+
+ forest.Performance(testset, predicts, probs)
+ e = forest.WritePerformance()
+ if e != nil {
+ log.Fatal(e)
+ }
+ }
+ }
+}
+
+func TestEnsemblingGlass(t *testing.T) {
+ SampleDsvFile = "../../testdata/forensic_glass/fgl.dsv"
+ RunOOB = false
+ OOBStatsFile = "glass.oob"
+ StatFile = "glass.stat"
+ PerfFile = "glass.perf"
+ DoTest = true
+
+ runRandomForest()
+}
+
+func TestEnsemblingIris(t *testing.T) {
+ SampleDsvFile = "../../testdata/iris/iris.dsv"
+ OOBStatsFile = "iris.oob"
+
+ runRandomForest()
+}
+
+func TestEnsemblingPhoneme(t *testing.T) {
+ SampleDsvFile = "../../testdata/phoneme/phoneme.dsv"
+ OOBStatsFile = "phoneme.oob.stat"
+ StatFile = "phoneme.stat"
+ PerfFile = "phoneme.perf"
+
+ NTree = 200
+ MinFeature = 3
+ MaxFeature = 4
+ RunOOB = false
+ DoTest = true
+
+ runRandomForest()
+}
+
+func TestEnsemblingSmotePhoneme(t *testing.T) {
+ SampleDsvFile = "../../resampling/smote/phoneme_smote.dsv"
+ OOBStatsFile = "phonemesmote.oob"
+
+ MinFeature = 3
+ MaxFeature = 4
+
+ runRandomForest()
+}
+
+func TestEnsemblingLnsmotePhoneme(t *testing.T) {
+ SampleDsvFile = "../../resampling/lnsmote/phoneme_lnsmote.dsv"
+ OOBStatsFile = "phonemelnsmote.oob"
+
+ MinFeature = 3
+ MaxFeature = 4
+
+ runRandomForest()
+}
+
+func TestWvc2010Lnsmote(t *testing.T) {
+ SampleDsvFile = "../../testdata/wvc2010lnsmote/wvc2010_features.lnsmote.dsv"
+ OOBStatsFile = "wvc2010lnsmote.oob"
+
+ NTree = 1
+ MinFeature = 5
+ MaxFeature = 6
+
+ runRandomForest()
+}
+
+func TestMain(m *testing.M) {
+ envTestRF := os.Getenv("TEST_RF")
+ if len(envTestRF) == 0 {
+ os.Exit(0)
+ }
+
+ os.Exit(m.Run())
+}
diff --git a/lib/mining/classifier/runtime.go b/lib/mining/classifier/runtime.go
new file mode 100644
index 00000000..0f7a7755
--- /dev/null
+++ b/lib/mining/classifier/runtime.go
@@ -0,0 +1,443 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package classifier
+
+import (
+ "fmt"
+ "math"
+
+ "github.com/shuLhan/share/lib/debug"
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/numbers"
+ libstrings "github.com/shuLhan/share/lib/strings"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+const (
+ tag = "[classifier.runtime]"
+)
+
+//
+// Runtime define a generic type which provide common fields that can be
+// embedded by the real classifier (e.g. RandomForest).
+//
+type Runtime struct {
+ // RunOOB if its true the OOB will be computed, default is false.
+ RunOOB bool `json:"RunOOB"`
+
+ // OOBStatsFile is the file where OOB statistic will be written.
+ OOBStatsFile string `json:"OOBStatsFile"`
+
+ // PerfFile is the file where statistic of performance will be written.
+ PerfFile string `json:"PerfFile"`
+
+ // StatFile is the file where statistic of classifying samples will be
+ // written.
+ StatFile string `json:"StatFile"`
+
+ // oobCms contain confusion matrix value for each OOB in iteration.
+ oobCms []CM
+
+ // oobStats contain statistic of classifier for each OOB in iteration.
+ oobStats Stats
+
+ // oobStatTotal contain total OOB statistic values.
+ oobStatTotal Stat
+
+ // oobWriter contain file writer for statistic.
+ oobWriter *dsv.Writer
+
+ // perfs contain performance statistic per sample, after classifying
+ // sample on classifier.
+ perfs Stats
+}
+
+//
+// Initialize will start the runtime for processing by saving start time and
+// opening stats file.
+//
+func (rt *Runtime) Initialize() error {
+ rt.oobStatTotal.Start()
+
+ return rt.OpenOOBStatsFile()
+}
+
+//
+// Finalize finish the runtime, compute total statistic, write it to file, and
+// close the file.
+//
+func (rt *Runtime) Finalize() (e error) {
+ st := &rt.oobStatTotal
+
+ st.End()
+ st.ID = int64(len(rt.oobStats))
+
+ e = rt.WriteOOBStat(st)
+ if e != nil {
+ return e
+ }
+
+ return rt.CloseOOBStatsFile()
+}
+
+//
+// OOBStats return all statistic objects.
+//
+func (rt *Runtime) OOBStats() *Stats {
+ return &rt.oobStats
+}
+
+//
+// StatTotal return total statistic.
+//
+func (rt *Runtime) StatTotal() *Stat {
+ return &rt.oobStatTotal
+}
+
+//
+// AddOOBCM will append new confusion matrix.
+//
+func (rt *Runtime) AddOOBCM(cm *CM) {
+ rt.oobCms = append(rt.oobCms, *cm)
+}
+
+//
+// AddStat will append new classifier statistic data.
+//
+func (rt *Runtime) AddStat(stat *Stat) {
+ rt.oobStats = append(rt.oobStats, stat)
+}
+
+//
+// ComputeCM will compute confusion matrix of sample using value space, actual
+// and prediction values.
+//
+func (rt *Runtime) ComputeCM(sampleIds []int,
+ vs, actuals, predicts []string,
+) (
+ cm *CM,
+) {
+ cm = &CM{}
+
+ cm.ComputeStrings(vs, actuals, predicts)
+ cm.GroupIndexPredictionsStrings(sampleIds, actuals, predicts)
+
+ if debug.Value >= 2 {
+ fmt.Println(tag, cm)
+ }
+
+ return cm
+}
+
+//
+// ComputeStatFromCM will compute statistic using confusion matrix.
+//
+func (rt *Runtime) ComputeStatFromCM(stat *Stat, cm *CM) {
+
+ stat.OobError = cm.GetFalseRate()
+
+ stat.OobErrorMean = rt.oobStatTotal.OobError /
+ float64(len(rt.oobStats)+1)
+
+ stat.TP = int64(cm.TP())
+ stat.FP = int64(cm.FP())
+ stat.TN = int64(cm.TN())
+ stat.FN = int64(cm.FN())
+
+ t := float64(stat.TP + stat.FN)
+ if t == 0 {
+ stat.TPRate = 0
+ } else {
+ stat.TPRate = float64(stat.TP) / t
+ }
+
+ t = float64(stat.FP + stat.TN)
+ if t == 0 {
+ stat.FPRate = 0
+ } else {
+ stat.FPRate = float64(stat.FP) / t
+ }
+
+ t = float64(stat.FP + stat.TN)
+ if t == 0 {
+ stat.TNRate = 0
+ } else {
+ stat.TNRate = float64(stat.TN) / t
+ }
+
+ t = float64(stat.TP + stat.FP)
+ if t == 0 {
+ stat.Precision = 0
+ } else {
+ stat.Precision = float64(stat.TP) / t
+ }
+
+ t = (1 / stat.Precision) + (1 / stat.TPRate)
+ if t == 0 {
+ stat.FMeasure = 0
+ } else {
+ stat.FMeasure = 2 / t
+ }
+
+ t = float64(stat.TP + stat.TN + stat.FP + stat.FN)
+ if t == 0 {
+ stat.Accuracy = 0
+ } else {
+ stat.Accuracy = float64(stat.TP+stat.TN) / t
+ }
+
+ if debug.Value >= 1 {
+ rt.PrintOobStat(stat, cm)
+ rt.PrintStat(stat)
+ }
+}
+
+//
+// ComputeStatTotal compute total statistic.
+//
+func (rt *Runtime) ComputeStatTotal(stat *Stat) {
+ if stat == nil {
+ return
+ }
+
+ nstat := len(rt.oobStats)
+ if nstat == 0 {
+ return
+ }
+
+ t := &rt.oobStatTotal
+
+ t.OobError += stat.OobError
+ t.OobErrorMean = t.OobError / float64(nstat)
+ t.TP += stat.TP
+ t.FP += stat.FP
+ t.TN += stat.TN
+ t.FN += stat.FN
+
+ total := float64(t.TP + t.FN)
+ if total == 0 {
+ t.TPRate = 0
+ } else {
+ t.TPRate = float64(t.TP) / total
+ }
+
+ total = float64(t.FP + t.TN)
+ if total == 0 {
+ t.FPRate = 0
+ } else {
+ t.FPRate = float64(t.FP) / total
+ }
+
+ total = float64(t.FP + t.TN)
+ if total == 0 {
+ t.TNRate = 0
+ } else {
+ t.TNRate = float64(t.TN) / total
+ }
+
+ total = float64(t.TP + t.FP)
+ if total == 0 {
+ t.Precision = 0
+ } else {
+ t.Precision = float64(t.TP) / total
+ }
+
+ total = (1 / t.Precision) + (1 / t.TPRate)
+ if total == 0 {
+ t.FMeasure = 0
+ } else {
+ t.FMeasure = 2 / total
+ }
+
+ total = float64(t.TP + t.TN + t.FP + t.FN)
+ if total == 0 {
+ t.Accuracy = 0
+ } else {
+ t.Accuracy = float64(t.TP+t.TN) / total
+ }
+}
+
+//
+// OpenOOBStatsFile will open statistic file for output.
+//
+func (rt *Runtime) OpenOOBStatsFile() error {
+ if rt.oobWriter != nil {
+ _ = rt.CloseOOBStatsFile()
+ }
+ rt.oobWriter = &dsv.Writer{}
+ return rt.oobWriter.OpenOutput(rt.OOBStatsFile)
+}
+
+//
+// WriteOOBStat will write statistic of process to file.
+//
+func (rt *Runtime) WriteOOBStat(stat *Stat) error {
+ if rt.oobWriter == nil {
+ return nil
+ }
+ if stat == nil {
+ return nil
+ }
+ return rt.oobWriter.WriteRawRow(stat.ToRow(), nil, nil)
+}
+
+//
+// CloseOOBStatsFile will close statistics file for writing.
+//
+func (rt *Runtime) CloseOOBStatsFile() (e error) {
+ if rt.oobWriter == nil {
+ return
+ }
+
+ e = rt.oobWriter.Close()
+ rt.oobWriter = nil
+
+ return
+}
+
+//
+// PrintOobStat will print the out-of-bag statistic to standard output.
+//
+func (rt *Runtime) PrintOobStat(stat *Stat, cm *CM) {
+ fmt.Printf("%s OOB error rate: %.4f,"+
+ " total: %.4f, mean %.4f, true rate: %.4f\n", tag,
+ stat.OobError, rt.oobStatTotal.OobError,
+ stat.OobErrorMean, cm.GetTrueRate())
+}
+
+//
+// PrintStat will print statistic value to standard output.
+//
+func (rt *Runtime) PrintStat(stat *Stat) {
+ if stat == nil {
+ statslen := len(rt.oobStats)
+ if statslen <= 0 {
+ return
+ }
+ stat = rt.oobStats[statslen-1]
+ }
+
+ fmt.Printf("%s TPRate: %.4f, FPRate: %.4f,"+
+ " TNRate: %.4f, precision: %.4f, f-measure: %.4f,"+
+ " accuracy: %.4f\n", tag, stat.TPRate, stat.FPRate,
+ stat.TNRate, stat.Precision, stat.FMeasure, stat.Accuracy)
+}
+
+//
+// PrintStatTotal will print total statistic to standard output.
+//
+func (rt *Runtime) PrintStatTotal(st *Stat) {
+ if st == nil {
+ st = &rt.oobStatTotal
+ }
+ rt.PrintStat(st)
+}
+
+//
+// Performance given an actuals class label and their probabilities, compute
+// the performance statistic of classifier.
+//
+// Algorithm,
+// (1) Sort the probabilities in descending order.
+// (2) Sort the actuals and predicts using sorted index from probs
+// (3) Compute tpr, fpr, precision
+// (4) Write performance to file.
+//
+func (rt *Runtime) Performance(samples tabula.ClasetInterface,
+ predicts []string, probs []float64,
+) (
+ perfs Stats,
+) {
+ // (1)
+ actuals := samples.GetClassAsStrings()
+ sortedIds := numbers.IntCreateSeq(0, len(probs)-1)
+ numbers.Floats64InplaceMergesort(probs, sortedIds, 0, len(probs),
+ false)
+
+ // (2)
+ libstrings.SortByIndex(&actuals, sortedIds)
+ libstrings.SortByIndex(&predicts, sortedIds)
+
+ // (3)
+ rt.computePerfByProbs(samples, actuals, probs)
+
+ return rt.perfs
+}
+
+func trapezoidArea(fp, fpprev, tp, tpprev int64) float64 {
+ base := math.Abs(float64(fp - fpprev))
+ heightAvg := float64(tp+tpprev) / float64(2.0)
+ return base * heightAvg
+}
+
+//
+// computePerfByProbs will compute classifier performance using probabilities
+// or score `probs`.
+//
+// This currently only work for two class problem.
+//
+func (rt *Runtime) computePerfByProbs(samples tabula.ClasetInterface,
+ actuals []string, probs []float64,
+) {
+ vs := samples.GetClassValueSpace()
+ nactuals := numbers.IntsTo64(samples.Counts())
+ nclass := libstrings.CountTokens(actuals, vs, false)
+
+ pprev := math.Inf(-1)
+ tp := int64(0)
+ fp := int64(0)
+ tpprev := int64(0)
+ fpprev := int64(0)
+
+ auc := float64(0)
+
+ for x, p := range probs {
+ if p != pprev {
+ stat := Stat{}
+ stat.SetTPRate(tp, nactuals[0])
+ stat.SetFPRate(fp, nactuals[1])
+ stat.SetPrecisionFromRate(nactuals[0], nactuals[1])
+
+ auc = auc + trapezoidArea(fp, fpprev, tp, tpprev)
+ stat.SetAUC(auc)
+
+ rt.perfs = append(rt.perfs, &stat)
+
+ pprev = p
+ tpprev = tp
+ fpprev = fp
+ }
+
+ if actuals[x] == vs[0] {
+ tp++
+ } else {
+ fp++
+ }
+ }
+
+ stat := Stat{}
+ stat.SetTPRate(tp, nactuals[0])
+ stat.SetFPRate(fp, nactuals[1])
+ stat.SetPrecisionFromRate(nactuals[0], nactuals[1])
+
+ auc = auc + trapezoidArea(fp, fpprev, tp, tpprev)
+ auc = auc / float64(nclass[0]*nclass[1])
+ stat.SetAUC(auc)
+
+ rt.perfs = append(rt.perfs, &stat)
+
+ if len(rt.perfs) >= 2 {
+ // Replace the first stat with second stat, because of NaN
+ // value on the first precision.
+ rt.perfs[0] = rt.perfs[1]
+ }
+}
+
+//
+// WritePerformance will write performance data to file.
+//
+func (rt *Runtime) WritePerformance() error {
+ return rt.perfs.Write(rt.PerfFile)
+}
diff --git a/lib/mining/classifier/stat.go b/lib/mining/classifier/stat.go
new file mode 100644
index 00000000..790aec10
--- /dev/null
+++ b/lib/mining/classifier/stat.go
@@ -0,0 +1,176 @@
+// Copyright 2015-2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package classifier
+
+import (
+ "time"
+
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+//
+// Stat hold statistic value of classifier, including TP rate, FP rate, precision,
+// and recall.
+//
+type Stat struct {
+ // ID unique id for this statistic (e.g. number of tree).
+ ID int64
+ // StartTime contain the start time of classifier in unix timestamp.
+ StartTime int64
+ // EndTime contain the end time of classifier in unix timestamp.
+ EndTime int64
+ // ElapsedTime contain actual time, in seconds, between end and start
+ // time.
+ ElapsedTime int64
+ // TP contain true-positive value.
+ TP int64
+ // FP contain false-positive value.
+ FP int64
+ // TN contain true-negative value.
+ TN int64
+ // FN contain false-negative value.
+ FN int64
+ // OobError contain out-of-bag error.
+ OobError float64
+ // OobErrorMean contain mean of out-of-bag error.
+ OobErrorMean float64
+ // TPRate contain true-positive rate (recall): tp/(tp+fn)
+ TPRate float64
+ // FPRate contain false-positive rate: fp/(fp+tn)
+ FPRate float64
+ // TNRate contain true-negative rate: tn/(tn+fp)
+ TNRate float64
+ // Precision contain: tp/(tp+fp)
+ Precision float64
+ // FMeasure contain value of F-measure or the harmonic mean of
+ // precision and recall.
+ FMeasure float64
+ // Accuracy contain the degree of closeness of measurements of a
+ // quantity to that quantity's true value.
+ Accuracy float64
+ // AUC contain the area under curve.
+ AUC float64
+}
+
+// SetAUC will set the AUC value.
+func (stat *Stat) SetAUC(v float64) {
+ stat.AUC = v
+}
+
+//
+// SetTPRate will set TP and TPRate using number of positive `p`.
+//
+func (stat *Stat) SetTPRate(tp, p int64) {
+ stat.TP = tp
+ stat.TPRate = float64(tp) / float64(p)
+}
+
+//
+// SetFPRate will set FP and FPRate using number of negative `n`.
+//
+func (stat *Stat) SetFPRate(fp, n int64) {
+ stat.FP = fp
+ stat.FPRate = float64(fp) / float64(n)
+}
+
+//
+// SetPrecisionFromRate will set Precision value using tprate and fprate.
+// `p` and `n` is the number of positive and negative class in samples.
+//
+func (stat *Stat) SetPrecisionFromRate(p, n int64) {
+ stat.Precision = (stat.TPRate * float64(p)) /
+ ((stat.TPRate * float64(p)) + (stat.FPRate * float64(n)))
+}
+
+//
+// Recall return value of recall.
+//
+func (stat *Stat) Recall() float64 {
+ return stat.TPRate
+}
+
+//
+// Sum will add statistic from other stat object to current stat, not including
+// the start and end time.
+//
+func (stat *Stat) Sum(other *Stat) {
+ stat.OobError += other.OobError
+ stat.OobErrorMean += other.OobErrorMean
+ stat.TP += other.TP
+ stat.FP += other.FP
+ stat.TN += other.TN
+ stat.FN += other.FN
+ stat.TPRate += other.TPRate
+ stat.FPRate += other.FPRate
+ stat.TNRate += other.TNRate
+ stat.Precision += other.Precision
+ stat.FMeasure += other.FMeasure
+ stat.Accuracy += other.Accuracy
+}
+
+//
+// ToRow will convert the stat to tabula.row in the order of Stat field.
+//
+func (stat *Stat) ToRow() (row *tabula.Row) {
+ row = &tabula.Row{}
+
+ row.PushBack(tabula.NewRecordInt(stat.ID))
+ row.PushBack(tabula.NewRecordInt(stat.StartTime))
+ row.PushBack(tabula.NewRecordInt(stat.EndTime))
+ row.PushBack(tabula.NewRecordInt(stat.ElapsedTime))
+ row.PushBack(tabula.NewRecordReal(stat.OobError))
+ row.PushBack(tabula.NewRecordReal(stat.OobErrorMean))
+ row.PushBack(tabula.NewRecordInt(stat.TP))
+ row.PushBack(tabula.NewRecordInt(stat.FP))
+ row.PushBack(tabula.NewRecordInt(stat.TN))
+ row.PushBack(tabula.NewRecordInt(stat.FN))
+ row.PushBack(tabula.NewRecordReal(stat.TPRate))
+ row.PushBack(tabula.NewRecordReal(stat.FPRate))
+ row.PushBack(tabula.NewRecordReal(stat.TNRate))
+ row.PushBack(tabula.NewRecordReal(stat.Precision))
+ row.PushBack(tabula.NewRecordReal(stat.FMeasure))
+ row.PushBack(tabula.NewRecordReal(stat.Accuracy))
+ row.PushBack(tabula.NewRecordReal(stat.AUC))
+
+ return
+}
+
+//
+// Start will start the timer.
+//
+func (stat *Stat) Start() {
+ stat.StartTime = time.Now().Unix()
+}
+
+//
+// End will stop the timer and compute the elapsed time.
+//
+func (stat *Stat) End() {
+ stat.EndTime = time.Now().Unix()
+ stat.ElapsedTime = stat.EndTime - stat.StartTime
+}
+
+//
+// Write will write the content of stat to `file`.
+//
+func (stat *Stat) Write(file string) (e error) {
+ if file == "" {
+ return
+ }
+
+ writer := &dsv.Writer{}
+ e = writer.OpenOutput(file)
+ if e != nil {
+ return e
+ }
+
+ e = writer.WriteRawRow(stat.ToRow(), nil, nil)
+ if e != nil {
+ return e
+ }
+
+ return writer.Close()
+}
diff --git a/lib/mining/classifier/stats.go b/lib/mining/classifier/stats.go
new file mode 100644
index 00000000..b29d3510
--- /dev/null
+++ b/lib/mining/classifier/stats.go
@@ -0,0 +1,143 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package classifier
+
+import (
+ "github.com/shuLhan/share/lib/dsv"
+)
+
+//
+// Stats define list of statistic values.
+//
+type Stats []*Stat
+
+//
+// Add will add other stat object to the slice.
+//
+func (stats *Stats) Add(stat *Stat) {
+ *stats = append(*stats, stat)
+}
+
+//
+// StartTimes return all start times in unix timestamp.
+//
+func (stats *Stats) StartTimes() (times []int64) {
+ for _, stat := range *stats {
+ times = append(times, stat.StartTime)
+ }
+ return
+}
+
+//
+// EndTimes return all end times in unix timestamp.
+//
+func (stats *Stats) EndTimes() (times []int64) {
+ for _, stat := range *stats {
+ times = append(times, stat.EndTime)
+ }
+ return
+}
+
+//
+// OobErrorMeans return all out-of-bag error mean values.
+//
+func (stats *Stats) OobErrorMeans() (oobmeans []float64) {
+ oobmeans = make([]float64, len(*stats))
+ for x, stat := range *stats {
+ oobmeans[x] = stat.OobErrorMean
+ }
+ return
+}
+
+//
+// TPRates return all true-positive rate values.
+//
+func (stats *Stats) TPRates() (tprates []float64) {
+ for _, stat := range *stats {
+ tprates = append(tprates, stat.TPRate)
+ }
+ return
+}
+
+//
+// FPRates return all false-positive rate values.
+//
+func (stats *Stats) FPRates() (fprates []float64) {
+ for _, stat := range *stats {
+ fprates = append(fprates, stat.FPRate)
+ }
+ return
+}
+
+//
+// TNRates will return all true-negative rate values.
+//
+func (stats *Stats) TNRates() (tnrates []float64) {
+ for _, stat := range *stats {
+ tnrates = append(tnrates, stat.TNRate)
+ }
+ return
+}
+
+//
+// Precisions return all precision values.
+//
+func (stats *Stats) Precisions() (precs []float64) {
+ for _, stat := range *stats {
+ precs = append(precs, stat.Precision)
+ }
+ return
+}
+
+//
+// Recalls return all recall values.
+//
+func (stats *Stats) Recalls() (recalls []float64) {
+ return stats.TPRates()
+}
+
+//
+// FMeasures return all F-measure values.
+//
+func (stats *Stats) FMeasures() (fmeasures []float64) {
+ for _, stat := range *stats {
+ fmeasures = append(fmeasures, stat.FMeasure)
+ }
+ return
+}
+
+//
+// Accuracies return all accuracy values.
+//
+func (stats *Stats) Accuracies() (accuracies []float64) {
+ for _, stat := range *stats {
+ accuracies = append(accuracies, stat.Accuracy)
+ }
+ return
+}
+
+//
+// Write will write all statistic data to `file`.
+//
+func (stats *Stats) Write(file string) (e error) {
+ if file == "" {
+ return
+ }
+
+ writer := &dsv.Writer{}
+ e = writer.OpenOutput(file)
+ if e != nil {
+ return e
+ }
+
+ for _, st := range *stats {
+ e = writer.WriteRawRow(st.ToRow(), nil, nil)
+ if e != nil {
+ return e
+ }
+ }
+
+ return writer.Close()
+}
diff --git a/lib/mining/classifier/stats_interface.go b/lib/mining/classifier/stats_interface.go
new file mode 100644
index 00000000..a9b60b9a
--- /dev/null
+++ b/lib/mining/classifier/stats_interface.go
@@ -0,0 +1,68 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package classifier
+
+//
+// ComputeFMeasures given array of precisions and recalls, compute F-measure
+// of each instance and return it.
+//
+func ComputeFMeasures(precisions, recalls []float64) (fmeasures []float64) {
+ // Get the minimum length of precision and recall.
+ // This is to make sure that we are not looping out of range.
+ minlen := len(precisions)
+ recallslen := len(recalls)
+ if recallslen < minlen {
+ minlen = recallslen
+ }
+
+ for x := 0; x < minlen; x++ {
+ f := 2 / ((1 / precisions[x]) + (1 / recalls[x]))
+ fmeasures = append(fmeasures, f)
+ }
+ return
+}
+
+//
+// ComputeAccuracies will compute and return accuracy from array of
+// true-positive, false-positive, true-negative, and false-negative; using
+// formula,
+//
+// (tp + tn) / (tp + tn + tn + fn)
+//
+func ComputeAccuracies(tp, fp, tn, fn []int64) (accuracies []float64) {
+ // Get minimum length of input, just to make sure we are not looping
+ // out of range.
+ minlen := len(tp)
+ if len(fp) < len(tn) {
+ minlen = len(fp)
+ }
+ if len(fn) < minlen {
+ minlen = len(fn)
+ }
+
+ for x := 0; x < minlen; x++ {
+ acc := float64(tp[x]+tn[x]) /
+ float64(tp[x]+fp[x]+tn[x]+fn[x])
+ accuracies = append(accuracies, acc)
+ }
+ return
+}
+
+//
+// ComputeElapsedTimes will compute and return elapsed time between `start`
+// and `end` timestamps.
+//
+func ComputeElapsedTimes(start, end []int64) (elaps []int64) {
+ // Get minimum length.
+ minlen := len(start)
+ if len(end) < minlen {
+ minlen = len(end)
+ }
+
+ for x := 0; x < minlen; x++ {
+ elaps = append(elaps, end[x]-start[x])
+ }
+ return
+}