Skip to content

Commit

Permalink
Support iter.Seqs in [Not]Contains and [Not]ElementsMatch
Browse files Browse the repository at this point in the history
  • Loading branch information
misberner authored and malte-prophet committed Dec 12, 2024
1 parent 89cbdd9 commit 7f96d99
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 26 deletions.
6 changes: 3 additions & 3 deletions assert/assertion_format.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions assert/assertion_forward.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 14 additions & 5 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}

same, ok := samePointers(expected, actual)
if !ok {
//fails when the arguments are not pointers
// fails when the arguments are not pointers
return !(Fail(t, "Both arguments must be pointers", msgAndArgs...))
}

Expand All @@ -549,7 +549,7 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
func samePointers(first, second interface{}) (same bool, ok bool) {
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
return false, false //not both are pointers
return false, false // not both are pointers
}

firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
Expand Down Expand Up @@ -918,7 +918,7 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {

}

// Contains asserts that the specified string, list(array, slice...) or map contains the
// Contains asserts that the specified string, list(array, slice, sequence...) or map contains the
// specified substring or element.
//
// assert.Contains(t, "Hello World", "World")
Expand All @@ -929,6 +929,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}

s = seqToSlice(s)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
Expand All @@ -952,6 +953,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}

s = seqToSlice(s)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
Expand Down Expand Up @@ -1088,6 +1090,10 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface
if h, ok := t.(tHelper); ok {
h.Helper()
}
// Convert sequences to lists, if applicable
listA = seqToSlice(listA)
listB = seqToSlice(listB)

if isEmpty(listA) && isEmpty(listB) {
return true
}
Expand Down Expand Up @@ -1175,8 +1181,8 @@ func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) stri
return msg.String()
}

// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// NotElementsMatch asserts that the specified listA(array, slice, sequence...) is NOT equal to specified
// listB(array, slice, sequence...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
Expand All @@ -1189,6 +1195,9 @@ func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interf
if h, ok := t.(tHelper); ok {
h.Helper()
}
// Convert sequences to lists, if applicable
listA = seqToSlice(listA)
listB = seqToSlice(listB)
if isEmpty(listA) && isEmpty(listB) {
return Fail(t, "listA and listB contain the same elements", msgAndArgs)
}
Expand Down
179 changes: 179 additions & 0 deletions assert/assertions_seq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
//go:build go1.23 || goexperiment.rangefunc

package assert

import (
"fmt"
"testing"
)

// go.mod version is set to 1.17, which precludes the use of generics (even though this file wouldn't be taken into
// account per the build tags).

func intSeq(s ...int) func(yield func(int) bool) {
return func(yield func(int) bool) {
for _, elem := range s {
if !yield(elem) {
break
}
}
}
}

func strSeq(s ...string) func(yield func(string) bool) {
return func(yield func(string) bool) {
for _, elem := range s {
if !yield(elem) {
break
}
}
}
}

func TestElementsMatch_Seq(t *testing.T) {
mockT := new(testing.T)

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{intSeq(), intSeq(), true},
{intSeq(1), intSeq(1), true},
{intSeq(1, 1), intSeq(1, 1), true},
{intSeq(1, 2), intSeq(1, 2), true},
{intSeq(1, 2), intSeq(2, 1), true},
{strSeq("hello", "world"), strSeq("world", "hello"), true},
{strSeq("hello", "hello"), strSeq("hello", "hello"), true},
{strSeq("hello", "hello", "world"), strSeq("hello", "world", "hello"), true},
{intSeq(), nil, true},

// not matching
{intSeq(1), intSeq(1, 1), false},
{intSeq(1, 2), intSeq(2, 2), false},
{strSeq("hello", "hello"), strSeq("hello"), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ElementsMatch(%#v, %#v)", seqToSlice(c.expected), seqToSlice(c.actual)), func(t *testing.T) {
res := ElementsMatch(mockT, c.actual, c.expected)

if res != c.result {
t.Errorf("ElementsMatch(%#v, %#v) should return %v", seqToSlice(c.actual), seqToSlice(c.expected), c.result)
}
})
}
}

func TestNotElementsMatch_Seq(t *testing.T) {
mockT := new(testing.T)

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
// not matching
{intSeq(1), intSeq(), true},
{intSeq(), intSeq(2), true},
{intSeq(1), intSeq(2), true},
{intSeq(1), intSeq(1, 1), true},
{intSeq(1, 2), intSeq(3, 4), true},
{intSeq(3, 4), intSeq(1, 2), true},
{intSeq(1, 1, 2, 3), intSeq(1, 2, 3), true},
{strSeq("hello"), strSeq("world"), true},
{strSeq("hello", "hello"), strSeq("world", "world"), true},

// matching
{intSeq(), nil, false},
{intSeq(), intSeq(), false},
{intSeq(1), intSeq(1), false},
{intSeq(1, 1), intSeq(1, 1), false},
{intSeq(1, 2), intSeq(2, 1), false},
{intSeq(1, 1, 2), intSeq(1, 2, 1), false},
{strSeq("hello", "world"), strSeq("world", "hello"), false},
{strSeq("hello", "hello"), strSeq("hello", "hello"), false},
{strSeq("hello", "hello", "world"), strSeq("hello", "world", "hello"), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("NotElementsMatch(%#v, %#v)", seqToSlice(c.expected), seqToSlice(c.actual)), func(t *testing.T) {
res := NotElementsMatch(mockT, c.actual, c.expected)

if res != c.result {
t.Errorf("NotElementsMatch(%#v, %#v) should return %v", seqToSlice(c.actual), seqToSlice(c.expected), c.result)
}
})
}
}

func TestContainsNotContains_Seq(t *testing.T) {

type A struct {
Name, Value string
}
complexSeq := func(s ...*A) func(yield func(*A) bool) {
return func(yield func(*A) bool) {
for _, elem := range s {
if !yield(elem) {
break
}
}
}
}

list := []string{"Foo", "Bar"}

complexList := []*A{
{"b", "c"},
{"d", "e"},
{"g", "h"},
{"j", "k"},
}

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{strSeq(list...), "Bar", true},
{strSeq(list...), "Salut", false},
{complexSeq(complexList...), &A{"g", "h"}, true},
{complexSeq(complexList...), &A{"g", "e"}, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("Contains(%#v, %#v)", seqToSlice(c.expected), seqToSlice(c.actual)), func(t *testing.T) {
mockT := new(testing.T)
res := Contains(mockT, c.expected, c.actual)

if res != c.result {
if res {
t.Errorf(
"Contains(%#v, %#v) should return true:\n\t%#v contains %#v",
seqToSlice(c.expected), seqToSlice(c.actual), seqToSlice(c.expected), seqToSlice(c.actual))
} else {
t.Errorf(
"Contains(%#v, %#v) should return false:\n\t%#v does not contain %#v",
seqToSlice(c.expected), seqToSlice(c.actual), seqToSlice(c.expected), seqToSlice(c.actual))
}
}
})
}

for _, c := range cases {
t.Run(fmt.Sprintf("NotContains(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
mockT := new(testing.T)
res := NotContains(mockT, c.expected, c.actual)

// NotContains should be inverse of Contains. If it's not, something is wrong
if res == Contains(mockT, c.expected, c.actual) {
if res {
t.Errorf("NotContains(%#v, %#v) should return true:\n\t%#v does not contains %#v", c.expected, c.actual, c.expected, c.actual)
} else {
t.Errorf("NotContains(%#v, %#v) should return false:\n\t%#v contains %#v", c.expected, c.actual, c.expected, c.actual)
}
}
})
}
}
44 changes: 44 additions & 0 deletions assert/seq_supported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build go1.23 || goexperiment.rangefunc

package assert

import "reflect"

var (
boolType = reflect.TypeOf(true)
)

// seqToSlice checks if x is a sequence, and converts it to a slice of the
// same element type. Otherwise, x is returned as-is.
func seqToSlice(x interface{}) interface{} {
if x == nil {
return nil
}

xv := reflect.ValueOf(x)
xt := xv.Type()
// We're looking for a function with exactly one input parameter and no return values.
if xt.Kind() != reflect.Func || xt.NumIn() != 1 || xt.NumOut() != 0 {
return x
}

// The input parameter should be of type func(T) bool
paramType := xt.In(0)
if paramType.Kind() != reflect.Func || paramType.NumIn() != 1 || paramType.NumOut() != 1 || paramType.Out(0) != boolType {
return x
}

elemType := paramType.In(0)
resultType := reflect.SliceOf(elemType)
result := reflect.MakeSlice(resultType, 0, 0)

yieldFunc := reflect.MakeFunc(paramType, func(args []reflect.Value) []reflect.Value {
result = reflect.Append(result, args[0])
return []reflect.Value{reflect.ValueOf(true)}
})

// Call the function with the yield function as the argument
xv.Call([]reflect.Value{yieldFunc})

return result.Interface()
}
9 changes: 9 additions & 0 deletions assert/seq_unsupported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build !go1.23 && !goexperiment.rangefunc

package assert

// seqToSlice would convert a sequence of elements to a slice of the respective type.
// However, since sequences are not supported given the build tags, it just returns x as-is.
func seqToSlice(x interface{}) interface{} {
return x
}
Loading

0 comments on commit 7f96d99

Please sign in to comment.