Skip to content

Commit

Permalink
support decimal in expression fuzzer test
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 21, 2024
1 parent 86f12b7 commit 6c05940
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 25 deletions.
3 changes: 0 additions & 3 deletions velox/expression/ReverseSignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ bool ReverseSignatureBinder::hasConstrainedIntegerVariable(
}

bool ReverseSignatureBinder::tryBind() {
if (hasConstrainedIntegerVariable(signature_.returnType())) {
return false;
}
return SignatureBinderBase::tryBind(signature_.returnType(), returnType_);
}

Expand Down
71 changes: 67 additions & 4 deletions velox/expression/tests/ArgumentTypeFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,42 @@ class ArgumentTypeFuzzerTest : public testing::Test {
}
}

void testFuzzingDecimalSuccess(
const std::shared_ptr<exec::FunctionSignature>& signature,
int32_t expectedArguments,
std::optional<TypeKind> outputKind = std::nullopt) {
std::mt19937 seed{0};
ArgumentTypeFuzzer fuzzer{*signature, seed};
ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs));

auto& argumentTypes = fuzzer.argumentTypes();
ASSERT_LE(argumentTypes.size(), expectedArguments);

auto& argumentSignatures = signature->argumentTypes();
int i;
for (i = 0; i < expectedArguments; ++i) {
ASSERT_TRUE(argumentTypes[i]->isDecimal())
<< "at " << i
<< ": Expected decimal. Got: " << argumentTypes[i]->toString();
}

if (i < argumentTypes.size()) {
ASSERT_TRUE(signature->variableArity());
ASSERT_LE(
argumentTypes.size() - argumentSignatures.size(), kMaxVariadicArgs);
for (int j = i; j < argumentTypes.size(); ++j) {
ASSERT_TRUE(argumentTypes[j]->equivalent(*argumentTypes[i - 1]));
}
}

const auto outputType = fuzzer.fuzzReturnType();
if (outputKind.has_value()) {
ASSERT_TRUE(outputType->kind() == outputKind);
} else {
ASSERT_TRUE(outputType->isDecimal());
}
}

void testFuzzingFailure(
const std::shared_ptr<exec::FunctionSignature>& signature,
const TypePtr& returnType) {
Expand Down Expand Up @@ -222,9 +258,36 @@ TEST_F(ArgumentTypeFuzzerTest, any) {
ASSERT_TRUE(argumentTypes[0] != nullptr);
}

TEST_F(ArgumentTypeFuzzerTest, unsupported) {
// Constraints on the return type is not supported.
auto signature =
TEST_F(ArgumentTypeFuzzerTest, decimal) {
auto signature = exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("a_precision")
.returnType("boolean")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.build();

testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN);

signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable(
"r_precision",
"min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)")
.integerVariable("r_scale", "max(a_scale, b_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build();

testFuzzingDecimalSuccess(signature, 2);

signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("b_scale")
Expand All @@ -239,7 +302,7 @@ TEST_F(ArgumentTypeFuzzerTest, unsupported) {
.argumentType("decimal(b_precision, b_scale)")
.build();

testFuzzingFailure(signature, DECIMAL(13, 6));
testFuzzingDecimalSuccess(signature, 2, TypeKind::ROW);
}

TEST_F(ArgumentTypeFuzzerTest, lambda) {
Expand Down
99 changes: 85 additions & 14 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,6 @@ bool isSupportedSignature(
// timestamp with time zone types.
return !(
useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(enableComplexType && useTypeName(signature, "unknown")));
Expand Down Expand Up @@ -587,7 +584,8 @@ ExpressionFuzzer::ExpressionFuzzer(
if (!isSupportedSignature(*signature, options_.enableComplexTypes)) {
continue;
}
if (!(signature->variables().empty() || options_.enableComplexTypes)) {
if (!(signature->variables().empty() || options_.enableComplexTypes ||
options_.enableDecimalType)) {
LOG(WARNING) << "Skipping unsupported signature: " << function.first
<< signature->toString();
continue;
Expand Down Expand Up @@ -711,8 +709,10 @@ ExpressionFuzzer::ExpressionFuzzer(

for (const auto& it : signatureTemplates_) {
auto& returnType = it.signature->returnType().baseName();
auto* returnTypeKey = &returnType;
if (it.typeVariables.find(returnType) != it.typeVariables.end()) {
std::string typeName = returnType;
folly::toLowerAscii(typeName);
const auto* returnTypeKey = &typeName;
if (it.typeVariables.find(typeName) != it.typeVariables.end()) {
// Return type is a template variable.
returnTypeKey = &kTypeParameterName;
}
Expand Down Expand Up @@ -783,9 +783,11 @@ int ExpressionFuzzer::getTickets(const std::string& funcName) {
void ExpressionFuzzer::addToTypeToExpressionListByTicketTimes(
const std::string& type,
const std::string& funcName) {
std::string typeName = type;
folly::toLowerAscii(typeName);
int tickets = getTickets(funcName);
for (int i = 0; i < tickets; i++) {
typeToExpressionList_[type].push_back(funcName);
typeToExpressionList_[typeName].push_back(funcName);
}
}

Expand Down Expand Up @@ -1081,7 +1083,8 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
} else {
expression = generateExpressionFromConcreteSignatures(
returnType, chosenFunctionName);
if (!expression && options_.enableComplexTypes) {
if (!expression &&
(options_.enableComplexTypes || options_.enableDecimalType)) {
expression = generateExpressionFromSignatureTemplate(
returnType, chosenFunctionName);
}
Expand All @@ -1105,13 +1108,81 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
return funcIt->second(callable);
}

TypePtr ExpressionFuzzer::getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature) {
// When function is unnested, the types of args are decided by fuzzer argument
// types. For nested function, they are decided by the return types of
// children functions. To handle the constraints between input types and
// output types of a decimal function, extract the input precisions and scales
// from decimal arguments, and bind them to integer variables.

// Checks if any variable is integer constrained, and get the decimal name
// style.
bool integerConstrained = false;
char decimalNameStyle = 0;
for (const auto& [variableName, variableInfo] : signature->variables()) {
const auto constraint = variableInfo.constraint();
if (variableInfo.isIntegerParameter()) {
if (variableName.find("precision") != std::string::npos ||
variableName.find("scale") != std::string::npos) {
decimalNameStyle = 'p';
} else if (variableName.find("i") != std::string::npos) {
decimalNameStyle = 'i';
}
integerConstrained = true;
break;
}
}

std::unordered_map<std::string, int> decimalVariablesBindings;
column_index_t decimalColIndex = 1;
for (column_index_t i = 0; i < args.size(); ++i) {
const auto argType = args[i]->type();
if (argType->isDecimal()) {
const auto [p, s] = getDecimalPrecisionScale(*argType);
switch (decimalNameStyle) {
case 'p': {
const auto column = std::string(1, 'a' + i);
decimalVariablesBindings[column + "_precision"] = p;
decimalVariablesBindings[column + "_scale"] = s;
break;
}
case 'i': {
decimalVariablesBindings["i" + std::to_string(decimalColIndex)] = p;
decimalVariablesBindings
["i" + std::to_string(decimalColIndex + kIntegerPairSize)] = s;
decimalColIndex++;
break;
}
default:
VELOX_UNSUPPORTED("Unsupported decimal name style.");
}
}
}

if (integerConstrained && decimalVariablesBindings.size() > 0 && signature) {
// Compute a correct output type according to constraints with the integer
// variables bindings.
ArgumentTypeFuzzer fuzzer{*signature, rng_, decimalVariablesBindings};
return fuzzer.fuzzReturnType();
}
return nullptr;
}

core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type) {
const TypePtr& type,
const exec::FunctionSignature* signature) {
auto args = getArgsForCallable(callable);
// Generate a CallTypedExpr with type because callable.returnType may not have
// the required field names.
return std::make_shared<core::CallTypedExpr>(type, args, callable.name);

// If a constrained output type is generated, use it to avoid breaking the
// constraints between input types and output types. Otherwise, generate a
// CallTypedExpr with type because callable.returnType may not have the
// required field names.
const auto constrainedType = getConstrainedOutputType(args, signature);
return std::make_shared<core::CallTypedExpr>(
constrainedType ? constrainedType : type, args, callable.name);
}

const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature(
Expand Down Expand Up @@ -1193,7 +1264,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures(
}

markSelected(chosen->name);
return getCallExprFromCallable(*chosen, returnType);
return getCallExprFromCallable(*chosen, returnType, nullptr);
}

const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate(
Expand Down Expand Up @@ -1298,7 +1369,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate(
.constantArgs = constantArguments};

markSelected(chosen->name);
return getCallExprFromCallable(callable, returnType);
return getCallExprFromCallable(callable, returnType, chosen->signature);
}

core::TypedExprPtr ExpressionFuzzer::generateCastExpression(
Expand Down
11 changes: 10 additions & 1 deletion velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class ExpressionFuzzer {
// types.
bool enableComplexTypes = false;

// Enable testing of function signatures with decimal argument or return
// types.
bool enableDecimalType = false;

// Enable generation of expressions where one input column can be used by
// multiple subexpressions.
bool enableColumnReuse = false;
Expand Down Expand Up @@ -267,9 +271,14 @@ class ExpressionFuzzer {
std::vector<core::TypedExprPtr> generateSwitchArgs(
const CallableSignature& input);

TypePtr getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature);

core::TypedExprPtr getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type);
const TypePtr& type,
const exec::FunctionSignature* signature = nullptr);

/// Return a random signature mapped to functionName in
/// expressionToSignature_ whose return type can match returnType. Return
Expand Down
6 changes: 6 additions & 0 deletions velox/expression/tests/FuzzerRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ DEFINE_bool(
false,
"Enable testing of function signatures with complex argument or return types.");

DEFINE_bool(
velox_fuzzer_enable_decimal_type,
false,
"Enable testing of function signatures with decimal argument or return types.");

DEFINE_bool(
velox_fuzzer_enable_column_reuse,
false,
Expand Down Expand Up @@ -168,6 +173,7 @@ ExpressionFuzzer::Options getExpressionFuzzerOptions(
opts.enableVariadicSignatures = FLAGS_enable_variadic_signatures;
opts.enableDereference = FLAGS_enable_dereference;
opts.enableComplexTypes = FLAGS_velox_fuzzer_enable_complex_types;
opts.enableDecimalType = FLAGS_velox_fuzzer_enable_decimal_type;
opts.enableColumnReuse = FLAGS_velox_fuzzer_enable_column_reuse;
opts.enableExpressionReuse = FLAGS_velox_fuzzer_enable_expression_reuse;
opts.functionTickets = FLAGS_assign_function_tickets;
Expand Down
Loading

0 comments on commit 6c05940

Please sign in to comment.