Skip to content

Commit

Permalink
add rudimentary quality tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Dec 18, 2024
1 parent b50d696 commit 9b94bff
Showing 1 changed file with 86 additions and 4 deletions.
90 changes: 86 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_hash_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include <gtest/gtest.h>
#include <unordered_set>

#include "arrow/chunked_array.h"
#include "arrow/compute/api.h"
Expand All @@ -36,10 +37,8 @@ namespace arrow {
namespace compute {

constexpr auto kSeed = 0x94378165;
// constexpr auto kArrayLengths = {0, 50, 100};
// constexpr auto kNullProbabilities = {0.0, 0.5, 1.0};
constexpr auto kArrayLengths = {5};
constexpr auto kNullProbabilities = {0.0};
constexpr auto kArrayLengths = {0, 50, 100};
constexpr auto kNullProbabilities = {0.0, 0.5, 1.0};

class TestScalarHash : public ::testing::Test {
public:
Expand Down Expand Up @@ -136,6 +135,38 @@ class TestScalarHash : public ::testing::Test {
}
}

void CheckHashQuality(const std::string& func, const std::shared_ptr<Array>& arr,
float tolerance = 1.0) {
ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func, {arr}));
auto hashes = result.make_array();

auto expected = arr->length();
if (arr->null_count()) {
expected -= (arr->null_count() - 1);
}
if (func == "hash64") {
auto hashes64 = dynamic_cast<const UInt64Array*>(hashes.get());
std::unordered_set<uint64_t> hash_set;
for (int64_t i = 0; i < hashes64->length(); ++i) {
hash_set.insert(hashes64->Value(i));
}
ASSERT_LE(hash_set.size(), expected);
ASSERT_GE(hash_set.size(), expected * tolerance);
} else if (func == "hash32") {
auto hashes32 = dynamic_cast<const UInt32Array*>(hashes.get());
std::unordered_set<uint32_t> hash_set;
for (int64_t i = 0; i < hashes32->length(); ++i) {
if (hashes32->IsValid(i)) {
hash_set.insert(hashes32->Value(i));
}
}
ASSERT_LE(hash_set.size(), expected);
ASSERT_GE(hash_set.size(), expected * tolerance);
} else {
FAIL() << "Unknown function: " << func;
}
}

void CheckPrimitive(const std::string& func, const std::shared_ptr<Array>& arr) {
ASSERT_OK_AND_ASSIGN(Datum hash_result, CallFunction(func, {arr}));
CheckDeterministic(func, arr);
Expand Down Expand Up @@ -394,5 +425,56 @@ TEST_F(TestScalarHash, RandomMap) {
}
}

// copied from cpp/src/arrow/util/hashing_test.cc
template <typename Integer>
static std::unordered_set<Integer> MakeSequentialIntegers(int32_t n_values) {
std::unordered_set<Integer> values;
values.reserve(n_values);

for (int32_t i = 0; i < n_values; ++i) {
values.insert(static_cast<Integer>(i));
}
DCHECK_EQ(values.size(), static_cast<uint32_t>(n_values));
return values;
}

// copied from cpp/src/arrow/util/hashing_test.cc
static std::unordered_set<std::string> MakeDistinctStrings(int32_t n_values) {
std::unordered_set<std::string> values;
values.reserve(n_values);

// Generate strings between 0 and 24 bytes, with ASCII characters
std::default_random_engine gen(42);
std::uniform_int_distribution<int32_t> length_dist(0, 24);
std::uniform_int_distribution<uint32_t> char_dist('0', 'z');

while (values.size() < static_cast<uint32_t>(n_values)) {
auto length = length_dist(gen);
std::string s(length, 'X');
for (int32_t i = 0; i < length; ++i) {
s[i] = static_cast<uint8_t>(char_dist(gen));
}
values.insert(std::move(s));
}
return values;
}

TEST_F(TestScalarHash, HashQuality) {
for (auto& func : {"hash32", "hash64"}) {
std::shared_ptr<Array> arr;
auto integer_values = MakeSequentialIntegers<int32_t>(100000);
auto integer_vector =
std::vector<int32_t>(integer_values.begin(), integer_values.end());
arrow::ArrayFromVector<Int32Type>(integer_vector, &arr);
CheckHashQuality(func, arr);

auto string_values = MakeDistinctStrings(10000);
auto string_vector =
std::vector<std::string>(string_values.begin(), string_values.end());
arrow::ArrayFromVector<StringType>(string_vector, &arr);
CheckHashQuality(func, arr);
}
}

} // namespace compute
} // namespace arrow

0 comments on commit 9b94bff

Please sign in to comment.