From b8185605295e55b1dc8740684351403f1860d87f Mon Sep 17 00:00:00 2001 From: Rossi Sun <zanmato1984@gmail.com> Date: Tue, 28 Jan 2025 00:25:29 +0800 Subject: [PATCH] GH-45334: [C++][Acero] Fix swiss join overflow issues in row offset calculation for fixed length and null masks (#45336) ### Rationale for this change #45334 ### What changes are included in this PR? 1. An all-mighty test case that can effectively reveal all the bugs mentioned in the issue; 2. Other than directly fixing the bugs (actually simply casting to 64-bit somewhere in the multiplication will do), I did some refinement to the buffer accessors of the row table, in order to eliminate more potential similar issues (which I believe do exist): 1. `null_masks()` -> `null_masks(row_id)` which does overflow-safe indexing inside; 2. `is_null(row_id, col_pos)` which does overflow-safe indexing and directly gets the bit of the column; 3. `data(1)` -> `fixed_length_rows(row_id)` which first asserts the row table being fixed-length, then does overflow-safe indexing inside; 4. `data(2)` -> `var_length_rows()` which only asserts the row table being var-length. It is supposed to be paired by the `offsets()` (which is already 64-bit by #43389 ); 5. The `data(0/1/2)` members are made private. 3. The AVX2 specializations are fixed individually by using 64-bit multiplication and indexing. ### Are these changes tested? Yes. ### Are there any user-facing changes? None. * GitHub Issue: #45334 Authored-by: Rossi Sun <zanmato1984@gmail.com> Signed-off-by: Rossi Sun <zanmato1984@gmail.com> --- cpp/src/arrow/acero/hash_join_node_test.cc | 99 +++++++++++++++++++ cpp/src/arrow/acero/swiss_join.cc | 36 ++++--- cpp/src/arrow/acero/swiss_join_avx2.cc | 34 ++----- cpp/src/arrow/acero/swiss_join_internal.h | 14 +-- cpp/src/arrow/compute/row/compare_internal.cc | 16 +-- .../compute/row/compare_internal_avx2.cc | 81 +++------------ cpp/src/arrow/compute/row/compare_test.cc | 6 +- cpp/src/arrow/compute/row/encode_internal.cc | 54 +++++----- cpp/src/arrow/compute/row/encode_internal.h | 10 +- .../arrow/compute/row/encode_internal_avx2.cc | 15 ++- cpp/src/arrow/compute/row/row_internal.cc | 10 +- cpp/src/arrow/compute/row/row_internal.h | 61 +++++++++--- cpp/src/arrow/compute/row/row_test.cc | 11 +-- .../compute/row/row_util_avx2_internal.h | 64 ++++++++++++ 14 files changed, 313 insertions(+), 198 deletions(-) create mode 100644 cpp/src/arrow/compute/row/row_util_avx2_internal.h diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 94504ccc9ba75..654fd59c45d5a 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -3449,5 +3449,104 @@ TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBVarLength)) { num_batches_left * num_rows_per_batch_left * num_batches_right); } +// GH-45334: The row ids of the matching rows on the right side (the build side) are very +// big, causing the index calculation overflow. +TEST(HashJoin, BuildSideLargeRowIds) { + GTEST_SKIP() << "Test disabled due to excessively time and resource consuming, " + "for local debugging only."; + + // A fair amount of match rows to trigger both SIMD and non-SIMD code paths. + const int64_t num_match_rows = 35; + const int64_t num_rows_per_match_batch = 35; + const int64_t num_match_batches = num_match_rows / num_rows_per_match_batch; + + const int64_t num_unmatch_rows_large = 720898048; + const int64_t num_rows_per_unmatch_batch_large = 352001; + const int64_t num_unmatch_batches_large = + num_unmatch_rows_large / num_rows_per_unmatch_batch_large; + + auto schema_small = + schema({field("small_key", int64()), field("small_payload", int64())}); + auto schema_large = + schema({field("large_key", int64()), field("large_payload", int64())}); + + // A carefully chosen key value which hashes to 0xFFFFFFFE, making the match rows to be + // placed at higher address of the row table. + const int64_t match_key = 289339070; + const int64_t match_payload = 42; + + // Match arrays of length num_rows_per_match_batch. + ASSERT_OK_AND_ASSIGN( + auto match_key_arr, + Constant(MakeScalar(match_key))->Generate(num_rows_per_match_batch)); + ASSERT_OK_AND_ASSIGN( + auto match_payload_arr, + Constant(MakeScalar(match_payload))->Generate(num_rows_per_match_batch)); + // Append 1 row of null to trigger null processing code paths. + ASSERT_OK_AND_ASSIGN(auto null_arr, MakeArrayOfNull(int64(), 1)); + ASSERT_OK_AND_ASSIGN(match_key_arr, Concatenate({match_key_arr, null_arr})); + ASSERT_OK_AND_ASSIGN(match_payload_arr, Concatenate({match_payload_arr, null_arr})); + // Match batch. + ExecBatch match_batch({match_key_arr, match_payload_arr}, num_rows_per_match_batch + 1); + + // Small batch. + ExecBatch batch_small = match_batch; + + // Large unmatch batches. + const int64_t seed = 42; + std::vector<ExecBatch> unmatch_batches_large; + unmatch_batches_large.reserve(num_unmatch_batches_large); + ASSERT_OK_AND_ASSIGN(auto unmatch_payload_arr_large, + MakeArrayOfNull(int64(), num_rows_per_unmatch_batch_large)); + int64_t unmatch_range_per_batch = + (std::numeric_limits<int64_t>::max() - match_key) / num_unmatch_batches_large; + for (int i = 0; i < num_unmatch_batches_large; ++i) { + auto unmatch_key_arr_large = RandomArrayGenerator(seed).Int64( + num_rows_per_unmatch_batch_large, + /*min=*/match_key + 1 + i * unmatch_range_per_batch, + /*max=*/match_key + 1 + (i + 1) * unmatch_range_per_batch); + unmatch_batches_large.push_back( + ExecBatch({unmatch_key_arr_large, unmatch_payload_arr_large}, + num_rows_per_unmatch_batch_large)); + } + // Large match batch. + ExecBatch match_batch_large = match_batch; + + // Batches with schemas. + auto batches_small = BatchesWithSchema{ + std::vector<ExecBatch>(num_match_batches, batch_small), schema_small}; + auto batches_large = BatchesWithSchema{std::move(unmatch_batches_large), schema_large}; + for (int i = 0; i < num_match_batches; i++) { + batches_large.batches.push_back(match_batch_large); + } + + Declaration source_small{ + "exec_batch_source", + ExecBatchSourceNodeOptions(batches_small.schema, batches_small.batches)}; + Declaration source_large{ + "exec_batch_source", + ExecBatchSourceNodeOptions(batches_large.schema, batches_large.batches)}; + + HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"small_key"}, + /*right_keys=*/{"large_key"}); + Declaration join{ + "hashjoin", {std::move(source_small), std::move(source_large)}, join_opts}; + + // Join should emit num_match_rows * num_match_rows rows. + ASSERT_OK_AND_ASSIGN(auto batches_result, DeclarationToExecBatches(std::move(join))); + Declaration result{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_result.schema), + std::move(batches_result.batches))}; + AssertRowCountEq(result, num_match_rows * num_match_rows); + + // All rows should be match_key/payload. + auto predicate = and_({equal(field_ref("small_key"), literal(match_key)), + equal(field_ref("small_payload"), literal(match_payload)), + equal(field_ref("large_key"), literal(match_key)), + equal(field_ref("large_payload"), literal(match_payload))}); + Declaration filter{"filter", {result}, FilterNodeOptions{std::move(predicate)}}; + AssertRowCountEq(std::move(filter), num_match_rows * num_match_rows); +} + } // namespace acero } // namespace arrow diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index fc3be1b462e60..85e14ac469ce7 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -477,14 +477,15 @@ void RowArrayMerge::CopyFixedLength(RowTableImpl* target, const RowTableImpl& so const int64_t* source_rows_permutation) { int64_t num_source_rows = source.length(); - int64_t fixed_length = target->metadata().fixed_length; + uint32_t fixed_length = target->metadata().fixed_length; // Permutation of source rows is optional. Without permutation all that is // needed is memcpy. // if (!source_rows_permutation) { - memcpy(target->mutable_data(1) + fixed_length * first_target_row_id, source.data(1), - fixed_length * num_source_rows); + DCHECK_LE(first_target_row_id, std::numeric_limits<uint32_t>::max()); + memcpy(target->mutable_fixed_length_rows(static_cast<uint32_t>(first_target_row_id)), + source.fixed_length_rows(/*row_id=*/0), fixed_length * num_source_rows); } else { // Row length must be a multiple of 64-bits due to enforced alignment. // Loop for each output row copying a fixed number of 64-bit words. @@ -494,10 +495,13 @@ void RowArrayMerge::CopyFixedLength(RowTableImpl* target, const RowTableImpl& so int64_t num_words_per_row = fixed_length / sizeof(uint64_t); for (int64_t i = 0; i < num_source_rows; ++i) { int64_t source_row_id = source_rows_permutation[i]; + DCHECK_LE(source_row_id, std::numeric_limits<uint32_t>::max()); const uint64_t* source_row_ptr = reinterpret_cast<const uint64_t*>( - source.data(1) + fixed_length * source_row_id); + source.fixed_length_rows(static_cast<uint32_t>(source_row_id))); + int64_t target_row_id = first_target_row_id + i; + DCHECK_LE(target_row_id, std::numeric_limits<uint32_t>::max()); uint64_t* target_row_ptr = reinterpret_cast<uint64_t*>( - target->mutable_data(1) + fixed_length * (first_target_row_id + i)); + target->mutable_fixed_length_rows(static_cast<uint32_t>(target_row_id))); for (int64_t word = 0; word < num_words_per_row; ++word) { target_row_ptr[word] = source_row_ptr[word]; @@ -529,16 +533,16 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& // We can simply memcpy bytes of rows if their order has not changed. // - memcpy(target->mutable_data(2) + target_offsets[first_target_row_id], source.data(2), - source_offsets[num_source_rows] - source_offsets[0]); + memcpy(target->mutable_var_length_rows() + target_offsets[first_target_row_id], + source.var_length_rows(), source_offsets[num_source_rows] - source_offsets[0]); } else { int64_t target_row_offset = first_target_row_offset; - uint64_t* target_row_ptr = - reinterpret_cast<uint64_t*>(target->mutable_data(2) + target_row_offset); + uint64_t* target_row_ptr = reinterpret_cast<uint64_t*>( + target->mutable_var_length_rows() + target_row_offset); for (int64_t i = 0; i < num_source_rows; ++i) { int64_t source_row_id = source_rows_permutation[i]; const uint64_t* source_row_ptr = reinterpret_cast<const uint64_t*>( - source.data(2) + source_offsets[source_row_id]); + source.var_length_rows() + source_offsets[source_row_id]); int64_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; // Though the row offset is 64-bit, the length of a single row must be 32-bit as // required by current row table implementation. @@ -564,14 +568,18 @@ void RowArrayMerge::CopyNulls(RowTableImpl* target, const RowTableImpl& source, const int64_t* source_rows_permutation) { int64_t num_source_rows = source.length(); int num_bytes_per_row = target->metadata().null_masks_bytes_per_row; - uint8_t* target_nulls = target->null_masks() + num_bytes_per_row * first_target_row_id; + DCHECK_LE(first_target_row_id, std::numeric_limits<uint32_t>::max()); + uint8_t* target_nulls = + target->mutable_null_masks(static_cast<uint32_t>(first_target_row_id)); if (!source_rows_permutation) { - memcpy(target_nulls, source.null_masks(), num_bytes_per_row * num_source_rows); + memcpy(target_nulls, source.null_masks(/*row_id=*/0), + num_bytes_per_row * num_source_rows); } else { - for (int64_t i = 0; i < num_source_rows; ++i) { + for (uint32_t i = 0; i < num_source_rows; ++i) { int64_t source_row_id = source_rows_permutation[i]; + DCHECK_LE(source_row_id, std::numeric_limits<uint32_t>::max()); const uint8_t* source_nulls = - source.null_masks() + num_bytes_per_row * source_row_id; + source.null_masks(static_cast<uint32_t>(source_row_id)); for (int64_t byte = 0; byte < num_bytes_per_row; ++byte) { *target_nulls++ = *source_nulls++; } diff --git a/cpp/src/arrow/acero/swiss_join_avx2.cc b/cpp/src/arrow/acero/swiss_join_avx2.cc index 1d6b7eda6e6a0..deeee2a4e110d 100644 --- a/cpp/src/arrow/acero/swiss_join_avx2.cc +++ b/cpp/src/arrow/acero/swiss_join_avx2.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/acero/swiss_join_internal.h" +#include "arrow/compute/row/row_util_avx2_internal.h" #include "arrow/util/bit_util.h" #include "arrow/util/simd.h" @@ -46,7 +47,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu if (!is_fixed_length_column) { int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); - const uint8_t* row_ptr_base = rows.data(2); + const uint8_t* row_ptr_base = rows.var_length_rows(); const RowTableImpl::offset_type* row_offsets = rows.offsets(); auto row_offsets_i64 = reinterpret_cast<const arrow::util::int64_for_gather_t*>(row_offsets); @@ -172,7 +173,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu if (is_fixed_length_row) { // Case 3: This is a fixed length column in fixed length row // - const uint8_t* row_ptr_base = rows.data(1); + const uint8_t* row_ptr_base = rows.fixed_length_rows(/*row_id=*/0); for (int i = 0; i < num_rows / kUnroll; ++i) { // Load 8 32-bit row ids. __m256i row_id = @@ -197,7 +198,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu } else { // Case 4: This is a fixed length column in varying length row // - const uint8_t* row_ptr_base = rows.data(2); + const uint8_t* row_ptr_base = rows.var_length_rows(); const RowTableImpl::offset_type* row_offsets = rows.offsets(); auto row_offsets_i64 = reinterpret_cast<const arrow::util::int64_for_gather_t*>(row_offsets); @@ -237,31 +238,12 @@ int RowArrayAccessor::VisitNulls_avx2(const RowTableImpl& rows, int column_id, // constexpr int kUnroll = 8; - const uint8_t* null_masks = rows.null_masks(); - __m256i null_bits_per_row = - _mm256_set1_epi32(8 * rows.metadata().null_masks_bytes_per_row); - __m256i pos_after_encoding = - _mm256_set1_epi32(rows.metadata().pos_after_encoding(column_id)); + uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id); for (int i = 0; i < num_rows / kUnroll; ++i) { __m256i row_id = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_ids) + i); - __m256i bit_id = _mm256_mullo_epi32(row_id, null_bits_per_row); - bit_id = _mm256_add_epi32(bit_id, pos_after_encoding); - __m256i bytes = _mm256_i32gather_epi32(reinterpret_cast<const int*>(null_masks), - _mm256_srli_epi32(bit_id, 3), 1); - __m256i bit_in_word = _mm256_sllv_epi32( - _mm256_set1_epi32(1), _mm256_and_si256(bit_id, _mm256_set1_epi32(7))); - // `result` will contain one 32-bit word per tested null bit, either 0xffffffff if the - // null bit was set or 0 if it was unset. - __m256i result = - _mm256_cmpeq_epi32(_mm256_and_si256(bytes, bit_in_word), bit_in_word); - // NB: Be careful about sign-extension when casting the return value of - // _mm256_movemask_epi8 (signed 32-bit) to unsigned 64-bit, which will pollute the - // higher bits of the following OR. - uint32_t null_bytes_lo = static_cast<uint32_t>( - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(result)))); - uint64_t null_bytes_hi = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(result, 1))); - uint64_t null_bytes = null_bytes_lo | (null_bytes_hi << 32); + __m256i null32 = GetNullBitInt32(rows, pos_after_encoding, row_id); + null32 = _mm256_cmpeq_epi32(null32, _mm256_set1_epi32(1)); + uint64_t null_bytes = arrow::compute::Cmp32To8(null32); process_8_values_fn(i * kUnroll, null_bytes); } diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index f2f3ac5b1bf93..85f443b0323c7 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -72,7 +72,7 @@ class RowArrayAccessor { if (!is_fixed_length_column) { int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); - const uint8_t* row_ptr_base = rows.data(2); + const uint8_t* row_ptr_base = rows.var_length_rows(); const RowTableImpl::offset_type* row_offsets = rows.offsets(); uint32_t field_offset_within_row, field_length; @@ -108,22 +108,21 @@ class RowArrayAccessor { if (field_length == 0) { field_length = 1; } - uint32_t row_length = rows.metadata().fixed_length; bool is_fixed_length_row = rows.metadata().is_fixed_length; if (is_fixed_length_row) { // Case 3: This is a fixed length column in a fixed length row // - const uint8_t* row_ptr_base = rows.data(1) + field_offset_within_row; for (int i = 0; i < num_rows; ++i) { uint32_t row_id = row_ids[i]; - const uint8_t* row_ptr = row_ptr_base + row_length * row_id; + const uint8_t* row_ptr = + rows.fixed_length_rows(row_id) + field_offset_within_row; process_value_fn(i, row_ptr, field_length); } } else { // Case 4: This is a fixed length column in a varying length row // - const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row; + const uint8_t* row_ptr_base = rows.var_length_rows() + field_offset_within_row; const RowTableImpl::offset_type* row_offsets = rows.offsets(); for (int i = 0; i < num_rows; ++i) { uint32_t row_id = row_ids[i]; @@ -142,13 +141,10 @@ class RowArrayAccessor { template <class PROCESS_VALUE_FN> static void VisitNulls(const RowTableImpl& rows, int column_id, int num_rows, const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn) { - const uint8_t* null_masks = rows.null_masks(); - uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id); for (int i = 0; i < num_rows; ++i) { uint32_t row_id = row_ids[i]; - int64_t bit_id = row_id * null_mask_num_bytes * 8 + pos_after_encoding; - process_value_fn(i, bit_util::GetBit(null_masks, bit_id) ? 0xff : 0); + process_value_fn(i, rows.is_null(row_id, pos_after_encoding) ? 0xff : 0); } } diff --git a/cpp/src/arrow/compute/row/compare_internal.cc b/cpp/src/arrow/compute/row/compare_internal.cc index 5e1a87b795202..b7a01ea75ad7d 100644 --- a/cpp/src/arrow/compute/row/compare_internal.cc +++ b/cpp/src/arrow/compute/row/compare_internal.cc @@ -55,13 +55,10 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com if (!col.data(0)) { // Remove rows from the result for which the column value is a null - const uint8_t* null_masks = rows.null_masks(); - uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - int64_t bitid = irow_right * null_mask_num_bytes * 8 + null_bit_id; - match_bytevector[i] &= (bit_util::GetBit(null_masks, bitid) ? 0 : 0xff); + match_bytevector[i] &= (rows.is_null(irow_right, null_bit_id) ? 0 : 0xff); } } else if (!rows.has_any_nulls(ctx)) { // Remove rows from the result for which the column value on left side is @@ -74,15 +71,12 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0xff : 0; } } else { - const uint8_t* null_masks = rows.null_masks(); - uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; const uint8_t* non_nulls = col.data(0); ARROW_DCHECK(non_nulls); for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + null_bit_id; - int right_null = bit_util::GetBit(null_masks, bitid_right) ? 0xff : 0; + int right_null = rows.is_null(irow_right, null_bit_id) ? 0xff : 0; int left_null = bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff; match_bytevector[i] |= left_null & right_null; @@ -101,7 +95,7 @@ void KeyCompare::CompareBinaryColumnToRowHelper( if (is_fixed_length) { uint32_t fixed_length = rows.metadata().fixed_length; const uint8_t* rows_left = col.data(1); - const uint8_t* rows_right = rows.data(1); + const uint8_t* rows_right = rows.fixed_length_rows(/*row_id=*/0); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; // irow_right is used to index into row data so promote to the row offset type. @@ -113,7 +107,7 @@ void KeyCompare::CompareBinaryColumnToRowHelper( } else { const uint8_t* rows_left = col.data(1); const RowTableImpl::offset_type* offsets_right = rows.offsets(); - const uint8_t* rows_right = rows.data(2); + const uint8_t* rows_right = rows.var_length_rows(); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; @@ -246,7 +240,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper( const uint32_t* offsets_left = col.offsets(); const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); - const uint8_t* rows_right = rows.data(2); + const uint8_t* rows_right = rows.var_length_rows(); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; diff --git a/cpp/src/arrow/compute/row/compare_internal_avx2.cc b/cpp/src/arrow/compute/row/compare_internal_avx2.cc index 9f6e1adfe2108..8af84ac6b2f52 100644 --- a/cpp/src/arrow/compute/row/compare_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/compare_internal_avx2.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/compute/row/compare_internal.h" +#include "arrow/compute/row/row_util_avx2_internal.h" #include "arrow/compute/util.h" #include "arrow/util/bit_util.h" #include "arrow/util/simd.h" @@ -49,9 +50,6 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( if (!col.data(0)) { // Remove rows from the result for which the column value is a null - const uint8_t* null_masks = rows.null_masks(); - uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; - uint32_t num_processed = 0; constexpr uint32_t unroll = 8; for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { @@ -64,21 +62,9 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( irow_right = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i); } - __m256i bitid = - _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); - bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id)); - __m256i right = - _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); - right = _mm256_and_si256( - _mm256_set1_epi32(1), - _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7)))); + __m256i right = GetNullBitInt32(rows, null_bit_id, irow_right); __m256i cmp = _mm256_cmpeq_epi32(right, _mm256_setzero_si256()); - uint32_t result_lo = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); - uint32_t result_hi = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); - reinterpret_cast<uint64_t*>(match_bytevector)[i] &= - result_lo | (static_cast<uint64_t>(result_hi) << 32); + reinterpret_cast<uint64_t*>(match_bytevector)[i] &= Cmp32To8(cmp); } num_processed = num_rows_to_compare / unroll * unroll; return num_processed; @@ -107,18 +93,11 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128); cmp = _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), bits); } - uint32_t result_lo = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); - uint32_t result_hi = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); - reinterpret_cast<uint64_t*>(match_bytevector)[i] &= - result_lo | (static_cast<uint64_t>(result_hi) << 32); - num_processed = num_rows_to_compare / unroll * unroll; + reinterpret_cast<uint64_t*>(match_bytevector)[i] &= Cmp32To8(cmp); } + num_processed = num_rows_to_compare / unroll * unroll; return num_processed; } else { - const uint8_t* null_masks = rows.null_masks(); - uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; const uint8_t* non_nulls = col.data(0); ARROW_DCHECK(non_nulls); @@ -147,29 +126,11 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( left_null = _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), _mm256_setzero_si256()); } - __m256i bitid = - _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); - bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id)); - __m256i right = - _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); - right = _mm256_and_si256( - _mm256_set1_epi32(1), - _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7)))); + __m256i right = GetNullBitInt32(rows, null_bit_id, irow_right); __m256i right_null = _mm256_cmpeq_epi32(right, _mm256_set1_epi32(1)); - uint64_t left_null_64 = - static_cast<uint32_t>(_mm256_movemask_epi8( - _mm256_cvtepi32_epi64(_mm256_castsi256_si128(left_null)))) | - (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(left_null, 1))))) - << 32); - - uint64_t right_null_64 = - static_cast<uint32_t>(_mm256_movemask_epi8( - _mm256_cvtepi32_epi64(_mm256_castsi256_si128(right_null)))) | - (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_null, 1))))) - << 32); + uint64_t left_null_64 = Cmp32To8(left_null); + uint64_t right_null_64 = Cmp32To8(right_null); reinterpret_cast<uint64_t*>(match_bytevector)[i] |= left_null_64 & right_null_64; reinterpret_cast<uint64_t*>(match_bytevector)[i] &= ~(left_null_64 ^ right_null_64); @@ -189,7 +150,7 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( if (is_fixed_length) { uint32_t fixed_length = rows.metadata().fixed_length; const uint8_t* rows_left = col.data(1); - const uint8_t* rows_right = rows.data(1); + const uint8_t* rows_right = rows.fixed_length_rows(/*row_id=*/0); constexpr uint32_t unroll = 8; __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { @@ -234,7 +195,7 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( } else { const uint8_t* rows_left = col.data(1); const RowTableImpl::offset_type* offsets_right = rows.offsets(); - const uint8_t* rows_right = rows.data(2); + const uint8_t* rows_right = rows.var_length_rows(); constexpr uint32_t unroll = 8; __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { @@ -321,12 +282,7 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r __m256i cmp = _mm256_cmpeq_epi32(left, right); - uint32_t result_lo = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); - uint32_t result_hi = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); - - return result_lo | (static_cast<uint64_t>(result_hi) << 32); + return Cmp32To8(cmp); } template <int column_width> @@ -372,12 +328,7 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas __m256i cmp = _mm256_cmpeq_epi32(left, right); - uint32_t result_lo = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); - uint32_t result_hi = - _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); - - return result_lo | (static_cast<uint64_t>(result_hi) << 32); + return Cmp32To8(cmp); } template <bool use_selection> @@ -402,9 +353,9 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig reinterpret_cast<const arrow::util::int64_for_gather_t*>(right_base); __m256i right_lo = _mm256_i64gather_epi64(right_base_i64, offset_right_lo, 1); __m256i right_hi = _mm256_i64gather_epi64(right_base_i64, offset_right_hi, 1); - uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo)); - uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi)); - return result_lo | (static_cast<uint64_t>(result_hi) << 32); + __m256i cmp_lo = _mm256_cmpeq_epi64(left_lo, right_lo); + __m256i cmp_hi = _mm256_cmpeq_epi64(left_hi, right_hi); + return Cmp64To8(cmp_lo, cmp_hi); } template <bool use_selection> @@ -554,7 +505,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( const uint32_t* offsets_left = col.offsets(); const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); - const uint8_t* rows_right = rows.data(2); + const uint8_t* rows_right = rows.var_length_rows(); for (uint32_t i = 0; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; diff --git a/cpp/src/arrow/compute/row/compare_test.cc b/cpp/src/arrow/compute/row/compare_test.cc index 5e8ee7c58a782..2b8f4d97561e8 100644 --- a/cpp/src/arrow/compute/row/compare_test.cc +++ b/cpp/src/arrow/compute/row/compare_test.cc @@ -327,7 +327,7 @@ TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver2GB)) { ASSERT_OK_AND_ASSIGN(RowTableImpl row_table_right, MakeRowTableFromExecBatch(batch_left)); // The row table must contain an offset buffer. - ASSERT_NE(row_table_right.data(2), NULLPTR); + ASSERT_NE(row_table_right.var_length_rows(), NULLPTR); // The whole point of this test. ASSERT_GT(row_table_right.offsets()[num_rows - 1], k2GB); @@ -387,7 +387,7 @@ TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBFixedLength)) { RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(), num_rows_row_table)); // The row table must not contain a third buffer. - ASSERT_EQ(row_table_right.data(2), NULLPTR); + ASSERT_EQ(row_table_right.var_length_rows(), NULLPTR); // The row data must be greater than 4GB. ASSERT_GT(row_table_right.buffer_size(1), k4GB); @@ -460,7 +460,7 @@ TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBVarLength)) { RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(), num_rows_row_table)); // The row table must contain an offset buffer. - ASSERT_NE(row_table_right.data(2), NULLPTR); + ASSERT_NE(row_table_right.var_length_rows(), NULLPTR); // At least the last row should be located at over 4GB. ASSERT_GT(row_table_right.offsets()[num_rows_row_table - 1], k4GB); diff --git a/cpp/src/arrow/compute/row/encode_internal.cc b/cpp/src/arrow/compute/row/encode_internal.cc index 127d43021d639..0e2720a286634 100644 --- a/cpp/src/arrow/compute/row/encode_internal.cc +++ b/cpp/src/arrow/compute/row/encode_internal.cc @@ -260,36 +260,32 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, col_prep.metadata().fixed_length == rows.metadata().fixed_length) { DCHECK_EQ(offset_within_row, 0); uint32_t row_size = rows.metadata().fixed_length; - memcpy(col_prep.mutable_data(1), rows.data(1) + start_row * row_size, - num_rows * row_size); + memcpy(col_prep.mutable_data(1), rows.fixed_length_rows(start_row), + static_cast<int64_t>(num_rows) * row_size); } else if (rows.metadata().is_fixed_length) { - uint32_t row_size = rows.metadata().fixed_length; - const uint8_t* row_base = - rows.data(1) + static_cast<RowTableImpl::offset_type>(start_row) * row_size; - row_base += offset_within_row; uint8_t* col_base = col_prep.mutable_data(1); switch (col_prep.metadata().fixed_length) { case 1: for (uint32_t i = 0; i < num_rows; ++i) { - col_base[i] = row_base[i * row_size]; + col_base[i] = *(rows.fixed_length_rows(start_row + i) + offset_within_row); } break; case 2: for (uint32_t i = 0; i < num_rows; ++i) { - reinterpret_cast<uint16_t*>(col_base)[i] = - *reinterpret_cast<const uint16_t*>(row_base + i * row_size); + reinterpret_cast<uint16_t*>(col_base)[i] = *reinterpret_cast<const uint16_t*>( + rows.fixed_length_rows(start_row + i) + offset_within_row); } break; case 4: for (uint32_t i = 0; i < num_rows; ++i) { - reinterpret_cast<uint32_t*>(col_base)[i] = - *reinterpret_cast<const uint32_t*>(row_base + i * row_size); + reinterpret_cast<uint32_t*>(col_base)[i] = *reinterpret_cast<const uint32_t*>( + rows.fixed_length_rows(start_row + i + offset_within_row)); } break; case 8: for (uint32_t i = 0; i < num_rows; ++i) { - reinterpret_cast<uint64_t*>(col_base)[i] = - *reinterpret_cast<const uint64_t*>(row_base + i * row_size); + reinterpret_cast<uint64_t*>(col_base)[i] = *reinterpret_cast<const uint64_t*>( + rows.fixed_length_rows(start_row + i) + offset_within_row); } break; default: @@ -297,7 +293,7 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, } } else { const RowTableImpl::offset_type* row_offsets = rows.offsets() + start_row; - const uint8_t* row_base = rows.data(2); + const uint8_t* row_base = rows.var_length_rows(); row_base += offset_within_row; uint8_t* col_base = col_prep.mutable_data(1); switch (col_prep.metadata().fixed_length) { @@ -343,14 +339,14 @@ void EncoderBinary::EncodeSelectedImp(uint32_t offset_within_row, RowTableImpl* if (is_fixed_length) { uint32_t row_width = rows->metadata().fixed_length; const uint8_t* src_base = col.data(1); - uint8_t* dst = rows->mutable_data(1) + offset_within_row; + uint8_t* dst = rows->mutable_fixed_length_rows(/*row_id=*/0) + offset_within_row; for (uint32_t i = 0; i < num_selected; ++i) { copy_fn(dst, src_base, selection[i]); dst += row_width; } if (col.data(0)) { const uint8_t* non_null_bits = col.data(0); - uint8_t* dst = rows->mutable_data(1) + offset_within_row; + dst = rows->mutable_fixed_length_rows(/*row_id=*/0) + offset_within_row; for (uint32_t i = 0; i < num_selected; ++i) { bool is_null = !bit_util::GetBit(non_null_bits, selection[i] + col.bit_offset(0)); if (is_null) { @@ -361,14 +357,14 @@ void EncoderBinary::EncodeSelectedImp(uint32_t offset_within_row, RowTableImpl* } } else { const uint8_t* src_base = col.data(1); - uint8_t* dst = rows->mutable_data(2) + offset_within_row; + uint8_t* dst = rows->mutable_var_length_rows() + offset_within_row; const RowTableImpl::offset_type* offsets = rows->offsets(); for (uint32_t i = 0; i < num_selected; ++i) { copy_fn(dst + offsets[i], src_base, selection[i]); } if (col.data(0)) { const uint8_t* non_null_bits = col.data(0); - uint8_t* dst = rows->mutable_data(2) + offset_within_row; + uint8_t* dst = rows->mutable_var_length_rows() + offset_within_row; const RowTableImpl::offset_type* offsets = rows->offsets(); for (uint32_t i = 0; i < num_selected; ++i) { bool is_null = !bit_util::GetBit(non_null_bits, selection[i] + col.bit_offset(0)); @@ -584,16 +580,13 @@ void EncoderBinaryPair::DecodeImp(uint32_t num_rows_to_skip, uint32_t start_row, uint8_t* dst_A = col1->mutable_data(1); uint8_t* dst_B = col2->mutable_data(1); - uint32_t fixed_length = rows.metadata().fixed_length; const RowTableImpl::offset_type* offsets; const uint8_t* src_base; if (is_row_fixed_length) { - src_base = rows.data(1) + - static_cast<RowTableImpl::offset_type>(start_row) * fixed_length + - offset_within_row; + src_base = rows.fixed_length_rows(start_row) + offset_within_row; offsets = nullptr; } else { - src_base = rows.data(2) + offset_within_row; + src_base = rows.var_length_rows() + offset_within_row; offsets = rows.offsets() + start_row; } @@ -601,6 +594,7 @@ void EncoderBinaryPair::DecodeImp(uint32_t num_rows_to_skip, uint32_t start_row, using col2_type_const = typename std::add_const<col2_type>::type; if (is_row_fixed_length) { + uint32_t fixed_length = rows.metadata().fixed_length; const uint8_t* src = src_base + num_rows_to_skip * fixed_length; for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) { reinterpret_cast<col1_type*>(dst_A)[i] = *reinterpret_cast<col1_type_const*>(src); @@ -654,7 +648,7 @@ void EncoderOffsets::Decode(uint32_t start_row, uint32_t num_rows, for (uint32_t i = 0; i < num_rows; ++i) { // Find the beginning of cumulative lengths array for next row - const uint8_t* row = rows.data(2) + row_offsets[i]; + const uint8_t* row = rows.var_length_rows() + row_offsets[i]; const uint32_t* varbinary_ends = rows.metadata().varbinary_end_array(row); // Update the offset of each column @@ -728,7 +722,7 @@ void EncoderOffsets::EncodeSelectedImp(uint32_t ivarbinary, RowTableImpl* rows, const std::vector<KeyColumnArray>& cols, uint32_t num_selected, const uint16_t* selection) { const RowTableImpl::offset_type* row_offsets = rows->offsets(); - uint8_t* row_base = rows->mutable_data(2) + + uint8_t* row_base = rows->mutable_var_length_rows() + rows->metadata().varbinary_end_array_offset + ivarbinary * sizeof(uint32_t); const uint32_t* col_offsets = cols[ivarbinary].offsets(); @@ -824,8 +818,6 @@ void EncoderNulls::Decode(uint32_t start_row, uint32_t num_rows, const RowTableI DCHECK(col.mutable_data(0) || col.metadata().is_null_type); } - const uint8_t* null_masks = rows.null_masks(); - uint32_t null_masks_bytes_per_row = rows.metadata().null_masks_bytes_per_row; for (size_t col = 0; col < cols->size(); ++col) { if ((*cols)[col].metadata().is_null_type) { continue; @@ -839,9 +831,7 @@ void EncoderNulls::Decode(uint32_t start_row, uint32_t num_rows, const RowTableI memset(non_nulls + 1, 0xff, bit_util::BytesForBits(num_rows - bits_in_first_byte)); } for (uint32_t row = 0; row < num_rows; ++row) { - uint32_t null_masks_bit_id = - (start_row + row) * null_masks_bytes_per_row * 8 + static_cast<uint32_t>(col); - bool is_set = bit_util::GetBit(null_masks, null_masks_bit_id); + bool is_set = rows.is_null(start_row + row, static_cast<uint32_t>(col)); if (is_set) { bit_util::ClearBit(non_nulls, bit_offset + row); } @@ -853,7 +843,7 @@ void EncoderVarBinary::EncodeSelected(uint32_t ivarbinary, RowTableImpl* rows, const KeyColumnArray& cols, uint32_t num_selected, const uint16_t* selection) { const RowTableImpl::offset_type* row_offsets = rows->offsets(); - uint8_t* row_base = rows->mutable_data(2); + uint8_t* row_base = rows->mutable_var_length_rows(); const uint32_t* col_offsets = cols.offsets(); const uint8_t* col_base = cols.data(2); @@ -882,7 +872,7 @@ void EncoderVarBinary::EncodeSelected(uint32_t ivarbinary, RowTableImpl* rows, void EncoderNulls::EncodeSelected(RowTableImpl* rows, const std::vector<KeyColumnArray>& cols, uint32_t num_selected, const uint16_t* selection) { - uint8_t* null_masks = rows->null_masks(); + uint8_t* null_masks = rows->mutable_null_masks(/*row_id=*/0); uint32_t null_mask_num_bytes = rows->metadata().null_masks_bytes_per_row; memset(null_masks, 0, null_mask_num_bytes * num_selected); for (size_t icol = 0; icol < cols.size(); ++icol) { diff --git a/cpp/src/arrow/compute/row/encode_internal.h b/cpp/src/arrow/compute/row/encode_internal.h index 37538fcc4b835..5ad82e0c8e749 100644 --- a/cpp/src/arrow/compute/row/encode_internal.h +++ b/cpp/src/arrow/compute/row/encode_internal.h @@ -164,11 +164,10 @@ class EncoderBinary { uint32_t col_width = col_const->metadata().fixed_length; if (is_row_fixed_length) { - uint32_t row_width = rows_const->metadata().fixed_length; for (uint32_t i = 0; i < num_rows; ++i) { const uint8_t* src; uint8_t* dst; - src = rows_const->data(1) + row_width * (start_row + i) + offset_within_row; + src = rows_const->fixed_length_rows(start_row + i) + offset_within_row; dst = col_mutable_maybe_null->mutable_data(1) + col_width * i; copy_fn(dst, src, col_width); } @@ -177,7 +176,8 @@ class EncoderBinary { for (uint32_t i = 0; i < num_rows; ++i) { const uint8_t* src; uint8_t* dst; - src = rows_const->data(2) + row_offsets[start_row + i] + offset_within_row; + src = rows_const->var_length_rows() + row_offsets[start_row + i] + + offset_within_row; dst = col_mutable_maybe_null->mutable_data(1) + col_width * i; copy_fn(dst, src, col_width); } @@ -277,7 +277,7 @@ class EncoderVarBinary { col_offset_next = col_offsets[i + 1]; RowTableImpl::offset_type row_offset = row_offsets_for_batch[i]; - const uint8_t* row = rows_const->data(2) + row_offset; + const uint8_t* row = rows_const->var_length_rows() + row_offset; uint32_t offset_within_row; uint32_t length; @@ -293,7 +293,7 @@ class EncoderVarBinary { const uint8_t* src; uint8_t* dst; - src = rows_const->data(2) + row_offset; + src = rows_const->var_length_rows() + row_offset; dst = col_mutable_maybe_null->mutable_data(2) + col_offset; copy_fn(dst, src, length); } diff --git a/cpp/src/arrow/compute/row/encode_internal_avx2.cc b/cpp/src/arrow/compute/row/encode_internal_avx2.cc index d2e317deb890c..650d24b8efc51 100644 --- a/cpp/src/arrow/compute/row/encode_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/encode_internal_avx2.cc @@ -75,14 +75,9 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows uint32_t fixed_length = rows.metadata().fixed_length; const RowTableImpl::offset_type* offsets; - const uint8_t* src_base; if (is_row_fixed_length) { - src_base = rows.data(1) + - static_cast<RowTableImpl::offset_type>(fixed_length) * start_row + - offset_within_row; offsets = nullptr; } else { - src_base = rows.data(2) + offset_within_row; offsets = rows.offsets() + start_row; } @@ -94,14 +89,15 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows for (uint32_t i = 0; i < num_rows / unroll; ++i) { const __m128i *src0, *src1, *src2, *src3; if (is_row_fixed_length) { - const uint8_t* src = src_base + (i * unroll) * fixed_length; + const uint8_t* src = + rows.fixed_length_rows(start_row + i * unroll) + offset_within_row; src0 = reinterpret_cast<const __m128i*>(src); src1 = reinterpret_cast<const __m128i*>(src + fixed_length); src2 = reinterpret_cast<const __m128i*>(src + fixed_length * 2); src3 = reinterpret_cast<const __m128i*>(src + fixed_length * 3); } else { + const uint8_t* src = rows.var_length_rows() + offset_within_row; const RowTableImpl::offset_type* row_offsets = offsets + i * unroll; - const uint8_t* src = src_base; src0 = reinterpret_cast<const __m128i*>(src + row_offsets[0]); src1 = reinterpret_cast<const __m128i*>(src + row_offsets[1]); src2 = reinterpret_cast<const __m128i*>(src + row_offsets[2]); @@ -127,7 +123,8 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows uint8_t buffer[64]; for (uint32_t i = 0; i < num_rows / unroll; ++i) { if (is_row_fixed_length) { - const uint8_t* src = src_base + (i * unroll) * fixed_length; + const uint8_t* src = + rows.fixed_length_rows(start_row + i * unroll) + offset_within_row; for (int j = 0; j < unroll; ++j) { if (col_width == 1) { reinterpret_cast<uint16_t*>(buffer)[j] = @@ -141,8 +138,8 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows } } } else { + const uint8_t* src = rows.var_length_rows() + offset_within_row; const RowTableImpl::offset_type* row_offsets = offsets + i * unroll; - const uint8_t* src = src_base; for (int j = 0; j < unroll; ++j) { if (col_width == 1) { reinterpret_cast<uint16_t*>(buffer)[j] = diff --git a/cpp/src/arrow/compute/row/row_internal.cc b/cpp/src/arrow/compute/row/row_internal.cc index aa7e62add45ff..492cc71ac49f3 100644 --- a/cpp/src/arrow/compute/row/row_internal.cc +++ b/cpp/src/arrow/compute/row/row_internal.cc @@ -406,10 +406,14 @@ bool RowTableImpl::has_any_nulls(const LightContext* ctx) const { return true; } if (num_rows_for_has_any_nulls_ < num_rows_) { - auto size_per_row = metadata().null_masks_bytes_per_row; + DCHECK_LE(num_rows_for_has_any_nulls_, std::numeric_limits<uint32_t>::max()); + int64_t num_bytes = + metadata().null_masks_bytes_per_row * (num_rows_ - num_rows_for_has_any_nulls_); + DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max()); has_any_nulls_ = !util::bit_util::are_all_bytes_zero( - ctx->hardware_flags, null_masks() + size_per_row * num_rows_for_has_any_nulls_, - static_cast<uint32_t>(size_per_row * (num_rows_ - num_rows_for_has_any_nulls_))); + ctx->hardware_flags, + null_masks(static_cast<uint32_t>(num_rows_for_has_any_nulls_)), + static_cast<uint32_t>(num_bytes)); num_rows_for_has_any_nulls_ = num_rows_; } return has_any_nulls_; diff --git a/cpp/src/arrow/compute/row/row_internal.h b/cpp/src/arrow/compute/row/row_internal.h index 3ab86fd1fc6ed..0919773a2281b 100644 --- a/cpp/src/arrow/compute/row/row_internal.h +++ b/cpp/src/arrow/compute/row/row_internal.h @@ -199,29 +199,44 @@ class ARROW_EXPORT RowTableImpl { const RowTableMetadata& metadata() const { return metadata_; } /// \brief The number of rows stored in the table int64_t length() const { return num_rows_; } - // Accessors into the table's buffers - const uint8_t* data(int i) const { - ARROW_DCHECK(i >= 0 && i < kMaxBuffers); - if (ARROW_PREDICT_TRUE(buffers_[i])) { - return buffers_[i]->data(); - } - return NULLPTR; + + const uint8_t* null_masks(uint32_t row_id) const { + return data(0) + static_cast<int64_t>(row_id) * metadata_.null_masks_bytes_per_row; } - uint8_t* mutable_data(int i) { - ARROW_DCHECK(i >= 0 && i < kMaxBuffers); - if (ARROW_PREDICT_TRUE(buffers_[i])) { - return buffers_[i]->mutable_data(); - } - return NULLPTR; + uint8_t* mutable_null_masks(uint32_t row_id) { + return mutable_data(0) + + static_cast<int64_t>(row_id) * metadata_.null_masks_bytes_per_row; + } + bool is_null(uint32_t row_id, uint32_t col_pos) const { + return bit_util::GetBit(null_masks(row_id), col_pos); } + + const uint8_t* fixed_length_rows(uint32_t row_id) const { + ARROW_DCHECK(metadata_.is_fixed_length); + return data(1) + static_cast<int64_t>(row_id) * metadata_.fixed_length; + } + uint8_t* mutable_fixed_length_rows(uint32_t row_id) { + ARROW_DCHECK(metadata_.is_fixed_length); + return mutable_data(1) + static_cast<int64_t>(row_id) * metadata_.fixed_length; + } + const offset_type* offsets() const { + ARROW_DCHECK(!metadata_.is_fixed_length); return reinterpret_cast<const offset_type*>(data(1)); } offset_type* mutable_offsets() { + ARROW_DCHECK(!metadata_.is_fixed_length); return reinterpret_cast<offset_type*>(mutable_data(1)); } - const uint8_t* null_masks() const { return null_masks_->data(); } - uint8_t* null_masks() { return null_masks_->mutable_data(); } + + const uint8_t* var_length_rows() const { + ARROW_DCHECK(!metadata_.is_fixed_length); + return data(2); + } + uint8_t* mutable_var_length_rows() { + ARROW_DCHECK(!metadata_.is_fixed_length); + return mutable_data(2); + } /// \brief True if there is a null value anywhere in the table /// @@ -237,6 +252,22 @@ class ARROW_EXPORT RowTableImpl { } private: + // Accessors into the table's buffers + const uint8_t* data(int i) const { + ARROW_DCHECK(i >= 0 && i < kMaxBuffers); + if (ARROW_PREDICT_TRUE(buffers_[i])) { + return buffers_[i]->data(); + } + return NULLPTR; + } + uint8_t* mutable_data(int i) { + ARROW_DCHECK(i >= 0 && i < kMaxBuffers); + if (ARROW_PREDICT_TRUE(buffers_[i])) { + return buffers_[i]->mutable_data(); + } + return NULLPTR; + } + /// \brief Resize the fixed length buffers to store `num_extra_rows` more rows. The /// fixed length buffers are buffers_[0] for null masks, buffers_[1] for row data if the /// row is fixed length, or for row offsets otherwise. diff --git a/cpp/src/arrow/compute/row/row_test.cc b/cpp/src/arrow/compute/row/row_test.cc index 5057ce91b5bea..49d8f2a9afe14 100644 --- a/cpp/src/arrow/compute/row/row_test.cc +++ b/cpp/src/arrow/compute/row/row_test.cc @@ -92,9 +92,8 @@ TEST(RowTableMemoryConsumption, Encode) { ASSERT_OK_AND_ASSIGN(auto row_table, MakeRowTableFromColumn(col, num_rows, dt->byte_width(), /*string_alignment=*/0)); - ASSERT_NE(row_table.data(0), NULLPTR); - ASSERT_NE(row_table.data(1), NULLPTR); - ASSERT_EQ(row_table.data(2), NULLPTR); + ASSERT_NE(row_table.null_masks(/*row_id=*/0), NULLPTR); + ASSERT_NE(row_table.fixed_length_rows(/*row_id=*/0), NULLPTR); int64_t actual_null_mask_size = num_rows * row_table.metadata().null_masks_bytes_per_row; @@ -113,9 +112,9 @@ TEST(RowTableMemoryConsumption, Encode) { SCOPED_TRACE("encoding var length column of " + std::to_string(num_rows) + " rows"); ASSERT_OK_AND_ASSIGN(auto row_table, MakeRowTableFromColumn(var_length_column, num_rows, 4, 4)); - ASSERT_NE(row_table.data(0), NULLPTR); - ASSERT_NE(row_table.data(1), NULLPTR); - ASSERT_NE(row_table.data(2), NULLPTR); + ASSERT_NE(row_table.null_masks(/*row_id=*/0), NULLPTR); + ASSERT_NE(row_table.offsets(), NULLPTR); + ASSERT_NE(row_table.var_length_rows(), NULLPTR); int64_t actual_null_mask_size = num_rows * row_table.metadata().null_masks_bytes_per_row; diff --git a/cpp/src/arrow/compute/row/row_util_avx2_internal.h b/cpp/src/arrow/compute/row/row_util_avx2_internal.h new file mode 100644 index 0000000000000..a8fce7e0e8687 --- /dev/null +++ b/cpp/src/arrow/compute/row/row_util_avx2_internal.h @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include "arrow/compute/row/row_internal.h" +#include "arrow/util/simd.h" + +#if !defined(ARROW_HAVE_AVX2) && !defined(ARROW_HAVE_AVX512) && \ + !defined(ARROW_HAVE_RUNTIME_AVX2) && !defined(ARROW_HAVE_RUNTIME_AVX512) +# error "This file should only be included when AVX2 or AVX512 is enabled" +#endif + +namespace arrow::compute { + +// Convert 8 64-bit comparision results, each being 0 or -1, to 8 bytes. +inline uint64_t Cmp64To8(__m256i cmp64_lo, __m256i cmp64_hi) { + uint32_t cmp_lo = _mm256_movemask_epi8(cmp64_lo); + uint32_t cmp_hi = _mm256_movemask_epi8(cmp64_hi); + return cmp_lo | (static_cast<uint64_t>(cmp_hi) << 32); +} + +// Convert 8 32-bit comparision results, each being 0 or -1, to 8 bytes. +inline uint64_t Cmp32To8(__m256i cmp32) { + return Cmp64To8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp32)), + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp32, 1))); +} + +// Get null bits for 8 32-bit row ids in `row_id32` at `col_pos` as a vector of 32-bit +// integers. Note that the result integer is 0 if the corresponding column is not null, or +// 1 otherwise. +inline __m256i GetNullBitInt32(const RowTableImpl& rows, uint32_t col_pos, + __m256i row_id32) { + const uint8_t* null_masks = rows.null_masks(/*row_id=*/0); + __m256i null_mask_num_bits = + _mm256_set1_epi64x(rows.metadata().null_masks_bytes_per_row * 8); + __m256i row_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(row_id32)); + __m256i row_hi = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(row_id32, 1)); + __m256i bit_id_lo = _mm256_mul_epi32(row_lo, null_mask_num_bits); + __m256i bit_id_hi = _mm256_mul_epi32(row_hi, null_mask_num_bits); + bit_id_lo = _mm256_add_epi64(bit_id_lo, _mm256_set1_epi64x(col_pos)); + bit_id_hi = _mm256_add_epi64(bit_id_hi, _mm256_set1_epi64x(col_pos)); + __m128i right_lo = _mm256_i64gather_epi32(reinterpret_cast<const int*>(null_masks), + _mm256_srli_epi64(bit_id_lo, 3), 1); + __m128i right_hi = _mm256_i64gather_epi32(reinterpret_cast<const int*>(null_masks), + _mm256_srli_epi64(bit_id_hi, 3), 1); + __m256i right = _mm256_set_m128i(right_hi, right_lo); + return _mm256_and_si256(_mm256_set1_epi32(1), _mm256_srli_epi32(right, col_pos & 7)); +} + +} // namespace arrow::compute