From 3ea7ad094bea7c98eab3c5660fd93873c9c6bcf6 Mon Sep 17 00:00:00 2001 From: Seiya <20365512+seiyab@users.noreply.github.com> Date: Tue, 7 May 2024 18:53:54 +0900 Subject: [PATCH] implement AddEqual (#10) --- deepequal.go | 5 +++++ teq.go | 30 ++++++++++++++++++++++++++++++ teq_customized_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/deepequal.go b/deepequal.go index f1dc661..bd785cf 100644 --- a/deepequal.go +++ b/deepequal.go @@ -35,6 +35,11 @@ func (teq Teq) deepValueEqual( return false } + eq, ok := teq.equals[v1.Type()] + if ok { + return eq(v1, v2) + } + tr, ok := teq.transforms[v1.Type()] if ok { t1 := tr(v1) diff --git a/teq.go b/teq.go index 15d9cd7..17b798c 100644 --- a/teq.go +++ b/teq.go @@ -11,6 +11,7 @@ type Teq struct { transforms map[reflect.Type]func(reflect.Value) reflect.Value formats map[reflect.Type]func(reflect.Value) string + equals map[reflect.Type]func(reflect.Value, reflect.Value) bool } // New returns new instance of Teq. @@ -20,6 +21,7 @@ func New() Teq { transforms: make(map[reflect.Type]func(reflect.Value) reflect.Value), formats: make(map[reflect.Type]func(reflect.Value) string), + equals: make(map[reflect.Type]func(reflect.Value, reflect.Value) bool), } } @@ -109,6 +111,34 @@ func (teq *Teq) AddFormat(format any) { teq.formats[ty.In(0)] = reflectFormat } +// AddEqual adds an equal function to Teq. +// The equal function must have two arguments with the same type and one return value of bool. +// If the passed equal function is not valid, it will panic. +// The equal function will be used for equality check instead of the default equality check. +func (teq *Teq) AddEqual(equal any) { + ty := reflect.TypeOf(equal) + if ty.Kind() != reflect.Func { + panic("equal must be a function") + } + if ty.NumIn() != 2 { + panic("equal must have two arguments") + } + if ty.In(0) != ty.In(1) { + panic("equal must have two arguments with the same type") + } + if ty.NumOut() != 1 { + panic("equal must have only one return value") + } + if ty.Out(0).Kind() != reflect.Bool { + panic("equal must return bool") + } + equalValue := reflect.ValueOf(equal) + reflectEqual := func(v1, v2 reflect.Value) bool { + return equalValue.Call([]reflect.Value{v1, v2})[0].Bool() + } + teq.equals[ty.In(0)] = reflectEqual +} + func (teq Teq) equal(x, y any) bool { if x == nil || y == nil { return x == y diff --git a/teq_customized_test.go b/teq_customized_test.go index 56b9853..32ee024 100644 --- a/teq_customized_test.go +++ b/teq_customized_test.go @@ -1,6 +1,7 @@ package teq_test import ( + "math" "reflect" "testing" "time" @@ -47,6 +48,43 @@ func TestEqual_Customized(t *testing.T) { t.Error("expected ds1 != ds2, got ds1 == ds2 with reflect.DeepEqual") } }) + + t.Run("AddEqual", func(t *testing.T) { + t.Run("float64", func(t *testing.T) { + d := teq.New() + c := teq.New() + c.AddEqual(func(a, b float64) bool { + const epsilon = 1e-3 + return math.Abs(a-b) < epsilon + }) + + d.NotEqual(t, 1.0, 1.001) + c.Equal(t, 1.0, 1.001) + c.NotEqual(t, 1.0, 1.002) + + d.NotEqual(t, float32(1.0), float32(1.001)) + c.NotEqual(t, float32(1.0), float32(1.001)) + + d.NotEqual(t, []float64{1.0, 1.0, 1.001}, []float64{1.001, 1.0, 1.0}) + c.Equal(t, []float64{1.0, 1.0, 1.001}, []float64{1.001, 1.0, 1.0}) + }) + + t.Run("time.Time", func(t *testing.T) { + d := teq.New() + c := teq.New() + c.AddEqual(func(a, b time.Time) bool { + return a.Equal(b) + }) + + secondsEastOfUTC := int((8 * time.Hour).Seconds()) + beijing := time.FixedZone("Beijing Time", secondsEastOfUTC) + d1 := time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC) + d2 := time.Date(2000, 2, 1, 20, 30, 0, 0, beijing) + + d.NotEqual(t, d1, d2) + c.Equal(t, d1, d2) + }) + }) } func TestEqual_CustomizedFormat(t *testing.T) {