diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 6ece1cb444cc0..ccb89aaeae190 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -20,7 +20,6 @@ #include "arrow/compare.h" #include -#include #include #include #include @@ -32,683 +31,28 @@ #include #include "arrow/array.h" -#include "arrow/array/diff.h" #include "arrow/array/statistics.h" #include "arrow/buffer.h" -#include "arrow/scalar.h" +#include "arrow/compare_internal.h" #include "arrow/sparse_tensor.h" #include "arrow/status.h" #include "arrow/tensor.h" #include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/binary_view_util.h" -#include "arrow/util/bit_run_reader.h" #include "arrow/util/bit_util.h" -#include "arrow/util/bitmap_ops.h" -#include "arrow/util/bitmap_reader.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/float16.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging_internal.h" #include "arrow/util/macros.h" -#include "arrow/util/memory_internal.h" -#include "arrow/util/ree_util.h" -#include "arrow/visit_scalar_inline.h" -#include "arrow/visit_type_inline.h" namespace arrow { -using internal::BitmapEquals; -using internal::BitmapReader; -using internal::BitmapUInt64Reader; using internal::checked_cast; -using internal::OptionalBitmapEquals; -using util::Float16; // ---------------------------------------------------------------------- // Public method implementations namespace { -// TODO also handle HALF_FLOAT NaNs - -template -struct FloatingEqualityFlags { - static constexpr bool approximate = Approximate; - static constexpr bool nans_equal = NansEqual; - static constexpr bool signed_zeros_equal = SignedZerosEqual; -}; - -template -struct FloatingEquality { - explicit FloatingEquality(const EqualOptions& options) - : epsilon(static_cast(options.atol())) {} - - bool operator()(T x, T y) const { - if (x == y) { - return Flags::signed_zeros_equal || (std::signbit(x) == std::signbit(y)); - } - if (Flags::nans_equal && std::isnan(x) && std::isnan(y)) { - return true; - } - if (Flags::approximate && (fabs(x - y) <= epsilon)) { - return true; - } - return false; - } - - const T epsilon; -}; - -// For half-float equality. -template -struct FloatingEquality { - explicit FloatingEquality(const EqualOptions& options) - : epsilon(static_cast(options.atol())) {} - - bool operator()(uint16_t x, uint16_t y) const { - Float16 f_x = Float16::FromBits(x); - Float16 f_y = Float16::FromBits(y); - if (x == y) { - return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit()); - } - if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) { - return true; - } - if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <= epsilon)) { - return true; - } - return false; - } - - const float epsilon; -}; - -template -struct FloatingEqualityDispatcher { - const EqualOptions& options; - bool floating_approximate; - Visitor&& visit; - - template - void DispatchL3() { - if (options.signed_zeros_equal()) { - visit(FloatingEquality>{ - options}); - } else { - visit(FloatingEquality>{ - options}); - } - } - - template - void DispatchL2() { - if (options.nans_equal()) { - DispatchL3(); - } else { - DispatchL3(); - } - } - - void Dispatch() { - if (floating_approximate) { - DispatchL2(); - } else { - DispatchL2(); - } - } -}; - -// Call `visit(equality_func)` where `equality_func` has the signature `bool(T, T)` -// and returns true if the two values compare equal. -template -void VisitFloatingEquality(const EqualOptions& options, bool floating_approximate, - Visitor&& visit) { - FloatingEqualityDispatcher{options, floating_approximate, - std::forward(visit)} - .Dispatch(); -} - -inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) { - if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) { - return false; - } - for (const auto& child : type.fields()) { - if (!IdentityImpliesEqualityNansNotEqual(*child->type())) { - return false; - } - } - return true; -} - -inline bool IdentityImpliesEquality(const DataType& type, const EqualOptions& options) { - if (options.nans_equal()) { - return true; - } - return IdentityImpliesEqualityNansNotEqual(type); -} - -bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, - int64_t left_start_idx, int64_t left_end_idx, - int64_t right_start_idx, const EqualOptions& options, - bool floating_approximate); - -class RangeDataEqualsImpl { - public: - // PRE-CONDITIONS: - // - the types are equal - // - the ranges are in bounds - RangeDataEqualsImpl(const EqualOptions& options, bool floating_approximate, - const ArrayData& left, const ArrayData& right, - int64_t left_start_idx, int64_t right_start_idx, - int64_t range_length) - : options_(options), - floating_approximate_(floating_approximate), - left_(left), - right_(right), - left_start_idx_(left_start_idx), - right_start_idx_(right_start_idx), - range_length_(range_length), - result_(false) {} - - bool Compare() { - // Compare null bitmaps - if (left_start_idx_ == 0 && right_start_idx_ == 0 && range_length_ == left_.length && - range_length_ == right_.length) { - // If we're comparing entire arrays, we can first compare the cached null counts - if (left_.GetNullCount() != right_.GetNullCount()) { - return false; - } - } - if (!OptionalBitmapEquals(left_.buffers[0], left_.offset + left_start_idx_, - right_.buffers[0], right_.offset + right_start_idx_, - range_length_)) { - return false; - } - // Compare values - return CompareWithType(*left_.type); - } - - bool CompareWithType(const DataType& type) { - result_ = true; - if (range_length_ != 0) { - ARROW_CHECK_OK(VisitTypeInline(type, this)); - } - return result_; - } - - Status Visit(const NullType&) { return Status::OK(); } - - template - enable_if_primitive_ctype Visit(const TypeClass& type) { - return ComparePrimitive(type); - } - - template - enable_if_t::value, Status> Visit(const TypeClass& type) { - return ComparePrimitive(type); - } - - Status Visit(const BooleanType&) { - const uint8_t* left_bits = left_.GetValues(1, 0); - const uint8_t* right_bits = right_.GetValues(1, 0); - auto compare_runs = [&](int64_t i, int64_t length) -> bool { - if (length <= 8) { - // Avoid the BitmapUInt64Reader overhead for very small runs - for (int64_t j = i; j < i + length; ++j) { - if (bit_util::GetBit(left_bits, left_start_idx_ + left_.offset + j) != - bit_util::GetBit(right_bits, right_start_idx_ + right_.offset + j)) { - return false; - } - } - return true; - } else if (length <= 1024) { - BitmapUInt64Reader left_reader(left_bits, left_start_idx_ + left_.offset + i, - length); - BitmapUInt64Reader right_reader(right_bits, right_start_idx_ + right_.offset + i, - length); - while (left_reader.position() < length) { - if (left_reader.NextWord() != right_reader.NextWord()) { - return false; - } - } - DCHECK_EQ(right_reader.position(), length); - } else { - // BitmapEquals is the fastest method on large runs - return BitmapEquals(left_bits, left_start_idx_ + left_.offset + i, right_bits, - right_start_idx_ + right_.offset + i, length); - } - return true; - }; - VisitValidRuns(compare_runs); - return Status::OK(); - } - - Status Visit(const FloatType& type) { return CompareFloating(type); } - - Status Visit(const DoubleType& type) { return CompareFloating(type); } - - Status Visit(const HalfFloatType& type) { return CompareFloating(type); } - - // Also matches StringType - Status Visit(const BinaryType& type) { return CompareBinary(type); } - - // Also matches StringViewType - Status Visit(const BinaryViewType& type) { - auto* left_values = left_.GetValues(1) + left_start_idx_; - auto* right_values = right_.GetValues(1) + right_start_idx_; - - auto* left_buffers = left_.buffers.data() + 2; - auto* right_buffers = right_.buffers.data() + 2; - VisitValidRuns([&](int64_t i, int64_t length) { - for (auto end_i = i + length; i < end_i; ++i) { - if (!util::EqualBinaryView(left_values[i], right_values[i], left_buffers, - right_buffers)) { - return false; - } - } - return true; - }); - return Status::OK(); - } - - // Also matches LargeStringType - Status Visit(const LargeBinaryType& type) { return CompareBinary(type); } - - Status Visit(const FixedSizeBinaryType& type) { - const auto byte_width = type.byte_width(); - const uint8_t* left_data = left_.GetValues(1, 0); - const uint8_t* right_data = right_.GetValues(1, 0); - - if (left_data != nullptr && right_data != nullptr) { - auto compare_runs = [&](int64_t i, int64_t length) -> bool { - return memcmp(left_data + (left_start_idx_ + left_.offset + i) * byte_width, - right_data + (right_start_idx_ + right_.offset + i) * byte_width, - length * byte_width) == 0; - }; - VisitValidRuns(compare_runs); - } else { - auto compare_runs = [&](int64_t i, int64_t length) -> bool { return true; }; - VisitValidRuns(compare_runs); - } - return Status::OK(); - } - - // Also matches MapType - Status Visit(const ListType& type) { return CompareList(type); } - - Status Visit(const LargeListType& type) { return CompareList(type); } - - Status Visit(const ListViewType& type) { return CompareListView(type); } - - Status Visit(const LargeListViewType& type) { return CompareListView(type); } - - Status Visit(const FixedSizeListType& type) { - const auto list_size = type.list_size(); - const ArrayData& left_data = *left_.child_data[0]; - const ArrayData& right_data = *right_.child_data[0]; - - auto compare_runs = [&](int64_t i, int64_t length) -> bool { - RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data, - (left_start_idx_ + left_.offset + i) * list_size, - (right_start_idx_ + right_.offset + i) * list_size, - length * list_size); - return impl.Compare(); - }; - VisitValidRuns(compare_runs); - return Status::OK(); - } - - Status Visit(const StructType& type) { - const int32_t num_fields = type.num_fields(); - - auto compare_runs = [&](int64_t i, int64_t length) -> bool { - for (int32_t f = 0; f < num_fields; ++f) { - RangeDataEqualsImpl impl(options_, floating_approximate_, *left_.child_data[f], - *right_.child_data[f], - left_start_idx_ + left_.offset + i, - right_start_idx_ + right_.offset + i, length); - if (!impl.Compare()) { - return false; - } - } - return true; - }; - VisitValidRuns(compare_runs); - return Status::OK(); - } - - Status Visit(const SparseUnionType& type) { - const auto& child_ids = type.child_ids(); - const int8_t* left_codes = left_.GetValues(1); - const int8_t* right_codes = right_.GetValues(1); - - // Unions don't have a null bitmap - int64_t run_start = 0; // Start index of the current run - - for (int64_t i = 0; i < range_length_; ++i) { - const auto current_type_id = left_codes[left_start_idx_ + i]; - - if (current_type_id != right_codes[right_start_idx_ + i]) { - result_ = false; - break; - } - // Check if the current element breaks the run - if (i > 0 && current_type_id != left_codes[left_start_idx_ + i - 1]) { - // Compare the previous run - const auto previous_child_num = child_ids[left_codes[left_start_idx_ + i - 1]]; - int64_t run_length = i - run_start; - - RangeDataEqualsImpl impl( - options_, floating_approximate_, *left_.child_data[previous_child_num], - *right_.child_data[previous_child_num], - left_start_idx_ + left_.offset + run_start, - right_start_idx_ + right_.offset + run_start, run_length); - - if (!impl.Compare()) { - result_ = false; - break; - } - - // Start a new run - run_start = i; - } - } - - // Handle the final run - if (result_) { - const auto final_child_num = child_ids[left_codes[left_start_idx_ + run_start]]; - int64_t final_run_length = range_length_ - run_start; - - RangeDataEqualsImpl impl( - options_, floating_approximate_, *left_.child_data[final_child_num], - *right_.child_data[final_child_num], left_start_idx_ + left_.offset + run_start, - right_start_idx_ + right_.offset + run_start, final_run_length); - - if (!impl.Compare()) { - result_ = false; - } - } - return Status::OK(); - } - - Status Visit(const DenseUnionType& type) { - const auto& child_ids = type.child_ids(); - const int8_t* left_codes = left_.GetValues(1); - const int8_t* right_codes = right_.GetValues(1); - const int32_t* left_offsets = left_.GetValues(2); - const int32_t* right_offsets = right_.GetValues(2); - - for (int64_t i = 0; i < range_length_; ++i) { - const auto type_id = left_codes[left_start_idx_ + i]; - if (type_id != right_codes[right_start_idx_ + i]) { - result_ = false; - break; - } - const auto child_num = child_ids[type_id]; - RangeDataEqualsImpl impl( - options_, floating_approximate_, *left_.child_data[child_num], - *right_.child_data[child_num], left_offsets[left_start_idx_ + i], - right_offsets[right_start_idx_ + i], 1); - if (!impl.Compare()) { - result_ = false; - break; - } - } - return Status::OK(); - } - - Status Visit(const DictionaryType& type) { - // Compare dictionaries - result_ &= CompareArrayRanges( - *left_.dictionary, *right_.dictionary, - /*left_start_idx=*/0, - /*left_end_idx=*/std::max(left_.dictionary->length, right_.dictionary->length), - /*right_start_idx=*/0, options_, floating_approximate_); - if (result_) { - // Compare indices - result_ &= CompareWithType(*type.index_type()); - } - return Status::OK(); - } - - Status Visit(const RunEndEncodedType& type) { - switch (type.run_end_type()->id()) { - case Type::INT16: - return CompareRunEndEncoded(); - case Type::INT32: - return CompareRunEndEncoded(); - case Type::INT64: - return CompareRunEndEncoded(); - default: - return Status::Invalid("invalid run ends type: ", *type.run_end_type()); - } - } - - Status Visit(const ExtensionType& type) { - // Compare storages - result_ &= CompareWithType(*type.storage_type()); - return Status::OK(); - } - - protected: - template - Status ComparePrimitive(const TypeClass&) { - const CType* left_values = left_.GetValues(1); - const CType* right_values = right_.GetValues(1); - VisitValidRuns([&](int64_t i, int64_t length) { - return memcmp(left_values + left_start_idx_ + i, - right_values + right_start_idx_ + i, length * sizeof(CType)) == 0; - }); - return Status::OK(); - } - - template - Status CompareFloating(const TypeClass&) { - using CType = typename TypeClass::c_type; - const CType* left_values = left_.GetValues(1); - const CType* right_values = right_.GetValues(1); - - auto visitor = [&](auto&& compare_func) { - VisitValues([&](int64_t i) { - const CType x = left_values[i + left_start_idx_]; - const CType y = right_values[i + right_start_idx_]; - return compare_func(x, y); - }); - }; - VisitFloatingEquality(options_, floating_approximate_, std::move(visitor)); - return Status::OK(); - } - - template - Status CompareBinary(const TypeClass&) { - const uint8_t* left_data = left_.GetValues(2, 0); - const uint8_t* right_data = right_.GetValues(2, 0); - - if (left_data != nullptr && right_data != nullptr) { - const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset, - int64_t length) -> bool { - return memcmp(left_data + left_offset, right_data + right_offset, length) == 0; - }; - CompareWithOffsets(1, compare_ranges); - } else { - // One of the arrays is an array of empty strings and nulls. - // We just need to compare the offsets. - // (note we must not call memcmp() with null data pointers) - CompareWithOffsets(1, [](...) { return true; }); - } - return Status::OK(); - } - - template - Status CompareList(const TypeClass&) { - const ArrayData& left_data = *left_.child_data[0]; - const ArrayData& right_data = *right_.child_data[0]; - - const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset, - int64_t length) -> bool { - RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data, - left_offset, right_offset, length); - return impl.Compare(); - }; - - CompareWithOffsets(1, compare_ranges); - return Status::OK(); - } - - template - Status CompareListView(const TypeClass& type) { - const ArrayData& left_values = *left_.child_data[0]; - const ArrayData& right_values = *right_.child_data[0]; - - using offset_type = typename TypeClass::offset_type; - const auto* left_offsets = left_.GetValues(1) + left_start_idx_; - const auto* right_offsets = right_.GetValues(1) + right_start_idx_; - const auto* left_sizes = left_.GetValues(2) + left_start_idx_; - const auto* right_sizes = right_.GetValues(2) + right_start_idx_; - - auto compare_view = [&](int64_t i, int64_t length) -> bool { - for (int64_t j = i; j < i + length; ++j) { - if (left_sizes[j] != right_sizes[j]) { - return false; - } - const offset_type size = left_sizes[j]; - if (size == 0) { - continue; - } - RangeDataEqualsImpl impl(options_, floating_approximate_, left_values, - right_values, left_offsets[j], right_offsets[j], size); - if (!impl.Compare()) { - return false; - } - } - return true; - }; - VisitValidRuns(std::move(compare_view)); - return Status::OK(); - } - - template - Status CompareRunEndEncoded() { - auto left_span = ArraySpan(left_); - auto right_span = ArraySpan(right_); - left_span.SetSlice(left_.offset + left_start_idx_, range_length_); - right_span.SetSlice(right_.offset + right_start_idx_, range_length_); - const ree_util::RunEndEncodedArraySpan left(left_span); - const ree_util::RunEndEncodedArraySpan right(right_span); - - const auto& left_values = *left_.child_data[1]; - const auto& right_values = *right_.child_data[1]; - - auto it = ree_util::MergedRunsIterator(left, right); - for (; !it.is_end(); ++it) { - RangeDataEqualsImpl impl(options_, floating_approximate_, left_values, right_values, - it.index_into_left_array(), it.index_into_right_array(), - /*range_length=*/1); - if (!impl.Compare()) { - result_ = false; - return Status::OK(); - } - } - return Status::OK(); - } - - template - void CompareWithOffsets(int offsets_buffer_index, CompareRanges&& compare_ranges) { - const offset_type* left_offsets = - left_.GetValues(offsets_buffer_index) + left_start_idx_; - const offset_type* right_offsets = - right_.GetValues(offsets_buffer_index) + right_start_idx_; - - const auto compare_runs = [&](int64_t i, int64_t length) { - for (int64_t j = i; j < i + length; ++j) { - if (left_offsets[j + 1] - left_offsets[j] != - right_offsets[j + 1] - right_offsets[j]) { - return false; - } - } - if (!compare_ranges(left_offsets[i], right_offsets[i], - left_offsets[i + length] - left_offsets[i])) { - return false; - } - return true; - }; - - VisitValidRuns(compare_runs); - } - - template - void VisitValues(CompareValues&& compare_values) { - internal::VisitSetBitRunsVoid(left_.buffers[0], left_.offset + left_start_idx_, - range_length_, [&](int64_t position, int64_t length) { - for (int64_t i = 0; i < length; ++i) { - result_ &= compare_values(position + i); - } - }); - } - - // Visit and compare runs of non-null values - template - void VisitValidRuns(CompareRuns&& compare_runs) { - const uint8_t* left_null_bitmap = left_.GetValues(0, 0); - if (left_null_bitmap == nullptr) { - result_ = compare_runs(0, range_length_); - return; - } - internal::SetBitRunReader reader(left_null_bitmap, left_.offset + left_start_idx_, - range_length_); - while (true) { - const auto run = reader.NextRun(); - if (run.length == 0) { - return; - } - if (!compare_runs(run.position, run.length)) { - result_ = false; - return; - } - } - } - - const EqualOptions& options_; - const bool floating_approximate_; - const ArrayData& left_; - const ArrayData& right_; - const int64_t left_start_idx_; - const int64_t right_start_idx_; - const int64_t range_length_; - - bool result_; -}; - -bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, - int64_t left_start_idx, int64_t left_end_idx, - int64_t right_start_idx, const EqualOptions& options, - bool floating_approximate) { - if (left.type->id() != right.type->id() || - !TypeEquals(*left.type, *right.type, false /* check_metadata */)) { - return false; - } - - const int64_t range_length = left_end_idx - left_start_idx; - DCHECK_GE(range_length, 0); - if (left_start_idx + range_length > left.length) { - // Left range too small - return false; - } - if (right_start_idx + range_length > right.length) { - // Right range too small - return false; - } - if (&left == &right && left_start_idx == right_start_idx && - IdentityImpliesEquality(*left.type, options)) { - return true; - } - // Compare values - RangeDataEqualsImpl impl(options, floating_approximate, left, right, left_start_idx, - right_start_idx, range_length); - return impl.Compare(); -} - class TypeEqualsVisitor { public: explicit TypeEqualsVisitor(const DataType& right, bool check_metadata) @@ -874,282 +218,6 @@ class TypeEqualsVisitor { bool check_metadata_; bool result_; }; - -bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, - bool floating_approximate); -bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, - bool floating_approximate); - -class ScalarEqualsVisitor { - public: - // PRE-CONDITIONS: - // - the types are equal - // - the scalars are non-null - explicit ScalarEqualsVisitor(const Scalar& right, const EqualOptions& opts, - bool floating_approximate) - : right_(right), - options_(opts), - floating_approximate_(floating_approximate), - result_(false) {} - - Status Visit(const NullScalar& left) { - result_ = true; - return Status::OK(); - } - - Status Visit(const BooleanScalar& left) { - const auto& right = checked_cast(right_); - result_ = left.value == right.value; - return Status::OK(); - } - - template - typename std::enable_if<(is_primitive_ctype::value || - is_temporal_type::value), - Status>::type - Visit(const T& left_) { - const auto& right = checked_cast(right_); - result_ = right.value == left_.value; - return Status::OK(); - } - - Status Visit(const FloatScalar& left) { return CompareFloating(left); } - - Status Visit(const DoubleScalar& left) { return CompareFloating(left); } - - Status Visit(const HalfFloatScalar& left) { return CompareFloating(left); } - - template - enable_if_t::value, Status> Visit(const T& left) { - const auto& right = checked_cast(right_); - result_ = internal::SharedPtrEquals(left.value, right.value); - return Status::OK(); - } - - Status Visit(const Decimal32Scalar& left) { - const auto& right = checked_cast(right_); - result_ = left.value == right.value; - return Status::OK(); - } - - Status Visit(const Decimal64Scalar& left) { - const auto& right = checked_cast(right_); - result_ = left.value == right.value; - return Status::OK(); - } - - Status Visit(const Decimal128Scalar& left) { - const auto& right = checked_cast(right_); - result_ = left.value == right.value; - return Status::OK(); - } - - Status Visit(const Decimal256Scalar& left) { - const auto& right = checked_cast(right_); - result_ = left.value == right.value; - return Status::OK(); - } - - Status Visit(const ListScalar& left) { - const auto& right = checked_cast(right_); - result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const LargeListScalar& left) { - const auto& right = checked_cast(right_); - result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const ListViewScalar& left) { - const auto& right = checked_cast(right_); - result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const LargeListViewScalar& left) { - const auto& right = checked_cast(right_); - result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const MapScalar& left) { - const auto& right = checked_cast(right_); - result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const FixedSizeListScalar& left) { - const auto& right = checked_cast(right_); - result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const StructScalar& left) { - const auto& right = checked_cast(right_); - - if (right.value.size() != left.value.size()) { - result_ = false; - } else { - bool all_equals = true; - for (size_t i = 0; i < left.value.size() && all_equals; i++) { - all_equals &= ScalarEquals(*left.value[i], *right.value[i], options_, - floating_approximate_); - } - result_ = all_equals; - } - - return Status::OK(); - } - - Status Visit(const DenseUnionScalar& left) { - const auto& right = checked_cast(right_); - result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const SparseUnionScalar& left) { - const auto& right = checked_cast(right_); - result_ = ScalarEquals(*left.value[left.child_id], *right.value[right.child_id], - options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const DictionaryScalar& left) { - const auto& right = checked_cast(right_); - result_ = ScalarEquals(*left.value.index, *right.value.index, options_, - floating_approximate_) && - ArrayEquals(*left.value.dictionary, *right.value.dictionary, options_, - floating_approximate_); - return Status::OK(); - } - - Status Visit(const RunEndEncodedScalar& left) { - const auto& right = checked_cast(right_); - result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - Status Visit(const ExtensionScalar& left) { - const auto& right = checked_cast(right_); - result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); - return Status::OK(); - } - - bool result() const { return result_; } - - protected: - template - Status CompareFloating(const ScalarType& left) { - using CType = decltype(left.value); - const auto& right = checked_cast(right_); - - auto visitor = [&](auto&& compare_func) { - result_ = compare_func(left.value, right.value); - }; - VisitFloatingEquality(options_, floating_approximate_, std::move(visitor)); - return Status::OK(); - } - - const Scalar& right_; - const EqualOptions options_; - const bool floating_approximate_; - bool result_; -}; - -Status PrintDiff(const Array& left, const Array& right, std::ostream* os); - -Status PrintDiff(const Array& left, const Array& right, int64_t left_offset, - int64_t left_length, int64_t right_offset, int64_t right_length, - std::ostream* os) { - if (os == nullptr) { - return Status::OK(); - } - - if (!left.type()->Equals(right.type())) { - *os << "# Array types differed: " << *left.type() << " vs " << *right.type() - << std::endl; - return Status::OK(); - } - - if (left.type()->id() == Type::DICTIONARY) { - *os << "# Dictionary arrays differed" << std::endl; - - const auto& left_dict = checked_cast(left); - const auto& right_dict = checked_cast(right); - - *os << "## dictionary diff"; - auto pos = os->tellp(); - RETURN_NOT_OK(PrintDiff(*left_dict.dictionary(), *right_dict.dictionary(), os)); - if (os->tellp() == pos) { - *os << std::endl; - } - - *os << "## indices diff"; - pos = os->tellp(); - RETURN_NOT_OK(PrintDiff(*left_dict.indices(), *right_dict.indices(), os)); - if (os->tellp() == pos) { - *os << std::endl; - } - return Status::OK(); - } - - const auto left_slice = left.Slice(left_offset, left_length); - const auto right_slice = right.Slice(right_offset, right_length); - ARROW_ASSIGN_OR_RAISE(auto edits, - Diff(*left_slice, *right_slice, default_memory_pool())); - ARROW_ASSIGN_OR_RAISE(auto formatter, MakeUnifiedDiffFormatter(*left.type(), os)); - return formatter(*edits, *left_slice, *right_slice); -} - -Status PrintDiff(const Array& left, const Array& right, std::ostream* os) { - return PrintDiff(left, right, 0, left.length(), 0, right.length(), os); -} - -bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, - int64_t left_end_idx, int64_t right_start_idx, - const EqualOptions& options, bool floating_approximate) { - bool are_equal = - CompareArrayRanges(*left.data(), *right.data(), left_start_idx, left_end_idx, - right_start_idx, options, floating_approximate); - if (!are_equal) { - ARROW_IGNORE_EXPR(PrintDiff( - left, right, left_start_idx, left_end_idx, right_start_idx, - right_start_idx + (left_end_idx - left_start_idx), options.diff_sink())); - } - return are_equal; -} - -bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, - bool floating_approximate) { - if (left.length() != right.length()) { - ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink())); - return false; - } - return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate); -} - -bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, - bool floating_approximate) { - if (&left == &right && IdentityImpliesEquality(*left.type, options)) { - return true; - } - if (!left.type->Equals(right.type)) { - return false; - } - if (left.is_valid != right.is_valid) { - return false; - } - if (!left.is_valid) { - return true; - } - ScalarEqualsVisitor visitor(right, options, floating_approximate); - auto error = VisitScalarInline(left, &visitor); - DCHECK_OK(error); - return visitor.result(); -} - } // namespace bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, @@ -1176,6 +244,12 @@ bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions return ArrayEquals(left, right, opts, floating_approximate); } +bool ArrayDataEquals(const ArrayData& left, const ArrayData& right, + const EqualOptions& opts) { + const bool floating_approximate = false; + return CompareArrayRanges(left, right, 0, left.length, 0, opts, floating_approximate); +} + bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) { return ScalarEquals(left, right, options, options.use_atol()); } diff --git a/cpp/src/arrow/compare_internal.h b/cpp/src/arrow/compare_internal.h new file mode 100644 index 0000000000000..5afeca1033feb --- /dev/null +++ b/cpp/src/arrow/compare_internal.h @@ -0,0 +1,962 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/array/array_dict.h" +#include "arrow/array/data.h" +#include "arrow/array/diff.h" +#include "arrow/compare.h" +#include "arrow/scalar.h" +#include "arrow/type_traits.h" +#include "arrow/util/binary_view_util.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/float16.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/memory_internal.h" +#include "arrow/util/ree_util.h" +#include "arrow/visit_scalar_inline.h" +#include "arrow/visit_type_inline.h" + +namespace arrow { + +using internal::BitmapEquals; +using internal::BitmapReader; +using internal::BitmapUInt64Reader; +using internal::checked_cast; +using internal::OptionalBitmapEquals; +using util::Float16; + +// TODO also handle HALF_FLOAT NaNs + +template +struct FloatingEqualityFlags { + static constexpr bool approximate = Approximate; + static constexpr bool nans_equal = NansEqual; + static constexpr bool signed_zeros_equal = SignedZerosEqual; +}; + +template +struct FloatingEquality { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast(options.atol())) {} + + bool operator()(T x, T y) const { + if (x == y) { + return Flags::signed_zeros_equal || (std::signbit(x) == std::signbit(y)); + } + if (Flags::nans_equal && std::isnan(x) && std::isnan(y)) { + return true; + } + if (Flags::approximate && (fabs(x - y) <= epsilon)) { + return true; + } + return false; + } + + const T epsilon; +}; + +// For half-float equality. +template +struct FloatingEquality { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast(options.atol())) {} + + bool operator()(uint16_t x, uint16_t y) const { + Float16 f_x = Float16::FromBits(x); + Float16 f_y = Float16::FromBits(y); + if (x == y) { + return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit()); + } + if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) { + return true; + } + if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <= epsilon)) { + return true; + } + return false; + } + + const float epsilon; +}; + +template +struct FloatingEqualityDispatcher { + const EqualOptions& options; + bool floating_approximate; + Visitor&& visit; + + template + void DispatchL3() { + if (options.signed_zeros_equal()) { + visit(FloatingEquality>{ + options}); + } else { + visit(FloatingEquality>{ + options}); + } + } + + template + void DispatchL2() { + if (options.nans_equal()) { + DispatchL3(); + } else { + DispatchL3(); + } + } + + void Dispatch() { + if (floating_approximate) { + DispatchL2(); + } else { + DispatchL2(); + } + } +}; + +// Call `visit(equality_func)` where `equality_func` has the signature `bool(T, T)` +// and returns true if the two values compare equal. +template +void VisitFloatingEquality(const EqualOptions& options, bool floating_approximate, + Visitor&& visit) { + FloatingEqualityDispatcher{options, floating_approximate, + std::forward(visit)} + .Dispatch(); +} + +inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) { + if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) { + return false; + } + for (const auto& child : type.fields()) { + if (!IdentityImpliesEqualityNansNotEqual(*child->type())) { + return false; + } + } + return true; +} + +inline bool IdentityImpliesEquality(const DataType& type, const EqualOptions& options) { + if (options.nans_equal()) { + return true; + } + return IdentityImpliesEqualityNansNotEqual(type); +} + +bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, + int64_t left_start_idx, int64_t left_end_idx, + int64_t right_start_idx, const EqualOptions& options, + bool floating_approximate); + +class RangeDataEqualsImpl { + public: + // PRE-CONDITIONS: + // - the types are equal + // - the ranges are in bounds + // - the ArrayData arguments have the same length + RangeDataEqualsImpl(const EqualOptions& options, bool floating_approximate, + const ArrayData& left, const ArrayData& right, + int64_t left_start_idx, int64_t right_start_idx, + int64_t range_length) + : options_(options), + floating_approximate_(floating_approximate), + left_(left), + right_(right), + left_start_idx_(left_start_idx), + right_start_idx_(right_start_idx), + range_length_(range_length), + result_(false) {} + + bool Compare() { + // Compare null bitmaps + if (left_start_idx_ == 0 && right_start_idx_ == 0 && range_length_ == left_.length && + range_length_ == right_.length) { + // If we're comparing entire arrays, we can first compare the cached null counts + if (left_.GetNullCount() != right_.GetNullCount()) { + return false; + } + } + if (!OptionalBitmapEquals(left_.buffers[0], left_.offset + left_start_idx_, + right_.buffers[0], right_.offset + right_start_idx_, + range_length_)) { + return false; + } + // Compare values + return CompareWithType(*left_.type); + } + + bool CompareWithType(const DataType& type) { + result_ = true; + if (range_length_ != 0) { + ARROW_CHECK_OK(VisitTypeInline(type, this)); + } + return result_; + } + + Status Visit(const NullType&) { return Status::OK(); } + + template + enable_if_primitive_ctype Visit(const TypeClass& type) { + return ComparePrimitive(type); + } + + template + enable_if_t::value, Status> Visit(const TypeClass& type) { + return ComparePrimitive(type); + } + + Status Visit(const BooleanType&) { + const uint8_t* left_bits = left_.GetValues(1, 0); + const uint8_t* right_bits = right_.GetValues(1, 0); + auto compare_runs = [&](int64_t i, int64_t length) -> bool { + if (length <= 8) { + // Avoid the BitmapUInt64Reader overhead for very small runs + for (int64_t j = i; j < i + length; ++j) { + if (bit_util::GetBit(left_bits, left_start_idx_ + left_.offset + j) != + bit_util::GetBit(right_bits, right_start_idx_ + right_.offset + j)) { + return false; + } + } + return true; + } else if (length <= 1024) { + BitmapUInt64Reader left_reader(left_bits, left_start_idx_ + left_.offset + i, + length); + BitmapUInt64Reader right_reader(right_bits, right_start_idx_ + right_.offset + i, + length); + while (left_reader.position() < length) { + if (left_reader.NextWord() != right_reader.NextWord()) { + return false; + } + } + DCHECK_EQ(right_reader.position(), length); + } else { + // BitmapEquals is the fastest method on large runs + return BitmapEquals(left_bits, left_start_idx_ + left_.offset + i, right_bits, + right_start_idx_ + right_.offset + i, length); + } + return true; + }; + VisitValidRuns(compare_runs); + return Status::OK(); + } + + Status Visit(const FloatType& type) { return CompareFloating(type); } + + Status Visit(const DoubleType& type) { return CompareFloating(type); } + + Status Visit(const HalfFloatType& type) { return CompareFloating(type); } + + // Also matches StringType + Status Visit(const BinaryType& type) { return CompareBinary(type); } + + // Also matches StringViewType + Status Visit(const BinaryViewType& type) { + auto* left_values = left_.GetValues(1) + left_start_idx_; + auto* right_values = right_.GetValues(1) + right_start_idx_; + + auto* left_buffers = left_.buffers.data() + 2; + auto* right_buffers = right_.buffers.data() + 2; + VisitValidRuns([&](int64_t i, int64_t length) { + for (auto end_i = i + length; i < end_i; ++i) { + if (!util::EqualBinaryView(left_values[i], right_values[i], left_buffers, + right_buffers)) { + return false; + } + } + return true; + }); + return Status::OK(); + } + + // Also matches LargeStringType + Status Visit(const LargeBinaryType& type) { return CompareBinary(type); } + + Status Visit(const FixedSizeBinaryType& type) { + const auto byte_width = type.byte_width(); + const uint8_t* left_data = left_.GetValues(1, 0); + const uint8_t* right_data = right_.GetValues(1, 0); + + if (left_data != nullptr && right_data != nullptr) { + auto compare_runs = [&](int64_t i, int64_t length) -> bool { + return memcmp(left_data + (left_start_idx_ + left_.offset + i) * byte_width, + right_data + (right_start_idx_ + right_.offset + i) * byte_width, + length * byte_width) == 0; + }; + VisitValidRuns(compare_runs); + } else { + auto compare_runs = [&](int64_t i, int64_t length) -> bool { return true; }; + VisitValidRuns(compare_runs); + } + return Status::OK(); + } + + // Also matches MapType + Status Visit(const ListType& type) { return CompareList(type); } + + Status Visit(const LargeListType& type) { return CompareList(type); } + + Status Visit(const ListViewType& type) { return CompareListView(type); } + + Status Visit(const LargeListViewType& type) { return CompareListView(type); } + + Status Visit(const FixedSizeListType& type) { + const auto list_size = type.list_size(); + const ArrayData& left_data = *left_.child_data[0]; + const ArrayData& right_data = *right_.child_data[0]; + + auto compare_runs = [&](int64_t i, int64_t length) -> bool { + RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data, + (left_start_idx_ + left_.offset + i) * list_size, + (right_start_idx_ + right_.offset + i) * list_size, + length * list_size); + return impl.Compare(); + }; + VisitValidRuns(compare_runs); + return Status::OK(); + } + + Status Visit(const StructType& type) { + const int32_t num_fields = type.num_fields(); + + auto compare_runs = [&](int64_t i, int64_t length) -> bool { + for (int32_t f = 0; f < num_fields; ++f) { + RangeDataEqualsImpl impl(options_, floating_approximate_, *left_.child_data[f], + *right_.child_data[f], + left_start_idx_ + left_.offset + i, + right_start_idx_ + right_.offset + i, length); + if (!impl.Compare()) { + return false; + } + } + return true; + }; + VisitValidRuns(compare_runs); + return Status::OK(); + } + + Status Visit(const SparseUnionType& type) { + const auto& child_ids = type.child_ids(); + const int8_t* left_codes = left_.GetValues(1); + const int8_t* right_codes = right_.GetValues(1); + + // Unions don't have a null bitmap + int64_t run_start = 0; // Start index of the current run + + for (int64_t i = 0; i < range_length_; ++i) { + const auto current_type_id = left_codes[left_start_idx_ + i]; + + if (current_type_id != right_codes[right_start_idx_ + i]) { + result_ = false; + break; + } + // Check if the current element breaks the run + if (i > 0 && current_type_id != left_codes[left_start_idx_ + i - 1]) { + // Compare the previous run + const auto previous_child_num = child_ids[left_codes[left_start_idx_ + i - 1]]; + int64_t run_length = i - run_start; + + RangeDataEqualsImpl impl( + options_, floating_approximate_, *left_.child_data[previous_child_num], + *right_.child_data[previous_child_num], + left_start_idx_ + left_.offset + run_start, + right_start_idx_ + right_.offset + run_start, run_length); + + if (!impl.Compare()) { + result_ = false; + break; + } + + // Start a new run + run_start = i; + } + } + + // Handle the final run + if (result_) { + const auto final_child_num = child_ids[left_codes[left_start_idx_ + run_start]]; + int64_t final_run_length = range_length_ - run_start; + + RangeDataEqualsImpl impl( + options_, floating_approximate_, *left_.child_data[final_child_num], + *right_.child_data[final_child_num], left_start_idx_ + left_.offset + run_start, + right_start_idx_ + right_.offset + run_start, final_run_length); + + if (!impl.Compare()) { + result_ = false; + } + } + return Status::OK(); + } + + Status Visit(const DenseUnionType& type) { + const auto& child_ids = type.child_ids(); + const int8_t* left_codes = left_.GetValues(1); + const int8_t* right_codes = right_.GetValues(1); + const int32_t* left_offsets = left_.GetValues(2); + const int32_t* right_offsets = right_.GetValues(2); + + for (int64_t i = 0; i < range_length_; ++i) { + const auto type_id = left_codes[left_start_idx_ + i]; + if (type_id != right_codes[right_start_idx_ + i]) { + result_ = false; + break; + } + const auto child_num = child_ids[type_id]; + RangeDataEqualsImpl impl( + options_, floating_approximate_, *left_.child_data[child_num], + *right_.child_data[child_num], left_offsets[left_start_idx_ + i], + right_offsets[right_start_idx_ + i], 1); + if (!impl.Compare()) { + result_ = false; + break; + } + } + return Status::OK(); + } + + Status Visit(const DictionaryType& type) { + // Compare dictionaries + result_ &= CompareArrayRanges( + *left_.dictionary, *right_.dictionary, + /*left_start_idx=*/0, + /*left_end_idx=*/std::max(left_.dictionary->length, right_.dictionary->length), + /*right_start_idx=*/0, options_, floating_approximate_); + if (result_) { + // Compare indices + result_ &= CompareWithType(*type.index_type()); + } + return Status::OK(); + } + + Status Visit(const RunEndEncodedType& type) { + switch (type.run_end_type()->id()) { + case Type::INT16: + return CompareRunEndEncoded(); + case Type::INT32: + return CompareRunEndEncoded(); + case Type::INT64: + return CompareRunEndEncoded(); + default: + return Status::Invalid("invalid run ends type: ", *type.run_end_type()); + } + } + + Status Visit(const ExtensionType& type) { + // Compare storages + result_ &= CompareWithType(*type.storage_type()); + return Status::OK(); + } + + protected: + template + Status ComparePrimitive(const TypeClass&) { + const CType* left_values = left_.GetValues(1); + const CType* right_values = right_.GetValues(1); + VisitValidRuns([&](int64_t i, int64_t length) { + return memcmp(left_values + left_start_idx_ + i, + right_values + right_start_idx_ + i, length * sizeof(CType)) == 0; + }); + return Status::OK(); + } + + template + Status CompareFloating(const TypeClass&) { + using CType = typename TypeClass::c_type; + const CType* left_values = left_.GetValues(1); + const CType* right_values = right_.GetValues(1); + + auto visitor = [&](auto&& compare_func) { + VisitValues([&](int64_t i) { + const CType x = left_values[i + left_start_idx_]; + const CType y = right_values[i + right_start_idx_]; + return compare_func(x, y); + }); + }; + VisitFloatingEquality(options_, floating_approximate_, std::move(visitor)); + return Status::OK(); + } + + template + Status CompareBinary(const TypeClass&) { + const uint8_t* left_data = left_.GetValues(2, 0); + const uint8_t* right_data = right_.GetValues(2, 0); + + if (left_data != nullptr && right_data != nullptr) { + const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset, + int64_t length) -> bool { + return memcmp(left_data + left_offset, right_data + right_offset, length) == 0; + }; + CompareWithOffsets(1, compare_ranges); + } else { + // One of the arrays is an array of empty strings and nulls. + // We just need to compare the offsets. + // (note we must not call memcmp() with null data pointers) + CompareWithOffsets(1, [](...) { return true; }); + } + return Status::OK(); + } + + template + Status CompareList(const TypeClass&) { + const ArrayData& left_data = *left_.child_data[0]; + const ArrayData& right_data = *right_.child_data[0]; + + const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset, + int64_t length) -> bool { + RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data, + left_offset, right_offset, length); + return impl.Compare(); + }; + + CompareWithOffsets(1, compare_ranges); + return Status::OK(); + } + + template + Status CompareListView(const TypeClass& type) { + const ArrayData& left_values = *left_.child_data[0]; + const ArrayData& right_values = *right_.child_data[0]; + + using offset_type = typename TypeClass::offset_type; + const auto* left_offsets = left_.GetValues(1) + left_start_idx_; + const auto* right_offsets = right_.GetValues(1) + right_start_idx_; + const auto* left_sizes = left_.GetValues(2) + left_start_idx_; + const auto* right_sizes = right_.GetValues(2) + right_start_idx_; + + auto compare_view = [&](int64_t i, int64_t length) -> bool { + for (int64_t j = i; j < i + length; ++j) { + if (left_sizes[j] != right_sizes[j]) { + return false; + } + const offset_type size = left_sizes[j]; + if (size == 0) { + continue; + } + RangeDataEqualsImpl impl(options_, floating_approximate_, left_values, + right_values, left_offsets[j], right_offsets[j], size); + if (!impl.Compare()) { + return false; + } + } + return true; + }; + VisitValidRuns(std::move(compare_view)); + return Status::OK(); + } + + template + Status CompareRunEndEncoded() { + auto left_span = ArraySpan(left_); + auto right_span = ArraySpan(right_); + left_span.SetSlice(left_.offset + left_start_idx_, range_length_); + right_span.SetSlice(right_.offset + right_start_idx_, range_length_); + const ree_util::RunEndEncodedArraySpan left(left_span); + const ree_util::RunEndEncodedArraySpan right(right_span); + + const auto& left_values = *left_.child_data[1]; + const auto& right_values = *right_.child_data[1]; + + auto it = ree_util::MergedRunsIterator(left, right); + for (; !it.is_end(); ++it) { + RangeDataEqualsImpl impl(options_, floating_approximate_, left_values, right_values, + it.index_into_left_array(), it.index_into_right_array(), + /*range_length=*/1); + if (!impl.Compare()) { + result_ = false; + return Status::OK(); + } + } + return Status::OK(); + } + + template + void CompareWithOffsets(int offsets_buffer_index, CompareRanges&& compare_ranges) { + const offset_type* left_offsets = + left_.GetValues(offsets_buffer_index) + left_start_idx_; + const offset_type* right_offsets = + right_.GetValues(offsets_buffer_index) + right_start_idx_; + + const auto compare_runs = [&](int64_t i, int64_t length) { + for (int64_t j = i; j < i + length; ++j) { + if (left_offsets[j + 1] - left_offsets[j] != + right_offsets[j + 1] - right_offsets[j]) { + return false; + } + } + if (!compare_ranges(left_offsets[i], right_offsets[i], + left_offsets[i + length] - left_offsets[i])) { + return false; + } + return true; + }; + + VisitValidRuns(compare_runs); + } + + template + void VisitValues(CompareValues&& compare_values) { + internal::VisitSetBitRunsVoid(left_.buffers[0], left_.offset + left_start_idx_, + range_length_, [&](int64_t position, int64_t length) { + for (int64_t i = 0; i < length; ++i) { + result_ &= compare_values(position + i); + } + }); + } + + // Visit and compare runs of non-null values + template + void VisitValidRuns(CompareRuns&& compare_runs) { + const uint8_t* left_null_bitmap = left_.GetValues(0, 0); + if (left_null_bitmap == nullptr) { + result_ = compare_runs(0, range_length_); + return; + } + internal::SetBitRunReader reader(left_null_bitmap, left_.offset + left_start_idx_, + range_length_); + while (true) { + const auto run = reader.NextRun(); + if (run.length == 0) { + return; + } + if (!compare_runs(run.position, run.length)) { + result_ = false; + return; + } + } + } + + const EqualOptions& options_; + const bool floating_approximate_; + const ArrayData& left_; + const ArrayData& right_; + const int64_t left_start_idx_; + const int64_t right_start_idx_; + const int64_t range_length_; + + bool result_; +}; + +bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, + int64_t left_start_idx, int64_t left_end_idx, + int64_t right_start_idx, const EqualOptions& options, + bool floating_approximate) { + if (left.type->id() != right.type->id() || + !TypeEquals(*left.type, *right.type, false /* check_metadata */)) { + return false; + } + + const int64_t range_length = left_end_idx - left_start_idx; + DCHECK_GE(range_length, 0); + if (left_start_idx + range_length > left.length) { + // Left range too small + return false; + } + if (right_start_idx + range_length > right.length) { + // Right range too small + return false; + } + if (&left == &right && left_start_idx == right_start_idx && + IdentityImpliesEquality(*left.type, options)) { + return true; + } + // Compare values + RangeDataEqualsImpl impl(options, floating_approximate, left, right, left_start_idx, + right_start_idx, range_length); + return impl.Compare(); +} + +bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, + bool floating_approximate); +bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, + bool floating_approximate); + +class ScalarEqualsVisitor { + public: + // PRE-CONDITIONS: + // - the types are equal + // - the scalars are non-null + explicit ScalarEqualsVisitor(const Scalar& right, const EqualOptions& opts, + bool floating_approximate) + : right_(right), + options_(opts), + floating_approximate_(floating_approximate), + result_(false) {} + + Status Visit(const NullScalar& left) { + result_ = true; + return Status::OK(); + } + + Status Visit(const BooleanScalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + template + typename std::enable_if<(is_primitive_ctype::value || + is_temporal_type::value), + Status>::type + Visit(const T& left_) { + const auto& right = checked_cast(right_); + result_ = right.value == left_.value; + return Status::OK(); + } + + Status Visit(const FloatScalar& left) { return CompareFloating(left); } + + Status Visit(const DoubleScalar& left) { return CompareFloating(left); } + + Status Visit(const HalfFloatScalar& left) { return CompareFloating(left); } + + template + enable_if_t::value, Status> Visit(const T& left) { + const auto& right = checked_cast(right_); + result_ = internal::SharedPtrEquals(left.value, right.value); + return Status::OK(); + } + + Status Visit(const Decimal32Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + Status Visit(const Decimal64Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + Status Visit(const Decimal128Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + Status Visit(const Decimal256Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + Status Visit(const ListScalar& left) { + const auto& right = checked_cast(right_); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const LargeListScalar& left) { + const auto& right = checked_cast(right_); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const ListViewScalar& left) { + const auto& right = checked_cast(right_); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const LargeListViewScalar& left) { + const auto& right = checked_cast(right_); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const MapScalar& left) { + const auto& right = checked_cast(right_); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const FixedSizeListScalar& left) { + const auto& right = checked_cast(right_); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const StructScalar& left) { + const auto& right = checked_cast(right_); + + if (right.value.size() != left.value.size()) { + result_ = false; + } else { + bool all_equals = true; + for (size_t i = 0; i < left.value.size() && all_equals; i++) { + all_equals &= ScalarEquals(*left.value[i], *right.value[i], options_, + floating_approximate_); + } + result_ = all_equals; + } + + return Status::OK(); + } + + Status Visit(const DenseUnionScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const SparseUnionScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value[left.child_id], *right.value[right.child_id], + options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const DictionaryScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value.index, *right.value.index, options_, + floating_approximate_) && + ArrayEquals(*left.value.dictionary, *right.value.dictionary, options_, + floating_approximate_); + return Status::OK(); + } + + Status Visit(const RunEndEncodedScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + Status Visit(const ExtensionScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); + } + + bool result() const { return result_; } + + protected: + template + Status CompareFloating(const ScalarType& left) { + using CType = decltype(left.value); + const auto& right = checked_cast(right_); + + auto visitor = [&](auto&& compare_func) { + result_ = compare_func(left.value, right.value); + }; + VisitFloatingEquality(options_, floating_approximate_, std::move(visitor)); + return Status::OK(); + } + + const Scalar& right_; + const EqualOptions options_; + const bool floating_approximate_; + bool result_; +}; + +Status PrintDiff(const Array& left, const Array& right, std::ostream* os); + +Status PrintDiff(const Array& left, const Array& right, int64_t left_offset, + int64_t left_length, int64_t right_offset, int64_t right_length, + std::ostream* os) { + if (os == nullptr) { + return Status::OK(); + } + + if (!left.type()->Equals(right.type())) { + *os << "# Array types differed: " << *left.type() << " vs " << *right.type() + << std::endl; + return Status::OK(); + } + + if (left.type()->id() == Type::DICTIONARY) { + *os << "# Dictionary arrays differed" << std::endl; + + const auto& left_dict = checked_cast(left); + const auto& right_dict = checked_cast(right); + + *os << "## dictionary diff"; + auto pos = os->tellp(); + RETURN_NOT_OK(PrintDiff(*left_dict.dictionary(), *right_dict.dictionary(), os)); + if (os->tellp() == pos) { + *os << std::endl; + } + + *os << "## indices diff"; + pos = os->tellp(); + RETURN_NOT_OK(PrintDiff(*left_dict.indices(), *right_dict.indices(), os)); + if (os->tellp() == pos) { + *os << std::endl; + } + return Status::OK(); + } + + const auto left_slice = left.Slice(left_offset, left_length); + const auto right_slice = right.Slice(right_offset, right_length); + ARROW_ASSIGN_OR_RAISE(auto edits, + Diff(*left_slice, *right_slice, default_memory_pool())); + ARROW_ASSIGN_OR_RAISE(auto formatter, MakeUnifiedDiffFormatter(*left.type(), os)); + return formatter(*edits, *left_slice, *right_slice); +} + +Status PrintDiff(const Array& left, const Array& right, std::ostream* os) { + return PrintDiff(left, right, 0, left.length(), 0, right.length(), os); +} + +bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, + int64_t left_end_idx, int64_t right_start_idx, + const EqualOptions& options, bool floating_approximate) { + bool are_equal = + CompareArrayRanges(*left.data(), *right.data(), left_start_idx, left_end_idx, + right_start_idx, options, floating_approximate); + if (!are_equal) { + ARROW_IGNORE_EXPR(PrintDiff( + left, right, left_start_idx, left_end_idx, right_start_idx, + right_start_idx + (left_end_idx - left_start_idx), options.diff_sink())); + } + return are_equal; +} + +bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, + bool floating_approximate) { + if (left.length() != right.length()) { + ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink())); + return false; + } + return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate); +} + +bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, + bool floating_approximate) { + if (&left == &right && IdentityImpliesEquality(*left.type, options)) { + return true; + } + if (!left.type->Equals(right.type)) { + return false; + } + if (left.is_valid != right.is_valid) { + return false; + } + if (!left.is_valid) { + return true; + } + ScalarEqualsVisitor visitor(right, options, floating_approximate); + auto error = VisitScalarInline(left, &visitor); + DCHECK_OK(error); + return visitor.result(); +} +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index fbf1bc55e1bbc..362d41ee30f57 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -151,6 +151,13 @@ struct GetViewType::value || static T LogicalValue(PhysicalType value) { return value; } }; +template +struct GetViewType> { + using T = const std::shared_ptr; + + static T LogicalValue(T value) { return value; } +}; + template <> struct GetViewType { using T = Decimal32; @@ -349,6 +356,32 @@ struct ArrayIterator> { } }; +template +struct ArrayIterator> { + using offset_type = typename Type::offset_type; + + const ArraySpan& arr; + const std::shared_ptr array_data; + const offset_type* offsets; + offset_type cur_offset; + int64_t position; + + explicit ArrayIterator(const ArraySpan& arr) + : arr(arr), + array_data(arr.child_data[0].ToArrayData()), + offsets(reinterpret_cast(arr.buffers[1].data) + arr.offset), + cur_offset(offsets[0]), + position(0) {} + + const std::shared_ptr operator()() { + offset_type next_offset = offsets[++position]; + const offset_type length = next_offset - cur_offset; + const auto result = array_data->Slice(cur_offset, length); + cur_offset = next_offset; + return result; + } +}; + template <> struct ArrayIterator { const ArraySpan& arr; @@ -425,6 +458,16 @@ struct UnboxScalar> { } }; +template +struct UnboxScalar> { + using T = const std::shared_ptr; + using ScalarT = typename TypeTraits::ScalarType; + + static const T& Unbox(const Scalar& val) { + return checked_cast(val).value->data(); + } +}; + template <> struct UnboxScalar { using T = Decimal32; @@ -1422,6 +1465,22 @@ auto GenerateDecimal(detail::GetTypeId get_id) { } } +// Generate a kernel given a templated functor for list types +// +// See "Numeric" above for description of the generator functor +template