Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement AddEqual #10

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ func (teq Teq) deepValueEqual(
return false
}

eq, ok := teq.equals[v1.Type()]
if ok {
return eq(v1, v2)
}

tr, ok := teq.transforms[v1.Type()]
if ok {
t1 := tr(v1)
Expand Down
30 changes: 30 additions & 0 deletions teq.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Teq struct {

transforms map[reflect.Type]func(reflect.Value) reflect.Value
formats map[reflect.Type]func(reflect.Value) string
equals map[reflect.Type]func(reflect.Value, reflect.Value) bool
}

// New returns new instance of Teq.
Expand All @@ -20,6 +21,7 @@ func New() Teq {

transforms: make(map[reflect.Type]func(reflect.Value) reflect.Value),
formats: make(map[reflect.Type]func(reflect.Value) string),
equals: make(map[reflect.Type]func(reflect.Value, reflect.Value) bool),
}
}

Expand Down Expand Up @@ -109,6 +111,34 @@ func (teq *Teq) AddFormat(format any) {
teq.formats[ty.In(0)] = reflectFormat
}

// AddEqual adds an equal function to Teq.
// The equal function must have two arguments with the same type and one return value of bool.
// If the passed equal function is not valid, it will panic.
// The equal function will be used for equality check instead of the default equality check.
func (teq *Teq) AddEqual(equal any) {
ty := reflect.TypeOf(equal)
if ty.Kind() != reflect.Func {
panic("equal must be a function")
}
if ty.NumIn() != 2 {
panic("equal must have two arguments")
}
if ty.In(0) != ty.In(1) {
panic("equal must have two arguments with the same type")
}
if ty.NumOut() != 1 {
panic("equal must have only one return value")
}
if ty.Out(0).Kind() != reflect.Bool {
panic("equal must return bool")
}
equalValue := reflect.ValueOf(equal)
reflectEqual := func(v1, v2 reflect.Value) bool {
return equalValue.Call([]reflect.Value{v1, v2})[0].Bool()
}
teq.equals[ty.In(0)] = reflectEqual
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
Expand Down
38 changes: 38 additions & 0 deletions teq_customized_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package teq_test

import (
"math"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -47,6 +48,43 @@ func TestEqual_Customized(t *testing.T) {
t.Error("expected ds1 != ds2, got ds1 == ds2 with reflect.DeepEqual")
}
})

t.Run("AddEqual", func(t *testing.T) {
t.Run("float64", func(t *testing.T) {
d := teq.New()
c := teq.New()
c.AddEqual(func(a, b float64) bool {
const epsilon = 1e-3
return math.Abs(a-b) < epsilon
})

d.NotEqual(t, 1.0, 1.001)
c.Equal(t, 1.0, 1.001)
c.NotEqual(t, 1.0, 1.002)

d.NotEqual(t, float32(1.0), float32(1.001))
c.NotEqual(t, float32(1.0), float32(1.001))

d.NotEqual(t, []float64{1.0, 1.0, 1.001}, []float64{1.001, 1.0, 1.0})
c.Equal(t, []float64{1.0, 1.0, 1.001}, []float64{1.001, 1.0, 1.0})
})

t.Run("time.Time", func(t *testing.T) {
d := teq.New()
c := teq.New()
c.AddEqual(func(a, b time.Time) bool {
return a.Equal(b)
})

secondsEastOfUTC := int((8 * time.Hour).Seconds())
beijing := time.FixedZone("Beijing Time", secondsEastOfUTC)
d1 := time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC)
d2 := time.Date(2000, 2, 1, 20, 30, 0, 0, beijing)

d.NotEqual(t, d1, d2)
c.Equal(t, d1, d2)
})
})
}

func TestEqual_CustomizedFormat(t *testing.T) {
Expand Down
Loading