Skip to content

Commit

Permalink
apacheGH-43759: [C++] Acero: Minor code enhancement for Join
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Aug 19, 2024
1 parent a380d69 commit 990e9a1
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 28 deletions.
9 changes: 4 additions & 5 deletions cpp/src/arrow/acero/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,20 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr<Array> dictiona
return Status::OK();
}

dictionary_ = dictionary;
dictionary_ = std::move(dictionary);

// Initialize encoder
RowEncoder encoder;
std::vector<TypeHolder> encoder_types;
encoder_types.emplace_back(value_type_);
std::vector<TypeHolder> encoder_types{value_type_};
encoder.Init(encoder_types, ctx);

// Encode all dictionary values
int64_t length = dictionary->data()->length;
int64_t length = dictionary_->data()->length;
if (length >= std::numeric_limits<int32_t>::max()) {
return Status::Invalid(
"Dictionary length in hash join must fit into signed 32-bit integer.");
}
RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary->data()}, length)));
RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary_->data()}, length)));

std::vector<int32_t> entries_to_take;

Expand Down
16 changes: 8 additions & 8 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,30 @@ Result<std::vector<FieldRef>> HashJoinSchema::ComputePayload(
const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
// payload = (output + filter) - keys, with no duplicates
std::unordered_set<int> payload_fields;
for (auto ref : output) {
for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}

for (auto ref : filter) {
for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}

for (auto ref : keys) {
for (const auto& ref : keys) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.erase(match[0]);
}

std::vector<FieldRef> payload_refs;
for (auto ref : output) {
for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
payload_fields.erase(match[0]);
}
}
for (auto ref : filter) {
for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
Expand Down Expand Up @@ -198,7 +198,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
return Status::Invalid("Different number of key fields on left (", left_keys.size(),
") and right (", right_keys.size(), ") side of the join");
}
if (left_keys.size() < 1) {
if (left_keys.empty()) {
return Status::Invalid("Join key cannot be empty");
}
for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) {
Expand Down Expand Up @@ -432,7 +432,7 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
indices[0] -= left_schema.num_fields();
FieldPath corrected_path(std::move(indices));
if (right_seen_paths.find(*path) == right_seen_paths.end()) {
right_filter.push_back(corrected_path);
right_filter.emplace_back(corrected_path);
right_seen_paths.emplace(std::move(corrected_path));
}
} else if (left_seen_paths.find(*path) == left_seen_paths.end()) {
Expand Down Expand Up @@ -698,7 +698,7 @@ class HashJoinNode : public ExecNode, public TracedNode {
std::shared_ptr<Schema> output_schema,
std::unique_ptr<HashJoinSchema> schema_mgr, Expression filter,
std::unique_ptr<HashJoinImpl> impl)
: ExecNode(plan, inputs, {"left", "right"},
: ExecNode(plan, std::move(inputs), {"left", "right"},
/*output_schema=*/std::move(output_schema)),
TracedNode(this),
join_type_(join_options.join_type),
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/acero/hash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
const std::string& right_field_name_suffix);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }
bool LeftPayloadIsEmpty() const { return PayloadIsEmpty(0); }

bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }
bool RightPayloadIsEmpty() const { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
Expand All @@ -88,7 +88,7 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
const SchemaProjectionMap& right_to_filter,
const Expression& filter);

bool PayloadIsEmpty(int side) {
bool PayloadIsEmpty(int side) const {
assert(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}
Expand Down
21 changes: 9 additions & 12 deletions cpp/src/arrow/compute/light_array_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,9 @@ Result<KeyColumnMetadata> ColumnMetadataFromDataType(
const std::shared_ptr<DataType>& type) {
const bool is_extension = type->id() == Type::EXTENSION;
const std::shared_ptr<DataType>& typ =
is_extension
? arrow::internal::checked_pointer_cast<ExtensionType>(type->GetSharedPtr())
->storage_type()
: type;
is_extension ? arrow::internal::checked_cast<const ExtensionType*>(type.get())
->storage_type()
: type;

if (typ->id() == Type::DICTIONARY) {
auto bit_width =
Expand Down Expand Up @@ -323,8 +322,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
}

Status ResizableArrayData::ResizeVaryingLengthBuffer() {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

if (!column_metadata.is_fixed_length) {
int64_t min_new_size = buffers_[kFixedLengthBuffer]->data_as<int32_t>()[num_rows_];
Expand All @@ -343,20 +341,19 @@ Status ResizableArrayData::ResizeVaryingLengthBuffer() {
}

KeyColumnArray ResizableArrayData::column_array() const {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
return KeyColumnArray(column_metadata, num_rows_,
buffers_[kValidityBuffer]->mutable_data(),
buffers_[kFixedLengthBuffer]->mutable_data(),
buffers_[kVariableLengthBuffer]->mutable_data());
}

std::shared_ptr<ArrayData> ResizableArrayData::array_data() const {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

auto valid_count = arrow::internal::CountSetBits(
buffers_[kValidityBuffer]->data(), /*offset=*/0, static_cast<int64_t>(num_rows_));
auto valid_count =
arrow::internal::CountSetBits(buffers_[kValidityBuffer]->data(), /*bit_offset=*/0,
static_cast<int64_t>(num_rows_));
int null_count = static_cast<int>(num_rows_) - static_cast<int>(valid_count);

if (column_metadata.is_fixed_length) {
Expand Down

0 comments on commit 990e9a1

Please sign in to comment.