From d046d7a21ac2e9538fe350c21b6887b53b12416c Mon Sep 17 00:00:00 2001 From: Krisztian Szucs Date: Thu, 19 Dec 2024 20:30:34 +0100 Subject: [PATCH] add type matcher to explicitly raise NotImplemented for unsupported types --- cpp/src/arrow/compute/kernels/scalar_hash.cc | 26 +++++++++++++++++-- .../arrow/compute/kernels/scalar_hash_test.cc | 16 ++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_hash.cc b/cpp/src/arrow/compute/kernels/scalar_hash.cc index 6fd5d0d3c9be8..8b443b1d2b474 100644 --- a/cpp/src/arrow/compute/kernels/scalar_hash.cc +++ b/cpp/src/arrow/compute/kernels/scalar_hash.cc @@ -180,6 +180,26 @@ struct FastHashScalar { } }; +class HashableMatcher : public TypeMatcher { + public: + HashableMatcher() {} + + bool Matches(const DataType& type) const override { + return !(is_union(type) || is_binary_view_like(type) || is_list_view(type) || + type.id() == Type::RUN_END_ENCODED); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + return casted != nullptr; + } + + std::string ToString() const override { return "hashable"; } +}; + const FunctionDoc hash32_doc{ "Construct a hash for every element of the input argument", ("This function is not suitable for cryptographic purposes.\n" @@ -191,6 +211,7 @@ const FunctionDoc hash64_doc{ ("This function is not suitable for cryptographic purposes.\n" "Hash results are 64-bit and emitted for each row, including NULLs."), {"hash_input"}}; + } // namespace void RegisterScalarHash(FunctionRegistry* registry) { @@ -199,9 +220,10 @@ void RegisterScalarHash(FunctionRegistry* registry) { auto hash64 = std::make_shared("hash64", Arity::Unary(), hash64_doc); // Add 32-bit and 64-bit kernels to hash32 and hash64 functions - ScalarKernel kernel32({InputType()}, OutputType(uint32()), + auto type_matcher = std::make_shared(); + ScalarKernel kernel32({InputType(type_matcher)}, OutputType(uint32()), FastHashScalar::Exec); - ScalarKernel kernel64({InputType()}, OutputType(uint64()), + ScalarKernel kernel64({InputType(type_matcher)}, OutputType(uint64()), FastHashScalar::Exec); kernel32.null_handling = NullHandling::OUTPUT_NOT_NULL; kernel64.null_handling = NullHandling::OUTPUT_NOT_NULL; diff --git a/cpp/src/arrow/compute/kernels/scalar_hash_test.cc b/cpp/src/arrow/compute/kernels/scalar_hash_test.cc index 96f1b84d51280..503717475f404 100644 --- a/cpp/src/arrow/compute/kernels/scalar_hash_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_hash_test.cc @@ -468,6 +468,22 @@ TEST_F(TestScalarHash, RandomMap) { } } +TEST_F(TestScalarHash, UnsuppoertedTypes) { + auto rand = random::RandomArrayGenerator(kSeed); + auto types = {list_view(int64()), + large_list_view(int64()), + binary_view(), + utf8_view(), + dense_union({field("a", int64()), field("b", binary())}), + sparse_union({field("a", int64()), field("b", binary())}), + run_end_encoded(int16(), utf8())}; + for (auto type : types) { + auto arr = rand.ArrayOf(type, 1, 0); + ASSERT_RAISES(NotImplemented, CallFunction("hash32", {arr})); + ASSERT_RAISES(NotImplemented, CallFunction("hash64", {arr})); + } +} + // copied from cpp/src/arrow/util/hashing_test.cc template static std::unordered_set MakeSequentialIntegers(int32_t n_values) {