Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement equality of struct #1

Merged
merged 4 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions api.go

This file was deleted.

48 changes: 0 additions & 48 deletions api_test.go

This file was deleted.

149 changes: 149 additions & 0 deletions deepequal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Some code is written referencing following codes:
// - deepequal.go in "reflect" package authored by Go Authors
// - deepequal.go in "github.com/weaveworks/scope/test/reflect" package authored by Weaveworks Ltd

package teq

import (
"reflect"
"unsafe"
)

// During deepValueEqual, must keep track of checks that are
// in progress. The comparison algorithm assumes that all
// checks in progress are true when it reencounters them.
// Visited comparisons are stored in a map indexed by visit.
type visit struct {
a1 unsafe.Pointer
a2 unsafe.Pointer
typ reflect.Type
}

const maxDepth = 1_000

func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if depth > maxDepth {
panic("maximum depth exceeded")
}
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}

if hard(v1.Kind()) {
if v1.CanAddr() && v2.CanAddr() {
addr1 := v1.Addr().UnsafePointer()
addr2 := v2.Addr().UnsafePointer()

// Short circuit
if uintptr(addr1) == uintptr(addr2) {
return true
}
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
addr1, addr2 = addr2, addr1
}

// Short circuit if references are already seen.
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true
}

// Remember for later.
visited[v] = true
}
}

eqFn, ok := eqs[v1.Kind()]
if !ok {
panic("not implemented")
}
var n next = func(v1, v2 reflect.Value) bool {
return teq.deepValueEqual(v1, v2, visited, depth+1)
}
return eqFn(teq, v1, v2, n)
}

type next func(v1, v2 reflect.Value) bool

var eqs = map[reflect.Kind]func(teq Teq, v1, v2 reflect.Value, nx next) bool{
reflect.Array: arrayEq,
reflect.Slice: todo,
reflect.Interface: interfaceEq,
reflect.Pointer: pointerEq,
reflect.Struct: structEq,
reflect.Map: todo,
reflect.Func: todo,
reflect.Int: intEq,
reflect.Int8: intEq,
reflect.Int16: intEq,
reflect.Int32: intEq,
reflect.Int64: intEq,
reflect.Uint: uintEq,
reflect.Uint8: uintEq,
reflect.Uint16: uintEq,
reflect.Uint32: uintEq,
reflect.Uint64: uintEq,
reflect.Uintptr: uintEq,
reflect.String: stringEq,
reflect.Bool: boolEq,
reflect.Float32: floatEq,
reflect.Float64: floatEq,
reflect.Complex64: complexEq,
reflect.Complex128: complexEq,
}

func hard(k reflect.Kind) bool {
switch k {
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct:
return true
}
return false
}

func todo(teq Teq, v1, v2 reflect.Value, nx next) bool {
panic("not implemented")
}

func arrayEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
for i := 0; i < v1.Len(); i++ {
if !nx(v1.Index(i), v2.Index(i)) {
return false
}
}
return true
}

func interfaceEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return nx(v1.Elem(), v2.Elem())
}

func pointerEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
if v1.UnsafePointer() == v2.UnsafePointer() {
return true
}
return nx(v1.Elem(), v2.Elem())
}

func structEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
for i, n := 0, v1.NumField(); i < n; i++ {
if !nx(v1.Field(i), v2.Field(i)) {
return false
}
}
return true
}

func intEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Int() == v2.Int() }
func uintEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Uint() == v2.Uint() }
func stringEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.String() == v2.String() }
func boolEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Bool() == v2.Bool() }
func floatEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Float() == v2.Float() }
func complexEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Complex() == v2.Complex() }
33 changes: 33 additions & 0 deletions teq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package teq

import (
"reflect"
)

type Teq struct{}

func (teq Teq) Equal(t TestingT, expected, actual any) bool {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Errorf("panic: %v", r)
}
}()
ok := teq.equal(expected, actual)
if !ok {
t.Errorf("expected %v, got %v", expected, actual)
}
return ok
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
}
v1 := reflect.ValueOf(x)
v2 := reflect.ValueOf(y)
if v1.Type() != v2.Type() {
return false
}
return teq.deepValueEqual(v1, v2, make(map[visit]bool), 0)
}
138 changes: 138 additions & 0 deletions teq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package teq_test

import (
"fmt"
"testing"

"github.com/seiyab/teq"
)

type test struct {
a any
b any
expected []string
pendingFormat bool // for development. we don't have stable format yet.
}

type group struct {
name string
tests []test
}

func TestEqual(t *testing.T) {
assert := teq.Teq{}

groups := []group{
{"primitives", primitives()},
{"structs", structs()},
{"recursions", recursions()},
}

for _, group := range groups {
t.Run(group.name, func(t *testing.T) {
for _, test := range group.tests {
name := fmt.Sprintf("%T(%v) == %T(%v)", test.a, test.a, test.b, test.b)
t.Run(name, func(t *testing.T) {
mt := &mockT{}
assert.Equal(mt, test.a, test.b)
if len(mt.errors) != len(test.expected) {
if len(mt.errors) > len(test.expected) {
for _, e := range mt.errors {
t.Logf("got %q", e)
}
}
t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors))
}
if test.pendingFormat {
return
}
for i, e := range test.expected {
if mt.errors[i] != e {
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i)
}
}
})
}
})
}
}

func primitives() []test {
return []test{
{1, 1, nil, false},
{1, 2, []string{"expected 1, got 2"}, false},
{uint8(1), uint8(1), nil, false},
{uint8(1), uint8(2), []string{"expected 1, got 2"}, false},
{1.5, 1.5, nil, false},
{1.5, 2.5, []string{"expected 1.5, got 2.5"}, false},
{"a", "a", nil, false},
{"a", "b", []string{"expected a, got b"}, false},

{"a", 1, []string{"expected a, got 1"}, false},
}
}

func structs() []test {
type s struct {
i int
}
type anotherS struct {
i int
}

type withPointer struct {
i *int
}

return []test{
{s{1}, s{1}, nil, false},
{s{1}, s{2}, []string{"expected {1}, got {2}"}, false},
{s{1}, anotherS{1}, []string{"expected {1}, got {1}"}, false},

{withPointer{ref(1)}, withPointer{ref(1)}, nil, false},
{withPointer{ref(1)}, withPointer{ref(2)}, []string{"expected {1}, got {2}"}, true},
}
}

func recursions() []test {
type privateRecursiveStruct struct {
i int
r *privateRecursiveStruct
}
r1_1 := privateRecursiveStruct{1, nil}
r1_1.r = &r1_1
r1_2 := privateRecursiveStruct{1, nil}
r1_2.r = &r1_2
r1_3 := privateRecursiveStruct{2, nil}
r1_3.r = &r1_3
r1_4 := privateRecursiveStruct{4, nil}
r1_5 := privateRecursiveStruct{4, nil}
r1_4.r = &r1_5
r1_5.r = &r1_4
r1_6 := privateRecursiveStruct{4, nil}
r1_6.r = &r1_6

type PublicRecursiveStruct struct {
I int
R *PublicRecursiveStruct
}
r2_1 := PublicRecursiveStruct{1, nil}
r2_1.R = &r2_1
r2_2 := PublicRecursiveStruct{1, nil}
r2_2.R = &r2_2

return []test{
{r1_1, r1_1, nil, false},
{r1_1, r1_2, nil, false},
{r1_1, r1_3, []string{"expected {1, <cyclic>}, got {2, <cyclic>}"}, true},
{r1_4, r1_5, nil, false},
{r1_4, r1_6, nil, false},

{r2_1, r2_1, nil, false},
{r2_1, r2_2, nil, false},
}
}

func ref[T any](v T) *T {
return &v
}
1 change: 1 addition & 0 deletions testingt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package teq
import "testing"

type TestingT interface {
Helper()
Errorf(format string, args ...interface{})
}

Expand Down
Loading
Loading