diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index 19e2fbb8f2fd..93b00d4c6c3d 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -324,7 +324,11 @@ struct udf_has_name : std::false_type {}; template struct udf_has_name : std::true_type {}; -template +template < + typename Fun, + typename TReturn, + typename ConstantChecker, + typename... Args> class SimpleFunctionMetadata : public ISimpleFunctionMetadata { public: using return_type = TReturn; @@ -453,9 +457,13 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { auto builder = exec::FunctionSignatureBuilder(); builder.returnType(analysis.outputType); - + int32_t position = 0; for (const auto& arg : analysis.argsTypes) { - builder.argumentType(arg); + if (ConstantChecker::isConstant[position++]) { + builder.constantArgumentType(arg); + } else { + builder.argumentType(arg); + } } for (const auto& [_, variable] : analysis.variables) { @@ -474,14 +482,21 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { // wraps a UDF object to provide the inheritance // this is basically just boilerplate-avoidance -template +template < + typename Fun, + typename Exec, + typename TReturn, + typename ConstantChecker, + typename... TArgs> class UDFHolder final - : public core::SimpleFunctionMetadata { + : public core:: + SimpleFunctionMetadata { Fun instance_; public: using udf_struct_t = Fun; - using Metadata = core::SimpleFunctionMetadata; + using Metadata = + core::SimpleFunctionMetadata; template using exec_resolver = typename Exec::template resolver; diff --git a/velox/expression/tests/SimpleFunctionTest.cpp b/velox/expression/tests/SimpleFunctionTest.cpp index c67594c9c3c3..a77c3de60eca 100644 --- a/velox/expression/tests/SimpleFunctionTest.cpp +++ b/velox/expression/tests/SimpleFunctionTest.cpp @@ -76,12 +76,14 @@ class SimpleFunctionTest : public functions::test::FunctionBaseTest { CallNullFreeFuncVoidOut, exec::VectorExec, int32_t, - TArgs...>; + ConstantChecker, + typename UnwrapConstantType::type...>; using holderClassBool = core::UDFHolder< CallNullFreeFuncBoolOut, exec::VectorExec, int32_t, - TArgs...>; + ConstantChecker, + typename UnwrapConstantType::type...>; if (voidOutput) { ASSERT_EQ( expected, @@ -1104,6 +1106,7 @@ TEST_F(SimpleFunctionTest, isAsciiArgs) { StringInputFunction, exec::VectorExec, int32_t, + ConstantChecker, Varchar>; using function_t = exec::SimpleFunctionAdapter; @@ -1239,4 +1242,50 @@ TEST_F(SimpleFunctionTest, flatNoNullsPathCallNullFree) { testCallNullFreeSupportFlatNotNulls(false, false); testCallNullFreeSupportFlatNotNulls(false, false); } + +template +struct ConstantArgumentFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + void initialize( + const core::QueryConfig& /*config*/, + const arg_type* /*first*/, + const arg_type* /*second*/, + const arg_type* /*third*/, + const arg_type>* /*fourth*/, + const arg_type>* /*fifth*/, + const arg_type>* /*sixth*/) {} + + bool callNullable( + out_type& out, + const arg_type* /*first*/, + const arg_type* /*second*/, + const arg_type* /*third*/, + const arg_type>* /*fourth*/, + const arg_type>* /*fifth*/, + const arg_type>* /*sixth*/) { + out = 1; + return true; + } +}; + +TEST_F(SimpleFunctionTest, constantArgument) { + registerFunction< + ConstantArgumentFunction, + int64_t, + int32_t, + Constant, + Constant, + Constant>, + Constant>, + Constant>>({"constant_argument_function"}); + auto signatures = exec::simpleFunctions().getFunctionSignatures( + "constant_argument_function"); + EXPECT_FALSE(signatures[0]->constantArguments().at(0)); + EXPECT_TRUE(signatures[0]->constantArguments().at(1)); + EXPECT_TRUE(signatures[0]->constantArguments().at(2)); + EXPECT_TRUE(signatures[0]->constantArguments().at(3)); + EXPECT_TRUE(signatures[0]->constantArguments().at(4)); + EXPECT_TRUE(signatures[0]->constantArguments().at(5)); +} } // namespace diff --git a/velox/functions/Registerer.h b/velox/functions/Registerer.h index 16b53a3eaa17..7b893c6c3928 100644 --- a/velox/functions/Registerer.h +++ b/velox/functions/Registerer.h @@ -37,8 +37,12 @@ using ParameterBinder = TempWrapper>; template void registerFunction(const std::vector& aliases = {}) { using funcClass = typename Func::template udf; - using holderClass = - core::UDFHolder; + using holderClass = core::UDFHolder< + funcClass, + exec::VectorExec, + TReturn, + ConstantChecker, + typename UnwrapConstantType::type...>; exec::registerSimpleFunction(aliases); } @@ -49,8 +53,12 @@ void registerFunction(const std::vector& aliases = {}) { template