Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow marking an argument as constant in the simple function API #7901

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions velox/core/SimpleFunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,11 @@ struct udf_has_name : std::false_type {};
template <typename T>
struct udf_has_name<T, decltype(&T::name, 0)> : std::true_type {};

template <typename Fun, typename TReturn, typename... Args>
template <
typename Fun,
typename TReturn,
typename ConstantChecker,
typename... Args>
class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
public:
using return_type = TReturn;
Expand Down Expand Up @@ -453,9 +457,13 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {
auto builder = exec::FunctionSignatureBuilder();

builder.returnType(analysis.outputType);

int32_t position = 0;
for (const auto& arg : analysis.argsTypes) {
builder.argumentType(arg);
if (ConstantChecker::isConstant[position++]) {
builder.constantArgumentType(arg);
} else {
builder.argumentType(arg);
}
}

for (const auto& [_, variable] : analysis.variables) {
Expand All @@ -474,14 +482,21 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata {

// wraps a UDF object to provide the inheritance
// this is basically just boilerplate-avoidance
template <typename Fun, typename Exec, typename TReturn, typename... TArgs>
template <
typename Fun,
typename Exec,
typename TReturn,
typename ConstantChecker,
typename... TArgs>
class UDFHolder final
: public core::SimpleFunctionMetadata<Fun, TReturn, TArgs...> {
: public core::
SimpleFunctionMetadata<Fun, TReturn, ConstantChecker, TArgs...> {
Fun instance_;

public:
using udf_struct_t = Fun;
using Metadata = core::SimpleFunctionMetadata<Fun, TReturn, TArgs...>;
using Metadata =
core::SimpleFunctionMetadata<Fun, TReturn, ConstantChecker, TArgs...>;

template <typename T>
using exec_resolver = typename Exec::template resolver<T>;
Expand Down
53 changes: 51 additions & 2 deletions velox/expression/tests/SimpleFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ class SimpleFunctionTest : public functions::test::FunctionBaseTest {
CallNullFreeFuncVoidOut<exec::VectorExec>,
exec::VectorExec,
int32_t,
TArgs...>;
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
using holderClassBool = core::UDFHolder<
CallNullFreeFuncBoolOut<exec::VectorExec>,
exec::VectorExec,
int32_t,
TArgs...>;
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
if (voidOutput) {
ASSERT_EQ(
expected,
Expand Down Expand Up @@ -1104,6 +1106,7 @@ TEST_F(SimpleFunctionTest, isAsciiArgs) {
StringInputFunction<exec::VectorExec>,
exec::VectorExec,
int32_t,
ConstantChecker<Varchar>,
Varchar>;
using function_t = exec::SimpleFunctionAdapter<holder_class_t>;

Expand Down Expand Up @@ -1239,4 +1242,50 @@ TEST_F(SimpleFunctionTest, flatNoNullsPathCallNullFree) {
testCallNullFreeSupportFlatNotNulls<int64_t, int64_t>(false, false);
testCallNullFreeSupportFlatNotNulls<Varchar, Varchar>(false, false);
}

template <typename T>
struct ConstantArgumentFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

void initialize(
const core::QueryConfig& /*config*/,
const arg_type<int32_t>* /*first*/,
const arg_type<int32_t>* /*second*/,
const arg_type<Varchar>* /*third*/,
const arg_type<Generic<T1>>* /*fourth*/,
const arg_type<Array<int32_t>>* /*fifth*/,
const arg_type<Map<int32_t, int32_t>>* /*sixth*/) {}

bool callNullable(
out_type<int64_t>& out,
const arg_type<int32_t>* /*first*/,
const arg_type<int32_t>* /*second*/,
const arg_type<Varchar>* /*third*/,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks.

const arg_type<Generic<T1>>* /*fourth*/,
const arg_type<Array<int32_t>>* /*fifth*/,
const arg_type<Map<int32_t, int32_t>>* /*sixth*/) {
out = 1;
return true;
}
};

TEST_F(SimpleFunctionTest, constantArgument) {
registerFunction<
ConstantArgumentFunction,
int64_t,
int32_t,
Constant<int32_t>,
Constant<Varchar>,
Constant<Generic<T1>>,
Constant<Array<int32_t>>,
Constant<Map<int32_t, int32_t>>>({"constant_argument_function"});
auto signatures = exec::simpleFunctions().getFunctionSignatures(
"constant_argument_function");
EXPECT_FALSE(signatures[0]->constantArguments().at(0));
EXPECT_TRUE(signatures[0]->constantArguments().at(1));
EXPECT_TRUE(signatures[0]->constantArguments().at(2));
EXPECT_TRUE(signatures[0]->constantArguments().at(3));
EXPECT_TRUE(signatures[0]->constantArguments().at(4));
EXPECT_TRUE(signatures[0]->constantArguments().at(5));
}
} // namespace
16 changes: 12 additions & 4 deletions velox/functions/Registerer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ using ParameterBinder = TempWrapper<T<exec::VectorExec, TArgs...>>;
template <typename Func, typename TReturn, typename... TArgs>
void registerFunction(const std::vector<std::string>& aliases = {}) {
using funcClass = typename Func::template udf<exec::VectorExec>;
using holderClass =
core::UDFHolder<funcClass, exec::VectorExec, TReturn, TArgs...>;
using holderClass = core::UDFHolder<
funcClass,
exec::VectorExec,
TReturn,
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
exec::registerSimpleFunction<holderClass>(aliases);
}

Expand All @@ -49,8 +53,12 @@ void registerFunction(const std::vector<std::string>& aliases = {}) {
template <template <class> typename Func, typename TReturn, typename... TArgs>
void registerFunction(const std::vector<std::string>& aliases = {}) {
using funcClass = Func<exec::VectorExec>;
using holderClass =
core::UDFHolder<funcClass, exec::VectorExec, TReturn, TArgs...>;
using holderClass = core::UDFHolder<
funcClass,
exec::VectorExec,
TReturn,
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
exec::registerSimpleFunction<holderClass>(aliases);
}

Expand Down
29 changes: 29 additions & 0 deletions velox/type/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,35 @@ struct Varchar {
Varchar() {}
};

template <typename T>
struct Constant {};

template <typename T>
struct UnwrapConstantType {
using type = T;
};

template <typename T>
struct UnwrapConstantType<Constant<T>> {
using type = T;
};

template <typename T>
struct isConstantType {
static constexpr bool value = false;
};

template <typename T>
struct isConstantType<Constant<T>> {
static constexpr bool value = true;
};

template <typename... TArgs>
struct ConstantChecker {
static constexpr bool isConstant[sizeof...(TArgs)] = {
isConstantType<TArgs>::value...};
};

template <typename T>
struct SimpleTypeTrait {};

Expand Down