Skip to content

Commit

Permalink
ensure that slices properly work
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Dec 18, 2024
1 parent 88e7707 commit 8ced842
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
12 changes: 8 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct FastHashScalar {
static Status HashArray(const ArraySpan& array, LightContext* hash_ctx,
MemoryPool* memory_pool, c_type* out) {
// KeyColumnArray objects are being passed to the hashing utility
KeyColumnArray column;
std::vector<KeyColumnArray> columns(1);

auto type_id = array.type->id();
Expand All @@ -126,21 +127,24 @@ struct FastHashScalar {
if (is_nested(child.type->id())) {
ARROW_ASSIGN_OR_RAISE(child_hashes[i],
HashChild(array, child, hash_ctx, memory_pool));
ARROW_ASSIGN_OR_RAISE(columns[i], ToColumnArray(*child_hashes[i], hash_ctx));
ARROW_ASSIGN_OR_RAISE(column, ToColumnArray(*child_hashes[i], hash_ctx));
} else {
ARROW_ASSIGN_OR_RAISE(columns[i], ToColumnArray(child, hash_ctx));
ARROW_ASSIGN_OR_RAISE(column, ToColumnArray(child, hash_ctx));
}
columns[i] = column.Slice(array.offset, array.length);
}
Hasher::HashMultiColumn(columns, hash_ctx, out);
} else if (is_list_like(type_id)) {
auto values = array.child_data[0];
ARROW_ASSIGN_OR_RAISE(auto value_hashes,
HashChild(array, values, hash_ctx, memory_pool));
ARROW_ASSIGN_OR_RAISE(
columns[0], ToColumnArray(array, hash_ctx, value_hashes->buffers[1]->data()));
column, ToColumnArray(array, hash_ctx, value_hashes->buffers[1]->data()));
columns[0] = column.Slice(array.offset, array.length);
Hasher::HashMultiColumn(columns, hash_ctx, out);
} else {
ARROW_ASSIGN_OR_RAISE(columns[0], ToColumnArray(array, hash_ctx));
ARROW_ASSIGN_OR_RAISE(column, ToColumnArray(array, hash_ctx));
columns[0] = column.Slice(array.offset, array.length);
Hasher::HashMultiColumn(columns, hash_ctx, out);
}
return Status::OK();
Expand Down
29 changes: 26 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_hash_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ namespace arrow {
namespace compute {

constexpr auto kSeed = 0x94378165;
constexpr auto kArrayLengths = {0, 50, 100};
constexpr auto kNullProbabilities = {0.0, 0.5, 1.0};
// constexpr auto kArrayLengths = {0, 50, 100};
// constexpr auto kNullProbabilities = {0.0, 0.5, 1.0};
constexpr auto kArrayLengths = {5};
constexpr auto kNullProbabilities = {0.0};

class TestScalarHash : public ::testing::Test {
public:
Expand Down Expand Up @@ -103,12 +105,13 @@ class TestScalarHash : public ::testing::Test {
}

void CheckDeterministic(const std::string& func, const std::shared_ptr<Array>& arr) {
// Check that the hash is deterministic
// Check that the hash is deterministic between different runs
ASSERT_OK_AND_ASSIGN(Datum res1, CallFunction(func, {arr}));
ASSERT_OK_AND_ASSIGN(Datum res2, CallFunction(func, {arr}));
ValidateOutput(res1);
ValidateOutput(res2);
ASSERT_EQ(res1.length(), arr->length());
ASSERT_EQ(res2.length(), arr->length());
if (func == "hash64") {
ASSERT_EQ(res1.type()->id(), Type::UINT64);
} else if (func == "hash32") {
Expand All @@ -117,6 +120,25 @@ class TestScalarHash : public ::testing::Test {
FAIL() << "Unknown function: " << func;
}
AssertDatumsEqual(res1, res2);

// Check that slicing the array does not affect the hash
auto hashes = res1.make_array();

ARROW_LOG(INFO) << "Truth: " << hashes->ToString();

if (arr->length() >= 1) {
auto in1 = arr->Slice(1);
ASSERT_OK_AND_ASSIGN(Datum out1, CallFunction(func, {in1}));
ARROW_LOG(INFO) << "Result: " << out1.make_array()->ToString();
ARROW_LOG(INFO) << "Hashes: " << hashes->Slice(1)->ToString();
ValidateOutput(out1);
AssertArraysEqual(*out1.make_array(), *hashes->Slice(1));
} else if (arr->length() >= 4) {
auto in2 = arr->Slice(2, 2);
ASSERT_OK_AND_ASSIGN(Datum out2, CallFunction(func, {in2}));
ValidateOutput(out2);
AssertArraysEqual(*out2.make_array(), *hashes->Slice(2, 2));
}
}

void CheckPrimitive(const std::string& func, const std::shared_ptr<Array>& arr) {
Expand Down Expand Up @@ -229,6 +251,7 @@ TEST_F(TestScalarHash, NumericLike) {
CheckPrimitive(func, ArrayFromJSON(type, R"([])"));
CheckPrimitive(func, ArrayFromJSON(type, R"([null])"));
CheckPrimitive(func, ArrayFromJSON(type, R"([1])"));
CheckPrimitive(func, ArrayFromJSON(type, R"([1, 2])"));
CheckPrimitive(func, ArrayFromJSON(type, R"([1, 2, null])"));
CheckPrimitive(func, ArrayFromJSON(type, R"([null, 2, 3])"));
CheckPrimitive(func, ArrayFromJSON(type, R"([1, 2, 3, 4])"));
Expand Down

0 comments on commit 8ced842

Please sign in to comment.