Skip to content

Commit

Permalink
apacheGH-45334: [C++][Acero] Fix swiss join overflow issues in row of…
Browse files Browse the repository at this point in the history
…fset calculation for fixed length and null masks (apache#45336)

### Rationale for this change

apache#45334 

### What changes are included in this PR?

1. An all-mighty test case that can effectively reveal all the bugs mentioned in the issue;
2. Other than directly fixing the bugs (actually simply casting to 64-bit somewhere in the multiplication will do), I did some refinement to the buffer accessors of the row table, in order to eliminate more potential similar issues (which I believe do exist):
    1. `null_masks()` -> `null_masks(row_id)` which does overflow-safe indexing inside;
    2. `is_null(row_id, col_pos)` which does overflow-safe indexing and directly gets the bit of the column;
    3. `data(1)` -> `fixed_length_rows(row_id)` which first asserts the row table being fixed-length, then does overflow-safe indexing inside;
    4. `data(2)` -> `var_length_rows()` which only asserts the row table being var-length. It is supposed to be paired by the `offsets()` (which is already 64-bit by apache#43389 );
    5. The `data(0/1/2)` members are made private.
3. The AVX2 specializations are fixed individually by using 64-bit multiplication and indexing.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

None.

* GitHub Issue: apache#45334

Authored-by: Rossi Sun <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
  • Loading branch information
zanmato1984 authored Jan 27, 2025
1 parent ac1e7ec commit b818560
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 198 deletions.
99 changes: 99 additions & 0 deletions cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3449,5 +3449,104 @@ TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBVarLength)) {
num_batches_left * num_rows_per_batch_left * num_batches_right);
}

// GH-45334: The row ids of the matching rows on the right side (the build side) are very
// big, causing the index calculation overflow.
TEST(HashJoin, BuildSideLargeRowIds) {
GTEST_SKIP() << "Test disabled due to excessively time and resource consuming, "
"for local debugging only.";

// A fair amount of match rows to trigger both SIMD and non-SIMD code paths.
const int64_t num_match_rows = 35;
const int64_t num_rows_per_match_batch = 35;
const int64_t num_match_batches = num_match_rows / num_rows_per_match_batch;

const int64_t num_unmatch_rows_large = 720898048;
const int64_t num_rows_per_unmatch_batch_large = 352001;
const int64_t num_unmatch_batches_large =
num_unmatch_rows_large / num_rows_per_unmatch_batch_large;

auto schema_small =
schema({field("small_key", int64()), field("small_payload", int64())});
auto schema_large =
schema({field("large_key", int64()), field("large_payload", int64())});

// A carefully chosen key value which hashes to 0xFFFFFFFE, making the match rows to be
// placed at higher address of the row table.
const int64_t match_key = 289339070;
const int64_t match_payload = 42;

// Match arrays of length num_rows_per_match_batch.
ASSERT_OK_AND_ASSIGN(
auto match_key_arr,
Constant(MakeScalar(match_key))->Generate(num_rows_per_match_batch));
ASSERT_OK_AND_ASSIGN(
auto match_payload_arr,
Constant(MakeScalar(match_payload))->Generate(num_rows_per_match_batch));
// Append 1 row of null to trigger null processing code paths.
ASSERT_OK_AND_ASSIGN(auto null_arr, MakeArrayOfNull(int64(), 1));
ASSERT_OK_AND_ASSIGN(match_key_arr, Concatenate({match_key_arr, null_arr}));
ASSERT_OK_AND_ASSIGN(match_payload_arr, Concatenate({match_payload_arr, null_arr}));
// Match batch.
ExecBatch match_batch({match_key_arr, match_payload_arr}, num_rows_per_match_batch + 1);

// Small batch.
ExecBatch batch_small = match_batch;

// Large unmatch batches.
const int64_t seed = 42;
std::vector<ExecBatch> unmatch_batches_large;
unmatch_batches_large.reserve(num_unmatch_batches_large);
ASSERT_OK_AND_ASSIGN(auto unmatch_payload_arr_large,
MakeArrayOfNull(int64(), num_rows_per_unmatch_batch_large));
int64_t unmatch_range_per_batch =
(std::numeric_limits<int64_t>::max() - match_key) / num_unmatch_batches_large;
for (int i = 0; i < num_unmatch_batches_large; ++i) {
auto unmatch_key_arr_large = RandomArrayGenerator(seed).Int64(
num_rows_per_unmatch_batch_large,
/*min=*/match_key + 1 + i * unmatch_range_per_batch,
/*max=*/match_key + 1 + (i + 1) * unmatch_range_per_batch);
unmatch_batches_large.push_back(
ExecBatch({unmatch_key_arr_large, unmatch_payload_arr_large},
num_rows_per_unmatch_batch_large));
}
// Large match batch.
ExecBatch match_batch_large = match_batch;

// Batches with schemas.
auto batches_small = BatchesWithSchema{
std::vector<ExecBatch>(num_match_batches, batch_small), schema_small};
auto batches_large = BatchesWithSchema{std::move(unmatch_batches_large), schema_large};
for (int i = 0; i < num_match_batches; i++) {
batches_large.batches.push_back(match_batch_large);
}

Declaration source_small{
"exec_batch_source",
ExecBatchSourceNodeOptions(batches_small.schema, batches_small.batches)};
Declaration source_large{
"exec_batch_source",
ExecBatchSourceNodeOptions(batches_large.schema, batches_large.batches)};

HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"small_key"},
/*right_keys=*/{"large_key"});
Declaration join{
"hashjoin", {std::move(source_small), std::move(source_large)}, join_opts};

// Join should emit num_match_rows * num_match_rows rows.
ASSERT_OK_AND_ASSIGN(auto batches_result, DeclarationToExecBatches(std::move(join)));
Declaration result{"exec_batch_source",
ExecBatchSourceNodeOptions(std::move(batches_result.schema),
std::move(batches_result.batches))};
AssertRowCountEq(result, num_match_rows * num_match_rows);

// All rows should be match_key/payload.
auto predicate = and_({equal(field_ref("small_key"), literal(match_key)),
equal(field_ref("small_payload"), literal(match_payload)),
equal(field_ref("large_key"), literal(match_key)),
equal(field_ref("large_payload"), literal(match_payload))});
Declaration filter{"filter", {result}, FilterNodeOptions{std::move(predicate)}};
AssertRowCountEq(std::move(filter), num_match_rows * num_match_rows);
}

} // namespace acero
} // namespace arrow
36 changes: 22 additions & 14 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,15 @@ void RowArrayMerge::CopyFixedLength(RowTableImpl* target, const RowTableImpl& so
const int64_t* source_rows_permutation) {
int64_t num_source_rows = source.length();

int64_t fixed_length = target->metadata().fixed_length;
uint32_t fixed_length = target->metadata().fixed_length;

// Permutation of source rows is optional. Without permutation all that is
// needed is memcpy.
//
if (!source_rows_permutation) {
memcpy(target->mutable_data(1) + fixed_length * first_target_row_id, source.data(1),
fixed_length * num_source_rows);
DCHECK_LE(first_target_row_id, std::numeric_limits<uint32_t>::max());
memcpy(target->mutable_fixed_length_rows(static_cast<uint32_t>(first_target_row_id)),
source.fixed_length_rows(/*row_id=*/0), fixed_length * num_source_rows);
} else {
// Row length must be a multiple of 64-bits due to enforced alignment.
// Loop for each output row copying a fixed number of 64-bit words.
Expand All @@ -494,10 +495,13 @@ void RowArrayMerge::CopyFixedLength(RowTableImpl* target, const RowTableImpl& so
int64_t num_words_per_row = fixed_length / sizeof(uint64_t);
for (int64_t i = 0; i < num_source_rows; ++i) {
int64_t source_row_id = source_rows_permutation[i];
DCHECK_LE(source_row_id, std::numeric_limits<uint32_t>::max());
const uint64_t* source_row_ptr = reinterpret_cast<const uint64_t*>(
source.data(1) + fixed_length * source_row_id);
source.fixed_length_rows(static_cast<uint32_t>(source_row_id)));
int64_t target_row_id = first_target_row_id + i;
DCHECK_LE(target_row_id, std::numeric_limits<uint32_t>::max());
uint64_t* target_row_ptr = reinterpret_cast<uint64_t*>(
target->mutable_data(1) + fixed_length * (first_target_row_id + i));
target->mutable_fixed_length_rows(static_cast<uint32_t>(target_row_id)));

for (int64_t word = 0; word < num_words_per_row; ++word) {
target_row_ptr[word] = source_row_ptr[word];
Expand Down Expand Up @@ -529,16 +533,16 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl&

// We can simply memcpy bytes of rows if their order has not changed.
//
memcpy(target->mutable_data(2) + target_offsets[first_target_row_id], source.data(2),
source_offsets[num_source_rows] - source_offsets[0]);
memcpy(target->mutable_var_length_rows() + target_offsets[first_target_row_id],
source.var_length_rows(), source_offsets[num_source_rows] - source_offsets[0]);
} else {
int64_t target_row_offset = first_target_row_offset;
uint64_t* target_row_ptr =
reinterpret_cast<uint64_t*>(target->mutable_data(2) + target_row_offset);
uint64_t* target_row_ptr = reinterpret_cast<uint64_t*>(
target->mutable_var_length_rows() + target_row_offset);
for (int64_t i = 0; i < num_source_rows; ++i) {
int64_t source_row_id = source_rows_permutation[i];
const uint64_t* source_row_ptr = reinterpret_cast<const uint64_t*>(
source.data(2) + source_offsets[source_row_id]);
source.var_length_rows() + source_offsets[source_row_id]);
int64_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id];
// Though the row offset is 64-bit, the length of a single row must be 32-bit as
// required by current row table implementation.
Expand All @@ -564,14 +568,18 @@ void RowArrayMerge::CopyNulls(RowTableImpl* target, const RowTableImpl& source,
const int64_t* source_rows_permutation) {
int64_t num_source_rows = source.length();
int num_bytes_per_row = target->metadata().null_masks_bytes_per_row;
uint8_t* target_nulls = target->null_masks() + num_bytes_per_row * first_target_row_id;
DCHECK_LE(first_target_row_id, std::numeric_limits<uint32_t>::max());
uint8_t* target_nulls =
target->mutable_null_masks(static_cast<uint32_t>(first_target_row_id));
if (!source_rows_permutation) {
memcpy(target_nulls, source.null_masks(), num_bytes_per_row * num_source_rows);
memcpy(target_nulls, source.null_masks(/*row_id=*/0),
num_bytes_per_row * num_source_rows);
} else {
for (int64_t i = 0; i < num_source_rows; ++i) {
for (uint32_t i = 0; i < num_source_rows; ++i) {
int64_t source_row_id = source_rows_permutation[i];
DCHECK_LE(source_row_id, std::numeric_limits<uint32_t>::max());
const uint8_t* source_nulls =
source.null_masks() + num_bytes_per_row * source_row_id;
source.null_masks(static_cast<uint32_t>(source_row_id));
for (int64_t byte = 0; byte < num_bytes_per_row; ++byte) {
*target_nulls++ = *source_nulls++;
}
Expand Down
34 changes: 8 additions & 26 deletions cpp/src/arrow/acero/swiss_join_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "arrow/acero/swiss_join_internal.h"
#include "arrow/compute/row/row_util_avx2_internal.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/simd.h"

Expand Down Expand Up @@ -46,7 +47,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu

if (!is_fixed_length_column) {
int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id);
const uint8_t* row_ptr_base = rows.data(2);
const uint8_t* row_ptr_base = rows.var_length_rows();
const RowTableImpl::offset_type* row_offsets = rows.offsets();
auto row_offsets_i64 =
reinterpret_cast<const arrow::util::int64_for_gather_t*>(row_offsets);
Expand Down Expand Up @@ -172,7 +173,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu
if (is_fixed_length_row) {
// Case 3: This is a fixed length column in fixed length row
//
const uint8_t* row_ptr_base = rows.data(1);
const uint8_t* row_ptr_base = rows.fixed_length_rows(/*row_id=*/0);
for (int i = 0; i < num_rows / kUnroll; ++i) {
// Load 8 32-bit row ids.
__m256i row_id =
Expand All @@ -197,7 +198,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu
} else {
// Case 4: This is a fixed length column in varying length row
//
const uint8_t* row_ptr_base = rows.data(2);
const uint8_t* row_ptr_base = rows.var_length_rows();
const RowTableImpl::offset_type* row_offsets = rows.offsets();
auto row_offsets_i64 =
reinterpret_cast<const arrow::util::int64_for_gather_t*>(row_offsets);
Expand Down Expand Up @@ -237,31 +238,12 @@ int RowArrayAccessor::VisitNulls_avx2(const RowTableImpl& rows, int column_id,
//
constexpr int kUnroll = 8;

const uint8_t* null_masks = rows.null_masks();
__m256i null_bits_per_row =
_mm256_set1_epi32(8 * rows.metadata().null_masks_bytes_per_row);
__m256i pos_after_encoding =
_mm256_set1_epi32(rows.metadata().pos_after_encoding(column_id));
uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id);
for (int i = 0; i < num_rows / kUnroll; ++i) {
__m256i row_id = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_ids) + i);
__m256i bit_id = _mm256_mullo_epi32(row_id, null_bits_per_row);
bit_id = _mm256_add_epi32(bit_id, pos_after_encoding);
__m256i bytes = _mm256_i32gather_epi32(reinterpret_cast<const int*>(null_masks),
_mm256_srli_epi32(bit_id, 3), 1);
__m256i bit_in_word = _mm256_sllv_epi32(
_mm256_set1_epi32(1), _mm256_and_si256(bit_id, _mm256_set1_epi32(7)));
// `result` will contain one 32-bit word per tested null bit, either 0xffffffff if the
// null bit was set or 0 if it was unset.
__m256i result =
_mm256_cmpeq_epi32(_mm256_and_si256(bytes, bit_in_word), bit_in_word);
// NB: Be careful about sign-extension when casting the return value of
// _mm256_movemask_epi8 (signed 32-bit) to unsigned 64-bit, which will pollute the
// higher bits of the following OR.
uint32_t null_bytes_lo = static_cast<uint32_t>(
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(result))));
uint64_t null_bytes_hi =
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(result, 1)));
uint64_t null_bytes = null_bytes_lo | (null_bytes_hi << 32);
__m256i null32 = GetNullBitInt32(rows, pos_after_encoding, row_id);
null32 = _mm256_cmpeq_epi32(null32, _mm256_set1_epi32(1));
uint64_t null_bytes = arrow::compute::Cmp32To8(null32);

process_8_values_fn(i * kUnroll, null_bytes);
}
Expand Down
14 changes: 5 additions & 9 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RowArrayAccessor {

if (!is_fixed_length_column) {
int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id);
const uint8_t* row_ptr_base = rows.data(2);
const uint8_t* row_ptr_base = rows.var_length_rows();
const RowTableImpl::offset_type* row_offsets = rows.offsets();
uint32_t field_offset_within_row, field_length;

Expand Down Expand Up @@ -108,22 +108,21 @@ class RowArrayAccessor {
if (field_length == 0) {
field_length = 1;
}
uint32_t row_length = rows.metadata().fixed_length;

bool is_fixed_length_row = rows.metadata().is_fixed_length;
if (is_fixed_length_row) {
// Case 3: This is a fixed length column in a fixed length row
//
const uint8_t* row_ptr_base = rows.data(1) + field_offset_within_row;
for (int i = 0; i < num_rows; ++i) {
uint32_t row_id = row_ids[i];
const uint8_t* row_ptr = row_ptr_base + row_length * row_id;
const uint8_t* row_ptr =
rows.fixed_length_rows(row_id) + field_offset_within_row;
process_value_fn(i, row_ptr, field_length);
}
} else {
// Case 4: This is a fixed length column in a varying length row
//
const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row;
const uint8_t* row_ptr_base = rows.var_length_rows() + field_offset_within_row;
const RowTableImpl::offset_type* row_offsets = rows.offsets();
for (int i = 0; i < num_rows; ++i) {
uint32_t row_id = row_ids[i];
Expand All @@ -142,13 +141,10 @@ class RowArrayAccessor {
template <class PROCESS_VALUE_FN>
static void VisitNulls(const RowTableImpl& rows, int column_id, int num_rows,
const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn) {
const uint8_t* null_masks = rows.null_masks();
uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id);
for (int i = 0; i < num_rows; ++i) {
uint32_t row_id = row_ids[i];
int64_t bit_id = row_id * null_mask_num_bytes * 8 + pos_after_encoding;
process_value_fn(i, bit_util::GetBit(null_masks, bit_id) ? 0xff : 0);
process_value_fn(i, rows.is_null(row_id, pos_after_encoding) ? 0xff : 0);
}
}

Expand Down
16 changes: 5 additions & 11 deletions cpp/src/arrow/compute/row/compare_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,10 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com

if (!col.data(0)) {
// Remove rows from the result for which the column value is a null
const uint8_t* null_masks = rows.null_masks();
uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
int64_t bitid = irow_right * null_mask_num_bytes * 8 + null_bit_id;
match_bytevector[i] &= (bit_util::GetBit(null_masks, bitid) ? 0 : 0xff);
match_bytevector[i] &= (rows.is_null(irow_right, null_bit_id) ? 0 : 0xff);
}
} else if (!rows.has_any_nulls(ctx)) {
// Remove rows from the result for which the column value on left side is
Expand All @@ -74,15 +71,12 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com
bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0xff : 0;
}
} else {
const uint8_t* null_masks = rows.null_masks();
uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
const uint8_t* non_nulls = col.data(0);
ARROW_DCHECK(non_nulls);
for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + null_bit_id;
int right_null = bit_util::GetBit(null_masks, bitid_right) ? 0xff : 0;
int right_null = rows.is_null(irow_right, null_bit_id) ? 0xff : 0;
int left_null =
bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff;
match_bytevector[i] |= left_null & right_null;
Expand All @@ -101,7 +95,7 @@ void KeyCompare::CompareBinaryColumnToRowHelper(
if (is_fixed_length) {
uint32_t fixed_length = rows.metadata().fixed_length;
const uint8_t* rows_left = col.data(1);
const uint8_t* rows_right = rows.data(1);
const uint8_t* rows_right = rows.fixed_length_rows(/*row_id=*/0);
for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
// irow_right is used to index into row data so promote to the row offset type.
Expand All @@ -113,7 +107,7 @@ void KeyCompare::CompareBinaryColumnToRowHelper(
} else {
const uint8_t* rows_left = col.data(1);
const RowTableImpl::offset_type* offsets_right = rows.offsets();
const uint8_t* rows_right = rows.data(2);
const uint8_t* rows_right = rows.var_length_rows();
for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
Expand Down Expand Up @@ -246,7 +240,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper(
const uint32_t* offsets_left = col.offsets();
const RowTableImpl::offset_type* offsets_right = rows.offsets();
const uint8_t* rows_left = col.data(2);
const uint8_t* rows_right = rows.data(2);
const uint8_t* rows_right = rows.var_length_rows();
for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
Expand Down
Loading

0 comments on commit b818560

Please sign in to comment.