diff options
| author | Shulhan <ms@kilabit.info> | 2018-09-17 05:04:26 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2018-09-18 01:50:21 +0700 |
| commit | 1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e (patch) | |
| tree | 5fa83fc0faa31e09cae82ac4d467cf8ba5f87fc2 /lib/mining/classifier/cart | |
| parent | 446fef94cd712861221c0098dcdd9ae52aaed0eb (diff) | |
| download | pakakeh.go-1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e.tar.xz | |
Merge package "github.com/shuLhan/go-mining"
Diffstat (limited to 'lib/mining/classifier/cart')
| -rw-r--r-- | lib/mining/classifier/cart/cart.go | 480 | ||||
| -rw-r--r-- | lib/mining/classifier/cart/cart_test.go | 62 | ||||
| -rw-r--r-- | lib/mining/classifier/cart/node.go | 44 |
3 files changed, 586 insertions, 0 deletions
diff --git a/lib/mining/classifier/cart/cart.go b/lib/mining/classifier/cart/cart.go new file mode 100644 index 00000000..449781d9 --- /dev/null +++ b/lib/mining/classifier/cart/cart.go @@ -0,0 +1,480 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +// Package cart implement the Classification and Regression Tree by Breiman, et al. +// CART is binary decision tree. +// +// Breiman, Leo, et al. Classification and regression trees. CRC press, +// 1984. +// +// The implementation is based on Data Mining book, +// +// Han, Jiawei, Micheline Kamber, and Jian Pei. Data mining: concepts and +// techniques: concepts and techniques. Elsevier, 2011. +// +package cart + +import ( + "fmt" + + "github.com/shuLhan/share/lib/debug" + "github.com/shuLhan/share/lib/mining/gain/gini" + "github.com/shuLhan/share/lib/mining/tree/binary" + "github.com/shuLhan/share/lib/numbers" + libstrings "github.com/shuLhan/share/lib/strings" + "github.com/shuLhan/share/lib/tabula" +) + +const ( + // SplitMethodGini if defined in Runtime, the dataset will be splitted + // using Gini gain for each possible value or partition. + // + // This option is used in Runtime.SplitMethod. + SplitMethodGini = "gini" +) + +const ( + // ColFlagParent denote that the column is parent/split node. + ColFlagParent = 1 + // ColFlagSkip denote that the column would be skipped. + ColFlagSkip = 2 +) + +// +// Runtime data for building CART. +// +type Runtime struct { + // SplitMethod define the criteria to used for splitting. + SplitMethod string `json:"SplitMethod"` + // NRandomFeature if less or equal to zero compute gain on all feature, + // otherwise select n random feature and compute gain only on selected + // features. + NRandomFeature int `json:"NRandomFeature"` + // OOBErrVal is the last out-of-bag error value in the tree. + OOBErrVal float64 + // Tree in classification. + Tree binary.Tree +} + +// +// New create new Runtime object. +// +func New(D tabula.ClasetInterface, splitMethod string, nRandomFeature int) ( + *Runtime, error, +) { + runtime := &Runtime{ + SplitMethod: splitMethod, + NRandomFeature: nRandomFeature, + Tree: binary.Tree{}, + } + + e := runtime.Build(D) + if e != nil { + return nil, e + } + + return runtime, nil +} + +// +// Build will create a tree using CART algorithm. +// +func (runtime *Runtime) Build(D tabula.ClasetInterface) (e error) { + // Re-check input configuration. + switch runtime.SplitMethod { + case SplitMethodGini: + // Do nothing. + default: + // Set default split method to Gini index. + runtime.SplitMethod = SplitMethodGini + } + + runtime.Tree.Root, e = runtime.splitTreeByGain(D) + + return +} + +// +// splitTreeByGain calculate the gain in all dataset, and split into two node: +// left and right. +// +// Return node with the split information. +// +func (runtime *Runtime) splitTreeByGain(D tabula.ClasetInterface) ( + node *binary.BTNode, + e error, +) { + node = &binary.BTNode{} + + D.RecountMajorMinor() + + // if dataset is empty return node labeled with majority classes in + // dataset. + nrow := D.GetNRow() + + if nrow <= 0 { + if debug.Value >= 2 { + fmt.Printf("[cart] empty dataset (%s) : %v\n", + D.MajorityClass(), D) + } + + node.Value = NodeValue{ + IsLeaf: true, + Class: D.MajorityClass(), + Size: 0, + } + return node, nil + } + + // if all dataset is in the same class, return node as leaf with class + // is set to that class. + single, name := D.IsInSingleClass() + if single { + if debug.Value >= 2 { + fmt.Printf("[cart] in single class (%s): %v\n", name, + D.GetColumns()) + } + + node.Value = NodeValue{ + IsLeaf: true, + Class: name, + Size: nrow, + } + return node, nil + } + + if debug.Value >= 2 { + fmt.Println("[cart] D:", D) + } + + // calculate the Gini gain for each attribute. + gains := runtime.computeGain(D) + + // get attribute with maximum Gini gain. + MaxGainIdx := gini.FindMaxGain(&gains) + MaxGain := gains[MaxGainIdx] + + // if maxgain value is 0, use majority class as node and terminate + // the process + if MaxGain.GetMaxGainValue() == 0 { + if debug.Value >= 2 { + fmt.Println("[cart] max gain 0 with target", + D.GetClassAsStrings(), + " and majority class is ", D.MajorityClass()) + } + + node.Value = NodeValue{ + IsLeaf: true, + Class: D.MajorityClass(), + Size: 0, + } + return node, nil + } + + // using the sorted index in MaxGain, sort all field in dataset + tabula.SortColumnsByIndex(D, MaxGain.SortedIndex) + + if debug.Value >= 2 { + fmt.Println("[cart] maxgain:", MaxGain) + } + + // Now that we have attribute with max gain in MaxGainIdx, and their + // gain dan partition value in Gains[MaxGainIdx] and + // GetMaxPartValue(), we split the dataset based on type of max-gain + // attribute. + // If its continuous, split the attribute using numeric value. + // If its discrete, split the attribute using subset (partition) of + // nominal values. + var splitV interface{} + + if MaxGain.IsContinu { + splitV = MaxGain.GetMaxPartGainValue() + } else { + attrPartV := MaxGain.GetMaxPartGainValue() + attrSubV := attrPartV.(libstrings.Row) + splitV = attrSubV[0] + } + + if debug.Value >= 2 { + fmt.Println("[cart] maxgainindex:", MaxGainIdx) + fmt.Println("[cart] split v:", splitV) + } + + node.Value = NodeValue{ + SplitAttrName: D.GetColumn(MaxGainIdx).GetName(), + IsLeaf: false, + IsContinu: MaxGain.IsContinu, + Size: nrow, + SplitAttrIdx: MaxGainIdx, + SplitV: splitV, + } + + dsL, dsR, e := tabula.SplitRowsByValue(D, MaxGainIdx, splitV) + + if e != nil { + return node, e + } + + splitL := dsL.(tabula.ClasetInterface) + splitR := dsR.(tabula.ClasetInterface) + + // Set the flag to parent in attribute referenced by + // MaxGainIdx, so it will not computed again in the next round. + cols := splitL.GetColumns() + for x := range *cols { + if x == MaxGainIdx { + (*cols)[x].Flag = ColFlagParent + } else { + (*cols)[x].Flag = 0 + } + } + + cols = splitR.GetColumns() + for x := range *cols { + if x == MaxGainIdx { + (*cols)[x].Flag = ColFlagParent + } else { + (*cols)[x].Flag = 0 + } + } + + nodeLeft, e := runtime.splitTreeByGain(splitL) + if e != nil { + return node, e + } + + nodeRight, e := runtime.splitTreeByGain(splitR) + if e != nil { + return node, e + } + + node.SetLeft(nodeLeft) + node.SetRight(nodeRight) + + return node, nil +} + +// SelectRandomFeature if NRandomFeature is greater than zero, select and +// compute gain in n random features instead of in all features +func (runtime *Runtime) SelectRandomFeature(D tabula.ClasetInterface) { + if runtime.NRandomFeature <= 0 { + // all features selected + return + } + + ncols := D.GetNColumn() + + // count all features minus class + nfeature := ncols - 1 + if runtime.NRandomFeature >= nfeature { + // Do nothing if number of random feature equal or greater than + // number of feature in dataset. + return + } + + // exclude class index and parent node index + excludeIdx := []int{D.GetClassIndex()} + cols := D.GetColumns() + for x, col := range *cols { + if (col.Flag & ColFlagParent) == ColFlagParent { + excludeIdx = append(excludeIdx, x) + } else { + (*cols)[x].Flag |= ColFlagSkip + } + } + + // Select random features excluding feature in `excludeIdx`. + var pickedIdx []int + for x := 0; x < runtime.NRandomFeature; x++ { + idx := numbers.IntPickRandPositive(ncols, false, pickedIdx, + excludeIdx) + pickedIdx = append(pickedIdx, idx) + + // Remove skip flag on selected column + col := D.GetColumn(idx) + col.Flag = col.Flag &^ ColFlagSkip + } + + if debug.Value >= 1 { + fmt.Println("[cart] selected random features:", pickedIdx) + fmt.Println("[cart] selected columns :", D.GetColumns()) + } +} + +// +// computeGain calculate the gini index for each value in each attribute. +// +func (runtime *Runtime) computeGain(D tabula.ClasetInterface) ( + gains []gini.Gini, +) { + switch runtime.SplitMethod { + case SplitMethodGini: + // create gains value for all attribute minus target class. + gains = make([]gini.Gini, D.GetNColumn()) + } + + runtime.SelectRandomFeature(D) + + classVS := D.GetClassValueSpace() + classIdx := D.GetClassIndex() + classType := D.GetClassType() + + for x, col := range *D.GetColumns() { + // skip class attribute. + if x == classIdx { + continue + } + + // skip column flagged with parent + if (col.Flag & ColFlagParent) == ColFlagParent { + gains[x].Skip = true + continue + } + + // ignore column flagged with skip + if (col.Flag & ColFlagSkip) == ColFlagSkip { + gains[x].Skip = true + continue + } + + // compute gain. + if col.GetType() == tabula.TReal { + attr := col.ToFloatSlice() + + if classType == tabula.TString { + target := D.GetClassAsStrings() + gains[x].ComputeContinu(&attr, &target, + &classVS) + } else { + targetReal := D.GetClassAsReals() + classVSReal := libstrings.ToFloat64(classVS) + + gains[x].ComputeContinuFloat(&attr, + &targetReal, &classVSReal) + } + } else { + attr := col.ToStringSlice() + attrV := col.ValueSpace + + if debug.Value >= 2 { + fmt.Println("[cart] attr :", attr) + fmt.Println("[cart] attrV:", attrV) + } + + target := D.GetClassAsStrings() + gains[x].ComputeDiscrete(&attr, &attrV, &target, + &classVS) + } + + if debug.Value >= 2 { + fmt.Println("[cart] gain :", gains[x]) + } + } + return +} + +// +// Classify return the prediction of one sample. +// +func (runtime *Runtime) Classify(data *tabula.Row) (class string) { + node := runtime.Tree.Root + nodev := node.Value.(NodeValue) + + for !nodev.IsLeaf { + if nodev.IsContinu { + splitV := nodev.SplitV.(float64) + attrV := (*data)[nodev.SplitAttrIdx].Float() + + if attrV < splitV { + node = node.Left + } else { + node = node.Right + } + } else { + splitV := nodev.SplitV.([]string) + attrV := (*data)[nodev.SplitAttrIdx].String() + + if libstrings.IsContain(splitV, attrV) { + node = node.Left + } else { + node = node.Right + } + } + nodev = node.Value.(NodeValue) + } + + return nodev.Class +} + +// +// ClassifySet set the class attribute based on tree classification. +// +func (runtime *Runtime) ClassifySet(data tabula.ClasetInterface) (e error) { + nrow := data.GetNRow() + targetAttr := data.GetClassColumn() + + for i := 0; i < nrow; i++ { + class := runtime.Classify(data.GetRow(i)) + + _ = (*targetAttr).Records[i].SetValue(class, tabula.TString) + } + + return +} + +// +// CountOOBError process out-of-bag data on tree and return error value. +// +func (runtime *Runtime) CountOOBError(oob tabula.Claset) ( + errval float64, + e error, +) { + // save the original target to be compared later. + origTarget := oob.GetClassAsStrings() + + if debug.Value >= 2 { + fmt.Println("[cart] OOB:", oob.Columns) + fmt.Println("[cart] TREE:", &runtime.Tree) + } + + // reset the target. + oobtarget := oob.GetClassColumn() + oobtarget.ClearValues() + + e = runtime.ClassifySet(&oob) + + if e != nil { + // set original target values back. + oobtarget.SetValues(origTarget) + return + } + + target := oobtarget.ToStringSlice() + + if debug.Value >= 2 { + fmt.Println("[cart] original target:", origTarget) + fmt.Println("[cart] classify target:", target) + } + + // count how many target value is miss-classified. + runtime.OOBErrVal, _, _ = libstrings.CountMissRate(origTarget, target) + + // set original target values back. + oobtarget.SetValues(origTarget) + + return runtime.OOBErrVal, nil +} + +// +// String yes, it will print it JSON like format. +// +func (runtime *Runtime) String() (s string) { + s = fmt.Sprintf("NRandomFeature: %d\n"+ + " SplitMethod : %s\n"+ + " Tree :\n%v", runtime.NRandomFeature, + runtime.SplitMethod, + runtime.Tree.String()) + return s +} diff --git a/lib/mining/classifier/cart/cart_test.go b/lib/mining/classifier/cart/cart_test.go new file mode 100644 index 00000000..14e89b12 --- /dev/null +++ b/lib/mining/classifier/cart/cart_test.go @@ -0,0 +1,62 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cart + +import ( + "fmt" + "testing" + + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/tabula" + "github.com/shuLhan/share/lib/test" +) + +const ( + NRows = 150 +) + +func TestCART(t *testing.T) { + fds := "../../testdata/iris/iris.dsv" + + ds := tabula.Claset{} + + _, e := dsv.SimpleRead(fds, &ds) + if nil != e { + t.Fatal(e) + } + + fmt.Println("[cart_test] class index:", ds.GetClassIndex()) + + // copy target to be compared later. + targetv := ds.GetClassAsStrings() + + test.Assert(t, "", NRows, ds.GetNRow(), true) + + // Build CART tree. + CART, e := New(&ds, SplitMethodGini, 0) + if e != nil { + t.Fatal(e) + } + + fmt.Println("[cart_test] CART Tree:\n", CART) + + // Create test set + testset := tabula.Claset{} + _, e = dsv.SimpleRead(fds, &testset) + + if nil != e { + t.Fatal(e) + } + + testset.GetClassColumn().ClearValues() + + // Classifiy test set + e = CART.ClassifySet(&testset) + if nil != e { + t.Fatal(e) + } + + test.Assert(t, "", targetv, testset.GetClassAsStrings(), true) +} diff --git a/lib/mining/classifier/cart/node.go b/lib/mining/classifier/cart/node.go new file mode 100644 index 00000000..b64dd13c --- /dev/null +++ b/lib/mining/classifier/cart/node.go @@ -0,0 +1,44 @@ +// Copyright 2015 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cart + +import ( + "fmt" + "reflect" +) + +// +// NodeValue of tree in CART. +// +type NodeValue struct { + // Class of leaf node. + Class string + // SplitAttrName define the name of attribute which cause the split. + SplitAttrName string + // IsLeaf define whether node is a leaf or not. + IsLeaf bool + // IsContinu define whether the node split is continuous or discrete. + IsContinu bool + // Size define number of sample that this node hold before splitting. + Size int + // SplitAttrIdx define the attribute which cause the split. + SplitAttrIdx int + // SplitV define the split value. + SplitV interface{} +} + +// +// String will return the value of node for printable. +// +func (nodev *NodeValue) String() (s string) { + if nodev.IsLeaf { + s = fmt.Sprintf("Class: %s", nodev.Class) + } else { + s = fmt.Sprintf("(SplitValue: %v)", + reflect.ValueOf(nodev.SplitV)) + } + + return s +} |
