aboutsummaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorShulhan <m.shulhan@gmail.com>2020-03-19 15:19:04 +0700
committerShulhan <m.shulhan@gmail.com>2020-03-20 03:34:25 +0700
commit9ec24a017e26239145d8abb02b801f243b505fa9 (patch)
treed5b17f04c7e521842bd38147ac3ed76cc46a8075 /lib
parent0ebbaa2d1cb7d7b27f6993bb23642e428ff68649 (diff)
downloadpakakeh.go-9ec24a017e26239145d8abb02b801f243b505fa9.tar.xz
reflect: add function IsEqual that works with Equaler interface
The IsEqual() function is like reflect.DeepEqual but its check if a struct have method "IsEqual", if its exist it will call the method to compare the value.
Diffstat (limited to 'lib')
-rw-r--r--lib/reflect/reflect.go195
-rw-r--r--lib/test/test.go18
2 files changed, 197 insertions, 16 deletions
diff --git a/lib/reflect/reflect.go b/lib/reflect/reflect.go
index b2b6eb58..9572b7b2 100644
--- a/lib/reflect/reflect.go
+++ b/lib/reflect/reflect.go
@@ -7,7 +7,9 @@
//
package reflect
-import "reflect"
+import (
+ "reflect"
+)
//
// IsNil will return true if v's type is chan, func, interface, map, pointer,
@@ -22,3 +24,194 @@ func IsNil(v interface{}) bool {
}
return false
}
+
+//
+// IsEqual is a naive interfaces comparison that check and use Equaler
+// interface.
+//
+func IsEqual(x, y interface{}) bool {
+ if x == nil && y == nil {
+ return true
+ }
+
+ v1 := reflect.ValueOf(x)
+ v2 := reflect.ValueOf(y)
+
+ return isEqual(v1, v2)
+}
+
+func isEqual(v1, v2 reflect.Value) bool {
+ if !v1.IsValid() || !v2.IsValid() {
+ return v1.IsValid() == v2.IsValid()
+ }
+
+ t1 := v1.Type()
+ t2 := v2.Type()
+ if t1 != t2 {
+ return false
+ }
+
+ k1 := v1.Kind()
+ k2 := v2.Kind()
+ if k1 != k2 {
+ return false
+ }
+
+ // For debugging.
+ //log.Printf("v1:%v(%s(%v)) v2:%v(%s(%v))", k1, t1.String(), v1,
+ // k2, t2.String(), v2)
+
+ switch k1 {
+ case reflect.Bool:
+ return v1.Bool() == v2.Bool()
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
+ reflect.Int64:
+ return v1.Int() == v2.Int()
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16,
+ reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return v1.Uint() == v2.Uint()
+
+ case reflect.Float32, reflect.Float64:
+ return v1.Float() == v2.Float()
+
+ case reflect.Complex64, reflect.Complex128:
+ return v1.Complex() == v2.Complex()
+
+ case reflect.Array:
+ if v1.Len() != v2.Len() {
+ return false
+ }
+ for x := 0; x < v1.Len(); x++ {
+ if !isEqual(v1.Index(x), v2.Index(x)) {
+ return false
+ }
+ }
+ return true
+
+ case reflect.Chan:
+ if v1.IsNil() && v2.IsNil() {
+ return true
+ }
+ return t1 == t2
+
+ case reflect.Func:
+ if v1.IsNil() && v2.IsNil() {
+ return true
+ }
+ if v2.IsNil() {
+ return false
+ }
+ return t1 == t2
+
+ case reflect.Interface:
+ if v1.IsNil() && v2.IsNil() {
+ return true
+ }
+ if v2.IsNil() {
+ return false
+ }
+ return isEqual(v1.Elem(), v2.Elem())
+
+ case reflect.Map:
+ return isEqualMap(v1, v2)
+
+ case reflect.Ptr:
+ if v1.IsNil() && v2.IsNil() {
+ return true
+ }
+ if v2.IsNil() {
+ return false
+ }
+ m1 := v1.MethodByName("IsEqual")
+ if m1.IsValid() {
+ res := m1.Call([]reflect.Value{
+ v2,
+ })
+ if len(res) == 1 &&
+ res[0].Kind() == reflect.Bool {
+ return res[0].Bool()
+ }
+ }
+ if v1.Pointer() == v2.Pointer() {
+ return true
+ }
+ return isEqual(v1.Elem(), v2.Elem())
+
+ case reflect.Slice:
+ if v1.IsNil() && v2.IsNil() {
+ return true
+ }
+ if v2.IsNil() {
+ return false
+ }
+
+ l1 := v1.Len()
+ l2 := v2.Len()
+ if l1 != l2 {
+ return false
+ }
+
+ for x := 0; x < l1; x++ {
+ s1 := v1.Index(x)
+ s2 := v2.Index(x)
+ if !isEqual(s1, s2) {
+ return false
+ }
+ }
+ return true
+
+ case reflect.String:
+ return v1.String() == v2.String()
+
+ case reflect.Struct:
+ return isEqualStruct(v1, v2)
+
+ case reflect.UnsafePointer:
+ return v1.UnsafeAddr() == v2.UnsafeAddr()
+ }
+
+ return false
+}
+
+func isEqualMap(v1, v2 reflect.Value) bool {
+ if v1.IsNil() && v2.IsNil() {
+ return true
+ }
+ if v2.IsNil() {
+ return false
+ }
+ if v1.Len() != v2.Len() {
+ return false
+ }
+ keys := v1.MapKeys()
+ for x := 0; x < len(keys); x++ {
+ if !isEqual(v1.MapIndex(keys[x]), v2.MapIndex(keys[x])) {
+ return false
+ }
+ }
+ return true
+}
+
+func isEqualStruct(v1, v2 reflect.Value) bool {
+ m1 := v1.MethodByName("IsEqual")
+ if m1.IsValid() {
+ res := m1.Call([]reflect.Value{
+ v2.Addr(),
+ })
+ if len(res) == 1 && res[0].Kind() == reflect.Bool {
+ return res[0].Bool()
+ }
+ }
+
+ n := v1.NumField()
+ for x := 0; x < n; x++ {
+ f1 := v1.Field(x)
+ f2 := v2.Field(x)
+ if !isEqual(f1, f2) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/lib/test/test.go b/lib/test/test.go
index d94b7006..a4b5f04e 100644
--- a/lib/test/test.go
+++ b/lib/test/test.go
@@ -8,11 +8,10 @@
package test
import (
- "reflect"
"runtime"
"testing"
- libreflect "github.com/shuLhan/share/lib/reflect"
+ "github.com/shuLhan/share/lib/reflect"
)
func printStackTrace(t testing.TB, trace []byte) {
@@ -48,18 +47,7 @@ func printStackTrace(t testing.TB, trace []byte) {
// expectation and then terminate the test routine.
//
func Assert(t *testing.T, name string, exp, got interface{}, equal bool) {
- if exp == nil && got == nil && equal {
- return
- }
- if libreflect.IsNil(exp) && libreflect.IsNil(got) && equal {
- return
- }
- eq, ok := exp.(libreflect.Equaler)
- if ok && eq.IsEqual(got) == equal {
- return
- }
-
- if reflect.DeepEqual(exp, got) == equal {
+ if reflect.IsEqual(exp, got) == equal {
return
}
@@ -80,7 +68,7 @@ func Assert(t *testing.T, name string, exp, got interface{}, equal bool) {
// expectation and then terminate the test routine.
//
func AssertBench(b *testing.B, name string, exp, got interface{}, equal bool) {
- if reflect.DeepEqual(exp, got) == equal {
+ if reflect.IsEqual(exp, got) == equal {
return
}