diff options
| author | Shulhan <ms@kilabit.info> | 2018-09-17 05:04:26 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2018-09-18 01:50:21 +0700 |
| commit | 1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e (patch) | |
| tree | 5fa83fc0faa31e09cae82ac4d467cf8ba5f87fc2 /lib/mining/classifier/runtime.go | |
| parent | 446fef94cd712861221c0098dcdd9ae52aaed0eb (diff) | |
| download | pakakeh.go-1cae4ca316afa5d177fdbf7a98a0ec7fffe76a3e.tar.xz | |
Merge package "github.com/shuLhan/go-mining"
Diffstat (limited to 'lib/mining/classifier/runtime.go')
| -rw-r--r-- | lib/mining/classifier/runtime.go | 443 |
1 files changed, 443 insertions, 0 deletions
diff --git a/lib/mining/classifier/runtime.go b/lib/mining/classifier/runtime.go new file mode 100644 index 00000000..0f7a7755 --- /dev/null +++ b/lib/mining/classifier/runtime.go @@ -0,0 +1,443 @@ +// Copyright 2016 Mhd Sulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package classifier + +import ( + "fmt" + "math" + + "github.com/shuLhan/share/lib/debug" + "github.com/shuLhan/share/lib/dsv" + "github.com/shuLhan/share/lib/numbers" + libstrings "github.com/shuLhan/share/lib/strings" + "github.com/shuLhan/share/lib/tabula" +) + +const ( + tag = "[classifier.runtime]" +) + +// +// Runtime define a generic type which provide common fields that can be +// embedded by the real classifier (e.g. RandomForest). +// +type Runtime struct { + // RunOOB if its true the OOB will be computed, default is false. + RunOOB bool `json:"RunOOB"` + + // OOBStatsFile is the file where OOB statistic will be written. + OOBStatsFile string `json:"OOBStatsFile"` + + // PerfFile is the file where statistic of performance will be written. + PerfFile string `json:"PerfFile"` + + // StatFile is the file where statistic of classifying samples will be + // written. + StatFile string `json:"StatFile"` + + // oobCms contain confusion matrix value for each OOB in iteration. + oobCms []CM + + // oobStats contain statistic of classifier for each OOB in iteration. + oobStats Stats + + // oobStatTotal contain total OOB statistic values. + oobStatTotal Stat + + // oobWriter contain file writer for statistic. + oobWriter *dsv.Writer + + // perfs contain performance statistic per sample, after classifying + // sample on classifier. + perfs Stats +} + +// +// Initialize will start the runtime for processing by saving start time and +// opening stats file. +// +func (rt *Runtime) Initialize() error { + rt.oobStatTotal.Start() + + return rt.OpenOOBStatsFile() +} + +// +// Finalize finish the runtime, compute total statistic, write it to file, and +// close the file. +// +func (rt *Runtime) Finalize() (e error) { + st := &rt.oobStatTotal + + st.End() + st.ID = int64(len(rt.oobStats)) + + e = rt.WriteOOBStat(st) + if e != nil { + return e + } + + return rt.CloseOOBStatsFile() +} + +// +// OOBStats return all statistic objects. +// +func (rt *Runtime) OOBStats() *Stats { + return &rt.oobStats +} + +// +// StatTotal return total statistic. +// +func (rt *Runtime) StatTotal() *Stat { + return &rt.oobStatTotal +} + +// +// AddOOBCM will append new confusion matrix. +// +func (rt *Runtime) AddOOBCM(cm *CM) { + rt.oobCms = append(rt.oobCms, *cm) +} + +// +// AddStat will append new classifier statistic data. +// +func (rt *Runtime) AddStat(stat *Stat) { + rt.oobStats = append(rt.oobStats, stat) +} + +// +// ComputeCM will compute confusion matrix of sample using value space, actual +// and prediction values. +// +func (rt *Runtime) ComputeCM(sampleIds []int, + vs, actuals, predicts []string, +) ( + cm *CM, +) { + cm = &CM{} + + cm.ComputeStrings(vs, actuals, predicts) + cm.GroupIndexPredictionsStrings(sampleIds, actuals, predicts) + + if debug.Value >= 2 { + fmt.Println(tag, cm) + } + + return cm +} + +// +// ComputeStatFromCM will compute statistic using confusion matrix. +// +func (rt *Runtime) ComputeStatFromCM(stat *Stat, cm *CM) { + + stat.OobError = cm.GetFalseRate() + + stat.OobErrorMean = rt.oobStatTotal.OobError / + float64(len(rt.oobStats)+1) + + stat.TP = int64(cm.TP()) + stat.FP = int64(cm.FP()) + stat.TN = int64(cm.TN()) + stat.FN = int64(cm.FN()) + + t := float64(stat.TP + stat.FN) + if t == 0 { + stat.TPRate = 0 + } else { + stat.TPRate = float64(stat.TP) / t + } + + t = float64(stat.FP + stat.TN) + if t == 0 { + stat.FPRate = 0 + } else { + stat.FPRate = float64(stat.FP) / t + } + + t = float64(stat.FP + stat.TN) + if t == 0 { + stat.TNRate = 0 + } else { + stat.TNRate = float64(stat.TN) / t + } + + t = float64(stat.TP + stat.FP) + if t == 0 { + stat.Precision = 0 + } else { + stat.Precision = float64(stat.TP) / t + } + + t = (1 / stat.Precision) + (1 / stat.TPRate) + if t == 0 { + stat.FMeasure = 0 + } else { + stat.FMeasure = 2 / t + } + + t = float64(stat.TP + stat.TN + stat.FP + stat.FN) + if t == 0 { + stat.Accuracy = 0 + } else { + stat.Accuracy = float64(stat.TP+stat.TN) / t + } + + if debug.Value >= 1 { + rt.PrintOobStat(stat, cm) + rt.PrintStat(stat) + } +} + +// +// ComputeStatTotal compute total statistic. +// +func (rt *Runtime) ComputeStatTotal(stat *Stat) { + if stat == nil { + return + } + + nstat := len(rt.oobStats) + if nstat == 0 { + return + } + + t := &rt.oobStatTotal + + t.OobError += stat.OobError + t.OobErrorMean = t.OobError / float64(nstat) + t.TP += stat.TP + t.FP += stat.FP + t.TN += stat.TN + t.FN += stat.FN + + total := float64(t.TP + t.FN) + if total == 0 { + t.TPRate = 0 + } else { + t.TPRate = float64(t.TP) / total + } + + total = float64(t.FP + t.TN) + if total == 0 { + t.FPRate = 0 + } else { + t.FPRate = float64(t.FP) / total + } + + total = float64(t.FP + t.TN) + if total == 0 { + t.TNRate = 0 + } else { + t.TNRate = float64(t.TN) / total + } + + total = float64(t.TP + t.FP) + if total == 0 { + t.Precision = 0 + } else { + t.Precision = float64(t.TP) / total + } + + total = (1 / t.Precision) + (1 / t.TPRate) + if total == 0 { + t.FMeasure = 0 + } else { + t.FMeasure = 2 / total + } + + total = float64(t.TP + t.TN + t.FP + t.FN) + if total == 0 { + t.Accuracy = 0 + } else { + t.Accuracy = float64(t.TP+t.TN) / total + } +} + +// +// OpenOOBStatsFile will open statistic file for output. +// +func (rt *Runtime) OpenOOBStatsFile() error { + if rt.oobWriter != nil { + _ = rt.CloseOOBStatsFile() + } + rt.oobWriter = &dsv.Writer{} + return rt.oobWriter.OpenOutput(rt.OOBStatsFile) +} + +// +// WriteOOBStat will write statistic of process to file. +// +func (rt *Runtime) WriteOOBStat(stat *Stat) error { + if rt.oobWriter == nil { + return nil + } + if stat == nil { + return nil + } + return rt.oobWriter.WriteRawRow(stat.ToRow(), nil, nil) +} + +// +// CloseOOBStatsFile will close statistics file for writing. +// +func (rt *Runtime) CloseOOBStatsFile() (e error) { + if rt.oobWriter == nil { + return + } + + e = rt.oobWriter.Close() + rt.oobWriter = nil + + return +} + +// +// PrintOobStat will print the out-of-bag statistic to standard output. +// +func (rt *Runtime) PrintOobStat(stat *Stat, cm *CM) { + fmt.Printf("%s OOB error rate: %.4f,"+ + " total: %.4f, mean %.4f, true rate: %.4f\n", tag, + stat.OobError, rt.oobStatTotal.OobError, + stat.OobErrorMean, cm.GetTrueRate()) +} + +// +// PrintStat will print statistic value to standard output. +// +func (rt *Runtime) PrintStat(stat *Stat) { + if stat == nil { + statslen := len(rt.oobStats) + if statslen <= 0 { + return + } + stat = rt.oobStats[statslen-1] + } + + fmt.Printf("%s TPRate: %.4f, FPRate: %.4f,"+ + " TNRate: %.4f, precision: %.4f, f-measure: %.4f,"+ + " accuracy: %.4f\n", tag, stat.TPRate, stat.FPRate, + stat.TNRate, stat.Precision, stat.FMeasure, stat.Accuracy) +} + +// +// PrintStatTotal will print total statistic to standard output. +// +func (rt *Runtime) PrintStatTotal(st *Stat) { + if st == nil { + st = &rt.oobStatTotal + } + rt.PrintStat(st) +} + +// +// Performance given an actuals class label and their probabilities, compute +// the performance statistic of classifier. +// +// Algorithm, +// (1) Sort the probabilities in descending order. +// (2) Sort the actuals and predicts using sorted index from probs +// (3) Compute tpr, fpr, precision +// (4) Write performance to file. +// +func (rt *Runtime) Performance(samples tabula.ClasetInterface, + predicts []string, probs []float64, +) ( + perfs Stats, +) { + // (1) + actuals := samples.GetClassAsStrings() + sortedIds := numbers.IntCreateSeq(0, len(probs)-1) + numbers.Floats64InplaceMergesort(probs, sortedIds, 0, len(probs), + false) + + // (2) + libstrings.SortByIndex(&actuals, sortedIds) + libstrings.SortByIndex(&predicts, sortedIds) + + // (3) + rt.computePerfByProbs(samples, actuals, probs) + + return rt.perfs +} + +func trapezoidArea(fp, fpprev, tp, tpprev int64) float64 { + base := math.Abs(float64(fp - fpprev)) + heightAvg := float64(tp+tpprev) / float64(2.0) + return base * heightAvg +} + +// +// computePerfByProbs will compute classifier performance using probabilities +// or score `probs`. +// +// This currently only work for two class problem. +// +func (rt *Runtime) computePerfByProbs(samples tabula.ClasetInterface, + actuals []string, probs []float64, +) { + vs := samples.GetClassValueSpace() + nactuals := numbers.IntsTo64(samples.Counts()) + nclass := libstrings.CountTokens(actuals, vs, false) + + pprev := math.Inf(-1) + tp := int64(0) + fp := int64(0) + tpprev := int64(0) + fpprev := int64(0) + + auc := float64(0) + + for x, p := range probs { + if p != pprev { + stat := Stat{} + stat.SetTPRate(tp, nactuals[0]) + stat.SetFPRate(fp, nactuals[1]) + stat.SetPrecisionFromRate(nactuals[0], nactuals[1]) + + auc = auc + trapezoidArea(fp, fpprev, tp, tpprev) + stat.SetAUC(auc) + + rt.perfs = append(rt.perfs, &stat) + + pprev = p + tpprev = tp + fpprev = fp + } + + if actuals[x] == vs[0] { + tp++ + } else { + fp++ + } + } + + stat := Stat{} + stat.SetTPRate(tp, nactuals[0]) + stat.SetFPRate(fp, nactuals[1]) + stat.SetPrecisionFromRate(nactuals[0], nactuals[1]) + + auc = auc + trapezoidArea(fp, fpprev, tp, tpprev) + auc = auc / float64(nclass[0]*nclass[1]) + stat.SetAUC(auc) + + rt.perfs = append(rt.perfs, &stat) + + if len(rt.perfs) >= 2 { + // Replace the first stat with second stat, because of NaN + // value on the first precision. + rt.perfs[0] = rt.perfs[1] + } +} + +// +// WritePerformance will write performance data to file. +// +func (rt *Runtime) WritePerformance() error { + return rt.perfs.Write(rt.PerfFile) +} |
