Skip to content

Commit

Permalink
support decimal in expression fuzzer test
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 26, 2024
1 parent 3aa020d commit ba31081
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 32 deletions.
1 change: 1 addition & 0 deletions velox/expression/ConstantExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ void appendSqlLiteral(
case TypeKind::TINYINT:
case TypeKind::SMALLINT:
case TypeKind::BIGINT:
case TypeKind::HUGEINT:
case TypeKind::TIMESTAMP:
case TypeKind::REAL:
case TypeKind::DOUBLE:
Expand Down
3 changes: 0 additions & 3 deletions velox/expression/ReverseSignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ bool ReverseSignatureBinder::hasConstrainedIntegerVariable(
}

bool ReverseSignatureBinder::tryBind() {
if (hasConstrainedIntegerVariable(signature_.returnType())) {
return false;
}
return SignatureBinderBase::tryBind(signature_.returnType(), returnType_);
}

Expand Down
95 changes: 91 additions & 4 deletions velox/expression/tests/ArgumentTypeFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,42 @@ class ArgumentTypeFuzzerTest : public testing::Test {
}
}

void testFuzzingDecimalSuccess(
const std::shared_ptr<exec::FunctionSignature>& signature,
int32_t expectedArguments,
std::optional<TypeKind> outputKind = std::nullopt) {
std::mt19937 seed{0};
ArgumentTypeFuzzer fuzzer{*signature, seed};
ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs));

auto& argumentTypes = fuzzer.argumentTypes();
ASSERT_LE(argumentTypes.size(), expectedArguments);

auto& argumentSignatures = signature->argumentTypes();
int i;
for (i = 0; i < expectedArguments; ++i) {
ASSERT_TRUE(argumentTypes[i]->isDecimal())
<< "at " << i
<< ": Expected decimal. Got: " << argumentTypes[i]->toString();
}

if (i < argumentTypes.size()) {
ASSERT_TRUE(signature->variableArity());
ASSERT_LE(
argumentTypes.size() - argumentSignatures.size(), kMaxVariadicArgs);
for (int j = i; j < argumentTypes.size(); ++j) {
ASSERT_TRUE(argumentTypes[j]->equivalent(*argumentTypes[i - 1]));
}
}

const auto outputType = fuzzer.fuzzReturnType();
if (outputKind.has_value()) {
ASSERT_TRUE(outputType->kind() == outputKind);
} else {
ASSERT_TRUE(outputType->isDecimal());
}
}

void testFuzzingFailure(
const std::shared_ptr<exec::FunctionSignature>& signature,
const TypePtr& returnType) {
Expand Down Expand Up @@ -222,9 +258,36 @@ TEST_F(ArgumentTypeFuzzerTest, any) {
ASSERT_TRUE(argumentTypes[0] != nullptr);
}

TEST_F(ArgumentTypeFuzzerTest, unsupported) {
// Constraints on the return type is not supported.
auto signature =
TEST_F(ArgumentTypeFuzzerTest, decimal) {
auto signature = exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("a_precision")
.returnType("boolean")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(a_precision, a_scale)")
.build();

testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN);

signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("b_precision")
.integerVariable("b_scale")
.integerVariable(
"r_precision",
"min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)")
.integerVariable("r_scale", "max(a_scale, b_scale)")
.returnType("DECIMAL(r_precision, r_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("DECIMAL(b_precision, b_scale)")
.build();

testFuzzingDecimalSuccess(signature, 2);

signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("b_scale")
Expand All @@ -239,7 +302,31 @@ TEST_F(ArgumentTypeFuzzerTest, unsupported) {
.argumentType("decimal(b_precision, b_scale)")
.build();

testFuzzingFailure(signature, DECIMAL(13, 6));
testFuzzingDecimalSuccess(signature, 2, TypeKind::ROW);

signature = exec::FunctionSignatureBuilder()
.integerVariable("i1")
.integerVariable("i2")
.integerVariable("i5")
.integerVariable("i6")
.integerVariable(
"i3", "min(38, max(i1 - i5, i2 - i6) + max(i5, i6) + 1)")
.integerVariable("i7", "max(i5, i6)")
.returnType("decimal(i3,i7)")
.argumentType("decimal(i1,i5)")
.argumentType("decimal(i2,i6)")
.build();
testFuzzingDecimalSuccess(signature, 2);

signature = exec::FunctionSignatureBuilder()
.integerVariable("i1")
.integerVariable("i5")
.returnType("boolean")
.argumentType("decimal(i1,i5)")
.argumentType("decimal(i1,i5)")
.argumentType("decimal(i1,i5)")
.build();
testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN);
}

TEST_F(ArgumentTypeFuzzerTest, lambda) {
Expand Down
164 changes: 146 additions & 18 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,20 @@ bool useTypeName(

bool isSupportedSignature(
const exec::FunctionSignature& signature,
bool enableComplexType) {
// Not supporting lambda functions, or functions using decimal and
// timestamp with time zone types.
bool enableComplexType,
bool enableDecimalType) {
// When enableDecimalType is disabled, not supporting decimal functions.
const bool useDecimal =
(useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal"));
// Not supporting lambda functions, or functions using timestamp with time
// zone types.
return !(
useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(!enableDecimalType && useDecimal) ||
(enableComplexType && useTypeName(signature, "unknown")));
}

Expand Down Expand Up @@ -563,10 +567,14 @@ ExpressionFuzzer::ExpressionFuzzer(
for (const auto& signature : function.second) {
++totalFunctionSignatures;

if (!isSupportedSignature(*signature, options_.enableComplexTypes)) {
if (!isSupportedSignature(
*signature,
options_.enableComplexTypes,
options_.enableDecimalType)) {
continue;
}
if (!(signature->variables().empty() || options_.enableComplexTypes)) {
if (!(signature->variables().empty() || options_.enableComplexTypes ||
options_.enableDecimalType)) {
LOG(WARNING) << "Skipping unsupported signature: " << function.first
<< signature->toString();
continue;
Expand Down Expand Up @@ -690,8 +698,10 @@ ExpressionFuzzer::ExpressionFuzzer(

for (const auto& it : signatureTemplates_) {
auto& returnType = it.signature->returnType().baseName();
auto* returnTypeKey = &returnType;
if (it.typeVariables.find(returnType) != it.typeVariables.end()) {
std::string typeName = returnType;
folly::toLowerAscii(typeName);
const auto* returnTypeKey = &typeName;
if (it.typeVariables.find(typeName) != it.typeVariables.end()) {
// Return type is a template variable.
returnTypeKey = &kTypeParameterName;
}
Expand All @@ -710,6 +720,9 @@ ExpressionFuzzer::ExpressionFuzzer(
// Register function override (for cases where we want to restrict the types
// or parameters we pass to functions).
registerFuncOverride(&ExpressionFuzzer::generateSwitchArgs, "switch");
registerFuncOverride(
&ExpressionFuzzer::generateExtremeFunctionArgs, "greatest");
registerFuncOverride(&ExpressionFuzzer::generateExtremeFunctionArgs, "least");
}

void ExpressionFuzzer::getTicketsForFunctions() {
Expand Down Expand Up @@ -762,9 +775,11 @@ int ExpressionFuzzer::getTickets(const std::string& funcName) {
void ExpressionFuzzer::addToTypeToExpressionListByTicketTimes(
const std::string& type,
const std::string& funcName) {
std::string typeName = type;
folly::toLowerAscii(typeName);
int tickets = getTickets(funcName);
for (int i = 0; i < tickets; i++) {
typeToExpressionList_[type].push_back(funcName);
typeToExpressionList_[typeName].push_back(funcName);
}
}

Expand Down Expand Up @@ -984,6 +999,39 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::generateSwitchArgs(
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateExtremeFunctionArgs(
const CallableSignature& input) {
const auto argTypes = input.args;
VELOX_CHECK_GE(
argTypes.size(),
1,
"Only one input is expected from the template signature.");
if (!argTypes[0]->isDecimal()) {
return generateArgs(input);
}

auto numVarArgs =
!input.variableArity ? 0 : rand32(0, options_.maxNumVarArgs);
std::vector<core::TypedExprPtr> inputExpressions;
inputExpressions.reserve(argTypes.size() + numVarArgs);
inputExpressions.emplace_back(
generateArg(argTypes.at(0), input.constantArgs.at(0)));

// Append varargs to the argument list.
for (int i = 0; i < numVarArgs; i++) {
size_t argClass = rand32(0, 1);
core::TypedExprPtr argExpr;
const auto argType = inputExpressions[0]->type();
if (argClass == kArgConstant) {
argExpr = generateArgConstant(argType);
} else {
argExpr = generateArgColumn(argType);
}
inputExpressions.emplace_back(argExpr);
}
return inputExpressions;
}

ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpressions(
const RowTypePtr& outType) {
state.reset();
Expand Down Expand Up @@ -1060,7 +1108,8 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
} else {
expression = generateExpressionFromConcreteSignatures(
returnType, chosenFunctionName);
if (!expression && options_.enableComplexTypes) {
if (!expression &&
(options_.enableComplexTypes || options_.enableDecimalType)) {
expression = generateExpressionFromSignatureTemplate(
returnType, chosenFunctionName);
}
Expand All @@ -1084,13 +1133,92 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
return funcIt->second(callable);
}

TypePtr ExpressionFuzzer::getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature) {
if (signature == nullptr) {
return nullptr;
}
// When function is unnested, the types of args are decided by fuzzer argument
// types. For nested function, they are decided by the return types of
// children functions. To handle the constraints between input types and
// output types of a decimal function, extract the input precisions and scales
// from decimal arguments, and bind them to integer variables.

// Checks if any variable is integer constrained, and get the decimal name
// style.
bool integerConstrained = false;
char decimalNameStyle = 0;
for (const auto& [variableName, variableInfo] : signature->variables()) {
const auto constraint = variableInfo.constraint();
if (variableInfo.isIntegerParameter()) {
if (variableName == "precision" || variableName == "scale") {
decimalNameStyle = 'v';
} else if (
variableName.find("precision") != std::string::npos ||
variableName.find("scale") != std::string::npos) {
decimalNameStyle = 'p';
} else if (variableName.find("i") != std::string::npos) {
decimalNameStyle = 'i';
}
integerConstrained = true;
break;
}
}

std::unordered_map<std::string, int> decimalVariablesBindings;
column_index_t decimalColIndex = 1;
for (column_index_t i = 0; i < args.size(); ++i) {
const auto argType = args[i]->type();
if (argType->isDecimal()) {
const auto [p, s] = getDecimalPrecisionScale(*argType);
switch (decimalNameStyle) {
case 'v': {
decimalVariablesBindings["precision"] = p;
decimalVariablesBindings["scale"] = s;
break;
}
case 'p': {
const auto column = std::string(1, 'a' + i);
decimalVariablesBindings[column + "_precision"] = p;
decimalVariablesBindings[column + "_scale"] = s;
break;
}
case 'i': {
decimalVariablesBindings["i" + std::to_string(decimalColIndex)] = p;
decimalVariablesBindings
["i" + std::to_string(decimalColIndex + kIntegerPairSize)] = s;
decimalColIndex++;
break;
}
default:
VELOX_UNSUPPORTED("Unsupported decimal name style.");
}
}
}

if (integerConstrained && decimalVariablesBindings.size() > 0 && signature) {
// Compute a correct output type according to constraints with the integer
// variables bindings.
ArgumentTypeFuzzer fuzzer{*signature, rng_, decimalVariablesBindings};
return fuzzer.fuzzReturnType();
}
return nullptr;
}

core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type) {
const TypePtr& type,
const exec::FunctionSignature* signature) {
auto args = getArgsForCallable(callable);
// Generate a CallTypedExpr with type because callable.returnType may not have
// the required field names.
return std::make_shared<core::CallTypedExpr>(type, args, callable.name);

// If a constrained output type is generated, use it to avoid breaking the
// constraints between input types and output types. Otherwise, generate a
// CallTypedExpr with type because callable.returnType may not have the
// required field names.
const auto constrainedType = getConstrainedOutputType(args, signature);
return std::make_shared<core::CallTypedExpr>(
constrainedType ? constrainedType : type, args, callable.name);
}

const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature(
Expand Down Expand Up @@ -1172,7 +1300,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures(
}

markSelected(chosen->name);
return getCallExprFromCallable(*chosen, returnType);
return getCallExprFromCallable(*chosen, returnType, nullptr);
}

const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate(
Expand Down Expand Up @@ -1277,7 +1405,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate(
.constantArgs = constantArguments};

markSelected(chosen->name);
return getCallExprFromCallable(callable, returnType);
return getCallExprFromCallable(callable, returnType, chosen->signature);
}

core::TypedExprPtr ExpressionFuzzer::generateCastExpression(
Expand Down
Loading

0 comments on commit ba31081

Please sign in to comment.