Skip to content

Commit

Permalink
Customizability (transform) (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
seiyab authored Apr 29, 2024
1 parent 951b441 commit 76e9755
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 14 deletions.
21 changes: 16 additions & 5 deletions deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ type visit struct {
typ reflect.Type
}

const maxDepth = 1_000

func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if depth > maxDepth {
func (teq Teq) deepValueEqual(
v1, v2 reflect.Value,
visited map[visit]bool,
depth int,
) bool {
if depth > teq.MaxDepth {
panic("maximum depth exceeded")
}
if !v1.IsValid() || !v2.IsValid() {
Expand All @@ -32,6 +34,15 @@ func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, dept
return false
}

tr, ok := teq.transforms[v1.Type()]
if ok {
t1 := tr(v1)
t2 := tr(v2)
newTeq := New()
newTeq.MaxDepth = teq.MaxDepth
return newTeq.deepValueEqual(t1, t2, visited, depth)
}

if hard(v1.Kind()) {
if v1.CanAddr() && v2.CanAddr() {
addr1 := v1.Addr().UnsafePointer()
Expand Down Expand Up @@ -134,7 +145,7 @@ func pointerEq(teq Teq, v1, v2 reflect.Value, nx next) bool {

func structEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
for i, n := 0, v1.NumField(); i < n; i++ {
if !nx(v1.Field(i), v2.Field(i)) {
if !nx(field(v1, i), field(v2, i)) {
return false
}
}
Expand Down
16 changes: 16 additions & 0 deletions misc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package teq

import (
"reflect"
)

func field(v reflect.Value, idx int) reflect.Value {
f1 := v.Field(idx)
if f1.CanAddr() {
return f1
}
vc := reflect.New(v.Type()).Elem()
vc.Set(v)
rf := vc.Field(idx)
return reflect.NewAt(rf.Type(), rf.Addr().UnsafePointer()).Elem()
}
60 changes: 57 additions & 3 deletions teq.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,25 @@ import (
"reflect"
)

type Teq struct{}
type Teq struct {
MaxDepth int

transforms map[reflect.Type]func(reflect.Value) reflect.Value
}

func New() Teq {
return Teq{
MaxDepth: 1_000,

transforms: make(map[reflect.Type]func(reflect.Value) reflect.Value),
}
}

func (teq Teq) Equal(t TestingT, expected, actual any) bool {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Errorf("panic: %v", r)
t.Errorf("panic in github.com/seiyab/teq. please report issue. message: %v", r)
}
}()
ok := teq.equal(expected, actual)
Expand All @@ -20,6 +32,44 @@ func (teq Teq) Equal(t TestingT, expected, actual any) bool {
return ok
}

func (teq Teq) NotEqual(t TestingT, expected, actual any) bool {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Errorf("panic in github.com/seiyab/teq. please report issue. message: %v", r)
}
}()
ok := !teq.equal(expected, actual)
if !ok {
if reflect.DeepEqual(expected, actual) {
t.Error("reflect.DeepEqual(expected, actual) == true.")
} else {
t.Errorf("expected %v != %v", expected, actual)
t.Log("reflect.DeepEqual(expected, actual) == false. maybe transforms made them equal.")
}
}
return ok

}

func (teq *Teq) AddTransform(transform any) {
ty := reflect.TypeOf(transform)
if ty.Kind() != reflect.Func {
panic("transform must be a function")
}
if ty.NumIn() != 1 {
panic("transform must have only one argument")
}
if ty.NumOut() != 1 {
panic("transform must have only one return value")
}
trValue := reflect.ValueOf(transform)
reflectTransform := func(v reflect.Value) reflect.Value {
return trValue.Call([]reflect.Value{v})[0]
}
teq.transforms[ty.In(0)] = reflectTransform
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
Expand All @@ -29,5 +79,9 @@ func (teq Teq) equal(x, y any) bool {
if v1.Type() != v2.Type() {
return false
}
return teq.deepValueEqual(v1, v2, make(map[visit]bool), 0)
return teq.deepValueEqual(
v1, v2,
make(map[visit]bool),
0,
)
}
55 changes: 55 additions & 0 deletions teq_customized_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package teq_test

import (
"reflect"
"testing"
"time"

"github.com/seiyab/teq"
)

func TestEqual_Customized(t *testing.T) {
t.Run("time.Time", func(t *testing.T) {
defaultTeq := teq.New()
customizedTeq := teq.New()
customizedTeq.AddTransform(utc)

secondsEastOfUTC := int((8 * time.Hour).Seconds())
beijing := time.FixedZone("Beijing Time", secondsEastOfUTC)
d1 := time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC)
d2 := time.Date(2000, 2, 1, 20, 30, 0, 0, beijing)

defaultTeq.NotEqual(t, d1, d2)
customizedTeq.Equal(t, d1, d2)
if reflect.DeepEqual(d1, d2) {
t.Error("expected d1 != d2, got d1 == d2 with reflect.DeepEqual")
}

type twoDates struct {
d1 time.Time
d2 time.Time
}
dt1 := twoDates{d1, d2}
dt2 := twoDates{d2, d1}

defaultTeq.NotEqual(t, dt1, dt2)
customizedTeq.Equal(t, dt1, dt2)

if reflect.DeepEqual(dt1, dt2) {
t.Error("expected dt1 != dt2, got dt1 == dt2 with reflect.DeepEqual")
}

t.Skip("slice is not supported yet")
ds1 := []time.Time{d1, d1, d2}
ds2 := []time.Time{d2, d1, d1}
defaultTeq.NotEqual(t, ds1, ds2)
customizedTeq.Equal(t, ds1, ds2)
if reflect.DeepEqual(ds1, ds2) {
t.Error("expected ds1 != ds2, got ds1 == ds2 with reflect.DeepEqual")
}
})
}

func utc(d time.Time) time.Time {
return d.UTC()
}
21 changes: 15 additions & 6 deletions teq_test.go → teq_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type group struct {
}

func TestEqual(t *testing.T) {
assert := teq.Teq{}
assert := teq.New()

groups := []group{
{"primitives", primitives()},
Expand All @@ -43,14 +43,23 @@ func TestEqual(t *testing.T) {
}
t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors))
}
if test.pendingFormat {
return

if !test.pendingFormat {
for i, e := range test.expected {
if mt.errors[i] != e {
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i)
}
}
}
for i, e := range test.expected {
if mt.errors[i] != e {
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i)

{
mt := &mockT{}
assert.NotEqual(mt, test.a, test.b)
if (len(mt.errors) > 0) == (len(test.expected) > 0) {
t.Errorf("expected (len(mt.errors) > 0) = %t, got %t", len(test.expected) > 0, len(mt.errors) > 0)
}
}

})
}
})
Expand Down
2 changes: 2 additions & 0 deletions testingt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import "testing"

type TestingT interface {
Helper()
Error(args ...interface{})
Errorf(format string, args ...interface{})
Log(args ...interface{})
}

var _ TestingT = &testing.T{}
6 changes: 6 additions & 0 deletions testingt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ var _ teq.TestingT = &mockT{}

func (t *mockT) Helper() {}

func (t *mockT) Error(args ...interface{}) {
t.errors = append(t.errors, fmt.Sprint(args...))
}

func (t *mockT) Errorf(format string, args ...interface{}) {
t.errors = append(t.errors, fmt.Sprintf(format, args...))
}

func (t *mockT) Log(args ...interface{}) {}

0 comments on commit 76e9755

Please sign in to comment.