Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 29, 2024
1 parent d50c072 commit 4c9dd49
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 27 deletions.
35 changes: 35 additions & 0 deletions velox/expression/tests/ArgumentGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/core/ITypedExpr.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"

namespace facebook::velox::test {

class ExpressionFuzzer;

class ArgumentGenerator {
public:
virtual ~ArgumentGenerator() = default;

/// Generates function arguments.
virtual std::vector<core::TypedExprPtr> generate(
ExpressionFuzzer* expressionFuzzer,
const CallableSignature& input) = 0;
};

} // namespace facebook::velox::test
5 changes: 3 additions & 2 deletions velox/expression/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ target_link_libraries(

add_executable(velox_expression_fuzzer_test ExpressionFuzzerTest.cpp)

target_link_libraries(velox_expression_fuzzer_test velox_expression_fuzzer
velox_functions_prestosql gtest gtest_main)
target_link_libraries(velox_expression_fuzzer_test velox_expression_fuzzer_utility
velox_expression_fuzzer velox_functions_prestosql
gtest gtest_main)

add_executable(spark_expression_fuzzer_test SparkExpressionFuzzerTest.cpp)

Expand Down
11 changes: 7 additions & 4 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,13 @@ ExpressionFuzzer::ExpressionFuzzer(
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators,
const std::optional<ExpressionFuzzer::Options>& options)
: options_(options.value_or(Options())),
vectorFuzzer_(vectorFuzzer),
state{rng_, std::max(1, options_.maxLevelOfNesting)} {
state{rng_, std::max(1, options_.maxLevelOfNesting)},
customArgumentGenerators_(customArgumentGenerators) {
VELOX_CHECK(vectorFuzzer, "Vector fuzzer must be provided");
seed(initialSeed);

Expand Down Expand Up @@ -711,9 +714,6 @@ ExpressionFuzzer::ExpressionFuzzer(
// Register function override (for cases where we want to restrict the types
// or parameters we pass to functions).
registerFuncOverride(&ExpressionFuzzer::generateSwitchArgs, "switch");
registerFuncOverride(
&ExpressionFuzzer::generateExtremeFunctionArgs, "greatest");
registerFuncOverride(&ExpressionFuzzer::generateExtremeFunctionArgs, "least");
registerFuncOverride(
&ExpressionFuzzer::generateMakeTimestampArgs, "make_timestamp");
registerFuncOverride(
Expand Down Expand Up @@ -1167,6 +1167,9 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(

std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
const CallableSignature& callable) {
if (customArgumentGenerators_.count(callable.name)) {
return customArgumentGenerators_[callable.name]->generate(this, callable);
}
auto funcIt = funcArgOverrides_.find(callable.name);
if (funcIt == funcArgOverrides_.end()) {
return generateArgs(callable);
Expand Down
29 changes: 18 additions & 11 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/core/ITypedExpr.h"
#include "velox/core/QueryCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/tests/ArgumentGenerator.h"
#include "velox/expression/tests/ExpressionVerifier.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"
#include "velox/functions/FunctionRegistry.h"
Expand Down Expand Up @@ -107,6 +108,8 @@ class ExpressionFuzzer {
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators,
const std::optional<ExpressionFuzzer::Options>& options = std::nullopt);

template <typename TFunc>
Expand Down Expand Up @@ -195,6 +198,17 @@ class ExpressionFuzzer {

RowTypePtr fuzzRowReturnType(size_t size, char prefix = 'p');

core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant);

core::TypedExprPtr generateArgConstant(const TypePtr& arg);

core::TypedExprPtr generateArgColumn(const TypePtr& arg);

std::vector<core::TypedExprPtr> generateArgs(const CallableSignature& input);

// Returns random integer between min and max inclusive.
int32_t rand32(int32_t min, int32_t max);

private:
// Either generates a new expression of the required return type or if
// already generated expressions of the same return type exist then there is
Expand All @@ -218,10 +232,6 @@ class ExpressionFuzzer {

void appendConjunctSignatures();

core::TypedExprPtr generateArgConstant(const TypePtr& arg);

core::TypedExprPtr generateArgColumn(const TypePtr& arg);

core::TypedExprPtr generateArg(const TypePtr& arg);

// Given lambda argument type, generate matching LambdaTypedExpr.
Expand All @@ -234,15 +244,11 @@ class ExpressionFuzzer {
// all input. The constant value is generated using 'generateArgConstant'.
core::TypedExprPtr generateArgFunction(const TypePtr& arg);

std::vector<core::TypedExprPtr> generateArgs(const CallableSignature& input);

std::vector<core::TypedExprPtr> generateArgs(
const std::vector<TypePtr>& argTypes,
const std::vector<bool>& constantArgs,
uint32_t numVarArgs = 0);

core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant);

/// Specialization for the "greatest" and "least" functions: decimal varargs
/// need to be constant or column.
std::vector<core::TypedExprPtr> generateExtremeFunctionArgs(
Expand Down Expand Up @@ -352,9 +358,6 @@ class ExpressionFuzzer {
state.expressionStats_[funcName]++;
}

// Returns random integer between min and max inclusive.
int32_t rand32(int32_t min, int32_t max);

static const inline std::string kTypeParameterName = "T";

const Options options_;
Expand Down Expand Up @@ -441,6 +444,10 @@ class ExpressionFuzzer {
int32_t remainingLevelOfNesting_;

} state;

std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>
customArgumentGenerators_;

friend class ExpressionFuzzerUnitTest;
};

Expand Down
15 changes: 14 additions & 1 deletion velox/expression/tests/ExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <unordered_set>

#include "velox/expression/tests/FuzzerRunner.h"
#include "velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

DEFINE_int64(
Expand Down Expand Up @@ -65,6 +66,18 @@ int main(int argc, char** argv) {
"regexp_extract_all",
"regexp_like",
};

const std::unordered_map<
std::string,
std::shared_ptr<facebook::velox::test::ArgumentGenerator>>
customArgumentGenerators = {
{"greatest",
std::make_shared<facebook::velox::test::ExtremeArgumentGenerator>()},
{"least",
std::make_shared<
facebook::velox::test::ExtremeArgumentGenerator>()}};

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;
return FuzzerRunner::run(initialSeed, skipFunctions, {{}});
return FuzzerRunner::run(
initialSeed, skipFunctions, {{}}, customArgumentGenerators);
}
4 changes: 4 additions & 0 deletions velox/expression/tests/ExpressionFuzzerUnitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ TEST_F(ExpressionFuzzerUnitTest, restrictedLevelOfNesting) {
velox::getFunctionSignatures(),
0,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(maxLevelOfNesting),
};

Expand Down Expand Up @@ -116,6 +117,7 @@ TEST_F(ExpressionFuzzerUnitTest, reproduceExpressionWithSeed) {
velox::getFunctionSignatures(),
1234567,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(5)};
for (auto i = 0; i < 10; ++i) {
firstGeneration.push_back(
Expand All @@ -142,6 +144,7 @@ TEST_F(ExpressionFuzzerUnitTest, exprBank) {
velox::getFunctionSignatures(),
0,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(maxLevelOfNesting)};
ExpressionFuzzer::ExprBank exprBank(seed, maxLevelOfNesting);
for (int i = 0; i < 5000; ++i) {
Expand Down Expand Up @@ -170,6 +173,7 @@ TEST_F(ExpressionFuzzerUnitTest, exprBank) {
velox::getFunctionSignatures(),
0,
vectorfuzzer,
{},
makeOptionsWithMaxLevelNesting(maxLevelOfNesting)};
ExpressionFuzzer::ExprBank exprBank(seed, maxLevelOfNesting);
for (int i = 0; i < 1000; ++i) {
Expand Down
5 changes: 4 additions & 1 deletion velox/expression/tests/ExpressionFuzzerVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ RowVectorPtr wrapChildren(
ExpressionFuzzerVerifier::ExpressionFuzzerVerifier(
const FunctionSignatureMap& signatureMap,
size_t initialSeed,
const ExpressionFuzzerVerifier::Options& options)
const ExpressionFuzzerVerifier::Options& options,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators)
: options_(options),
queryCtx_(std::make_shared<core::QueryCtx>(
nullptr,
Expand All @@ -98,6 +100,7 @@ ExpressionFuzzerVerifier::ExpressionFuzzerVerifier(
signatureMap,
initialSeed,
vectorFuzzer_,
customArgumentGenerators,
options.expressionFuzzerOptions) {
seed(initialSeed);

Expand Down
4 changes: 3 additions & 1 deletion velox/expression/tests/ExpressionFuzzerVerifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class ExpressionFuzzerVerifier {
ExpressionFuzzerVerifier(
const FunctionSignatureMap& signatureMap,
size_t initialSeed,
const Options& options);
const Options& options,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators);

// This function starts the test that is performed by the
// ExpressionFuzzerVerifier which is generating random expressions and
Expand Down
13 changes: 9 additions & 4 deletions velox/expression/tests/FuzzerRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,27 @@ ExpressionFuzzerVerifier::Options getExpressionFuzzerVerifierOptions(
int FuzzerRunner::run(
size_t seed,
const std::unordered_set<std::string>& skipFunctions,
const std::unordered_map<std::string, std::string>& queryConfigs) {
runFromGtest(seed, skipFunctions, queryConfigs);
const std::unordered_map<std::string, std::string>& queryConfigs,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators) {
runFromGtest(seed, skipFunctions, queryConfigs, customArgumentGenerators);
return RUN_ALL_TESTS();
}

// static
void FuzzerRunner::runFromGtest(
size_t seed,
const std::unordered_set<std::string>& skipFunctions,
const std::unordered_map<std::string, std::string>& queryConfigs) {
const std::unordered_map<std::string, std::string>& queryConfigs,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators) {
memory::MemoryManager::testingSetInstance({});
auto signatures = facebook::velox::getFunctionSignatures();
ExpressionFuzzerVerifier(
signatures,
seed,
getExpressionFuzzerVerifierOptions(skipFunctions, queryConfigs))
getExpressionFuzzerVerifierOptions(skipFunctions, queryConfigs),
customArgumentGenerators)
.go();
}
} // namespace facebook::velox::test
9 changes: 7 additions & 2 deletions velox/expression/tests/FuzzerRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <unordered_set>
#include <vector>

#include "velox/expression/tests/ArgumentGenerator.h"
#include "velox/expression/tests/ExpressionFuzzerVerifier.h"
#include "velox/functions/FunctionRegistry.h"

Expand All @@ -33,12 +34,16 @@ class FuzzerRunner {
static int run(
size_t seed,
const std::unordered_set<std::string>& skipFunctions,
const std::unordered_map<std::string, std::string>& queryConfigs);
const std::unordered_map<std::string, std::string>& queryConfigs,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators);

static void runFromGtest(
size_t seed,
const std::unordered_set<std::string>& skipFunctions,
const std::unordered_map<std::string, std::string>& queryConfigs);
const std::unordered_map<std::string, std::string>& queryConfigs,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators);
};

} // namespace facebook::velox::test
8 changes: 7 additions & 1 deletion velox/expression/tests/SparkExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,11 @@ int main(int argc, char** argv) {
{facebook::velox::core::QueryConfig::kSessionTimezone,
"America/Los_Angeles"}};

return FuzzerRunner::run(FLAGS_seed, skipFunctions, queryConfigs);
const std::unordered_map<
std::string,
std::shared_ptr<facebook::velox::test::ArgumentGenerator>>
customArgumentGenerators = {};

return FuzzerRunner::run(
FLAGS_seed, skipFunctions, queryConfigs, customArgumentGenerators);
}
7 changes: 7 additions & 0 deletions velox/functions/prestosql/fuzzer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ target_link_libraries(
velox_functions_prestosql
gtest
gtest_main)

add_executable(velox_expression_fuzzer_utility ExtremeArgumentGenerator.cpp)
target_link_libraries(
velox_expression_fuzzer_utility
velox_expression_fuzzer
gtest
gtest_main)
60 changes: 60 additions & 0 deletions velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h"
#include "velox/expression/tests/ExpressionFuzzer.h"

namespace facebook::velox::test {

std::vector<core::TypedExprPtr> ExtremeArgumentGenerator::generate(
ExpressionFuzzer* expressionFuzzer,
const CallableSignature& input) {
const auto argTypes = input.args;
VELOX_CHECK_GE(
argTypes.size(),
1,
"At least one input is expected from the template signature.");
if (!argTypes[0]->isDecimal()) {
return expressionFuzzer->generateArgs(input);
}

auto numVarArgs = !input.variableArity
? 0
: expressionFuzzer->rand32(0, options_.maxNumVarArgs);
std::vector<core::TypedExprPtr> inputExpressions;
inputExpressions.reserve(argTypes.size() + numVarArgs);
inputExpressions.emplace_back(
expressionFuzzer->generateArg(argTypes.at(0), input.constantArgs.at(0)));

// Append varargs to the argument list.
for (int i = 0; i < numVarArgs; i++) {
core::TypedExprPtr argExpr;
// The varargs need to be generated following the result type of the first
// argument. But when nested expression is generated, that cannot be
// guaranteed as argument precisions and scales cannot be inferred from the
// result type through a decimal function signature. Given this limitation,
// generate constant or column only.
const auto argType = inputExpressions[0]->type();
if (expressionFuzzer->rand32(0, 1) == kArgConstant) {
argExpr = expressionFuzzer->generateArgConstant(argType);
} else {
argExpr = expressionFuzzer->generateArgColumn(argType);
}
inputExpressions.emplace_back(argExpr);
}
return inputExpressions;
}

} // namespace facebook::velox::test
Loading

0 comments on commit 4c9dd49

Please sign in to comment.