Skip to content

Commit

Permalink
GH-44052: [C++][Compute] Reduce the complexity of row segmenter (#44053)
Browse files Browse the repository at this point in the history
### Rationale for this change

As described in #44052, currently `AnyKeysSegmenter::GetNextSegment` has `O(n*m)` complexity, where `n` is the number of rows in a batch, and `m` is the number of segments in this batch (a "segment" is the group of contiguous rows who have the same segment key). This is because in each invocation of the method, it computes all the group ids of the remaining rows in this batch, where it's only interested in the first group, making the rest of the computation a waste.

In this PR I introduced a new API `GetSegments` (and subsequently deprecated the old `GetNextSegment`) to compute the group ids only once and iterate all the segments outside to avoid the duplicated computation. This reduces the complexity from `O(n*m)` to `O(n)`.

### What changes are included in this PR?

1. Because `grouper.h` is a [public header](https://github.com/apache/arrow/blob/8556001e6a8b4c7f35d4e18c28704d7811005904/cpp/src/arrow/compute/api.h#L47), so I assume `RowSegmenter::GetNextSegment` is a public API and only deprecate it instead of removing it.
2. Implement new API `RowSegmenter::GetSegments` and update the call-sites.
3. Some code reorg of the segmenter code (mostly moving to inside a class).
4. A new benchmark for the segmented aggregation. (The benchmark result is listed in the comments below, which shows up to `50x` speedup, nearly `O(n*m)` to `O(n)` complexity reduction.)

### Are these changes tested?

Legacy tests are sufficient.

### Are there any user-facing changes?

Yes.

**This PR includes breaking changes to public APIs.**

The API `RowSegmenter::GetNextSegment` is deprecated due to its inefficiency and replaced with a more efficient one `RowSegmenter::GetSegments`.

* GitHub Issue: #44052

Lead-authored-by: Ruoxi Sun <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
zanmato1984 and pitrou authored Sep 18, 2024
1 parent 9576a41 commit 3d6d581
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 130 deletions.
119 changes: 98 additions & 21 deletions cpp/src/arrow/acero/aggregate_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "arrow/array/array_primitive.h"
#include "arrow/compute/api.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/util/benchmark_util.h"
Expand Down Expand Up @@ -325,7 +326,8 @@ BENCHMARK_TEMPLATE(ReferenceSum, SumBitmapVectorizeUnroll<int64_t>)

std::shared_ptr<RecordBatch> RecordBatchFromArrays(
const std::vector<std::shared_ptr<Array>>& arguments,
const std::vector<std::shared_ptr<Array>>& keys) {
const std::vector<std::shared_ptr<Array>>& keys,
const std::vector<std::shared_ptr<Array>>& segment_keys) {
std::vector<std::shared_ptr<Field>> fields;
std::vector<std::shared_ptr<Array>> all_arrays;
int64_t length = -1;
Expand All @@ -347,37 +349,56 @@ std::shared_ptr<RecordBatch> RecordBatchFromArrays(
fields.push_back(field("key" + ToChars(key_idx), key->type()));
all_arrays.push_back(key);
}
for (std::size_t segment_key_idx = 0; segment_key_idx < segment_keys.size();
segment_key_idx++) {
const auto& segment_key = segment_keys[segment_key_idx];
DCHECK_EQ(segment_key->length(), length);
fields.push_back(
field("segment_key" + ToChars(segment_key_idx), segment_key->type()));
all_arrays.push_back(segment_key);
}
return RecordBatch::Make(schema(std::move(fields)), length, std::move(all_arrays));
}

Result<std::shared_ptr<Table>> BatchGroupBy(
std::shared_ptr<RecordBatch> batch, std::vector<Aggregate> aggregates,
std::vector<FieldRef> keys, bool use_threads = false,
MemoryPool* memory_pool = default_memory_pool()) {
std::vector<FieldRef> keys, std::vector<FieldRef> segment_keys,
bool use_threads = false, MemoryPool* memory_pool = default_memory_pool()) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> table,
Table::FromRecordBatches({std::move(batch)}));
Declaration plan = Declaration::Sequence(
{{"table_source", TableSourceNodeOptions(std::move(table))},
{"aggregate", AggregateNodeOptions(std::move(aggregates), std::move(keys))}});
{"aggregate", AggregateNodeOptions(std::move(aggregates), std::move(keys),
std::move(segment_keys))}});
return DeclarationToTable(std::move(plan), use_threads, memory_pool);
}

static void BenchmarkGroupBy(benchmark::State& state, std::vector<Aggregate> aggregates,
const std::vector<std::shared_ptr<Array>>& arguments,
const std::vector<std::shared_ptr<Array>>& keys) {
std::shared_ptr<RecordBatch> batch = RecordBatchFromArrays(arguments, keys);
static void BenchmarkAggregate(
benchmark::State& state, std::vector<Aggregate> aggregates,
const std::vector<std::shared_ptr<Array>>& arguments,
const std::vector<std::shared_ptr<Array>>& keys,
const std::vector<std::shared_ptr<Array>>& segment_keys = {}) {
std::shared_ptr<RecordBatch> batch =
RecordBatchFromArrays(arguments, keys, segment_keys);
std::vector<FieldRef> key_refs;
for (std::size_t key_idx = 0; key_idx < keys.size(); key_idx++) {
key_refs.emplace_back(static_cast<int>(key_idx + arguments.size()));
}
std::vector<FieldRef> segment_key_refs;
for (std::size_t segment_key_idx = 0; segment_key_idx < segment_keys.size();
segment_key_idx++) {
segment_key_refs.emplace_back(
static_cast<int>(segment_key_idx + arguments.size() + keys.size()));
}
for (std::size_t arg_idx = 0; arg_idx < arguments.size(); arg_idx++) {
aggregates[arg_idx].target = {FieldRef(static_cast<int>(arg_idx))};
}
int64_t total_bytes = TotalBufferSize(*batch);
for (auto _ : state) {
ABORT_NOT_OK(BatchGroupBy(batch, aggregates, key_refs));
ABORT_NOT_OK(BatchGroupBy(batch, aggregates, key_refs, segment_key_refs));
}
state.SetBytesProcessed(total_bytes * state.iterations());
state.SetItemsProcessed(batch->num_rows() * state.iterations());
}

#define GROUP_BY_BENCHMARK(Name, Impl) \
Expand All @@ -404,7 +425,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyStringSet, [&] {
/*min_length=*/3,
/*max_length=*/32);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallStringSet, [&] {
Expand All @@ -419,7 +440,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallStringSet, [&] {
/*min_length=*/3,
/*max_length=*/32);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumStringSet, [&] {
Expand All @@ -434,7 +455,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumStringSet, [&] {
/*min_length=*/3,
/*max_length=*/32);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntegerSet, [&] {
Expand All @@ -448,7 +469,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntegerSet, [&] {
/*min=*/0,
/*max=*/15);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntegerSet, [&] {
Expand All @@ -462,7 +483,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntegerSet, [&] {
/*min=*/0,
/*max=*/255);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntegerSet, [&] {
Expand All @@ -476,7 +497,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntegerSet, [&] {
/*min=*/0,
/*max=*/4095);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntStringPairSet, [&] {
Expand All @@ -494,7 +515,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntStringPairSet, [&] {
/*min_length=*/3,
/*max_length=*/32);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntStringPairSet, [&] {
Expand All @@ -512,7 +533,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntStringPairSet, [&] {
/*min_length=*/3,
/*max_length=*/32);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
});

GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntStringPairSet, [&] {
Expand All @@ -530,7 +551,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntStringPairSet, [&] {
/*min_length=*/3,
/*max_length=*/32);

BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
});

// Grouped MinMax
Expand All @@ -543,7 +564,7 @@ GROUP_BY_BENCHMARK(MinMaxDoublesGroupedByMediumInt, [&] {
/*nan_probability=*/args.null_proportion / 10);
auto int_key = rng.Int64(args.size, /*min=*/0, /*max=*/63);

BenchmarkGroupBy(state, {{"hash_min_max", ""}}, {input}, {int_key});
BenchmarkAggregate(state, {{"hash_min_max", ""}}, {input}, {int_key});
});

GROUP_BY_BENCHMARK(MinMaxShortStringsGroupedByMediumInt, [&] {
Expand All @@ -553,7 +574,7 @@ GROUP_BY_BENCHMARK(MinMaxShortStringsGroupedByMediumInt, [&] {
/*null_probability=*/args.null_proportion);
auto int_key = rng.Int64(args.size, /*min=*/0, /*max=*/63);

BenchmarkGroupBy(state, {{"hash_min_max", ""}}, {input}, {int_key});
BenchmarkAggregate(state, {{"hash_min_max", ""}}, {input}, {int_key});
});

GROUP_BY_BENCHMARK(MinMaxLongStringsGroupedByMediumInt, [&] {
Expand All @@ -563,7 +584,7 @@ GROUP_BY_BENCHMARK(MinMaxLongStringsGroupedByMediumInt, [&] {
/*null_probability=*/args.null_proportion);
auto int_key = rng.Int64(args.size, /*min=*/0, /*max=*/63);

BenchmarkGroupBy(state, {{"hash_min_max", ""}}, {input}, {int_key});
BenchmarkAggregate(state, {{"hash_min_max", ""}}, {input}, {int_key});
});

//
Expand Down Expand Up @@ -866,5 +887,61 @@ BENCHMARK(TDigestKernelDoubleMedian)->Apply(QuantileKernelArgs);
BENCHMARK(TDigestKernelDoubleDeciles)->Apply(QuantileKernelArgs);
BENCHMARK(TDigestKernelDoubleCentiles)->Apply(QuantileKernelArgs);

//
// Segmented Aggregate
//

static void BenchmarkSegmentedAggregate(
benchmark::State& state, int64_t num_rows, std::vector<Aggregate> aggregates,
const std::vector<std::shared_ptr<Array>>& arguments,
const std::vector<std::shared_ptr<Array>>& keys, int64_t num_segment_keys,
int64_t num_segments) {
ASSERT_GT(num_segments, 0);

auto rng = random::RandomArrayGenerator(42);
auto segment_key = rng.Int64(num_rows, /*min=*/0, /*max=*/num_segments - 1);
int64_t* values = segment_key->data()->GetMutableValues<int64_t>(1);
std::sort(values, values + num_rows);
// num_segment_keys copies of the segment key.
ArrayVector segment_keys(num_segment_keys, segment_key);

BenchmarkAggregate(state, std::move(aggregates), arguments, keys, segment_keys);
}

template <typename... Args>
static void CountScalarSegmentedByInts(benchmark::State& state, Args&&...) {
constexpr int64_t num_rows = 32 * 1024;

// A trivial column to count from.
auto arg = ConstantArrayGenerator::Zeroes(num_rows, int32());

BenchmarkSegmentedAggregate(state, num_rows, {{"count", ""}}, {arg}, /*keys=*/{},
state.range(0), state.range(1));
}
BENCHMARK(CountScalarSegmentedByInts)
->ArgNames({"SegmentKeys", "Segments"})
->ArgsProduct({{0, 1, 2}, benchmark::CreateRange(1, 256, 8)});

template <typename... Args>
static void CountGroupByIntsSegmentedByInts(benchmark::State& state, Args&&...) {
constexpr int64_t num_rows = 32 * 1024;

// A trivial column to count from.
auto arg = ConstantArrayGenerator::Zeroes(num_rows, int32());

auto rng = random::RandomArrayGenerator(42);
int64_t num_keys = state.range(0);
ArrayVector keys(num_keys);
for (auto& key : keys) {
key = rng.Int64(num_rows, /*min=*/0, /*max=*/64);
}

BenchmarkSegmentedAggregate(state, num_rows, {{"hash_count", ""}}, {arg}, keys,
state.range(1), state.range(2));
}
BENCHMARK(CountGroupByIntsSegmentedByInts)
->ArgNames({"Keys", "SegmentKeys", "Segments"})
->ArgsProduct({{1, 2}, {0, 1, 2}, benchmark::CreateRange(1, 256, 8)});

} // namespace acero
} // namespace arrow
9 changes: 3 additions & 6 deletions cpp/src/arrow/acero/aggregate_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,14 @@ void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
template <typename BatchHandler>
Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch,
const std::vector<int>& ids, const BatchHandler& handle_batch) {
int64_t offset = 0;
ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids));
ExecSpan segment_batch(segment_exec_batch);

while (true) {
ARROW_ASSIGN_OR_RAISE(compute::Segment segment,
segmenter->GetNextSegment(segment_batch, offset));
if (segment.offset >= segment_batch.length) break; // condition of no-next-segment
ARROW_ASSIGN_OR_RAISE(auto segments, segmenter->GetSegments(segment_batch));
for (const auto& segment : segments) {
ARROW_RETURN_NOT_OK(handle_batch(batch, segment));
offset = segment.offset + segment.length;
}

return Status::OK();
}

Expand Down
Loading

0 comments on commit 3d6d581

Please sign in to comment.