From 288e73d9dd83ea4b9d4fa2ab0ec9f7d5f72eb4f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gergely=20B=C3=B3di?= Date: Thu, 5 Sep 2024 10:04:13 +0200 Subject: [PATCH 1/2] settable time precision for small differences --- deep.go | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/deep.go b/deep.go index 4be3e1f..5fe2cc6 100644 --- a/deep.go +++ b/deep.go @@ -9,6 +9,7 @@ import ( "log" "reflect" "strings" + "time" ) var ( @@ -16,6 +17,9 @@ var ( // to when comparing. FloatPrecision = 10 + // TimePrecision is a precision used for time.Time.Truncate(), if it is non-zero. + TimePrecision time.Duration + // MaxDiff specifies the maximum number of differences to return. MaxDiff = 10 @@ -79,7 +83,11 @@ type cmp struct { flag map[byte]bool } -var errorType = reflect.TypeOf((*error)(nil)).Elem() +var ( + errorType = reflect.TypeOf((*error)(nil)).Elem() + timeType = reflect.TypeOf(time.Time{}) + durationType = reflect.TypeOf(time.Nanosecond) +) // Equal compares variables a and b, recursing into their structure up to // MaxDepth levels deep (if greater than zero), and returns a list of differences, @@ -203,6 +211,23 @@ func (c *cmp) equals(a, b reflect.Value, level int) { return } + fixTimePrecision := func() { + if TimePrecision > 0 { + switch aType { + case timeType, durationType: + aFunc := a.MethodByName("Truncate") + bFunc := a.MethodByName("Truncate") + + if aFunc.CanInterface() && bFunc.CanInterface() { + precision := reflect.ValueOf(TimePrecision) + + a = aFunc.Call([]reflect.Value{precision})[0] + b = bFunc.Call([]reflect.Value{precision})[0] + } + } + } + } + switch aKind { ///////////////////////////////////////////////////////////////////// @@ -221,6 +246,8 @@ func (c *cmp) equals(a, b reflect.Value, level int) { Iterate through the fields (FirstName, LastName), recurse into their values. */ + fixTimePrecision() + // Types with an Equal() method, like time.Time, only if struct field // is exported (CanInterface) if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() { @@ -439,6 +466,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.saveDiff(a.Bool(), b.Bool()) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fixTimePrecision() if a.Int() != b.Int() { c.saveDiff(a.Int(), b.Int()) } From d956abba4b25ada4c530bf381b78c4095748f73f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gergely=20B=C3=B3di?= Date: Thu, 5 Sep 2024 10:31:06 +0200 Subject: [PATCH 2/2] tests --- deep_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/deep_test.go b/deep_test.go index 74377e1..1203f3a 100644 --- a/deep_test.go +++ b/deep_test.go @@ -14,6 +14,12 @@ import ( v2 "github.com/go-test/deep/test/v2" ) +const ( + multilineTestError = `wrong diff: + got: %q + expected: %q` +) + func TestString(t *testing.T) { diff := deep.Equal("foo", "foo") if len(diff) > 0 { @@ -1184,6 +1190,43 @@ func TestTimeUnexported(t *testing.T) { } } +func TestTimePrecision(t *testing.T) { + restoreTimePrecision := deep.TimePrecision + t.Cleanup(func() { deep.TimePrecision = restoreTimePrecision }) + + deep.TimePrecision = 1 * time.Microsecond + + now := time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC) + later := now.Add(123 * time.Nanosecond) + + shouldBeEqual(t, deep.Equal(now, later)) + + d1 := 1 * time.Microsecond + d2 := d1 + 123*time.Nanosecond + + shouldBeEqual(t, deep.Equal(d1, d2)) + + restoreCompareUnexportedFields := deep.CompareUnexportedFields + t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields }) + + deep.CompareUnexportedFields = true + + type S struct { + t time.Time + d time.Duration + } + + s1 := &S{t: now, d: d1} + s2 := &S{t: later, d: d2} + + // Since we cannot call `Truncate` on the unexported fields, + // we will show differences here. + shouldBeDiffs(t, deep.Equal(s1, s2), + "t.wall: 0 != 123", + "d: 1000 != 1123", + ) +} + func TestInterface(t *testing.T) { a := map[string]interface{}{ "foo": map[string]string{ @@ -1613,3 +1656,51 @@ func TestNilPointersAreZero(t *testing.T) { t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff) } } + +func reportWrongDiff(t testing.TB, got, expect string) { + t.Helper() + + output := fmt.Sprintf("wrong diff: got %q, expected %q", got, expect) + if len(output) > 120 { + output = fmt.Sprintf(multilineTestError, got, expect) + } + + t.Error(output) +} + +func shouldBeDiffs(t testing.TB, diff []string, head string, tail ...string) { + t.Helper() + + if len(diff) == 0 { + t.Fatal("no diffs") + } + + if len(diff) != len(tail)+1 { + t.Log("diff:", diff) + t.Errorf("wrong number of diffs: got %d, expected %d", len(diff), len(tail)+1) + } + + if expect := head; diff[0] != expect { + reportWrongDiff(t, diff[0], expect) + } + + for i, expect := range tail { + if i+1 >= len(diff) { + t.Errorf("missing diff: %q", expect) + continue + } + + if got := diff[i+1]; got != expect { + reportWrongDiff(t, got, expect) + } + + } +} + +func shouldBeEqual(t testing.TB, diff []string) { + t.Helper() + + if len(diff) > 0 { + t.Errorf("should be equal: %q", diff) + } +}