Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ generate:
go run ./cmd/minimock/minimock.go -i ./tests.formatterAlias -o ./tests/formatter_alias_mock.go
go run ./cmd/minimock/minimock.go -i ./tests.formatterType -o ./tests/formatter_type_mock.go
go run ./cmd/minimock/minimock.go -i ./tests.reader -o ./tests/reader_mock.go -gr
go run ./cmd/minimock/minimock.go -i ./tests.funcCaller -o ./tests/func_caller_mock.go

./bin:
mkdir ./bin
Expand Down
20 changes: 20 additions & 0 deletions equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func Equal(a, b interface{}) bool {
continue
}

if checkFunctions(aFieldValue, bFieldValue) {
continue
}

if !reflect.DeepEqual(aFieldValue, bFieldValue) {
return false
}
Expand All @@ -48,6 +52,10 @@ func Equal(a, b interface{}) bool {
return true
}

if checkFunctions(a, b) {
return true
}

return reflect.DeepEqual(a, b)
}

Expand Down Expand Up @@ -103,3 +111,15 @@ func unexported(field reflect.Value) interface{} {
func unexportedVal(field reflect.Value) reflect.Value {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
}

// checkFunctions returns true
func checkFunctions(a, b interface{}) (differs bool) {
if a == nil || b == nil {
return false
}
if reflect.TypeOf(a).Kind() != reflect.Func || reflect.TypeOf(b).Kind() != reflect.Func {
return false
}

return reflect.ValueOf(a).UnsafePointer() == reflect.ValueOf(b).UnsafePointer()
}
74 changes: 74 additions & 0 deletions equal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package minimock

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestCheckFunctions(t *testing.T) {
t.Parallel()
var (
validFunc = func() {}
nilFunc func() = nil
)

tests := []struct {
name string
a interface{}
b interface{}
want bool
}{{
name: "one is nil, another is nil func",
a: nil,
b: nilFunc,
want: false,
}, {
name: "arguments are nil",
a: nil,
b: nil,
want: false,
}, {
name: "nil func and nil",
a: nilFunc,
b: nil,
want: false,
}, {
name: "arguments are not functions",
a: 1,
b: "string",
want: false,
}, {
name: "both functions are nil",
a: nilFunc,
b: nilFunc,
want: true,
}, {
name: "a is nil",
a: nilFunc,
b: validFunc,
want: false,
}, {
name: "b is nil",
a: validFunc,
b: nilFunc,
want: false,
}, {
name: "functions are anonymous",
a: func() {},
b: func() {},
want: false,
}, {
name: "functions are equal",
a: validFunc,
b: validFunc,
want: true,
}}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, checkFunctions(tt.a, tt.b))
})
}
}
Loading