Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Apr 3, 2024
1 parent dd6d547 commit 3ae1234
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 149 deletions.
10 changes: 9 additions & 1 deletion velox/docs/develop/scalar-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,16 @@ element of the array and returns a new array of the results.

The signature of a function that handles DECIMAL types can additionally take
variables and constraints to represent the precision and scale values.
The variables of input decimal types store the input precisions and scales.
Their names begin with an incrementing prefix starting from 'a', and followed
by '_precision' or '_scale'. Variables of output decimal types store the output
precision and scale. Their names begin with 'r', and followed by '_precision'
or '_scale'. When there is only one input decimal type, and the output type
holds the same precision and scale with the input type, the variables could be
named as 'precision' and 'scale'.
The constraints are evaluated using a type calculator built from Flex and Bison
tools. The decimal arithmetic addition function has the following signature:
tools.
The decimal arithmetic addition function has the following signature:

.. code-block:: c++

Expand Down
17 changes: 0 additions & 17 deletions velox/expression/ReverseSignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@

namespace facebook::velox::exec {

bool ReverseSignatureBinder::hasConstrainedIntegerVariable(
const TypeSignature& type) const {
if (type.parameters().empty()) {
auto it = variables().find(type.baseName());
return it != variables().end() && it->second.isIntegerParameter() &&
it->second.constraint() != "";
}

const auto& parameters = type.parameters();
for (const auto& parameter : parameters) {
if (hasConstrainedIntegerVariable(parameter)) {
return true;
}
}
return false;
}

bool ReverseSignatureBinder::tryBind() {
return SignatureBinderBase::tryBind(signature_.returnType(), returnType_);
}
Expand Down
4 changes: 0 additions & 4 deletions velox/expression/ReverseSignatureBinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ class ReverseSignatureBinder : private SignatureBinderBase {
}

private:
/// Return whether there is a constraint on an integer variable in type
/// signature.
bool hasConstrainedIntegerVariable(const TypeSignature& type) const;

const TypePtr returnType_;
};

Expand Down
75 changes: 62 additions & 13 deletions velox/expression/tests/ArgumentTypeFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class ArgumentTypeFuzzerTest : public testing::Test {
void testFuzzingDecimalSuccess(
const std::shared_ptr<exec::FunctionSignature>& signature,
int32_t expectedArguments,
std::optional<TypeKind> outputKind = std::nullopt) {
const std::function<bool(const std::vector<TypePtr>&, const TypePtr&)>&
returnTypeVerifier) {
std::mt19937 seed{0};
ArgumentTypeFuzzer fuzzer{*signature, seed};
ASSERT_TRUE(fuzzer.fuzzArgumentTypes(kMaxVariadicArgs));
Expand All @@ -87,11 +88,10 @@ class ArgumentTypeFuzzerTest : public testing::Test {
}
}

const auto outputType = fuzzer.fuzzReturnType();
if (outputKind.has_value()) {
ASSERT_TRUE(outputType->kind() == outputKind);
} else {
ASSERT_TRUE(outputType->isDecimal());
const auto returnType = fuzzer.fuzzReturnType();
if (returnTypeVerifier) {
ASSERT_TRUE(returnTypeVerifier(argumentTypes, returnType))
<< ": Got return type: " << returnType->toString();
}
}

Expand Down Expand Up @@ -268,7 +268,14 @@ TEST_F(ArgumentTypeFuzzerTest, decimal) {
.argumentType("decimal(a_precision, a_scale)")
.build();

testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN);
std::function<bool(const std::vector<TypePtr>&, const TypePtr&)> verifier =
[](const std::vector<TypePtr>& argumentTypes, const TypePtr& returnType) {
if (returnType->kind() != TypeKind::BOOLEAN) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 3, verifier);

signature =
exec::FunctionSignatureBuilder()
Expand All @@ -285,7 +292,14 @@ TEST_F(ArgumentTypeFuzzerTest, decimal) {
.argumentType("DECIMAL(b_precision, b_scale)")
.build();

testFuzzingDecimalSuccess(signature, 2);
verifier = [](const std::vector<TypePtr>& argumentTypes,
const TypePtr& returnType) {
if (!returnType->isDecimal()) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 2, verifier);

signature =
exec::FunctionSignatureBuilder()
Expand All @@ -302,7 +316,14 @@ TEST_F(ArgumentTypeFuzzerTest, decimal) {
.argumentType("decimal(b_precision, b_scale)")
.build();

testFuzzingDecimalSuccess(signature, 2, TypeKind::ROW);
verifier = [](const std::vector<TypePtr>& argumentTypes,
const TypePtr& returnType) {
if (returnType->kind() != TypeKind::ROW) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 2, verifier);

signature = exec::FunctionSignatureBuilder()
.integerVariable("i1")
Expand All @@ -316,7 +337,14 @@ TEST_F(ArgumentTypeFuzzerTest, decimal) {
.argumentType("decimal(i1,i5)")
.argumentType("decimal(i2,i6)")
.build();
testFuzzingDecimalSuccess(signature, 2);
verifier = [](const std::vector<TypePtr>& argumentTypes,
const TypePtr& returnType) {
if (!returnType->isDecimal()) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 2, verifier);

signature = exec::FunctionSignatureBuilder()
.integerVariable("i1")
Expand All @@ -326,7 +354,14 @@ TEST_F(ArgumentTypeFuzzerTest, decimal) {
.argumentType("decimal(i1,i5)")
.argumentType("decimal(i1,i5)")
.build();
testFuzzingDecimalSuccess(signature, 3, TypeKind::BOOLEAN);
verifier = [](const std::vector<TypePtr>& argumentTypes,
const TypePtr& returnType) {
if (returnType->kind() != TypeKind::BOOLEAN) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 3, verifier);

signature = exec::FunctionSignatureBuilder()
.integerVariable("precision")
Expand All @@ -335,14 +370,28 @@ TEST_F(ArgumentTypeFuzzerTest, decimal) {
.argumentType("DECIMAL(precision, scale)")
.variableArity()
.build();
testFuzzingDecimalSuccess(signature, 1);
verifier = [](const std::vector<TypePtr>& argumentTypes,
const TypePtr& returnType) {
if (!returnType->isDecimal()) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 1, verifier);

signature = exec::FunctionSignatureBuilder()
.integerVariable("precision", "min(max(6, precision), 18)")
.returnType("timestamp")
.argumentType("decimal(precision, 6)")
.build();
testFuzzingDecimalSuccess(signature, 1, TypeKind::TIMESTAMP);
verifier = [](const std::vector<TypePtr>& argumentTypes,
const TypePtr& returnType) {
if (returnType->kind() != TypeKind::TIMESTAMP) {
return false;
}
return true;
};
testFuzzingDecimalSuccess(signature, 1, verifier);
}

TEST_F(ArgumentTypeFuzzerTest, lambda) {
Expand Down
119 changes: 27 additions & 92 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,10 @@ ExpressionFuzzer::ExpressionFuzzer(

for (const auto& it : signatureTemplates_) {
auto& returnType = it.signature->returnType().baseName();
std::string typeName = returnType;
folly::toLowerAscii(typeName);
const auto* returnTypeKey = &typeName;
if (it.typeVariables.find(typeName) != it.typeVariables.end()) {
const auto sanitizedName = exec::sanitizeName(returnType);

const auto* returnTypeKey = &sanitizedName;
if (it.typeVariables.find(sanitizedName) != it.typeVariables.end()) {
// Return type is a template variable.
returnTypeKey = &kTypeParameterName;
}
Expand Down Expand Up @@ -766,11 +766,10 @@ int ExpressionFuzzer::getTickets(const std::string& funcName) {
void ExpressionFuzzer::addToTypeToExpressionListByTicketTimes(
const std::string& type,
const std::string& funcName) {
std::string typeName = type;
folly::toLowerAscii(typeName);
const auto sanitizedName = exec::sanitizeName(type);
int tickets = getTickets(funcName);
for (int i = 0; i < tickets; i++) {
typeToExpressionList_[typeName].push_back(funcName);
typeToExpressionList_[sanitizedName].push_back(funcName);
}
}

Expand Down Expand Up @@ -1073,97 +1072,33 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
return funcIt->second(callable);
}

TypePtr ExpressionFuzzer::getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature) {
if (signature == nullptr) {
return nullptr;
}
// Checks if any variable is integer constrained, and get the decimal name
// style.
bool integerConstrained = false;
char decimalNameStyle = 0;
for (const auto& [variableName, variableInfo] : signature->variables()) {
if (variableInfo.isIntegerParameter()) {
// If constraints are empty, the integer variable is also regarded to be
// constrained as variables are shared across argument and return types.
integerConstrained = true;
if (variableName == "precision" || variableName == "scale") {
decimalNameStyle = 'a';
break;
}
if (variableName.find("precision") != std::string::npos ||
variableName.find("scale") != std::string::npos) {
decimalNameStyle = 'b';
break;
}
if (variableName.find("i") != std::string::npos) {
decimalNameStyle = 'c';
break;
}
}
}

// To handle the constraints between input types and output types of a decimal
// function, extracts the input precisions and scales from decimal arguments,
// and bind them to integer variables.
std::unordered_map<std::string, int> decimalVariablesBindings;
column_index_t decimalColIndex = 1;
for (column_index_t i = 0; i < args.size(); ++i) {
const auto argType = args[i]->type();
if (argType->isDecimal()) {
const auto [p, s] = getDecimalPrecisionScale(*argType);
switch (decimalNameStyle) {
case 'a': {
decimalVariablesBindings["precision"] = p;
decimalVariablesBindings["scale"] = s;
break;
}
case 'b': {
const auto column = std::string(1, 'a' + i);
decimalVariablesBindings[column + "_precision"] = p;
decimalVariablesBindings[column + "_scale"] = s;
break;
}
case 'c': {
decimalVariablesBindings["i" + std::to_string(decimalColIndex)] = p;
decimalVariablesBindings
["i" + std::to_string(decimalColIndex + kIntegerPairSize)] = s;
decimalColIndex++;
break;
}
default:
VELOX_UNSUPPORTED(
"Unsupported decimal name style {}.", decimalNameStyle);
}
}
}

// Calculates the matched return type through the argument types with argument
// type fuzzer, which evaluates the constraints internally.
if (integerConstrained && decimalVariablesBindings.size() > 0 && signature) {
ArgumentTypeFuzzer fuzzer{*signature, rng_, decimalVariablesBindings};
return fuzzer.fuzzReturnType();
}
return nullptr;
}

core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type,
const exec::FunctionSignature* signature) {
auto args = getArgsForCallable(callable);

// For a decimal function (especially a nested one), as argument precisions
// and scales are randomly generated, callable.returnType does not follow the
// required constraints, and the matched result type needs to be recalculated
// from the argument types. If function signature is provided, generates a
// constrained type to avoid breaking the constraints between input types and
// output types. Otherwise, generate a CallTypedExpr with type because
// callable.returnType may not have the required field names.
const auto constrainedType = getConstrainedOutputType(args, signature);
return std::make_shared<core::CallTypedExpr>(
constrainedType ? constrainedType : type, args, callable.name);
// Generate a CallTypedExpr with type because callable.returnType may not have
// the required field names.
auto outputType = type;
// If signature is provided, for a decimal function (especially a nested one),
// as argument precisions and scales are randomly generated,
// callable.returnType does not follow the required constraints, and the
// matched result type needs to be recalculated from the argument types. Use
// ArgumentTypeFuzzer to generate a constrained type to avoid breaking the
// constraints between input types and output types.
if (signature) {
std::vector<TypePtr> argTypes;
argTypes.reserve(args.size());
for (const auto& arg : args) {
argTypes.emplace_back(arg->type());
}
ArgumentTypeFuzzer fuzzer{*signature, rng_, argTypes};
if (auto constrainedType = fuzzer.fuzzReturnType()) {
outputType = constrainedType;
}
}
return std::make_shared<core::CallTypedExpr>(outputType, args, callable.name);
}

const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature(
Expand Down
6 changes: 0 additions & 6 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,6 @@ class ExpressionFuzzer {
std::vector<core::TypedExprPtr> generateSwitchArgs(
const CallableSignature& input);

/// Given the argument types, calculates the return type of a decimal function
/// by evaluating constraints.
TypePtr getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature);

core::TypedExprPtr getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type,
Expand Down
Loading

0 comments on commit 3ae1234

Please sign in to comment.