Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43759: [C++] Acero: Minor code enhancement for Join #43760

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions cpp/src/arrow/acero/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,20 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr<Array> dictiona
return Status::OK();
}

dictionary_ = dictionary;
dictionary_ = std::move(dictionary);

// Initialize encoder
RowEncoder encoder;
std::vector<TypeHolder> encoder_types;
encoder_types.emplace_back(value_type_);
std::vector<TypeHolder> encoder_types{value_type_};
encoder.Init(encoder_types, ctx);

// Encode all dictionary values
int64_t length = dictionary->data()->length;
int64_t length = dictionary_->data()->length;
if (length >= std::numeric_limits<int32_t>::max()) {
return Status::Invalid(
"Dictionary length in hash join must fit into signed 32-bit integer.");
}
RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary->data()}, length)));
RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary_->data()}, length)));

std::vector<int32_t> entries_to_take;

Expand Down
16 changes: 8 additions & 8 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,30 @@ Result<std::vector<FieldRef>> HashJoinSchema::ComputePayload(
const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
// payload = (output + filter) - keys, with no duplicates
std::unordered_set<int> payload_fields;
for (auto ref : output) {
for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}

for (auto ref : filter) {
for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}

for (auto ref : keys) {
for (const auto& ref : keys) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.erase(match[0]);
}

std::vector<FieldRef> payload_refs;
for (auto ref : output) {
for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
payload_fields.erase(match[0]);
}
}
for (auto ref : filter) {
for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
Expand Down Expand Up @@ -198,7 +198,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
return Status::Invalid("Different number of key fields on left (", left_keys.size(),
") and right (", right_keys.size(), ") side of the join");
}
if (left_keys.size() < 1) {
if (left_keys.empty()) {
return Status::Invalid("Join key cannot be empty");
}
for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) {
Expand Down Expand Up @@ -432,7 +432,7 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
indices[0] -= left_schema.num_fields();
FieldPath corrected_path(std::move(indices));
if (right_seen_paths.find(*path) == right_seen_paths.end()) {
right_filter.push_back(corrected_path);
right_filter.emplace_back(corrected_path);
right_seen_paths.emplace(std::move(corrected_path));
}
} else if (left_seen_paths.find(*path) == left_seen_paths.end()) {
Expand Down Expand Up @@ -698,7 +698,7 @@ class HashJoinNode : public ExecNode, public TracedNode {
std::shared_ptr<Schema> output_schema,
std::unique_ptr<HashJoinSchema> schema_mgr, Expression filter,
std::unique_ptr<HashJoinImpl> impl)
: ExecNode(plan, inputs, {"left", "right"},
: ExecNode(plan, std::move(inputs), {"left", "right"},
/*output_schema=*/std::move(output_schema)),
TracedNode(this),
join_type_(join_options.join_type),
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/acero/hash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
const std::string& right_field_name_suffix);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }
bool LeftPayloadIsEmpty() const { return PayloadIsEmpty(0); }

bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }
bool RightPayloadIsEmpty() const { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
Expand All @@ -88,7 +88,7 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
const SchemaProjectionMap& right_to_filter,
const Expression& filter);

bool PayloadIsEmpty(int side) {
bool PayloadIsEmpty(int side) const {
assert(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ Result<std::shared_ptr<ArrayData>> JoinResultMaterialize::FlushBuildColumn(
const std::shared_ptr<DataType>& data_type, const RowArray* row_array, int column_id,
uint32_t* row_ids) {
ResizableArrayData output;
output.Init(data_type, pool_, bit_util::Log2(num_rows_));
RETURN_NOT_OK(output.Init(data_type, pool_, bit_util::Log2(num_rows_)));

for (size_t i = 0; i <= null_ranges_.size(); ++i) {
int row_id_begin =
Expand Down Expand Up @@ -2247,8 +2247,9 @@ Result<ExecBatch> JoinResidualFilter::MaterializeFilterInput(
build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
for (int i = 0; i < num_build_cols; ++i) {
ResizableArrayData column_data;
column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i), pool_,
bit_util::Log2(num_batch_rows));
RETURN_NOT_OK(
column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i),
pool_, bit_util::Log2(num_batch_rows)));
if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) {
RETURN_NOT_OK(build_keys_->DecodeSelected(&column_data, idx, num_batch_rows,
key_ids_maybe_null, pool_));
Expand Down
68 changes: 32 additions & 36 deletions cpp/src/arrow/compute/light_array_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,9 @@ Result<KeyColumnMetadata> ColumnMetadataFromDataType(
const std::shared_ptr<DataType>& type) {
const bool is_extension = type->id() == Type::EXTENSION;
const std::shared_ptr<DataType>& typ =
is_extension
? arrow::internal::checked_pointer_cast<ExtensionType>(type->GetSharedPtr())
->storage_type()
: type;
is_extension ? arrow::internal::checked_cast<const ExtensionType*>(type.get())
->storage_type()
: type;

if (typ->id() == Type::DICTIONARY) {
auto bit_width =
Expand Down Expand Up @@ -205,22 +204,25 @@ Status ColumnArraysFromExecBatch(const ExecBatch& batch,
column_arrays);
}

void ResizableArrayData::Init(const std::shared_ptr<DataType>& data_type,
MemoryPool* pool, int log_num_rows_min) {
Status ResizableArrayData::Init(const std::shared_ptr<DataType>& data_type,
MemoryPool* pool, int log_num_rows_min) {
#ifndef NDEBUG
if (num_rows_allocated_ > 0) {
ARROW_DCHECK(data_type_ != NULLPTR);
KeyColumnMetadata metadata_before =
ColumnMetadataFromDataType(data_type_).ValueOrDie();
KeyColumnMetadata metadata_after = ColumnMetadataFromDataType(data_type).ValueOrDie();
ARROW_DCHECK(data_type_ != nullptr);
const KeyColumnMetadata& metadata_before = column_metadata_;
ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata_after,
ColumnMetadataFromDataType(data_type));
ARROW_DCHECK(metadata_before.is_fixed_length == metadata_after.is_fixed_length &&
metadata_before.fixed_length == metadata_after.fixed_length);
}
#endif
ARROW_DCHECK(data_type != nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid duplicating this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ARROW_ASSIGN_OR_RAISE(column_metadata_, ColumnMetadataFromDataType(data_type));
Clear(/*release_buffers=*/false);
log_num_rows_min_ = log_num_rows_min;
data_type_ = data_type;
pool_ = pool;
return Status::OK();
}

void ResizableArrayData::Clear(bool release_buffers) {
Expand All @@ -246,8 +248,6 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
num_rows_allocated_new *= 2;
}

KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

if (buffers_[kFixedLengthBuffer] == NULLPTR) {
ARROW_DCHECK(buffers_[kValidityBuffer] == NULLPTR &&
buffers_[kVariableLengthBuffer] == NULLPTR);
Expand All @@ -258,8 +258,8 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes, pool_));
memset(mutable_data(kValidityBuffer), 0,
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes);
if (column_metadata.is_fixed_length) {
if (column_metadata.fixed_length == 0) {
if (column_metadata_.is_fixed_length) {
if (column_metadata_.fixed_length == 0) {
ARROW_ASSIGN_OR_RAISE(
buffers_[kFixedLengthBuffer],
AllocateResizableBuffer(
Expand All @@ -271,7 +271,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
ARROW_ASSIGN_OR_RAISE(
buffers_[kFixedLengthBuffer],
AllocateResizableBuffer(
num_rows_allocated_new * column_metadata.fixed_length + kNumPaddingBytes,
num_rows_allocated_new * column_metadata_.fixed_length + kNumPaddingBytes,
pool_));
}
} else {
Expand Down Expand Up @@ -300,15 +300,15 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
memset(mutable_data(kValidityBuffer) + bytes_for_bits_before, 0,
bytes_for_bits_after - bytes_for_bits_before);

if (column_metadata.is_fixed_length) {
if (column_metadata.fixed_length == 0) {
if (column_metadata_.is_fixed_length) {
if (column_metadata_.fixed_length == 0) {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes));
memset(mutable_data(kFixedLengthBuffer) + bytes_for_bits_before, 0,
bytes_for_bits_after - bytes_for_bits_before);
} else {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
num_rows_allocated_new * column_metadata.fixed_length + kNumPaddingBytes));
num_rows_allocated_new * column_metadata_.fixed_length + kNumPaddingBytes));
}
} else {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
Expand All @@ -323,10 +323,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
}

Status ResizableArrayData::ResizeVaryingLengthBuffer() {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

if (!column_metadata.is_fixed_length) {
if (!column_metadata_.is_fixed_length) {
int64_t min_new_size = buffers_[kFixedLengthBuffer]->data_as<int32_t>()[num_rows_];
ARROW_DCHECK(var_len_buf_size_ > 0);
if (var_len_buf_size_ < min_new_size) {
Expand All @@ -343,23 +340,19 @@ Status ResizableArrayData::ResizeVaryingLengthBuffer() {
}

KeyColumnArray ResizableArrayData::column_array() const {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
return KeyColumnArray(column_metadata, num_rows_,
return KeyColumnArray(column_metadata_, num_rows_,
buffers_[kValidityBuffer]->mutable_data(),
buffers_[kFixedLengthBuffer]->mutable_data(),
buffers_[kVariableLengthBuffer]->mutable_data());
}

std::shared_ptr<ArrayData> ResizableArrayData::array_data() const {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

auto valid_count = arrow::internal::CountSetBits(
buffers_[kValidityBuffer]->data(), /*offset=*/0, static_cast<int64_t>(num_rows_));
auto valid_count =
arrow::internal::CountSetBits(buffers_[kValidityBuffer]->data(), /*bit_offset=*/0,
static_cast<int64_t>(num_rows_));
int null_count = static_cast<int>(num_rows_) - static_cast<int>(valid_count);

if (column_metadata.is_fixed_length) {
if (column_metadata_.is_fixed_length) {
return ArrayData::Make(data_type_, num_rows_,
{buffers_[kValidityBuffer], buffers_[kFixedLengthBuffer]},
null_count);
Expand Down Expand Up @@ -493,10 +486,12 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
ARROW_DCHECK(num_rows_before >= 0);
int num_rows_after = num_rows_before + num_rows_to_append;
if (target->num_rows() == 0) {
target->Init(source->type, pool, kLogNumRows);
RETURN_NOT_OK(target->Init(source->type, pool, kLogNumRows));
}
RETURN_NOT_OK(target->ResizeFixedLengthBuffers(num_rows_after));

// Since target->Init is called before, we can assume that the ColumnMetadata
// would never fail to be created
KeyColumnMetadata column_metadata =
ColumnMetadataFromDataType(source->type).ValueOrDie();

Expand Down Expand Up @@ -647,11 +642,12 @@ Status ExecBatchBuilder::AppendNulls(const std::shared_ptr<DataType>& type,
int num_rows_before = target.num_rows();
int num_rows_after = num_rows_before + num_rows_to_append;
if (target.num_rows() == 0) {
target.Init(type, pool, kLogNumRows);
RETURN_NOT_OK(target.Init(type, pool, kLogNumRows));
}
RETURN_NOT_OK(target.ResizeFixedLengthBuffers(num_rows_after));

KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(type).ValueOrDie();
ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata column_metadata,
ColumnMetadataFromDataType(type));

// Process fixed length buffer
//
Expand Down Expand Up @@ -708,7 +704,7 @@ Status ExecBatchBuilder::AppendSelected(MemoryPool* pool, const ExecBatch& batch
const Datum& data = batch.values[col_ids ? col_ids[i] : i];
ARROW_DCHECK(data.is_array());
const std::shared_ptr<ArrayData>& array_data = data.array();
values_[i].Init(array_data->type, pool, kLogNumRows);
RETURN_NOT_OK(values_[i].Init(array_data->type, pool, kLogNumRows));
}
}

Expand Down Expand Up @@ -739,7 +735,7 @@ Status ExecBatchBuilder::AppendNulls(MemoryPool* pool,
if (values_.empty()) {
values_.resize(types.size());
for (size_t i = 0; i < types.size(); ++i) {
values_[i].Init(types[i], pool, kLogNumRows);
RETURN_NOT_OK(values_[i].Init(types[i], pool, kLogNumRows));
}
}

Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/compute/light_array_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ class ARROW_EXPORT ResizableArrayData {
/// \param pool The pool to make allocations on
/// \param log_num_rows_min All resize operations will allocate at least enough
/// space for (1 << log_num_rows_min) rows
void Init(const std::shared_ptr<DataType>& data_type, MemoryPool* pool,
int log_num_rows_min);
Status Init(const std::shared_ptr<DataType>& data_type, MemoryPool* pool,
int log_num_rows_min);

/// \brief Resets the array back to an empty state
/// \param release_buffers If true then allocated memory is released and the
Expand Down Expand Up @@ -351,6 +351,8 @@ class ARROW_EXPORT ResizableArrayData {
static constexpr int64_t kNumPaddingBytes = 64;
int log_num_rows_min_;
std::shared_ptr<DataType> data_type_;
// Would be valid if data_type_ != NULLPTR.
KeyColumnMetadata column_metadata_{};
MemoryPool* pool_;
int num_rows_;
int num_rows_allocated_;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/light_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ TEST(ResizableArrayData, Basic) {
arrow::internal::checked_pointer_cast<FixedWidthType>(type)->bit_width() / 8;
{
ResizableArrayData array;
array.Init(type, pool.get(), /*log_num_rows_min=*/16);
ASSERT_OK(array.Init(type, pool.get(), /*log_num_rows_min=*/16));
ASSERT_EQ(0, array.num_rows());
ASSERT_OK(array.ResizeFixedLengthBuffers(2));
ASSERT_EQ(2, array.num_rows());
Expand Down Expand Up @@ -330,7 +330,7 @@ TEST(ResizableArrayData, Binary) {
ARROW_SCOPED_TRACE("Type: ", type->ToString());
{
ResizableArrayData array;
array.Init(type, pool.get(), /*log_num_rows_min=*/4);
ASSERT_OK(array.Init(type, pool.get(), /*log_num_rows_min=*/4));
ASSERT_EQ(0, array.num_rows());
ASSERT_OK(array.ResizeFixedLengthBuffers(2));
ASSERT_EQ(2, array.num_rows());
Expand Down
Loading