Skip to content

Commit

Permalink
apacheGH-40102: Support scalar aggregates in ExecuteScalarExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
JerAguilon committed Feb 16, 2024
1 parent a03d957 commit 54913ae
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
17 changes: 13 additions & 4 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/util/string.h"
#include "arrow/util/value_parsing.h"
#include "arrow/util/vector.h"
#include "exec_internal.h"

namespace arrow {

Expand Down Expand Up @@ -298,15 +299,17 @@ bool Expression::IsScalarExpression() const {
}

if (call->function) {
return call->function->kind() == compute::Function::SCALAR;
return call->function->kind() == compute::Function::SCALAR ||
call->function->kind() == compute::Function::SCALAR_AGGREGATE;
}

// this expression is not bound; make a best guess based on
// the default function registry
if (auto function = compute::GetFunctionRegistry()
->GetFunction(call->function_name)
.ValueOr(nullptr)) {
return function->kind() == compute::Function::SCALAR;
return function->kind() == compute::Function::SCALAR ||
function->kind() == compute::Function::SCALAR_AGGREGATE;
}

// unknown function or other error; conservatively return false
Expand Down Expand Up @@ -770,14 +773,20 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& i
input_length = input.length;
}

auto executor = compute::detail::KernelExecutor::MakeScalar();
DCHECK(call->function->kind() == compute::Function::SCALAR ||
call->function->kind() == compute::Function::SCALAR_AGGREGATE);
auto executor = call->function->kind() == compute::Function::SCALAR
? compute::detail::KernelExecutor::MakeScalar()
: compute::detail::KernelExecutor::MakeScalarAggregate();

compute::KernelContext kernel_context(exec_context, call->kernel);
kernel_context.SetState(call->kernel_state.get());

const Kernel* kernel = call->kernel;
std::vector<TypeHolder> types = GetTypes(arguments);
auto options = call->options.get();
auto options = call->options.get() == NULLPTR ? call->function->default_options()
: call->options.get();

RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options}));

compute::detail::DatumAccumulator listener;
Expand Down
15 changes: 14 additions & 1 deletion cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// under the License.

#include "arrow/compute/expression.h"
#include <arrow/compute/api_aggregate.h> // TODO


#include <chrono>
#include <cstdint>
Expand Down Expand Up @@ -77,6 +79,10 @@ Expression true_unless_null(Expression argument) {
return call("true_unless_null", {std::move(argument)});
}

Expression last(Expression l) {
return call("last", {std::move(l)});
}

Expression add(Expression l, Expression r) {
return call("add", {std::move(l), std::move(r)});
}
Expand Down Expand Up @@ -810,7 +816,6 @@ void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) {
}

ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, *schm, in));

ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in));

AssertDatumsEqual(actual, expected, /*verbose=*/true);
Expand All @@ -821,6 +826,14 @@ void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) {
}

TEST(Expression, ExecuteCall) {
ExpectExecute(greater(field_ref("a"), last(field_ref("a"))),
ArrayFromJSON(struct_({field("a", float64())}), R"([
{"a": 5},
{"a": 4},
{"a": 3},
{"a": 4}
])"));

ExpectExecute(add(field_ref("a"), literal(3.5)),
ArrayFromJSON(struct_({field("a", float64())}), R"([
{"a": 6.125},
Expand Down

0 comments on commit 54913ae

Please sign in to comment.