-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
323 additions
and
59 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
// Some code is written referencing following codes: | ||
// - deepequal.go in "reflect" package authored by Go Authors | ||
// - deepequal.go in "github.com/weaveworks/scope/test/reflect" package authored by Weaveworks Ltd | ||
|
||
package teq | ||
|
||
import ( | ||
"reflect" | ||
"unsafe" | ||
) | ||
|
||
// During deepValueEqual, must keep track of checks that are | ||
// in progress. The comparison algorithm assumes that all | ||
// checks in progress are true when it reencounters them. | ||
// Visited comparisons are stored in a map indexed by visit. | ||
type visit struct { | ||
a1 unsafe.Pointer | ||
a2 unsafe.Pointer | ||
typ reflect.Type | ||
} | ||
|
||
const maxDepth = 1_000 | ||
|
||
func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { | ||
if depth > maxDepth { | ||
panic("maximum depth exceeded") | ||
} | ||
if !v1.IsValid() || !v2.IsValid() { | ||
return v1.IsValid() == v2.IsValid() | ||
} | ||
if v1.Type() != v2.Type() { | ||
return false | ||
} | ||
|
||
if hard(v1.Kind()) { | ||
if v1.CanAddr() && v2.CanAddr() { | ||
addr1 := v1.Addr().UnsafePointer() | ||
addr2 := v2.Addr().UnsafePointer() | ||
|
||
// Short circuit | ||
if uintptr(addr1) == uintptr(addr2) { | ||
return true | ||
} | ||
if uintptr(addr1) > uintptr(addr2) { | ||
// Canonicalize order to reduce number of entries in visited. | ||
addr1, addr2 = addr2, addr1 | ||
} | ||
|
||
// Short circuit if references are already seen. | ||
typ := v1.Type() | ||
v := visit{addr1, addr2, typ} | ||
if visited[v] { | ||
return true | ||
} | ||
|
||
// Remember for later. | ||
visited[v] = true | ||
} | ||
} | ||
|
||
eqFn, ok := eqs[v1.Kind()] | ||
if !ok { | ||
panic("not implemented") | ||
} | ||
var n next = func(v1, v2 reflect.Value) bool { | ||
return teq.deepValueEqual(v1, v2, visited, depth+1) | ||
} | ||
return eqFn(teq, v1, v2, n) | ||
} | ||
|
||
type next func(v1, v2 reflect.Value) bool | ||
|
||
var eqs = map[reflect.Kind]func(teq Teq, v1, v2 reflect.Value, nx next) bool{ | ||
reflect.Array: arrayEq, | ||
reflect.Slice: todo, | ||
reflect.Interface: interfaceEq, | ||
reflect.Pointer: pointerEq, | ||
reflect.Struct: structEq, | ||
reflect.Map: todo, | ||
reflect.Func: todo, | ||
reflect.Int: intEq, | ||
reflect.Int8: intEq, | ||
reflect.Int16: intEq, | ||
reflect.Int32: intEq, | ||
reflect.Int64: intEq, | ||
reflect.Uint: uintEq, | ||
reflect.Uint8: uintEq, | ||
reflect.Uint16: uintEq, | ||
reflect.Uint32: uintEq, | ||
reflect.Uint64: uintEq, | ||
reflect.Uintptr: uintEq, | ||
reflect.String: stringEq, | ||
reflect.Bool: boolEq, | ||
reflect.Float32: floatEq, | ||
reflect.Float64: floatEq, | ||
reflect.Complex64: complexEq, | ||
reflect.Complex128: complexEq, | ||
} | ||
|
||
func hard(k reflect.Kind) bool { | ||
switch k { | ||
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct: | ||
return true | ||
} | ||
return false | ||
} | ||
|
||
func todo(teq Teq, v1, v2 reflect.Value, nx next) bool { | ||
panic("not implemented") | ||
} | ||
|
||
func arrayEq(teq Teq, v1, v2 reflect.Value, nx next) bool { | ||
for i := 0; i < v1.Len(); i++ { | ||
if !nx(v1.Index(i), v2.Index(i)) { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func interfaceEq(teq Teq, v1, v2 reflect.Value, nx next) bool { | ||
if v1.IsNil() || v2.IsNil() { | ||
return v1.IsNil() == v2.IsNil() | ||
} | ||
return nx(v1.Elem(), v2.Elem()) | ||
} | ||
|
||
func pointerEq(teq Teq, v1, v2 reflect.Value, nx next) bool { | ||
if v1.UnsafePointer() == v2.UnsafePointer() { | ||
return true | ||
} | ||
return nx(v1.Elem(), v2.Elem()) | ||
} | ||
|
||
func structEq(teq Teq, v1, v2 reflect.Value, nx next) bool { | ||
for i, n := 0, v1.NumField(); i < n; i++ { | ||
if !nx(v1.Field(i), v2.Field(i)) { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func intEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Int() == v2.Int() } | ||
func uintEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Uint() == v2.Uint() } | ||
func stringEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.String() == v2.String() } | ||
func boolEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Bool() == v2.Bool() } | ||
func floatEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Float() == v2.Float() } | ||
func complexEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Complex() == v2.Complex() } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package teq | ||
|
||
import ( | ||
"reflect" | ||
) | ||
|
||
type Teq struct{} | ||
|
||
func (teq Teq) Equal(t TestingT, expected, actual any) bool { | ||
t.Helper() | ||
defer func() { | ||
if r := recover(); r != nil { | ||
t.Errorf("panic: %v", r) | ||
} | ||
}() | ||
ok := teq.equal(expected, actual) | ||
if !ok { | ||
t.Errorf("expected %v, got %v", expected, actual) | ||
} | ||
return ok | ||
} | ||
|
||
func (teq Teq) equal(x, y any) bool { | ||
if x == nil || y == nil { | ||
return x == y | ||
} | ||
v1 := reflect.ValueOf(x) | ||
v2 := reflect.ValueOf(y) | ||
if v1.Type() != v2.Type() { | ||
return false | ||
} | ||
return teq.deepValueEqual(v1, v2, make(map[visit]bool), 0) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
package teq_test | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/seiyab/teq" | ||
) | ||
|
||
type test struct { | ||
a any | ||
b any | ||
expected []string | ||
pendingFormat bool // for development. we don't have stable format yet. | ||
} | ||
|
||
type group struct { | ||
name string | ||
tests []test | ||
} | ||
|
||
func TestEqual(t *testing.T) { | ||
assert := teq.Teq{} | ||
|
||
groups := []group{ | ||
{"primitives", primitives()}, | ||
{"structs", structs()}, | ||
{"recursions", recursions()}, | ||
} | ||
|
||
for _, group := range groups { | ||
t.Run(group.name, func(t *testing.T) { | ||
for _, test := range group.tests { | ||
name := fmt.Sprintf("%T(%v) == %T(%v)", test.a, test.a, test.b, test.b) | ||
t.Run(name, func(t *testing.T) { | ||
mt := &mockT{} | ||
assert.Equal(mt, test.a, test.b) | ||
if len(mt.errors) != len(test.expected) { | ||
if len(mt.errors) > len(test.expected) { | ||
for _, e := range mt.errors { | ||
t.Logf("got %q", e) | ||
} | ||
} | ||
t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors)) | ||
} | ||
if test.pendingFormat { | ||
return | ||
} | ||
for i, e := range test.expected { | ||
if mt.errors[i] != e { | ||
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i) | ||
} | ||
} | ||
}) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func primitives() []test { | ||
return []test{ | ||
{1, 1, nil, false}, | ||
{1, 2, []string{"expected 1, got 2"}, false}, | ||
{uint8(1), uint8(1), nil, false}, | ||
{uint8(1), uint8(2), []string{"expected 1, got 2"}, false}, | ||
{1.5, 1.5, nil, false}, | ||
{1.5, 2.5, []string{"expected 1.5, got 2.5"}, false}, | ||
{"a", "a", nil, false}, | ||
{"a", "b", []string{"expected a, got b"}, false}, | ||
|
||
{"a", 1, []string{"expected a, got 1"}, false}, | ||
} | ||
} | ||
|
||
func structs() []test { | ||
type s struct { | ||
i int | ||
} | ||
type anotherS struct { | ||
i int | ||
} | ||
|
||
type withPointer struct { | ||
i *int | ||
} | ||
|
||
return []test{ | ||
{s{1}, s{1}, nil, false}, | ||
{s{1}, s{2}, []string{"expected {1}, got {2}"}, false}, | ||
{s{1}, anotherS{1}, []string{"expected {1}, got {1}"}, false}, | ||
|
||
{withPointer{ref(1)}, withPointer{ref(1)}, nil, false}, | ||
{withPointer{ref(1)}, withPointer{ref(2)}, []string{"expected {1}, got {2}"}, true}, | ||
} | ||
} | ||
|
||
func recursions() []test { | ||
type privateRecursiveStruct struct { | ||
i int | ||
r *privateRecursiveStruct | ||
} | ||
r1_1 := privateRecursiveStruct{1, nil} | ||
r1_1.r = &r1_1 | ||
r1_2 := privateRecursiveStruct{1, nil} | ||
r1_2.r = &r1_2 | ||
r1_3 := privateRecursiveStruct{2, nil} | ||
r1_3.r = &r1_3 | ||
r1_4 := privateRecursiveStruct{4, nil} | ||
r1_5 := privateRecursiveStruct{4, nil} | ||
r1_4.r = &r1_5 | ||
r1_5.r = &r1_4 | ||
r1_6 := privateRecursiveStruct{4, nil} | ||
r1_6.r = &r1_6 | ||
|
||
type PublicRecursiveStruct struct { | ||
I int | ||
R *PublicRecursiveStruct | ||
} | ||
r2_1 := PublicRecursiveStruct{1, nil} | ||
r2_1.R = &r2_1 | ||
r2_2 := PublicRecursiveStruct{1, nil} | ||
r2_2.R = &r2_2 | ||
|
||
return []test{ | ||
{r1_1, r1_1, nil, false}, | ||
{r1_1, r1_2, nil, false}, | ||
{r1_1, r1_3, []string{"expected {1, <cyclic>}, got {2, <cyclic>}"}, true}, | ||
{r1_4, r1_5, nil, false}, | ||
{r1_4, r1_6, nil, false}, | ||
|
||
{r2_1, r2_1, nil, false}, | ||
{r2_1, r2_2, nil, false}, | ||
} | ||
} | ||
|
||
func ref[T any](v T) *T { | ||
return &v | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.