From 6c0594094084036d0baccabf0f7b11a96c75a063 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Mon, 18 Dec 2023 15:47:56 +0800 Subject: [PATCH] support decimal in expression fuzzer test --- velox/expression/ReverseSignatureBinder.cpp | 3 - .../tests/ArgumentTypeFuzzerTest.cpp | 71 ++++++++++++- velox/expression/tests/ExpressionFuzzer.cpp | 99 ++++++++++++++++--- velox/expression/tests/ExpressionFuzzer.h | 11 ++- velox/expression/tests/FuzzerRunner.cpp | 6 ++ .../tests/utils/ArgumentTypeFuzzer.cpp | 66 ++++++++++++- .../tests/utils/ArgumentTypeFuzzer.h | 15 +++ velox/type/Type.h | 1 + 8 files changed, 247 insertions(+), 25 deletions(-) diff --git a/velox/expression/ReverseSignatureBinder.cpp b/velox/expression/ReverseSignatureBinder.cpp index b21715fef36d3..afd4bd311edc7 100644 --- a/velox/expression/ReverseSignatureBinder.cpp +++ b/velox/expression/ReverseSignatureBinder.cpp @@ -36,9 +36,6 @@ bool ReverseSignatureBinder::hasConstrainedIntegerVariable( } bool ReverseSignatureBinder::tryBind() { - if (hasConstrainedIntegerVariable(signature_.returnType())) { - return false; - } return SignatureBinderBase::tryBind(signature_.returnType(), returnType_); } diff --git a/velox/expression/tests/ArgumentTypeFuzzerTest.cpp b/velox/expression/tests/ArgumentTypeFuzzerTest.cpp index e139ae0c4f7a1..73671b92c6b52 100644 --- a/velox/expression/tests/ArgumentTypeFuzzerTest.cpp +++ b/velox/expression/tests/ArgumentTypeFuzzerTest.cpp @@ -59,6 +59,42 @@ class ArgumentTypeFuzzerTest : public testing::Test { } } + void testFuzzingDecimalSuccess( + const std::shared_ptr& signature, + int32_t expectedArguments, + std::optional outputKind = std::nullopt) { + std::mt19937 seed{0}; + ArgumentTypeFuzzer fuzzer{*signature, seed}; + ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs)); + + auto& argumentTypes = fuzzer.argumentTypes(); + ASSERT_LE(argumentTypes.size(), expectedArguments); + + auto& argumentSignatures = signature->argumentTypes(); + int i; + for (i = 0; i < expectedArguments; ++i) { + ASSERT_TRUE(argumentTypes[i]->isDecimal()) + << "at " << i + << ": Expected decimal. Got: " << argumentTypes[i]->toString(); + } + + if (i < argumentTypes.size()) { + ASSERT_TRUE(signature->variableArity()); + ASSERT_LE( + argumentTypes.size() - argumentSignatures.size(), kMaxVariadicArgs); + for (int j = i; j < argumentTypes.size(); ++j) { + ASSERT_TRUE(argumentTypes[j]->equivalent(*argumentTypes[i - 1])); + } + } + + const auto outputType = fuzzer.fuzzReturnType(); + if (outputKind.has_value()) { + ASSERT_TRUE(outputType->kind() == outputKind); + } else { + ASSERT_TRUE(outputType->isDecimal()); + } + } + void testFuzzingFailure( const std::shared_ptr& signature, const TypePtr& returnType) { @@ -222,9 +258,36 @@ TEST_F(ArgumentTypeFuzzerTest, any) { ASSERT_TRUE(argumentTypes[0] != nullptr); } -TEST_F(ArgumentTypeFuzzerTest, unsupported) { - // Constraints on the return type is not supported. - auto signature = +TEST_F(ArgumentTypeFuzzerTest, decimal) { + auto signature = exec::FunctionSignatureBuilder() + .integerVariable("a_scale") + .integerVariable("a_precision") + .returnType("boolean") + .argumentType("decimal(a_precision, a_scale)") + .argumentType("decimal(a_precision, a_scale)") + .argumentType("decimal(a_precision, a_scale)") + .build(); + + testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN); + + signature = + exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", + "min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)") + .integerVariable("r_scale", "max(a_scale, b_scale)") + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build(); + + testFuzzingDecimalSuccess(signature, 2); + + signature = exec::FunctionSignatureBuilder() .integerVariable("a_scale") .integerVariable("b_scale") @@ -239,7 +302,7 @@ TEST_F(ArgumentTypeFuzzerTest, unsupported) { .argumentType("decimal(b_precision, b_scale)") .build(); - testFuzzingFailure(signature, DECIMAL(13, 6)); + testFuzzingDecimalSuccess(signature, 2, TypeKind::ROW); } TEST_F(ArgumentTypeFuzzerTest, lambda) { diff --git a/velox/expression/tests/ExpressionFuzzer.cpp b/velox/expression/tests/ExpressionFuzzer.cpp index 899c0853a24db..f172473c6f82f 100644 --- a/velox/expression/tests/ExpressionFuzzer.cpp +++ b/velox/expression/tests/ExpressionFuzzer.cpp @@ -471,9 +471,6 @@ bool isSupportedSignature( // timestamp with time zone types. return !( useTypeName(signature, "opaque") || - useTypeName(signature, "long_decimal") || - useTypeName(signature, "short_decimal") || - useTypeName(signature, "decimal") || useTypeName(signature, "timestamp with time zone") || useTypeName(signature, "interval day to second") || (enableComplexType && useTypeName(signature, "unknown"))); @@ -587,7 +584,8 @@ ExpressionFuzzer::ExpressionFuzzer( if (!isSupportedSignature(*signature, options_.enableComplexTypes)) { continue; } - if (!(signature->variables().empty() || options_.enableComplexTypes)) { + if (!(signature->variables().empty() || options_.enableComplexTypes || + options_.enableDecimalType)) { LOG(WARNING) << "Skipping unsupported signature: " << function.first << signature->toString(); continue; @@ -711,8 +709,10 @@ ExpressionFuzzer::ExpressionFuzzer( for (const auto& it : signatureTemplates_) { auto& returnType = it.signature->returnType().baseName(); - auto* returnTypeKey = &returnType; - if (it.typeVariables.find(returnType) != it.typeVariables.end()) { + std::string typeName = returnType; + folly::toLowerAscii(typeName); + const auto* returnTypeKey = &typeName; + if (it.typeVariables.find(typeName) != it.typeVariables.end()) { // Return type is a template variable. returnTypeKey = &kTypeParameterName; } @@ -783,9 +783,11 @@ int ExpressionFuzzer::getTickets(const std::string& funcName) { void ExpressionFuzzer::addToTypeToExpressionListByTicketTimes( const std::string& type, const std::string& funcName) { + std::string typeName = type; + folly::toLowerAscii(typeName); int tickets = getTickets(funcName); for (int i = 0; i < tickets; i++) { - typeToExpressionList_[type].push_back(funcName); + typeToExpressionList_[typeName].push_back(funcName); } } @@ -1081,7 +1083,8 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression( } else { expression = generateExpressionFromConcreteSignatures( returnType, chosenFunctionName); - if (!expression && options_.enableComplexTypes) { + if (!expression && + (options_.enableComplexTypes || options_.enableDecimalType)) { expression = generateExpressionFromSignatureTemplate( returnType, chosenFunctionName); } @@ -1105,13 +1108,81 @@ std::vector ExpressionFuzzer::getArgsForCallable( return funcIt->second(callable); } +TypePtr ExpressionFuzzer::getConstrainedOutputType( + const std::vector& args, + const exec::FunctionSignature* signature) { + // When function is unnested, the types of args are decided by fuzzer argument + // types. For nested function, they are decided by the return types of + // children functions. To handle the constraints between input types and + // output types of a decimal function, extract the input precisions and scales + // from decimal arguments, and bind them to integer variables. + + // Checks if any variable is integer constrained, and get the decimal name + // style. + bool integerConstrained = false; + char decimalNameStyle = 0; + for (const auto& [variableName, variableInfo] : signature->variables()) { + const auto constraint = variableInfo.constraint(); + if (variableInfo.isIntegerParameter()) { + if (variableName.find("precision") != std::string::npos || + variableName.find("scale") != std::string::npos) { + decimalNameStyle = 'p'; + } else if (variableName.find("i") != std::string::npos) { + decimalNameStyle = 'i'; + } + integerConstrained = true; + break; + } + } + + std::unordered_map decimalVariablesBindings; + column_index_t decimalColIndex = 1; + for (column_index_t i = 0; i < args.size(); ++i) { + const auto argType = args[i]->type(); + if (argType->isDecimal()) { + const auto [p, s] = getDecimalPrecisionScale(*argType); + switch (decimalNameStyle) { + case 'p': { + const auto column = std::string(1, 'a' + i); + decimalVariablesBindings[column + "_precision"] = p; + decimalVariablesBindings[column + "_scale"] = s; + break; + } + case 'i': { + decimalVariablesBindings["i" + std::to_string(decimalColIndex)] = p; + decimalVariablesBindings + ["i" + std::to_string(decimalColIndex + kIntegerPairSize)] = s; + decimalColIndex++; + break; + } + default: + VELOX_UNSUPPORTED("Unsupported decimal name style."); + } + } + } + + if (integerConstrained && decimalVariablesBindings.size() > 0 && signature) { + // Compute a correct output type according to constraints with the integer + // variables bindings. + ArgumentTypeFuzzer fuzzer{*signature, rng_, decimalVariablesBindings}; + return fuzzer.fuzzReturnType(); + } + return nullptr; +} + core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable( const CallableSignature& callable, - const TypePtr& type) { + const TypePtr& type, + const exec::FunctionSignature* signature) { auto args = getArgsForCallable(callable); - // Generate a CallTypedExpr with type because callable.returnType may not have - // the required field names. - return std::make_shared(type, args, callable.name); + + // If a constrained output type is generated, use it to avoid breaking the + // constraints between input types and output types. Otherwise, generate a + // CallTypedExpr with type because callable.returnType may not have the + // required field names. + const auto constrainedType = getConstrainedOutputType(args, signature); + return std::make_shared( + constrainedType ? constrainedType : type, args, callable.name); } const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature( @@ -1193,7 +1264,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures( } markSelected(chosen->name); - return getCallExprFromCallable(*chosen, returnType); + return getCallExprFromCallable(*chosen, returnType, nullptr); } const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate( @@ -1298,7 +1369,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate( .constantArgs = constantArguments}; markSelected(chosen->name); - return getCallExprFromCallable(callable, returnType); + return getCallExprFromCallable(callable, returnType, chosen->signature); } core::TypedExprPtr ExpressionFuzzer::generateCastExpression( diff --git a/velox/expression/tests/ExpressionFuzzer.h b/velox/expression/tests/ExpressionFuzzer.h index da6e50a28adcb..20283cb80d3fc 100644 --- a/velox/expression/tests/ExpressionFuzzer.h +++ b/velox/expression/tests/ExpressionFuzzer.h @@ -48,6 +48,10 @@ class ExpressionFuzzer { // types. bool enableComplexTypes = false; + // Enable testing of function signatures with decimal argument or return + // types. + bool enableDecimalType = false; + // Enable generation of expressions where one input column can be used by // multiple subexpressions. bool enableColumnReuse = false; @@ -267,9 +271,14 @@ class ExpressionFuzzer { std::vector generateSwitchArgs( const CallableSignature& input); + TypePtr getConstrainedOutputType( + const std::vector& args, + const exec::FunctionSignature* signature); + core::TypedExprPtr getCallExprFromCallable( const CallableSignature& callable, - const TypePtr& type); + const TypePtr& type, + const exec::FunctionSignature* signature = nullptr); /// Return a random signature mapped to functionName in /// expressionToSignature_ whose return type can match returnType. Return diff --git a/velox/expression/tests/FuzzerRunner.cpp b/velox/expression/tests/FuzzerRunner.cpp index ac741944dfa61..e947f147c6853 100644 --- a/velox/expression/tests/FuzzerRunner.cpp +++ b/velox/expression/tests/FuzzerRunner.cpp @@ -121,6 +121,11 @@ DEFINE_bool( false, "Enable testing of function signatures with complex argument or return types."); +DEFINE_bool( + velox_fuzzer_enable_decimal_type, + false, + "Enable testing of function signatures with decimal argument or return types."); + DEFINE_bool( velox_fuzzer_enable_column_reuse, false, @@ -168,6 +173,7 @@ ExpressionFuzzer::Options getExpressionFuzzerOptions( opts.enableVariadicSignatures = FLAGS_enable_variadic_signatures; opts.enableDereference = FLAGS_enable_dereference; opts.enableComplexTypes = FLAGS_velox_fuzzer_enable_complex_types; + opts.enableDecimalType = FLAGS_velox_fuzzer_enable_decimal_type; opts.enableColumnReuse = FLAGS_velox_fuzzer_enable_column_reuse; opts.enableExpressionReuse = FLAGS_velox_fuzzer_enable_expression_reuse; opts.functionTickets = FLAGS_assign_function_tickets; diff --git a/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp b/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp index 44c8a0a69bf3d..0da8eb6df02fe 100644 --- a/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp +++ b/velox/expression/tests/utils/ArgumentTypeFuzzer.cpp @@ -21,13 +21,16 @@ #include "velox/expression/ReverseSignatureBinder.h" #include "velox/expression/SignatureBinder.h" +#include "velox/expression/type_calculation/TypeCalculation.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" namespace facebook::velox::test { std::string typeToBaseName(const TypePtr& type) { - return boost::algorithm::to_lower_copy(std::string{type->kindName()}); + return type->isDecimal() + ? "decimal" // TODO + : boost::algorithm::to_lower_copy(std::string{type->kindName()}); } std::optional baseNameToTypeKind(const std::string& typeName) { @@ -35,6 +38,58 @@ std::optional baseNameToTypeKind(const std::string& typeName) { return tryMapNameToTypeKind(kindName); } +void ArgumentTypeFuzzer::determineUnboundedIntegerVariables() { + // Assign a random value for all integer values. + for (const auto& [variableName, variableInfo] : variables()) { + if (!variableInfo.isIntegerParameter() || + integerVariablesBindings_.count(variableName)) { + continue; + } + + if (auto pos = variableName.find("precision"); pos != std::string::npos) { + // Handle decimal precisions and scales. + const auto precision = + boost::random::uniform_int_distribution(1, 38)(rng_); + integerVariablesBindings_[variableName] = precision; + const auto colName = variableName.substr(0, pos); + // Corresponding scale should not exceed the generated precision. + for (const auto& [name, info] : variables()) { + if (name == colName + "scale") { + integerVariablesBindings_[name] = + boost::random::uniform_int_distribution( + 0, precision)(rng_); + } + } + } else if (auto pos = variableName.find("i"); pos != std::string::npos) { + VELOX_USER_CHECK_GE(variableName.size(), 2); + auto index = std::stoi(variableName.substr(pos + 1, variableName.size())); + if (index <= kIntegerPairSize) { + const auto precision = + boost::random::uniform_int_distribution(1, 38)(rng_); + integerVariablesBindings_[variableName] = precision; + const auto scaleIndex = index + kIntegerPairSize; + const auto scaleName = "i" + std::to_string(scaleIndex); + integerVariablesBindings_[scaleName] = + boost::random::uniform_int_distribution( + 0, precision)(rng_); + } + } else { + integerVariablesBindings_[variableName] = + boost::random::uniform_int_distribution()(rng_); + } + } + + // Handle constraints. + for (const auto& [variableName, variableInfo] : variables()) { + const auto constraint = variableInfo.constraint(); + if (constraint == "") { + continue; + } + auto calculation = fmt::format("{}={}", variableName, constraint); + expression::calculation::evaluate(calculation, integerVariablesBindings_); + } +} + void ArgumentTypeFuzzer::determineUnboundedTypeVariables() { for (auto& [variableName, variableInfo] : variables()) { if (!variableInfo.isTypeParameter()) { @@ -80,6 +135,7 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) { } } + determineUnboundedIntegerVariables(); determineUnboundedTypeVariables(); for (auto i = 0; i < formalArgsCnt; i++) { TypePtr actualArg; @@ -87,7 +143,7 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) { actualArg = randType(); } else { actualArg = exec::SignatureBinder::tryResolveType( - formalArgs[i], variables(), bindings_); + formalArgs[i], variables(), bindings_, integerVariablesBindings_); VELOX_CHECK(actualArg != nullptr); } argumentTypes_.push_back(actualArg); @@ -113,13 +169,17 @@ TypePtr ArgumentTypeFuzzer::fuzzReturnType() { nullptr, "Only fuzzing uninitialized return type is allowed."); + determineUnboundedIntegerVariables(); determineUnboundedTypeVariables(); if (signature_.returnType().baseName() == "any") { returnType_ = randType(); return returnType_; } else { returnType_ = exec::SignatureBinder::tryResolveType( - signature_.returnType(), variables(), bindings_); + signature_.returnType(), + variables(), + bindings_, + integerVariablesBindings_); VELOX_CHECK_NE(returnType_, nullptr); return returnType_; } diff --git a/velox/expression/tests/utils/ArgumentTypeFuzzer.h b/velox/expression/tests/utils/ArgumentTypeFuzzer.h index 540a353221bfd..d6427c4b1500d 100644 --- a/velox/expression/tests/utils/ArgumentTypeFuzzer.h +++ b/velox/expression/tests/utils/ArgumentTypeFuzzer.h @@ -36,6 +36,14 @@ class ArgumentTypeFuzzer { std::mt19937& rng) : ArgumentTypeFuzzer(signature, nullptr, rng) {} + ArgumentTypeFuzzer( + const exec::FunctionSignature& signature, + std::mt19937& rng, + const std::unordered_map& integerVariablesBindings) + : ArgumentTypeFuzzer(signature, nullptr, rng) { + integerVariablesBindings_ = integerVariablesBindings; + } + ArgumentTypeFuzzer( const exec::FunctionSignature& signature, const TypePtr& returnType, @@ -65,6 +73,10 @@ class ArgumentTypeFuzzer { return signature_.variables(); } + /// Bind each integer variable that is not determined to a randomly generated + /// value. + void determineUnboundedIntegerVariables(); + /// Bind each type variable that is not determined by the return type to a /// randomly generated type. void determineUnboundedTypeVariables(); @@ -83,6 +95,9 @@ class ArgumentTypeFuzzer { /// Bindings between type variables and their actual types. std::unordered_map bindings_; + /// Bindings between integer variables and their values. + std::unordered_map integerVariablesBindings_; + /// RNG to generate random types for unbounded type variables when necessary. std::mt19937& rng_; }; diff --git a/velox/type/Type.h b/velox/type/Type.h index 34fb60af35dbb..aaa4832047bdb 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -1760,6 +1760,7 @@ using S1 = IntegerVariable<5>; using S2 = IntegerVariable<6>; using S3 = IntegerVariable<7>; using S4 = IntegerVariable<8>; +const uint8_t kIntegerPairSize = 4; template struct ShortDecimal {