Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

settable time precision for small differences #66

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion deep.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import (
"log"
"reflect"
"strings"
"time"
)

var (
// FloatPrecision is the number of decimal places to round float values
// to when comparing.
FloatPrecision = 10

// TimePrecision is a precision used for time.Time.Truncate(), if it is non-zero.
TimePrecision time.Duration

// MaxDiff specifies the maximum number of differences to return.
MaxDiff = 10

Expand Down Expand Up @@ -79,7 +83,11 @@ type cmp struct {
flag map[byte]bool
}

var errorType = reflect.TypeOf((*error)(nil)).Elem()
var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
timeType = reflect.TypeOf(time.Time{})
durationType = reflect.TypeOf(time.Nanosecond)
)

// Equal compares variables a and b, recursing into their structure up to
// MaxDepth levels deep (if greater than zero), and returns a list of differences,
Expand Down Expand Up @@ -203,6 +211,23 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
return
}

fixTimePrecision := func() {
if TimePrecision > 0 {
switch aType {
case timeType, durationType:
aFunc := a.MethodByName("Truncate")
bFunc := a.MethodByName("Truncate")

if aFunc.CanInterface() && bFunc.CanInterface() {
precision := reflect.ValueOf(TimePrecision)

a = aFunc.Call([]reflect.Value{precision})[0]
b = bFunc.Call([]reflect.Value{precision})[0]
}
}
}
}

switch aKind {

/////////////////////////////////////////////////////////////////////
Expand All @@ -221,6 +246,8 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
Iterate through the fields (FirstName, LastName), recurse into their values.
*/

fixTimePrecision()

// Types with an Equal() method, like time.Time, only if struct field
// is exported (CanInterface)
if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() {
Expand Down Expand Up @@ -439,6 +466,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.saveDiff(a.Bool(), b.Bool())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fixTimePrecision()
if a.Int() != b.Int() {
c.saveDiff(a.Int(), b.Int())
}
Expand Down
91 changes: 91 additions & 0 deletions deep_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import (
v2 "github.com/go-test/deep/test/v2"
)

const (
multilineTestError = `wrong diff:
got: %q
expected: %q`
)

func TestString(t *testing.T) {
diff := deep.Equal("foo", "foo")
if len(diff) > 0 {
Expand Down Expand Up @@ -1184,6 +1190,43 @@ func TestTimeUnexported(t *testing.T) {
}
}

func TestTimePrecision(t *testing.T) {
restoreTimePrecision := deep.TimePrecision
t.Cleanup(func() { deep.TimePrecision = restoreTimePrecision })

deep.TimePrecision = 1 * time.Microsecond

now := time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC)
later := now.Add(123 * time.Nanosecond)

shouldBeEqual(t, deep.Equal(now, later))

d1 := 1 * time.Microsecond
d2 := d1 + 123*time.Nanosecond

shouldBeEqual(t, deep.Equal(d1, d2))

restoreCompareUnexportedFields := deep.CompareUnexportedFields
t.Cleanup(func() { deep.CompareUnexportedFields = restoreCompareUnexportedFields })

deep.CompareUnexportedFields = true

type S struct {
t time.Time
d time.Duration
}

s1 := &S{t: now, d: d1}
s2 := &S{t: later, d: d2}

// Since we cannot call `Truncate` on the unexported fields,
// we will show differences here.
shouldBeDiffs(t, deep.Equal(s1, s2),
"t.wall: 0 != 123",
"d: 1000 != 1123",
)
}

func TestInterface(t *testing.T) {
a := map[string]interface{}{
"foo": map[string]string{
Expand Down Expand Up @@ -1613,3 +1656,51 @@ func TestNilPointersAreZero(t *testing.T) {
t.Fatalf("expected 1 diff, got %d: %s", len(diff), diff)
}
}

func reportWrongDiff(t testing.TB, got, expect string) {
t.Helper()

output := fmt.Sprintf("wrong diff: got %q, expected %q", got, expect)
if len(output) > 120 {
output = fmt.Sprintf(multilineTestError, got, expect)
}

t.Error(output)
}

func shouldBeDiffs(t testing.TB, diff []string, head string, tail ...string) {
t.Helper()

if len(diff) == 0 {
t.Fatal("no diffs")
}

if len(diff) != len(tail)+1 {
t.Log("diff:", diff)
t.Errorf("wrong number of diffs: got %d, expected %d", len(diff), len(tail)+1)
}

if expect := head; diff[0] != expect {
reportWrongDiff(t, diff[0], expect)
}

for i, expect := range tail {
if i+1 >= len(diff) {
t.Errorf("missing diff: %q", expect)
continue
}

if got := diff[i+1]; got != expect {
reportWrongDiff(t, got, expect)
}

}
}

func shouldBeEqual(t testing.TB, diff []string) {
t.Helper()

if len(diff) > 0 {
t.Errorf("should be equal: %q", diff)
}
}