aboutsummaryrefslogtreecommitdiff
path: root/lib/mining/classifier/rf
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2018-09-17 05:04:26 +0700
committerShulhan <ms@kilabit.info>2018-09-18 01:50:21 +0700
commit1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e (patch)
tree5fa83fc0faa31e09cae82ac4d467cf8ba5f87fc2 /lib/mining/classifier/rf
parent446fef94cd712861221c0098dcdd9ae52aaed0eb (diff)
downloadpakakeh.go-1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e.tar.xz
Merge package "github.com/shuLhan/go-mining"
Diffstat (limited to 'lib/mining/classifier/rf')
-rw-r--r--lib/mining/classifier/rf/rf.go363
-rw-r--r--lib/mining/classifier/rf/rf_bench_test.go22
-rw-r--r--lib/mining/classifier/rf/rf_test.go190
3 files changed, 575 insertions, 0 deletions
diff --git a/lib/mining/classifier/rf/rf.go b/lib/mining/classifier/rf/rf.go
new file mode 100644
index 00000000..86595efd
--- /dev/null
+++ b/lib/mining/classifier/rf/rf.go
@@ -0,0 +1,363 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//
+// Package rf implement ensemble of classifiers using random forest
+// algorithm by Breiman and Cutler.
+//
+// Breiman, Leo. "Random forests." Machine learning 45.1 (2001): 5-32.
+//
+// The implementation is based on various sources and using author experience.
+//
+package rf
+
+import (
+ "errors"
+ "fmt"
+ "math"
+
+ "github.com/shuLhan/share/lib/debug"
+ "github.com/shuLhan/share/lib/mining/classifier"
+ "github.com/shuLhan/share/lib/mining/classifier/cart"
+ "github.com/shuLhan/share/lib/numbers"
+ libstrings "github.com/shuLhan/share/lib/strings"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+const (
+ tag = "[rf]"
+
+ // DefNumTree default number of tree.
+ DefNumTree = 100
+
+ // DefPercentBoot default percentage of sample that will be used for
+ // bootstraping a tree.
+ DefPercentBoot = 66
+
+ // DefOOBStatsFile default statistic file output.
+ DefOOBStatsFile = "rf.oob.stat"
+
+ // DefPerfFile default performance file output.
+ DefPerfFile = "rf.perf"
+
+ // DefStatFile default statistic file.
+ DefStatFile = "rf.stat"
+)
+
+var (
+ // ErrNoInput will tell you when no input is given.
+ ErrNoInput = errors.New("rf: input samples is empty")
+)
+
+//
+// Runtime contains input and output configuration when generating random forest.
+//
+type Runtime struct {
+ // Runtime embed common fields for classifier.
+ classifier.Runtime
+
+ // NTree number of tree in forest.
+ NTree int `json:"NTree"`
+ // NRandomFeature number of feature randomly selected for each tree.
+ NRandomFeature int `json:"NRandomFeature"`
+ // PercentBoot percentage of sample for bootstraping.
+ PercentBoot int `json:"PercentBoot"`
+
+ // nSubsample number of samples used for bootstraping.
+ nSubsample int
+ // trees contain all tree in the forest.
+ trees []cart.Runtime
+ // bagIndices contain list of index of selected samples at bootstraping
+ // for book-keeping.
+ bagIndices [][]int
+}
+
+//
+// Trees return all tree in forest.
+//
+func (forest *Runtime) Trees() []cart.Runtime {
+ return forest.trees
+}
+
+//
+// AddCartTree add tree to forest
+//
+func (forest *Runtime) AddCartTree(tree cart.Runtime) {
+ forest.trees = append(forest.trees, tree)
+}
+
+//
+// AddBagIndex add bagging index for book keeping.
+//
+func (forest *Runtime) AddBagIndex(bagIndex []int) {
+ forest.bagIndices = append(forest.bagIndices, bagIndex)
+}
+
+//
+// Initialize will check forest inputs and set it to default values if invalid.
+//
+// It will also calculate number of random samples for each tree using,
+//
+// number-of-sample * percentage-of-bootstrap
+//
+//
+func (forest *Runtime) Initialize(samples tabula.ClasetInterface) error {
+ if forest.NTree <= 0 {
+ forest.NTree = DefNumTree
+ }
+ if forest.PercentBoot <= 0 {
+ forest.PercentBoot = DefPercentBoot
+ }
+ if forest.NRandomFeature <= 0 {
+ // Set default value to square-root of features.
+ ncol := samples.GetNColumn() - 1
+ forest.NRandomFeature = int(math.Sqrt(float64(ncol)))
+ }
+ if forest.OOBStatsFile == "" {
+ forest.OOBStatsFile = DefOOBStatsFile
+ }
+ if forest.PerfFile == "" {
+ forest.PerfFile = DefPerfFile
+ }
+ if forest.StatFile == "" {
+ forest.StatFile = DefStatFile
+ }
+
+ forest.nSubsample = int(float32(samples.GetNRow()) *
+ (float32(forest.PercentBoot) / 100.0))
+
+ return forest.Runtime.Initialize()
+}
+
+//
+// Build the forest using samples dataset.
+//
+// Algorithm,
+//
+// (0) Recheck input value: number of tree, percentage bootstrap, etc; and
+// Open statistic file output.
+// (1) For 0 to NTree,
+// (1.1) Create new tree, repeat until all trees has been build.
+// (2) Compute and write total statistic.
+//
+func (forest *Runtime) Build(samples tabula.ClasetInterface) (e error) {
+ // check input samples
+ if samples == nil {
+ return ErrNoInput
+ }
+
+ // (0)
+ e = forest.Initialize(samples)
+ if e != nil {
+ return
+ }
+
+ fmt.Println(tag, "Training set :", samples)
+ fmt.Println(tag, "Sample (one row):", samples.GetRow(0))
+ fmt.Println(tag, "Forest config :", forest)
+
+ // (1)
+ for t := 0; t < forest.NTree; t++ {
+ if debug.Value >= 1 {
+ fmt.Println(tag, "tree #", t)
+ }
+
+ // (1.1)
+ for {
+ _, _, e = forest.GrowTree(samples)
+ if e == nil {
+ break
+ }
+
+ fmt.Println(tag, "error:", e)
+ }
+ }
+
+ // (2)
+ return forest.Finalize()
+}
+
+//
+// GrowTree build a new tree in forest, return OOB error value or error if tree
+// can not grow.
+//
+// Algorithm,
+//
+// (1) Select random samples with replacement, also with OOB.
+// (2) Build tree using CART, without pruning.
+// (3) Add tree to forest.
+// (4) Save index of random samples for calculating error rate later.
+// (5) Run OOB on forest.
+// (6) Calculate OOB error rate and statistic values.
+//
+func (forest *Runtime) GrowTree(samples tabula.ClasetInterface) (
+ cm *classifier.CM, stat *classifier.Stat, e error,
+) {
+ stat = &classifier.Stat{}
+ stat.ID = int64(len(forest.trees))
+ stat.Start()
+
+ // (1)
+ bag, oob, bagIdx, oobIdx := tabula.RandomPickRows(
+ samples.(tabula.DatasetInterface),
+ forest.nSubsample, true)
+
+ bagset := bag.(tabula.ClasetInterface)
+
+ if debug.Value >= 2 {
+ bagset.RecountMajorMinor()
+ fmt.Println(tag, "Bagging:", bagset)
+ }
+
+ // (2)
+ cart, e := cart.New(bagset, cart.SplitMethodGini,
+ forest.NRandomFeature)
+ if e != nil {
+ return nil, nil, e
+ }
+
+ // (3)
+ forest.AddCartTree(*cart)
+
+ // (4)
+ forest.AddBagIndex(bagIdx)
+
+ // (5)
+ if forest.RunOOB {
+ oobset := oob.(tabula.ClasetInterface)
+ _, cm, _ = forest.ClassifySet(oobset, oobIdx)
+
+ forest.AddOOBCM(cm)
+ }
+
+ stat.End()
+
+ if debug.Value >= 3 && forest.RunOOB {
+ fmt.Println(tag, "Elapsed time (s):", stat.ElapsedTime)
+ }
+
+ forest.AddStat(stat)
+
+ // (6)
+ if forest.RunOOB {
+ forest.ComputeStatFromCM(stat, cm)
+
+ if debug.Value >= 2 {
+ fmt.Println(tag, "OOB stat:", stat)
+ }
+ }
+
+ forest.ComputeStatTotal(stat)
+ e = forest.WriteOOBStat(stat)
+
+ return cm, stat, e
+}
+
+//
+// ClassifySet given a samples predict their class by running each sample in
+// forest, adn return their class prediction with confusion matrix.
+// `samples` is the sample that will be predicted, `sampleIds` is the index of
+// samples.
+// If `sampleIds` is not nil, then sample index will be checked in each tree,
+// if the sample is used for training, their vote is not counted.
+//
+// Algorithm,
+//
+// (0) Get value space (possible class values in dataset)
+// (1) For each row in test-set,
+// (1.1) collect votes in all trees,
+// (1.2) select majority class vote, and
+// (1.3) compute and save the actual class probabilities.
+// (2) Compute confusion matrix from predictions.
+// (3) Compute stat from confusion matrix.
+// (4) Write the stat to file only if sampleIds is empty, which mean its run
+// not from OOB set.
+//
+func (forest *Runtime) ClassifySet(samples tabula.ClasetInterface,
+ sampleIds []int,
+) (
+ predicts []string, cm *classifier.CM, probs []float64,
+) {
+ stat := classifier.Stat{}
+ stat.Start()
+
+ if len(sampleIds) <= 0 {
+ fmt.Println(tag, "Classify set:", samples)
+ fmt.Println(tag, "Classify set sample (one row):",
+ samples.GetRow(0))
+ }
+
+ // (0)
+ vs := samples.GetClassValueSpace()
+ actuals := samples.GetClassAsStrings()
+ sampleIdx := -1
+
+ // (1)
+ rows := samples.GetRows()
+ for x, row := range *rows {
+ // (1.1)
+ if len(sampleIds) > 0 {
+ sampleIdx = sampleIds[x]
+ }
+ votes := forest.Votes(row, sampleIdx)
+
+ // (1.2)
+ classProbs := libstrings.FrequencyOfTokens(votes, vs, false)
+
+ _, idx, ok := numbers.Floats64FindMax(classProbs)
+
+ if ok {
+ predicts = append(predicts, vs[idx])
+ }
+
+ // (1.3)
+ probs = append(probs, classProbs[0])
+ }
+
+ // (2)
+ cm = forest.ComputeCM(sampleIds, vs, actuals, predicts)
+
+ // (3)
+ forest.ComputeStatFromCM(&stat, cm)
+ stat.End()
+
+ if len(sampleIds) <= 0 {
+ fmt.Println(tag, "CM:", cm)
+ fmt.Println(tag, "Classifying stat:", stat)
+ _ = stat.Write(forest.StatFile)
+ }
+
+ return predicts, cm, probs
+}
+
+//
+// Votes will return votes, or classes, in each tree based on sample.
+// If checkIdx is true then the `sampleIdx` will be checked in if it has been used
+// when training the tree, if its exist then the sample will be skipped.
+//
+// (1) If row is used to build the tree then skip it,
+// (2) classify row in tree,
+// (3) save tree class value.
+//
+func (forest *Runtime) Votes(sample *tabula.Row, sampleIdx int) (
+ votes []string,
+) {
+ for x, tree := range forest.trees {
+ // (1)
+ if sampleIdx >= 0 {
+ exist := numbers.IntsIsExist(forest.bagIndices[x],
+ sampleIdx)
+ if exist {
+ continue
+ }
+ }
+
+ // (2)
+ class := tree.Classify(sample)
+
+ // (3)
+ votes = append(votes, class)
+ }
+ return votes
+}
diff --git a/lib/mining/classifier/rf/rf_bench_test.go b/lib/mining/classifier/rf/rf_bench_test.go
new file mode 100644
index 00000000..eddd721c
--- /dev/null
+++ b/lib/mining/classifier/rf/rf_bench_test.go
@@ -0,0 +1,22 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rf
+
+import (
+ "testing"
+)
+
+func BenchmarkPhoneme(b *testing.B) {
+ SampleDsvFile = "../../testdata/phoneme/phoneme.dsv"
+ OOBStatsFile = "phoneme.oob"
+ PerfFile = "phoneme.perf"
+
+ MinFeature = 3
+ MaxFeature = 4
+
+ for x := 0; x < b.N; x++ {
+ runRandomForest()
+ }
+}
diff --git a/lib/mining/classifier/rf/rf_test.go b/lib/mining/classifier/rf/rf_test.go
new file mode 100644
index 00000000..ba96125b
--- /dev/null
+++ b/lib/mining/classifier/rf/rf_test.go
@@ -0,0 +1,190 @@
+// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rf
+
+import (
+ "fmt"
+ "log"
+ "os"
+ "testing"
+
+ "github.com/shuLhan/share/lib/dsv"
+ "github.com/shuLhan/share/lib/mining/classifier"
+ "github.com/shuLhan/share/lib/tabula"
+)
+
+// Global options to run for each test.
+var (
+ // SampleDsvFile is the file that contain samples config.
+ SampleDsvFile string
+ // DoTest if its true then the dataset will splited into training and
+ // test set with random selection without replacement.
+ DoTest = false
+ // NTree number of tree to generate.
+ NTree = 100
+ // NBootstrap percentage of sample used as subsample.
+ NBootstrap = 66
+ // MinFeature number of feature to begin with.
+ MinFeature = 1
+ // MaxFeature maximum number of feature to test
+ MaxFeature = -1
+ // RunOOB if its true the the OOB samples will be used to test the
+ // model in each iteration.
+ RunOOB = true
+ // OOBStatsFile is the file where OOB statistic will be saved.
+ OOBStatsFile string
+ // PerfFile is the file where performance statistic will be saved.
+ PerfFile string
+ // StatFile is the file where classifying statistic will be saved.
+ StatFile string
+)
+
+func getSamples() (train, test tabula.ClasetInterface) {
+ samples := tabula.Claset{}
+ _, e := dsv.SimpleRead(SampleDsvFile, &samples)
+ if nil != e {
+ log.Fatal(e)
+ }
+
+ if !DoTest {
+ return &samples, nil
+ }
+
+ ntrain := int(float32(samples.Len()) * (float32(NBootstrap) / 100.0))
+
+ bag, oob, _, _ := tabula.RandomPickRows(&samples, ntrain, false)
+
+ train = bag.(tabula.ClasetInterface)
+ test = oob.(tabula.ClasetInterface)
+
+ train.SetClassIndex(samples.GetClassIndex())
+ test.SetClassIndex(samples.GetClassIndex())
+
+ return train, test
+}
+
+func runRandomForest() {
+ oobStatsFile := OOBStatsFile
+ perfFile := PerfFile
+ statFile := StatFile
+
+ trainset, testset := getSamples()
+
+ if MaxFeature < 0 {
+ MaxFeature = trainset.GetNColumn()
+ }
+
+ for nfeature := MinFeature; nfeature < MaxFeature; nfeature++ {
+ // Add prefix to OOB stats file.
+ oobStatsFile = fmt.Sprintf("N%d.%s", nfeature, OOBStatsFile)
+
+ // Add prefix to performance file.
+ perfFile = fmt.Sprintf("N%d.%s", nfeature, PerfFile)
+
+ // Add prefix to stat file.
+ statFile = fmt.Sprintf("N%d.%s", nfeature, StatFile)
+
+ // Create and build random forest.
+ forest := Runtime{
+ Runtime: classifier.Runtime{
+ RunOOB: RunOOB,
+ OOBStatsFile: oobStatsFile,
+ PerfFile: perfFile,
+ StatFile: statFile,
+ },
+ NTree: NTree,
+ NRandomFeature: nfeature,
+ PercentBoot: NBootstrap,
+ }
+
+ e := forest.Build(trainset)
+ if e != nil {
+ log.Fatal(e)
+ }
+
+ if DoTest {
+ predicts, _, probs := forest.ClassifySet(testset, nil)
+
+ forest.Performance(testset, predicts, probs)
+ e = forest.WritePerformance()
+ if e != nil {
+ log.Fatal(e)
+ }
+ }
+ }
+}
+
+func TestEnsemblingGlass(t *testing.T) {
+ SampleDsvFile = "../../testdata/forensic_glass/fgl.dsv"
+ RunOOB = false
+ OOBStatsFile = "glass.oob"
+ StatFile = "glass.stat"
+ PerfFile = "glass.perf"
+ DoTest = true
+
+ runRandomForest()
+}
+
+func TestEnsemblingIris(t *testing.T) {
+ SampleDsvFile = "../../testdata/iris/iris.dsv"
+ OOBStatsFile = "iris.oob"
+
+ runRandomForest()
+}
+
+func TestEnsemblingPhoneme(t *testing.T) {
+ SampleDsvFile = "../../testdata/phoneme/phoneme.dsv"
+ OOBStatsFile = "phoneme.oob.stat"
+ StatFile = "phoneme.stat"
+ PerfFile = "phoneme.perf"
+
+ NTree = 200
+ MinFeature = 3
+ MaxFeature = 4
+ RunOOB = false
+ DoTest = true
+
+ runRandomForest()
+}
+
+func TestEnsemblingSmotePhoneme(t *testing.T) {
+ SampleDsvFile = "../../resampling/smote/phoneme_smote.dsv"
+ OOBStatsFile = "phonemesmote.oob"
+
+ MinFeature = 3
+ MaxFeature = 4
+
+ runRandomForest()
+}
+
+func TestEnsemblingLnsmotePhoneme(t *testing.T) {
+ SampleDsvFile = "../../resampling/lnsmote/phoneme_lnsmote.dsv"
+ OOBStatsFile = "phonemelnsmote.oob"
+
+ MinFeature = 3
+ MaxFeature = 4
+
+ runRandomForest()
+}
+
+func TestWvc2010Lnsmote(t *testing.T) {
+ SampleDsvFile = "../../testdata/wvc2010lnsmote/wvc2010_features.lnsmote.dsv"
+ OOBStatsFile = "wvc2010lnsmote.oob"
+
+ NTree = 1
+ MinFeature = 5
+ MaxFeature = 6
+
+ runRandomForest()
+}
+
+func TestMain(m *testing.M) {
+ envTestRF := os.Getenv("TEST_RF")
+ if len(envTestRF) == 0 {
+ os.Exit(0)
+ }
+
+ os.Exit(m.Run())
+}