diff --git a/internal/engine/integration/procedure_test.go b/internal/engine/integration/procedure_test.go index c31a85aef..d039b7d07 100644 --- a/internal/engine/integration/procedure_test.go +++ b/internal/engine/integration/procedure_test.go @@ -160,6 +160,86 @@ func Test_Procedures(t *testing.T) { {"satoshi", "hello world"}, }, }, + { + name: "string functions", + procedure: `procedure string_funcs() public view { + $val := 'hello world'; + $val := $val || '!!!'; + $val := upper($val); + if $val != 'HELLO WORLD!!!' { + error('upper failed'); + } + $val := lower($val); + if $val != 'hello world!!!' { + error('lower failed'); + } + + if bit_length($val) != 112 { + error('bit_length failed'); + } + if char_length($val) != 14 or character_length($val) != 14 or length($val) != 14 { + error('length failed'); + } + if octet_length($val) != 14 { + error('octet_length failed'); + } + $val := rtrim($val, '!'); + if $val != 'hello world' { + error('rtrim failed'); + } + if rtrim($val||' ') != 'hello world' { + error('rtrim 2 failed'); + } + + $val := ltrim($val, 'h'); + if $val != 'ello world' { + error('ltrim failed'); + } + if ltrim(' '||$val) != 'ello world' { // add a space and trim it off + error('ltrim 2 failed'); + } + + $val := lpad($val, 11, 'h'); + if $val != 'hello world' { + error('lpad failed'); + } + if lpad($val, 12) != ' hello world' { + error('lpad 2 failed'); + } + + $val := rpad($val, 12, '!'); + if $val != 'hello world!' { + error('rpad failed'); + } + if rpad($val, 13) != 'hello world! ' { + error('rpad 2 failed'); + } + + if overlay($val, 'xx', 2, 5) != 'hxxworld!' { + error('overlay failed'); + } + if overlay($val, 'xx', 2) != 'hxxlo world!' { + error('overlay 2 failed'); + } + + if position('world', $val) != 7 { + error('position failed'); + } + if substring($val, 7, 5) != 'world' { + error('substring failed'); + } + if substring($val, 7) != 'world!' { + error('substring 2 failed'); + } + + if trim(' ' || $val || ' ') != 'hello world!' { + error('trim failed'); + } + if trim('a'||$val||'a', 'a') != 'hello world!' { + error('trim 2 failed'); + } + }`, + }, } for _, test := range tests { diff --git a/parse/analyze.go b/parse/analyze.go index 796747650..072f90241 100644 --- a/parse/analyze.go +++ b/parse/analyze.go @@ -848,17 +848,25 @@ func (s *sqlAnalyzer) VisitExpressionArithmetic(p0 *ExpressionArithmetic) any { return s.expressionTypeErr(p0.Right) } - if !left.Equals(right) { - return s.typeErr(p0.Right, right, left) - } - // both must be numeric UNLESS it is a concat if p0.Operator == ArithmeticOperatorConcat { - s.expect(p0.Left, left, types.TextType) + if !left.Equals(types.TextType) || !right.Equals(types.TextType) { + // Postgres supports concatenation on non-text types, but we do not, + // so we give a more descriptive error here. + // see the note at the top of: https://www.postgresql.org/docs/16.1/functions-string.html + s.errs.AddErr(p0.Left, ErrType, "concatenation only allowed on text types. received %s and %s", left.String(), right.String()) + return cast(p0, types.UnknownType) + } } else { s.expectedNumeric(p0.Left, left) } + // we check this after to return a more helpful error message if + // the user is not concatenating strings. + if !left.Equals(right) { + return s.typeErr(p0.Right, right, left) + } + return cast(p0, left) } diff --git a/parse/functions.go b/parse/functions.go index 598fb3e63..50e071eba 100644 --- a/parse/functions.go +++ b/parse/functions.go @@ -15,22 +15,13 @@ var ( return nil, wrapErrArgumentNumber(1, len(args)) } - if !args[0].EqualsStrict(types.IntType) { - return nil, wrapErrArgumentType(types.IntType, args[0]) + if !args[0].EqualsStrict(types.IntType) && args[0].Name != types.DecimalStr { + return nil, fmt.Errorf("expected argument to be int or decimal, got %s", args[0].String()) } - return types.IntType, nil - }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("abs") - } - if distinct { - return "", errDistinct("abs") - } - - return fmt.Sprintf("abs(%s)", inputs[0]), nil + return args[0], nil }, + PGFormat: defaultFormat("abs"), }, "error": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { @@ -44,64 +35,177 @@ var ( return nil, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("error") + PGFormat: defaultFormat("error"), + }, + "uuid_generate_v5": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + // first argument must be a uuid, second argument must be text + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) } - if distinct { - return "", errDistinct("error") + + if !args[0].EqualsStrict(types.UUIDType) { + return nil, wrapErrArgumentType(types.UUIDType, args[0]) + } + + if !args[1].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[1]) } - return fmt.Sprintf("error(%s)", inputs[0]), nil + return types.UUIDType, nil }, + PGFormat: defaultFormat("uuid_generate_v5"), }, - "length": { + "encode": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - if len(args) != 1 { - return nil, wrapErrArgumentNumber(1, len(args)) + // first must be blob, second must be text + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) + } + + if !args[0].EqualsStrict(types.BlobType) { + return nil, wrapErrArgumentType(types.BlobType, args[0]) + } + + if !args[1].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[1]) + } + + return types.TextType, nil + }, + PGFormat: defaultFormat("encode"), + }, + "decode": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + // first must be text, second must be text + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) } if !args[0].EqualsStrict(types.TextType) { return nil, wrapErrArgumentType(types.TextType, args[0]) } - return types.IntType, nil + if !args[1].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[1]) + } + + return types.BlobType, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", fmt.Errorf("cannot use * with length") + PGFormat: defaultFormat("decode"), + }, + "digest": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + // first must be either text or blob, second must be text + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) } - if distinct { - return "", fmt.Errorf("cannot use distinct with length") + + if !args[0].EqualsStrict(types.TextType) && !args[0].EqualsStrict(types.BlobType) { + return nil, fmt.Errorf("expected first argument to be text or blob, got %s", args[0].String()) } - return fmt.Sprintf("length(%s)", inputs[0]), nil + if !args[1].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[1]) + } + + return types.BlobType, nil }, + PGFormat: defaultFormat("digest"), }, - "lower": { + // array functions + "array_append": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) + } + + if !args[0].IsArray { + return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String()) + } + + if args[1].IsArray { + return nil, fmt.Errorf("%w: expected second argument to be a scalar, got %s", ErrType, args[1].String()) + } + + if !strings.EqualFold(args[0].Name, args[1].Name) { + return nil, fmt.Errorf("%w: append type must be equal to scalar array type: array type: %s append type: %s", ErrType, args[0].Name, args[1].Name) + } + + return args[0], nil + }, + PGFormat: defaultFormat("array_append"), + }, + "array_prepend": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) + } + + if args[0].IsArray { + return nil, fmt.Errorf("%w: expected first argument to be a scalar, got %s", ErrType, args[0].String()) + } + + if !args[1].IsArray { + return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String()) + } + + if !strings.EqualFold(args[0].Name, args[1].Name) { + return nil, fmt.Errorf("%w: prepend type must be equal to scalar array type: array type: %s prepend type: %s", ErrType, args[1].Name, args[0].Name) + } + + return args[1], nil + }, + PGFormat: defaultFormat("array_prepend"), + }, + "array_cat": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) != 2 { + return nil, wrapErrArgumentNumber(2, len(args)) + } + + if !args[0].IsArray { + return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String()) + } + + if !args[1].IsArray { + return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String()) + } + + if !strings.EqualFold(args[0].Name, args[1].Name) { + return nil, fmt.Errorf("%w: expected both arrays to be of the same scalar type, got %s and %s", ErrType, args[0].Name, args[1].Name) + } + + return args[0], nil + }, + PGFormat: defaultFormat("array_cat"), + }, + "array_length": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { if len(args) != 1 { return nil, wrapErrArgumentNumber(1, len(args)) } - if !args[0].EqualsStrict(types.TextType) { - return nil, wrapErrArgumentType(types.TextType, args[0]) + if !args[0].IsArray { + return nil, fmt.Errorf("expected argument to be an array, got %s", args[0].String()) } - return types.TextType, nil + return types.IntType, nil }, PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { if star { - return "", errStar("lower") + return "", errStar("array_length") } if distinct { - return "", errDistinct("lower") + return "", errDistinct("array_length") } - return fmt.Sprintf("lower(%s)", inputs[0]), nil + return fmt.Sprintf("array_length(%s, 1)", inputs[0]), nil }, }, - "upper": { + // string functions + // the main SQL string functions defined here: https://www.postgresql.org/docs/16.1/functions-string.html + "bit_length": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { if len(args) != 1 { return nil, wrapErrArgumentNumber(1, len(args)) @@ -111,103 +215,125 @@ var ( return nil, wrapErrArgumentType(types.TextType, args[0]) } - return types.TextType, nil + return types.IntType, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("upper") + PGFormat: defaultFormat("bit_length"), + }, + "char_length": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) != 1 { + return nil, wrapErrArgumentNumber(1, len(args)) } - if distinct { - return "", errDistinct("upper") + + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - return fmt.Sprintf("upper(%s)", inputs[0]), nil + return types.IntType, nil }, + PGFormat: defaultFormat("char_length"), }, - "format": { + "character_length": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - if len(args) < 1 { - return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args)) + if len(args) != 1 { + return nil, wrapErrArgumentNumber(1, len(args)) } if !args[0].EqualsStrict(types.TextType) { return nil, wrapErrArgumentType(types.TextType, args[0]) } - return types.TextType, nil + return types.IntType, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("format") + PGFormat: defaultFormat("character_length"), + }, + "length": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) != 1 { + return nil, wrapErrArgumentNumber(1, len(args)) } - if distinct { - return "", errDistinct("format") + + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - return fmt.Sprintf("format(%s)", strings.Join(inputs, ", ")), nil + return types.IntType, nil }, + PGFormat: defaultFormat("length"), }, - "uuid_generate_v5": { + "lower": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - // first argument must be a uuid, second argument must be text - if len(args) != 2 { - return nil, wrapErrArgumentNumber(2, len(args)) + if len(args) != 1 { + return nil, wrapErrArgumentNumber(1, len(args)) } - if !args[0].EqualsStrict(types.UUIDType) { - return nil, wrapErrArgumentType(types.UUIDType, args[0]) + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - if !args[1].EqualsStrict(types.TextType) { - return nil, wrapErrArgumentType(types.TextType, args[1]) + return types.TextType, nil + }, + PGFormat: defaultFormat("lower"), + }, + "lpad": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + //can have 2-3 args. 1 and 3 must be text, 2 must be int + if len(args) < 2 || len(args) > 3 { + return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args)) } - return types.UUIDType, nil - }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("uuid_generate_v5") + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - if distinct { - return "", errDistinct("uuid_generate_v5") + + if !args[1].EqualsStrict(types.IntType) { + return nil, wrapErrArgumentType(types.IntType, args[1]) + } + + if len(args) == 3 && !args[2].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[2]) } - return fmt.Sprintf("uuid_generate_v5(%s)", strings.Join(inputs, ", ")), nil + return types.TextType, nil }, + PGFormat: defaultFormat("lpad"), }, - "encode": { + "ltrim": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - // first must be blob, second must be text - if len(args) != 2 { - return nil, wrapErrArgumentNumber(2, len(args)) - } - - if !args[0].EqualsStrict(types.BlobType) { - return nil, wrapErrArgumentType(types.BlobType, args[0]) + //can have 1 or 2 args. both must be text + if len(args) < 1 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args)) } - if !args[1].EqualsStrict(types.TextType) { - return nil, wrapErrArgumentType(types.TextType, args[1]) + for _, arg := range args { + if !arg.EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, arg) + } } return types.TextType, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("encode") + PGFormat: defaultFormat("ltrim"), + }, + "octet_length": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) != 1 { + return nil, wrapErrArgumentNumber(1, len(args)) } - if distinct { - return "", errDistinct("encode") + + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - return fmt.Sprintf("encode(%s)", strings.Join(inputs, ", ")), nil + return types.IntType, nil }, + PGFormat: defaultFormat("octet_length"), }, - "decode": { + "overlay": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - // first must be text, second must be text - if len(args) != 2 { - return nil, wrapErrArgumentNumber(2, len(args)) + // 3-4 arguments. 1 and 2 must be text, 3 must be int, 4 must be int + if len(args) < 3 || len(args) > 4 { + return nil, fmt.Errorf("invalid number of arguments: expected 3 or 4, got %d", len(args)) } if !args[0].EqualsStrict(types.TextType) { @@ -218,163 +344,200 @@ var ( return nil, wrapErrArgumentType(types.TextType, args[1]) } - return types.BlobType, nil + if !args[2].EqualsStrict(types.IntType) { + return nil, wrapErrArgumentType(types.IntType, args[2]) + } + + if len(args) == 4 && !args[3].EqualsStrict(types.IntType) { + return nil, wrapErrArgumentType(types.IntType, args[3]) + } + + return types.TextType, nil }, PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { + if distinct { + return "", errDistinct("overlay") + } + if star { - return "", errStar("decode") + return "", errStar("overlay") } - if distinct { - return "", errDistinct("decode") + + str := strings.Builder{} + str.WriteString("overlay(") + str.WriteString(inputs[0]) + str.WriteString(" placing ") + str.WriteString(inputs[1]) + str.WriteString(" from ") + str.WriteString(inputs[2]) + if len(inputs) == 4 { + str.WriteString(" for ") + str.WriteString(inputs[3]) } + str.WriteString(")") - return fmt.Sprintf("decode(%s)", strings.Join(inputs, ", ")), nil + return str.String(), nil }, }, - "digest": { + "position": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - // first must be either text or blob, second must be text + // 2 arguments. both must be text if len(args) != 2 { return nil, wrapErrArgumentNumber(2, len(args)) } - if !args[0].EqualsStrict(types.TextType) && !args[0].EqualsStrict(types.BlobType) { - return nil, fmt.Errorf("expected first argument to be text or blob, got %s", args[0].String()) - } - - if !args[1].EqualsStrict(types.TextType) { - return nil, wrapErrArgumentType(types.TextType, args[1]) + for _, arg := range args { + if !arg.EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, arg) + } } - return types.BlobType, nil + return types.IntType, nil }, PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("digest") - } if distinct { - return "", errDistinct("digest") + return "", errDistinct("position") + } + + if star { + return "", errStar("position") } - return fmt.Sprintf("digest(%s)", strings.Join(inputs, ", ")), nil + return fmt.Sprintf("position(%s in %s)", inputs[0], inputs[1]), nil }, }, - // array functions - "array_append": { + "rpad": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - if len(args) != 2 { - return nil, wrapErrArgumentNumber(2, len(args)) + // 2-3 args, 1 and 3 must be text, 2 must be int + if len(args) < 2 || len(args) > 3 { + return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args)) } - if !args[0].IsArray { - return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String()) + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - if args[1].IsArray { - return nil, fmt.Errorf("%w: expected second argument to be a scalar, got %s", ErrType, args[1].String()) + if !args[1].EqualsStrict(types.IntType) { + return nil, wrapErrArgumentType(types.IntType, args[1]) } - if !strings.EqualFold(args[0].Name, args[1].Name) { - return nil, fmt.Errorf("%w: append type must be equal to scalar array type: array type: %s append type: %s", ErrType, args[0].Name, args[1].Name) + if len(args) == 3 && !args[2].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[2]) } - return args[0], nil + return types.TextType, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("array_append") + PGFormat: defaultFormat("rpad"), + }, + "rtrim": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + // 1-2 args, both must be text + if len(args) < 1 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args)) } - if distinct { - return "", errDistinct("array_append") + + for _, arg := range args { + if !arg.EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, arg) + } } - return fmt.Sprintf("array_append(%s)", strings.Join(inputs, ", ")), nil + return types.TextType, nil }, + PGFormat: defaultFormat("rtrim"), }, - "array_prepend": { + "substring": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - if len(args) != 2 { - return nil, wrapErrArgumentNumber(2, len(args)) + // 2-3 args, 1 must be text, 2 and 3 must be int + // Postgres supports several different usages of substring, however Kwil only supports 1. + // In Postgres, substring can be used to both impose a string over a range, or to perform + // regex matching. Kwil only supports the former, as regex matching is not supported. + // Therefore, the second and third arguments must be integers. + if len(args) < 2 || len(args) > 3 { + return nil, fmt.Errorf("invalid number of arguments: expected 2 or 3, got %d", len(args)) } - if args[0].IsArray { - return nil, fmt.Errorf("%w: expected first argument to be a scalar, got %s", ErrType, args[0].String()) + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - if !args[1].IsArray { - return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String()) + if !args[1].EqualsStrict(types.IntType) { + return nil, wrapErrArgumentType(types.IntType, args[1]) } - if !strings.EqualFold(args[0].Name, args[1].Name) { - return nil, fmt.Errorf("%w: prepend type must be equal to scalar array type: array type: %s prepend type: %s", ErrType, args[1].Name, args[0].Name) + if len(args) == 3 && !args[2].EqualsStrict(types.IntType) { + return nil, wrapErrArgumentType(types.IntType, args[2]) } - return args[1], nil + return types.TextType, nil }, PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("array_prepend") - } if distinct { - return "", errDistinct("array_prepend") - } - - return fmt.Sprintf("array_prepend(%s)", strings.Join(inputs, ", ")), nil - }, - }, - "array_cat": { - ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { - if len(args) != 2 { - return nil, wrapErrArgumentNumber(2, len(args)) + return "", errDistinct("substring") } - if !args[0].IsArray { - return nil, fmt.Errorf("%w: expected first argument to be an array, got %s", ErrType, args[0].String()) - } - - if !args[1].IsArray { - return nil, fmt.Errorf("%w: expected second argument to be an array, got %s", ErrType, args[1].String()) + if star { + return "", errStar("substring") } - if !strings.EqualFold(args[0].Name, args[1].Name) { - return nil, fmt.Errorf("%w: expected both arrays to be of the same scalar type, got %s and %s", ErrType, args[0].Name, args[1].Name) + str := strings.Builder{} + str.WriteString("substring(") + str.WriteString(inputs[0]) + str.WriteString(" from ") + str.WriteString(inputs[1]) + if len(inputs) == 3 { + str.WriteString(" for ") + str.WriteString(inputs[2]) } + str.WriteString(")") - return args[0], nil + return str.String(), nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("array_cat") + }, + "trim": { // kwil only supports trim both + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + // 1-2 args, both must be text + if len(args) < 1 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments: expected 1 or 2, got %d", len(args)) } - if distinct { - return "", errDistinct("array_cat") + + for _, arg := range args { + if !arg.EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, arg) + } } - return fmt.Sprintf("array_cat(%s)", strings.Join(inputs, ", ")), nil + return types.TextType, nil }, + PGFormat: defaultFormat("trim"), }, - "array_length": { + "upper": { ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { if len(args) != 1 { return nil, wrapErrArgumentNumber(1, len(args)) } - if !args[0].IsArray { - return nil, fmt.Errorf("expected argument to be an array, got %s", args[0].String()) + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - return types.IntType, nil + return types.TextType, nil }, - PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { - if star { - return "", errStar("array_length") + PGFormat: defaultFormat("upper"), + }, + "format": { + ValidateArgs: func(args []*types.DataType) (*types.DataType, error) { + if len(args) < 1 { + return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args)) } - if distinct { - return "", errDistinct("array_length") + + if !args[0].EqualsStrict(types.TextType) { + return nil, wrapErrArgumentType(types.TextType, args[0]) } - return fmt.Sprintf("array_length(%s, 1)", inputs[0]), nil + return types.TextType, nil }, + PGFormat: defaultFormat("format"), }, // Aggregate functions "count": { @@ -426,6 +589,20 @@ var ( } ) +// defaultFormat is the default PGFormat function for functions that do not have a custom one. +func defaultFormat(name string) FormatFunc { + return func(inputs []string, distinct bool, star bool) (string, error) { + if star { + return "", errStar(name) + } + if distinct { + return "", errDistinct(name) + } + + return fmt.Sprintf("%s(%s)", name, strings.Join(inputs, ", ")), nil + } +} + func errDistinct(funcName string) error { return fmt.Errorf(`%w: cannot use DISTINCT with function "%s"`, ErrFunctionSignature, funcName) }