diff --git a/api.go b/api.go deleted file mode 100644 index 3a24f5e..0000000 --- a/api.go +++ /dev/null @@ -1,11 +0,0 @@ -package teq - -type Teq struct{} - -func (teq *Teq) Equal(t TestingT, a, b interface{}) bool { - ok := a == b - if !ok { - t.Errorf("expected %v, got %v", a, b) - } - return ok -} diff --git a/api_test.go b/api_test.go deleted file mode 100644 index ca73d98..0000000 --- a/api_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package teq_test - -import ( - "fmt" - "testing" - - "github.com/seiyab/teq" -) - -type test struct { - a any - b any - expected []string -} - -func TestEqual(t *testing.T) { - assert := teq.Teq{} - - tests := primitives() - for _, test := range tests { - name := fmt.Sprintf("%T(%v) == %T(%v)", test.a, test.a, test.b, test.b) - t.Run(name, func(t *testing.T) { - mt := &mockT{} - assert.Equal(mt, test.a, test.b) - if len(mt.errors) != len(test.expected) { - t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors)) - } - for i, e := range test.expected { - if mt.errors[i] != e { - t.Errorf("expected %v, got %v at i = %d", e, mt.errors[i], i) - } - } - }) - } -} - -func primitives() []test { - return []test{ - {1, 1, nil}, - {1, 2, []string{"expected 1, got 2"}}, - {uint8(1), uint8(1), nil}, - {uint8(1), uint8(2), []string{"expected 1, got 2"}}, - {"a", "a", nil}, - {"a", "b", []string{"expected a, got b"}}, - - {"a", 1, []string{"expected a, got 1"}}, - } -} diff --git a/deepequal.go b/deepequal.go new file mode 100644 index 0000000..72b5d97 --- /dev/null +++ b/deepequal.go @@ -0,0 +1,149 @@ +// Some code is written referencing following codes: +// - deepequal.go in "reflect" package authored by Go Authors +// - deepequal.go in "github.com/weaveworks/scope/test/reflect" package authored by Weaveworks Ltd + +package teq + +import ( + "reflect" + "unsafe" +) + +// During deepValueEqual, must keep track of checks that are +// in progress. The comparison algorithm assumes that all +// checks in progress are true when it reencounters them. +// Visited comparisons are stored in a map indexed by visit. +type visit struct { + a1 unsafe.Pointer + a2 unsafe.Pointer + typ reflect.Type +} + +const maxDepth = 1_000 + +func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + if depth > maxDepth { + panic("maximum depth exceeded") + } + if !v1.IsValid() || !v2.IsValid() { + return v1.IsValid() == v2.IsValid() + } + if v1.Type() != v2.Type() { + return false + } + + if hard(v1.Kind()) { + if v1.CanAddr() && v2.CanAddr() { + addr1 := v1.Addr().UnsafePointer() + addr2 := v2.Addr().UnsafePointer() + + // Short circuit + if uintptr(addr1) == uintptr(addr2) { + return true + } + if uintptr(addr1) > uintptr(addr2) { + // Canonicalize order to reduce number of entries in visited. + addr1, addr2 = addr2, addr1 + } + + // Short circuit if references are already seen. + typ := v1.Type() + v := visit{addr1, addr2, typ} + if visited[v] { + return true + } + + // Remember for later. + visited[v] = true + } + } + + eqFn, ok := eqs[v1.Kind()] + if !ok { + panic("not implemented") + } + var n next = func(v1, v2 reflect.Value) bool { + return teq.deepValueEqual(v1, v2, visited, depth+1) + } + return eqFn(teq, v1, v2, n) +} + +type next func(v1, v2 reflect.Value) bool + +var eqs = map[reflect.Kind]func(teq Teq, v1, v2 reflect.Value, nx next) bool{ + reflect.Array: arrayEq, + reflect.Slice: todo, + reflect.Interface: interfaceEq, + reflect.Pointer: pointerEq, + reflect.Struct: structEq, + reflect.Map: todo, + reflect.Func: todo, + reflect.Int: intEq, + reflect.Int8: intEq, + reflect.Int16: intEq, + reflect.Int32: intEq, + reflect.Int64: intEq, + reflect.Uint: uintEq, + reflect.Uint8: uintEq, + reflect.Uint16: uintEq, + reflect.Uint32: uintEq, + reflect.Uint64: uintEq, + reflect.Uintptr: uintEq, + reflect.String: stringEq, + reflect.Bool: boolEq, + reflect.Float32: floatEq, + reflect.Float64: floatEq, + reflect.Complex64: complexEq, + reflect.Complex128: complexEq, +} + +func hard(k reflect.Kind) bool { + switch k { + case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct: + return true + } + return false +} + +func todo(teq Teq, v1, v2 reflect.Value, nx next) bool { + panic("not implemented") +} + +func arrayEq(teq Teq, v1, v2 reflect.Value, nx next) bool { + for i := 0; i < v1.Len(); i++ { + if !nx(v1.Index(i), v2.Index(i)) { + return false + } + } + return true +} + +func interfaceEq(teq Teq, v1, v2 reflect.Value, nx next) bool { + if v1.IsNil() || v2.IsNil() { + return v1.IsNil() == v2.IsNil() + } + return nx(v1.Elem(), v2.Elem()) +} + +func pointerEq(teq Teq, v1, v2 reflect.Value, nx next) bool { + if v1.UnsafePointer() == v2.UnsafePointer() { + return true + } + return nx(v1.Elem(), v2.Elem()) +} + +func structEq(teq Teq, v1, v2 reflect.Value, nx next) bool { + for i, n := 0, v1.NumField(); i < n; i++ { + if !nx(v1.Field(i), v2.Field(i)) { + return false + } + } + return true +} + +func intEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Int() == v2.Int() } +func uintEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Uint() == v2.Uint() } +func stringEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.String() == v2.String() } +func boolEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Bool() == v2.Bool() } +func floatEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Float() == v2.Float() } +func complexEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Complex() == v2.Complex() } diff --git a/teq.go b/teq.go new file mode 100644 index 0000000..f7b5d24 --- /dev/null +++ b/teq.go @@ -0,0 +1,33 @@ +package teq + +import ( + "reflect" +) + +type Teq struct{} + +func (teq Teq) Equal(t TestingT, expected, actual any) bool { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Errorf("panic: %v", r) + } + }() + ok := teq.equal(expected, actual) + if !ok { + t.Errorf("expected %v, got %v", expected, actual) + } + return ok +} + +func (teq Teq) equal(x, y any) bool { + if x == nil || y == nil { + return x == y + } + v1 := reflect.ValueOf(x) + v2 := reflect.ValueOf(y) + if v1.Type() != v2.Type() { + return false + } + return teq.deepValueEqual(v1, v2, make(map[visit]bool), 0) +} diff --git a/teq_test.go b/teq_test.go new file mode 100644 index 0000000..821658d --- /dev/null +++ b/teq_test.go @@ -0,0 +1,138 @@ +package teq_test + +import ( + "fmt" + "testing" + + "github.com/seiyab/teq" +) + +type test struct { + a any + b any + expected []string + pendingFormat bool // for development. we don't have stable format yet. +} + +type group struct { + name string + tests []test +} + +func TestEqual(t *testing.T) { + assert := teq.Teq{} + + groups := []group{ + {"primitives", primitives()}, + {"structs", structs()}, + {"recursions", recursions()}, + } + + for _, group := range groups { + t.Run(group.name, func(t *testing.T) { + for _, test := range group.tests { + name := fmt.Sprintf("%T(%v) == %T(%v)", test.a, test.a, test.b, test.b) + t.Run(name, func(t *testing.T) { + mt := &mockT{} + assert.Equal(mt, test.a, test.b) + if len(mt.errors) != len(test.expected) { + if len(mt.errors) > len(test.expected) { + for _, e := range mt.errors { + t.Logf("got %q", e) + } + } + t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors)) + } + if test.pendingFormat { + return + } + for i, e := range test.expected { + if mt.errors[i] != e { + t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i) + } + } + }) + } + }) + } +} + +func primitives() []test { + return []test{ + {1, 1, nil, false}, + {1, 2, []string{"expected 1, got 2"}, false}, + {uint8(1), uint8(1), nil, false}, + {uint8(1), uint8(2), []string{"expected 1, got 2"}, false}, + {1.5, 1.5, nil, false}, + {1.5, 2.5, []string{"expected 1.5, got 2.5"}, false}, + {"a", "a", nil, false}, + {"a", "b", []string{"expected a, got b"}, false}, + + {"a", 1, []string{"expected a, got 1"}, false}, + } +} + +func structs() []test { + type s struct { + i int + } + type anotherS struct { + i int + } + + type withPointer struct { + i *int + } + + return []test{ + {s{1}, s{1}, nil, false}, + {s{1}, s{2}, []string{"expected {1}, got {2}"}, false}, + {s{1}, anotherS{1}, []string{"expected {1}, got {1}"}, false}, + + {withPointer{ref(1)}, withPointer{ref(1)}, nil, false}, + {withPointer{ref(1)}, withPointer{ref(2)}, []string{"expected {1}, got {2}"}, true}, + } +} + +func recursions() []test { + type privateRecursiveStruct struct { + i int + r *privateRecursiveStruct + } + r1_1 := privateRecursiveStruct{1, nil} + r1_1.r = &r1_1 + r1_2 := privateRecursiveStruct{1, nil} + r1_2.r = &r1_2 + r1_3 := privateRecursiveStruct{2, nil} + r1_3.r = &r1_3 + r1_4 := privateRecursiveStruct{4, nil} + r1_5 := privateRecursiveStruct{4, nil} + r1_4.r = &r1_5 + r1_5.r = &r1_4 + r1_6 := privateRecursiveStruct{4, nil} + r1_6.r = &r1_6 + + type PublicRecursiveStruct struct { + I int + R *PublicRecursiveStruct + } + r2_1 := PublicRecursiveStruct{1, nil} + r2_1.R = &r2_1 + r2_2 := PublicRecursiveStruct{1, nil} + r2_2.R = &r2_2 + + return []test{ + {r1_1, r1_1, nil, false}, + {r1_1, r1_2, nil, false}, + {r1_1, r1_3, []string{"expected {1, }, got {2, }"}, true}, + {r1_4, r1_5, nil, false}, + {r1_4, r1_6, nil, false}, + + {r2_1, r2_1, nil, false}, + {r2_1, r2_2, nil, false}, + } +} + +func ref[T any](v T) *T { + return &v +} diff --git a/testingt.go b/testingt.go index 7c995e0..b7f6c87 100644 --- a/testingt.go +++ b/testingt.go @@ -3,6 +3,7 @@ package teq import "testing" type TestingT interface { + Helper() Errorf(format string, args ...interface{}) } diff --git a/testingt_test.go b/testingt_test.go index b6b835c..a7ec7b0 100644 --- a/testingt_test.go +++ b/testingt_test.go @@ -12,6 +12,8 @@ type mockT struct { var _ teq.TestingT = &mockT{} +func (t *mockT) Helper() {} + func (t *mockT) Errorf(format string, args ...interface{}) { t.errors = append(t.errors, fmt.Sprintf(format, args...)) }