From 551463d17b3e8ebb68078e95d1f933cced2e1be5 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Fri, 29 Mar 2024 14:55:28 +0800 Subject: [PATCH] Refactor --- velox/expression/tests/ArgumentGenerator.h | 35 +++++++++++ velox/expression/tests/ExpressionFuzzer.cpp | 11 ++-- velox/expression/tests/ExpressionFuzzer.h | 29 +++++---- .../expression/tests/ExpressionFuzzerTest.cpp | 13 +++- .../tests/ExpressionFuzzerUnitTest.cpp | 4 ++ .../tests/ExpressionFuzzerVerifier.cpp | 5 +- .../tests/ExpressionFuzzerVerifier.h | 4 +- velox/expression/tests/FuzzerRunner.cpp | 13 ++-- velox/expression/tests/FuzzerRunner.h | 9 ++- .../tests/SparkExpressionFuzzerTest.cpp | 8 ++- .../fuzzer/ExtremeArgumentGenerator.cpp | 62 +++++++++++++++++++ .../fuzzer/ExtremeArgumentGenerator.h | 32 ++++++++++ 12 files changed, 200 insertions(+), 25 deletions(-) create mode 100644 velox/expression/tests/ArgumentGenerator.h create mode 100644 velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.cpp create mode 100644 velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h diff --git a/velox/expression/tests/ArgumentGenerator.h b/velox/expression/tests/ArgumentGenerator.h new file mode 100644 index 0000000000000..2052563ef7ab9 --- /dev/null +++ b/velox/expression/tests/ArgumentGenerator.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/ITypedExpr.h" +#include "velox/expression/tests/utils/FuzzerToolkit.h" + +namespace facebook::velox::test { + +class ExpressionFuzzer; + +class ArgumentGenerator { + public: + virtual ~ArgumentGenerator() = default; + + /// Generates function arguments. + virtual std::vector generate( + ExpressionFuzzer* expressionFuzzer, + const CallableSignature& input) = 0; +}; + +} // namespace facebook::velox::test diff --git a/velox/expression/tests/ExpressionFuzzer.cpp b/velox/expression/tests/ExpressionFuzzer.cpp index 98f4c703e2911..32a8dad1724cd 100644 --- a/velox/expression/tests/ExpressionFuzzer.cpp +++ b/velox/expression/tests/ExpressionFuzzer.cpp @@ -531,10 +531,13 @@ ExpressionFuzzer::ExpressionFuzzer( FunctionSignatureMap signatureMap, size_t initialSeed, const std::shared_ptr& vectorFuzzer, + const std::unordered_map>& + customArgumentGenerators, const std::optional& options) : options_(options.value_or(Options())), vectorFuzzer_(vectorFuzzer), - state{rng_, std::max(1, options_.maxLevelOfNesting)} { + state{rng_, std::max(1, options_.maxLevelOfNesting)}, + customArgumentGenerators_(customArgumentGenerators) { VELOX_CHECK(vectorFuzzer, "Vector fuzzer must be provided"); seed(initialSeed); @@ -711,9 +714,6 @@ ExpressionFuzzer::ExpressionFuzzer( // Register function override (for cases where we want to restrict the types // or parameters we pass to functions). registerFuncOverride(&ExpressionFuzzer::generateSwitchArgs, "switch"); - registerFuncOverride( - &ExpressionFuzzer::generateExtremeFunctionArgs, "greatest"); - registerFuncOverride(&ExpressionFuzzer::generateExtremeFunctionArgs, "least"); registerFuncOverride( &ExpressionFuzzer::generateMakeTimestampArgs, "make_timestamp"); registerFuncOverride( @@ -1167,6 +1167,9 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression( std::vector ExpressionFuzzer::getArgsForCallable( const CallableSignature& callable) { + if (customArgumentGenerators_.count(callable.name)) { + return customArgumentGenerators_[callable.name]->generate(callable); + } auto funcIt = funcArgOverrides_.find(callable.name); if (funcIt == funcArgOverrides_.end()) { return generateArgs(callable); diff --git a/velox/expression/tests/ExpressionFuzzer.h b/velox/expression/tests/ExpressionFuzzer.h index 84eb919c3a279..bc293d82195bb 100644 --- a/velox/expression/tests/ExpressionFuzzer.h +++ b/velox/expression/tests/ExpressionFuzzer.h @@ -19,6 +19,7 @@ #include "velox/core/ITypedExpr.h" #include "velox/core/QueryCtx.h" #include "velox/expression/Expr.h" +#include "velox/expression/tests/ArgumentGenerator.h" #include "velox/expression/tests/ExpressionVerifier.h" #include "velox/expression/tests/utils/FuzzerToolkit.h" #include "velox/functions/FunctionRegistry.h" @@ -107,6 +108,8 @@ class ExpressionFuzzer { FunctionSignatureMap signatureMap, size_t initialSeed, const std::shared_ptr& vectorFuzzer, + const std::unordered_map>& + customArgumentGenerators, const std::optional& options = std::nullopt); template @@ -195,6 +198,17 @@ class ExpressionFuzzer { RowTypePtr fuzzRowReturnType(size_t size, char prefix = 'p'); + core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant); + + core::TypedExprPtr generateArgConstant(const TypePtr& arg); + + core::TypedExprPtr generateArgColumn(const TypePtr& arg); + + std::vector generateArgs(const CallableSignature& input); + + // Returns random integer between min and max inclusive. + int32_t rand32(int32_t min, int32_t max); + private: // Either generates a new expression of the required return type or if // already generated expressions of the same return type exist then there is @@ -218,10 +232,6 @@ class ExpressionFuzzer { void appendConjunctSignatures(); - core::TypedExprPtr generateArgConstant(const TypePtr& arg); - - core::TypedExprPtr generateArgColumn(const TypePtr& arg); - core::TypedExprPtr generateArg(const TypePtr& arg); // Given lambda argument type, generate matching LambdaTypedExpr. @@ -234,15 +244,11 @@ class ExpressionFuzzer { // all input. The constant value is generated using 'generateArgConstant'. core::TypedExprPtr generateArgFunction(const TypePtr& arg); - std::vector generateArgs(const CallableSignature& input); - std::vector generateArgs( const std::vector& argTypes, const std::vector& constantArgs, uint32_t numVarArgs = 0); - core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant); - /// Specialization for the "greatest" and "least" functions: decimal varargs /// need to be constant or column. std::vector generateExtremeFunctionArgs( @@ -352,9 +358,6 @@ class ExpressionFuzzer { state.expressionStats_[funcName]++; } - // Returns random integer between min and max inclusive. - int32_t rand32(int32_t min, int32_t max); - static const inline std::string kTypeParameterName = "T"; const Options options_; @@ -441,6 +444,10 @@ class ExpressionFuzzer { int32_t remainingLevelOfNesting_; } state; + + std::unordered_map>& + customArgumentGenerators_; + friend class ExpressionFuzzerUnitTest; }; diff --git a/velox/expression/tests/ExpressionFuzzerTest.cpp b/velox/expression/tests/ExpressionFuzzerTest.cpp index 1c6a675bfdf73..9762b87e3cde2 100644 --- a/velox/expression/tests/ExpressionFuzzerTest.cpp +++ b/velox/expression/tests/ExpressionFuzzerTest.cpp @@ -19,6 +19,7 @@ #include #include "velox/expression/tests/FuzzerRunner.h" +#include "velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" DEFINE_int64( @@ -65,6 +66,16 @@ int main(int argc, char** argv) { "regexp_extract_all", "regexp_like", }; + + const std::unordered_map> + customArgumentGenerators = { + {"greatest", + std::make_shared()}, + {"least", + std::make_shared< + facebook::velox::test::ExtremeArgumentGenerator>()}}; + size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; - return FuzzerRunner::run(initialSeed, skipFunctions, {{}}); + return FuzzerRunner::run( + initialSeed, skipFunctions, {{}}, customArgumentGenerators); } diff --git a/velox/expression/tests/ExpressionFuzzerUnitTest.cpp b/velox/expression/tests/ExpressionFuzzerUnitTest.cpp index 9fd3a09b3a047..043e45a35fd03 100644 --- a/velox/expression/tests/ExpressionFuzzerUnitTest.cpp +++ b/velox/expression/tests/ExpressionFuzzerUnitTest.cpp @@ -77,6 +77,7 @@ TEST_F(ExpressionFuzzerUnitTest, restrictedLevelOfNesting) { velox::getFunctionSignatures(), 0, vectorfuzzer, + {}, makeOptionsWithMaxLevelNesting(maxLevelOfNesting), }; @@ -116,6 +117,7 @@ TEST_F(ExpressionFuzzerUnitTest, reproduceExpressionWithSeed) { velox::getFunctionSignatures(), 1234567, vectorfuzzer, + {}, makeOptionsWithMaxLevelNesting(5)}; for (auto i = 0; i < 10; ++i) { firstGeneration.push_back( @@ -142,6 +144,7 @@ TEST_F(ExpressionFuzzerUnitTest, exprBank) { velox::getFunctionSignatures(), 0, vectorfuzzer, + {}, makeOptionsWithMaxLevelNesting(maxLevelOfNesting)}; ExpressionFuzzer::ExprBank exprBank(seed, maxLevelOfNesting); for (int i = 0; i < 5000; ++i) { @@ -170,6 +173,7 @@ TEST_F(ExpressionFuzzerUnitTest, exprBank) { velox::getFunctionSignatures(), 0, vectorfuzzer, + {}, makeOptionsWithMaxLevelNesting(maxLevelOfNesting)}; ExpressionFuzzer::ExprBank exprBank(seed, maxLevelOfNesting); for (int i = 0; i < 1000; ++i) { diff --git a/velox/expression/tests/ExpressionFuzzerVerifier.cpp b/velox/expression/tests/ExpressionFuzzerVerifier.cpp index 8e36fd739f21c..94ce099a72f41 100644 --- a/velox/expression/tests/ExpressionFuzzerVerifier.cpp +++ b/velox/expression/tests/ExpressionFuzzerVerifier.cpp @@ -80,7 +80,9 @@ RowVectorPtr wrapChildren( ExpressionFuzzerVerifier::ExpressionFuzzerVerifier( const FunctionSignatureMap& signatureMap, size_t initialSeed, - const ExpressionFuzzerVerifier::Options& options) + const ExpressionFuzzerVerifier::Options& options, + const std::unordered_map>& + customArgumentGenerators) : options_(options), queryCtx_(std::make_shared( nullptr, @@ -98,6 +100,7 @@ ExpressionFuzzerVerifier::ExpressionFuzzerVerifier( signatureMap, initialSeed, vectorFuzzer_, + customArgumentGenerators, options.expressionFuzzerOptions) { seed(initialSeed); diff --git a/velox/expression/tests/ExpressionFuzzerVerifier.h b/velox/expression/tests/ExpressionFuzzerVerifier.h index 2f85b5d52bc71..a9e9b6392fa16 100644 --- a/velox/expression/tests/ExpressionFuzzerVerifier.h +++ b/velox/expression/tests/ExpressionFuzzerVerifier.h @@ -51,7 +51,9 @@ class ExpressionFuzzerVerifier { ExpressionFuzzerVerifier( const FunctionSignatureMap& signatureMap, size_t initialSeed, - const Options& options); + const Options& options, + const std::unordered_map>& + customArgumentGenerators); // This function starts the test that is performed by the // ExpressionFuzzerVerifier which is generating random expressions and diff --git a/velox/expression/tests/FuzzerRunner.cpp b/velox/expression/tests/FuzzerRunner.cpp index e947f147c6853..c05333f8f7f93 100644 --- a/velox/expression/tests/FuzzerRunner.cpp +++ b/velox/expression/tests/FuzzerRunner.cpp @@ -210,8 +210,10 @@ ExpressionFuzzerVerifier::Options getExpressionFuzzerVerifierOptions( int FuzzerRunner::run( size_t seed, const std::unordered_set& skipFunctions, - const std::unordered_map& queryConfigs) { - runFromGtest(seed, skipFunctions, queryConfigs); + const std::unordered_map& queryConfigs, + const std::unordered_map>& + customArgumentGenerators) { + runFromGtest(seed, skipFunctions, queryConfigs, customArgumentGenerators); return RUN_ALL_TESTS(); } @@ -219,13 +221,16 @@ int FuzzerRunner::run( void FuzzerRunner::runFromGtest( size_t seed, const std::unordered_set& skipFunctions, - const std::unordered_map& queryConfigs) { + const std::unordered_map& queryConfigs, + const std::unordered_map>& + customArgumentGenerators) { memory::MemoryManager::testingSetInstance({}); auto signatures = facebook::velox::getFunctionSignatures(); ExpressionFuzzerVerifier( signatures, seed, - getExpressionFuzzerVerifierOptions(skipFunctions, queryConfigs)) + getExpressionFuzzerVerifierOptions(skipFunctions, queryConfigs), + customArgumentGenerators) .go(); } } // namespace facebook::velox::test diff --git a/velox/expression/tests/FuzzerRunner.h b/velox/expression/tests/FuzzerRunner.h index cbf3d5ac290a9..efb09b262b461 100644 --- a/velox/expression/tests/FuzzerRunner.h +++ b/velox/expression/tests/FuzzerRunner.h @@ -22,6 +22,7 @@ #include #include +#include "velox/expression/tests/ArgumentGenerator.h" #include "velox/expression/tests/ExpressionFuzzerVerifier.h" #include "velox/functions/FunctionRegistry.h" @@ -33,12 +34,16 @@ class FuzzerRunner { static int run( size_t seed, const std::unordered_set& skipFunctions, - const std::unordered_map& queryConfigs); + const std::unordered_map& queryConfigs, + const std::unordered_map>& + customArgumentGenerators); static void runFromGtest( size_t seed, const std::unordered_set& skipFunctions, - const std::unordered_map& queryConfigs); + const std::unordered_map& queryConfigs, + const std::unordered_map>& + customArgumentGenerators); }; } // namespace facebook::velox::test diff --git a/velox/expression/tests/SparkExpressionFuzzerTest.cpp b/velox/expression/tests/SparkExpressionFuzzerTest.cpp index 1232c49ed12ab..135c0db358502 100644 --- a/velox/expression/tests/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/tests/SparkExpressionFuzzerTest.cpp @@ -62,5 +62,11 @@ int main(int argc, char** argv) { {facebook::velox::core::QueryConfig::kSessionTimezone, "America/Los_Angeles"}}; - return FuzzerRunner::run(FLAGS_seed, skipFunctions, queryConfigs); + const std::unordered_map< + std::string, + std::shared_ptr> + customArgumentGenerators = {}; + + return FuzzerRunner::run( + FLAGS_seed, skipFunctions, queryConfigs, customArgumentGenerators); } diff --git a/velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.cpp b/velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.cpp new file mode 100644 index 0000000000000..998e3dd3c4af1 --- /dev/null +++ b/velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/tests/ExtremeArgumentGenerator.h" +#include "velox/expression/tests/ExpressionFuzzer.h" + +namespace facebook::velox::test { + +std::vector ExtremeArgumentGenerator::generate( + ExpressionFuzzer* expressionFuzzer, + const CallableSignature& input) { + const auto argTypes = input.args; + VELOX_CHECK_GE( + argTypes.size(), + 1, + "At least one input is expected from the template signature."); + if (!argTypes[0]->isDecimal()) { + return expressionFuzzer->generateArgs(input); + } + + auto numVarArgs = !input.variableArity + ? 0 + : expressionFuzzer->rand32(0, options_.maxNumVarArgs); + std::vector inputExpressions; + inputExpressions.reserve(argTypes.size() + numVarArgs); + inputExpressions.emplace_back( + expressionFuzzer->generateArg(argTypes.at(0), input.constantArgs.at(0))); + + // Append varargs to the argument list. + for (int i = 0; i < numVarArgs; i++) { + core::TypedExprPtr argExpr; + // The varargs need to be generated following the result type of the first + // argument. But when nested expression is generated, that cannot be + // guaranteed as argument precisions and scales cannot be inferred from the + // result type through a decimal function signature. Given this limitation, + // generate constant or column only. + const auto argType = inputExpressions[0]->type(); + if (expressionFuzzer->rand32(0, 1) == kArgConstant) { + argExpr = expressionFuzzer->generateArgConstant(argType); + } else { + argExpr = expressionFuzzer->generateArgColumn(argType); + } + inputExpressions.emplace_back(argExpr); + } + return inputExpressions; +} + +} // namespace facebook::velox::test diff --git a/velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h b/velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h new file mode 100644 index 0000000000000..b30b22703db5d --- /dev/null +++ b/velox/functions/prestosql/fuzzer/ExtremeArgumentGenerator.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/tests/ArgumentGenerator.h" + +namespace facebook::velox::test { + +class ExpressionFuzzer; + +class ExtremeArgumentGenerator : public ArgumentGenerator { + public: + /// Generates function arguments. + std::vector generate( + ExpressionFuzzer* expressionFuzzer, + const CallableSignature& input) override; +}; + +} // namespace facebook::velox::test