diff --git a/velox/expression/tests/ExpressionFuzzer.cpp b/velox/expression/tests/ExpressionFuzzer.cpp index 941c9655a916..05b64f49e61e 100644 --- a/velox/expression/tests/ExpressionFuzzer.cpp +++ b/velox/expression/tests/ExpressionFuzzer.cpp @@ -398,33 +398,40 @@ ExpressionFuzzer::ExpressionFuzzer( (double)unsupportedFunctionSignatures / totalFunctionSignatures * 100); // We sort the available signatures before inserting them into - // signaturesMap_. The purpose of this step is to ensure the vector of - // function signatures associated with each key in signaturesMap_ has a - // deterministic order, so that we can deterministically generate - // expressions across platforms. We just do this once and the vector is - // small, so it doesn't need to be very efficient. + // typeToExpressionList_ and expressionToSignature_. The purpose of this step + // is to ensure the vector of function signatures associated with each key in + // signaturesMap_ has a deterministic order, so that we can deterministically + // generate expressions across platforms. We just do this once and the vector + // is small, so it doesn't need to be very efficient. sortCallableSignatures(signatures_); - // Generates signaturesMap, which maps a given type to the function - // signature that returns it. for (const auto& it : signatures_) { - signaturesMap_[it.returnType->kind()].push_back(&it); + auto returnType = typeToBaseName(it.returnType); + if (typeToExpressionList_[returnType].empty() || + typeToExpressionList_[returnType].back() != it.name) { + // Ensure only one entry for a function name is added. This + // gives all others a fair chance to be selected. Since signatures + // are sorted on the function name this check will always work. + typeToExpressionList_[returnType].push_back(it.name); + } + expressionToSignature_[it.name][returnType].push_back(&it); } // Similarly, sort all template signatures. sortSignatureTemplates(signatureTemplates_); - // Insert signature templates into signatureTemplateMap_ grouped by their - // return type base name. If the return type is a type variable, insert the - // signature template into the list of key kTypeParameterName. for (const auto& it : signatureTemplates_) { auto& returnType = it.signature->returnType().baseName(); - if (it.typeVariables.find(returnType) == it.typeVariables.end()) { - signatureTemplateMap_[it.signature->returnType().baseName()].push_back( - &it); - } else { - signatureTemplateMap_[kTypeParameterName].push_back(&it); + auto* returnTypeKey = &returnType; + if (it.typeVariables.find(returnType) != it.typeVariables.end()) { + // Return type is a template variable. + returnTypeKey = &kTypeParameterName; + } + if (typeToExpressionList_[*returnTypeKey].empty() || + typeToExpressionList_[*returnTypeKey].back() != it.name) { + typeToExpressionList_[*returnTypeKey].push_back(it.name); } + expressionToTemplatedSignature_[it.name][*returnTypeKey].push_back(&it); } // Register function override (for cases where we want to restrict the types @@ -613,41 +620,42 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression( bool reuseExpression = FLAGS_velox_fuzzer_enable_expression_reuse && !listOfCandidateExprs.empty() && vectorFuzzer_.coinToss(0.3); if (!reuseExpression) { + auto baseType = typeToBaseName(returnType); + VELOX_CHECK_NE( + baseType, "T", "returnType should have all concrete types defined"); + // Randomly pick among all functions that support this return type. Also, + // consider all functions that have return type "T" as they can + // support any concrete return type. + auto& baseList = typeToExpressionList_[baseType]; + auto& templateList = typeToExpressionList_[kTypeParameterName]; + uint32_t numEligible = baseList.size() + templateList.size(); core::TypedExprPtr expression; - // Generate a cast expression with 40% chance. - if (FLAGS_enable_cast && vectorFuzzer_.coinToss(0.4)) { - expression = generateCastExpression(returnType); - if (!expression) { - LOG(INFO) << "Casting to '" << returnType->toString() - << "' is unsupported. Returning a constant instead."; - expression = generateArgConstant(returnType); + if (numEligible > 0) { + size_t chosenExprIndex = + boost::random::uniform_int_distribution( + 0, numEligible - 1)(rng_); + std::string chosenFunctionName; + if (chosenExprIndex < baseList.size()) { + chosenFunctionName = baseList[chosenExprIndex]; + } else { + chosenExprIndex -= baseList.size(); + chosenFunctionName = templateList[chosenExprIndex]; } - return expression; - } - auto firstAttempt = - &ExpressionFuzzer::generateExpressionFromConcreteSignatures; - auto secondAttempt = - &ExpressionFuzzer::generateExpressionFromSignatureTemplate; - size_t useSignatureTemplate = - boost::random::uniform_int_distribution(0, 1)(rng_); - if (FLAGS_velox_fuzzer_enable_complex_types && useSignatureTemplate) { - std::swap(firstAttempt, secondAttempt); + expression = generateExpressionFromConcreteSignatures( + returnType, chosenFunctionName); + if (!expression && FLAGS_velox_fuzzer_enable_complex_types) { + expression = generateExpressionFromSignatureTemplate( + returnType, chosenFunctionName); + } } - - expression = (this->*firstAttempt)(returnType); if (!expression) { - if (FLAGS_velox_fuzzer_enable_complex_types) { - expression = (this->*secondAttempt)(returnType); - } - if (!expression) { - LOG(INFO) << "Couldn't find any function to return '" - << returnType->toString() - << "'. Returning a constant instead."; - expression = generateArgConstant(returnType); - } + LOG(INFO) << "Couldn't find any function to return '" + << returnType->toString() << "'. Returning a constant instead."; + return generateArgConstant(returnType); } + if (remainingLevelOfNesting_ == 0) { // Only add expressions that do not have nested expressions. listOfCandidateExprs.emplace_back(expression); @@ -677,18 +685,22 @@ core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable( } core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures( - const TypePtr& returnType) { - auto it = signaturesMap_.find(returnType->kind()); - if (it == signaturesMap_.end()) { + const TypePtr& returnType, + const std::string& functionName) { + if (expressionToSignature_.find(functionName) == + expressionToSignature_.end()) { + return nullptr; + } + auto baseType = typeToBaseName(returnType); + auto itr = expressionToSignature_[functionName].find(baseType); + if (itr == expressionToSignature_[functionName].end()) { return nullptr; } - // Only function signatures whose return type equals to returnType are // eligible. There may be ineligible signatures in signaturesMap_ because // the map keys only differentiate top-level type kinds. std::vector eligible; - const auto& signatures = it->second; - for (const auto* signature : signatures) { + for (auto signature : itr->second) { if (signature->returnType->equivalent(*returnType)) { eligible.push_back(signature); } @@ -708,17 +720,22 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures( const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate( const TypePtr& returnType, - const std::string& typeName) { + const std::string& typeName, + const std::string& functionName) { std::vector eligible; - auto it = signatureTemplateMap_.find(typeName); - if (it == signatureTemplateMap_.end()) { + if (expressionToTemplatedSignature_.find(functionName) == + expressionToTemplatedSignature_.end()) { + return nullptr; + } + auto it = expressionToTemplatedSignature_[functionName].find(typeName); + if (it == expressionToTemplatedSignature_[functionName].end()) { return nullptr; } // Only function signatures whose return type can match returnType are // eligible. There may be ineligible signatures in signaturesMap_ because // the map keys only differentiate the top-level type names. auto& signatureTemplates = it->second; - for (auto* signatureTemplate : signatureTemplates) { + for (auto signatureTemplate : signatureTemplates) { exec::ReverseSignatureBinder binder{ *signatureTemplate->signature, returnType}; if (binder.tryBind()) { @@ -735,12 +752,15 @@ const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate( } core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate( - const TypePtr& returnType) { + const TypePtr& returnType, + const std::string& functionName) { auto typeName = typeToBaseName(returnType); - auto* chosen = chooseRandomSignatureTemplate(returnType, typeName); + auto* chosen = + chooseRandomSignatureTemplate(returnType, typeName, functionName); if (!chosen) { - chosen = chooseRandomSignatureTemplate(returnType, kTypeParameterName); + chosen = chooseRandomSignatureTemplate( + returnType, kTypeParameterName, functionName); if (!chosen) { return nullptr; } diff --git a/velox/expression/tests/ExpressionFuzzer.h b/velox/expression/tests/ExpressionFuzzer.h index d29930c426c1..99299bf89fc7 100644 --- a/velox/expression/tests/ExpressionFuzzer.h +++ b/velox/expression/tests/ExpressionFuzzer.h @@ -137,22 +137,27 @@ class ExpressionFuzzer { core::TypedExprPtr getCallExprFromCallable(const CallableSignature& callable); - /// Generate an expression with a random concrete function signature that - /// returns returnType. + /// Generate an expression by randomly selecting a concrete function signature + /// that returns 'returnType' among all signatures that the function named + /// 'functionName' supports. core::TypedExprPtr generateExpressionFromConcreteSignatures( - const TypePtr& returnType); + const TypePtr& returnType, + const std::string& functionName); - /// Return a random signature template mapped to typeName in - /// signatureTemplateMap_ whose return type can match returnType. Return - /// nullptr if no such signature template exists. + /// Return a random signature template mapped to typeName and functionName in + /// expressionToTemplatedSignature_ whose return type can match returnType. + /// Return nullptr if no such signature template exists. const SignatureTemplate* chooseRandomSignatureTemplate( const TypePtr& returnType, - const std::string& typeName); + const std::string& typeName, + const std::string& functionName); - /// Generate an expression with a random function signature template that - /// returns returnType. + /// Generate an expression by randomly selecting a function signature template + /// that returns 'returnType' among all signature templates that the function + /// named 'functionName' supports. core::TypedExprPtr generateExpressionFromSignatureTemplate( - const TypePtr& returnType); + const TypePtr& returnType, + const std::string& functionName); /// Generate a cast expression that returns the specified type. Return a /// nullptr if casting to the specified type is not supported. The supported @@ -194,18 +199,31 @@ class ExpressionFuzzer { size_t currentSeed_{0}; std::vector signatures_; - - /// Maps a given type to the functions that return that type. - std::unordered_map> - signaturesMap_; - std::vector signatureTemplates_; - /// Maps the base name of the return type signature to the functions that - /// return this type. Base name could be "T" if the return type is a type - /// variable. - std::unordered_map> - signatureTemplateMap_; + /// Maps the base name of a return type signature to the function names that + /// support that return type. Base name could be "T" if the return type is a + /// type variable. + std::unordered_map> + typeToExpressionList_; + + /// Maps the base name of a *concrete* return type signature to the function + /// names that support that return type. Those names then each further map to + /// a list of CallableSignature objects that they support. Base name could be + /// "T" if the return type is a type variable. + std::unordered_map< + std::string, + std::unordered_map>> + expressionToSignature_; + + /// Maps the base name of a *templated* return type signature to the function + /// names that support that return type. Those names then each further map to + /// a list of SignatureTemplate objects that they support. Base name could be + /// "T" if the return type is a type variable. + std::unordered_map< + std::string, + std::unordered_map>> + expressionToTemplatedSignature_; /// The remaining levels of expression nesting. It's initialized by /// FLAGS_max_level_of_nesting and updated in generateExpression(). When its