Skip to content

Commit

Permalink
implement equality of struct (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
seiyab authored Apr 28, 2024
1 parent 57617aa commit 951b441
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 59 deletions.
11 changes: 0 additions & 11 deletions api.go

This file was deleted.

48 changes: 0 additions & 48 deletions api_test.go

This file was deleted.

149 changes: 149 additions & 0 deletions deepequal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Some code is written referencing following codes:
// - deepequal.go in "reflect" package authored by Go Authors
// - deepequal.go in "github.com/weaveworks/scope/test/reflect" package authored by Weaveworks Ltd

package teq

import (
"reflect"
"unsafe"
)

// During deepValueEqual, must keep track of checks that are
// in progress. The comparison algorithm assumes that all
// checks in progress are true when it reencounters them.
// Visited comparisons are stored in a map indexed by visit.
type visit struct {
a1 unsafe.Pointer
a2 unsafe.Pointer
typ reflect.Type
}

const maxDepth = 1_000

func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if depth > maxDepth {
panic("maximum depth exceeded")
}
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}

if hard(v1.Kind()) {
if v1.CanAddr() && v2.CanAddr() {
addr1 := v1.Addr().UnsafePointer()
addr2 := v2.Addr().UnsafePointer()

// Short circuit
if uintptr(addr1) == uintptr(addr2) {
return true
}
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
addr1, addr2 = addr2, addr1
}

// Short circuit if references are already seen.
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true
}

// Remember for later.
visited[v] = true
}
}

eqFn, ok := eqs[v1.Kind()]
if !ok {
panic("not implemented")
}
var n next = func(v1, v2 reflect.Value) bool {
return teq.deepValueEqual(v1, v2, visited, depth+1)
}
return eqFn(teq, v1, v2, n)
}

type next func(v1, v2 reflect.Value) bool

var eqs = map[reflect.Kind]func(teq Teq, v1, v2 reflect.Value, nx next) bool{
reflect.Array: arrayEq,
reflect.Slice: todo,
reflect.Interface: interfaceEq,
reflect.Pointer: pointerEq,
reflect.Struct: structEq,
reflect.Map: todo,
reflect.Func: todo,
reflect.Int: intEq,
reflect.Int8: intEq,
reflect.Int16: intEq,
reflect.Int32: intEq,
reflect.Int64: intEq,
reflect.Uint: uintEq,
reflect.Uint8: uintEq,
reflect.Uint16: uintEq,
reflect.Uint32: uintEq,
reflect.Uint64: uintEq,
reflect.Uintptr: uintEq,
reflect.String: stringEq,
reflect.Bool: boolEq,
reflect.Float32: floatEq,
reflect.Float64: floatEq,
reflect.Complex64: complexEq,
reflect.Complex128: complexEq,
}

func hard(k reflect.Kind) bool {
switch k {
case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct:
return true
}
return false
}

func todo(teq Teq, v1, v2 reflect.Value, nx next) bool {
panic("not implemented")
}

func arrayEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
for i := 0; i < v1.Len(); i++ {
if !nx(v1.Index(i), v2.Index(i)) {
return false
}
}
return true
}

func interfaceEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return nx(v1.Elem(), v2.Elem())
}

func pointerEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
if v1.UnsafePointer() == v2.UnsafePointer() {
return true
}
return nx(v1.Elem(), v2.Elem())
}

func structEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
for i, n := 0, v1.NumField(); i < n; i++ {
if !nx(v1.Field(i), v2.Field(i)) {
return false
}
}
return true
}

func intEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Int() == v2.Int() }
func uintEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Uint() == v2.Uint() }
func stringEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.String() == v2.String() }
func boolEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Bool() == v2.Bool() }
func floatEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Float() == v2.Float() }
func complexEq(_ Teq, v1, v2 reflect.Value, _ next) bool { return v1.Complex() == v2.Complex() }
33 changes: 33 additions & 0 deletions teq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package teq

import (
"reflect"
)

type Teq struct{}

func (teq Teq) Equal(t TestingT, expected, actual any) bool {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Errorf("panic: %v", r)
}
}()
ok := teq.equal(expected, actual)
if !ok {
t.Errorf("expected %v, got %v", expected, actual)
}
return ok
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
}
v1 := reflect.ValueOf(x)
v2 := reflect.ValueOf(y)
if v1.Type() != v2.Type() {
return false
}
return teq.deepValueEqual(v1, v2, make(map[visit]bool), 0)
}
138 changes: 138 additions & 0 deletions teq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package teq_test

import (
"fmt"
"testing"

"github.com/seiyab/teq"
)

type test struct {
a any
b any
expected []string
pendingFormat bool // for development. we don't have stable format yet.
}

type group struct {
name string
tests []test
}

func TestEqual(t *testing.T) {
assert := teq.Teq{}

groups := []group{
{"primitives", primitives()},
{"structs", structs()},
{"recursions", recursions()},
}

for _, group := range groups {
t.Run(group.name, func(t *testing.T) {
for _, test := range group.tests {
name := fmt.Sprintf("%T(%v) == %T(%v)", test.a, test.a, test.b, test.b)
t.Run(name, func(t *testing.T) {
mt := &mockT{}
assert.Equal(mt, test.a, test.b)
if len(mt.errors) != len(test.expected) {
if len(mt.errors) > len(test.expected) {
for _, e := range mt.errors {
t.Logf("got %q", e)
}
}
t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors))
}
if test.pendingFormat {
return
}
for i, e := range test.expected {
if mt.errors[i] != e {
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i)
}
}
})
}
})
}
}

func primitives() []test {
return []test{
{1, 1, nil, false},
{1, 2, []string{"expected 1, got 2"}, false},
{uint8(1), uint8(1), nil, false},
{uint8(1), uint8(2), []string{"expected 1, got 2"}, false},
{1.5, 1.5, nil, false},
{1.5, 2.5, []string{"expected 1.5, got 2.5"}, false},
{"a", "a", nil, false},
{"a", "b", []string{"expected a, got b"}, false},

{"a", 1, []string{"expected a, got 1"}, false},
}
}

func structs() []test {
type s struct {
i int
}
type anotherS struct {
i int
}

type withPointer struct {
i *int
}

return []test{
{s{1}, s{1}, nil, false},
{s{1}, s{2}, []string{"expected {1}, got {2}"}, false},
{s{1}, anotherS{1}, []string{"expected {1}, got {1}"}, false},

{withPointer{ref(1)}, withPointer{ref(1)}, nil, false},
{withPointer{ref(1)}, withPointer{ref(2)}, []string{"expected {1}, got {2}"}, true},
}
}

func recursions() []test {
type privateRecursiveStruct struct {
i int
r *privateRecursiveStruct
}
r1_1 := privateRecursiveStruct{1, nil}
r1_1.r = &r1_1
r1_2 := privateRecursiveStruct{1, nil}
r1_2.r = &r1_2
r1_3 := privateRecursiveStruct{2, nil}
r1_3.r = &r1_3
r1_4 := privateRecursiveStruct{4, nil}
r1_5 := privateRecursiveStruct{4, nil}
r1_4.r = &r1_5
r1_5.r = &r1_4
r1_6 := privateRecursiveStruct{4, nil}
r1_6.r = &r1_6

type PublicRecursiveStruct struct {
I int
R *PublicRecursiveStruct
}
r2_1 := PublicRecursiveStruct{1, nil}
r2_1.R = &r2_1
r2_2 := PublicRecursiveStruct{1, nil}
r2_2.R = &r2_2

return []test{
{r1_1, r1_1, nil, false},
{r1_1, r1_2, nil, false},
{r1_1, r1_3, []string{"expected {1, <cyclic>}, got {2, <cyclic>}"}, true},
{r1_4, r1_5, nil, false},
{r1_4, r1_6, nil, false},

{r2_1, r2_1, nil, false},
{r2_1, r2_2, nil, false},
}
}

func ref[T any](v T) *T {
return &v
}
1 change: 1 addition & 0 deletions testingt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package teq
import "testing"

type TestingT interface {
Helper()
Errorf(format string, args ...interface{})
}

Expand Down
Loading

0 comments on commit 951b441

Please sign in to comment.