diff --git a/velox/expression/ConstantExpr.cpp b/velox/expression/ConstantExpr.cpp index 28c43048be338..381d93741e45e 100644 --- a/velox/expression/ConstantExpr.cpp +++ b/velox/expression/ConstantExpr.cpp @@ -162,6 +162,7 @@ void appendSqlLiteral( case TypeKind::TINYINT: case TypeKind::SMALLINT: case TypeKind::BIGINT: + case TypeKind::HUGEINT: case TypeKind::TIMESTAMP: case TypeKind::REAL: case TypeKind::DOUBLE: diff --git a/velox/expression/ReverseSignatureBinder.cpp b/velox/expression/ReverseSignatureBinder.cpp index b21715fef36d3..afd4bd311edc7 100644 --- a/velox/expression/ReverseSignatureBinder.cpp +++ b/velox/expression/ReverseSignatureBinder.cpp @@ -36,9 +36,6 @@ bool ReverseSignatureBinder::hasConstrainedIntegerVariable( } bool ReverseSignatureBinder::tryBind() { - if (hasConstrainedIntegerVariable(signature_.returnType())) { - return false; - } return SignatureBinderBase::tryBind(signature_.returnType(), returnType_); } diff --git a/velox/expression/tests/ArgumentTypeFuzzerTest.cpp b/velox/expression/tests/ArgumentTypeFuzzerTest.cpp index e139ae0c4f7a1..bbaabac61cec1 100644 --- a/velox/expression/tests/ArgumentTypeFuzzerTest.cpp +++ b/velox/expression/tests/ArgumentTypeFuzzerTest.cpp @@ -59,6 +59,42 @@ class ArgumentTypeFuzzerTest : public testing::Test { } } + void testFuzzingDecimalSuccess( + const std::shared_ptr& signature, + int32_t expectedArguments, + std::optional outputKind = std::nullopt) { + std::mt19937 seed{0}; + ArgumentTypeFuzzer fuzzer{*signature, seed}; + ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs)); + + auto& argumentTypes = fuzzer.argumentTypes(); + ASSERT_LE(argumentTypes.size(), expectedArguments); + + auto& argumentSignatures = signature->argumentTypes(); + int i; + for (i = 0; i < expectedArguments; ++i) { + ASSERT_TRUE(argumentTypes[i]->isDecimal()) + << "at " << i + << ": Expected decimal. Got: " << argumentTypes[i]->toString(); + } + + if (i < argumentTypes.size()) { + ASSERT_TRUE(signature->variableArity()); + ASSERT_LE( + argumentTypes.size() - argumentSignatures.size(), kMaxVariadicArgs); + for (int j = i; j < argumentTypes.size(); ++j) { + ASSERT_TRUE(argumentTypes[j]->equivalent(*argumentTypes[i - 1])); + } + } + + const auto outputType = fuzzer.fuzzReturnType(); + if (outputKind.has_value()) { + ASSERT_TRUE(outputType->kind() == outputKind); + } else { + ASSERT_TRUE(outputType->isDecimal()); + } + } + void testFuzzingFailure( const std::shared_ptr& signature, const TypePtr& returnType) { @@ -222,9 +258,36 @@ TEST_F(ArgumentTypeFuzzerTest, any) { ASSERT_TRUE(argumentTypes[0] != nullptr); } -TEST_F(ArgumentTypeFuzzerTest, unsupported) { - // Constraints on the return type is not supported. - auto signature = +TEST_F(ArgumentTypeFuzzerTest, decimal) { + auto signature = exec::FunctionSignatureBuilder() + .integerVariable("a_scale") + .integerVariable("a_precision") + .returnType("boolean") + .argumentType("decimal(a_precision, a_scale)") + .argumentType("decimal(a_precision, a_scale)") + .argumentType("decimal(a_precision, a_scale)") + .build(); + + testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN); + + signature = + exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", + "min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)") + .integerVariable("r_scale", "max(a_scale, b_scale)") + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build(); + + testFuzzingDecimalSuccess(signature, 2); + + signature = exec::FunctionSignatureBuilder() .integerVariable("a_scale") .integerVariable("b_scale") @@ -239,7 +302,31 @@ TEST_F(ArgumentTypeFuzzerTest, unsupported) { .argumentType("decimal(b_precision, b_scale)") .build(); - testFuzzingFailure(signature, DECIMAL(13, 6)); + testFuzzingDecimalSuccess(signature, 2, TypeKind::ROW); + + signature = exec::FunctionSignatureBuilder() + .integerVariable("i1") + .integerVariable("i2") + .integerVariable("i5") + .integerVariable("i6") + .integerVariable( + "i3", "min(38, max(i1 - i5, i2 - i6) + max(i5, i6) + 1)") + .integerVariable("i7", "max(i5, i6)") + .returnType("decimal(i3,i7)") + .argumentType("decimal(i1,i5)") + .argumentType("decimal(i2,i6)") + .build(); + testFuzzingDecimalSuccess(signature, 2); + + signature = exec::FunctionSignatureBuilder() + .integerVariable("i1") + .integerVariable("i5") + .returnType("boolean") + .argumentType("decimal(i1,i5)") + .argumentType("decimal(i1,i5)") + .argumentType("decimal(i1,i5)") + .build(); + testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN); } TEST_F(ArgumentTypeFuzzerTest, lambda) { diff --git a/velox/expression/tests/ExpressionFuzzer.cpp b/velox/expression/tests/ExpressionFuzzer.cpp index ec99da84e3a31..4ff197d120bde 100644 --- a/velox/expression/tests/ExpressionFuzzer.cpp +++ b/velox/expression/tests/ExpressionFuzzer.cpp @@ -445,16 +445,20 @@ bool useTypeName( bool isSupportedSignature( const exec::FunctionSignature& signature, - bool enableComplexType) { - // Not supporting lambda functions, or functions using decimal and - // timestamp with time zone types. + bool enableComplexType, + bool enableDecimalType) { + // When enableDecimalType is disabled, not supporting decimal functions. + const bool useDecimal = + (useTypeName(signature, "long_decimal") || + useTypeName(signature, "short_decimal") || + useTypeName(signature, "decimal")); + // Not supporting lambda functions, or functions using timestamp with time + // zone types. return !( useTypeName(signature, "opaque") || - useTypeName(signature, "long_decimal") || - useTypeName(signature, "short_decimal") || - useTypeName(signature, "decimal") || useTypeName(signature, "timestamp with time zone") || useTypeName(signature, "interval day to second") || + (!enableDecimalType && useDecimal) || (enableComplexType && useTypeName(signature, "unknown"))); } @@ -563,10 +567,14 @@ ExpressionFuzzer::ExpressionFuzzer( for (const auto& signature : function.second) { ++totalFunctionSignatures; - if (!isSupportedSignature(*signature, options_.enableComplexTypes)) { + if (!isSupportedSignature( + *signature, + options_.enableComplexTypes, + options_.enableDecimalType)) { continue; } - if (!(signature->variables().empty() || options_.enableComplexTypes)) { + if (!(signature->variables().empty() || options_.enableComplexTypes || + options_.enableDecimalType)) { LOG(WARNING) << "Skipping unsupported signature: " << function.first << signature->toString(); continue; @@ -690,8 +698,10 @@ ExpressionFuzzer::ExpressionFuzzer( for (const auto& it : signatureTemplates_) { auto& returnType = it.signature->returnType().baseName(); - auto* returnTypeKey = &returnType; - if (it.typeVariables.find(returnType) != it.typeVariables.end()) { + std::string typeName = returnType; + folly::toLowerAscii(typeName); + const auto* returnTypeKey = &typeName; + if (it.typeVariables.find(typeName) != it.typeVariables.end()) { // Return type is a template variable. returnTypeKey = &kTypeParameterName; } @@ -710,6 +720,13 @@ ExpressionFuzzer::ExpressionFuzzer( // Register function override (for cases where we want to restrict the types // or parameters we pass to functions). registerFuncOverride(&ExpressionFuzzer::generateSwitchArgs, "switch"); + registerFuncOverride( + &ExpressionFuzzer::generateExtremeFunctionArgs, "greatest"); + registerFuncOverride(&ExpressionFuzzer::generateExtremeFunctionArgs, "least"); + registerFuncOverride( + &ExpressionFuzzer::generateMakeTimestampArgs, "make_timestamp"); + registerFuncOverride( + &ExpressionFuzzer::generateUnscaledValueArgs, "unscaled_value"); } void ExpressionFuzzer::getTicketsForFunctions() { @@ -762,9 +779,11 @@ int ExpressionFuzzer::getTickets(const std::string& funcName) { void ExpressionFuzzer::addToTypeToExpressionListByTicketTimes( const std::string& type, const std::string& funcName) { + std::string typeName = type; + folly::toLowerAscii(typeName); int tickets = getTickets(funcName); for (int i = 0; i < tickets; i++) { - typeToExpressionList_[type].push_back(funcName); + typeToExpressionList_[typeName].push_back(funcName); } } @@ -950,6 +969,78 @@ std::vector ExpressionFuzzer::generateEmptyApproxSetArgs( return {generateArgConstant(input.args[0])}; } +std::vector ExpressionFuzzer::generateExtremeFunctionArgs( + const CallableSignature& input) { + const auto argTypes = input.args; + VELOX_CHECK_GE( + argTypes.size(), + 1, + "Only one input is expected from the template signature."); + if (!argTypes[0]->isDecimal()) { + return generateArgs(input); + } + + auto numVarArgs = + !input.variableArity ? 0 : rand32(0, options_.maxNumVarArgs); + std::vector inputExpressions; + inputExpressions.reserve(argTypes.size() + numVarArgs); + inputExpressions.emplace_back( + generateArg(argTypes.at(0), input.constantArgs.at(0))); + + // Append varargs to the argument list. + for (int i = 0; i < numVarArgs; i++) { + size_t argClass = rand32(0, 1); + core::TypedExprPtr argExpr; + const auto argType = inputExpressions[0]->type(); + if (argClass == kArgConstant) { + argExpr = generateArgConstant(argType); + } else { + argExpr = generateArgColumn(argType); + } + inputExpressions.emplace_back(argExpr); + } + return inputExpressions; +} + +std::vector ExpressionFuzzer::generateMakeTimestampArgs( + const CallableSignature& input) { + VELOX_CHECK_GE( + input.args.size(), + 6, + "At least six inputs are expected from the template signature."); + bool useTimezone = vectorFuzzer_->coinToss(0.5); + std::vector inputExpressions; + inputExpressions.reserve(6); + for (int index = 0; index < 5; ++index) { + inputExpressions.emplace_back(generateArg(input.args[index])); + } + + // It cannot be ensured that the generated expression follows the required + // return type, so only constant or column can be generated as the decimal + // argument. + size_t argClass = rand32(0, 1); + core::TypedExprPtr argExpr; + if (argClass == kArgConstant) { + argExpr = generateArgConstant(input.args[5]); + } else { + argExpr = generateArgColumn(input.args[5]); + } + inputExpressions.emplace_back(argExpr); + + if (input.args.size() == 7) { + std::vector timezoneSet = { + "Asia/Kolkata", + "America/Los_Angeles", + "Canada/Atlantic", + "+08:00", + "-10:00"}; + size_t zoneIndex = rand32(0, 4); + inputExpressions.emplace_back(std::make_shared( + VARCHAR(), variant(timezoneSet[zoneIndex]))); + } + return inputExpressions; +} + // Specialization for the "regexp_replace" function: second and third // (optional) parameters always need to be constant. std::vector ExpressionFuzzer::generateRegexpReplaceArgs( @@ -984,6 +1075,28 @@ std::vector ExpressionFuzzer::generateSwitchArgs( return inputExpressions; } +std::vector ExpressionFuzzer::generateUnscaledValueArgs( + const CallableSignature& input) { + VELOX_CHECK_EQ( + input.args.size(), + 1, + "Only one input is expected from the template signature."); + + // It cannot be ensured that the generated expression follows the required + // return type, so only constant or column can be generated as the decimal + // argument. + std::vector inputExpressions; + size_t argClass = rand32(0, 1); + core::TypedExprPtr argExpr; + if (argClass == kArgConstant) { + argExpr = generateArgConstant(input.args[0]); + } else { + argExpr = generateArgColumn(input.args[0]); + } + inputExpressions.emplace_back(argExpr); + return inputExpressions; +} + ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpressions( const RowTypePtr& outType) { state.reset(); @@ -1060,7 +1173,8 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression( } else { expression = generateExpressionFromConcreteSignatures( returnType, chosenFunctionName); - if (!expression && options_.enableComplexTypes) { + if (!expression && + (options_.enableComplexTypes || options_.enableDecimalType)) { expression = generateExpressionFromSignatureTemplate( returnType, chosenFunctionName); } @@ -1084,13 +1198,92 @@ std::vector ExpressionFuzzer::getArgsForCallable( return funcIt->second(callable); } +TypePtr ExpressionFuzzer::getConstrainedOutputType( + const std::vector& args, + const exec::FunctionSignature* signature) { + if (signature == nullptr) { + return nullptr; + } + // When function is unnested, the types of args are decided by fuzzer argument + // types. For nested function, they are decided by the return types of + // children functions. To handle the constraints between input types and + // output types of a decimal function, extract the input precisions and scales + // from decimal arguments, and bind them to integer variables. + + // Checks if any variable is integer constrained, and get the decimal name + // style. + bool integerConstrained = false; + char decimalNameStyle = 0; + for (const auto& [variableName, variableInfo] : signature->variables()) { + const auto constraint = variableInfo.constraint(); + if (variableInfo.isIntegerParameter()) { + if (variableName == "precision" || variableName == "scale") { + decimalNameStyle = 'v'; + } else if ( + variableName.find("precision") != std::string::npos || + variableName.find("scale") != std::string::npos) { + decimalNameStyle = 'p'; + } else if (variableName.find("i") != std::string::npos) { + decimalNameStyle = 'i'; + } + integerConstrained = true; + break; + } + } + + std::unordered_map decimalVariablesBindings; + column_index_t decimalColIndex = 1; + for (column_index_t i = 0; i < args.size(); ++i) { + const auto argType = args[i]->type(); + if (argType->isDecimal()) { + const auto [p, s] = getDecimalPrecisionScale(*argType); + switch (decimalNameStyle) { + case 'v': { + decimalVariablesBindings["precision"] = p; + decimalVariablesBindings["scale"] = s; + break; + } + case 'p': { + const auto column = std::string(1, 'a' + i); + decimalVariablesBindings[column + "_precision"] = p; + decimalVariablesBindings[column + "_scale"] = s; + break; + } + case 'i': { + decimalVariablesBindings["i" + std::to_string(decimalColIndex)] = p; + decimalVariablesBindings + ["i" + std::to_string(decimalColIndex + kIntegerPairSize)] = s; + decimalColIndex++; + break; + } + default: + VELOX_UNSUPPORTED("Unsupported decimal name style."); + } + } + } + + if (integerConstrained && decimalVariablesBindings.size() > 0 && signature) { + // Compute a correct output type according to constraints with the integer + // variables bindings. + ArgumentTypeFuzzer fuzzer{*signature, rng_, decimalVariablesBindings}; + return fuzzer.fuzzReturnType(); + } + return nullptr; +} + core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable( const CallableSignature& callable, - const TypePtr& type) { + const TypePtr& type, + const exec::FunctionSignature* signature) { auto args = getArgsForCallable(callable); - // Generate a CallTypedExpr with type because callable.returnType may not have - // the required field names. - return std::make_shared(type, args, callable.name); + + // If a constrained output type is generated, use it to avoid breaking the + // constraints between input types and output types. Otherwise, generate a + // CallTypedExpr with type because callable.returnType may not have the + // required field names. + const auto constrainedType = getConstrainedOutputType(args, signature); + return std::make_shared( + constrainedType ? constrainedType : type, args, callable.name); } const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature( @@ -1172,7 +1365,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures( } markSelected(chosen->name); - return getCallExprFromCallable(*chosen, returnType); + return getCallExprFromCallable(*chosen, returnType, nullptr); } const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate( @@ -1277,7 +1470,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate( .constantArgs = constantArguments}; markSelected(chosen->name); - return getCallExprFromCallable(callable, returnType); + return getCallExprFromCallable(callable, returnType, chosen->signature); } core::TypedExprPtr ExpressionFuzzer::generateCastExpression( diff --git a/velox/expression/tests/ExpressionFuzzer.h b/velox/expression/tests/ExpressionFuzzer.h index da6e50a28adcb..508999633d972 100644 --- a/velox/expression/tests/ExpressionFuzzer.h +++ b/velox/expression/tests/ExpressionFuzzer.h @@ -48,6 +48,10 @@ class ExpressionFuzzer { // types. bool enableComplexTypes = false; + // Enable testing of function signatures with decimal argument or return + // types. + bool enableDecimalType = false; + // Enable generation of expressions where one input column can be used by // multiple subexpressions. bool enableColumnReuse = false; @@ -249,15 +253,17 @@ class ExpressionFuzzer { std::vector generateEmptyApproxSetArgs( const CallableSignature& input); + std::vector generateExtremeFunctionArgs( + const CallableSignature& input); + + std::vector generateMakeTimestampArgs( + const CallableSignature& input); + /// Specialization for the "regexp_replace" function: second and third /// (optional) parameters always need to be constant. std::vector generateRegexpReplaceArgs( const CallableSignature& input); - // Return a vector of expressions for each argument of callable in order. - std::vector getArgsForCallable( - const CallableSignature& callable); - /// Specialization for the "switch" function. Takes in a signature that is /// of the form Switch (condition, then): boolean, T -> T where the type /// variable is bounded to a randomly selected type. It randomly decides the @@ -267,9 +273,21 @@ class ExpressionFuzzer { std::vector generateSwitchArgs( const CallableSignature& input); + std::vector generateUnscaledValueArgs( + const CallableSignature& input); + + // Return a vector of expressions for each argument of callable in order. + std::vector getArgsForCallable( + const CallableSignature& callable); + + TypePtr getConstrainedOutputType( + const std::vector& args, + const exec::FunctionSignature* signature); + core::TypedExprPtr getCallExprFromCallable( const CallableSignature& callable, - const TypePtr& type); + const TypePtr& type, + const exec::FunctionSignature* signature = nullptr); /// Return a random signature mapped to functionName in /// expressionToSignature_ whose return type can match returnType. Return diff --git a/velox/expression/tests/FuzzerRunner.cpp b/velox/expression/tests/FuzzerRunner.cpp index ac741944dfa61..e947f147c6853 100644 --- a/velox/expression/tests/FuzzerRunner.cpp +++ b/velox/expression/tests/FuzzerRunner.cpp @@ -121,6 +121,11 @@ DEFINE_bool( false, "Enable testing of function signatures with complex argument or return types."); +DEFINE_bool( + velox_fuzzer_enable_decimal_type, + false, + "Enable testing of function signatures with decimal argument or return types."); + DEFINE_bool( velox_fuzzer_enable_column_reuse, false, @@ -168,6 +173,7 @@ ExpressionFuzzer::Options getExpressionFuzzerOptions( opts.enableVariadicSignatures = FLAGS_enable_variadic_signatures; opts.enableDereference = FLAGS_enable_dereference; opts.enableComplexTypes = FLAGS_velox_fuzzer_enable_complex_types; + opts.enableDecimalType = FLAGS_velox_fuzzer_enable_decimal_type; opts.enableColumnReuse = FLAGS_velox_fuzzer_enable_column_reuse; opts.enableExpressionReuse = FLAGS_velox_fuzzer_enable_expression_reuse; opts.functionTickets = FLAGS_assign_function_tickets; diff --git a/velox/expression/tests/SparkExpressionFuzzerTest.cpp b/velox/expression/tests/SparkExpressionFuzzerTest.cpp index c9531632f4137..1232c49ed12ab 100644 --- a/velox/expression/tests/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/tests/SparkExpressionFuzzerTest.cpp @@ -58,7 +58,9 @@ int main(int argc, char** argv) { // Required by spark_partition_id function. std::unordered_map queryConfigs = { - {facebook::velox::core::QueryConfig::kSparkPartitionId, "123"}}; + {facebook::velox::core::QueryConfig::kSparkPartitionId, "123"}, + {facebook::velox::core::QueryConfig::kSessionTimezone, + "America/Los_Angeles"}}; return FuzzerRunner::run(FLAGS_seed, skipFunctions, queryConfigs); } diff --git a/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp b/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp index 44c8a0a69bf3d..93ed98cac6666 100644 --- a/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp +++ b/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp @@ -21,13 +21,17 @@ #include "velox/expression/ReverseSignatureBinder.h" #include "velox/expression/SignatureBinder.h" +#include "velox/expression/type_calculation/TypeCalculation.h" +#include "velox/type/SimpleFunctionApi.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" namespace facebook::velox::test { std::string typeToBaseName(const TypePtr& type) { - return boost::algorithm::to_lower_copy(std::string{type->kindName()}); + return type->isDecimal() + ? "decimal" + : boost::algorithm::to_lower_copy(std::string{type->kindName()}); } std::optional baseNameToTypeKind(const std::string& typeName) { @@ -35,6 +39,60 @@ std::optional baseNameToTypeKind(const std::string& typeName) { return tryMapNameToTypeKind(kindName); } +void ArgumentTypeFuzzer::determineUnboundedIntegerVariables() { + // Assign a random value for all integer values. + for (const auto& [variableName, variableInfo] : variables()) { + if (!variableInfo.isIntegerParameter() || + integerVariablesBindings_.count(variableName)) { + continue; + } + + if (variableName == "precision" && returnType_ && + returnType_->isDecimal()) { + const auto [precision, scale] = getDecimalPrecisionScale(*returnType_); + integerVariablesBindings_["precision"] = precision; + integerVariablesBindings_["scale"] = scale; + } else if (auto pos = variableName.find("precision"); + pos != std::string::npos) { + // Handle decimal precisions and scales. + const auto precision = + boost::random::uniform_int_distribution(1, 38)(rng_); + integerVariablesBindings_[variableName] = precision; + const auto colName = variableName.substr(0, pos); + // Corresponding scale should not exceed the generated precision. + integerVariablesBindings_[colName + "scale"] = + boost::random::uniform_int_distribution(0, precision)(rng_); + } else if (auto pos = variableName.find("i"); pos != std::string::npos) { + VELOX_USER_CHECK_GE(variableName.size(), 2); + const auto index = + std::stoi(variableName.substr(pos + 1, variableName.size())); + if (index <= kIntegerPairSize) { + const auto precision = + boost::random::uniform_int_distribution(1, 38)(rng_); + integerVariablesBindings_[variableName] = precision; + const auto scaleIndex = index + kIntegerPairSize; + const auto scaleName = "i" + std::to_string(scaleIndex); + integerVariablesBindings_[scaleName] = + boost::random::uniform_int_distribution( + 0, precision)(rng_); + } + } else { + integerVariablesBindings_[variableName] = + boost::random::uniform_int_distribution()(rng_); + } + } + + // Handle constraints. + for (const auto& [variableName, variableInfo] : variables()) { + const auto constraint = variableInfo.constraint(); + if (constraint == "") { + continue; + } + auto calculation = fmt::format("{}={}", variableName, constraint); + expression::calculation::evaluate(calculation, integerVariablesBindings_); + } +} + void ArgumentTypeFuzzer::determineUnboundedTypeVariables() { for (auto& [variableName, variableInfo] : variables()) { if (!variableInfo.isTypeParameter()) { @@ -80,6 +138,7 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) { } } + determineUnboundedIntegerVariables(); determineUnboundedTypeVariables(); for (auto i = 0; i < formalArgsCnt; i++) { TypePtr actualArg; @@ -87,7 +146,7 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) { actualArg = randType(); } else { actualArg = exec::SignatureBinder::tryResolveType( - formalArgs[i], variables(), bindings_); + formalArgs[i], variables(), bindings_, integerVariablesBindings_); VELOX_CHECK(actualArg != nullptr); } argumentTypes_.push_back(actualArg); @@ -113,13 +172,17 @@ TypePtr ArgumentTypeFuzzer::fuzzReturnType() { nullptr, "Only fuzzing uninitialized return type is allowed."); + determineUnboundedIntegerVariables(); determineUnboundedTypeVariables(); if (signature_.returnType().baseName() == "any") { returnType_ = randType(); return returnType_; } else { returnType_ = exec::SignatureBinder::tryResolveType( - signature_.returnType(), variables(), bindings_); + signature_.returnType(), + variables(), + bindings_, + integerVariablesBindings_); VELOX_CHECK_NE(returnType_, nullptr); return returnType_; } diff --git a/velox/expression/tests/utils/ArgumentTypeFuzzer.h b/velox/expression/tests/utils/ArgumentTypeFuzzer.h index 540a353221bfd..d6427c4b1500d 100644 --- a/velox/expression/tests/utils/ArgumentTypeFuzzer.h +++ b/velox/expression/tests/utils/ArgumentTypeFuzzer.h @@ -36,6 +36,14 @@ class ArgumentTypeFuzzer { std::mt19937& rng) : ArgumentTypeFuzzer(signature, nullptr, rng) {} + ArgumentTypeFuzzer( + const exec::FunctionSignature& signature, + std::mt19937& rng, + const std::unordered_map& integerVariablesBindings) + : ArgumentTypeFuzzer(signature, nullptr, rng) { + integerVariablesBindings_ = integerVariablesBindings; + } + ArgumentTypeFuzzer( const exec::FunctionSignature& signature, const TypePtr& returnType, @@ -65,6 +73,10 @@ class ArgumentTypeFuzzer { return signature_.variables(); } + /// Bind each integer variable that is not determined to a randomly generated + /// value. + void determineUnboundedIntegerVariables(); + /// Bind each type variable that is not determined by the return type to a /// randomly generated type. void determineUnboundedTypeVariables(); @@ -83,6 +95,9 @@ class ArgumentTypeFuzzer { /// Bindings between type variables and their actual types. std::unordered_map bindings_; + /// Bindings between integer variables and their values. + std::unordered_map integerVariablesBindings_; + /// RNG to generate random types for unbounded type variables when necessary. std::mt19937& rng_; }; diff --git a/velox/functions/prestosql/DecimalFunctions.cpp b/velox/functions/prestosql/DecimalFunctions.cpp index 747163a9fc36b..25d371a9db764 100644 --- a/velox/functions/prestosql/DecimalFunctions.cpp +++ b/velox/functions/prestosql/DecimalFunctions.cpp @@ -364,7 +364,9 @@ void registerDecimalMultiply(const std::string& prefix) { exec::SignatureVariable( S3::name(), fmt::format( - "{a_scale} + {b_scale}", + "min({a_scale} + {b_scale}, min(38, {a_precision} + {b_precision}))", + fmt::arg("a_precision", P1::name()), + fmt::arg("b_precision", P2::name()), fmt::arg("a_scale", S1::name()), fmt::arg("b_scale", S2::name())), exec::ParameterType::kIntegerParameter), diff --git a/velox/functions/sparksql/MakeTimestamp.cpp b/velox/functions/sparksql/MakeTimestamp.cpp index 4466482195a1b..00e7c3195c053 100644 --- a/velox/functions/sparksql/MakeTimestamp.cpp +++ b/velox/functions/sparksql/MakeTimestamp.cpp @@ -149,7 +149,7 @@ class MakeTimestampFunction : public exec::VectorFunction { static std::vector> signatures() { return { exec::FunctionSignatureBuilder() - .integerVariable("precision") + .integerVariable("precision", "min(max(6, precision), 18)") .returnType("timestamp") .argumentType("integer") .argumentType("integer") @@ -159,7 +159,7 @@ class MakeTimestampFunction : public exec::VectorFunction { .argumentType("decimal(precision, 6)") .build(), exec::FunctionSignatureBuilder() - .integerVariable("precision") + .integerVariable("precision", "min(max(6, precision), 18)") .returnType("timestamp") .argumentType("integer") .argumentType("integer") diff --git a/velox/functions/sparksql/UnscaledValueFunction.cpp b/velox/functions/sparksql/UnscaledValueFunction.cpp index 7833db779d359..bbf2c82e5058b 100644 --- a/velox/functions/sparksql/UnscaledValueFunction.cpp +++ b/velox/functions/sparksql/UnscaledValueFunction.cpp @@ -49,8 +49,8 @@ class UnscaledValueFunction final : public exec::VectorFunction { std::vector> unscaledValueSignatures() { return {exec::FunctionSignatureBuilder() - .integerVariable("precision") - .integerVariable("scale") + .integerVariable("precision", "min(precision, 18)") + .integerVariable("scale", "min(min(precision, 18), scale)") .returnType("bigint") .argumentType("DECIMAL(precision, scale)") .build()}; diff --git a/velox/type/SimpleFunctionApi.h b/velox/type/SimpleFunctionApi.h index 70848cfc79531..62028532964ee 100644 --- a/velox/type/SimpleFunctionApi.h +++ b/velox/type/SimpleFunctionApi.h @@ -63,6 +63,7 @@ using S1 = IntegerVariable<5>; using S2 = IntegerVariable<6>; using S3 = IntegerVariable<7>; using S4 = IntegerVariable<8>; +const uint8_t kIntegerPairSize = 4; template struct ShortDecimal {