Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43911: [C++] Compute Row: ListKeyEncoder Supports #43912

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions cpp/src/arrow/acero/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,29 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::string ToString() const override { return "HashJoinBasicImpl"; }

private:
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
Status InitEncoder(int side, HashJoinProjection projection_handle,
RowEncoder* encoder) {
std::vector<TypeHolder> data_types;
int num_cols = schema_[side]->num_cols(projection_handle);
data_types.resize(num_cols);
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] = schema_[side]->data_type(projection_handle, icol);
}
encoder->Init(data_types, ctx_->exec_context());
RETURN_NOT_OK(encoder->Init(data_types, ctx_->exec_context()));
encoder->Clear();
return Status::OK();
}

Status InitLocalStateIfNeeded(size_t thread_index) {
DCHECK_LT(thread_index, local_states_.size());
ThreadLocalState& local_state = local_states_[thread_index];
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
RETURN_NOT_OK(
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys));
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
RETURN_NOT_OK(InitEncoder(0, HashJoinProjection::PAYLOAD,
&local_state.exec_batch_payloads));
}
local_state.is_initialized = true;
}
Expand Down Expand Up @@ -512,8 +516,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, *schema_[0], *schema_[1], ctx_->exec_context());
ARROW_ASSIGN_OR_RAISE(
bool use_key_batch_for_dicts,
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1],
ctx_->exec_context()));
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
Expand Down Expand Up @@ -563,10 +569,11 @@ class HashJoinBasicImpl : public HashJoinImpl {

Status BuildHashTable_exec_task(size_t thread_index, int64_t /*task_id*/) {
AccumulationQueue batches = std::move(build_batches_);
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_->exec_context());
RETURN_NOT_OK(
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_->exec_context()));
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
RETURN_NOT_OK(InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_));
}
hash_table_empty_ = true;

Expand Down
36 changes: 20 additions & 16 deletions cpp/src/arrow/acero/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr<Array> dictiona
// Initialize encoder
RowEncoder encoder;
std::vector<TypeHolder> encoder_types{value_type_};
encoder.Init(encoder_types, ctx);
RETURN_NOT_OK(encoder.Init(encoder_types, ctx));

// Encode all dictionary values
int64_t length = dictionary_->data()->length;
Expand Down Expand Up @@ -290,7 +290,7 @@ Result<std::shared_ptr<ArrayData>> HashJoinDictBuild::RemapInputValues(
//
RowEncoder encoder;
std::vector<TypeHolder> encoder_types = {value_type_};
encoder.Init(encoder_types, ctx);
RETURN_NOT_OK(encoder.Init(encoder_types, ctx));

// Encode all
//
Expand Down Expand Up @@ -426,7 +426,7 @@ Result<std::shared_ptr<ArrayData>> HashJoinDictProbe::RemapInput(
opt_build_side->RemapInputValues(ctx, Datum(dict->data()), dict->length()));
} else {
std::vector<TypeHolder> encoder_types = {dict_type.value_type()};
encoder_.Init(encoder_types, ctx);
RETURN_NOT_OK(encoder_.Init(encoder_types, ctx));
RETURN_NOT_OK(
encoder_.EncodeAndAppend(ExecSpan({*dict->data()}, dict->length())));
}
Expand Down Expand Up @@ -514,7 +514,7 @@ Status HashJoinDictBuildMulti::Init(
return Status::OK();
}

void HashJoinDictBuildMulti::InitEncoder(
Status HashJoinDictBuildMulti::InitEncoder(
const SchemaProjectionMaps<HashJoinProjection>& proj_map, RowEncoder* encoder,
ExecContext* ctx) {
int num_cols = proj_map.num_cols(HashJoinProjection::KEY);
Expand All @@ -525,9 +525,9 @@ void HashJoinDictBuildMulti::InitEncoder(
if (HashJoinDictBuild::KeyNeedsProcessing(data_type)) {
data_type = HashJoinDictBuild::DataTypeAfterRemapping();
}
data_types[icol] = data_type;
data_types[icol] = std::move(data_type);
}
encoder->Init(data_types, ctx);
return encoder->Init(data_types, ctx);
}

Status HashJoinDictBuildMulti::EncodeBatch(
Expand Down Expand Up @@ -568,20 +568,21 @@ Status HashJoinDictBuildMulti::PostDecode(

void HashJoinDictProbeMulti::Init(size_t num_threads) {
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
for (auto& local_state : local_states_) {
local_state.is_initialized = false;
}
}

bool HashJoinDictProbeMulti::BatchRemapNeeded(
Result<bool> HashJoinDictProbeMulti::BatchRemapNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
RETURN_NOT_OK(
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx));
DCHECK_LT(thread_index, local_states_.size());
return local_states_[thread_index].any_needs_remap;
}

void HashJoinDictProbeMulti::InitLocalStateIfNeeded(
Status HashJoinDictProbeMulti::InitLocalStateIfNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
ThreadLocalState& local_state = local_states_[thread_index];
Expand All @@ -603,11 +604,13 @@ void HashJoinDictProbeMulti::InitLocalStateIfNeeded(
}

if (local_state.any_needs_remap) {
InitEncoder(proj_map_probe, proj_map_build, &local_state.post_remap_encoder, ctx);
RETURN_NOT_OK(InitEncoder(proj_map_probe, proj_map_build,
&local_state.post_remap_encoder, ctx));
}
return Status::OK();
}

void HashJoinDictProbeMulti::InitEncoder(
Status HashJoinDictProbeMulti::InitEncoder(
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, RowEncoder* encoder,
ExecContext* ctx) {
Expand All @@ -616,14 +619,14 @@ void HashJoinDictProbeMulti::InitEncoder(
for (int icol = 0; icol < num_cols; ++icol) {
std::shared_ptr<DataType> data_type =
proj_map_probe.data_type(HashJoinProjection::KEY, icol);
std::shared_ptr<DataType> build_data_type =
const std::shared_ptr<DataType>& build_data_type =
proj_map_build.data_type(HashJoinProjection::KEY, icol);
if (HashJoinDictProbe::KeyNeedsProcessing(data_type, build_data_type)) {
data_type = HashJoinDictProbe::DataTypeAfterRemapping(build_data_type);
}
data_types[icol] = data_type;
}
encoder->Init(data_types, ctx);
return encoder->Init(data_types, ctx);
}

Status HashJoinDictProbeMulti::EncodeBatch(
Expand All @@ -632,7 +635,8 @@ Status HashJoinDictProbeMulti::EncodeBatch(
const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch,
RowEncoder** out_encoder, ExecBatch* opt_out_key_batch, ExecContext* ctx) {
ThreadLocalState& local_state = local_states_[thread_index];
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
RETURN_NOT_OK(
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx));

ExecBatch projected({}, batch.length);
int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY);
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/arrow/acero/hash_join_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ class HashJoinDictBuildMulti {
public:
Status Init(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
const ExecBatch* opt_non_empty_batch, ExecContext* ctx);
static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
RowEncoder* encoder, ExecContext* ctx);
static Status InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
RowEncoder* encoder, ExecContext* ctx);
Status EncodeBatch(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map,
const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const;
Expand All @@ -280,10 +280,9 @@ class HashJoinDictBuildMulti {
class HashJoinDictProbeMulti {
public:
void Init(size_t num_threads);
bool BatchRemapNeeded(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
ExecContext* ctx);
Result<bool> BatchRemapNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx);
Status EncodeBatch(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
Expand All @@ -292,12 +291,13 @@ class HashJoinDictProbeMulti {
ExecContext* ctx);

private:
void InitLocalStateIfNeeded(
Status InitLocalStateIfNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx);
static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
RowEncoder* encoder, ExecContext* ctx);
static Status InitEncoder(
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, RowEncoder* encoder,
ExecContext* ctx);
struct ThreadLocalState {
bool is_initialized;
// Whether any key column needs remapping (because of dictionaries used) before doing
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ std::vector<std::shared_ptr<Array>> GenRandomUniqueRecords(
val_types.push_back(result[i]->type());
}
RowEncoder encoder;
encoder.Init(val_types, ctx);
auto s = encoder.Init(val_types, ctx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please at least use DCHECK_OK.

ExecBatch batch({}, num_desired);
batch.values.resize(result.size());
for (size_t i = 0; i < result.size(); ++i) {
Expand Down
106 changes: 64 additions & 42 deletions cpp/src/arrow/compute/row/row_encoder_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,65 @@ using internal::FirstTimeBitmapWriter;
namespace compute {
namespace internal {

Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(
const TypeHolder& column_type, std::shared_ptr<ExtensionType>* extension_type,
MemoryPool* pool) {
const bool is_extension = column_type.id() == Type::EXTENSION;
const TypeHolder& type =
is_extension ? arrow::internal::checked_cast<const ExtensionType*>(column_type.type)
->storage_type()
: column_type;

if (is_extension) {
*extension_type =
arrow::internal::checked_pointer_cast<ExtensionType>(column_type.GetSharedPtr());
}
if (type.id() == Type::BOOL) {
return std::make_shared<BooleanKeyEncoder>();
}

if (type.id() == Type::DICTIONARY) {
return std::make_shared<DictionaryKeyEncoder>(type.GetSharedPtr(), pool);
}

if (is_fixed_width(type.id())) {
return std::make_shared<FixedWidthKeyEncoder>(type.GetSharedPtr());
}

if (is_binary_like(type.id())) {
return std::make_shared<VarLengthKeyEncoder<BinaryType>>(type.GetSharedPtr());
}

if (is_large_binary_like(type.id())) {
return std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(type.GetSharedPtr());
}

if (is_list(type.id())) {
auto element_type =
::arrow::checked_cast<const BaseListType*>(type.type)->value_type();
if (is_nested(element_type->id())) {
return Status::NotImplemented("Unsupported nested type in List for row encoder",
type.ToString());
}
if (type.id() == Type::FIXED_SIZE_LIST) {
return Status::NotImplemented("Unsupported FixedSizeList for row encoder",
type.ToString());
}
std::shared_ptr<ExtensionType> element_extension_type;
ARROW_ASSIGN_OR_RAISE(auto element_encoder,
MakeKeyEncoder(element_type, &element_extension_type, pool));
if (type.id() == Type::LIST) {
return std::make_shared<ListKeyEncoder<ListType>>(
type.type->GetSharedPtr(), std::move(element_type), std::move(element_encoder));
}
ARROW_CHECK(type.id() == Type::LARGE_LIST);
return std::make_shared<ListKeyEncoder<LargeListType>>(
type.type->GetSharedPtr(), std::move(element_type), std::move(element_encoder));
}

return Status::NotImplemented("Unsupported type for row encoder", type.ToString());
}

// extract the null bitmap from the leading nullity bytes of encoded keys
Status KeyEncoder::DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
std::shared_ptr<Buffer>* null_bitmap,
Expand Down Expand Up @@ -256,53 +315,15 @@ Result<std::shared_ptr<ArrayData>> DictionaryKeyEncoder::Decode(uint8_t** encode
return data;
}

void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
Status RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
ctx_ = ctx;
encoders_.resize(column_types.size());
extension_types_.resize(column_types.size());

for (size_t i = 0; i < column_types.size(); ++i) {
const bool is_extension = column_types[i].id() == Type::EXTENSION;
const TypeHolder& type =
is_extension
? arrow::internal::checked_cast<const ExtensionType*>(column_types[i].type)
->storage_type()
: column_types[i];

if (is_extension) {
extension_types_[i] = arrow::internal::checked_pointer_cast<ExtensionType>(
column_types[i].GetSharedPtr());
}
if (type.id() == Type::BOOL) {
encoders_[i] = std::make_shared<BooleanKeyEncoder>();
continue;
}

if (type.id() == Type::DICTIONARY) {
encoders_[i] =
std::make_shared<DictionaryKeyEncoder>(type.GetSharedPtr(), ctx->memory_pool());
continue;
}

if (is_fixed_width(type.id())) {
encoders_[i] = std::make_shared<FixedWidthKeyEncoder>(type.GetSharedPtr());
continue;
}

if (is_binary_like(type.id())) {
encoders_[i] =
std::make_shared<VarLengthKeyEncoder<BinaryType>>(type.GetSharedPtr());
continue;
}

if (is_large_binary_like(type.id())) {
encoders_[i] =
std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(type.GetSharedPtr());
continue;
}

// We should not get here
ARROW_DCHECK(false);
ARROW_ASSIGN_OR_RAISE(
encoders_[i],
MakeKeyEncoder(column_types[i], &extension_types_[i], ctx->memory_pool()));
}

int32_t total_length = 0;
Expand All @@ -314,6 +335,7 @@ void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext*
for (size_t i = 0; i < column_types.size(); ++i) {
encoders_[i]->EncodeNull(&buf_ptr);
}
return Status::OK();
}

void RowEncoder::Clear() {
Expand Down
Loading
Loading