Skip to content

Commit

Permalink
support decimal in expression fuzzer test
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 20, 2024
1 parent 5e07790 commit 8e4fa85
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 20 deletions.
3 changes: 0 additions & 3 deletions velox/expression/ReverseSignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ bool ReverseSignatureBinder::hasConstrainedIntegerVariable(
}

bool ReverseSignatureBinder::tryBind() {
if (hasConstrainedIntegerVariable(signature_.returnType())) {
return false;
}
return SignatureBinderBase::tryBind(signature_.returnType(), returnType_);
}

Expand Down
71 changes: 67 additions & 4 deletions velox/expression/tests/ArgumentTypeFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,42 @@ class ArgumentTypeFuzzerTest : public testing::Test {
}
}

void testFuzzingDecimalSuccess(
const std::shared_ptr<exec::FunctionSignature>& signature,
int32_t expectedArguments,
std::optional<TypeKind> outputKind = std::nullopt) {
std::mt19937 seed{0};
ArgumentTypeFuzzer fuzzer{*signature, seed};
ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs));

auto& argumentTypes = fuzzer.argumentTypes();
ASSERT_LE(argumentTypes.size(), expectedArguments);

auto& argumentSignatures = signature->argumentTypes();
int i;
for (i = 0; i < expectedArguments; ++i) {
ASSERT_TRUE(argumentTypes[i]->isDecimal())
<< "at " << i
<< ": Expected decimal. Got: " << argumentTypes[i]->toString();
}

if (i < argumentTypes.size()) {
ASSERT_TRUE(signature->variableArity());
ASSERT_LE(
argumentTypes.size() - argumentSignatures.size(), kMaxVariadicArgs);
for (int j = i; j < argumentTypes.size(); ++j) {
ASSERT_TRUE(argumentTypes[j]->equivalent(*argumentTypes[i - 1]));
}
}

const auto outputType = fuzzer.fuzzReturnType();
if (outputKind.has_value()) {
ASSERT_TRUE(outputType->kind() == outputKind);
} else {
ASSERT_TRUE(outputType->isDecimal());
}
}

void testFuzzingFailure(
const std::shared_ptr<exec::FunctionSignature>& signature,
const TypePtr& returnType) {
Expand Down Expand Up @@ -222,9 +258,36 @@ TEST_F(ArgumentTypeFuzzerTest, any) {
ASSERT_TRUE(argumentTypes[0] != nullptr);
}

TEST_F(ArgumentTypeFuzzerTest, unsupported) {
// Constraints on the return type is not supported.
auto signature =
TEST_F(ArgumentTypeFuzzerTest, decimal) {
auto signature = exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("a_precision")
.returnType("boolean")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.build();

testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN);

signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable(
"r_precision",
"min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)")
.integerVariable("r_scale", "max(a_scale, b_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build();

testFuzzingDecimalSuccess(signature, 2);

signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("b_scale")
Expand All @@ -239,7 +302,7 @@ TEST_F(ArgumentTypeFuzzerTest, unsupported) {
.argumentType("decimal(b_precision, b_scale)")
.build();

testFuzzingFailure(signature, DECIMAL(13, 6));
testFuzzingDecimalSuccess(signature, 2, TypeKind::ROW);
}

TEST_F(ArgumentTypeFuzzerTest, lambda) {
Expand Down
54 changes: 45 additions & 9 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,6 @@ bool isSupportedSignature(
// timestamp with time zone types.
return !(
useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(enableComplexType && useTypeName(signature, "unknown")));
Expand Down Expand Up @@ -587,7 +584,8 @@ ExpressionFuzzer::ExpressionFuzzer(
if (!isSupportedSignature(*signature, options_.enableComplexTypes)) {
continue;
}
if (!(signature->variables().empty() || options_.enableComplexTypes)) {
if (!(signature->variables().empty() || options_.enableComplexTypes ||
options_.enableDecimalType)) {
LOG(WARNING) << "Skipping unsupported signature: " << function.first
<< signature->toString();
continue;
Expand Down Expand Up @@ -1081,7 +1079,8 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
} else {
expression = generateExpressionFromConcreteSignatures(
returnType, chosenFunctionName);
if (!expression && options_.enableComplexTypes) {
if (!expression &&
(options_.enableComplexTypes || options_.enableDecimalType)) {
expression = generateExpressionFromSignatureTemplate(
returnType, chosenFunctionName);
}
Expand All @@ -1107,11 +1106,48 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(

core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type) {
const TypePtr& type,
const exec::FunctionSignature* signature) {
auto args = getArgsForCallable(callable);

// When function is unnested, the types of args are decided by fuzzer argument
// types. For nested function, they are decided by the return types of
// children functions. To handle the constraints between input types and
// output types of a decimal function, extract the input precisions and scales
// from decimal arguments, and bind them to integer variables.
std::unordered_map<std::string, int> integerVariablesBindings;
for (column_index_t i = 0; i < args.size(); ++i) {
const auto argType = args[i]->type();
if (argType->isDecimal()) {
const auto [p, s] = getDecimalPrecisionScale(*argType);
const auto column = std::string(1, 'a' + i);
integerVariablesBindings[column + "_precision"] = p;
integerVariablesBindings[column + "_scale"] = s;
}
}

// Generate a CallTypedExpr with type because callable.returnType may not have
// the required field names.
return std::make_shared<core::CallTypedExpr>(type, args, callable.name);
auto outputType = type;

if (integerVariablesBindings.size() > 0 && signature) {
// Checks if any variable is integer constrained.
bool integerConstrained = false;
for (const auto& [_, variableInfo] : signature->variables()) {
const auto constraint = variableInfo.constraint();
if (variableInfo.isIntegerParameter() && constraint != "") {
integerConstrained = true;
break;
}
}
if (integerConstrained) {
// Compute a correct output type according to constraints with the integer
// variables bindings.
ArgumentTypeFuzzer fuzzer{*signature, rng_, integerVariablesBindings};
outputType = fuzzer.fuzzReturnType();
}
}
return std::make_shared<core::CallTypedExpr>(outputType, args, callable.name);
}

const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature(
Expand Down Expand Up @@ -1193,7 +1229,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures(
}

markSelected(chosen->name);
return getCallExprFromCallable(*chosen, returnType);
return getCallExprFromCallable(*chosen, returnType, nullptr);
}

const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate(
Expand Down Expand Up @@ -1298,7 +1334,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate(
.constantArgs = constantArguments};

markSelected(chosen->name);
return getCallExprFromCallable(callable, returnType);
return getCallExprFromCallable(callable, returnType, chosen->signature);
}

core::TypedExprPtr ExpressionFuzzer::generateCastExpression(
Expand Down
7 changes: 6 additions & 1 deletion velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class ExpressionFuzzer {
// types.
bool enableComplexTypes = false;

// Enable testing of function signatures with decimal argument or return
// types.
bool enableDecimalType = false;

// Enable generation of expressions where one input column can be used by
// multiple subexpressions.
bool enableColumnReuse = false;
Expand Down Expand Up @@ -269,7 +273,8 @@ class ExpressionFuzzer {

core::TypedExprPtr getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type);
const TypePtr& type,
const exec::FunctionSignature* signature = nullptr);

/// Return a random signature mapped to functionName in
/// expressionToSignature_ whose return type can match returnType. Return
Expand Down
6 changes: 6 additions & 0 deletions velox/expression/tests/FuzzerRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ DEFINE_bool(
false,
"Enable testing of function signatures with complex argument or return types.");

DEFINE_bool(
velox_fuzzer_enable_decimal_type,
false,
"Enable testing of function signatures with decimal argument or return types.");

DEFINE_bool(
velox_fuzzer_enable_column_reuse,
false,
Expand Down Expand Up @@ -168,6 +173,7 @@ ExpressionFuzzer::Options getExpressionFuzzerOptions(
opts.enableVariadicSignatures = FLAGS_enable_variadic_signatures;
opts.enableDereference = FLAGS_enable_dereference;
opts.enableComplexTypes = FLAGS_velox_fuzzer_enable_complex_types;
opts.enableDecimalType = FLAGS_velox_fuzzer_enable_decimal_type;
opts.enableColumnReuse = FLAGS_velox_fuzzer_enable_column_reuse;
opts.enableExpressionReuse = FLAGS_velox_fuzzer_enable_expression_reuse;
opts.functionTickets = FLAGS_assign_function_tickets;
Expand Down
53 changes: 50 additions & 3 deletions velox/expression/tests/utils/ArgumentTypeFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,62 @@

#include "velox/expression/ReverseSignatureBinder.h"
#include "velox/expression/SignatureBinder.h"
#include "velox/expression/type_calculation/TypeCalculation.h"
#include "velox/type/Type.h"
#include "velox/vector/fuzzer/VectorFuzzer.h"

namespace facebook::velox::test {

std::string typeToBaseName(const TypePtr& type) {
return boost::algorithm::to_lower_copy(std::string{type->kindName()});
return type->isDecimal()
? "DECIMAL"
: boost::algorithm::to_lower_copy(std::string{type->kindName()});
}

std::optional<TypeKind> baseNameToTypeKind(const std::string& typeName) {
auto kindName = boost::algorithm::to_upper_copy(typeName);
return tryMapNameToTypeKind(kindName);
}

void ArgumentTypeFuzzer::determineUnboundedIntegerVariables() {
// Assign a random value for all integer values.
for (const auto& [variableName, variableInfo] : variables()) {
if (!variableInfo.isIntegerParameter() ||
integerVariablesBindings_.count(variableName)) {
continue;
}

if (auto pos = variableName.find("precision"); pos == std::string::npos) {
integerVariablesBindings_[variableName] =
boost::random::uniform_int_distribution<int32_t>()(rng_);
} else {
// Handle decimal precisions and scales.
const auto precision =
boost::random::uniform_int_distribution<uint32_t>(1, 38)(rng_);
integerVariablesBindings_[variableName] = precision;
const auto colName = variableName.substr(0, pos);
// Corresponding scale should not exceed the generated precision.
for (const auto& [name, info] : variables()) {
if (name == colName + "scale") {
integerVariablesBindings_[name] =
boost::random::uniform_int_distribution<uint32_t>(
0, precision)(rng_);
}
}
}
}

// Handle constraints.
for (const auto& [variableName, variableInfo] : variables()) {
const auto constraint = variableInfo.constraint();
if (constraint == "") {
continue;
}
auto calculation = fmt::format("{}={}", variableName, constraint);
expression::calculation::evaluate(calculation, integerVariablesBindings_);
}
}

void ArgumentTypeFuzzer::determineUnboundedTypeVariables() {
for (auto& [variableName, variableInfo] : variables()) {
if (!variableInfo.isTypeParameter()) {
Expand Down Expand Up @@ -80,14 +122,15 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) {
}
}

determineUnboundedIntegerVariables();
determineUnboundedTypeVariables();
for (auto i = 0; i < formalArgsCnt; i++) {
TypePtr actualArg;
if (formalArgs[i].baseName() == "any") {
actualArg = randType();
} else {
actualArg = exec::SignatureBinder::tryResolveType(
formalArgs[i], variables(), bindings_);
formalArgs[i], variables(), bindings_, integerVariablesBindings_);
VELOX_CHECK(actualArg != nullptr);
}
argumentTypes_.push_back(actualArg);
Expand All @@ -113,13 +156,17 @@ TypePtr ArgumentTypeFuzzer::fuzzReturnType() {
nullptr,
"Only fuzzing uninitialized return type is allowed.");

determineUnboundedIntegerVariables();
determineUnboundedTypeVariables();
if (signature_.returnType().baseName() == "any") {
returnType_ = randType();
return returnType_;
} else {
returnType_ = exec::SignatureBinder::tryResolveType(
signature_.returnType(), variables(), bindings_);
signature_.returnType(),
variables(),
bindings_,
integerVariablesBindings_);
VELOX_CHECK_NE(returnType_, nullptr);
return returnType_;
}
Expand Down
15 changes: 15 additions & 0 deletions velox/expression/tests/utils/ArgumentTypeFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class ArgumentTypeFuzzer {
std::mt19937& rng)
: ArgumentTypeFuzzer(signature, nullptr, rng) {}

ArgumentTypeFuzzer(
const exec::FunctionSignature& signature,
std::mt19937& rng,
const std::unordered_map<std::string, int>& integerVariablesBindings)
: ArgumentTypeFuzzer(signature, nullptr, rng) {
integerVariablesBindings_ = integerVariablesBindings;
}

ArgumentTypeFuzzer(
const exec::FunctionSignature& signature,
const TypePtr& returnType,
Expand Down Expand Up @@ -65,6 +73,10 @@ class ArgumentTypeFuzzer {
return signature_.variables();
}

/// Bind each integer variable that is not determined to a randomly generated
/// value.
void determineUnboundedIntegerVariables();

/// Bind each type variable that is not determined by the return type to a
/// randomly generated type.
void determineUnboundedTypeVariables();
Expand All @@ -83,6 +95,9 @@ class ArgumentTypeFuzzer {
/// Bindings between type variables and their actual types.
std::unordered_map<std::string, TypePtr> bindings_;

/// Bindings between integer variables and their values.
std::unordered_map<std::string, int> integerVariablesBindings_;

/// RNG to generate random types for unbounded type variables when necessary.
std::mt19937& rng_;
};
Expand Down

0 comments on commit 8e4fa85

Please sign in to comment.