From 862aafa3768049ae7c66c7c22ae3815f2d57119e Mon Sep 17 00:00:00 2001 From: Larry Wang Date: Thu, 1 Aug 2024 18:02:18 -0400 Subject: [PATCH] implement inexact dispatch on guarantee and value set types --- cpp/src/arrow/compute/expression.cc | 67 +++++++++++++----------- cpp/src/arrow/compute/expression_test.cc | 41 +++++++++------ 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index d72dc6d34e0e1..85361623c3ea4 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -28,6 +28,7 @@ #include "arrow/compute/expression_internal.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/util.h" +#include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/kernels/set_lookup_internal.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" @@ -1185,9 +1186,11 @@ Result> PrepareIsInValueSet(std::shared_ptr value_ /// /// \pre `is_in_call` is a call to the `is_in` function /// \return the value set to be simplified, guaranteed to be sorted with no -/// duplicate or null values +/// duplicate or null values and cast to the given type Result> GetIsInValueSetForSimplification( - const Expression::Call* is_in_call, SimplificationContext& context) { + const Expression::Call* is_in_call, + const TypeHolder& type, + SimplificationContext& context) { DCHECK_EQ(is_in_call->function_name, "is_in"); std::shared_ptr& value_set = context.is_in_value_sets[is_in_call]; if (!value_set) { @@ -1202,6 +1205,11 @@ Result> GetIsInValueSetForSimplification( ARROW_ASSIGN_OR_RAISE(state->sorted_and_unique_value_set, PrepareIsInValueSet(unprepared_value_set)); } + if (!state->sorted_and_unique_value_set->type()->Equals(*type)) { + ARROW_ASSIGN_OR_RAISE( + state->sorted_and_unique_value_set, + Cast(*state->sorted_and_unique_value_set, type, CastOptions::Safe())); + } value_set = state->sorted_and_unique_value_set; } return value_set; @@ -1317,25 +1325,17 @@ struct Inequality { /// \return a simplified value set, or a bool if the simplification of the value set /// means the whole is_in expr can become a boolean literal. template - static std::variant, bool> SimplifyIsInValueSet( + static Result, bool>> SimplifyIsInValueSet( const Inequality& guarantee, std::shared_ptr value_set) { using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - using CType = decltype(checked_pointer_cast(value_set)->GetView(0)); DCHECK(guarantee.bound.is_scalar()); - DCHECK_EQ(guarantee.bound.type()->id(), value_set->type_id()); if (value_set->length() == 0) return false; - CType bound; - if constexpr (std::is_same_v, - typename ScalarType::ValueType>) { - bound = guarantee.bound.scalar_as().view(); - } else { - bound = guarantee.bound.scalar_as().value; - } - + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar_bound, + guarantee.bound.scalar()->CastTo(value_set->type())); + auto bound = internal::UnboxScalar::Unbox(*scalar_bound); auto compare = [&bound, &value_set](size_t i) -> Comparison::type { DCHECK(value_set->IsValid(i)); auto value = checked_pointer_cast(value_set)->GetView(i); @@ -1378,7 +1378,7 @@ struct Inequality { case Comparison::NOT_EQUAL: case Comparison::NA: DCHECK(false); - break; + return Status::Invalid("Invalid comparison"); } if (value_set->length() == 0) return false; @@ -1412,27 +1412,29 @@ struct Inequality { if (*lhs.field_ref() != guarantee.target) return std::nullopt; auto options = checked_pointer_cast(is_in_call->options); - Type::type type = options->value_set.type()->id(); - - // For now, we abort simplification if the guarantee bound's type does not - // exactly match the value set's type. - if (guarantee.bound.type()->id() != type) return std::nullopt; + std::array types{guarantee.bound.type().get(), + options->value_set.type().get()}; + TypeHolder cmp_type; + if (types[0] == types[1]) cmp_type = types[0]; + if (!cmp_type) cmp_type = internal::CommonNumeric(types.data(), types.size()); + if (!cmp_type) cmp_type = internal::CommonTemporal(types.data(), types.size()); + if (!cmp_type) cmp_type = internal::CommonBinary(types.data(), types.size()); + if (!cmp_type) return std::nullopt; std::variant, bool> result; - auto simplify_value_set = [&](auto type) -> Status { - using T = decltype(type); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr value_set, - GetIsInValueSetForSimplification(is_in_call, context)); - result = SimplifyIsInValueSet(guarantee, value_set); - return Status::OK(); - }; -#define CASE(TYPE_CLASS) \ - case TYPE_CLASS##Type::type_id: \ - RETURN_NOT_OK(simplify_value_set(TYPE_CLASS##Type{})); \ - break; +#define CASE(TYPE_CLASS) \ + case TYPE_CLASS##Type::type_id: { \ + ARROW_ASSIGN_OR_RAISE( \ + std::shared_ptr value_set, \ + GetIsInValueSetForSimplification(is_in_call, cmp_type, context)); \ + ARROW_ASSIGN_OR_RAISE( \ + result, \ + SimplifyIsInValueSet(guarantee, value_set)); \ + break; \ + } - switch (type) { + switch (cmp_type.id()) { CASE(UInt8) CASE(Int8) CASE(UInt16) @@ -1452,6 +1454,7 @@ struct Inequality { CASE(String) CASE(LargeString) CASE(StringView) + CASE(FixedSizeBinary) default: return std::nullopt; } diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index c13fd8b41ab96..c79d180b674a7 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -1617,58 +1617,67 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) { } TEST(Expression, SimplifyIsIn) { - auto is_in = [](Expression field, std::string json_array) { - SetLookupOptions options{ArrayFromJSON(int32(), json_array), + auto is_in = [](Expression field, std::shared_ptr value_set_type, + std::string json_array) { + SetLookupOptions options{ArrayFromJSON(value_set_type, json_array), SetLookupOptions::MATCH}; return call("is_in", {field}, options); }; - Simplify{is_in(field_ref("i32"), "[]")} + Simplify{is_in(field_ref("i32"), int32(), "[]")} .WithGuarantee(greater(field_ref("i32"), literal(2))) .Expect(false); - Simplify{is_in(field_ref("i32"), "[null]")} + Simplify{is_in(field_ref("i32"), int32(), "[null]")} .WithGuarantee(greater(field_ref("i32"), literal(2))) .Expect(false); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee(equal(field_ref("i32"), literal(7))) .Expect(true); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee(equal(field_ref("i32"), literal(6))) .Expect(false); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee(greater(field_ref("i32"), literal(3))) - .Expect(is_in(field_ref("i32"), "[5,7,9]")); + .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]")); - Simplify{is_in(field_ref("i32"), "[1,null,3,5,null,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]")} .WithGuarantee(greater(field_ref("i32"), literal(3))) - .Expect(is_in(field_ref("i32"), "[5,7,9]")); + .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]")); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee(greater(field_ref("i32"), literal(9))) .Expect(false); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee(less_equal(field_ref("i32"), literal(0))) .Expect(false); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee(greater(field_ref("i32"), literal(0))) .ExpectUnchanged(); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee( or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32")))) .ExpectUnchanged(); - Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")} .WithGuarantee( and_(less_equal(field_ref("i32"), literal(7)), greater(field_ref("i32"), literal(4)))) - .Expect(is_in(field_ref("i32"), "[5,7]")); + .Expect(is_in(field_ref("i32"), int32(), "[5,7]")); + + Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]")} + .WithGuarantee(greater(field_ref("u32"), literal(3))) + .Expect(is_in(field_ref("u32"), int32(), "[5,7,9]")); + + Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]")} + .WithGuarantee(greater(field_ref("u32"), literal(3))) + .Expect(is_in(field_ref("u32"), int64(), "[5,7,9]")); } TEST(Expression, SimplifyThenExecute) {