Skip to content

Commit

Permalink
make skipFunctions, onlyFunctions, and specialForms as ExpressionFuzz…
Browse files Browse the repository at this point in the history
…er options. (#7882)

Summary:
Pull Request resolved: #7882

More refactoring for the expression fuzzer.

1) Push the skipFunctions, onlyFunctions, and specialForms as ExpressionFuzzer options.
2) FuzzerRunner: just a tool that wrap ExpressionFuzzerVerifier into a unit test.
3) Move the comment from FuzzerRunner class to ExpressionFuzzerVerifier since it describes the later.

Next diff :
4) Move all the flags from ExpressionFuzzerVerifier to FuzzerRunner and pass them through ExpressionFuzzerVerifier::Options .

1. spark fuzzer used to only support and, or not it uses all of them "and,or,cast,coalesce,if,switch".

Reviewed By: kevinwilfong

Differential Revision: D51856248

fbshipit-source-id: f06c9667d05f82cb6f572bdf42e5dab9e9315ae2
  • Loading branch information
laithsakka authored and facebook-github-bot committed Dec 6, 2023
1 parent e7bebbd commit e963545
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 394 deletions.
307 changes: 291 additions & 16 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
namespace facebook::velox::test {

namespace {

using exec::SignatureBinder;
using exec::SignatureBinderBase;

Expand Down Expand Up @@ -71,15 +70,286 @@ class FullSignatureBinder : public SignatureBinderBase {
bool bound_{false};
};

static const std::vector<std::string> kIntegralTypes{
"tinyint",
"smallint",
"integer",
"bigint",
"boolean"};

static const std::vector<std::string> kFloatingPointTypes{"real", "double"};

facebook::velox::exec::FunctionSignaturePtr makeCastSignature(
const std::string& fromType,
const std::string& toType) {
return facebook::velox::exec::FunctionSignatureBuilder()
.argumentType(fromType)
.returnType(toType)
.build();
}

void addCastFromIntegralSignatures(
const std::string& toType,
std::vector<facebook::velox::exec::FunctionSignaturePtr>& signatures) {
for (const auto& fromType : kIntegralTypes) {
signatures.push_back(makeCastSignature(fromType, toType));
}
}

void addCastFromFloatingPointSignatures(
const std::string& toType,
std::vector<facebook::velox::exec::FunctionSignaturePtr>& signatures) {
for (const auto& fromType : kFloatingPointTypes) {
signatures.push_back(makeCastSignature(fromType, toType));
}
}

void addCastFromVarcharSignature(
const std::string& toType,
std::vector<facebook::velox::exec::FunctionSignaturePtr>& signatures) {
signatures.push_back(makeCastSignature("varchar", toType));
}

void addCastFromTimestampSignature(
const std::string& toType,
std::vector<facebook::velox::exec::FunctionSignaturePtr>& signatures) {
signatures.push_back(makeCastSignature("timestamp", toType));
}

void addCastFromDateSignature(
const std::string& toType,
std::vector<facebook::velox::exec::FunctionSignaturePtr>& signatures) {
signatures.push_back(makeCastSignature("date", toType));
}

std::vector<facebook::velox::exec::FunctionSignaturePtr>
getSignaturesForCast() {
std::vector<facebook::velox::exec::FunctionSignaturePtr> signatures;

// To integral types.
for (const auto& toType : kIntegralTypes) {
addCastFromIntegralSignatures(toType, signatures);
addCastFromFloatingPointSignatures(toType, signatures);
addCastFromVarcharSignature(toType, signatures);
}

// To floating-point types.
for (const auto& toType : kFloatingPointTypes) {
addCastFromIntegralSignatures(toType, signatures);
addCastFromFloatingPointSignatures(toType, signatures);
addCastFromVarcharSignature(toType, signatures);
}

// To varchar type.
addCastFromIntegralSignatures("varchar", signatures);
addCastFromFloatingPointSignatures("varchar", signatures);
addCastFromVarcharSignature("varchar", signatures);
addCastFromDateSignature("varchar", signatures);
addCastFromTimestampSignature("varchar", signatures);

// To timestamp type.
addCastFromVarcharSignature("timestamp", signatures);
addCastFromDateSignature("timestamp", signatures);

// To date type.
addCastFromVarcharSignature("date", signatures);
addCastFromTimestampSignature("date", signatures);

// For each supported translation pair T --> U, add signatures of array(T) -->
// array(U), map(varchar, T) --> map(varchar, U), row(T) --> row(U).
auto size = signatures.size();
for (auto i = 0; i < size; ++i) {
auto from = signatures[i]->argumentTypes()[0].baseName();
auto to = signatures[i]->returnType().baseName();

signatures.push_back(makeCastSignature(
fmt::format("array({})", from), fmt::format("array({})", to)));

signatures.push_back(makeCastSignature(
fmt::format("map(varchar, {})", from),
fmt::format("map(varchar, {})", to)));

signatures.push_back(makeCastSignature(
fmt::format("row({})", from), fmt::format("row({})", to)));
}
return signatures;
}

static const std::unordered_map<
std::string,
std::vector<facebook::velox::exec::FunctionSignaturePtr>>
kSpecialForms = {
{"and",
std::vector<facebook::velox::exec::FunctionSignaturePtr>{
// Signature: and (condition,...) -> output:
// boolean, boolean,.. -> boolean
facebook::velox::exec::FunctionSignatureBuilder()
.argumentType("boolean")
.argumentType("boolean")
.variableArity()
.returnType("boolean")
.build()}},
{"or",
std::vector<facebook::velox::exec::FunctionSignaturePtr>{
// Signature: or (condition,...) -> output:
// boolean, boolean,.. -> boolean
facebook::velox::exec::FunctionSignatureBuilder()
.argumentType("boolean")
.argumentType("boolean")
.variableArity()
.returnType("boolean")
.build()}},
{"coalesce",
std::vector<facebook::velox::exec::FunctionSignaturePtr>{
// Signature: coalesce (input,...) -> output:
// T, T,.. -> T
facebook::velox::exec::FunctionSignatureBuilder()
.typeVariable("T")
.argumentType("T")
.argumentType("T")
.variableArity()
.returnType("T")
.build()}},
{
"if",
std::vector<facebook::velox::exec::FunctionSignaturePtr>{
// Signature: if (condition, then) -> output:
// boolean, T -> T
facebook::velox::exec::FunctionSignatureBuilder()
.typeVariable("T")
.argumentType("boolean")
.argumentType("T")
.returnType("T")
.build(),
// Signature: if (condition, then, else) -> output:
// boolean, T, T -> T
facebook::velox::exec::FunctionSignatureBuilder()
.typeVariable("T")
.argumentType("boolean")
.argumentType("T")
.argumentType("T")
.returnType("T")
.build()},
},
{
"switch",
std::vector<facebook::velox::exec::FunctionSignaturePtr>{
// Signature: Switch (condition, then) -> output:
// boolean, T -> T
// This is only used to bind to a randomly selected type for the
// output, then while generating arguments, an override is used
// to generate inputs that can create variation of multiple
// cases and may or may not include a final else clause.
facebook::velox::exec::FunctionSignatureBuilder()
.typeVariable("T")
.argumentType("boolean")
.argumentType("T")
.returnType("T")
.build()},
},
{
"cast",
/// TODO: Add supported Cast signatures to CastTypedExpr and expose
/// them to fuzzer instead of hard-coding signatures here.
getSignaturesForCast(),
},
};

static std::unordered_set<std::string> splitNames(const std::string& names) {
// Parse, lower case and trim it.
std::vector<folly::StringPiece> nameList;
folly::split(',', names, nameList);
std::unordered_set<std::string> nameSet;

for (const auto& it : nameList) {
auto str = folly::trimWhitespace(it).toString();
folly::toLowerAscii(str);
nameSet.insert(str);
}
return nameSet;
}

static std::pair<std::string, std::string> splitSignature(
const std::string& signature) {
const auto parenPos = signature.find("(");

if (parenPos != std::string::npos) {
return {signature.substr(0, parenPos), signature.substr(parenPos)};
}

return {signature, ""};
}

// Parse the comma separated list of function names, and use it to filter the
// input signatures.
static void filterSignatures(
facebook::velox::FunctionSignatureMap& input,
const std::string& onlyFunctions,
const std::unordered_set<std::string>& skipFunctions) {
if (!onlyFunctions.empty()) {
// Parse, lower case and trim it.
auto nameSet = splitNames(onlyFunctions);

// Use the generated set to filter the input signatures.
for (auto it = input.begin(); it != input.end();) {
if (!nameSet.count(it->first)) {
it = input.erase(it);
} else
it++;
}
}

for (auto skip : skipFunctions) {
// 'skip' can be function name or signature.
const auto [skipName, skipSignature] = splitSignature(skip);

if (skipSignature.empty()) {
input.erase(skipName);
} else {
auto it = input.find(skipName);
if (it != input.end()) {
// Compiler refuses to reference 'skipSignature' from the lambda as
// is.
const auto& signatureToRemove = skipSignature;

auto removeIt = std::find_if(
it->second.begin(), it->second.end(), [&](const auto& signature) {
return signature->toString() == signatureToRemove;
});
VELOX_CHECK(
removeIt != it->second.end(), "Skip signature not found: {}", skip);
it->second.erase(removeIt);
}
}
}
}

static void appendSpecialForms(
facebook::velox::FunctionSignatureMap& signatureMap,
const std::string& specialForms) {
auto specialFormNames = splitNames(specialForms);
for (const auto& [name, signatures] : kSpecialForms) {
if (specialFormNames.count(name) == 0) {
LOG(INFO) << "Skipping special form: " << name;
continue;
}
std::vector<const facebook::velox::exec::FunctionSignature*> rawSignatures;
for (const auto& signature : signatures) {
rawSignatures.push_back(signature.get());
}
signatureMap.insert({name, std::move(rawSignatures)});
}
}

/// Returns if `functionName` with the given `argTypes` is deterministic.
/// Returns true if the function was not found or determinism cannot be
/// established.
bool isDeterministic(
const std::string& functionName,
const std::vector<TypePtr>& argTypes) {
// We know that the 'cast', 'and', and 'or' special forms are deterministic.
// Hard-code them here because they are not real functions and hence cannot be
// resolved by the code below.
// Hard-code them here because they are not real functions and hence cannot
// be resolved by the code below.
if (functionName == "and" || functionName == "or" ||
functionName == "coalesce" || functionName == "if" ||
functionName == "switch" || functionName == "cast") {
Expand All @@ -93,9 +363,9 @@ bool isDeterministic(
}

// Vector functions are a bit more complicated. We need to fetch the list of
// available signatures and check if any of them bind given the current input
// arg types. If it binds (if there's a match), we fetch the function and
// return the isDeterministic bool.
// available signatures and check if any of them bind given the current
// input arg types. If it binds (if there's a match), we fetch the function
// and return the isDeterministic bool.
try {
if (auto vectorFunctionSignatures =
exec::getVectorFunctionSignatures(functionName)) {
Expand All @@ -110,12 +380,12 @@ bool isDeterministic(
}
}
}
// TODO: Some stateful functions can only be built when constant arguments are
// passed, making the getVectorFunction() call above to throw. We only have a
// few of these functions, so for now we assume they are deterministic so they
// are picked for Fuzz testing. Once we make the isDeterministic() flag static
// (and hence we won't need to build the function object in here) we can clean
// up this code.
// TODO: Some stateful functions can only be built when constant arguments
// are passed, making the getVectorFunction() call above to throw. We only
// have a few of these functions, so for now we assume they are
// deterministic so they are picked for Fuzz testing. Once we make the
// isDeterministic() flag static (and hence we won't need to build the
// function object in here) we can clean up this code.
catch (const std::exception& e) {
LOG(WARNING) << "Unable to determine if '" << functionName
<< "' is deterministic or not. Assuming it is.";
Expand Down Expand Up @@ -178,8 +448,8 @@ bool containTypeName(
return false;
}

// Determine whether the signature has an argument or return type that contains
// typeName. typeName should be in lower case.
// Determine whether the signature has an argument or return type that
// contains typeName. typeName should be in lower case.
bool useTypeName(
const exec::FunctionSignature& signature,
const std::string& typeName) {
Expand Down Expand Up @@ -237,7 +507,8 @@ BufferPtr extractNonNullIndices(const RowVectorPtr& data) {
}

/// Wraps child vectors of the specified 'rowVector' in dictionary using
/// specified 'indices'. Returns new RowVector created from the wrapped vectors.
/// specified 'indices'. Returns new RowVector created from the wrapped
/// vectors.
RowVectorPtr wrapChildren(
const BufferPtr& indices,
const RowVectorPtr& rowVector) {
Expand Down Expand Up @@ -283,7 +554,7 @@ uint32_t levelOfNesting(const TypePtr& type) {
} // namespace

ExpressionFuzzer::ExpressionFuzzer(
const FunctionSignatureMap& signatureMap,
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::optional<ExpressionFuzzer::Options>& options)
Expand All @@ -293,6 +564,10 @@ ExpressionFuzzer::ExpressionFuzzer(
VELOX_CHECK(vectorFuzzer, "Vector fuzzer must be provided");
seed(initialSeed);

appendSpecialForms(signatureMap, options_.specialForms);
filterSignatures(
signatureMap, options_.useOnlyFunctions, options_.skipFunctions);

size_t totalFunctions = 0;
size_t totalFunctionSignatures = 0;
size_t supportedFunctionSignatures = 0;
Expand Down
Loading

0 comments on commit e963545

Please sign in to comment.