Skip to content

Commit

Permalink
try decimal fuzzer test
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 15, 2024
1 parent 925cef0 commit ec8b08d
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 19 deletions.
76 changes: 66 additions & 10 deletions velox/expression/tests/ArgumentTypeFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,35 @@ class ArgumentTypeFuzzerTest : public testing::Test {
}
}

void testFuzzingDecimalSuccess(
const std::shared_ptr<exec::FunctionSignature>& signature,
const TypePtr& returnType,
int32_t expectedArguments) {
std::mt19937 seed{0};
ArgumentTypeFuzzer fuzzer{*signature, returnType, seed};
ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs));

auto& argumentTypes = fuzzer.argumentTypes();
ASSERT_LE(argumentTypes.size(), expectedArguments);

auto& argumentSignatures = signature->argumentTypes();
int i;
for (i = 0; i < expectedArguments; ++i) {
ASSERT_TRUE(argumentTypes[i]->isDecimal())
<< "at " << i
<< ": Expected decimal. Got: " << argumentTypes[i]->toString();
}

if (i < argumentTypes.size()) {
ASSERT_TRUE(signature->variableArity());
ASSERT_LE(
argumentTypes.size() - argumentSignatures.size(), kMaxVariadicArgs);
for (int j = i; j < argumentTypes.size(); ++j) {
ASSERT_TRUE(argumentTypes[j]->equivalent(*argumentTypes[i - 1]));
}
}
}

void testFuzzingFailure(
const std::shared_ptr<exec::FunctionSignature>& signature,
const TypePtr& returnType) {
Expand Down Expand Up @@ -222,24 +251,51 @@ TEST_F(ArgumentTypeFuzzerTest, any) {
ASSERT_TRUE(argumentTypes[0] != nullptr);
}

TEST_F(ArgumentTypeFuzzerTest, unsupported) {
// Constraints on the return type is not supported.
auto signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("b_scale")
TEST_F(ArgumentTypeFuzzerTest, decimal) {
auto signature = exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("a_precision")
.returnType("boolean")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.build();

testFuzzingDecimalSuccess(signature, BOOLEAN(), 3);

signature = exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable(
"r_precision",
"min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)")
.integerVariable("r_scale", "max(a_scale, b_scale)")
.returnType("row(array(decimal(r_precision, r_scale)))")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(b_precision, b_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build();

testFuzzingFailure(signature, DECIMAL(13, 6));
testFuzzingDecimalSuccess(signature, BOOLEAN(), 2);

// Constraints on the return type is not supported.
// signature =
// exec::FunctionSignatureBuilder()
// .integerVariable("a_scale")
// .integerVariable("b_scale")
// .integerVariable("a_precision")
// .integerVariable("b_precision")
// .integerVariable(
// "r_precision",
// "min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)")
// .integerVariable("r_scale", "max(a_scale, b_scale)")
// .returnType("row(array(decimal(r_precision, r_scale)))")
// .argumentType("decimal(a_precision, a_scale)")
// .argumentType("decimal(b_precision, b_scale)")
// .build();

// testFuzzingFailure(signature, DECIMAL(13, 6));
}

TEST_F(ArgumentTypeFuzzerTest, lambda) {
Expand Down
9 changes: 4 additions & 5 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,6 @@ bool isSupportedSignature(
// timestamp with time zone types.
return !(
useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(enableComplexType && useTypeName(signature, "unknown")));
Expand Down Expand Up @@ -587,7 +584,8 @@ ExpressionFuzzer::ExpressionFuzzer(
if (!isSupportedSignature(*signature, options_.enableComplexTypes)) {
continue;
}
if (!(signature->variables().empty() || options_.enableComplexTypes)) {
if (!(signature->variables().empty() || options_.enableComplexTypes ||
options_.enableDecimalType)) {
LOG(WARNING) << "Skipping unsupported signature: " << function.first
<< signature->toString();
continue;
Expand Down Expand Up @@ -1081,7 +1079,8 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
} else {
expression = generateExpressionFromConcreteSignatures(
returnType, chosenFunctionName);
if (!expression && options_.enableComplexTypes) {
if (!expression &&
(options_.enableComplexTypes || options_.enableDecimalType)) {
expression = generateExpressionFromSignatureTemplate(
returnType, chosenFunctionName);
}
Expand Down
4 changes: 4 additions & 0 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class ExpressionFuzzer {
// types.
bool enableComplexTypes = false;

// Enable testing of function signatures with decimal argument or return
// types.
bool enableDecimalType = false;

// Enable generation of expressions where one input column can be used by
// multiple subexpressions.
bool enableColumnReuse = false;
Expand Down
6 changes: 6 additions & 0 deletions velox/expression/tests/FuzzerRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ DEFINE_bool(
false,
"Enable testing of function signatures with complex argument or return types.");

DEFINE_bool(
velox_fuzzer_enable_decimal_type,
false,
"Enable testing of function signatures with decimal argument or return types.");

DEFINE_bool(
velox_fuzzer_enable_column_reuse,
false,
Expand Down Expand Up @@ -168,6 +173,7 @@ ExpressionFuzzer::Options getExpressionFuzzerOptions(
opts.enableVariadicSignatures = FLAGS_enable_variadic_signatures;
opts.enableDereference = FLAGS_enable_dereference;
opts.enableComplexTypes = FLAGS_velox_fuzzer_enable_complex_types;
opts.enableDecimalType = FLAGS_velox_fuzzer_enable_decimal_type;
opts.enableColumnReuse = FLAGS_velox_fuzzer_enable_column_reuse;
opts.enableExpressionReuse = FLAGS_velox_fuzzer_enable_expression_reuse;
opts.functionTickets = FLAGS_assign_function_tickets;
Expand Down
46 changes: 44 additions & 2 deletions velox/expression/tests/utils/ArgumentTypeFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,30 @@ std::optional<TypeKind> baseNameToTypeKind(const std::string& typeName) {
return tryMapNameToTypeKind(kindName);
}

void ArgumentTypeFuzzer::determineUnboundedIntegerVariables() {
for (const auto& [variableName, variableInfo] : variables()) {
if (!variableInfo.isIntegerParameter() ||
integerVariablesBindings_.count(variableName)) {
continue;
}

if (auto pos = variableName.find("precision"); pos != std::string::npos) {
std::srand(std::time(nullptr));
const int32_t precision = rand() % 38;
integerVariablesBindings_[variableName] = precision;
const auto colName = variableName.substr(0, pos);
// Corresponding scale should not exceed the precision.
for (const auto& [name, info] : variables()) {
if (name == colName + "scale") {
integerVariablesBindings_[name] = rand() % precision;
}
}
} else {
integerVariablesBindings_[variableName] = rand();
}
}
}

void ArgumentTypeFuzzer::determineUnboundedTypeVariables() {
for (auto& [variableName, variableInfo] : variables()) {
if (!variableInfo.isTypeParameter()) {
Expand Down Expand Up @@ -80,11 +104,18 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) {
}
}

determineUnboundedIntegerVariables();
determineUnboundedTypeVariables();
for (auto i = 0; i < formalArgsCnt; i++) {
TypePtr actualArg;
if (formalArgs[i].baseName() == "any") {
auto baseName = formalArgs[i].baseName();
folly::toLowerAscii(baseName);
if (baseName == "any") {
actualArg = randType();
} else if (formalArgs[i].baseName() == "decimal") {
actualArg = exec::SignatureBinder::tryResolveType(
formalArgs[i], variables(), bindings_, integerVariablesBindings_);
VELOX_CHECK(actualArg != nullptr);
} else {
actualArg = exec::SignatureBinder::tryResolveType(
formalArgs[i], variables(), bindings_);
Expand Down Expand Up @@ -113,10 +144,21 @@ TypePtr ArgumentTypeFuzzer::fuzzReturnType() {
nullptr,
"Only fuzzing uninitialized return type is allowed.");

determineUnboundedIntegerVariables();
determineUnboundedTypeVariables();
if (signature_.returnType().baseName() == "any") {
auto baseName = signature_.returnType().baseName();
folly::toLowerAscii(baseName);
if (baseName == "any") {
returnType_ = randType();
return returnType_;
} else if (baseName == "decimal") {
returnType_ = exec::SignatureBinder::tryResolveType(
signature_.returnType(),
variables(),
bindings_,
integerVariablesBindings_);
VELOX_CHECK_NE(returnType_, nullptr);
return returnType_;
} else {
returnType_ = exec::SignatureBinder::tryResolveType(
signature_.returnType(), variables(), bindings_);
Expand Down
7 changes: 7 additions & 0 deletions velox/expression/tests/utils/ArgumentTypeFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class ArgumentTypeFuzzer {
return signature_.variables();
}

/// Bind each integer variable that is not determined to a randomly generated
/// value.
void determineUnboundedIntegerVariables();

/// Bind each type variable that is not determined by the return type to a
/// randomly generated type.
void determineUnboundedTypeVariables();
Expand All @@ -83,6 +87,9 @@ class ArgumentTypeFuzzer {
/// Bindings between type variables and their actual types.
std::unordered_map<std::string, TypePtr> bindings_;

/// Bindings between integer variables and their values.
std::unordered_map<std::string, int> integerVariablesBindings_;

/// RNG to generate random types for unbounded type variables when necessary.
std::mt19937& rng_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void registerComparisonFunctions(const std::string& prefix) {
registerFunction<BetweenFunction, bool, Timestamp, Timestamp, Timestamp>(
{prefix + "between"});

VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_between, prefix + "between");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_between, prefix + "decimal_between");
}

} // namespace facebook::velox::functions
2 changes: 1 addition & 1 deletion velox/functions/sparksql/RegisterArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void registerArithmeticFunctions(const std::string& prefix) {
registerFunction<sparksql::Log10Function, double, double>({prefix + "log10"});
registerRandFunctions(prefix);

VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_add, prefix + "add");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_add, prefix + "decimal_add");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "subtract");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "multiply");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "divide");
Expand Down

0 comments on commit ec8b08d

Please sign in to comment.