Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45167: [C++] Implement Compute Equals for List Types #45272

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value ||
static T LogicalValue(PhysicalType value) { return value; }
};

template <typename Type>
struct GetViewType<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;

static T LogicalValue(T value) { return value; }
};

template <>
struct GetViewType<Decimal32Type> {
using T = Decimal32;
Expand Down Expand Up @@ -322,6 +329,47 @@ struct ArrayIterator<Type, enable_if_base_binary<Type>> {
}
};

template <typename Type>
struct ArrayIterator<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;
using ArrayT = typename TypeTraits<Type>::ArrayType;
using offset_type = typename Type::offset_type;

const ArraySpan& arr;
const offset_type* offsets;
offset_type cur_offset;
const ArraySpan& values;
const uint8_t* data;
int64_t position;

explicit ArrayIterator(const ArraySpan& arr)
: arr(arr),
offsets(reinterpret_cast<const offset_type*>(arr.buffers[1].data)),
cur_offset(offsets[arr.offset]),
values(arr.child_data[0]),
position(arr.offset) {}

T operator()() {
offset_type next_offset = offsets[++position];
const auto len = next_offset - cur_offset;
const auto null_count = values.null_count;
const std::shared_ptr<Buffer> nulls_buffer =
null_count > 0 ? *values.buffers[0].owner : nullptr;
std::vector<std::shared_ptr<Buffer>> bufs = {nulls_buffer, *values.buffers[1].owner};
const auto child_offset = values.offset;

// TODO: do not hard code child type. also need to be aware of non-primitive children
const auto array_data = ArrayData::Make(int32(), len, std::move(bufs), null_count,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely do not want to hard code the data type here but I'm not sure what facilities can help select from the ListArray generically - do these exist already or would they need to be built out in this function?

cur_offset + child_offset);
const auto array = MakeArray(array_data);
const auto result = T{array};

cur_offset = next_offset;

return result;
}
};

template <>
struct ArrayIterator<FixedSizeBinaryType> {
const ArraySpan& arr;
Expand Down Expand Up @@ -390,6 +438,12 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
}
};

template <typename Type>
struct UnboxScalar<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;
static const T& Unbox(const Scalar& val) { return checked_cast<const T&>(val); }
};

template <>
struct UnboxScalar<Decimal32Type> {
using T = Decimal32;
Expand Down Expand Up @@ -1383,6 +1437,22 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
}
}

// Generate a kernel given a templated functor for list types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateList(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::LIST:
return Generator<Type0, ListType, Args...>::Exec;
case Type::LARGE_LIST:
return Generator<Type0, LargeListType, Args...>::Exec;
default:
DCHECK(false);
return nullptr;
}
}

// END of kernel generator-dispatchers
// ----------------------------------------------------------------------
// BEGIN of DispatchBest helpers
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,14 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
}

if constexpr (std::is_same_v<Op, Equal> || std::is_same_v<Op, NotEqual>) {
for (const auto id : {Type::LIST, Type::LARGE_LIST}) {
auto exec = GenerateList<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
}
}

return func;
}

Expand Down
66 changes: 66 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,72 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) {
}
}

template <typename ArrowType>
class TestCompareList : public ::testing::Test {};
TYPED_TEST_SUITE(TestCompareList, ListArrowTypes);

TYPED_TEST(TestCompareList, ArrayScalar) {
auto value_typ = std::make_shared<Int32Type>();
auto ty = std::make_shared<TypeParam>(std::move(value_typ));

std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
auto lhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
auto rhs = ScalarFromJSON(ty, R"([1, 2, 3])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareList, ScalarArray) {
auto value_typ = std::make_shared<Int32Type>();
auto ty = std::make_shared<TypeParam>(std::move(value_typ));

std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
auto lhs = ScalarFromJSON(ty, R"([1, 2, 3])");
auto rhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareList, ArrayArray) {
auto value_typ = std::make_shared<Int32Type>();
auto ty = std::make_shared<TypeParam>(std::move(value_typ));

std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
auto lhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5, 6], [7], null])");
auto rhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5], [6, 7, 8], null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"),
ArrayFromJSON(boolean(), "[]"));
CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"),
ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]"));

CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
}
}

// Helper to organize tests for fixed size binary comparisons
struct CompareCase {
std::shared_ptr<DataType> lhs_type;
Expand Down
Loading