Skip to content

Commit

Permalink
support simple
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 21, 2024
1 parent 9263b90 commit 1aa93b2
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 35 deletions.
91 changes: 61 additions & 30 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1104,50 +1104,81 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
return funcIt->second(callable);
}

core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type,
TypePtr ExpressionFuzzer::getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature) {
auto args = getArgsForCallable(callable);

// When function is unnested, the types of args are decided by fuzzer argument
// types. For nested function, they are decided by the return types of
// children functions. To handle the constraints between input types and
// output types of a decimal function, extract the input precisions and scales
// from decimal arguments, and bind them to integer variables.
std::unordered_map<std::string, int> integerVariablesBindings;

// Checks if any variable is integer constrained, and get the decimal name
// style.
bool integerConstrained = false;
char decimalNameStyle = 0;
for (const auto& [variableName, variableInfo] : signature->variables()) {
const auto constraint = variableInfo.constraint();
if (variableInfo.isIntegerParameter()) {
if (variableName.find("precision") != std::string::npos ||
variableName.find("scale") != std::string::npos) {
decimalNameStyle = 'p';
} else if (variableName.find("i") != std::string::npos) {
decimalNameStyle = 'i';
}
integerConstrained = true;
break;
}
}

std::unordered_map<std::string, int> decimalVariablesBindings;
column_index_t decimalColIndex = 1;
for (column_index_t i = 0; i < args.size(); ++i) {
const auto argType = args[i]->type();
if (argType->isDecimal()) {
const auto [p, s] = getDecimalPrecisionScale(*argType);
const auto column = std::string(1, 'a' + i);
integerVariablesBindings[column + "_precision"] = p;
integerVariablesBindings[column + "_scale"] = s;
switch (decimalNameStyle) {
case 'p': {
const auto column = std::string(1, 'a' + i);
decimalVariablesBindings[column + "_precision"] = p;
decimalVariablesBindings[column + "_scale"] = s;
break;
}
case 'i': {
decimalVariablesBindings["i" + std::to_string(decimalColIndex)] = p;
decimalVariablesBindings
["i" + std::to_string(decimalColIndex + kIntegerPairSize)] = s;
decimalColIndex++;
break;
}
default:
VELOX_UNSUPPORTED("Unsupported decimal name style.");
}
}
}

// Generate a CallTypedExpr with type because callable.returnType may not have
// the required field names.
auto outputType = type;

if (integerVariablesBindings.size() > 0 && signature) {
// Checks if any variable is integer constrained.
bool integerConstrained = false;
for (const auto& [_, variableInfo] : signature->variables()) {
const auto constraint = variableInfo.constraint();
if (variableInfo.isIntegerParameter() && constraint != "") {
integerConstrained = true;
break;
}
}
if (integerConstrained) {
// Compute a correct output type according to constraints with the integer
// variables bindings.
ArgumentTypeFuzzer fuzzer{*signature, rng_, integerVariablesBindings};
outputType = fuzzer.fuzzReturnType();
}
if (integerConstrained && decimalVariablesBindings.size() > 0 && signature) {
// Compute a correct output type according to constraints with the integer
// variables bindings.
ArgumentTypeFuzzer fuzzer{*signature, rng_, decimalVariablesBindings};
return fuzzer.fuzzReturnType();
}
return std::make_shared<core::CallTypedExpr>(outputType, args, callable.name);
return nullptr;
}

core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type,
const exec::FunctionSignature* signature) {
auto args = getArgsForCallable(callable);

// If a constrained output type is generated, use it to avoid breaking the
// constraints between input types and output types. Otherwise, generate a
// CallTypedExpr with type because callable.returnType may not have the
// required field names.
const auto constrainedType = getConstrainedOutputType(args, signature);
return std::make_shared<core::CallTypedExpr>(
constrainedType ? constrainedType : type, args, callable.name);
}

const CallableSignature* ExpressionFuzzer::chooseRandomConcreteSignature(
Expand Down
4 changes: 4 additions & 0 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ class ExpressionFuzzer {
std::vector<core::TypedExprPtr> generateSwitchArgs(
const CallableSignature& input);

TypePtr getConstrainedOutputType(
const std::vector<core::TypedExprPtr>& args,
const exec::FunctionSignature* signature);

core::TypedExprPtr getCallExprFromCallable(
const CallableSignature& callable,
const TypePtr& type,
Expand Down
23 changes: 18 additions & 5 deletions velox/expression/tests/utils/ArgumentTypeFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace facebook::velox::test {

std::string typeToBaseName(const TypePtr& type) {
return type->isDecimal()
? "DECIMAL"
? "decimal" // TODO
: boost::algorithm::to_lower_copy(std::string{type->kindName()});
}

Expand All @@ -46,10 +46,7 @@ void ArgumentTypeFuzzer::determineUnboundedIntegerVariables() {
continue;
}

if (auto pos = variableName.find("precision"); pos == std::string::npos) {
integerVariablesBindings_[variableName] =
boost::random::uniform_int_distribution<int32_t>()(rng_);
} else {
if (auto pos = variableName.find("precision"); pos != std::string::npos) {
// Handle decimal precisions and scales.
const auto precision =
boost::random::uniform_int_distribution<uint32_t>(1, 38)(rng_);
Expand All @@ -63,6 +60,22 @@ void ArgumentTypeFuzzer::determineUnboundedIntegerVariables() {
0, precision)(rng_);
}
}
} else if (auto pos = variableName.find("i"); pos != std::string::npos) {
VELOX_USER_CHECK_GE(variableName.size(), 2);
auto index = std::stoi(variableName.substr(pos + 1, variableName.size()));
if (index <= kIntegerPairSize) {
const auto precision =
boost::random::uniform_int_distribution<uint32_t>(1, 38)(rng_);
integerVariablesBindings_[variableName] = precision;
const auto scaleIndex = index + kIntegerPairSize;
const auto scaleName = "i" + std::to_string(scaleIndex);
integerVariablesBindings_[scaleName] =
boost::random::uniform_int_distribution<uint32_t>(
0, precision)(rng_);
}
} else {
integerVariablesBindings_[variableName] =
boost::random::uniform_int_distribution<int32_t>()(rng_);
}
}

Expand Down
1 change: 1 addition & 0 deletions velox/type/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,7 @@ using S1 = IntegerVariable<5>;
using S2 = IntegerVariable<6>;
using S3 = IntegerVariable<7>;
using S4 = IntegerVariable<8>;
const uint8_t kIntegerPairSize = 4;

template <typename P, typename S>
struct ShortDecimal {
Expand Down

0 comments on commit 1aa93b2

Please sign in to comment.