Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 29, 2024
1 parent d50c072 commit 8d0651e
Show file tree
Hide file tree
Showing 19 changed files with 409 additions and 151 deletions.
36 changes: 36 additions & 0 deletions velox/expression/tests/ArgumentGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/core/ITypedExpr.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"

namespace facebook::velox::test {

class ExpressionFuzzer;

class ArgumentGenerator {
public:
virtual ~ArgumentGenerator() = default;

/// Generates function arguments.
virtual std::vector<core::TypedExprPtr> generate(
ExpressionFuzzer* expressionFuzzer,
const CallableSignature& input,
int32_t maxNumVarArgs) = 0;
};

} // namespace facebook::velox::test
10 changes: 6 additions & 4 deletions velox/expression/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@ target_link_libraries(

add_executable(velox_expression_fuzzer_test ExpressionFuzzerTest.cpp)

target_link_libraries(velox_expression_fuzzer_test velox_expression_fuzzer
velox_functions_prestosql gtest gtest_main)
target_link_libraries(velox_expression_fuzzer_test velox_expression_fuzzer_utility
velox_expression_fuzzer velox_functions_prestosql
gtest gtest_main)

add_executable(spark_expression_fuzzer_test SparkExpressionFuzzerTest.cpp)

target_link_libraries(spark_expression_fuzzer_test velox_expression_fuzzer
velox_functions_spark gtest gtest_main)
target_link_libraries(spark_expression_fuzzer_test spark_expression_fuzzer_utility
velox_expression_fuzzer velox_functions_spark
gtest gtest_main)
117 changes: 8 additions & 109 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,13 @@ ExpressionFuzzer::ExpressionFuzzer(
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators,
const std::optional<ExpressionFuzzer::Options>& options)
: options_(options.value_or(Options())),
vectorFuzzer_(vectorFuzzer),
state{rng_, std::max(1, options_.maxLevelOfNesting)} {
state{rng_, std::max(1, options_.maxLevelOfNesting)},
customArgumentGenerators_(customArgumentGenerators) {
VELOX_CHECK(vectorFuzzer, "Vector fuzzer must be provided");
seed(initialSeed);

Expand Down Expand Up @@ -711,13 +714,6 @@ ExpressionFuzzer::ExpressionFuzzer(
// Register function override (for cases where we want to restrict the types
// or parameters we pass to functions).
registerFuncOverride(&ExpressionFuzzer::generateSwitchArgs, "switch");
registerFuncOverride(
&ExpressionFuzzer::generateExtremeFunctionArgs, "greatest");
registerFuncOverride(&ExpressionFuzzer::generateExtremeFunctionArgs, "least");
registerFuncOverride(
&ExpressionFuzzer::generateMakeTimestampArgs, "make_timestamp");
registerFuncOverride(
&ExpressionFuzzer::generateUnscaledValueArgs, "unscaled_value");
}

void ExpressionFuzzer::getTicketsForFunctions() {
Expand Down Expand Up @@ -950,84 +946,6 @@ core::TypedExprPtr ExpressionFuzzer::generateArg(
}
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateExtremeFunctionArgs(
const CallableSignature& input) {
const auto argTypes = input.args;
VELOX_CHECK_GE(
argTypes.size(),
1,
"At least one input is expected from the template signature.");
if (!argTypes[0]->isDecimal()) {
return generateArgs(input);
}

auto numVarArgs =
!input.variableArity ? 0 : rand32(0, options_.maxNumVarArgs);
std::vector<core::TypedExprPtr> inputExpressions;
inputExpressions.reserve(argTypes.size() + numVarArgs);
inputExpressions.emplace_back(
generateArg(argTypes.at(0), input.constantArgs.at(0)));

// Append varargs to the argument list.
for (int i = 0; i < numVarArgs; i++) {
core::TypedExprPtr argExpr;
// The varargs need to be generated following the result type of the first
// argument. But when nested expression is generated, that cannot be
// guaranteed as argument precisions and scales cannot be inferred from the
// result type through a decimal function signature. Given this limitation,
// generate constant or column only.
const auto argType = inputExpressions[0]->type();
if (rand32(0, 1) == kArgConstant) {
argExpr = generateArgConstant(argType);
} else {
argExpr = generateArgColumn(argType);
}
inputExpressions.emplace_back(argExpr);
}
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateMakeTimestampArgs(
const CallableSignature& input) {
VELOX_CHECK_GE(
input.args.size(),
6,
"At least six inputs are expected from the template signature.");
bool useTimezone = vectorFuzzer_->coinToss(0.5);
std::vector<core::TypedExprPtr> inputExpressions;
inputExpressions.reserve(6);
for (int index = 0; index < 5; ++index) {
inputExpressions.emplace_back(generateArg(input.args[index]));
}

// The required result type of the sixth argument is a short decimal type with
// scale being 6. But when nested expression is generated, that cannot be
// guaranteed as argument precisions and scales cannot be inferred from the
// result type through a decimal function signature. Given this limitation,
// generate constant or column only.
core::TypedExprPtr argExpr;
if (rand32(0, 1) == kArgConstant) {
argExpr = generateArgConstant(input.args[5]);
} else {
argExpr = generateArgColumn(input.args[5]);
}
inputExpressions.emplace_back(argExpr);

if (input.args.size() == 7) {
// The 7th. argument cannot be randomly generated as it should be a valid
// timezone string.
std::vector<std::string> timezoneSet = {
"Asia/Kolkata",
"America/Los_Angeles",
"Canada/Atlantic",
"+08:00",
"-10:00"};
inputExpressions.emplace_back(std::make_shared<core::ConstantTypedExpr>(
VARCHAR(), variant(timezoneSet[rand32(0, 4)])));
}
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateSwitchArgs(
const CallableSignature& input) {
VELOX_CHECK_EQ(
Expand All @@ -1050,29 +968,6 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::generateSwitchArgs(
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateUnscaledValueArgs(
const CallableSignature& input) {
VELOX_CHECK_EQ(
input.args.size(),
1,
"Only one input is expected from the template signature.");

// The required result type of input argument is a short decimal type. But
// when nested expression is generated, that cannot be guaranteed as argument
// precisions and scales cannot be inferred from the result type through a
// decimal function signature. Given this limitation, generate constant or
// column only.
std::vector<core::TypedExprPtr> inputExpressions;
core::TypedExprPtr argExpr;
if (rand32(0, 1) == kArgConstant) {
argExpr = generateArgConstant(input.args[0]);
} else {
argExpr = generateArgColumn(input.args[0]);
}
inputExpressions.emplace_back(argExpr);
return inputExpressions;
}

ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpressions(
const RowTypePtr& outType) {
state.reset();
Expand Down Expand Up @@ -1167,6 +1062,10 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(

std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
const CallableSignature& callable) {
if (customArgumentGenerators_.count(callable.name)) {
return customArgumentGenerators_[callable.name]->generate(
this, callable, options_.maxNumVarArgs);
}
auto funcIt = funcArgOverrides_.find(callable.name);
if (funcIt == funcArgOverrides_.end()) {
return generateArgs(callable);
Expand Down
49 changes: 21 additions & 28 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/core/ITypedExpr.h"
#include "velox/core/QueryCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/tests/ArgumentGenerator.h"
#include "velox/expression/tests/ExpressionVerifier.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"
#include "velox/functions/FunctionRegistry.h"
Expand Down Expand Up @@ -107,6 +108,8 @@ class ExpressionFuzzer {
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators,
const std::optional<ExpressionFuzzer::Options>& options = std::nullopt);

template <typename TFunc>
Expand Down Expand Up @@ -195,6 +198,19 @@ class ExpressionFuzzer {

RowTypePtr fuzzRowReturnType(size_t size, char prefix = 'p');

core::TypedExprPtr generateArg(const TypePtr& arg);

core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant);

std::vector<core::TypedExprPtr> generateArgs(const CallableSignature& input);

core::TypedExprPtr generateArgColumn(const TypePtr& arg);

core::TypedExprPtr generateArgConstant(const TypePtr& arg);

// Returns random integer between min and max inclusive.
int32_t rand32(int32_t min, int32_t max);

private:
// Either generates a new expression of the required return type or if
// already generated expressions of the same return type exist then there is
Expand All @@ -218,12 +234,6 @@ class ExpressionFuzzer {

void appendConjunctSignatures();

core::TypedExprPtr generateArgConstant(const TypePtr& arg);

core::TypedExprPtr generateArgColumn(const TypePtr& arg);

core::TypedExprPtr generateArg(const TypePtr& arg);

// Given lambda argument type, generate matching LambdaTypedExpr.
//
// The 'arg' specifies inputs types and result type for the lambda. This
Expand All @@ -234,25 +244,11 @@ class ExpressionFuzzer {
// all input. The constant value is generated using 'generateArgConstant'.
core::TypedExprPtr generateArgFunction(const TypePtr& arg);

std::vector<core::TypedExprPtr> generateArgs(const CallableSignature& input);

std::vector<core::TypedExprPtr> generateArgs(
const std::vector<TypePtr>& argTypes,
const std::vector<bool>& constantArgs,
uint32_t numVarArgs = 0);

core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant);

/// Specialization for the "greatest" and "least" functions: decimal varargs
/// need to be constant or column.
std::vector<core::TypedExprPtr> generateExtremeFunctionArgs(
const CallableSignature& input);

/// Specialization for the "make_timestamp" function: 1) decimal argument
/// needs to be constant or column. 2) timezone argument needs to be valid.
std::vector<core::TypedExprPtr> generateMakeTimestampArgs(
const CallableSignature& input);

/// Specialization for the "switch" function. Takes in a signature that is
/// of the form Switch (condition, then): boolean, T -> T where the type
/// variable is bounded to a randomly selected type. It randomly decides the
Expand All @@ -262,11 +258,6 @@ class ExpressionFuzzer {
std::vector<core::TypedExprPtr> generateSwitchArgs(
const CallableSignature& input);

/// Specialization for the "unscaled_value" function: decimal argument needs
/// to be constant or column.
std::vector<core::TypedExprPtr> generateUnscaledValueArgs(
const CallableSignature& input);

// Return a vector of expressions for each argument of callable in order.
std::vector<core::TypedExprPtr> getArgsForCallable(
const CallableSignature& callable);
Expand Down Expand Up @@ -352,9 +343,6 @@ class ExpressionFuzzer {
state.expressionStats_[funcName]++;
}

// Returns random integer between min and max inclusive.
int32_t rand32(int32_t min, int32_t max);

static const inline std::string kTypeParameterName = "T";

const Options options_;
Expand Down Expand Up @@ -441,6 +429,11 @@ class ExpressionFuzzer {
int32_t remainingLevelOfNesting_;

} state;

// Maps from function name to the generator that generates custom arguments.
std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>
customArgumentGenerators_;

friend class ExpressionFuzzerUnitTest;
};

Expand Down
16 changes: 15 additions & 1 deletion velox/expression/tests/ExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <unordered_set>

#include "velox/expression/tests/FuzzerRunner.h"
#include "velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

DEFINE_int64(
Expand Down Expand Up @@ -65,6 +66,19 @@ int main(int argc, char** argv) {
"regexp_extract_all",
"regexp_like",
};

const std::unordered_map<
std::string,
std::shared_ptr<facebook::velox::test::ArgumentGenerator>>
customArgumentGenerators = {
{"greatest",
std::make_shared<
facebook::velox::functions::test::ExtremeArgumentGenerator>()},
{"least",
std::make_shared<
facebook::velox::functions::test::ExtremeArgumentGenerator>()}};

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;
return FuzzerRunner::run(initialSeed, skipFunctions, {{}});
return FuzzerRunner::run(
initialSeed, skipFunctions, {{}}, customArgumentGenerators);
}
4 changes: 4 additions & 0 deletions velox/expression/tests/ExpressionFuzzerUnitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ TEST_F(ExpressionFuzzerUnitTest, restrictedLevelOfNesting) {
velox::getFunctionSignatures(),
0,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(maxLevelOfNesting),
};

Expand Down Expand Up @@ -116,6 +117,7 @@ TEST_F(ExpressionFuzzerUnitTest, reproduceExpressionWithSeed) {
velox::getFunctionSignatures(),
1234567,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(5)};
for (auto i = 0; i < 10; ++i) {
firstGeneration.push_back(
Expand All @@ -142,6 +144,7 @@ TEST_F(ExpressionFuzzerUnitTest, exprBank) {
velox::getFunctionSignatures(),
0,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(maxLevelOfNesting)};
ExpressionFuzzer::ExprBank exprBank(seed, maxLevelOfNesting);
for (int i = 0; i < 5000; ++i) {
Expand Down Expand Up @@ -170,6 +173,7 @@ TEST_F(ExpressionFuzzerUnitTest, exprBank) {
velox::getFunctionSignatures(),
0,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(maxLevelOfNesting)};
ExpressionFuzzer::ExprBank exprBank(seed, maxLevelOfNesting);
for (int i = 0; i < 1000; ++i) {
Expand Down
5 changes: 4 additions & 1 deletion velox/expression/tests/ExpressionFuzzerVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ RowVectorPtr wrapChildren(
ExpressionFuzzerVerifier::ExpressionFuzzerVerifier(
const FunctionSignatureMap& signatureMap,
size_t initialSeed,
const ExpressionFuzzerVerifier::Options& options)
const ExpressionFuzzerVerifier::Options& options,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators)
: options_(options),
queryCtx_(std::make_shared<core::QueryCtx>(
nullptr,
Expand All @@ -98,6 +100,7 @@ ExpressionFuzzerVerifier::ExpressionFuzzerVerifier(
signatureMap,
initialSeed,
vectorFuzzer_,
customArgumentGenerators,
options.expressionFuzzerOptions) {
seed(initialSeed);

Expand Down
Loading

0 comments on commit 8d0651e

Please sign in to comment.