diff --git a/checker/checker.go b/checker/checker.go index b46178d4..e4429e09 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -623,9 +623,9 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) } if isAny(collection) { - return arrayType, info{} + return anyArrayType, info{} } - return arrayType, info{} + return collection, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -643,7 +643,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { closure.NumOut() == 1 && closure.NumIn() == 1 && isAny(closure.In(0)) { - return arrayType, info{} + return anyArrayType, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -1122,9 +1122,9 @@ func (v *checker) ArrayNode(node *ast.ArrayNode) (reflect.Type, info) { prev = curr } if allElementsAreSameType && prev != nil { - return arrayType, info{elem: prev} + return arrayType(prev), info{elem: prev} } - return arrayType, info{} + return anyArrayType, info{} } func (v *checker) MapNode(node *ast.MapNode) (reflect.Type, info) { diff --git a/checker/checker_test.go b/checker/checker_test.go index d6a84abc..07b8d118 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -974,3 +974,21 @@ func TestCheck_builtin_without_call(t *testing.T) { }) } } + +func TestCheck_types(t *testing.T) { + tests := []struct { + input string + nodeType reflect.Type + }{ + {`filter([1,2,3], # > 1)`, reflect.TypeOf([]int{})}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + tree, err := parser.Parse(test.input) + require.NoError(t, err) + + nodeType, err := checker.Check(tree, conf.New(nil)) + require.Equal(t, test.nodeType.String(), nodeType.String()) + }) + } +} diff --git a/checker/types.go b/checker/types.go index d10736a7..0942cd32 100644 --- a/checker/types.go +++ b/checker/types.go @@ -13,13 +13,17 @@ var ( integerType = reflect.TypeOf(0) floatType = reflect.TypeOf(float64(0)) stringType = reflect.TypeOf("") - arrayType = reflect.TypeOf([]any{}) + anyArrayType = reflect.TypeOf([]any{}) mapType = reflect.TypeOf(map[string]any{}) anyType = reflect.TypeOf(new(any)).Elem() timeType = reflect.TypeOf(time.Time{}) durationType = reflect.TypeOf(time.Duration(0)) ) +func arrayType(t reflect.Type) reflect.Type { + return reflect.SliceOf(t) +} + func combined(a, b reflect.Type) reflect.Type { if a.Kind() == b.Kind() { return a