diff --git a/cpp/src/arrow/acero/hash_join_node.cc b/cpp/src/arrow/acero/hash_join_node.cc index 254dad361ff87..8924bb5451813 100644 --- a/cpp/src/arrow/acero/hash_join_node.cc +++ b/cpp/src/arrow/acero/hash_join_node.cc @@ -45,13 +45,19 @@ using compute::KeyColumnArray; namespace acero { // Check if a type is supported in a join (as either a key or non-key column) -bool HashJoinSchema::IsTypeSupported(const DataType& type) { +bool HashJoinSchema::IsTypeSupported(const DataType& type, bool is_key) { const Type::type id = type.id(); if (id == Type::DICTIONARY) { - return IsTypeSupported(*checked_cast(type).value_type()); + return IsTypeSupported(*checked_cast(type).value_type(), + is_key); } if (id == Type::EXTENSION) { - return IsTypeSupported(*checked_cast(type).storage_type()); + return IsTypeSupported(*checked_cast(type).storage_type(), + is_key); + } + // If it's a key column, do not support NULL (Type::NA) + if (id == Type::NA && !is_key) { + return true; } return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id); } @@ -214,7 +220,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc const FieldPath& match = result.ValueUnsafe(); const std::shared_ptr& type = (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); - if (!IsTypeSupported(*type)) { + if (!IsTypeSupported(*type, true)) { return Status::Invalid("Data type ", *type, " is not supported in join key field"); } } @@ -234,14 +240,14 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc } for (const auto& field : left_schema.fields()) { const auto& type = *field->type(); - if (!IsTypeSupported(type)) { + if (!IsTypeSupported(type, false)) { return Status::Invalid("Data type ", type, " is not supported in join non-key field ", field->name()); } } for (const auto& field : right_schema.fields()) { const auto& type = *field->type(); - if (!IsTypeSupported(type)) { + if (!IsTypeSupported(type, false)) { return Status::Invalid("Data type ", type, " is not supported in join non-key field ", field->name()); } diff --git a/cpp/src/arrow/acero/hash_join_node.h b/cpp/src/arrow/acero/hash_join_node.h index cca64d59830b2..dc6e61e0109ab 100644 --- a/cpp/src/arrow/acero/hash_join_node.h +++ b/cpp/src/arrow/acero/hash_join_node.h @@ -75,7 +75,7 @@ class ARROW_ACERO_EXPORT HashJoinSchema { SchemaProjectionMaps proj_maps[2]; private: - static bool IsTypeSupported(const DataType& type); + static bool IsTypeSupported(const DataType& type, bool is_key); Status CollectFilterColumns(std::vector& left_filter, std::vector& right_filter, diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 58551f4eca00a..fd45fd83f4430 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -254,7 +254,7 @@ struct RandomDataTypeConstraints { void Default() { data_type_enabled_mask = - kInt1 | kInt2 | kInt4 | kInt8 | kBool | kBinary | kString | kLargeString; + kInt1 | kInt2 | kInt4 | kInt8 | kBool | kBinary | kString | kLargeString | kNull; min_null_probability = 0.0; max_null_probability = 0.2; min_binary_length = 1; @@ -282,6 +282,8 @@ struct RandomDataTypeConstraints { } } + void withoutNullColumn() { data_type_enabled_mask = ~kNull; } + // Data type mask constants static constexpr int64_t kInt1 = 1; static constexpr int64_t kInt2 = 2; @@ -291,6 +293,7 @@ struct RandomDataTypeConstraints { static constexpr int64_t kBinary = 32; static constexpr int64_t kString = 64; static constexpr int64_t kLargeString = 128; + static constexpr int64_t kNull = 256; }; struct RandomDataType { @@ -300,10 +303,18 @@ struct RandomDataType { int min_string_length; int max_string_length; bool is_large_string; + bool is_null_type; static RandomDataType Random(Random64Bit& rng, const RandomDataTypeConstraints& constraints) { RandomDataType result; + if (constraints.data_type_enabled_mask & constraints.kNull) { + // 10% chance of null type column + result.is_null_type = ((rng.next() % 100) < 10); + } else { + result.is_null_type = false; + } + if ((constraints.data_type_enabled_mask & constraints.kString) != 0) { if (constraints.data_type_enabled_mask != constraints.kString) { // Both string and fixed length types enabled @@ -386,7 +397,9 @@ std::vector> GenRandomRecords( std::vector> result; random::RandomArrayGenerator rag(static_cast(rng.next())); for (size_t i = 0; i < data_types.size(); ++i) { - if (data_types[i].is_fixed_length) { + if (data_types[i].is_null_type) { + result.push_back(std::make_shared(num_rows)); + } else if (data_types[i].is_fixed_length) { switch (data_types[i].fixed_length) { case 0: result.push_back(rag.Boolean(num_rows, 0.5, data_types[i].null_probability)); @@ -465,15 +478,20 @@ void TakeUsingVector(ExecContext* ctx, const std::vector> AllocateBitmap(indices.size(), ctx->memory_pool())); uint8_t* non_nulls = null_buf->mutable_data(); memset(non_nulls, 0xFF, bit_util::BytesForBits(indices.size())); - if ((*result)[i]->data()->buffers.size() == 2) { - (*result)[i] = MakeArray( - ArrayData::Make((*result)[i]->type(), indices.size(), - {std::move(null_buf), (*result)[i]->data()->buffers[1]})); + if ((*result)[i]->type()->id() == Type::NA) { + (*result)[i] = MakeArray(ArrayData::Make((*result)[i]->type(), indices.size(), + {std::move(null_buf)})); } else { - (*result)[i] = MakeArray( - ArrayData::Make((*result)[i]->type(), indices.size(), - {std::move(null_buf), (*result)[i]->data()->buffers[1], - (*result)[i]->data()->buffers[2]})); + if ((*result)[i]->data()->buffers.size() == 2) { + (*result)[i] = MakeArray( + ArrayData::Make((*result)[i]->type(), indices.size(), + {std::move(null_buf), (*result)[i]->data()->buffers[1]})); + } else { + (*result)[i] = MakeArray( + ArrayData::Make((*result)[i]->type(), indices.size(), + {std::move(null_buf), (*result)[i]->data()->buffers[1], + (*result)[i]->data()->buffers[2]})); + } } } (*result)[i]->data()->SetNullCount(kUnknownNullCount); @@ -481,6 +499,9 @@ void TakeUsingVector(ExecContext* ctx, const std::vector> for (size_t i = 0; i < indices.size(); ++i) { if (indices[i] < 0) { for (size_t col = 0; col < result->size(); ++col) { + if ((*result)[col]->data()->buffers[0] == NULLPTR) { + continue; + } uint8_t* non_nulls = (*result)[col]->data()->buffers[0]->mutable_data(); bit_util::ClearBit(non_nulls, i); } @@ -989,8 +1010,11 @@ TEST(HashJoin, Random) { default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); // Constraints - RandomDataTypeConstraints type_constraints; - type_constraints.Default(); + RandomDataTypeConstraints key_type_constraints; + key_type_constraints.Default(); + key_type_constraints.withoutNullColumn(); + RandomDataTypeConstraints non_key_type_constraints; + non_key_type_constraints.Default(); // type_constraints.OnlyInt(1, true); constexpr int max_num_key_fields = 3; constexpr int max_num_payload_fields = 3; @@ -1017,7 +1041,7 @@ TEST(HashJoin, Random) { int num_key_fields = rng.from_range(1, max_num_key_fields); RandomDataTypeVector key_types; for (int i = 0; i < num_key_fields; ++i) { - key_types.AddRandom(rng, type_constraints); + key_types.AddRandom(rng, key_type_constraints); } // Generate lists of payload data types @@ -1026,7 +1050,7 @@ TEST(HashJoin, Random) { for (int i = 0; i < 2; ++i) { num_payload_fields[i] = rng.from_range(0, max_num_payload_fields); for (int j = 0; j < num_payload_fields[i]; ++j) { - payload_types[i].AddRandom(rng, type_constraints); + payload_types[i].AddRandom(rng, non_key_type_constraints); } } diff --git a/cpp/src/arrow/compute/kernels/row_encoder.cc b/cpp/src/arrow/compute/kernels/row_encoder.cc index 8224eaa6d6315..55d74d3e3349a 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.cc +++ b/cpp/src/arrow/compute/kernels/row_encoder.cc @@ -272,6 +272,12 @@ void RowEncoder::Init(const std::vector& column_types, ExecContext* extension_types_[i] = arrow::internal::checked_pointer_cast( column_types[i].GetSharedPtr()); } + + if (type.id() == Type::NA) { + encoders_[i] = std::make_shared(); + continue; + } + if (type.id() == Type::BOOL) { encoders_[i] = std::make_shared(); continue; diff --git a/cpp/src/arrow/compute/light_array.cc b/cpp/src/arrow/compute/light_array.cc index 4e8b2b2d7cc3a..70f5a39095af5 100644 --- a/cpp/src/arrow/compute/light_array.cc +++ b/cpp/src/arrow/compute/light_array.cc @@ -34,9 +34,19 @@ KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata, int64_t length "This class was intended to be a POD type"); metadata_ = metadata; length_ = length; - buffers_[kValidityBuffer] = validity_buffer; - buffers_[kFixedLengthBuffer] = fixed_length_buffer; - buffers_[kVariableLengthBuffer] = var_length_buffer; + + // Check if the column is of Null Type + if (metadata.is_null_type) { + // For Null type columns, only the validity buffer is relevant. + buffers_[kValidityBuffer] = validity_buffer; + buffers_[kFixedLengthBuffer] = nullptr; + buffers_[kVariableLengthBuffer] = nullptr; + } else { + buffers_[kValidityBuffer] = validity_buffer; + buffers_[kFixedLengthBuffer] = fixed_length_buffer; + buffers_[kVariableLengthBuffer] = var_length_buffer; + } + mutable_buffers_[kValidityBuffer] = mutable_buffers_[kFixedLengthBuffer] = mutable_buffers_[kVariableLengthBuffer] = nullptr; bit_offset_[kValidityBuffer] = bit_offset_validity; @@ -158,13 +168,22 @@ Result ColumnArrayFromArrayData( KeyColumnArray ColumnArrayFromArrayDataAndMetadata( const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, int64_t start_row, int64_t num_rows) { + const uint8_t* fixed_length_buffer = nullptr; + const uint8_t* var_length_buffer = nullptr; + + // Check if the column is of Null Type + if (!metadata.is_null_type) { + fixed_length_buffer = array_data->buffers[1]->data(); + if (array_data->buffers.size() > 2 && array_data->buffers[2] != NULLPTR) { + var_length_buffer = array_data->buffers[2]->data(); + } + } + KeyColumnArray column_array = KeyColumnArray( metadata, array_data->offset + start_row + num_rows, array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() : nullptr, - array_data->buffers[1]->data(), - (array_data->buffers.size() > 2 && array_data->buffers[2] != NULLPTR) - ? array_data->buffers[2]->data() - : nullptr); + fixed_length_buffer, var_length_buffer); + return column_array.Slice(array_data->offset + start_row, num_rows); } @@ -502,7 +521,10 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(source->type).ValueOrDie(); - if (column_metadata.is_fixed_length) { + if (column_metadata.fixed_length == 0 && column_metadata.is_fixed_length && + column_metadata.is_null_type) { + // Null column + } else if (column_metadata.is_fixed_length) { // Fixed length column // uint32_t fixed_length = column_metadata.fixed_length;