Skip to content

Commit

Permalink
implement AddEqual (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
seiyab authored May 7, 2024
1 parent 39f806b commit 3ea7ad0
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ func (teq Teq) deepValueEqual(
return false
}

eq, ok := teq.equals[v1.Type()]
if ok {
return eq(v1, v2)
}

tr, ok := teq.transforms[v1.Type()]
if ok {
t1 := tr(v1)
Expand Down
30 changes: 30 additions & 0 deletions teq.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Teq struct {

transforms map[reflect.Type]func(reflect.Value) reflect.Value
formats map[reflect.Type]func(reflect.Value) string
equals map[reflect.Type]func(reflect.Value, reflect.Value) bool
}

// New returns new instance of Teq.
Expand All @@ -20,6 +21,7 @@ func New() Teq {

transforms: make(map[reflect.Type]func(reflect.Value) reflect.Value),
formats: make(map[reflect.Type]func(reflect.Value) string),
equals: make(map[reflect.Type]func(reflect.Value, reflect.Value) bool),
}
}

Expand Down Expand Up @@ -109,6 +111,34 @@ func (teq *Teq) AddFormat(format any) {
teq.formats[ty.In(0)] = reflectFormat
}

// AddEqual adds an equal function to Teq.
// The equal function must have two arguments with the same type and one return value of bool.
// If the passed equal function is not valid, it will panic.
// The equal function will be used for equality check instead of the default equality check.
func (teq *Teq) AddEqual(equal any) {
ty := reflect.TypeOf(equal)
if ty.Kind() != reflect.Func {
panic("equal must be a function")
}
if ty.NumIn() != 2 {
panic("equal must have two arguments")
}
if ty.In(0) != ty.In(1) {
panic("equal must have two arguments with the same type")
}
if ty.NumOut() != 1 {
panic("equal must have only one return value")
}
if ty.Out(0).Kind() != reflect.Bool {
panic("equal must return bool")
}
equalValue := reflect.ValueOf(equal)
reflectEqual := func(v1, v2 reflect.Value) bool {
return equalValue.Call([]reflect.Value{v1, v2})[0].Bool()
}
teq.equals[ty.In(0)] = reflectEqual
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
Expand Down
38 changes: 38 additions & 0 deletions teq_customized_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package teq_test

import (
"math"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -47,6 +48,43 @@ func TestEqual_Customized(t *testing.T) {
t.Error("expected ds1 != ds2, got ds1 == ds2 with reflect.DeepEqual")
}
})

t.Run("AddEqual", func(t *testing.T) {
t.Run("float64", func(t *testing.T) {
d := teq.New()
c := teq.New()
c.AddEqual(func(a, b float64) bool {
const epsilon = 1e-3
return math.Abs(a-b) < epsilon
})

d.NotEqual(t, 1.0, 1.001)
c.Equal(t, 1.0, 1.001)
c.NotEqual(t, 1.0, 1.002)

d.NotEqual(t, float32(1.0), float32(1.001))
c.NotEqual(t, float32(1.0), float32(1.001))

d.NotEqual(t, []float64{1.0, 1.0, 1.001}, []float64{1.001, 1.0, 1.0})
c.Equal(t, []float64{1.0, 1.0, 1.001}, []float64{1.001, 1.0, 1.0})
})

t.Run("time.Time", func(t *testing.T) {
d := teq.New()
c := teq.New()
c.AddEqual(func(a, b time.Time) bool {
return a.Equal(b)
})

secondsEastOfUTC := int((8 * time.Hour).Seconds())
beijing := time.FixedZone("Beijing Time", secondsEastOfUTC)
d1 := time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC)
d2 := time.Date(2000, 2, 1, 20, 30, 0, 0, beijing)

d.NotEqual(t, d1, d2)
c.Equal(t, d1, d2)
})
})
}

func TestEqual_CustomizedFormat(t *testing.T) {
Expand Down

0 comments on commit 3ea7ad0

Please sign in to comment.