diff --git a/.github/workflows/matlab.yml b/.github/workflows/matlab.yml index 221ed5c77cd47..6921e12213b5b 100644 --- a/.github/workflows/matlab.yml +++ b/.github/workflows/matlab.yml @@ -53,6 +53,8 @@ jobs: run: sudo apt-get install ninja-build - name: Install MATLAB uses: matlab-actions/setup-matlab@v1 + with: + release: R2023a - name: Install ccache run: sudo apt-get install ccache - name: Setup ccache @@ -99,6 +101,8 @@ jobs: run: brew install ninja - name: Install MATLAB uses: matlab-actions/setup-matlab@v1 + with: + release: R2023a - name: Install ccache run: brew install ccache - name: Setup ccache @@ -135,6 +139,8 @@ jobs: fetch-depth: 0 - name: Install MATLAB uses: matlab-actions/setup-matlab@v1 + with: + release: R2023a - name: Download Timezone Database shell: bash run: ci/scripts/download_tz_database.sh diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 7fe005f94a5bb..9692f277d183f 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -3346,7 +3346,7 @@ garrow_set_lookup_options_get_property(GObject *object, g_value_set_object(value, priv->value_set); break; case PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS: - g_value_set_boolean(value, options->skip_nulls); + g_value_set_boolean(value, options->skip_nulls.has_value() && options->skip_nulls.value()); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); @@ -3398,13 +3398,11 @@ garrow_set_lookup_options_class_init(GArrowSetLookupOptionsClass *klass) * * Since: 6.0.0 */ - spec = g_param_spec_boolean("skip-nulls", - "Skip NULLs", - "Whether NULLs are skipped or not", - options.skip_nulls, - static_cast(G_PARAM_READWRITE)); - g_object_class_install_property(gobject_class, - PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS, + auto skip_nulls = (options.skip_nulls.has_value() && options.skip_nulls.value()); + spec = + g_param_spec_boolean("skip-nulls", "Skip NULLs", "Whether NULLs are skipped or not", + skip_nulls, static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_SET_LOOKUP_OPTIONS_SKIP_NULLS, spec); } @@ -6458,9 +6456,10 @@ garrow_set_lookup_options_new_raw( arrow_copied_options.get()); auto value_set = garrow_datum_new_raw(&(arrow_copied_set_lookup_options->value_set)); + auto skip_nulls = (arrow_options->skip_nulls.has_value() && arrow_options->skip_nulls.value()); auto options = g_object_new(GARROW_TYPE_SET_LOOKUP_OPTIONS, "value-set", value_set, - "skip-nulls", arrow_options->skip_nulls, + "skip-nulls", skip_nulls, NULL); return GARROW_SET_LOOKUP_OPTIONS(options); } diff --git a/ci/conda_env_archery.txt b/ci/conda_env_archery.txt index ace7a42acb026..40875e0a55039 100644 --- a/ci/conda_env_archery.txt +++ b/ci/conda_env_archery.txt @@ -25,7 +25,7 @@ jira pygit2 pygithub ruamel.yaml -setuptools_scm +setuptools_scm<8.0.0 toolz # benchmark diff --git a/ci/conda_env_crossbow.txt b/ci/conda_env_crossbow.txt index 347294650ca28..59b799720f12b 100644 --- a/ci/conda_env_crossbow.txt +++ b/ci/conda_env_crossbow.txt @@ -21,5 +21,5 @@ jinja2 jira pygit2 ruamel.yaml -setuptools_scm +setuptools_scm<8.0.0 toolz diff --git a/ci/conda_env_python.txt b/ci/conda_env_python.txt index 4ae5c3614a1dc..d914229ec58c0 100644 --- a/ci/conda_env_python.txt +++ b/ci/conda_env_python.txt @@ -28,4 +28,4 @@ pytest-faulthandler pytest-lazy-fixture s3fs>=2021.8.0 setuptools -setuptools_scm +setuptools_scm<8.0.0 diff --git a/ci/docker/conda-python-cython2.dockerfile b/ci/docker/conda-python-cython2.dockerfile new file mode 100644 index 0000000000000..d67ef677276c7 --- /dev/null +++ b/ci/docker/conda-python-cython2.dockerfile @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG repo +ARG arch +ARG python=3.8 +FROM ${repo}:${arch}-conda-python-${python} + +RUN mamba install -q -y "cython<3" && \ + mamba clean --all diff --git a/ci/scripts/integration_arrow.sh b/ci/scripts/integration_arrow.sh index 30cbb2d63791c..a165f8027bf8f 100755 --- a/ci/scripts/integration_arrow.sh +++ b/ci/scripts/integration_arrow.sh @@ -22,10 +22,12 @@ set -ex arrow_dir=${1} gold_dir=$arrow_dir/testing/data/arrow-ipc-stream/integration -pip install -e $arrow_dir/dev/archery +pip install -e $arrow_dir/dev/archery[integration] # Rust can be enabled by exporting ARCHERY_INTEGRATION_WITH_RUST=1 -archery integration \ +time archery integration \ + --run-c-data \ + --run-ipc \ --run-flight \ --with-cpp=1 \ --with-csharp=1 \ diff --git a/ci/scripts/matlab_build.sh b/ci/scripts/matlab_build.sh index 235002da3afc6..d3f86adbb8a2b 100755 --- a/ci/scripts/matlab_build.sh +++ b/ci/scripts/matlab_build.sh @@ -29,8 +29,6 @@ cmake \ -S ${source_dir} \ -B ${build_dir} \ -G Ninja \ - -D MATLAB_BUILD_TESTS=ON \ -D CMAKE_INSTALL_PREFIX=${install_dir} \ -D MATLAB_ADD_INSTALL_DIR_TO_SEARCH_PATH=OFF cmake --build ${build_dir} --config Release --target install -ctest --test-dir ${build_dir} diff --git a/cpp/cmake_modules/BuildUtils.cmake b/cpp/cmake_modules/BuildUtils.cmake index 9112b836c9ef4..083ac2fe9a862 100644 --- a/cpp/cmake_modules/BuildUtils.cmake +++ b/cpp/cmake_modules/BuildUtils.cmake @@ -99,7 +99,7 @@ function(arrow_create_merged_static_lib output_target) if(APPLE) set(BUNDLE_COMMAND "libtool" "-no_warning_for_no_symbols" "-static" "-o" ${output_lib_path} ${all_library_paths}) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "^(Clang|GNU|Intel)$") + elseif(CMAKE_CXX_COMPILER_ID MATCHES "^(Clang|GNU|Intel|IntelLLVM)$") set(ar_script_path ${CMAKE_BINARY_DIR}/${ARG_NAME}.ar) file(WRITE ${ar_script_path}.in "CREATE ${output_lib_path}\n") diff --git a/cpp/cmake_modules/SetupCxxFlags.cmake b/cpp/cmake_modules/SetupCxxFlags.cmake index a5f5659723c28..5531415ac2277 100644 --- a/cpp/cmake_modules/SetupCxxFlags.cmake +++ b/cpp/cmake_modules/SetupCxxFlags.cmake @@ -329,7 +329,8 @@ if("${BUILD_WARNING_LEVEL}" STREQUAL "CHECKIN") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-sign-conversion") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wunused-result") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wdate-time") - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel" OR CMAKE_CXX_COMPILER_ID STREQUAL + "IntelLLVM") if(WIN32) set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wno-deprecated") @@ -360,7 +361,8 @@ elseif("${BUILD_WARNING_LEVEL}" STREQUAL "EVERYTHING") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wextra") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-unused-parameter") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wunused-result") - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel" OR CMAKE_CXX_COMPILER_ID STREQUAL + "IntelLLVM") if(WIN32) set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall") else() @@ -383,7 +385,8 @@ else() OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall") - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel" OR CMAKE_CXX_COMPILER_ID STREQUAL + "IntelLLVM") if(WIN32) set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall") else() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index f474d0c517fa0..9a6117011535e 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -383,7 +383,11 @@ endif() # if(ARROW_BUILD_INTEGRATION OR ARROW_BUILD_TESTS) - list(APPEND ARROW_SRCS integration/json_integration.cc integration/json_internal.cc) + list(APPEND + ARROW_SRCS + integration/c_data_integration_internal.cc + integration/json_integration.cc + integration/json_internal.cc) endif() if(ARROW_CSV) diff --git a/cpp/src/arrow/array/array_dict.cc b/cpp/src/arrow/array/array_dict.cc index cccc7bb78220d..c9e2f93cde66f 100644 --- a/cpp/src/arrow/array/array_dict.cc +++ b/cpp/src/arrow/array/array_dict.cc @@ -282,9 +282,9 @@ class DictionaryUnifierImpl : public DictionaryUnifier { *out_type = arrow::dictionary(index_type, value_type_); // Build unified dictionary array - std::shared_ptr data; - RETURN_NOT_OK(DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table_, - 0 /* start_offset */, &data)); + ARROW_ASSIGN_OR_RAISE( + auto data, DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table_, + 0 /* start_offset */)); *out_dict = MakeArray(data); return Status::OK(); } @@ -299,9 +299,9 @@ class DictionaryUnifierImpl : public DictionaryUnifier { } // Build unified dictionary array - std::shared_ptr data; - RETURN_NOT_OK(DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table_, - 0 /* start_offset */, &data)); + ARROW_ASSIGN_OR_RAISE( + auto data, DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table_, + 0 /* start_offset */)); *out_dict = MakeArray(data); return Status::OK(); } diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index 061fb600412fd..525b0afbc908a 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -106,8 +106,9 @@ class DictionaryMemoTable::DictionaryMemoTableImpl { enable_if_memoize Visit(const T&) { using ConcreteMemoTable = typename DictionaryTraits::MemoTableType; auto memo_table = checked_cast(memo_table_); - return DictionaryTraits::GetDictionaryArrayData(pool_, value_type_, *memo_table, - start_offset_, out_); + ARROW_ASSIGN_OR_RAISE(*out_, DictionaryTraits::GetDictionaryArrayData( + pool_, value_type_, *memo_table, start_offset_)); + return Status::OK(); } }; diff --git a/cpp/src/arrow/array/dict_internal.h b/cpp/src/arrow/array/dict_internal.h index 5245c8d0ff313..3c1c8c453d1e7 100644 --- a/cpp/src/arrow/array/dict_internal.h +++ b/cpp/src/arrow/array/dict_internal.h @@ -29,6 +29,7 @@ #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -63,11 +64,9 @@ struct DictionaryTraits { using T = BooleanType; using MemoTableType = typename HashTraits::MemoTableType; - static Status GetDictionaryArrayData(MemoryPool* pool, - const std::shared_ptr& type, - const MemoTableType& memo_table, - int64_t start_offset, - std::shared_ptr* out) { + static Result> GetDictionaryArrayData( + MemoryPool* pool, const std::shared_ptr& type, + const MemoTableType& memo_table, int64_t start_offset) { if (start_offset < 0) { return Status::Invalid("invalid start_offset ", start_offset); } @@ -82,7 +81,9 @@ struct DictionaryTraits { : builder.Append(bool_values[i])); } - return builder.FinishInternal(out); + std::shared_ptr out; + RETURN_NOT_OK(builder.FinishInternal(&out)); + return out; } }; // namespace internal @@ -91,11 +92,9 @@ struct DictionaryTraits> { using c_type = typename T::c_type; using MemoTableType = typename HashTraits::MemoTableType; - static Status GetDictionaryArrayData(MemoryPool* pool, - const std::shared_ptr& type, - const MemoTableType& memo_table, - int64_t start_offset, - std::shared_ptr* out) { + static Result> GetDictionaryArrayData( + MemoryPool* pool, const std::shared_ptr& type, + const MemoTableType& memo_table, int64_t start_offset) { auto dict_length = static_cast(memo_table.size()) - start_offset; // This makes a copy, but we assume a dictionary array is usually small // compared to the size of the dictionary-using array. @@ -112,8 +111,7 @@ struct DictionaryTraits> { RETURN_NOT_OK( ComputeNullBitmap(pool, memo_table, start_offset, &null_count, &null_bitmap)); - *out = ArrayData::Make(type, dict_length, {null_bitmap, dict_buffer}, null_count); - return Status::OK(); + return ArrayData::Make(type, dict_length, {null_bitmap, dict_buffer}, null_count); } }; @@ -121,11 +119,9 @@ template struct DictionaryTraits> { using MemoTableType = typename HashTraits::MemoTableType; - static Status GetDictionaryArrayData(MemoryPool* pool, - const std::shared_ptr& type, - const MemoTableType& memo_table, - int64_t start_offset, - std::shared_ptr* out) { + static Result> GetDictionaryArrayData( + MemoryPool* pool, const std::shared_ptr& type, + const MemoTableType& memo_table, int64_t start_offset) { using offset_type = typename T::offset_type; // Create the offsets buffer @@ -148,11 +144,9 @@ struct DictionaryTraits> { RETURN_NOT_OK( ComputeNullBitmap(pool, memo_table, start_offset, &null_count, &null_bitmap)); - *out = ArrayData::Make(type, dict_length, + return ArrayData::Make(type, dict_length, {null_bitmap, std::move(dict_offsets), std::move(dict_data)}, null_count); - - return Status::OK(); } }; @@ -160,11 +154,9 @@ template struct DictionaryTraits> { using MemoTableType = typename HashTraits::MemoTableType; - static Status GetDictionaryArrayData(MemoryPool* pool, - const std::shared_ptr& type, - const MemoTableType& memo_table, - int64_t start_offset, - std::shared_ptr* out) { + static Result> GetDictionaryArrayData( + MemoryPool* pool, const std::shared_ptr& type, + const MemoTableType& memo_table, int64_t start_offset) { const T& concrete_type = internal::checked_cast(*type); // Create the data buffer @@ -182,9 +174,8 @@ struct DictionaryTraits> { RETURN_NOT_OK( ComputeNullBitmap(pool, memo_table, start_offset, &null_count, &null_bitmap)); - *out = ArrayData::Make(type, dict_length, {null_bitmap, std::move(dict_data)}, + return ArrayData::Make(type, dict_length, {null_bitmap, std::move(dict_data)}, null_count); - return Status::OK(); } }; diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index d7a61d0a55985..eaec940556361 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -275,6 +275,29 @@ struct EnumTraits } }; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "SetLookupOptions::NullMatchingBehavior"; } + static std::string value_name(compute::SetLookupOptions::NullMatchingBehavior value) { + switch (value) { + case compute::SetLookupOptions::NullMatchingBehavior::MATCH: + return "MATCH"; + case compute::SetLookupOptions::NullMatchingBehavior::SKIP: + return "SKIP"; + case compute::SetLookupOptions::NullMatchingBehavior::EMIT_NULL: + return "EMIT_NULL"; + case compute::SetLookupOptions::NullMatchingBehavior::INCONCLUSIVE: + return "INCONCLUSIVE"; + } + return ""; + } +}; + } // namespace internal namespace compute { @@ -286,6 +309,7 @@ using ::arrow::internal::checked_cast; namespace internal { namespace { +using ::arrow::internal::CoercedDataMember; using ::arrow::internal::DataMember; static auto kArithmeticOptionsType = GetFunctionOptionsType( DataMember("check_overflow", &ArithmeticOptions::check_overflow)); @@ -344,7 +368,8 @@ static auto kRoundToMultipleOptionsType = GetFunctionOptionsType( DataMember("value_set", &SetLookupOptions::value_set), - DataMember("skip_nulls", &SetLookupOptions::skip_nulls)); + CoercedDataMember("null_matching_behavior", &SetLookupOptions::null_matching_behavior, + &SetLookupOptions::GetNullMatchingBehavior)); static auto kSliceOptionsType = GetFunctionOptionsType( DataMember("start", &SliceOptions::start), DataMember("stop", &SliceOptions::stop), DataMember("step", &SliceOptions::step)); @@ -540,8 +565,29 @@ constexpr char RoundToMultipleOptions::kTypeName[]; SetLookupOptions::SetLookupOptions(Datum value_set, bool skip_nulls) : FunctionOptions(internal::kSetLookupOptionsType), value_set(std::move(value_set)), - skip_nulls(skip_nulls) {} -SetLookupOptions::SetLookupOptions() : SetLookupOptions({}, false) {} + skip_nulls(skip_nulls) { + if (skip_nulls) { + this->null_matching_behavior = SetLookupOptions::SKIP; + } else { + this->null_matching_behavior = SetLookupOptions::MATCH; + } +} +SetLookupOptions::SetLookupOptions( + Datum value_set, SetLookupOptions::NullMatchingBehavior null_matching_behavior) + : FunctionOptions(internal::kSetLookupOptionsType), + value_set(std::move(value_set)), + null_matching_behavior(std::move(null_matching_behavior)) {} +SetLookupOptions::SetLookupOptions() + : SetLookupOptions({}, SetLookupOptions::NullMatchingBehavior::MATCH) {} +SetLookupOptions::NullMatchingBehavior SetLookupOptions::GetNullMatchingBehavior() const { + if (!this->skip_nulls.has_value()) { + return this->null_matching_behavior; + } else if (this->skip_nulls.value()) { + return SetLookupOptions::SKIP; + } else { + return SetLookupOptions::MATCH; + } +} constexpr char SetLookupOptions::kTypeName[]; SliceOptions::SliceOptions(int64_t start, int64_t stop, int64_t step) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 0a06a2829f0da..9f12471ddca14 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -268,19 +268,49 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { /// Options for IsIn and IndexIn functions class ARROW_EXPORT SetLookupOptions : public FunctionOptions { public: - explicit SetLookupOptions(Datum value_set, bool skip_nulls = false); + /// How to handle null values. + enum NullMatchingBehavior { + /// MATCH, any null in `value_set` is successfully matched in + /// the input. + MATCH, + /// SKIP, any null in `value_set` is ignored and nulls in the input + /// produce null (IndexIn) or false (IsIn) values in the output. + SKIP, + /// EMIT_NULL, any null in `value_set` is ignored and nulls in the + /// input produce null (IndexIn and IsIn) values in the output. + EMIT_NULL, + /// INCONCLUSIVE, null values are regarded as unknown values, which is + /// sql-compatible. nulls in the input produce null (IndexIn and IsIn) + /// values in the output. Besides, if `value_set` contains a null, + /// non-null unmatched values in the input also produce null values + /// (IndexIn and IsIn) in the output. + INCONCLUSIVE + }; + + explicit SetLookupOptions(Datum value_set, NullMatchingBehavior = MATCH); SetLookupOptions(); + + // DEPRECATED(will be removed after removing of skip_nulls) + explicit SetLookupOptions(Datum value_set, bool skip_nulls); + static constexpr char const kTypeName[] = "SetLookupOptions"; /// The set of values to look up input values into. Datum value_set; + + NullMatchingBehavior null_matching_behavior; + + // DEPRECATED(will be removed after removing of skip_nulls) + NullMatchingBehavior GetNullMatchingBehavior() const; + + // DEPRECATED(use null_matching_behavior instead) /// Whether nulls in `value_set` count for lookup. /// /// If true, any null in `value_set` is ignored and nulls in the input /// produce null (IndexIn) or false (IsIn) values in the output. /// If false, any null in `value_set` is successfully matched in /// the input. - bool skip_nulls; + std::optional skip_nulls; }; /// Options for struct_field function diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index b852f6f6b0cdb..44159e76600fb 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -263,8 +263,9 @@ TEST(Expression, ToString) { auto in_12 = call("index_in", {field_ref("beta")}, compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); - EXPECT_EQ(in_12.ToString(), - "index_in(beta, {value_set=int32:[\n 1,\n 2\n], skip_nulls=false})"); + EXPECT_EQ( + in_12.ToString(), + "index_in(beta, {value_set=int32:[\n 1,\n 2\n], null_matching_behavior=MATCH})"); EXPECT_EQ(and_(field_ref("a"), field_ref("b")).ToString(), "(a and b)"); EXPECT_EQ(or_(field_ref("a"), field_ref("b")).ToString(), "(a or b)"); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index 00d391653d240..e2d5583e36e6b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -44,6 +44,7 @@ struct SetLookupState : public SetLookupStateBase { explicit SetLookupState(MemoryPool* pool) : memory_pool(pool) {} Status Init(const SetLookupOptions& options) { + this->null_matching_behavior = options.GetNullMatchingBehavior(); if (options.value_set.is_array()) { const ArrayData& value_set = *options.value_set.array(); memo_index_to_value_index.reserve(value_set.length); @@ -66,7 +67,8 @@ struct SetLookupState : public SetLookupStateBase { } else { return Status::Invalid("value_set should be an array or chunked array"); } - if (!options.skip_nulls && lookup_table->GetNull() >= 0) { + if (this->null_matching_behavior != SetLookupOptions::SKIP && + lookup_table->GetNull() >= 0) { null_index = memo_index_to_value_index[lookup_table->GetNull()]; } value_set_type = options.value_set.type(); @@ -117,19 +119,23 @@ struct SetLookupState : public SetLookupStateBase { // be mapped back to indices in the value_set. std::vector memo_index_to_value_index; int32_t null_index = -1; + SetLookupOptions::NullMatchingBehavior null_matching_behavior; }; template <> struct SetLookupState : public SetLookupStateBase { explicit SetLookupState(MemoryPool*) {} - Status Init(const SetLookupOptions& options) { - value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls; + Status Init(SetLookupOptions& options) { + null_matching_behavior = options.GetNullMatchingBehavior(); + value_set_has_null = (options.value_set.length() > 0) && + this->null_matching_behavior != SetLookupOptions::SKIP; value_set_type = null(); return Status::OK(); } bool value_set_has_null; + SetLookupOptions::NullMatchingBehavior null_matching_behavior; }; // TODO: Put this concept somewhere reusable @@ -270,14 +276,20 @@ struct IndexInVisitor { : ctx(ctx), data(data), out(out), out_bitmap(out->buffers[0].data) {} Status Visit(const DataType& type) { - DCHECK_EQ(type.id(), Type::NA); + DCHECK(false) << "IndexIn " << type; + return Status::NotImplemented("IndexIn has no implementation with value type ", type); + } + + Status Visit(const NullType&) { const auto& state = checked_cast&>(*ctx->state()); if (data.length != 0) { - // skip_nulls is honored for consistency with other types - bit_util::SetBitsTo(out_bitmap, out->offset, out->length, state.value_set_has_null); + bit_util::SetBitsTo(out_bitmap, out->offset, out->length, + state.null_matching_behavior == SetLookupOptions::MATCH && + state.value_set_has_null); // Set all values to 0, which will be unmasked only if null is in the value_set + // and null_matching_behavior is equal to MATCH std::memset(out->GetValues(1), 0x00, out->length * sizeof(int32_t)); } return Status::OK(); @@ -305,7 +317,8 @@ struct IndexInVisitor { bitmap_writer.Next(); }, [&]() { - if (state.null_index != -1) { + if (state.null_index != -1 && + state.null_matching_behavior == SetLookupOptions::MATCH) { bitmap_writer.Set(); // value_set included null @@ -379,49 +392,86 @@ Status ExecIndexIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { return IndexInVisitor(ctx, batch[0].array, out->array_span_mutable()).Execute(); } -// ---------------------------------------------------------------------- - // IsIn writes the results into a preallocated boolean data bitmap struct IsInVisitor { KernelContext* ctx; const ArraySpan& data; ArraySpan* out; + uint8_t* out_boolean_bitmap; + uint8_t* out_null_bitmap; IsInVisitor(KernelContext* ctx, const ArraySpan& data, ArraySpan* out) - : ctx(ctx), data(data), out(out) {} + : ctx(ctx), + data(data), + out(out), + out_boolean_bitmap(out->buffers[1].data), + out_null_bitmap(out->buffers[0].data) {} Status Visit(const DataType& type) { - DCHECK_EQ(type.id(), Type::NA); + DCHECK(false) << "IndexIn " << type; + return Status::NotImplemented("IsIn has no implementation with value type ", type); + } + + Status Visit(const NullType&) { const auto& state = checked_cast&>(*ctx->state()); - // skip_nulls is honored for consistency with other types - bit_util::SetBitsTo(out->buffers[1].data, out->offset, out->length, - state.value_set_has_null); + + if (state.null_matching_behavior == SetLookupOptions::MATCH && + state.value_set_has_null) { + bit_util::SetBitsTo(out_boolean_bitmap, out->offset, out->length, true); + bit_util::SetBitsTo(out_null_bitmap, out->offset, out->length, true); + } else if (state.null_matching_behavior == SetLookupOptions::SKIP || + (!state.value_set_has_null && + state.null_matching_behavior == SetLookupOptions::MATCH)) { + bit_util::SetBitsTo(out_boolean_bitmap, out->offset, out->length, false); + bit_util::SetBitsTo(out_null_bitmap, out->offset, out->length, true); + } else { + bit_util::SetBitsTo(out_null_bitmap, out->offset, out->length, false); + } return Status::OK(); } template Status ProcessIsIn(const SetLookupState& state, const ArraySpan& input) { using T = typename GetViewType::T; - FirstTimeBitmapWriter writer(out->buffers[1].data, out->offset, out->length); + FirstTimeBitmapWriter writer_boolean(out_boolean_bitmap, out->offset, out->length); + FirstTimeBitmapWriter writer_null(out_null_bitmap, out->offset, out->length); + bool value_set_has_null = state.null_index != -1; VisitArraySpanInline( input, [&](T v) { - if (state.lookup_table->Get(v) != -1) { - writer.Set(); - } else { - writer.Clear(); + if (state.lookup_table->Get(v) != -1) { // true + writer_boolean.Set(); + writer_null.Set(); + } else if (state.null_matching_behavior == SetLookupOptions::INCONCLUSIVE && + value_set_has_null) { // null + writer_boolean.Clear(); + writer_null.Clear(); + } else { // false + writer_boolean.Clear(); + writer_null.Set(); } - writer.Next(); + writer_boolean.Next(); + writer_null.Next(); }, [&]() { - if (state.null_index != -1) { - writer.Set(); - } else { - writer.Clear(); + if (state.null_matching_behavior == SetLookupOptions::MATCH && + value_set_has_null) { // true + writer_boolean.Set(); + writer_null.Set(); + } else if (state.null_matching_behavior == SetLookupOptions::SKIP || + (!value_set_has_null && state.null_matching_behavior == + SetLookupOptions::MATCH)) { // false + writer_boolean.Clear(); + writer_null.Set(); + } else { // null + writer_boolean.Clear(); + writer_null.Clear(); } - writer.Next(); + writer_boolean.Next(); + writer_null.Next(); }); - writer.Finish(); + writer_boolean.Finish(); + writer_null.Finish(); return Status::OK(); } @@ -598,7 +648,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { ScalarKernel isin_base; isin_base.init = InitSetLookup; isin_base.exec = ExecIsIn; - isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL; + isin_base.null_handling = NullHandling::COMPUTED_PREALLOCATE; auto is_in = std::make_shared("is_in", Arity::Unary(), is_in_doc); AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index d1645eb8d9a49..89e10d1b54103 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -50,7 +50,67 @@ namespace compute { void CheckIsIn(const std::shared_ptr input, const std::shared_ptr& value_set, const std::string& expected_json, - bool skip_nulls = false) { + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + auto expected = ArrayFromJSON(boolean(), expected_json); + + ASSERT_OK_AND_ASSIGN(Datum actual_datum, + IsIn(input, SetLookupOptions(value_set, null_matching_behavior))); + std::shared_ptr actual = actual_datum.make_array(); + ValidateOutput(actual_datum); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +void CheckIsIn(const std::shared_ptr& type, const std::string& input_json, + const std::string& value_set_json, const std::string& expected_json, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + auto input = ArrayFromJSON(type, input_json); + auto value_set = ArrayFromJSON(type, value_set_json); + CheckIsIn(input, value_set, expected_json, null_matching_behavior); +} + +void CheckIsInChunked(const std::shared_ptr& input, + const std::shared_ptr& value_set, + const std::shared_ptr& expected, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + ASSERT_OK_AND_ASSIGN(Datum actual_datum, + IsIn(input, SetLookupOptions(value_set, null_matching_behavior))); + auto actual = actual_datum.chunked_array(); + ValidateOutput(actual_datum); + + // Output contiguous in a single chunk + ASSERT_EQ(1, actual->num_chunks()); + ASSERT_TRUE(actual->Equals(*expected)); +} + +void CheckIsInDictionary(const std::shared_ptr& type, + const std::shared_ptr& index_type, + const std::string& input_dictionary_json, + const std::string& input_index_json, + const std::string& value_set_json, + const std::string& expected_json, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + auto dict_type = dictionary(index_type, type); + auto indices = ArrayFromJSON(index_type, input_index_json); + auto dict = ArrayFromJSON(type, input_dictionary_json); + + ASSERT_OK_AND_ASSIGN(auto input, DictionaryArray::FromArrays(dict_type, indices, dict)); + auto value_set = ArrayFromJSON(type, value_set_json); + auto expected = ArrayFromJSON(boolean(), expected_json); + + ASSERT_OK_AND_ASSIGN(Datum actual_datum, + IsIn(input, SetLookupOptions(value_set, null_matching_behavior))); + std::shared_ptr actual = actual_datum.make_array(); + ValidateOutput(actual_datum); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +void CheckIsIn(const std::shared_ptr input, + const std::shared_ptr& value_set, const std::string& expected_json, + bool skip_nulls) { auto expected = ArrayFromJSON(boolean(), expected_json); ASSERT_OK_AND_ASSIGN(Datum actual_datum, @@ -62,7 +122,7 @@ void CheckIsIn(const std::shared_ptr input, void CheckIsIn(const std::shared_ptr& type, const std::string& input_json, const std::string& value_set_json, const std::string& expected_json, - bool skip_nulls = false) { + bool skip_nulls) { auto input = ArrayFromJSON(type, input_json); auto value_set = ArrayFromJSON(type, value_set_json); CheckIsIn(input, value_set, expected_json, skip_nulls); @@ -70,8 +130,7 @@ void CheckIsIn(const std::shared_ptr& type, const std::string& input_j void CheckIsInChunked(const std::shared_ptr& input, const std::shared_ptr& value_set, - const std::shared_ptr& expected, - bool skip_nulls = false) { + const std::shared_ptr& expected, bool skip_nulls) { ASSERT_OK_AND_ASSIGN(Datum actual_datum, IsIn(input, SetLookupOptions(value_set, skip_nulls))); auto actual = actual_datum.chunked_array(); @@ -87,7 +146,7 @@ void CheckIsInDictionary(const std::shared_ptr& type, const std::string& input_dictionary_json, const std::string& input_index_json, const std::string& value_set_json, - const std::string& expected_json, bool skip_nulls = false) { + const std::string& expected_json, bool skip_nulls) { auto dict_type = dictionary(index_type, type); auto indices = ArrayFromJSON(index_type, input_index_json); auto dict = ArrayFromJSON(type, input_dictionary_json); @@ -185,18 +244,43 @@ TYPED_TEST(TestIsInKernelPrimitive, IsIn) { /*skip_nulls=*/false); CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[null, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[null, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Nulls in right array CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", /*skip_nulls=*/false); CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[null, true, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Nulls in both the arrays CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[true, true, true, false, true]", /*skip_nulls=*/false); CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[true, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", + "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[null, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[null, true, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", @@ -204,6 +288,18 @@ TYPED_TEST(TestIsInKernelPrimitive, IsIn) { /*skip_nulls=*/false); CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", "[false, true, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", + "[true, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", + "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", + "[null, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", + "[null, true, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Empty Arrays CheckIsIn(type, "[]", "[]", "[]"); @@ -217,11 +313,30 @@ TEST_F(TestIsInKernel, NullType) { CheckIsIn(type, "[]", "[]", "[]"); CheckIsIn(type, "[null, null]", "[null]", "[false, false]", /*skip_nulls=*/true); + CheckIsIn(type, "[null, null]", "[null]", "[false, false]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[null, null]", "[null]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[null, null]", "[null]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); + CheckIsIn(type, "[null, null]", "[]", "[false, false]", /*skip_nulls=*/true); + CheckIsIn(type, "[null, null]", "[]", "[false, false]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[null, null]", "[]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[null, null]", "[]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, "[null, null, null]", "[null, null]", "[true, true, true]"); CheckIsIn(type, "[null, null]", "[null, null]", "[false, false]", /*skip_nulls=*/true); + CheckIsIn(type, "[null, null]", "[null, null]", "[false, false]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[null, null]", "[null, null]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[null, null]", "[null, null]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } TEST_F(TestIsInKernel, TimeTimestamp) { @@ -232,12 +347,36 @@ TEST_F(TestIsInKernel, TimeTimestamp) { "[true, true, false, true, true]", /*skip_nulls=*/false); CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", "[true, false, false, true, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, false, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, null, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, null, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", "[true, true, false, true, true]", /*skip_nulls=*/false); CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", "[true, false, false, true, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, false, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, null, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, null, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } // Disallow mixing timezone-aware and timezone-naive values @@ -260,12 +399,36 @@ TEST_F(TestIsInKernel, TimeDuration) { "[true, true, false, true, true]", /*skip_nulls=*/false); CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", "[true, false, false, true, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, false, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, null, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", + "[true, null, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", "[true, true, false, true, true]", /*skip_nulls=*/false); CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", "[true, false, false, true, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, false, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, null, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", + "[true, null, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } // Different units, cast value_set to values will fail, then cast values to value_set @@ -285,17 +448,53 @@ TEST_F(TestIsInKernel, Boolean) { "[false, true, false, false, true]", /*skip_nulls=*/false); CheckIsIn(type, "[true, false, null, true, false]", "[false]", "[false, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[true, false, null, true, false]", "[false]", + "[false, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[true, false, null, true, false]", "[false]", + "[false, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[true, false, null, true, false]", "[false]", + "[false, true, null, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[true, false, null, true, false]", "[false]", + "[false, true, null, false, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", "[false, true, true, false, true]", /*skip_nulls=*/false); CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", "[false, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", + "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", + "[false, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", + "[false, true, null, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", + "[null, true, null, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", "[false, true, true, false, true]", /*skip_nulls=*/false); CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", "[false, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", + "[false, true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", + "[false, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", + "[false, true, null, false, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", + "[null, true, null, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } TYPED_TEST_SUITE(TestIsInKernelBinary, BaseBinaryArrowTypes); @@ -309,6 +508,18 @@ TYPED_TEST(TestIsInKernelBinary, Binary) { CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", "[true, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", "[true, true, false, true, true]", @@ -316,6 +527,18 @@ TYPED_TEST(TestIsInKernelBinary, Binary) { CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", "[true, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", + "[true, true, null, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", @@ -324,6 +547,18 @@ TYPED_TEST(TestIsInKernelBinary, Binary) { CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", + R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", + R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", + R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", + R"([null, "aaa", "aaa", "", "", null])", "[true, true, null, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } TEST_F(TestIsInKernel, FixedSizeBinary) { @@ -335,6 +570,18 @@ TEST_F(TestIsInKernel, FixedSizeBinary) { CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", "[true, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", "[true, true, false, true, true]", @@ -342,6 +589,18 @@ TEST_F(TestIsInKernel, FixedSizeBinary) { CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", "[true, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", + "[true, true, null, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", @@ -352,6 +611,22 @@ TEST_F(TestIsInKernel, FixedSizeBinary) { R"(["aaa", null, "aaa", "bbb", "bbb", null])", "[true, true, false, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", + R"(["aaa", null, "aaa", "bbb", "bbb", null])", + "[true, true, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", + R"(["aaa", null, "aaa", "bbb", "bbb", null])", + "[true, true, false, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", + R"(["aaa", null, "aaa", "bbb", "bbb", null])", + "[true, true, false, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", + R"(["aaa", null, "aaa", "bbb", "bbb", null])", + "[true, true, null, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(fixed_size_binary(3), R"(["abc"])"), @@ -366,6 +641,18 @@ TEST_F(TestIsInKernel, Decimal) { CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", "[true, false, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", + "[true, false, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", + "[true, false, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", + "[true, false, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", + "[true, false, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9", null])", "[true, false, true, true, true]", @@ -373,6 +660,18 @@ TEST_F(TestIsInKernel, Decimal) { CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9", null])", "[true, false, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"(["12.3", "78.9", null])", "[true, false, true, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"(["12.3", "78.9", null])", "[true, false, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"(["12.3", "78.9", null])", "[true, false, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"(["12.3", "78.9", null])", "[true, null, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in right array CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", @@ -383,6 +682,22 @@ TEST_F(TestIsInKernel, Decimal) { R"([null, "12.3", "12.3", "78.9", "78.9", null])", "[true, false, true, false, true]", /*skip_nulls=*/true); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"([null, "12.3", "12.3", "78.9", "78.9", null])", + "[true, false, true, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"([null, "12.3", "12.3", "78.9", "78.9", null])", + "[true, false, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"([null, "12.3", "12.3", "78.9", "78.9", null])", + "[true, false, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", + R"([null, "12.3", "12.3", "78.9", "78.9", null])", + "[true, null, true, null, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); CheckIsIn(ArrayFromJSON(decimal128(4, 2), R"(["12.30", "45.60", "78.90"])"), ArrayFromJSON(type, R"(["12.3", "78.9"])"), "[true, false, true]"); @@ -405,6 +720,20 @@ TEST_F(TestIsInKernel, DictionaryArray) { /*value_set_json=*/"[4.1, 42, -1.0]", /*expected_json=*/"[true, true, false, true]", /*skip_nulls=*/false); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 2, null, 0]", + /*value_set_json=*/R"(["A", "B", "C"])", + /*expected_json=*/"[true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsInDictionary(/*type=*/float32(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/"[4.1, -1.0, 42, 9.8]", + /*input_index_json=*/"[1, 2, null, 0]", + /*value_set_json=*/"[4.1, 42, -1.0]", + /*expected_json=*/"[true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); // With nulls and skip_nulls=false CheckIsInDictionary(/*type=*/utf8(), @@ -428,6 +757,27 @@ TEST_F(TestIsInKernel, DictionaryArray) { /*value_set_json=*/R"(["C", "B", "A"])", /*expected_json=*/"[false, false, false, true, false]", /*skip_nulls=*/false); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[true, false, true, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[true, false, true, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A"])", + /*expected_json=*/"[false, false, false, true, false]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); // With nulls and skip_nulls=true CheckIsInDictionary(/*type=*/utf8(), @@ -451,6 +801,73 @@ TEST_F(TestIsInKernel, DictionaryArray) { /*value_set_json=*/R"(["C", "B", "A"])", /*expected_json=*/"[false, false, false, true, false]", /*skip_nulls=*/true); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[true, false, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[false, false, false, true, false]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A"])", + /*expected_json=*/"[false, false, false, true, false]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + + // With nulls and null_matching_behavior=EMIT_NULL + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[true, false, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[null, false, null, true, null]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A"])", + /*expected_json=*/"[null, false, null, true, null]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + + // With nulls and null_matching_behavior=INCONCLUSIVE + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[true, null, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[null, null, null, true, null]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A"])", + /*expected_json=*/"[null, false, null, true, null]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // With duplicates in value_set CheckIsInDictionary(/*type=*/utf8(), @@ -474,6 +891,41 @@ TEST_F(TestIsInKernel, DictionaryArray) { /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", /*expected_json=*/"[true, false, false, true, true]", /*skip_nulls=*/true); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 2, null, 0]", + /*value_set_json=*/R"(["A", "A", "B", "A", "B", "C"])", + /*expected_json=*/"[true, true, false, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", + /*expected_json=*/"[true, false, true, true, true]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", + /*expected_json=*/"[true, false, false, true, true]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", + /*expected_json=*/"[true, false, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + CheckIsInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", + /*expected_json=*/"[true, null, null, true, true]", + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } } @@ -487,14 +939,38 @@ TEST_F(TestIsInKernel, ChunkedArrayInvoke) { CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false); CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::SKIP); + expected = ChunkedArrayFromJSON( + boolean(), {"[true, true, true, true, false]", "[true, null, true, false]"}); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + expected = ChunkedArrayFromJSON( + boolean(), {"[true, true, true, true, false]", "[true, null, true, false]"}); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); value_set = ChunkedArrayFromJSON(utf8(), {R"(["", "def"])", R"([null])"}); expected = ChunkedArrayFromJSON( boolean(), {"[false, true, true, false, false]", "[true, true, false, false]"}); CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::MATCH); expected = ChunkedArrayFromJSON( boolean(), {"[false, true, true, false, false]", "[true, false, false, false]"}); CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::SKIP); + expected = ChunkedArrayFromJSON( + boolean(), {"[false, true, true, false, false]", "[true, null, false, false]"}); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + expected = ChunkedArrayFromJSON( + boolean(), {"[null, true, true, null, null]", "[true, null, null, null]"}); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); // Duplicates in value_set value_set = @@ -502,9 +978,21 @@ TEST_F(TestIsInKernel, ChunkedArrayInvoke) { expected = ChunkedArrayFromJSON( boolean(), {"[false, true, true, false, false]", "[true, true, false, false]"}); CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::MATCH); expected = ChunkedArrayFromJSON( boolean(), {"[false, true, true, false, false]", "[true, false, false, false]"}); CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::SKIP); + expected = ChunkedArrayFromJSON( + boolean(), {"[false, true, true, false, false]", "[true, null, false, false]"}); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::EMIT_NULL); + expected = ChunkedArrayFromJSON( + boolean(), {"[null, true, true, null, null]", "[true, null, null, null]"}); + CheckIsInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::INCONCLUSIVE); } // ---------------------------------------------------------------------- @@ -514,7 +1002,70 @@ class TestIndexInKernel : public ::testing::Test { public: void CheckIndexIn(const std::shared_ptr& input, const std::shared_ptr& value_set, - const std::string& expected_json, bool skip_nulls = false) { + const std::string& expected_json, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + std::shared_ptr expected = ArrayFromJSON(int32(), expected_json); + + SetLookupOptions options(value_set, null_matching_behavior); + ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, options)); + std::shared_ptr actual = actual_datum.make_array(); + ValidateOutput(actual_datum); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + void CheckIndexIn(const std::shared_ptr& type, const std::string& input_json, + const std::string& value_set_json, const std::string& expected_json, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + std::shared_ptr input = ArrayFromJSON(type, input_json); + std::shared_ptr value_set = ArrayFromJSON(type, value_set_json); + return CheckIndexIn(input, value_set, expected_json, null_matching_behavior); + } + + void CheckIndexInChunked(const std::shared_ptr& input, + const std::shared_ptr& value_set, + const std::shared_ptr& expected, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + ASSERT_OK_AND_ASSIGN( + Datum actual, + IndexIn(input, SetLookupOptions(value_set, null_matching_behavior))); + ASSERT_EQ(Datum::CHUNKED_ARRAY, actual.kind()); + ValidateOutput(actual); + + auto actual_chunked = actual.chunked_array(); + + // Output contiguous in a single chunk + ASSERT_EQ(1, actual_chunked->num_chunks()); + ASSERT_TRUE(actual_chunked->Equals(*expected)); + } + + void CheckIndexInDictionary( + const std::shared_ptr& type, const std::shared_ptr& index_type, + const std::string& input_dictionary_json, const std::string& input_index_json, + const std::string& value_set_json, const std::string& expected_json, + SetLookupOptions::NullMatchingBehavior null_matching_behavior = + SetLookupOptions::MATCH) { + auto dict_type = dictionary(index_type, type); + auto indices = ArrayFromJSON(index_type, input_index_json); + auto dict = ArrayFromJSON(type, input_dictionary_json); + + ASSERT_OK_AND_ASSIGN(auto input, + DictionaryArray::FromArrays(dict_type, indices, dict)); + auto value_set = ArrayFromJSON(type, value_set_json); + auto expected = ArrayFromJSON(int32(), expected_json); + + SetLookupOptions options(value_set, null_matching_behavior); + ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, options)); + std::shared_ptr actual = actual_datum.make_array(); + ValidateOutput(actual_datum); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + void CheckIndexIn(const std::shared_ptr& input, + const std::shared_ptr& value_set, + const std::string& expected_json, bool skip_nulls) { std::shared_ptr expected = ArrayFromJSON(int32(), expected_json); SetLookupOptions options(value_set, skip_nulls); @@ -526,7 +1077,7 @@ class TestIndexInKernel : public ::testing::Test { void CheckIndexIn(const std::shared_ptr& type, const std::string& input_json, const std::string& value_set_json, const std::string& expected_json, - bool skip_nulls = false) { + bool skip_nulls) { std::shared_ptr input = ArrayFromJSON(type, input_json); std::shared_ptr value_set = ArrayFromJSON(type, value_set_json); return CheckIndexIn(input, value_set, expected_json, skip_nulls); @@ -553,7 +1104,7 @@ class TestIndexInKernel : public ::testing::Test { const std::string& input_dictionary_json, const std::string& input_index_json, const std::string& value_set_json, - const std::string& expected_json, bool skip_nulls = false) { + const std::string& expected_json, bool skip_nulls) { auto dict_type = dictionary(index_type, type); auto indices = ArrayFromJSON(index_type, input_index_json); auto dict = ArrayFromJSON(type, input_dictionary_json); @@ -656,6 +1207,16 @@ TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) { /*value_set=*/"[1, 3]", /*expected=*/"[null, 0, null, 1, null]", /*skip_nulls=*/true); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, 3]", + /*expected=*/"[null, 0, null, 1, null]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, 3]", + /*expected=*/"[null, 0, null, 1, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Same with duplicates in value_set this->CheckIndexIn(type, /*input=*/"[0, 1, 2, 3, null]", @@ -667,6 +1228,16 @@ TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) { /*value_set=*/"[1, 1, 3, 3]", /*expected=*/"[null, 0, null, 2, null]", /*skip_nulls=*/true); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, 1, 3, 3]", + /*expected=*/"[null, 0, null, 2, null]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, 1, 3, 3]", + /*expected=*/"[null, 0, null, 2, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Nulls in value_set this->CheckIndexIn(type, @@ -679,12 +1250,27 @@ TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) { /*value_set=*/"[1, 1, null, null, 3, 3]", /*expected=*/"[null, 0, null, 4, null]", /*skip_nulls=*/true); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, null, 3]", + /*expected=*/"[null, 0, null, 2, 1]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, 1, null, null, 3, 3]", + /*expected=*/"[null, 0, null, 4, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Same with duplicates in value_set this->CheckIndexIn(type, /*input=*/"[0, 1, 2, 3, null]", /*value_set=*/"[1, 1, null, null, 3, 3]", /*expected=*/"[null, 0, null, 4, 2]", /*skip_nulls=*/false); + this->CheckIndexIn(type, + /*input=*/"[0, 1, 2, 3, null]", + /*value_set=*/"[1, 1, null, null, 3, 3]", + /*expected=*/"[null, 0, null, 4, 2]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); } TEST_F(TestIndexInKernel, NullType) { @@ -695,6 +1281,10 @@ TEST_F(TestIndexInKernel, NullType) { CheckIndexIn(null(), "[null, null]", "[null]", "[null, null]", /*skip_nulls=*/true); CheckIndexIn(null(), "[null, null]", "[]", "[null, null]", /*skip_nulls=*/true); + CheckIndexIn(null(), "[null, null]", "[null]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIndexIn(null(), "[null, null]", "[]", "[null, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); } TEST_F(TestIndexInKernel, TimeTimestamp) { @@ -979,6 +1569,11 @@ TEST_F(TestIndexInKernel, FixedSizeBinary) { /*value_set=*/R"(["aaa", null, "bbb", "ccc"])", /*expected=*/R"([2, null, null, 0, 3, 0])", /*skip_nulls=*/true); + CheckIndexIn(fixed_size_binary(3), + /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", + /*value_set=*/R"(["aaa", null, "bbb", "ccc"])", + /*expected=*/R"([2, null, null, 0, 3, 0])", + /*null_matching_behavior=*/SetLookupOptions::SKIP); CheckIndexIn(fixed_size_binary(3), /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", @@ -989,6 +1584,11 @@ TEST_F(TestIndexInKernel, FixedSizeBinary) { /*value_set=*/R"(["aaa", "bbb", "ccc"])", /*expected=*/R"([1, null, null, 0, 2, 0])", /*skip_nulls=*/true); + CheckIndexIn(fixed_size_binary(3), + /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", + /*value_set=*/R"(["aaa", "bbb", "ccc"])", + /*expected=*/R"([1, null, null, 0, 2, 0])", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Duplicates in value_set CheckIndexIn(fixed_size_binary(3), @@ -1000,6 +1600,11 @@ TEST_F(TestIndexInKernel, FixedSizeBinary) { /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", "ccc"])", /*expected=*/R"([4, null, null, 0, 6, 0])", /*skip_nulls=*/true); + CheckIndexIn(fixed_size_binary(3), + /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", + /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", "ccc"])", + /*expected=*/R"([4, null, null, 0, 6, 0])", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Empty input array CheckIndexIn(fixed_size_binary(5), R"([])", R"(["bbbbb", null, "aaaaa", "ccccc"])", @@ -1026,6 +1631,11 @@ TEST_F(TestIndexInKernel, MonthDayNanoInterval) { /*value_set=*/R"([null, [4, 5, 6], [5, -1, 5]])", /*expected=*/R"([2, 0, 1, 2, null])", /*skip_nulls=*/false); + CheckIndexIn(type, + /*input=*/R"([[5, -1, 5], null, [4, 5, 6], [5, -1, 5], [1, 2, 3]])", + /*value_set=*/R"([null, [4, 5, 6], [5, -1, 5]])", + /*expected=*/R"([2, 0, 1, 2, null])", + /*null_matching_behavior=*/SetLookupOptions::MATCH); // Duplicates in value_set CheckIndexIn( @@ -1034,6 +1644,12 @@ TEST_F(TestIndexInKernel, MonthDayNanoInterval) { /*value_set=*/R"([null, null, [0, 0, 0], [0, 0, 0], [7, 8, 0], [7, 8, 0]])", /*expected=*/R"([4, 0, 2, 4, null])", /*skip_nulls=*/false); + CheckIndexIn( + type, + /*input=*/R"([[7, 8, 0], null, [0, 0, 0], [7, 8, 0], [0, 0, 1]])", + /*value_set=*/R"([null, null, [0, 0, 0], [0, 0, 0], [7, 8, 0], [7, 8, 0]])", + /*expected=*/R"([4, 0, 2, 4, null])", + /*null_matching_behavior=*/SetLookupOptions::MATCH); } TEST_F(TestIndexInKernel, Decimal) { @@ -1048,6 +1664,16 @@ TEST_F(TestIndexInKernel, Decimal) { /*value_set=*/R"([null, "11", "12"])", /*expected=*/R"([2, null, 1, 2, null])", /*skip_nulls=*/true); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"([null, "11", "12"])", + /*expected=*/R"([2, 0, 1, 2, null])", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"([null, "11", "12"])", + /*expected=*/R"([2, null, 1, 2, null])", + /*null_matching_behavior=*/SetLookupOptions::SKIP); CheckIndexIn(type, /*input=*/R"(["12", null, "11", "12", "13"])", @@ -1059,6 +1685,16 @@ TEST_F(TestIndexInKernel, Decimal) { /*value_set=*/R"(["11", "12"])", /*expected=*/R"([1, null, 0, 1, null])", /*skip_nulls=*/true); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"(["11", "12"])", + /*expected=*/R"([1, null, 0, 1, null])", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"(["11", "12"])", + /*expected=*/R"([1, null, 0, 1, null])", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Duplicates in value_set CheckIndexIn(type, @@ -1076,6 +1712,21 @@ TEST_F(TestIndexInKernel, Decimal) { /*value_set=*/R"([null, "11", "12"])", /*expected=*/R"([2, 0, 1, 2, null])", /*skip_nulls=*/false); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"([null, null, "11", "11", "12", "12"])", + /*expected=*/R"([4, 0, 2, 4, null])", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"([null, null, "11", "11", "12", "12"])", + /*expected=*/R"([4, null, 2, 4, null])", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIndexIn(type, + /*input=*/R"(["12", null, "11", "12", "13"])", + /*value_set=*/R"([null, "11", "12"])", + /*expected=*/R"([2, 0, 1, 2, null])", + /*null_matching_behavior=*/SetLookupOptions::MATCH); CheckIndexIn( ArrayFromJSON(decimal256(3, 1), R"(["12.0", null, "11.0", "12.0", "13.0"])"), @@ -1099,6 +1750,20 @@ TEST_F(TestIndexInKernel, DictionaryArray) { /*value_set_json=*/"[4.1, 42, -1.0]", /*expected_json=*/"[2, 1, null, 0]", /*skip_nulls=*/false); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 2, null, 0]", + /*value_set_json=*/R"(["A", "B", "C"])", + /*expected_json=*/"[1, 2, null, 0]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexInDictionary(/*type=*/float32(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/"[4.1, -1.0, 42, 9.8]", + /*input_index_json=*/"[1, 2, null, 0]", + /*value_set_json=*/"[4.1, 42, -1.0]", + /*expected_json=*/"[2, 1, null, 0]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); // With nulls and skip_nulls=false CheckIndexInDictionary(/*type=*/utf8(), @@ -1122,6 +1787,27 @@ TEST_F(TestIndexInKernel, DictionaryArray) { /*value_set_json=*/R"(["C", "B", "A"])", /*expected_json=*/"[null, null, null, 2, null]", /*skip_nulls=*/false); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[1, null, 3, 2, 1]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[3, null, 3, 2, 3]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A"])", + /*expected_json=*/"[null, null, null, 2, null]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); // With nulls and skip_nulls=true CheckIndexInDictionary(/*type=*/utf8(), @@ -1145,6 +1831,27 @@ TEST_F(TestIndexInKernel, DictionaryArray) { /*value_set_json=*/R"(["C", "B", "A"])", /*expected_json=*/"[null, null, null, 2, null]", /*skip_nulls=*/true); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[1, null, null, 2, 1]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A", null])", + /*expected_json=*/"[null, null, null, 2, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "B", "A"])", + /*expected_json=*/"[null, null, null, 2, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); // With duplicates in value_set CheckIndexInDictionary(/*type=*/utf8(), @@ -1168,6 +1875,27 @@ TEST_F(TestIndexInKernel, DictionaryArray) { /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])", /*expected_json=*/"[null, null, null, 4, null]", /*skip_nulls=*/true); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", + /*input_index_json=*/"[1, 2, null, 0]", + /*value_set_json=*/R"(["A", "A", "B", "B", "C", "C"])", + /*expected_json=*/"[2, 4, null, 0]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])", + /*expected_json=*/"[6, null, 6, 4, 6]", + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexInDictionary(/*type=*/utf8(), + /*index_type=*/index_ty, + /*input_dictionary_json=*/R"(["A", null, "C", "D"])", + /*input_index_json=*/"[1, 3, null, 0, 1]", + /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])", + /*expected_json=*/"[null, null, null, 4, null]", + /*null_matching_behavior=*/SetLookupOptions::SKIP); } } @@ -1181,21 +1909,33 @@ TEST_F(TestIndexInKernel, ChunkedArrayInvoke) { CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false); CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); + CheckIndexInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::MATCH); + CheckIndexInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Null in value_set value_set = ChunkedArrayFromJSON(utf8(), {R"(["ghi", "def"])", R"([null, "abc"])"}); expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, 2, 3, null]"}); CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false); + CheckIndexInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::MATCH); expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, null, 3, null]"}); CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); + CheckIndexInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::SKIP); // Duplicates in value_set value_set = ChunkedArrayFromJSON( utf8(), {R"(["ghi", "ghi", "def"])", R"(["def", null, null, "abc"])"}); expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, 4, 6, null]"}); CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false); + CheckIndexInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::MATCH); expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, null, 6, null]"}); CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); + CheckIndexInChunked(input, value_set, expected, + /*null_matching_behavior=*/SetLookupOptions::SKIP); } TEST(TestSetLookup, DispatchBest) { diff --git a/cpp/src/arrow/compute/kernels/vector_hash.cc b/cpp/src/arrow/compute/kernels/vector_hash.cc index a7bb2d88c291b..d9143b760f32b 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -285,8 +285,9 @@ class RegularHashKernel : public HashKernel { Status FlushFinal(ExecResult* out) override { return action_.FlushFinal(out); } Status GetDictionary(std::shared_ptr* out) override { - return DictionaryTraits::GetDictionaryArrayData(pool_, type_, *memo_table_, - 0 /* start_offset */, out); + ARROW_ASSIGN_OR_RAISE(*out, DictionaryTraits::GetDictionaryArrayData( + pool_, type_, *memo_table_, 0 /* start_offset */)); + return Status::OK(); } std::shared_ptr value_type() const override { return type_; } diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index c30441d911e4e..751937e93b937 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -88,11 +88,28 @@ parquet::ArrowReaderProperties MakeArrowReaderProperties( return properties; } -template +parquet::ArrowReaderProperties MakeArrowReaderProperties( + const ParquetFileFormat& format, const parquet::FileMetaData& metadata, + const ScanOptions& options, const ParquetFragmentScanOptions& parquet_scan_options) { + auto arrow_properties = MakeArrowReaderProperties(format, metadata); + arrow_properties.set_batch_size(options.batch_size); + // Must be set here since the sync ScanTask handles pre-buffering itself + arrow_properties.set_pre_buffer( + parquet_scan_options.arrow_reader_properties->pre_buffer()); + arrow_properties.set_cache_options( + parquet_scan_options.arrow_reader_properties->cache_options()); + arrow_properties.set_io_context( + parquet_scan_options.arrow_reader_properties->io_context()); + arrow_properties.set_use_threads(options.use_threads); + return arrow_properties; +} + Result> GetSchemaManifest( - const M& metadata, const parquet::ArrowReaderProperties& properties) { + const parquet::FileMetaData& metadata, + const parquet::ArrowReaderProperties& properties) { auto manifest = std::make_shared(); - const std::shared_ptr& key_value_metadata = nullptr; + const std::shared_ptr& key_value_metadata = + metadata.key_value_metadata(); RETURN_NOT_OK(SchemaManifest::Make(metadata.schema(), key_value_metadata, properties, manifest.get())); return manifest; @@ -410,13 +427,42 @@ Result> ParquetFileFormat::Inspect( Result> ParquetFileFormat::GetReader( const FileSource& source, const std::shared_ptr& options) const { - return GetReaderAsync(source, options, nullptr).result(); + return GetReader(source, options, /*metadata=*/nullptr); } Result> ParquetFileFormat::GetReader( const FileSource& source, const std::shared_ptr& options, const std::shared_ptr& metadata) const { - return GetReaderAsync(source, options, metadata).result(); + ARROW_ASSIGN_OR_RAISE( + auto parquet_scan_options, + GetFragmentScanOptions(kParquetTypeName, options.get(), + default_fragment_scan_options)); + auto properties = + MakeReaderProperties(*this, parquet_scan_options.get(), options->pool); + ARROW_ASSIGN_OR_RAISE(auto input, source.Open()); + // `parquet::ParquetFileReader::Open` will not wrap the exception as status, + // so using `open_parquet_file` to wrap it. + auto open_parquet_file = [&]() -> Result> { + BEGIN_PARQUET_CATCH_EXCEPTIONS + auto reader = parquet::ParquetFileReader::Open(std::move(input), + std::move(properties), metadata); + return reader; + END_PARQUET_CATCH_EXCEPTIONS + }; + + auto reader_opt = open_parquet_file(); + if (!reader_opt.ok()) { + return WrapSourceError(reader_opt.status(), source.path()); + } + auto reader = std::move(reader_opt).ValueOrDie(); + + std::shared_ptr reader_metadata = reader->metadata(); + auto arrow_properties = + MakeArrowReaderProperties(*this, *reader_metadata, *options, *parquet_scan_options); + std::unique_ptr arrow_reader; + RETURN_NOT_OK(parquet::arrow::FileReader::Make( + options->pool, std::move(reader), std::move(arrow_properties), &arrow_reader)); + return arrow_reader; } Future> ParquetFileFormat::GetReaderAsync( @@ -445,16 +491,8 @@ Future> ParquetFileFormat::GetReader ARROW_ASSIGN_OR_RAISE(std::unique_ptr reader, reader_fut.MoveResult()); std::shared_ptr metadata = reader->metadata(); - auto arrow_properties = MakeArrowReaderProperties(*self, *metadata); - arrow_properties.set_batch_size(options->batch_size); - // Must be set here since the sync ScanTask handles pre-buffering itself - arrow_properties.set_pre_buffer( - parquet_scan_options->arrow_reader_properties->pre_buffer()); - arrow_properties.set_cache_options( - parquet_scan_options->arrow_reader_properties->cache_options()); - arrow_properties.set_io_context( - parquet_scan_options->arrow_reader_properties->io_context()); - arrow_properties.set_use_threads(options->use_threads); + auto arrow_properties = + MakeArrowReaderProperties(*this, *metadata, *options, *parquet_scan_options); std::unique_ptr arrow_reader; RETURN_NOT_OK(parquet::arrow::FileReader::Make(options->pool, std::move(reader), std::move(arrow_properties), diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 42f923f0e6a27..177ca824179a8 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -18,12 +18,14 @@ #include "arrow/dataset/file_parquet.h" #include +#include #include #include #include "arrow/compute/api_scalar.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/test_util_internal.h" +#include "arrow/io/interfaces.h" #include "arrow/io/memory.h" #include "arrow/io/test_common.h" #include "arrow/io/util_internal.h" @@ -63,11 +65,15 @@ class ParquetFormatHelper { public: using FormatType = ParquetFileFormat; - static Result> Write(RecordBatchReader* reader) { + static Result> Write( + RecordBatchReader* reader, + const std::shared_ptr& arrow_properties = + default_arrow_writer_properties()) { auto pool = ::arrow::default_memory_pool(); std::shared_ptr out; auto sink = CreateOutputStream(pool); - RETURN_NOT_OK(WriteRecordBatchReader(reader, pool, sink)); + RETURN_NOT_OK(WriteRecordBatchReader(reader, pool, sink, default_writer_properties(), + arrow_properties)); return sink->Finish(); } static std::shared_ptr MakeFormat() { @@ -367,6 +373,29 @@ TEST_F(TestParquetFileFormat, MultithreadedScan) { ASSERT_EQ(batches.size(), kNumRowGroups); } +TEST_F(TestParquetFileFormat, SingleThreadExecutor) { + // Reset capacity for io executor + struct PoolResetGuard { + int original_capacity = io::GetIOThreadPoolCapacity(); + ~PoolResetGuard() { DCHECK_OK(io::SetIOThreadPoolCapacity(original_capacity)); } + } guard; + ASSERT_OK(io::SetIOThreadPoolCapacity(1)); + + auto reader = GetRecordBatchReader(schema({field("utf8", utf8())})); + + ASSERT_OK_AND_ASSIGN(auto buffer, ParquetFormatHelper::Write(reader.get())); + auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(buffer); + auto source = std::make_shared(std::move(buffer_reader), buffer->size()); + auto options = std::make_shared(); + + { + auto fragment = MakeFragment(*source); + auto count_rows = fragment->CountRows(literal(true), options); + ASSERT_OK_AND_ASSIGN(auto result, count_rows.MoveResult()); + ASSERT_EQ(expected_rows(), result); + } +} + class TestParquetFileSystemDataset : public WriteFileSystemDatasetMixin, public testing::Test { public: @@ -678,6 +707,29 @@ TEST_P(TestParquetFileFormatScan, PredicatePushdownRowGroupFragmentsUsingStringC CountRowGroupsInFragment(fragment, {0, 3}, equal(field_ref("x"), literal("a"))); } +TEST_P(TestParquetFileFormatScan, PredicatePushdownRowGroupFragmentsUsingDurationColumn) { + // GH-37111: Parquet arrow stores writer schema and possible field_id in + // key_value_metadata when store_schema enabled. When storing `arrow::duration`, it will + // be stored as int64. This test ensures that dataset can parse the writer schema + // correctly. + auto table = TableFromJSON(schema({field("t", duration(TimeUnit::NANO))}), + { + R"([{"t": 1}])", + R"([{"t": 2}, {"t": 3}])", + }); + TableBatchReader table_reader(*table); + ASSERT_OK_AND_ASSIGN( + auto buffer, + ParquetFormatHelper::Write( + &table_reader, ArrowWriterProperties::Builder().store_schema()->build())); + auto source = std::make_shared(buffer); + SetSchema({field("t", duration(TimeUnit::NANO))}); + ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); + + auto expr = equal(field_ref("t"), literal(::arrow::DurationScalar(1, TimeUnit::NANO))); + CountRowGroupsInFragment(fragment, {0}, expr); +} + // Tests projection with nested/indexed FieldRefs. // https://github.com/apache/arrow/issues/35579 TEST_P(TestParquetFileFormatScan, ProjectWithNonNamedFieldRefs) { diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 29f8882225ae3..08fbcde6fd9de 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -1454,14 +1454,7 @@ class ObjectOutputStream final : public io::OutputStream { // OutputStream interface - Status Close() override { - auto fut = CloseAsync(); - return fut.status(); - } - - Future<> CloseAsync() override { - if (closed_) return Status::OK(); - + Status EnsureReadyToFlushFromClose() { if (current_part_) { // Upload last part RETURN_NOT_OK(CommitCurrentPart()); @@ -1472,36 +1465,56 @@ class ObjectOutputStream final : public io::OutputStream { RETURN_NOT_OK(UploadPart("", 0)); } - // Wait for in-progress uploads to finish (if async writes are enabled) - return FlushAsync().Then([this]() { - ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); + return Status::OK(); + } - // At this point, all part uploads have finished successfully - DCHECK_GT(part_number_, 1); - DCHECK_EQ(upload_state_->completed_parts.size(), - static_cast(part_number_ - 1)); - - S3Model::CompletedMultipartUpload completed_upload; - completed_upload.SetParts(upload_state_->completed_parts); - S3Model::CompleteMultipartUploadRequest req; - req.SetBucket(ToAwsString(path_.bucket)); - req.SetKey(ToAwsString(path_.key)); - req.SetUploadId(upload_id_); - req.SetMultipartUpload(std::move(completed_upload)); - - auto outcome = - client_lock.Move()->CompleteMultipartUploadWithErrorFixup(std::move(req)); - if (!outcome.IsSuccess()) { - return ErrorToStatus( - std::forward_as_tuple("When completing multiple part upload for key '", - path_.key, "' in bucket '", path_.bucket, "': "), - "CompleteMultipartUpload", outcome.GetError()); - } + Status FinishPartUploadAfterFlush() { + ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); - holder_ = nullptr; - closed_ = true; - return Status::OK(); - }); + // At this point, all part uploads have finished successfully + DCHECK_GT(part_number_, 1); + DCHECK_EQ(upload_state_->completed_parts.size(), + static_cast(part_number_ - 1)); + + S3Model::CompletedMultipartUpload completed_upload; + completed_upload.SetParts(upload_state_->completed_parts); + S3Model::CompleteMultipartUploadRequest req; + req.SetBucket(ToAwsString(path_.bucket)); + req.SetKey(ToAwsString(path_.key)); + req.SetUploadId(upload_id_); + req.SetMultipartUpload(std::move(completed_upload)); + + auto outcome = + client_lock.Move()->CompleteMultipartUploadWithErrorFixup(std::move(req)); + if (!outcome.IsSuccess()) { + return ErrorToStatus( + std::forward_as_tuple("When completing multiple part upload for key '", + path_.key, "' in bucket '", path_.bucket, "': "), + "CompleteMultipartUpload", outcome.GetError()); + } + + holder_ = nullptr; + closed_ = true; + return Status::OK(); + } + + Status Close() override { + if (closed_) return Status::OK(); + + RETURN_NOT_OK(EnsureReadyToFlushFromClose()); + + RETURN_NOT_OK(Flush()); + return FinishPartUploadAfterFlush(); + } + + Future<> CloseAsync() override { + if (closed_) return Status::OK(); + + RETURN_NOT_OK(EnsureReadyToFlushFromClose()); + + auto self = std::dynamic_pointer_cast(shared_from_this()); + // Wait for in-progress uploads to finish (if async writes are enabled) + return FlushAsync().Then([self]() { return self->FinishPartUploadAfterFlush(); }); } bool closed() const override { return closed_; } diff --git a/cpp/src/arrow/filesystem/s3fs_test.cc b/cpp/src/arrow/filesystem/s3fs_test.cc index e9f14fde72316..f88ee7eef9332 100644 --- a/cpp/src/arrow/filesystem/s3fs_test.cc +++ b/cpp/src/arrow/filesystem/s3fs_test.cc @@ -590,6 +590,21 @@ class TestS3FS : public S3TestMixin { AssertObjectContents(client_.get(), "bucket", "somefile", "new data"); } + void TestOpenOutputStreamCloseAsyncDestructor() { + std::shared_ptr stream; + ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/somefile")); + ASSERT_OK(stream->Write("new data")); + // Destructor implicitly closes stream and completes the multipart upload. + // GH-37670: Testing it doesn't matter whether flush is triggered asynchronously + // after CloseAsync or synchronously after stream.reset() since we're just + // checking that `closeAsyncFut` keeps the stream alive until completion + // rather than segfaulting on a dangling stream + auto closeAsyncFut = stream->CloseAsync(); + stream.reset(); + ASSERT_OK(closeAsyncFut.MoveResult()); + AssertObjectContents(client_.get(), "bucket", "somefile", "new data"); + } + protected: S3Options options_; std::shared_ptr fs_; @@ -1177,6 +1192,16 @@ TEST_F(TestS3FS, OpenOutputStreamDestructorSyncWrite) { TestOpenOutputStreamDestructor(); } +TEST_F(TestS3FS, OpenOutputStreamAsyncDestructorBackgroundWrites) { + TestOpenOutputStreamCloseAsyncDestructor(); +} + +TEST_F(TestS3FS, OpenOutputStreamAsyncDestructorSyncWrite) { + options_.background_writes = false; + MakeFileSystem(); + TestOpenOutputStreamCloseAsyncDestructor(); +} + TEST_F(TestS3FS, OpenOutputStreamMetadata) { std::shared_ptr stream; diff --git a/cpp/src/arrow/integration/c_data_integration_internal.cc b/cpp/src/arrow/integration/c_data_integration_internal.cc new file mode 100644 index 0000000000000..79e09eaf91a39 --- /dev/null +++ b/cpp/src/arrow/integration/c_data_integration_internal.cc @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/integration/c_data_integration_internal.h" + +#include +#include + +#include "arrow/c/bridge.h" +#include "arrow/integration/json_integration.h" +#include "arrow/io/file.h" +#include "arrow/memory_pool.h" +#include "arrow/pretty_print.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/logging.h" + +namespace arrow::internal::integration { +namespace { + +template +const char* StatusToErrorString(Func&& func) { + static std::string error; + + Status st = func(); + if (st.ok()) { + return nullptr; + } + error = st.ToString(); + ARROW_CHECK_GT(error.length(), 0); + return error.c_str(); +} + +Result> ReadSchemaFromJson(const std::string& json_path, + MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE(auto file, io::ReadableFile::Open(json_path, pool)); + ARROW_ASSIGN_OR_RAISE(auto reader, IntegrationJsonReader::Open(pool, file)); + return reader->schema(); +} + +Result> ReadBatchFromJson(const std::string& json_path, + int num_batch, MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE(auto file, io::ReadableFile::Open(json_path, pool)); + ARROW_ASSIGN_OR_RAISE(auto reader, IntegrationJsonReader::Open(pool, file)); + return reader->ReadRecordBatch(num_batch); +} + +// XXX ideally, we should allow use of a custom memory pool in the C bridge API, +// but that requires non-trivial refactor + +Status ExportSchemaFromJson(std::string json_path, ArrowSchema* out) { + auto pool = default_memory_pool(); + ARROW_ASSIGN_OR_RAISE(auto schema, ReadSchemaFromJson(json_path, pool)); + return ExportSchema(*schema, out); +} + +Status ImportSchemaAndCompareToJson(std::string json_path, ArrowSchema* c_schema) { + auto pool = default_memory_pool(); + ARROW_ASSIGN_OR_RAISE(auto json_schema, ReadSchemaFromJson(json_path, pool)); + ARROW_ASSIGN_OR_RAISE(auto imported_schema, ImportSchema(c_schema)); + if (!imported_schema->Equals(json_schema, /*check_metadata=*/true)) { + return Status::Invalid("Schemas are different:", "\n- Json Schema: ", *json_schema, + "\n- Imported Schema: ", *imported_schema); + } + return Status::OK(); +} + +Status ExportBatchFromJson(std::string json_path, int num_batch, ArrowArray* out) { + auto pool = default_memory_pool(); + ARROW_ASSIGN_OR_RAISE(auto batch, ReadBatchFromJson(json_path, num_batch, pool)); + return ExportRecordBatch(*batch, out); +} + +Status ImportBatchAndCompareToJson(std::string json_path, int num_batch, + ArrowArray* c_batch) { + auto pool = default_memory_pool(); + ARROW_ASSIGN_OR_RAISE(auto batch, ReadBatchFromJson(json_path, num_batch, pool)); + ARROW_ASSIGN_OR_RAISE(auto imported_batch, ImportRecordBatch(c_batch, batch->schema())); + RETURN_NOT_OK(imported_batch->ValidateFull()); + if (!imported_batch->Equals(*batch, /*check_metadata=*/true)) { + std::stringstream pp_expected; + std::stringstream pp_actual; + PrettyPrintOptions options(/*indent=*/2); + options.window = 50; + ARROW_CHECK_OK(PrettyPrint(*batch, options, &pp_expected)); + ARROW_CHECK_OK(PrettyPrint(*imported_batch, options, &pp_actual)); + return Status::Invalid("Record Batches are different:", "\n- Json Batch: ", + pp_expected.str(), "\n- Imported Batch: ", pp_actual.str()); + } + return Status::OK(); +} + +} // namespace +} // namespace arrow::internal::integration + +const char* ArrowCpp_CDataIntegration_ExportSchemaFromJson(const char* json_path, + ArrowSchema* out) { + using namespace arrow::internal::integration; // NOLINT(build/namespaces) + return StatusToErrorString([=]() { return ExportSchemaFromJson(json_path, out); }); +} + +const char* ArrowCpp_CDataIntegration_ImportSchemaAndCompareToJson(const char* json_path, + ArrowSchema* schema) { + using namespace arrow::internal::integration; // NOLINT(build/namespaces) + return StatusToErrorString( + [=]() { return ImportSchemaAndCompareToJson(json_path, schema); }); +} + +const char* ArrowCpp_CDataIntegration_ExportBatchFromJson(const char* json_path, + int num_batch, + ArrowArray* out) { + using namespace arrow::internal::integration; // NOLINT(build/namespaces) + return StatusToErrorString( + [=]() { return ExportBatchFromJson(json_path, num_batch, out); }); +} + +const char* ArrowCpp_CDataIntegration_ImportBatchAndCompareToJson(const char* json_path, + int num_batch, + ArrowArray* batch) { + using namespace arrow::internal::integration; // NOLINT(build/namespaces) + return StatusToErrorString( + [=]() { return ImportBatchAndCompareToJson(json_path, num_batch, batch); }); +} + +int64_t ArrowCpp_BytesAllocated() { + auto pool = arrow::default_memory_pool(); + return pool->bytes_allocated(); +} diff --git a/cpp/src/arrow/integration/c_data_integration_internal.h b/cpp/src/arrow/integration/c_data_integration_internal.h new file mode 100644 index 0000000000000..0a62363dffab3 --- /dev/null +++ b/cpp/src/arrow/integration/c_data_integration_internal.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/c/abi.h" +#include "arrow/util/visibility.h" + +// This file only serves as documentation for the C Data Interface integration +// entrypoints. The actual functions are called by Archery through DLL symbol lookup. + +extern "C" { + +ARROW_EXPORT +const char* ArrowCpp_CDataIntegration_ExportSchemaFromJson(const char* json_path, + ArrowSchema* out); + +ARROW_EXPORT +const char* ArrowCpp_CDataIntegration_ImportSchemaAndCompareToJson(const char* json_path, + ArrowSchema* schema); + +ARROW_EXPORT +const char* ArrowCpp_CDataIntegration_ExportBatchFromJson(const char* json_path, + int num_batch, ArrowArray* out); + +ARROW_EXPORT +const char* ArrowCpp_CDataIntegration_ImportBatchAndCompareToJson(const char* json_path, + int num_batch, + ArrowArray* batch); + +ARROW_EXPORT +int64_t ArrowCpp_BytesAllocated(); + +} // extern "C" diff --git a/cpp/src/arrow/integration/json_integration.cc b/cpp/src/arrow/integration/json_integration.cc index 178abe5e8b687..590f6eddd7c24 100644 --- a/cpp/src/arrow/integration/json_integration.cc +++ b/cpp/src/arrow/integration/json_integration.cc @@ -144,10 +144,9 @@ class IntegrationJsonReader::Impl { } Result> ReadRecordBatch(int i) { - DCHECK_GE(i, 0) << "i out of bounds"; - DCHECK_LT(i, static_cast(record_batches_->GetArray().Size())) - << "i out of bounds"; - + if (i < 0 || i >= static_cast(record_batches_->GetArray().Size())) { + return Status::IndexError("record batch index ", i, " out of bounds"); + } return json::ReadRecordBatch(record_batches_->GetArray()[i], schema_, &dictionary_memo_, pool_); } diff --git a/cpp/src/arrow/io/interfaces.cc b/cpp/src/arrow/io/interfaces.cc index e7819e139f67a..d3229fd067cbe 100644 --- a/cpp/src/arrow/io/interfaces.cc +++ b/cpp/src/arrow/io/interfaces.cc @@ -123,7 +123,8 @@ Result> InputStream::ReadMetadata() { // executor Future> InputStream::ReadMetadataAsync( const IOContext& ctx) { - auto self = shared_from_this(); + std::shared_ptr self = + std::dynamic_pointer_cast(shared_from_this()); return DeferNotOk(internal::SubmitIO(ctx, [self] { return self->ReadMetadata(); })); } @@ -165,7 +166,7 @@ Result> RandomAccessFile::ReadAt(int64_t position, Future> RandomAccessFile::ReadAsync(const IOContext& ctx, int64_t position, int64_t nbytes) { - auto self = checked_pointer_cast(shared_from_this()); + auto self = std::dynamic_pointer_cast(shared_from_this()); return DeferNotOk(internal::SubmitIO( ctx, [self, position, nbytes] { return self->ReadAt(position, nbytes); })); } diff --git a/cpp/src/arrow/io/interfaces.h b/cpp/src/arrow/io/interfaces.h index dcbe4feb261fb..d2a11b7b6d7ce 100644 --- a/cpp/src/arrow/io/interfaces.h +++ b/cpp/src/arrow/io/interfaces.h @@ -96,7 +96,7 @@ struct ARROW_EXPORT IOContext { StopToken stop_token_; }; -class ARROW_EXPORT FileInterface { +class ARROW_EXPORT FileInterface : public std::enable_shared_from_this { public: virtual ~FileInterface() = 0; @@ -205,9 +205,7 @@ class ARROW_EXPORT OutputStream : virtual public FileInterface, public Writable OutputStream() = default; }; -class ARROW_EXPORT InputStream : virtual public FileInterface, - virtual public Readable, - public std::enable_shared_from_this { +class ARROW_EXPORT InputStream : virtual public FileInterface, virtual public Readable { public: /// \brief Advance or skip stream indicated number of bytes /// \param[in] nbytes the number to move forward diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 69b827b8fe78d..3ae007c20efe7 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -1519,6 +1519,22 @@ class ReaderWriterMixin : public ExtensionTypesMixin { } } + void TestWriteAfterClose() { + // Part of GH-35095. + std::shared_ptr batch_ints; + ASSERT_OK(MakeIntRecordBatch(&batch_ints)); + + auto schema = batch_ints->schema(); + + WriterHelper writer_helper; + ASSERT_OK(writer_helper.Init(schema, IpcWriteOptions::Defaults())); + ASSERT_OK(writer_helper.WriteBatch(batch_ints)); + ASSERT_OK(writer_helper.Finish()); + + // Write after close raises status + ASSERT_RAISES(Invalid, writer_helper.WriteBatch(batch_ints)); + } + void TestWriteDifferentSchema() { // Test writing batches with a different schema than the RecordBatchWriter // was initialized with. @@ -1991,6 +2007,9 @@ TEST_F(TestFileFormatGenerator, DictionaryRoundTrip) { TestDictionaryRoundtrip() TEST_F(TestFileFormatGeneratorCoalesced, DictionaryRoundTrip) { TestDictionaryRoundtrip(); } +TEST_F(TestFileFormat, WriteAfterClose) { TestWriteAfterClose(); } + +TEST_F(TestStreamFormat, WriteAfterClose) { TestWriteAfterClose(); } TEST_F(TestStreamFormat, DifferentSchema) { TestWriteDifferentSchema(); } diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 1d230601566a0..e4b49ed56464e 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -1070,6 +1070,9 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { Status WriteRecordBatch( const RecordBatch& batch, const std::shared_ptr& custom_metadata) override { + if (closed_) { + return Status::Invalid("Destination already closed"); + } if (!batch.schema()->Equals(schema_, false /* check_metadata */)) { return Status::Invalid("Tried to write record batch with different schema"); } @@ -1101,7 +1104,9 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { Status Close() override { RETURN_NOT_OK(CheckStarted()); - return payload_writer_->Close(); + RETURN_NOT_OK(payload_writer_->Close()); + closed_ = true; + return Status::OK(); } Status Start() { @@ -1213,6 +1218,7 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { std::unordered_map> last_dictionaries_; bool started_ = false; + bool closed_ = false; IpcWriteOptions options_; WriteStats stats_; }; diff --git a/cpp/src/arrow/symbols.map b/cpp/src/arrow/symbols.map index 9ef0e404bc091..0144e6116554b 100644 --- a/cpp/src/arrow/symbols.map +++ b/cpp/src/arrow/symbols.map @@ -32,6 +32,7 @@ }; # Also export C-level helpers arrow_*; + Arrow*; # ARROW-14771: export Protobuf symbol table descriptor_table_Flight_2eproto; descriptor_table_FlightSql_2eproto; diff --git a/cpp/src/arrow/util/reflection_internal.h b/cpp/src/arrow/util/reflection_internal.h index d7de913bafd88..5d281a265ff71 100644 --- a/cpp/src/arrow/util/reflection_internal.h +++ b/cpp/src/arrow/util/reflection_internal.h @@ -71,6 +71,30 @@ constexpr DataMemberProperty DataMember(std::string_view name, return {name, ptr}; } +template +struct CoercedDataMemberProperty { + using Class = C; + using Type = T; + + constexpr Type get(const Class& obj) const { return (obj.*get_coerced_)(); } + + void set(Class* obj, Type value) const { (*obj).*ptr_for_set_ = std::move(value); } + + constexpr std::string_view name() const { return name_; } + + std::string_view name_; + Type Class::*ptr_for_set_; + Type (Class::*get_coerced_)() const; +}; + +template +constexpr CoercedDataMemberProperty CoercedDataMember(std::string_view name, + Type Class::*ptr, + Type (Class::*get)() + const) { + return {name, ptr, get}; +} + template struct PropertyTuple { template diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index 70f865cc2fa70..7a3b45f9788e6 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -723,8 +723,12 @@ def _set_default(opt, default): envvar="ARCHERY_INTEGRATION_WITH_RUST") @click.option('--write_generated_json', default="", help='Generate test JSON to indicated path') +@click.option('--run-ipc', is_flag=True, default=False, + help='Run IPC integration tests') @click.option('--run-flight', is_flag=True, default=False, help='Run Flight integration tests') +@click.option('--run-c-data', is_flag=True, default=False, + help='Run C Data Interface integration tests') @click.option('--debug', is_flag=True, default=False, help='Run executables in debug mode as relevant') @click.option('--serial', is_flag=True, default=False, @@ -753,15 +757,19 @@ def integration(with_all=False, random_seed=12345, **args): gen_path = args['write_generated_json'] languages = ['cpp', 'csharp', 'java', 'js', 'go', 'rust'] + formats = ['ipc', 'flight', 'c_data'] enabled_languages = 0 for lang in languages: - param = 'with_{}'.format(lang) + param = f'with_{lang}' if with_all: args[param] = with_all + enabled_languages += args[param] - if args[param]: - enabled_languages += 1 + enabled_formats = 0 + for fmt in formats: + param = f'run_{fmt}' + enabled_formats += args[param] if gen_path: # XXX See GH-37575: this option is only used by the JS test suite @@ -769,8 +777,13 @@ def integration(with_all=False, random_seed=12345, **args): os.makedirs(gen_path, exist_ok=True) write_js_test_json(gen_path) else: + if enabled_formats == 0: + raise click.UsageError( + "Need to enable at least one format to test " + "(IPC, Flight, C Data Interface); try --help") if enabled_languages == 0: - raise Exception("Must enable at least 1 language to test") + raise click.UsageError( + "Need to enable at least one language to test; try --help") run_all_tests(**args) diff --git a/dev/archery/archery/integration/cdata.py b/dev/archery/archery/integration/cdata.py new file mode 100644 index 0000000000000..c201f5f867f8f --- /dev/null +++ b/dev/archery/archery/integration/cdata.py @@ -0,0 +1,107 @@ +# licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import cffi +from contextlib import contextmanager +import functools + +from .tester import CDataExporter, CDataImporter + + +_c_data_decls = """ + struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; + }; + + struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; + }; + + struct ArrowArrayStream { + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback + void (*release)(struct ArrowArrayStream*); + // Opaque producer-specific data + void* private_data; + }; + """ + + +@functools.lru_cache +def ffi() -> cffi.FFI: + """ + Return a FFI object supporting C Data Interface types. + """ + ffi = cffi.FFI() + ffi.cdef(_c_data_decls) + return ffi + + +@contextmanager +def check_memory_released(exporter: CDataExporter, importer: CDataImporter): + """ + A context manager for memory release checks. + + The context manager arranges cooperation between the exporter and importer + to try and release memory at the end of the enclosed block. + + However, if either the exporter or importer doesn't support deterministic + memory release, no memory check is performed. + """ + do_check = (exporter.supports_releasing_memory and + importer.supports_releasing_memory) + if do_check: + before = exporter.record_allocation_state() + yield + # We don't use a `finally` clause: if the enclosed block raised an + # exception, no need to add another one. + if do_check: + ok = exporter.compare_allocation_state(before, importer.gc_until) + if not ok: + after = exporter.record_allocation_state() + raise RuntimeError( + f"Memory was not released correctly after roundtrip: " + f"before = {before}, after = {after} (should have been equal)") diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index e83fa0152931a..299881c4b613a 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -25,6 +25,7 @@ import numpy as np from .util import frombytes, tobytes, random_bytes, random_utf8 +from .util import SKIP_C_SCHEMA, SKIP_C_ARRAY def metadata_key_values(pairs): @@ -664,6 +665,26 @@ def _get_type(self): return OrderedDict([('name', 'largeutf8')]) +class BinaryViewField(BinaryField): + + @property + def column_class(self): + return BinaryViewColumn + + def _get_type(self): + return OrderedDict([('name', 'binaryview')]) + + +class StringViewField(StringField): + + @property + def column_class(self): + return StringViewColumn + + def _get_type(self): + return OrderedDict([('name', 'utf8view')]) + + class Schema(object): def __init__(self, fields, metadata=None): @@ -743,6 +764,74 @@ class LargeStringColumn(_BaseStringColumn, _LargeOffsetsMixin): pass +class BinaryViewColumn(PrimitiveColumn): + + def _encode_value(self, x): + return frombytes(binascii.hexlify(x).upper()) + + def _get_buffers(self): + views = [] + data_buffers = [] + # a small default data buffer size is used so we can exercise + # arrays with multiple data buffers with small data sets + DEFAULT_BUFFER_SIZE = 32 + INLINE_SIZE = 12 + + for i, v in enumerate(self.values): + if not self.is_valid[i]: + v = b'' + assert isinstance(v, bytes) + + if len(v) <= INLINE_SIZE: + # Append an inline view, skip data buffer management. + views.append(OrderedDict([ + ('SIZE', len(v)), + ('INLINED', self._encode_value(v)), + ])) + continue + + if len(data_buffers) == 0: + # No data buffers have been added yet; + # add this string whole (we may append to it later). + offset = 0 + data_buffers.append(v) + elif len(data_buffers[-1]) + len(v) > DEFAULT_BUFFER_SIZE: + # Appending this string to the current active data buffer + # would overflow the default buffer size; add it whole. + offset = 0 + data_buffers.append(v) + else: + # Append this string to the current active data buffer. + offset = len(data_buffers[-1]) + data_buffers[-1] += v + + # the prefix is always 4 bytes so it may not be utf-8 + # even if the whole string view is + prefix = frombytes(binascii.hexlify(v[:4]).upper()) + + views.append(OrderedDict([ + ('SIZE', len(v)), + ('PREFIX_HEX', prefix), + ('BUFFER_INDEX', len(data_buffers) - 1), + ('OFFSET', offset), + ])) + + return [ + ('VALIDITY', [int(x) for x in self.is_valid]), + ('VIEWS', views), + ('VARIADIC_DATA_BUFFERS', [ + frombytes(binascii.hexlify(b).upper()) + for b in data_buffers + ]), + ] + + +class StringViewColumn(BinaryViewColumn): + + def _encode_value(self, x): + return frombytes(x) + + class FixedSizeBinaryColumn(PrimitiveColumn): def _encode_value(self, x): @@ -1224,15 +1313,16 @@ def get_json(self): class File(object): def __init__(self, name, schema, batches, dictionaries=None, - skip=None, path=None, quirks=None): + skip_testers=None, path=None, quirks=None): self.name = name self.schema = schema self.dictionaries = dictionaries or [] self.batches = batches - self.skip = set() + self.skipped_testers = set() + self.skipped_formats = {} self.path = path - if skip: - self.skip.update(skip) + if skip_testers: + self.skipped_testers.update(skip_testers) # For tracking flags like whether to validate decimal values # fit into the given precision (ARROW-13558). self.quirks = set() @@ -1258,14 +1348,39 @@ def write(self, path): f.write(json.dumps(self.get_json(), indent=2).encode('utf-8')) self.path = path - def skip_category(self, category): - """Skip this test for the given category. + def skip_tester(self, tester): + """Skip this test for the given tester (such as 'C#'). + """ + self.skipped_testers.add(tester) + return self - Category should be SKIP_ARROW or SKIP_FLIGHT. + def skip_format(self, format, tester='all'): + """Skip this test for the given format, and optionally tester. """ - self.skip.add(category) + self.skipped_formats.setdefault(format, set()).add(tester) return self + def add_skips_from(self, other_file): + """Add skips from another File object. + """ + self.skipped_testers.update(other_file.skipped_testers) + for format, testers in other_file.skipped_formats.items(): + self.skipped_formats.setdefault(format, set()).update(testers) + + def should_skip(self, tester, format): + """Whether this (tester, format) combination should be skipped. + """ + if tester in self.skipped_testers: + return True + testers = self.skipped_formats.get(format, ()) + return 'all' in testers or tester in testers + + @property + def num_batches(self): + """The number of record batches in this file. + """ + return len(self.batches) + def get_field(name, type_, **kwargs): if type_ == 'binary': @@ -1295,8 +1410,8 @@ def get_field(name, type_, **kwargs): raise TypeError(dtype) -def _generate_file(name, fields, batch_sizes, dictionaries=None, skip=None, - metadata=None): +def _generate_file(name, fields, batch_sizes, *, + dictionaries=None, metadata=None): schema = Schema(fields, metadata=metadata) batches = [] for size in batch_sizes: @@ -1307,7 +1422,7 @@ def _generate_file(name, fields, batch_sizes, dictionaries=None, skip=None, batches.append(RecordBatch(size, columns)) - return File(name, schema, batches, dictionaries, skip=skip) + return File(name, schema, batches, dictionaries) def generate_custom_metadata_case(): @@ -1541,6 +1656,15 @@ def generate_run_end_encoded_case(): return _generate_file("run_end_encoded", fields, batch_sizes) +def generate_binary_view_case(): + fields = [ + BinaryViewField('bv'), + StringViewField('sv'), + ] + batch_sizes = [0, 7, 256] + return _generate_file("binary_view", fields, batch_sizes) + + def generate_nested_large_offsets_case(): fields = [ LargeListField('large_list_nullable', get_field('item', 'int32')), @@ -1666,8 +1790,8 @@ def _temp_path(): generate_primitive_case([0, 0, 0], name='primitive_zerolength'), generate_primitive_large_offsets_case([17, 20]) - .skip_category('C#') - .skip_category('JS'), + .skip_tester('C#') + .skip_tester('JS'), generate_null_case([10, 0]), @@ -1676,65 +1800,78 @@ def _temp_path(): generate_decimal128_case(), generate_decimal256_case() - .skip_category('JS'), + .skip_tester('JS'), generate_datetime_case(), generate_duration_case() - .skip_category('C#') - .skip_category('JS'), # TODO(ARROW-5239): Intervals + JS + .skip_tester('C#') + .skip_tester('JS'), # TODO(ARROW-5239): Intervals + JS generate_interval_case() - .skip_category('C#') - .skip_category('JS'), # TODO(ARROW-5239): Intervals + JS + .skip_tester('C#') + .skip_tester('JS'), # TODO(ARROW-5239): Intervals + JS generate_month_day_nano_interval_case() - .skip_category('C#') - .skip_category('JS'), + .skip_tester('C#') + .skip_tester('JS'), generate_map_case() - .skip_category('C#'), + .skip_tester('C#'), generate_non_canonical_map_case() - .skip_category('C#') - .skip_category('Java'), # TODO(ARROW-8715) + .skip_tester('C#') + .skip_tester('Java') # TODO(ARROW-8715) + # Canonical map names are restored on import, so the schemas are unequal + .skip_format(SKIP_C_SCHEMA, 'C++'), generate_nested_case(), generate_recursive_nested_case(), generate_nested_large_offsets_case() - .skip_category('C#') - .skip_category('JS'), + .skip_tester('C#') + .skip_tester('JS'), generate_unions_case(), generate_custom_metadata_case() - .skip_category('C#'), + .skip_tester('C#'), generate_duplicate_fieldnames_case() - .skip_category('C#') - .skip_category('JS'), + .skip_tester('C#') + .skip_tester('JS'), generate_dictionary_case() - .skip_category('C#'), + .skip_tester('C#'), generate_dictionary_unsigned_case() - .skip_category('C#') - .skip_category('Java'), # TODO(ARROW-9377) + .skip_tester('C#') + .skip_tester('Java'), # TODO(ARROW-9377) generate_nested_dictionary_case() - .skip_category('C#') - .skip_category('Java'), # TODO(ARROW-7779) + .skip_tester('C#') + .skip_tester('Java'), # TODO(ARROW-7779) generate_run_end_encoded_case() - .skip_category('C#') - .skip_category('Java') - .skip_category('JS') - .skip_category('Rust'), + .skip_tester('C#') + .skip_tester('Java') + .skip_tester('JS') + .skip_tester('Rust'), + + generate_binary_view_case() + .skip_tester('C++') + .skip_tester('C#') + .skip_tester('Go') + .skip_tester('Java') + .skip_tester('JS') + .skip_tester('Rust'), generate_extension_case() - .skip_category('C#'), + .skip_tester('C#') + # TODO: ensure the extension is registered in the C++ entrypoint + .skip_format(SKIP_C_SCHEMA, 'C++') + .skip_format(SKIP_C_ARRAY, 'C++'), ] generated_paths = [] diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 0ee9ab814e5e6..2fd1d2d7f0c44 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -25,17 +25,19 @@ import sys import tempfile import traceback -from typing import Callable, List +from typing import Callable, List, Optional +from . import cdata from .scenario import Scenario -from .tester import Tester -from .tester_cpp import CPPTester +from .tester import Tester, CDataExporter, CDataImporter +from .tester_cpp import CppTester from .tester_go import GoTester from .tester_rust import RustTester from .tester_java import JavaTester from .tester_js import JSTester from .tester_csharp import CSharpTester -from .util import guid, SKIP_ARROW, SKIP_FLIGHT, printer +from .util import guid, printer +from .util import SKIP_C_ARRAY, SKIP_C_SCHEMA, SKIP_FLIGHT, SKIP_IPC from ..utils.source import ARROW_ROOT_DEFAULT from . import datagen @@ -76,7 +78,7 @@ def __init__(self, json_files, self.json_files = [json_file for json_file in self.json_files if self.match in json_file.name] - def run(self): + def run_ipc(self): """ Run Arrow IPC integration tests for the matrix of enabled implementations. @@ -84,23 +86,24 @@ def run(self): for producer, consumer in itertools.product( filter(lambda t: t.PRODUCER, self.testers), filter(lambda t: t.CONSUMER, self.testers)): - self._compare_implementations( + self._compare_ipc_implementations( producer, consumer, self._produce_consume, self.json_files) if self.gold_dirs: for gold_dir, consumer in itertools.product( self.gold_dirs, filter(lambda t: t.CONSUMER, self.testers)): - log('\n\n\n\n') + log('\n') log('******************************************************') log('Tests against golden files in {}'.format(gold_dir)) log('******************************************************') def run_gold(_, consumer, test_case: datagen.File): return self._run_gold(gold_dir, consumer, test_case) - self._compare_implementations( + self._compare_ipc_implementations( consumer, consumer, run_gold, self._gold_tests(gold_dir)) + log('\n') def run_flight(self): """ @@ -112,6 +115,18 @@ def run_flight(self): self.testers) for server, client in itertools.product(servers, clients): self._compare_flight_implementations(server, client) + log('\n') + + def run_c_data(self): + """ + Run Arrow C Data interface integration tests for the matrix of + enabled implementations. + """ + for producer, consumer in itertools.product( + filter(lambda t: t.C_DATA_EXPORTER, self.testers), + filter(lambda t: t.C_DATA_IMPORTER, self.testers)): + self._compare_c_data_implementations(producer, consumer) + log('\n') def _gold_tests(self, gold_dir): prefix = os.path.basename(os.path.normpath(gold_dir)) @@ -125,28 +140,31 @@ def _gold_tests(self, gold_dir): with open(out_path, "wb") as out: out.write(i.read()) + # Find the generated file with the same name as this gold file try: - skip = next(f for f in self.json_files - if f.name == name).skip + equiv_json_file = next(f for f in self.json_files + if f.name == name) except StopIteration: - skip = set() + equiv_json_file = None + + skip_testers = set() if name == 'union' and prefix == '0.17.1': - skip.add("Java") - skip.add("JS") + skip_testers.add("Java") + skip_testers.add("JS") if prefix == '1.0.0-bigendian' or prefix == '1.0.0-littleendian': - skip.add("C#") - skip.add("Java") - skip.add("JS") - skip.add("Rust") + skip_testers.add("C#") + skip_testers.add("Java") + skip_testers.add("JS") + skip_testers.add("Rust") if prefix == '2.0.0-compression': - skip.add("C#") - skip.add("JS") + skip_testers.add("C#") + skip_testers.add("JS") # See https://github.com/apache/arrow/pull/9822 for how to # disable specific compression type tests. if prefix == '4.0.0-shareddict': - skip.add("C#") + skip_testers.add("C#") quirks = set() if prefix in {'0.14.1', '0.17.1', @@ -157,12 +175,18 @@ def _gold_tests(self, gold_dir): quirks.add("no_date64_validate") quirks.add("no_times_validate") - yield datagen.File(name, None, None, skip=skip, path=out_path, - quirks=quirks) + json_file = datagen.File(name, schema=None, batches=None, + path=out_path, + skip_testers=skip_testers, + quirks=quirks) + if equiv_json_file is not None: + json_file.add_skips_from(equiv_json_file) + yield json_file def _run_test_cases(self, case_runner: Callable[[datagen.File], Outcome], - test_cases: List[datagen.File]) -> None: + test_cases: List[datagen.File], + *, serial: Optional[bool] = None) -> None: """ Populate self.failures with the outcomes of the ``case_runner`` ran against ``test_cases`` @@ -171,10 +195,13 @@ def case_wrapper(test_case): with printer.cork(): return case_runner(test_case) + if serial is None: + serial = self.serial + if self.failures and self.stop_on_error: return - if self.serial: + if serial: for outcome in map(case_wrapper, test_cases): if outcome.failure is not None: self.failures.append(outcome.failure) @@ -189,7 +216,7 @@ def case_wrapper(test_case): if self.stop_on_error: break - def _compare_implementations( + def _compare_ipc_implementations( self, producer: Tester, consumer: Tester, @@ -221,22 +248,17 @@ def _run_ipc_test_case( outcome = Outcome() json_path = test_case.path - log('==========================================================') + log('=' * 70) log('Testing file {0}'.format(json_path)) - log('==========================================================') - - if producer.name in test_case.skip: - log('-- Skipping test because producer {0} does ' - 'not support'.format(producer.name)) - outcome.skipped = True - elif consumer.name in test_case.skip: - log('-- Skipping test because consumer {0} does ' - 'not support'.format(consumer.name)) + if test_case.should_skip(producer.name, SKIP_IPC): + log(f'-- Skipping test because producer {producer.name} does ' + f'not support IPC') outcome.skipped = True - elif SKIP_ARROW in test_case.skip: - log('-- Skipping test') + elif test_case.should_skip(consumer.name, SKIP_IPC): + log(f'-- Skipping test because consumer {consumer.name} does ' + f'not support IPC') outcome.skipped = True else: @@ -247,6 +269,8 @@ def _run_ipc_test_case( outcome.failure = Failure(test_case, producer, consumer, sys.exc_info()) + log('=' * 70) + return outcome def _produce_consume(self, @@ -344,22 +368,17 @@ def _run_flight_test_case(self, """ outcome = Outcome() - log('=' * 58) + log('=' * 70) log('Testing file {0}'.format(test_case.name)) - log('=' * 58) - - if producer.name in test_case.skip: - log('-- Skipping test because producer {0} does ' - 'not support'.format(producer.name)) - outcome.skipped = True - elif consumer.name in test_case.skip: - log('-- Skipping test because consumer {0} does ' - 'not support'.format(consumer.name)) + if test_case.should_skip(producer.name, SKIP_FLIGHT): + log(f'-- Skipping test because producer {producer.name} does ' + f'not support Flight') outcome.skipped = True - elif SKIP_FLIGHT in test_case.skip: - log('-- Skipping test') + elif test_case.should_skip(consumer.name, SKIP_FLIGHT): + log(f'-- Skipping test because consumer {consumer.name} does ' + f'not support Flight') outcome.skipped = True else: @@ -380,6 +399,125 @@ def _run_flight_test_case(self, outcome.failure = Failure(test_case, producer, consumer, sys.exc_info()) + log('=' * 70) + + return outcome + + def _compare_c_data_implementations( + self, + producer: Tester, + consumer: Tester + ): + log('##########################################################') + log(f'C Data Interface: ' + f'{producer.name} exporting, {consumer.name} importing') + log('##########################################################') + + # Serial execution is required for proper memory accounting + serial = True + + exporter = producer.make_c_data_exporter() + importer = consumer.make_c_data_importer() + + case_runner = partial(self._run_c_schema_test_case, producer, consumer, + exporter, importer) + self._run_test_cases(case_runner, self.json_files, serial=serial) + + case_runner = partial(self._run_c_array_test_cases, producer, consumer, + exporter, importer) + self._run_test_cases(case_runner, self.json_files, serial=serial) + + def _run_c_schema_test_case(self, + producer: Tester, consumer: Tester, + exporter: CDataExporter, + importer: CDataImporter, + test_case: datagen.File) -> Outcome: + """ + Run one C ArrowSchema test case. + """ + outcome = Outcome() + + def do_run(): + json_path = test_case.path + ffi = cdata.ffi() + c_schema_ptr = ffi.new("struct ArrowSchema*") + with cdata.check_memory_released(exporter, importer): + exporter.export_schema_from_json(json_path, c_schema_ptr) + importer.import_schema_and_compare_to_json(json_path, c_schema_ptr) + + log('=' * 70) + log(f'Testing C ArrowSchema from file {test_case.name!r}') + + if test_case.should_skip(producer.name, SKIP_C_SCHEMA): + log(f'-- Skipping test because producer {producer.name} does ' + f'not support C ArrowSchema') + outcome.skipped = True + + elif test_case.should_skip(consumer.name, SKIP_C_SCHEMA): + log(f'-- Skipping test because consumer {consumer.name} does ' + f'not support C ArrowSchema') + outcome.skipped = True + + else: + try: + do_run() + except Exception: + traceback.print_exc(file=printer.stdout) + outcome.failure = Failure(test_case, producer, consumer, + sys.exc_info()) + + log('=' * 70) + + return outcome + + def _run_c_array_test_cases(self, + producer: Tester, consumer: Tester, + exporter: CDataExporter, + importer: CDataImporter, + test_case: datagen.File) -> Outcome: + """ + Run one set C ArrowArray test cases. + """ + outcome = Outcome() + + def do_run(): + json_path = test_case.path + ffi = cdata.ffi() + c_array_ptr = ffi.new("struct ArrowArray*") + for num_batch in range(test_case.num_batches): + log(f'... with record batch #{num_batch}') + with cdata.check_memory_released(exporter, importer): + exporter.export_batch_from_json(json_path, + num_batch, + c_array_ptr) + importer.import_batch_and_compare_to_json(json_path, + num_batch, + c_array_ptr) + + log('=' * 70) + log(f'Testing C ArrowArray ' + f'from file {test_case.name!r}') + + if test_case.should_skip(producer.name, SKIP_C_ARRAY): + log(f'-- Skipping test because producer {producer.name} does ' + f'not support C ArrowArray') + outcome.skipped = True + + elif test_case.should_skip(consumer.name, SKIP_C_ARRAY): + log(f'-- Skipping test because consumer {consumer.name} does ' + f'not support C ArrowArray') + outcome.skipped = True + + else: + try: + do_run() + except Exception: + traceback.print_exc(file=printer.stdout) + outcome.failure = Failure(test_case, producer, consumer, + sys.exc_info()) + + log('=' * 70) + return outcome @@ -387,7 +525,7 @@ def get_static_json_files(): glob_pattern = os.path.join(ARROW_ROOT_DEFAULT, 'integration', 'data', '*.json') return [ - datagen.File(name=os.path.basename(p), path=p, skip=set(), + datagen.File(name=os.path.basename(p), path=p, schema=None, batches=None) for p in glob.glob(glob_pattern) ] @@ -395,13 +533,14 @@ def get_static_json_files(): def run_all_tests(with_cpp=True, with_java=True, with_js=True, with_csharp=True, with_go=True, with_rust=False, - run_flight=False, tempdir=None, **kwargs): + run_ipc=False, run_flight=False, run_c_data=False, + tempdir=None, **kwargs): tempdir = tempdir or tempfile.mkdtemp(prefix='arrow-integration-') testers: List[Tester] = [] if with_cpp: - testers.append(CPPTester(**kwargs)) + testers.append(CppTester(**kwargs)) if with_java: testers.append(JavaTester(**kwargs)) @@ -434,54 +573,57 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, Scenario( "ordered", description="Ensure FlightInfo.ordered is supported.", - skip={"JS", "C#", "Rust"}, + skip_testers={"JS", "C#", "Rust"}, ), Scenario( "expiration_time:do_get", description=("Ensure FlightEndpoint.expiration_time with " "DoGet is working as expected."), - skip={"JS", "C#", "Rust"}, + skip_testers={"JS", "C#", "Rust"}, ), Scenario( "expiration_time:list_actions", description=("Ensure FlightEndpoint.expiration_time related " "pre-defined actions is working with ListActions " "as expected."), - skip={"JS", "C#", "Rust"}, + skip_testers={"JS", "C#", "Rust"}, ), Scenario( "expiration_time:cancel_flight_info", description=("Ensure FlightEndpoint.expiration_time and " "CancelFlightInfo are working as expected."), - skip={"JS", "C#", "Rust"}, + skip_testers={"JS", "C#", "Rust"}, ), Scenario( "expiration_time:renew_flight_endpoint", description=("Ensure FlightEndpoint.expiration_time and " "RenewFlightEndpoint are working as expected."), - skip={"JS", "C#", "Rust"}, + skip_testers={"JS", "C#", "Rust"}, ), Scenario( "poll_flight_info", description="Ensure PollFlightInfo is supported.", - skip={"JS", "C#", "Rust"} + skip_testers={"JS", "C#", "Rust"} ), Scenario( "flight_sql", description="Ensure Flight SQL protocol is working as expected.", - skip={"Rust"} + skip_testers={"Rust"} ), Scenario( "flight_sql:extension", description="Ensure Flight SQL extensions work as expected.", - skip={"Rust"} + skip_testers={"Rust"} ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) - runner.run() + if run_ipc: + runner.run_ipc() if run_flight: runner.run_flight() + if run_c_data: + runner.run_c_data() fail_count = 0 if runner.failures: @@ -492,7 +634,8 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, log(test_case.name, producer.name, "producing, ", consumer.name, "consuming") if exc_info: - traceback.print_exception(*exc_info) + exc_type, exc_value, exc_tb = exc_info + log(f'{exc_type}: {exc_value}') log() log(fail_count, "failures") diff --git a/dev/archery/archery/integration/scenario.py b/dev/archery/archery/integration/scenario.py index 1fcbca64e6a1f..89c64452e5fc5 100644 --- a/dev/archery/archery/integration/scenario.py +++ b/dev/archery/archery/integration/scenario.py @@ -23,7 +23,10 @@ class Scenario: Does not correspond to a particular IPC JSON file. """ - def __init__(self, name, description, skip=None): + def __init__(self, name, description, skip_testers=None): self.name = name self.description = description - self.skip = skip or set() + self.skipped_testers = skip_testers or set() + + def should_skip(self, tester, format): + return tester in self.skipped_testers diff --git a/dev/archery/archery/integration/tester.py b/dev/archery/archery/integration/tester.py index 54bfe621efd92..6a3061992d006 100644 --- a/dev/archery/archery/integration/tester.py +++ b/dev/archery/archery/integration/tester.py @@ -17,12 +17,181 @@ # Base class for language-specific integration test harnesses +from abc import ABC, abstractmethod +import os import subprocess +import typing from .util import log -class Tester(object): +_Predicate = typing.Callable[[], bool] + + +class CDataExporter(ABC): + + @abstractmethod + def export_schema_from_json(self, json_path: os.PathLike, + c_schema_ptr: object): + """ + Read a JSON integration file and export its schema. + + Parameters + ---------- + json_path : Path + Path to the JSON file + c_schema_ptr : cffi pointer value + Pointer to the ``ArrowSchema`` struct to export to. + """ + + @abstractmethod + def export_batch_from_json(self, json_path: os.PathLike, + num_batch: int, + c_array_ptr: object): + """ + Read a JSON integration file and export one of its batches. + + Parameters + ---------- + json_path : Path + Path to the JSON file + num_batch : int + Number of the record batch in the JSON file + c_schema_ptr : cffi pointer value + Pointer to the ``ArrowArray`` struct to export to. + """ + + @property + @abstractmethod + def supports_releasing_memory(self) -> bool: + """ + Whether the implementation is able to release memory deterministically. + + Here, "release memory" means that, after the `release` callback of + a C Data Interface export is called, `compare_allocation_state` is + able to trigger the deallocation of the memory underlying the export + (for example buffer data). + + If false, then `record_allocation_state` and `compare_allocation_state` + are allowed to raise NotImplementedError. + """ + + def record_allocation_state(self) -> object: + """ + Record the current memory allocation state. + + Returns + ------- + state : object + Opaque object representing the allocation state, + for example the number of allocated bytes. + """ + raise NotImplementedError + + def compare_allocation_state(self, recorded: object, + gc_until: typing.Callable[[_Predicate], bool] + ) -> bool: + """ + Compare the current memory allocation state with the recorded one. + + Parameters + ---------- + recorded : object + The previous allocation state returned by + `record_allocation_state()` + gc_until : callable + A callable itself accepting a callable predicate, and + returning a boolean. + `gc_until` should try to release memory until the predicate + becomes true, or until it decides to give up. The final value + of the predicate should be returned. + `gc_until` is typically provided by the C Data Interface importer. + + Returns + ------- + success : bool + Whether memory allocation state finally reached its previously + recorded value. + """ + raise NotImplementedError + + +class CDataImporter(ABC): + + @abstractmethod + def import_schema_and_compare_to_json(self, json_path: os.PathLike, + c_schema_ptr: object): + """ + Import schema and compare it to the schema of a JSON integration file. + + An error is raised if importing fails or the schemas differ. + + Parameters + ---------- + json_path : Path + The path to the JSON file + c_schema_ptr : cffi pointer value + Pointer to the ``ArrowSchema`` struct to import from. + """ + + @abstractmethod + def import_batch_and_compare_to_json(self, json_path: os.PathLike, + num_batch: int, + c_array_ptr: object): + """ + Import record batch and compare it to one of the batches + from a JSON integration file. + + The schema used for importing the record batch is the one from + the JSON file. + + An error is raised if importing fails or the batches differ. + + Parameters + ---------- + json_path : Path + The path to the JSON file + num_batch : int + Number of the record batch in the JSON file + c_array_ptr : cffi pointer value + Pointer to the ``ArrowArray`` struct to import from. + """ + + @property + @abstractmethod + def supports_releasing_memory(self) -> bool: + """ + Whether the implementation is able to release memory deterministically. + + Here, "release memory" means calling the `release` callback of + a C Data Interface export (which should then trigger a deallocation + mechanism on the exporter). + + If false, then `gc_until` is allowed to raise NotImplementedError. + """ + + def gc_until(self, predicate: _Predicate): + """ + Try to release memory until the predicate becomes true, or fail. + + Depending on the CDataImporter implementation, this may for example + try once, or run a garbage collector a given number of times, or + any other implementation-specific strategy for releasing memory. + + The running time should be kept reasonable and compatible with + execution of multiple C Data integration tests. + + This should not raise if `supports_releasing_memory` is true. + + Returns + ------- + success : bool + The final value of the predicate. + """ + raise NotImplementedError + + +class Tester: """ The interface to declare a tester to run integration tests against. """ @@ -34,8 +203,12 @@ class Tester(object): FLIGHT_SERVER = False # whether the language supports receiving Flight FLIGHT_CLIENT = False + # whether the language supports the C Data Interface as an exporter + C_DATA_EXPORTER = False + # whether the language supports the C Data Interface as an importer + C_DATA_IMPORTER = False - # the name shown in the logs + # the name used for skipping and shown in the logs name = "unknown" def __init__(self, debug=False, **args): @@ -85,3 +258,9 @@ def flight_server(self, scenario_name=None): def flight_request(self, port, json_path=None, scenario_name=None): raise NotImplementedError + + def make_c_data_exporter(self) -> CDataExporter: + raise NotImplementedError + + def make_c_data_importer(self) -> CDataImporter: + raise NotImplementedError diff --git a/dev/archery/archery/integration/tester_cpp.py b/dev/archery/archery/integration/tester_cpp.py index 52cc565dc00a3..9ddc3c480002a 100644 --- a/dev/archery/archery/integration/tester_cpp.py +++ b/dev/archery/archery/integration/tester_cpp.py @@ -16,10 +16,12 @@ # under the License. import contextlib +import functools import os import subprocess -from .tester import Tester +from . import cdata +from .tester import Tester, CDataExporter, CDataImporter from .util import run_cmd, log from ..utils.source import ARROW_ROOT_DEFAULT @@ -39,12 +41,19 @@ "localhost", ] +_dll_suffix = ".dll" if os.name == "nt" else ".so" -class CPPTester(Tester): +_DLL_PATH = _EXE_PATH +_ARROW_DLL = os.path.join(_DLL_PATH, "libarrow" + _dll_suffix) + + +class CppTester(Tester): PRODUCER = True CONSUMER = True FLIGHT_SERVER = True FLIGHT_CLIENT = True + C_DATA_EXPORTER = True + C_DATA_IMPORTER = True name = 'C++' @@ -133,3 +142,104 @@ def flight_request(self, port, json_path=None, scenario_name=None): if self.debug: log(' '.join(cmd)) run_cmd(cmd) + + def make_c_data_exporter(self): + return CppCDataExporter(self.debug, self.args) + + def make_c_data_importer(self): + return CppCDataImporter(self.debug, self.args) + + +_cpp_c_data_entrypoints = """ + const char* ArrowCpp_CDataIntegration_ExportSchemaFromJson( + const char* json_path, struct ArrowSchema* out); + const char* ArrowCpp_CDataIntegration_ImportSchemaAndCompareToJson( + const char* json_path, struct ArrowSchema* schema); + + const char* ArrowCpp_CDataIntegration_ExportBatchFromJson( + const char* json_path, int num_batch, struct ArrowArray* out); + const char* ArrowCpp_CDataIntegration_ImportBatchAndCompareToJson( + const char* json_path, int num_batch, struct ArrowArray* batch); + + int64_t ArrowCpp_BytesAllocated(); + """ + + +@functools.lru_cache +def _load_ffi(ffi, lib_path=_ARROW_DLL): + ffi.cdef(_cpp_c_data_entrypoints) + dll = ffi.dlopen(lib_path) + dll.ArrowCpp_CDataIntegration_ExportSchemaFromJson + return dll + + +class _CDataBase: + + def __init__(self, debug, args): + self.debug = debug + self.args = args + self.ffi = cdata.ffi() + self.dll = _load_ffi(self.ffi) + + def _check_c_error(self, c_error): + """ + Check a `const char*` error return from an integration entrypoint. + + A null means success, a non-empty string is an error message. + The string is statically allocated on the C++ side. + """ + assert self.ffi.typeof(c_error) is self.ffi.typeof("const char*") + if c_error != self.ffi.NULL: + error = self.ffi.string(c_error).decode('utf8', + errors='replace') + raise RuntimeError( + f"C++ C Data Integration call failed: {error}") + + +class CppCDataExporter(CDataExporter, _CDataBase): + + def export_schema_from_json(self, json_path, c_schema_ptr): + c_error = self.dll.ArrowCpp_CDataIntegration_ExportSchemaFromJson( + str(json_path).encode(), c_schema_ptr) + self._check_c_error(c_error) + + def export_batch_from_json(self, json_path, num_batch, c_array_ptr): + c_error = self.dll.ArrowCpp_CDataIntegration_ExportBatchFromJson( + str(json_path).encode(), num_batch, c_array_ptr) + self._check_c_error(c_error) + + @property + def supports_releasing_memory(self): + return True + + def record_allocation_state(self): + return self.dll.ArrowCpp_BytesAllocated() + + def compare_allocation_state(self, recorded, gc_until): + def pred(): + # No GC on our side, so just compare allocation state + return self.record_allocation_state() == recorded + + return gc_until(pred) + + +class CppCDataImporter(CDataImporter, _CDataBase): + + def import_schema_and_compare_to_json(self, json_path, c_schema_ptr): + c_error = self.dll.ArrowCpp_CDataIntegration_ImportSchemaAndCompareToJson( + str(json_path).encode(), c_schema_ptr) + self._check_c_error(c_error) + + def import_batch_and_compare_to_json(self, json_path, num_batch, + c_array_ptr): + c_error = self.dll.ArrowCpp_CDataIntegration_ImportBatchAndCompareToJson( + str(json_path).encode(), num_batch, c_array_ptr) + self._check_c_error(c_error) + + @property + def supports_releasing_memory(self): + return True + + def gc_until(self, predicate): + # No GC on our side, so can evaluate predicate immediately + return predicate() diff --git a/dev/archery/archery/integration/util.py b/dev/archery/archery/integration/util.py index 80ba30052e4da..afef7d5eb13b9 100644 --- a/dev/archery/archery/integration/util.py +++ b/dev/archery/archery/integration/util.py @@ -32,8 +32,10 @@ def guid(): # SKIP categories -SKIP_ARROW = 'arrow' +SKIP_C_ARRAY = 'c_array' +SKIP_C_SCHEMA = 'c_schema' SKIP_FLIGHT = 'flight' +SKIP_IPC = 'ipc' class _Printer: diff --git a/dev/archery/archery/lang/python.py b/dev/archery/archery/lang/python.py index 8600a0d7c48c0..d4c1853d097b2 100644 --- a/dev/archery/archery/lang/python.py +++ b/dev/archery/archery/lang/python.py @@ -16,6 +16,7 @@ # under the License. from contextlib import contextmanager +from enum import EnumMeta import inspect import tokenize @@ -112,6 +113,10 @@ def inspect_signature(obj): class NumpyDoc: + IGNORE_VALIDATION_ERRORS_FOR_TYPE = { + # Enum function signatures should never be documented + EnumMeta: ["PR01"] + } def __init__(self, symbols=None): if not have_numpydoc: @@ -229,6 +234,10 @@ def callback(obj): continue if disallow_rules and errcode in disallow_rules: continue + if any(isinstance(obj, obj_type) and errcode in errcode_list + for obj_type, errcode_list + in NumpyDoc.IGNORE_VALIDATION_ERRORS_FOR_TYPE.items()): + continue errors.append((errcode, errmsg)) if len(errors): diff --git a/dev/archery/setup.py b/dev/archery/setup.py index 627e576fb6f59..e2c89ae204bd6 100755 --- a/dev/archery/setup.py +++ b/dev/archery/setup.py @@ -28,16 +28,17 @@ jinja_req = 'jinja2>=2.11' extras = { - 'lint': ['numpydoc==1.1.0', 'autopep8', 'flake8==6.1.0', 'cython-lint', - 'cmake_format==0.6.13'], 'benchmark': ['pandas'], - 'docker': ['ruamel.yaml', 'python-dotenv'], - 'release': ['pygithub', jinja_req, 'jira', 'semver', 'gitpython'], 'crossbow': ['github3.py', jinja_req, 'pygit2>=1.6.0', 'requests', - 'ruamel.yaml', 'setuptools_scm'], + 'ruamel.yaml', 'setuptools_scm<8.0.0'], 'crossbow-upload': ['github3.py', jinja_req, 'ruamel.yaml', 'setuptools_scm'], - 'numpydoc': ['numpydoc==1.1.0'] + 'docker': ['ruamel.yaml', 'python-dotenv'], + 'integration': ['cffi'], + 'lint': ['numpydoc==1.1.0', 'autopep8', 'flake8==6.1.0', 'cython-lint', + 'cmake_format==0.6.13'], + 'numpydoc': ['numpydoc==1.1.0'], + 'release': ['pygithub', jinja_req, 'jira', 'semver', 'gitpython'], } extras['bot'] = extras['crossbow'] + ['pygithub', 'jira'] extras['all'] = list(set(functools.reduce(operator.add, extras.values()))) diff --git a/dev/release/post-11-bump-versions-test.rb b/dev/release/post-11-bump-versions-test.rb index 79d17e84eb7cb..0ef4646236740 100644 --- a/dev/release/post-11-bump-versions-test.rb +++ b/dev/release/post-11-bump-versions-test.rb @@ -235,7 +235,7 @@ def test_version_post_tag ] end - Dir.glob("go/**/{go.mod,*.go,*.go.*}") do |path| + Dir.glob("go/**/{go.mod,*.go,*.go.*,README.md}") do |path| if path == "go/arrow/doc.go" expected_changes << { path: path, @@ -253,19 +253,34 @@ def test_version_post_tag hunks = [] if release_type == :major lines = File.readlines(path, chomp: true) - target_lines = lines.grep(/#{Regexp.escape(import_path)}/) + target_lines = lines.each_with_index.select do |line, i| + line.include?(import_path) + end next if target_lines.empty? - hunk = [] - target_lines.each do |line| - hunk << "-#{line}" + n_context_lines = 3 # The default of Git's diff.context + target_hunks = [[target_lines.first[0]]] + previous_i = target_lines.first[1] + target_lines[1..-1].each do |line, i| + if i - previous_i < n_context_lines + target_hunks.last << line + else + target_hunks << [line] + end + previous_i = i end - target_lines.each do |line| - new_line = line.gsub("v#{@snapshot_major_version}") do - "v#{@next_major_version}" + target_hunks.each do |lines| + hunk = [] + lines.each do |line,| + hunk << "-#{line}" + end + lines.each do |line| + new_line = line.gsub("v#{@snapshot_major_version}") do + "v#{@next_major_version}" + end + hunk << "+#{new_line}" end - hunk << "+#{new_line}" + hunks << hunk end - hunks << hunk end if path == "go/parquet/writer_properties.go" hunks << [ diff --git a/dev/release/utils-prepare.sh b/dev/release/utils-prepare.sh index ceb51812c11ae..464702b811d8b 100644 --- a/dev/release/utils-prepare.sh +++ b/dev/release/utils-prepare.sh @@ -155,8 +155,8 @@ update_versions() { popd pushd "${ARROW_DIR}/go" - find . "(" -name "*.go*" -o -name "go.mod" ")" -exec sed -i.bak -E -e \ - "s|(github\\.com/apache/arrow/go)/v[0-9]+|\1/v${major_version}|" {} \; + find . "(" -name "*.go*" -o -name "go.mod" -o -name README.md ")" -exec sed -i.bak -E -e \ + "s|(github\\.com/apache/arrow/go)/v[0-9]+|\1/v${major_version}|g" {} \; # update parquet writer version sed -i.bak -E -e \ "s/\"parquet-go version .+\"/\"parquet-go version ${version}\"/" \ diff --git a/dev/tasks/conda-recipes/arrow-cpp/meta.yaml b/dev/tasks/conda-recipes/arrow-cpp/meta.yaml index ac4b29eb5ee7e..fbe40af3dae01 100644 --- a/dev/tasks/conda-recipes/arrow-cpp/meta.yaml +++ b/dev/tasks/conda-recipes/arrow-cpp/meta.yaml @@ -244,7 +244,7 @@ outputs: - numpy - python - setuptools - - setuptools_scm + - setuptools_scm <8.0.0 run: # - {{ pin_subpackage('libarrow', exact=True) }} - libarrow ={{ version }}=*_{{ PKG_BUILDNUM }}_{{ build_ext }} @@ -327,7 +327,7 @@ outputs: - numpy - python - setuptools - - setuptools_scm + - setuptools_scm <8.0.0 run: - {{ pin_subpackage('pyarrow', exact=True) }} - python diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index ed238778635d3..29e038a922412 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1286,6 +1286,14 @@ tasks: PYTHON: "3.10" image: conda-python-substrait + test-conda-python-3.10-cython2: + ci: github + template: docker-tests/github.linux.yml + params: + env: + PYTHON: "3.10" + image: conda-python-cython2 + test-debian-11-python-3: ci: azure template: docker-tests/azure.linux.yml diff --git a/docker-compose.yml b/docker-compose.yml index a79b13c0a5f91..8ae06900c57f9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -119,6 +119,7 @@ x-hierarchy: - conda-python: - conda-python-pandas: - conda-python-docs + - conda-python-cython2 - conda-python-dask - conda-python-hdfs - conda-python-java-integration @@ -1349,6 +1350,30 @@ services: /arrow/ci/scripts/java_build.sh /arrow /build /tmp/dist/java && /arrow/ci/scripts/java_cdata_integration.sh /arrow /tmp/dist/java" ] + conda-python-cython2: + # Usage: + # docker-compose build conda + # docker-compose build conda-cpp + # docker-compose build conda-python + # docker-compose build conda-python-cython2 + # docker-compose run --rm conda-python-cython2 + image: ${REPO}:${ARCH}-conda-python-${PYTHON}-cython2 + build: + context: . + dockerfile: ci/docker/conda-python-cython2.dockerfile + cache_from: + - ${REPO}:${ARCH}-conda-python-${PYTHON}-cython2 + args: + repo: ${REPO} + arch: ${ARCH} + python: ${PYTHON} + shm_size: *shm-size + environment: + <<: [*common, *ccache] + PYTEST_ARGS: # inherit + volumes: *conda-volumes + command: *python-conda-command + ################################## R ######################################## ubuntu-r: diff --git a/docs/source/format/Columnar.rst b/docs/source/format/Columnar.rst index 3390f1b7b5f2c..afbe2a08ee28c 100644 --- a/docs/source/format/Columnar.rst +++ b/docs/source/format/Columnar.rst @@ -21,7 +21,7 @@ Arrow Columnar Format ********************* -*Version: 1.3* +*Version: 1.4* The "Arrow Columnar Format" includes a language-agnostic in-memory data structure specification, metadata serialization, and a protocol @@ -108,6 +108,10 @@ the different physical layouts defined by Arrow: * **Variable-size Binary**: a sequence of values each having a variable byte length. Two variants of this layout are supported using 32-bit and 64-bit length encoding. +* **Views of Variable-size Binary**: a sequence of values each having a + variable byte length. In contrast to Variable-size Binary, the values + of this layout are distributed across potentially multiple buffers + instead of densely and sequentially packed in a single buffer. * **Fixed-size List**: a nested layout where each value has the same number of elements taken from a child data type. * **Variable-size List**: a nested layout where each value is a @@ -350,6 +354,51 @@ will be represented as follows: :: |----------------|-----------------------| | joemark | unspecified (padding) | +Variable-size Binary View Layout +-------------------------------- + +.. versionadded:: Arrow Columnar Format 1.4 + +Each value in this layout consists of 0 or more bytes. These bytes' +locations are indicated using a **views** buffer, which may point to one +of potentially several **data** buffers or may contain the characters +inline. + +The views buffer contains `length` view structures with the following layout: + +:: + + * Short strings, length <= 12 + | Bytes 0-3 | Bytes 4-15 | + |------------|---------------------------------------| + | length | data (padded with 0) | + + * Long strings, length > 12 + | Bytes 0-3 | Bytes 4-7 | Bytes 8-11 | Bytes 12-15 | + |------------|------------|------------|-------------| + | length | prefix | buf. index | offset | + +In both the long and short string cases, the first four bytes encode the +length of the string and can be used to determine how the rest of the view +should be interpreted. + +In the short string case the string's bytes are inlined- stored inside the +view itself, in the twelve bytes which follow the length. + +In the long string case, a buffer index indicates which data buffer +stores the data bytes and an offset indicates where in that buffer the +data bytes begin. Buffer index 0 refers to the first data buffer, IE +the first buffer **after** the validity buffer and the views buffer. +The half-open range ``[offset, offset + length)`` must be entirely contained +within the indicated buffer. A copy of the first four bytes of the string is +stored inline in the prefix, after the length. This prefix enables a +profitable fast path for string comparisons, which are frequently determined +within the first four bytes. + +All integers (length, buffer index, and offset) are signed. + +This layout is adapted from TU Munich's `UmbraDB`_. + .. _variable-size-list-layout: Variable-size List Layout @@ -880,19 +929,20 @@ For the avoidance of ambiguity, we provide listing the order and type of memory buffers for each layout. .. csv-table:: Buffer Layouts - :header: "Layout Type", "Buffer 0", "Buffer 1", "Buffer 2" - :widths: 30, 20, 20, 20 - - "Primitive",validity,data, - "Variable Binary",validity,offsets,data - "List",validity,offsets, - "Fixed-size List",validity,, - "Struct",validity,, - "Sparse Union",type ids,, - "Dense Union",type ids,offsets, - "Null",,, - "Dictionary-encoded",validity,data (indices), - "Run-end encoded",,, + :header: "Layout Type", "Buffer 0", "Buffer 1", "Buffer 2", "Variadic Buffers" + :widths: 30, 20, 20, 20, 20 + + "Primitive",validity,data,, + "Variable Binary",validity,offsets,data, + "Variable Binary View",validity,views,,data + "List",validity,offsets,, + "Fixed-size List",validity,,, + "Struct",validity,,, + "Sparse Union",type ids,,, + "Dense Union",type ids,offsets,, + "Null",,,, + "Dictionary-encoded",validity,data (indices),, + "Run-end encoded",,,, Logical Types ============= @@ -1071,6 +1121,39 @@ bytes. Since this metadata can be used to communicate in-memory pointer addresses between libraries, it is recommended to set ``size`` to the actual memory size rather than the padded size. +Variadic buffers +^^^^^^^^^^^^^^^^ + +Some types such as Utf8View are represented using a variable number of buffers. +For each such Field in the pre-ordered flattened logical schema, there will be +an entry in ``variadicBufferCounts`` to indicate the number of variadic buffers +which belong to that Field in the current RecordBatch. + +For example, consider the schema :: + + col1: Struct + col2: Utf8View + +This has two fields with variadic buffers, so ``variadicBufferCounts`` will +have two entries in each RecordBatch. For a RecordBatch of this schema with +``variadicBufferCounts = [3, 2]``, the flattened buffers would be:: + + buffer 0: col1 validity + buffer 1: col1.a validity + buffer 2: col1.a values + buffer 3: col1.b validity + buffer 4: col1.b views + buffer 5: col1.b data + buffer 6: col1.b data + buffer 7: col1.b data + buffer 8: col1.c validity + buffer 9: col1.c values + buffer 10: col2 validity + buffer 11: col2 views + buffer 12: col2 data + buffer 13: col2 data + + Byte Order (`Endianness`_) --------------------------- @@ -1346,3 +1429,4 @@ the Arrow spec. .. _Endianness: https://en.wikipedia.org/wiki/Endianness .. _SIMD: https://software.intel.com/en-us/cpp-compiler-developer-guide-and-reference-introduction-to-the-simd-data-layout-templates .. _Parquet: https://parquet.apache.org/docs/ +.. _UmbraDB: https://db.in.tum.de/~freitag/papers/p29-neumann-cidr20.pdf diff --git a/docs/source/java/dataset.rst b/docs/source/java/dataset.rst index 35ffa81058072..a4381e0814638 100644 --- a/docs/source/java/dataset.rst +++ b/docs/source/java/dataset.rst @@ -132,12 +132,10 @@ within method ``Scanner::schema()``: .. _java-dataset-projection: -Projection -========== +Projection (Subset of Columns) +============================== -User can specify projections in ScanOptions. For ``FileSystemDataset``, only -column projection is allowed for now, which means, only column names -in the projection list will be accepted. For example: +User can specify projections in ScanOptions. For example: .. code-block:: Java @@ -159,6 +157,27 @@ Or use shortcut construtor: Then all columns will be emitted during scanning. +Projection (Produce New Columns) and Filters +============================================ + +User can specify projections (new columns) or filters in ScanOptions using Substrait. For example: + +.. code-block:: Java + + ByteBuffer substraitExpressionFilter = getSubstraitExpressionFilter(); + ByteBuffer substraitExpressionProject = getSubstraitExpressionProjection(); + // Use Substrait APIs to create an Expression and serialize to a ByteBuffer + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitExpressionFilter(substraitExpressionFilter) + .substraitExpressionProjection(getSubstraitExpressionProjection()) + .build(); + +.. seealso:: + + :doc:`Executing Projections and Filters Using Extended Expressions ` + Projections and Filters using Substrait. + Read Data from HDFS =================== diff --git a/docs/source/java/substrait.rst b/docs/source/java/substrait.rst index 41effedbf01d9..d8d49a96e88f8 100644 --- a/docs/source/java/substrait.rst +++ b/docs/source/java/substrait.rst @@ -22,8 +22,10 @@ Substrait The ``arrow-dataset`` module can execute Substrait_ plans via the :doc:`Acero <../cpp/streaming_execution>` query engine. -Executing Substrait Plans -========================= +.. contents:: + +Executing Queries Using Substrait Plans +======================================= Plans can reference data in files via URIs, or "named tables" that must be provided along with the plan. @@ -102,6 +104,349 @@ Here is an example of a Java program that queries a Parquet file using Java Subs 0 ALGERIA 0 haggle. carefully final deposits detect slyly agai 1 ARGENTINA 1 al foxes promise slyly according to the regular accounts. bold requests alon +Executing Projections and Filters Using Extended Expressions +============================================================ + +Dataset also supports projections and filters with Substrait's `Extended Expression`_. +This requires the substrait-java library. + +This Java program: + +- Loads a Parquet file containing the "nation" table from the TPC-H benchmark. +- Projects two new columns: + - ``N_NAME || ' - ' || N_COMMENT`` + - ``N_REGIONKEY + 10`` +- Applies a filter: ``N_NATIONKEY > 18`` + +.. code-block:: Java + + import io.substrait.extension.ExtensionCollector; + import io.substrait.proto.Expression; + import io.substrait.proto.ExpressionReference; + import io.substrait.proto.ExtendedExpression; + import io.substrait.proto.FunctionArgument; + import io.substrait.proto.SimpleExtensionDeclaration; + import io.substrait.proto.SimpleExtensionURI; + import io.substrait.type.NamedStruct; + import io.substrait.type.Type; + import io.substrait.type.TypeCreator; + import io.substrait.type.proto.TypeProtoConverter; + import java.nio.ByteBuffer; + import java.util.ArrayList; + import java.util.Arrays; + import java.util.Base64; + import java.util.HashMap; + import java.util.List; + import java.util.Optional; + import org.apache.arrow.dataset.file.FileFormat; + import org.apache.arrow.dataset.file.FileSystemDatasetFactory; + import org.apache.arrow.dataset.jni.NativeMemoryPool; + import org.apache.arrow.dataset.scanner.ScanOptions; + import org.apache.arrow.dataset.scanner.Scanner; + import org.apache.arrow.dataset.source.Dataset; + import org.apache.arrow.dataset.source.DatasetFactory; + import org.apache.arrow.memory.BufferAllocator; + import org.apache.arrow.memory.RootAllocator; + import org.apache.arrow.vector.ipc.ArrowReader; + + public class ClientSubstraitExtendedExpressionsCookbook { + + public static void main(String[] args) throws Exception { + // project and filter dataset using extended expression definition - 03 Expressions: + // Expression 01 - CONCAT: N_NAME || ' - ' || N_COMMENT = col 1 || ' - ' || col 3 + // Expression 02 - ADD: N_REGIONKEY + 10 = col 1 + 10 + // Expression 03 - FILTER: N_NATIONKEY > 18 = col 3 > 18 + projectAndFilterDataset(); + } + + public static void projectAndFilterDataset() { + String uri = "file:///Users/data/tpch_parquet/nation.parquet"; + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getSubstraitExpressionFilter()) + .substraitProjection(getSubstraitExpressionProjection()) + .build(); + try ( + BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), + FileFormat.PARQUET, uri); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + while (reader.loadNextBatch()) { + System.out.println( + reader.getVectorSchemaRoot().contentToTSVString()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static ByteBuffer getSubstraitExpressionProjection() { + // Expression: N_REGIONKEY + 10 = col 3 + 10 + Expression.Builder selectionBuilderProjectOne = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 2) + ) + ) + ); + Expression.Builder literalBuilderProjectOne = Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setI32(10) + ); + io.substrait.proto.Type outputProjectOne = TypeCreator.NULLABLE.I32.accept( + new TypeProtoConverter(new ExtensionCollector())); + Expression.Builder expressionBuilderProjectOne = Expression. + newBuilder(). + setScalarFunction( + Expression. + ScalarFunction. + newBuilder(). + setFunctionReference(0). + setOutputType(outputProjectOne). + addArguments( + 0, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectOne) + ). + addArguments( + 1, + FunctionArgument.newBuilder().setValue( + literalBuilderProjectOne) + ) + ); + ExpressionReference.Builder expressionReferenceBuilderProjectOne = ExpressionReference.newBuilder(). + setExpression(expressionBuilderProjectOne) + .addOutputNames("ADD_TEN_TO_COLUMN_N_REGIONKEY"); + + // Expression: name || name = N_NAME || "-" || N_COMMENT = col 1 || col 3 + Expression.Builder selectionBuilderProjectTwo = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 1) + ) + ) + ); + Expression.Builder selectionBuilderProjectTwoConcatLiteral = Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setString(" - ") + ); + Expression.Builder selectionBuilderProjectOneToConcat = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 3) + ) + ) + ); + io.substrait.proto.Type outputProjectTwo = TypeCreator.NULLABLE.STRING.accept( + new TypeProtoConverter(new ExtensionCollector())); + Expression.Builder expressionBuilderProjectTwo = Expression. + newBuilder(). + setScalarFunction( + Expression. + ScalarFunction. + newBuilder(). + setFunctionReference(1). + setOutputType(outputProjectTwo). + addArguments( + 0, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectTwo) + ). + addArguments( + 1, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectTwoConcatLiteral) + ). + addArguments( + 2, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectOneToConcat) + ) + ); + ExpressionReference.Builder expressionReferenceBuilderProjectTwo = ExpressionReference.newBuilder(). + setExpression(expressionBuilderProjectTwo) + .addOutputNames("CONCAT_COLUMNS_N_NAME_AND_N_COMMENT"); + + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", + "N_REGIONKEY", "N_COMMENT"); + List dataTypes = Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING + ); + NamedStruct of = NamedStruct.of( + columnNames, + Type.Struct.builder().fields(dataTypes).nullable(false).build() + ); + // Extensions URI + HashMap extensionUris = new HashMap<>(); + extensionUris.put( + "key-001", + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("/functions_arithmetic.yaml") + .build() + ); + // Extensions + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionAdd = SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(0) + .setName("add:i32_i32") + .setExtensionUriReference(1)) + .build(); + SimpleExtensionDeclaration extensionFunctionGreaterThan = SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("concat:vchar") + .setExtensionUriReference(2)) + .build(); + extensions.add(extensionFunctionAdd); + extensions.add(extensionFunctionGreaterThan); + // Extended Expression + ExtendedExpression.Builder extendedExpressionBuilder = + ExtendedExpression.newBuilder(). + addReferredExpr(0, + expressionReferenceBuilderProjectOne). + addReferredExpr(1, + expressionReferenceBuilderProjectTwo). + setBaseSchema(of.toProto(new TypeProtoConverter( + new ExtensionCollector()))); + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + ExtendedExpression extendedExpression = extendedExpressionBuilder.build(); + byte[] extendedExpressions = Base64.getDecoder().decode( + Base64.getEncoder().encodeToString( + extendedExpression.toByteArray())); + ByteBuffer substraitExpressionProjection = ByteBuffer.allocateDirect( + extendedExpressions.length); + substraitExpressionProjection.put(extendedExpressions); + return substraitExpressionProjection; + } + + private static ByteBuffer getSubstraitExpressionFilter() { + // Expression: Filter: N_NATIONKEY > 18 = col 1 > 18 + Expression.Builder selectionBuilderFilterOne = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 0) + ) + ) + ); + Expression.Builder literalBuilderFilterOne = Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setI32(18) + ); + io.substrait.proto.Type outputFilterOne = TypeCreator.NULLABLE.BOOLEAN.accept( + new TypeProtoConverter(new ExtensionCollector())); + Expression.Builder expressionBuilderFilterOne = Expression. + newBuilder(). + setScalarFunction( + Expression. + ScalarFunction. + newBuilder(). + setFunctionReference(1). + setOutputType(outputFilterOne). + addArguments( + 0, + FunctionArgument.newBuilder().setValue( + selectionBuilderFilterOne) + ). + addArguments( + 1, + FunctionArgument.newBuilder().setValue( + literalBuilderFilterOne) + ) + ); + ExpressionReference.Builder expressionReferenceBuilderFilterOne = ExpressionReference.newBuilder(). + setExpression(expressionBuilderFilterOne) + .addOutputNames("COLUMN_N_NATIONKEY_GREATER_THAN_18"); + + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", + "N_REGIONKEY", "N_COMMENT"); + List dataTypes = Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING + ); + NamedStruct of = NamedStruct.of( + columnNames, + Type.Struct.builder().fields(dataTypes).nullable(false).build() + ); + // Extensions URI + HashMap extensionUris = new HashMap<>(); + extensionUris.put( + "key-001", + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("/functions_comparison.yaml") + .build() + ); + // Extensions + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("gt:any_any") + .setExtensionUriReference(1)) + .build(); + extensions.add(extensionFunctionLowerThan); + // Extended Expression + ExtendedExpression.Builder extendedExpressionBuilder = + ExtendedExpression.newBuilder(). + addReferredExpr(0, + expressionReferenceBuilderFilterOne). + setBaseSchema(of.toProto(new TypeProtoConverter( + new ExtensionCollector()))); + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + ExtendedExpression extendedExpression = extendedExpressionBuilder.build(); + byte[] extendedExpressions = Base64.getDecoder().decode( + Base64.getEncoder().encodeToString( + extendedExpression.toByteArray())); + ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect( + extendedExpressions.length); + substraitExpressionFilter.put(extendedExpressions); + return substraitExpressionFilter; + } + } + +.. code-block:: text + + ADD_TEN_TO_COLUMN_N_REGIONKEY CONCAT_COLUMNS_N_NAME_AND_N_COMMENT + 13 ROMANIA - ular asymptotes are about the furious multipliers. express dependencies nag above the ironically ironic account + 14 SAUDI ARABIA - ts. silent requests haggle. closely express packages sleep across the blithely + 12 VIETNAM - hely enticingly express accounts. even, final + 13 RUSSIA - requests against the platelets use never according to the quickly regular pint + 13 UNITED KINGDOM - eans boost carefully special requests. accounts are. carefull + 11 UNITED STATES - y final packages. slow foxes cajole quickly. quickly silent platelets breach ironic accounts. unusual pinto be + .. _`Substrait`: https://substrait.io/ .. _`Substrait Java`: https://github.com/substrait-io/substrait-java -.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html \ No newline at end of file +.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html +.. _`Extended Expression`: https://github.com/substrait-io/substrait/blob/main/site/docs/expressions/extended_expression.md diff --git a/format/Message.fbs b/format/Message.fbs index 170ea8fbced89..92a629f3f9d95 100644 --- a/format/Message.fbs +++ b/format/Message.fbs @@ -99,6 +99,22 @@ table RecordBatch { /// Optional compression of the message body compression: BodyCompression; + + /// Some types such as Utf8View are represented using a variable number of buffers. + /// For each such Field in the pre-ordered flattened logical schema, there will be + /// an entry in variadicBufferCounts to indicate the number of number of variadic + /// buffers which belong to that Field in the current RecordBatch. + /// + /// For example, the schema + /// col1: Struct + /// col2: Utf8View + /// contains two Fields with variadic buffers so variadicBufferCounts will have + /// two entries, the first counting the variadic buffers of `col1.b` and the + /// second counting `col2`'s. + /// + /// This field may be omitted if and only if the schema contains no Fields with + /// a variable number of buffers, such as BinaryView and Utf8View. + variadicBufferCounts: [long]; } /// For sending dictionary encoding information. Any Field can be diff --git a/format/Schema.fbs b/format/Schema.fbs index ce29c25b7d1c8..fdaf623931760 100644 --- a/format/Schema.fbs +++ b/format/Schema.fbs @@ -22,6 +22,7 @@ /// Version 1.1 - Add Decimal256. /// Version 1.2 - Add Interval MONTH_DAY_NANO. /// Version 1.3 - Add Run-End Encoded. +/// Version 1.4 - Add BinaryView, Utf8View, and variadicBufferCounts. namespace org.apache.arrow.flatbuf; @@ -171,6 +172,27 @@ table LargeUtf8 { table LargeBinary { } +/// Logically the same as Utf8, but the internal representation uses a view +/// struct that contains the string length and either the string's entire data +/// inline (for small strings) or an inlined prefix, an index of another buffer, +/// and an offset pointing to a slice in that buffer (for non-small strings). +/// +/// Since it uses a variable number of data buffers, each Field with this type +/// must have a corresponding entry in `variadicBufferCounts`. +table Utf8View { +} + +/// Logically the same as Binary, but the internal representation uses a header +/// struct that contains the string length and either the string's entire data +/// inline (for small strings) or an inlined prefix, an index of another buffer, +/// and an offset pointing to a slice in that buffer (for non-small strings). +/// +/// Since it uses a variable number of data buffers, each Field with this type +/// must have a corresponding entry in `variadicBufferCounts`. +table BinaryView { +} + + table FixedSizeBinary { /// Number of bytes per value byteWidth: int; @@ -427,6 +449,8 @@ union Type { LargeUtf8, LargeList, RunEndEncoded, + BinaryView, + Utf8View, } /// ---------------------------------------------------------------------- diff --git a/go/README.md b/go/README.md index 5b3f72760f331..660549cb1b366 100644 --- a/go/README.md +++ b/go/README.md @@ -20,7 +20,7 @@ Apache Arrow for Go =================== -[![GoDoc](https://godoc.org/github.com/apache/arrow/go/arrow?status.svg)](https://godoc.org/github.com/apache/arrow/go/arrow) +[![Go Reference](https://pkg.go.dev/badge/github.com/apache/arrow/go/v14.svg)](https://pkg.go.dev/github.com/apache/arrow/go/v14) [Apache Arrow][arrow] is a cross-language development platform for in-memory data. It specifies a standardized language-independent columnar memory format diff --git a/go/arrow/flight/flightsql/driver/README.md b/go/arrow/flight/flightsql/driver/README.md index f81cb9250e1c9..b8850527c19c1 100644 --- a/go/arrow/flight/flightsql/driver/README.md +++ b/go/arrow/flight/flightsql/driver/README.md @@ -36,7 +36,7 @@ connection pooling, transactions combined with ease of use (see (#usage)). ## Prerequisites * Go 1.17+ -* Installation via `go get -u github.com/apache/arrow/go/v12/arrow/flight/flightsql` +* Installation via `go get -u github.com/apache/arrow/go/v14/arrow/flight/flightsql` * Backend speaking FlightSQL --------------------------------------- @@ -55,7 +55,7 @@ import ( "database/sql" "time" - _ "github.com/apache/arrow/go/v12/arrow/flight/flightsql" + _ "github.com/apache/arrow/go/v14/arrow/flight/flightsql" ) // Open the connection to an SQLite backend @@ -141,7 +141,7 @@ import ( "log" "time" - "github.com/apache/arrow/go/v12/arrow/flight/flightsql" + "github.com/apache/arrow/go/v14/arrow/flight/flightsql" ) func main() { diff --git a/go/parquet/file/file_writer.go b/go/parquet/file/file_writer.go index 64a21473c293a..c6289434bbe6e 100644 --- a/go/parquet/file/file_writer.go +++ b/go/parquet/file/file_writer.go @@ -41,23 +41,24 @@ type Writer struct { // The Schema of this writer Schema *schema.Schema - // The current FileMetadata to write - FileMetadata *metadata.FileMetaData - // The current keyvalue metadata - KeyValueMetadata metadata.KeyValueMetadata } -type WriteOption func(*Writer) +type writerConfig struct { + props *parquet.WriterProperties + keyValueMetadata metadata.KeyValueMetadata +} + +type WriteOption func(*writerConfig) func WithWriterProps(props *parquet.WriterProperties) WriteOption { - return func(w *Writer) { - w.props = props + return func(c *writerConfig) { + c.props = props } } func WithWriteMetadata(meta metadata.KeyValueMetadata) WriteOption { - return func(w *Writer) { - w.KeyValueMetadata = meta + return func(c *writerConfig) { + c.keyValueMetadata = meta } } @@ -66,19 +67,23 @@ func WithWriteMetadata(meta metadata.KeyValueMetadata) WriteOption { // If props is nil, then the default Writer Properties will be used. If the key value metadata is not nil, // it will be added to the file. func NewParquetWriter(w io.Writer, sc *schema.GroupNode, opts ...WriteOption) *Writer { + config := &writerConfig{} + for _, o := range opts { + o(config) + } + if config.props == nil { + config.props = parquet.NewWriterProperties() + } + fileSchema := schema.NewSchema(sc) fw := &Writer{ + props: config.props, sink: &utils.TellWrapper{Writer: w}, open: true, Schema: fileSchema, } - for _, o := range opts { - o(fw) - } - if fw.props == nil { - fw.props = parquet.NewWriterProperties() - } - fw.metadata = *metadata.NewFileMetadataBuilder(fw.Schema, fw.props, fw.KeyValueMetadata) + + fw.metadata = *metadata.NewFileMetadataBuilder(fw.Schema, fw.props, config.keyValueMetadata) fw.startFile() return fw } @@ -154,6 +159,11 @@ func (fw *Writer) startFile() { } } +// AppendKeyValueMetadata appends a key/value pair to the existing key/value metadata +func (fw *Writer) AppendKeyValueMetadata(key string, value string) error { + return fw.metadata.AppendKeyValueMetadata(key, value) +} + // Close closes any open row group writer and writes the file footer. Subsequent // calls to close will have no effect. func (fw *Writer) Close() (err error) { @@ -180,11 +190,12 @@ func (fw *Writer) Close() (err error) { fileEncryptProps := fw.props.FileEncryptionProperties() if fileEncryptProps == nil { // non encrypted file - if fw.FileMetadata, err = fw.metadata.Finish(); err != nil { + fileMetadata, err := fw.metadata.Finish() + if err != nil { return err } - _, err = writeFileMetadata(fw.FileMetadata, fw.sink) + _, err = writeFileMetadata(fileMetadata, fw.sink) return err } @@ -193,12 +204,12 @@ func (fw *Writer) Close() (err error) { return nil } -func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) (err error) { +func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) error { // encrypted file with encrypted footer if props.EncryptedFooter() { - fw.FileMetadata, err = fw.metadata.Finish() + fileMetadata, err := fw.metadata.Finish() if err != nil { - return + return err } footerLen := int64(0) @@ -211,7 +222,7 @@ func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) (e footerLen += n footerEncryptor := fw.fileEncryptor.GetFooterEncryptor() - n, err = writeEncryptedFileMetadata(fw.FileMetadata, fw.sink, footerEncryptor, true) + n, err = writeEncryptedFileMetadata(fileMetadata, fw.sink, footerEncryptor, true) if err != nil { return err } @@ -224,11 +235,12 @@ func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) (e return err } } else { - if fw.FileMetadata, err = fw.metadata.Finish(); err != nil { - return + fileMetadata, err := fw.metadata.Finish() + if err != nil { + return err } footerSigningEncryptor := fw.fileEncryptor.GetFooterSigningEncryptor() - if _, err = writeEncryptedFileMetadata(fw.FileMetadata, fw.sink, footerSigningEncryptor, false); err != nil { + if _, err = writeEncryptedFileMetadata(fileMetadata, fw.sink, footerSigningEncryptor, false); err != nil { return err } } diff --git a/go/parquet/file/file_writer_test.go b/go/parquet/file/file_writer_test.go index 0cca1cd40d4c9..af083ebe60e4f 100644 --- a/go/parquet/file/file_writer_test.go +++ b/go/parquet/file/file_writer_test.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow/go/v14/parquet/internal/testutils" "github.com/apache/arrow/go/v14/parquet/schema" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -371,6 +372,34 @@ func TestAllNulls(t *testing.T) { assert.Equal(t, []int16{0, 0, 0}, defLevels[:]) } +func TestKeyValueMetadata(t *testing.T) { + fields := schema.FieldList{ + schema.NewInt32Node("unused", parquet.Repetitions.Optional, -1), + } + sc, _ := schema.NewGroupNode("root", parquet.Repetitions.Required, fields, -1) + sink := encoding.NewBufferWriter(0, memory.DefaultAllocator) + + writer := file.NewParquetWriter(sink, sc) + + testKey := "testKey" + testValue := "testValue" + writer.AppendKeyValueMetadata(testKey, testValue) + writer.Close() + + buffer := sink.Finish() + defer buffer.Release() + props := parquet.NewReaderProperties(memory.DefaultAllocator) + props.BufferedStreamEnabled = true + + reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()), file.WithReadProps(props)) + assert.NoError(t, err) + + metadata := reader.MetaData() + got := metadata.KeyValueMetadata().FindValue(testKey) + require.NotNil(t, got) + assert.Equal(t, testValue, *got) +} + func createSerializeTestSuite(typ reflect.Type) suite.TestingSuite { return &SerializeTestSuite{PrimitiveTypedTest: testutils.NewPrimitiveTypedTest(typ)} } diff --git a/go/parquet/metadata/file.go b/go/parquet/metadata/file.go index efe3c01c25b33..dddd95c5df670 100644 --- a/go/parquet/metadata/file.go +++ b/go/parquet/metadata/file.go @@ -95,6 +95,11 @@ func (f *FileMetaDataBuilder) AppendRowGroup() *RowGroupMetaDataBuilder { return f.currentRgBldr } +// AppendKeyValueMetadata appends a key/value pair to the existing key/value metadata +func (f *FileMetaDataBuilder) AppendKeyValueMetadata(key string, value string) error { + return f.kvmeta.Append(key, value) +} + // Finish will finalize the metadata of the number of rows, row groups, // version etc. This will clear out this filemetadatabuilder so it can // be re-used diff --git a/go/parquet/metadata/metadata_test.go b/go/parquet/metadata/metadata_test.go index 0db64d88ab0f4..b685dd2223274 100644 --- a/go/parquet/metadata/metadata_test.go +++ b/go/parquet/metadata/metadata_test.go @@ -272,6 +272,41 @@ func TestKeyValueMetadata(t *testing.T) { assert.True(t, faccessor.KeyValueMetadata().Equals(kvmeta)) } +func TestKeyValueMetadataAppend(t *testing.T) { + props := parquet.NewWriterProperties(parquet.WithVersion(parquet.V1_0)) + + fields := schema.FieldList{ + schema.NewInt32Node("int_col", parquet.Repetitions.Required, -1), + schema.NewFloat32Node("float_col", parquet.Repetitions.Required, -1), + } + root, err := schema.NewGroupNode("schema", parquet.Repetitions.Repeated, fields, -1) + require.NoError(t, err) + schema := schema.NewSchema(root) + + kvmeta := metadata.NewKeyValueMetadata() + key1 := "test_key1" + value1 := "test_value1" + require.NoError(t, kvmeta.Append(key1, value1)) + + fbuilder := metadata.NewFileMetadataBuilder(schema, props, kvmeta) + + key2 := "test_key2" + value2 := "test_value2" + require.NoError(t, fbuilder.AppendKeyValueMetadata(key2, value2)) + faccessor, err := fbuilder.Finish() + require.NoError(t, err) + + kv := faccessor.KeyValueMetadata() + + got1 := kv.FindValue(key1) + require.NotNil(t, got1) + assert.Equal(t, value1, *got1) + + got2 := kv.FindValue(key2) + require.NotNil(t, got2) + assert.Equal(t, value2, *got2) +} + func TestApplicationVersion(t *testing.T) { version := metadata.NewAppVersion("parquet-mr version 1.7.9") version1 := metadata.NewAppVersion("parquet-mr version 1.8.0") diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 654d3d813cf85..3c20cf2d4757b 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -360,6 +360,51 @@ func simpleRoundTrip(t *testing.T, tbl arrow.Table, rowGroupSize int64) { } } +func TestWriteKeyValueMetadata(t *testing.T) { + kv := map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + sc := arrow.NewSchema([]arrow.Field{ + {Name: "int32", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + }, nil) + bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) + defer bldr.Release() + for _, b := range bldr.Fields() { + b.AppendNull() + } + + rec := bldr.NewRecord() + defer rec.Release() + + props := parquet.NewWriterProperties( + parquet.WithVersion(parquet.V1_0), + ) + var buf bytes.Buffer + fw, err := pqarrow.NewFileWriter(sc, &buf, props, pqarrow.DefaultWriterProps()) + require.NoError(t, err) + err = fw.Write(rec) + require.NoError(t, err) + + for key, value := range kv { + require.NoError(t, fw.AppendKeyValueMetadata(key, value)) + } + + err = fw.Close() + require.NoError(t, err) + + reader, err := file.NewParquetReader(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + + for key, value := range kv { + got := reader.MetaData().KeyValueMetadata().FindValue(key) + require.NotNil(t, got) + assert.Equal(t, value, *got) + } +} + func TestWriteEmptyLists(t *testing.T) { sc := arrow.NewSchema([]arrow.Field{ {Name: "f1", Type: arrow.ListOf(arrow.FixedWidthTypes.Date32)}, diff --git a/go/parquet/pqarrow/file_writer.go b/go/parquet/pqarrow/file_writer.go index 052220e716c77..aa0bae7b1fdfb 100644 --- a/go/parquet/pqarrow/file_writer.go +++ b/go/parquet/pqarrow/file_writer.go @@ -272,6 +272,11 @@ func (fw *FileWriter) WriteTable(tbl arrow.Table, chunkSize int64) error { return nil } +// AppendKeyValueMetadata appends a key/value pair to the existing key/value metadata +func (fw *FileWriter) AppendKeyValueMetadata(key string, value string) error { + return fw.wr.AppendKeyValueMetadata(key, value) +} + // Close flushes out the data and closes the file. It can be called multiple times, // subsequent calls after the first will have no effect. func (fw *FileWriter) Close() error { diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index 5640bc4349670..49e0f1720909f 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -29,6 +29,8 @@ #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/s3fs.h" #include "arrow/engine/substrait/util.h" +#include "arrow/engine/substrait/serde.h" +#include "arrow/engine/substrait/relation.h" #include "arrow/ipc/api.h" #include "arrow/util/iterator.h" #include "jni_util.h" @@ -200,7 +202,6 @@ arrow::Result> SchemaFromColumnNames( return arrow::Status::Invalid("Partition column '", ref.ToString(), "' is not in dataset schema"); } } - return schema(std::move(columns))->WithMetadata(input->metadata()); } } // namespace @@ -317,6 +318,14 @@ std::shared_ptr GetTableByName(const std::vector& nam return it->second; } +std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobject byte_buffer) { + const auto *buff = reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); + int length = env->GetDirectBufferCapacity(byte_buffer); + std::shared_ptr buffer = JniGetOrThrow(arrow::AllocateBuffer(length)); + std::memcpy(buffer->mutable_data(), buff, length); + return buffer; +} + /* * Class: org_apache_arrow_dataset_jni_NativeMemoryPool * Method: getDefaultMemoryPool @@ -455,11 +464,12 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: createScanner - * Signature: (J[Ljava/lang/String;JJ)J + * Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J */ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner( - JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, jlong batch_size, - jlong memory_pool_id) { + JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, + jobject substrait_projection, jobject substrait_filter, + jlong batch_size, jlong memory_pool_id) { JNI_METHOD_START arrow::MemoryPool* pool = reinterpret_cast(memory_pool_id); if (pool == nullptr) { @@ -474,6 +484,40 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann std::vector column_vector = ToStringVector(env, columns); JniAssertOkOrThrow(scanner_builder->Project(column_vector)); } + if (substrait_projection != nullptr) { + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, + substrait_projection); + std::vector project_exprs; + std::vector project_names; + arrow::engine::BoundExpressions bounded_expression = + JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer)); + for(arrow::engine::NamedExpression& named_expression : + bounded_expression.named_expressions) { + project_exprs.push_back(std::move(named_expression.expression)); + project_names.push_back(std::move(named_expression.name)); + } + JniAssertOkOrThrow(scanner_builder->Project(std::move(project_exprs), std::move(project_names))); + } + if (substrait_filter != nullptr) { + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, + substrait_filter); + std::optional filter_expr = std::nullopt; + arrow::engine::BoundExpressions bounded_expression = + JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer)); + for(arrow::engine::NamedExpression& named_expression : + bounded_expression.named_expressions) { + filter_expr = named_expression.expression; + if (named_expression.expression.type()->id() == arrow::Type::BOOL) { + filter_expr = named_expression.expression; + } else { + JniThrow("There is no filter expression in the expression provided"); + } + } + if (filter_expr == std::nullopt) { + JniThrow("The filter expression has not been provided"); + } + JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr)); + } JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size)); auto scanner = JniGetOrThrow(scanner_builder->Finish()); @@ -748,10 +792,7 @@ JNIEXPORT void JNICALL arrow::engine::ConversionOptions conversion_options; conversion_options.named_table_provider = std::move(table_provider); // mapping arrow::Buffer - auto *buff = reinterpret_cast(env->GetDirectBufferAddress(plan)); - int length = env->GetDirectBufferCapacity(plan); - std::shared_ptr buffer = JniGetOrThrow(arrow::AllocateBuffer(length)); - std::memcpy(buffer->mutable_data(), buff, length); + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, plan); // execute plan std::shared_ptr reader_out = JniGetOrThrow(arrow::engine::ExecuteSerializedPlan(*buffer, nullptr, nullptr, conversion_options)); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 93cc5d7a37040..a7df5be42f13b 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -17,6 +17,8 @@ package org.apache.arrow.dataset.jni; +import java.nio.ByteBuffer; + /** * JNI wrapper for Dataset API's native implementation. */ @@ -66,15 +68,19 @@ private JniWrapper() { /** * Create Scanner from a Dataset and get the native pointer of the Dataset. + * * @param datasetId the native pointer of the arrow::dataset::Dataset instance. * @param columns desired column names. * Columns not in this list will not be emitted when performing scan operation. Null equals * to "all columns". + * @param substraitProjection substrait extended expression to evaluate for project new columns + * @param substraitFilter substrait extended expression to evaluate for apply filter * @param batchSize batch size of scanned record batches. * @param memoryPool identifier of memory pool used in the native scanner. * @return the native pointer of the arrow::dataset::Scanner instance. */ - public native long createScanner(long datasetId, String[] columns, long batchSize, long memoryPool); + public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection, + ByteBuffer substraitFilter, long batchSize, long memoryPool); /** * Get a serialized schema from native instance of a Scanner. diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java index 30ff1a9302f7a..d9abad9971c4e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java @@ -40,8 +40,12 @@ public synchronized NativeScanner newScan(ScanOptions options) { if (closed) { throw new NativeInstanceReleasedException(); } + long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null), + options.getSubstraitProjection().orElse(null), + options.getSubstraitFilter().orElse(null), options.getBatchSize(), context.getMemoryPool().getNativeInstanceId()); + return new NativeScanner(context, scannerId); } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java index f5a1af384b24e..995d05ac3b314 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java @@ -17,6 +17,7 @@ package org.apache.arrow.dataset.scanner; +import java.nio.ByteBuffer; import java.util.Optional; import org.apache.arrow.util.Preconditions; @@ -25,8 +26,10 @@ * Options used during scanning. */ public class ScanOptions { - private final Optional columns; private final long batchSize; + private final Optional columns; + private final Optional substraitProjection; + private final Optional substraitFilter; /** * Constructor. @@ -56,6 +59,8 @@ public ScanOptions(long batchSize, Optional columns) { Preconditions.checkNotNull(columns); this.batchSize = batchSize; this.columns = columns; + this.substraitProjection = Optional.empty(); + this.substraitFilter = Optional.empty(); } public ScanOptions(long batchSize) { @@ -69,4 +74,77 @@ public Optional getColumns() { public long getBatchSize() { return batchSize; } + + public Optional getSubstraitProjection() { + return substraitProjection; + } + + public Optional getSubstraitFilter() { + return substraitFilter; + } + + /** + * Builder for Options used during scanning. + */ + public static class Builder { + private final long batchSize; + private Optional columns; + private ByteBuffer substraitProjection; + private ByteBuffer substraitFilter; + + /** + * Constructor. + * @param batchSize Maximum row number of each returned {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch} + */ + public Builder(long batchSize) { + this.batchSize = batchSize; + } + + /** + * Set the Projected columns. Empty for scanning all columns. + * + * @param columns Projected columns. Empty for scanning all columns. + * @return the ScanOptions configured. + */ + public Builder columns(Optional columns) { + Preconditions.checkNotNull(columns); + this.columns = columns; + return this; + } + + /** + * Set the Substrait extended expression for Projection new columns. + * + * @param substraitProjection Expressions to evaluate for project new columns. + * @return the ScanOptions configured. + */ + public Builder substraitProjection(ByteBuffer substraitProjection) { + Preconditions.checkNotNull(substraitProjection); + this.substraitProjection = substraitProjection; + return this; + } + + /** + * Set the Substrait extended expression for Filter. + * + * @param substraitFilter Expressions to evaluate for apply Filter. + * @return the ScanOptions configured. + */ + public Builder substraitFilter(ByteBuffer substraitFilter) { + Preconditions.checkNotNull(substraitFilter); + this.substraitFilter = substraitFilter; + return this; + } + + public ScanOptions build() { + return new ScanOptions(this); + } + } + + private ScanOptions(Builder builder) { + batchSize = builder.batchSize; + columns = builder.columns; + substraitProjection = Optional.ofNullable(builder.substraitProjection); + substraitFilter = Optional.ofNullable(builder.substraitFilter); + } } diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java index c23b7e002880a..0fba72892cdc6 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java @@ -18,6 +18,8 @@ package org.apache.arrow.dataset.substrait; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import java.nio.ByteBuffer; import java.nio.file.Files; @@ -27,6 +29,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import org.apache.arrow.dataset.ParquetWriteSupport; import org.apache.arrow.dataset.TestDataset; @@ -85,7 +88,7 @@ public void testRunQueryLocalFiles() throws Exception { } @Test - public void testRunQueryNamedTableNation() throws Exception { + public void testRunQueryNamedTable() throws Exception { //Query: //SELECT id, name FROM Users //Isthmus: @@ -123,7 +126,7 @@ public void testRunQueryNamedTableNation() throws Exception { } @Test(expected = RuntimeException.class) - public void testRunQueryNamedTableNationWithException() throws Exception { + public void testRunQueryNamedTableWithException() throws Exception { //Query: //SELECT id, name FROM Users //Isthmus: @@ -160,7 +163,7 @@ public void testRunQueryNamedTableNationWithException() throws Exception { } @Test - public void testRunBinaryQueryNamedTableNation() throws Exception { + public void testRunBinaryQueryNamedTable() throws Exception { //Query: //SELECT id, name FROM Users //Isthmus: @@ -187,9 +190,7 @@ public void testRunBinaryQueryNamedTableNation() throws Exception { Map mapTableToArrowReader = new HashMap<>(); mapTableToArrowReader.put("USERS", reader); // get binary plan - byte[] plan = Base64.getDecoder().decode(binaryPlan); - ByteBuffer substraitPlan = ByteBuffer.allocateDirect(plan.length); - substraitPlan.put(plan); + ByteBuffer substraitPlan = getByteBuffer(binaryPlan); // run query try (ArrowReader arrowReader = new AceroSubstraitConsumer(rootAllocator()).runQuery( substraitPlan, @@ -204,4 +205,256 @@ public void testRunBinaryQueryNamedTableNation() throws Exception { } } } + + @Test + public void testRunExtendedExpressionsFilter() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Filter: + // Expression 01: WHERE ID < 20 + String base64EncodedSubstraitFilter = "Ch4IARIaL2Z1bmN0aW9uc19jb21wYXJpc29uLnlhbWwSEhoQCAIQAhoKbHQ6YW55X2F" + + "ueRo3ChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGhdmaWx0ZXJfaWRfbG93ZXJfdGhhbl8yMCIaCgJJRAoETkFNRRIOCgQqAhA" + + "BCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + rowcount += reader.getVectorSchemaRoot().getRowCount(); + assertTrue(reader.getVectorSchemaRoot().getVector("id").toString().equals("[19, 1, 11]")); + assertTrue(reader.getVectorSchemaRoot().getVector("name").toString() + .equals("[value_19, value_1, value_11]")); + } + assertEquals(3, rowcount); + } + } + + @Test + public void testRunExtendedExpressionsFilterWithProjectionsInsteadOfFilterException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Project New Column: + // Expression ADD: id + 2 + // Expression CONCAT: name + '-' + name + String base64EncodedSubstraitFilter = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwSERoPCAEaC2FkZDppM" + + "zJfaTMyEhQaEggCEAEaDGNvbmNhdDp2Y2hhchoxChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoTYWRkX3R3b190b19jb2x1" + + "bW5fYRpGCi0aKwgBGgRiAhABIgoaCBIGCgQSAggBIgkaBwoFYgMgLSAiChoIEgYKBBICCAEaFWNvbmNhdF9jb2x1bW5fYV9hbmR" + + "fYiIaCgJJRAoETkFNRRIOCgQqAhABCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish() + ) { + Exception e = assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + assertTrue(e.getMessage().startsWith("There is no filter expression in the expression provided")); + } + } + + @Test + public void testRunExtendedExpressionsFilterWithEmptyFilterException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + String base64EncodedSubstraitFilter = ""; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish() + ) { + Exception e = assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + assertTrue(e.getMessage().contains("no anonymous struct type was provided to which names could be attached.")); + } + } + + @Test + public void testRunExtendedExpressionsProjection() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("add_two_to_column_a", new ArrowType.Int(32, true)), + Field.nullable("concat_column_a_and_b", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Project New Column: + // Expression ADD: id + 2 + // Expression CONCAT: name + '-' + name + String binarySubstraitExpressionProject = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwSERoPCAEaC2FkZDppM" + + "zJfaTMyEhQaEggCEAEaDGNvbmNhdDp2Y2hhchoxChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoTYWRkX3R3b190b19jb2x1" + + "bW5fYRpGCi0aKwgBGgRiAhABIgoaCBIGCgQSAggBIgkaBwoFYgMgLSAiChoIEgYKBBICCAEaFWNvbmNhdF9jb2x1bW5fYV9hbmR" + + "fYiIaCgJJRAoETkFNRRIOCgQqAhABCgRiAhABGAI="; + ByteBuffer substraitExpressionProject = getByteBuffer(binarySubstraitExpressionProject); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionProject) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + assertTrue(reader.getVectorSchemaRoot().getVector("add_two_to_column_a").toString() + .equals("[21, 3, 13, 23, 47]")); + assertTrue(reader.getVectorSchemaRoot().getVector("concat_column_a_and_b").toString() + .equals("[value_19 - value_19, value_1 - value_1, value_11 - value_11, " + + "value_21 - value_21, value_45 - value_45]")); + rowcount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(5, rowcount); + } + } + + @Test + public void testRunExtendedExpressionsProjectionWithFilterInsteadOfProjectionException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("filter_id_lower_than_20", new ArrowType.Bool()) + ), null); + // Substrait Extended Expression: Filter: + // Expression 01: WHERE ID < 20 + String binarySubstraitExpressionFilter = "Ch4IARIaL2Z1bmN0aW9uc19jb21wYXJpc29uLnlhbWwSEhoQCAIQAhoKbHQ6YW55X2F" + + "ueRo3ChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGhdmaWx0ZXJfaWRfbG93ZXJfdGhhbl8yMCIaCgJJRAoETkFNRRIOCgQqAhA" + + "BCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(binarySubstraitExpressionFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + assertTrue(reader.getVectorSchemaRoot().getVector("filter_id_lower_than_20").toString() + .equals("[true, true, true, false, false]")); + rowcount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(5, rowcount); + } + } + + @Test + public void testRunExtendedExpressionsProjectionWithEmptyProjectionException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + String base64EncodedSubstraitFilter = ""; + ByteBuffer substraitExpressionProjection = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionProjection) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish() + ) { + Exception e = assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + assertTrue(e.getMessage().contains("no anonymous struct type was provided to which names could be attached.")); + } + } + + @Test + public void testRunExtendedExpressionsProjectAndFilter() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("add_two_to_column_a", new ArrowType.Int(32, true)), + Field.nullable("concat_column_a_and_b", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Project New Column: + // Expression ADD: id + 2 + // Expression CONCAT: name + '-' + name + String binarySubstraitExpressionProject = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwSERoPCAEaC2FkZDppM" + + "zJfaTMyEhQaEggCEAEaDGNvbmNhdDp2Y2hhchoxChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoTYWRkX3R3b190b19jb2x1" + + "bW5fYRpGCi0aKwgBGgRiAhABIgoaCBIGCgQSAggBIgkaBwoFYgMgLSAiChoIEgYKBBICCAEaFWNvbmNhdF9jb2x1bW5fYV9hbmR" + + "fYiIaCgJJRAoETkFNRRIOCgQqAhABCgRiAhABGAI="; + ByteBuffer substraitExpressionProject = getByteBuffer(binarySubstraitExpressionProject); + // Substrait Extended Expression: Filter: + // Expression 01: WHERE ID < 20 + String base64EncodedSubstraitFilter = "Ch4IARIaL2Z1bmN0aW9uc19jb21wYXJpc29uLnlhbWwSEhoQCAIQAhoKbHQ6YW55X2F" + + "ueRo3ChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGhdmaWx0ZXJfaWRfbG93ZXJfdGhhbl8yMCIaCgJJRAoETkFNRRIOCgQqAhA" + + "BCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionProject) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + assertTrue(reader.getVectorSchemaRoot().getVector("add_two_to_column_a").toString() + .equals("[21, 3, 13]")); + assertTrue(reader.getVectorSchemaRoot().getVector("concat_column_a_and_b").toString() + .equals("[value_19 - value_19, value_1 - value_1, value_11 - value_11]")); + rowcount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(3, rowcount); + } + } + + private static ByteBuffer getByteBuffer(String base64EncodedSubstrait) { + byte[] decodedSubstrait = Base64.getDecoder().decode(base64EncodedSubstrait); + ByteBuffer substraitExpression = ByteBuffer.allocateDirect(decodedSubstrait.length); + substraitExpression.put(decodedSubstrait); + return substraitExpression; + } } diff --git a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java index 1f80f25266b57..5e6580b6131c1 100644 --- a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -124,6 +124,24 @@ public void write(${name}Holder holder) { } + <#if minor.class?ends_with("VarBinary")> + public void writeTo${minor.class}(byte[] value) { + fail("${name}"); + } + + public void writeTo${minor.class}(byte[] value, int offset, int length) { + fail("${name}"); + } + + public void writeTo${minor.class}(ByteBuffer value) { + fail("${name}"); + } + + public void writeTo${minor.class}(ByteBuffer value, int offset, int length) { + fail("${name}"); + } + + public void writeNull() { diff --git a/java/vector/src/main/codegen/templates/ComplexWriters.java b/java/vector/src/main/codegen/templates/ComplexWriters.java index 0b1e321afb70e..4ae4c4f75f208 100644 --- a/java/vector/src/main/codegen/templates/ComplexWriters.java +++ b/java/vector/src/main/codegen/templates/ComplexWriters.java @@ -180,6 +180,28 @@ public void writeNull() { vector.setValueCount(idx()+1); } + + <#if minor.class?ends_with("VarBinary")> + public void writeTo${minor.class}(byte[] value) { + vector.setSafe(idx(), value); + vector.setValueCount(idx() + 1); + } + + public void writeTo${minor.class}(byte[] value, int offset, int length) { + vector.setSafe(idx(), value, offset, length); + vector.setValueCount(idx() + 1); + } + + public void writeTo${minor.class}(ByteBuffer value) { + vector.setSafe(idx(), value, 0, value.remaining()); + vector.setValueCount(idx() + 1); + } + + public void writeTo${minor.class}(ByteBuffer value, int offset, int length) { + vector.setSafe(idx(), value, offset, length); + vector.setValueCount(idx() + 1); + } + } <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/writer/${eName}Writer.java" /> @@ -223,6 +245,17 @@ public interface ${eName}Writer extends BaseWriter { @Deprecated public void writeBigEndianBytesTo${minor.class}(byte[] value); + +<#if minor.class?ends_with("VarBinary")> + public void writeTo${minor.class}(byte[] value); + + public void writeTo${minor.class}(byte[] value, int offset, int length); + + public void writeTo${minor.class}(ByteBuffer value); + + public void writeTo${minor.class}(ByteBuffer value, int offset, int length); + + } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java new file mode 100644 index 0000000000000..7c06509b23c87 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.complex.writer; + +import java.nio.ByteBuffer; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.complex.impl.LargeVarBinaryWriterImpl; +import org.apache.arrow.vector.complex.impl.VarBinaryWriterImpl; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSimpleWriter { + + private BufferAllocator allocator; + + @Before + public void init() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @After + public void terminate() throws Exception { + allocator.close(); + } + + @Test + public void testWriteByteArrayToVarBinary() { + try (VarBinaryVector vector = new VarBinaryVector("test", allocator); + VarBinaryWriterImpl writer = new VarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + writer.writeToVarBinary(input); + byte[] result = vector.get(0); + Assert.assertArrayEquals(input, result); + } + } + + @Test + public void testWriteByteArrayWithOffsetToVarBinary() { + try (VarBinaryVector vector = new VarBinaryVector("test", allocator); + VarBinaryWriterImpl writer = new VarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + writer.writeToVarBinary(input, 1, 1); + byte[] result = vector.get(0); + Assert.assertArrayEquals(new byte[] { 0x02 }, result); + } + } + + @Test + public void testWriteByteBufferToVarBinary() { + try (VarBinaryVector vector = new VarBinaryVector("test", allocator); + VarBinaryWriterImpl writer = new VarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + ByteBuffer buffer = ByteBuffer.wrap(input); + writer.writeToVarBinary(buffer); + byte[] result = vector.get(0); + Assert.assertArrayEquals(input, result); + } + } + + @Test + public void testWriteByteBufferWithOffsetToVarBinary() { + try (VarBinaryVector vector = new VarBinaryVector("test", allocator); + VarBinaryWriterImpl writer = new VarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + ByteBuffer buffer = ByteBuffer.wrap(input); + writer.writeToVarBinary(buffer, 1, 1); + byte[] result = vector.get(0); + Assert.assertArrayEquals(new byte[] { 0x02 }, result); + } + } + + @Test + public void testWriteByteArrayToLargeVarBinary() { + try (LargeVarBinaryVector vector = new LargeVarBinaryVector("test", allocator); + LargeVarBinaryWriterImpl writer = new LargeVarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + writer.writeToLargeVarBinary(input); + byte[] result = vector.get(0); + Assert.assertArrayEquals(input, result); + } + } + + @Test + public void testWriteByteArrayWithOffsetToLargeVarBinary() { + try (LargeVarBinaryVector vector = new LargeVarBinaryVector("test", allocator); + LargeVarBinaryWriterImpl writer = new LargeVarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + writer.writeToLargeVarBinary(input, 1, 1); + byte[] result = vector.get(0); + Assert.assertArrayEquals(new byte[] { 0x02 }, result); + } + } + + @Test + public void testWriteByteBufferToLargeVarBinary() { + try (LargeVarBinaryVector vector = new LargeVarBinaryVector("test", allocator); + LargeVarBinaryWriterImpl writer = new LargeVarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + ByteBuffer buffer = ByteBuffer.wrap(input); + writer.writeToLargeVarBinary(buffer); + byte[] result = vector.get(0); + Assert.assertArrayEquals(input, result); + } + } + + @Test + public void testWriteByteBufferWithOffsetToLargeVarBinary() { + try (LargeVarBinaryVector vector = new LargeVarBinaryVector("test", allocator); + LargeVarBinaryWriterImpl writer = new LargeVarBinaryWriterImpl(vector)) { + byte[] input = new byte[] { 0x01, 0x02 }; + ByteBuffer buffer = ByteBuffer.wrap(input); + writer.writeToLargeVarBinary(buffer, 1, 1); + byte[] result = vector.get(0); + Assert.assertArrayEquals(new byte[] { 0x02 }, result); + } + } +} diff --git a/matlab/CMakeLists.txt b/matlab/CMakeLists.txt index d73173b58e78a..b7af37a278536 100644 --- a/matlab/CMakeLists.txt +++ b/matlab/CMakeLists.txt @@ -17,9 +17,9 @@ cmake_minimum_required(VERSION 3.20) -# Build the Arrow C++ libraries. +# Build the Arrow C++ libraries using ExternalProject_Add. function(build_arrow) - set(options BUILD_GTEST) + set(options) set(one_value_args) set(multi_value_args) @@ -34,73 +34,54 @@ function(build_arrow) set(ARROW_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/arrow_ep-prefix") set(ARROW_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/arrow_ep-build") - set(ARROW_CMAKE_ARGS "-DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX}" - "-DCMAKE_INSTALL_LIBDIR=lib" "-DARROW_BUILD_STATIC=OFF") - - if(Arrow_FOUND - AND NOT GTest_FOUND - AND ARG_BUILD_GTEST) - # If find_package has already found a valid Arrow installation, then - # we don't want to link against the Arrow libraries that will be built - # from source. - # - # However, we still need to create a library target to trigger building - # of the arrow_ep target, which will ultimately build the bundled - # GoogleTest binaries. - add_library(arrow_shared_for_gtest SHARED IMPORTED) - set(ARROW_LIBRARY_TARGET arrow_shared_for_gtest) + set(ARROW_CMAKE_ARGS + "-DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX}" "-DCMAKE_INSTALL_LIBDIR=lib" + "-DARROW_BUILD_STATIC=OFF" "-DARROW_CSV=ON") + + add_library(arrow_shared SHARED IMPORTED) + set(ARROW_LIBRARY_TARGET arrow_shared) + + # Set the runtime shared library (.dll, .so, or .dylib) + if(WIN32) + # The shared library (i.e. .dll) is located in the "bin" directory. + set(ARROW_SHARED_LIBRARY_DIR "${ARROW_PREFIX}/bin") else() - add_library(arrow_shared SHARED IMPORTED) - set(ARROW_LIBRARY_TARGET arrow_shared) - - # Set the runtime shared library (.dll, .so, or .dylib) - if(WIN32) - # The shared library (i.e. .dll) is located in the "bin" directory. - set(ARROW_SHARED_LIBRARY_DIR "${ARROW_PREFIX}/bin") - else() - # The shared library (i.e. .so or .dylib) is located in the "lib" directory. - set(ARROW_SHARED_LIBRARY_DIR "${ARROW_PREFIX}/lib") - endif() - - set(ARROW_SHARED_LIB_FILENAME - "${CMAKE_SHARED_LIBRARY_PREFIX}arrow${CMAKE_SHARED_LIBRARY_SUFFIX}") - set(ARROW_SHARED_LIB "${ARROW_SHARED_LIBRARY_DIR}/${ARROW_SHARED_LIB_FILENAME}") - - set_target_properties(arrow_shared PROPERTIES IMPORTED_LOCATION ${ARROW_SHARED_LIB}) - - # Set the link-time import library (.lib) - if(WIN32) - # The import library (i.e. .lib) is located in the "lib" directory. - set(ARROW_IMPORT_LIB_FILENAME - "${CMAKE_IMPORT_LIBRARY_PREFIX}arrow${CMAKE_IMPORT_LIBRARY_SUFFIX}") - set(ARROW_IMPORT_LIB "${ARROW_PREFIX}/lib/${ARROW_IMPORT_LIB_FILENAME}") - - set_target_properties(arrow_shared PROPERTIES IMPORTED_IMPLIB ${ARROW_IMPORT_LIB}) - endif() - - # Set the include directories - set(ARROW_INCLUDE_DIR "${ARROW_PREFIX}/include") - file(MAKE_DIRECTORY "${ARROW_INCLUDE_DIR}") - set_target_properties(arrow_shared PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - ${ARROW_INCLUDE_DIR}) - - # Set the build byproducts for the ExternalProject build - # The appropriate libraries need to be guaranteed to be available when linking the test - # executables. - if(WIN32) - set(ARROW_BUILD_BYPRODUCTS "${ARROW_IMPORT_LIB}") - else() - set(ARROW_BUILD_BYPRODUCTS "${ARROW_SHARED_LIB}") - endif() + # The shared library (i.e. .so or .dylib) is located in the "lib" directory. + set(ARROW_SHARED_LIBRARY_DIR "${ARROW_PREFIX}/lib") endif() - # Building the Arrow C++ libraries and bundled GoogleTest binaries requires ExternalProject. - include(ExternalProject) + set(ARROW_SHARED_LIB_FILENAME + "${CMAKE_SHARED_LIBRARY_PREFIX}arrow${CMAKE_SHARED_LIBRARY_SUFFIX}") + set(ARROW_SHARED_LIB "${ARROW_SHARED_LIBRARY_DIR}/${ARROW_SHARED_LIB_FILENAME}") + + set_target_properties(arrow_shared PROPERTIES IMPORTED_LOCATION ${ARROW_SHARED_LIB}) + + # Set the link-time import library (.lib) + if(WIN32) + # The import library (i.e. .lib) is located in the "lib" directory. + set(ARROW_IMPORT_LIB_FILENAME + "${CMAKE_IMPORT_LIBRARY_PREFIX}arrow${CMAKE_IMPORT_LIBRARY_SUFFIX}") + set(ARROW_IMPORT_LIB "${ARROW_PREFIX}/lib/${ARROW_IMPORT_LIB_FILENAME}") - if(ARG_BUILD_GTEST) - enable_gtest() + set_target_properties(arrow_shared PROPERTIES IMPORTED_IMPLIB ${ARROW_IMPORT_LIB}) endif() + # Set the include directories + set(ARROW_INCLUDE_DIR "${ARROW_PREFIX}/include") + file(MAKE_DIRECTORY "${ARROW_INCLUDE_DIR}") + set_target_properties(arrow_shared PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ${ARROW_INCLUDE_DIR}) + + # Set the build byproducts for the ExternalProject build + if(WIN32) + set(ARROW_BUILD_BYPRODUCTS "${ARROW_IMPORT_LIB}") + else() + set(ARROW_BUILD_BYPRODUCTS "${ARROW_SHARED_LIB}") + endif() + + # Building the Arrow C++ libraries requires ExternalProject. + include(ExternalProject) + externalproject_add(arrow_ep SOURCE_DIR "${CMAKE_SOURCE_DIR}/../cpp" BINARY_DIR "${ARROW_BINARY_DIR}" @@ -109,69 +90,8 @@ function(build_arrow) add_dependencies(${ARROW_LIBRARY_TARGET} arrow_ep) - if(ARG_BUILD_GTEST) - build_gtest() - endif() endfunction() -macro(enable_gtest) - set(ARROW_GTEST_INCLUDE_DIR "${ARROW_PREFIX}/include/arrow-gtest") - - set(ARROW_GTEST_IMPORT_LIB_DIR "${ARROW_PREFIX}/lib") - if(WIN32) - set(ARROW_GTEST_SHARED_LIB_DIR "${ARROW_PREFIX}/bin") - else() - set(ARROW_GTEST_SHARED_LIB_DIR "${ARROW_PREFIX}/lib") - endif() - set(ARROW_GTEST_IMPORT_LIB - "${ARROW_GTEST_IMPORT_LIB_DIR}/${CMAKE_IMPORT_LIBRARY_PREFIX}arrow_gtest${CMAKE_IMPORT_LIBRARY_SUFFIX}" - ) - set(ARROW_GTEST_MAIN_IMPORT_LIB - "${ARROW_GTEST_IMPORT_LIB_DIR}/${CMAKE_IMPORT_LIBRARY_PREFIX}arrow_gtest_main${CMAKE_IMPORT_LIBRARY_SUFFIX}" - ) - set(ARROW_GTEST_SHARED_LIB - "${ARROW_GTEST_SHARED_LIB_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}arrow_gtest${CMAKE_SHARED_LIBRARY_SUFFIX}" - ) - set(ARROW_GTEST_MAIN_SHARED_LIB - "${ARROW_GTEST_SHARED_LIB_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}arrow_gtest_main${CMAKE_SHARED_LIBRARY_SUFFIX}" - ) - - list(APPEND ARROW_CMAKE_ARGS "-DARROW_BUILD_TESTS=ON") - - # The appropriate libraries need to be guaranteed to be available when linking the test - # executables. - if(WIN32) - # On Windows, add the gtest link libraries as BUILD_BYPRODUCTS for arrow_ep. - list(APPEND ARROW_BUILD_BYPRODUCTS "${ARROW_GTEST_IMPORT_LIB}" - "${ARROW_GTEST_MAIN_IMPORT_LIB}") - else() - # On Linux and macOS, add the gtest shared libraries as BUILD_BYPRODUCTS for arrow_ep. - list(APPEND ARROW_BUILD_BYPRODUCTS "${ARROW_GTEST_SHARED_LIB}" - "${ARROW_GTEST_MAIN_SHARED_LIB}") - endif() -endmacro() - -# Build the GoogleTest binaries that are bundled with the Arrow C++ libraries. -macro(build_gtest) - file(MAKE_DIRECTORY "${ARROW_GTEST_INCLUDE_DIR}") - - # Create target GTest::gtest - add_library(GTest::gtest SHARED IMPORTED) - set_target_properties(GTest::gtest - PROPERTIES IMPORTED_IMPLIB ${ARROW_GTEST_IMPORT_LIB} - IMPORTED_LOCATION ${ARROW_GTEST_SHARED_LIB} - INTERFACE_INCLUDE_DIRECTORIES - ${ARROW_GTEST_INCLUDE_DIR}) - add_dependencies(GTest::gtest arrow_ep) - - # Create target GTest::gtest_main - add_library(GTest::gtest_main SHARED IMPORTED) - set_target_properties(GTest::gtest_main - PROPERTIES IMPORTED_IMPLIB ${ARROW_GTEST_MAIN_IMPORT_LIB} - IMPORTED_LOCATION ${ARROW_GTEST_MAIN_SHARED_LIB}) - add_dependencies(GTest::gtest_main arrow_ep) -endmacro() - set(CMAKE_CXX_STANDARD 17) set(MLARROW_VERSION "14.0.0-SNAPSHOT") @@ -185,8 +105,6 @@ if(WIN32) set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL") endif() -option(MATLAB_BUILD_TESTS "Build the C++ tests for the MATLAB interface" OFF) - # Add tools/cmake directory to the CMAKE_MODULE_PATH. list(PREPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/tools/cmake) @@ -208,56 +126,9 @@ else() set(MATLAB_BUILD_OUTPUT_DIR "${CMAKE_BINARY_DIR}") endif() -# Only build the MATLAB interface C++ tests if MATLAB_BUILD_TESTS=ON. -if(MATLAB_BUILD_TESTS) - # find_package(GTest) supports custom GTEST_ROOT as well as package managers. - find_package(GTest) - - if(NOT GTest_FOUND) - # find_package(Arrow) supports custom ARROW_HOME as well as package - # managers. - find_package(Arrow QUIET) - # Trigger an automatic build of the Arrow C++ libraries and bundled - # GoogleTest binaries. If a valid Arrow installation was not already - # found by find_package, then build_arrow will use the Arrow - # C++ libraries that are built from source. - build_arrow(BUILD_GTEST) - else() - # On Windows, IMPORTED_LOCATION needs to be set to indicate where the shared - # libraries live when GTest is found. - if(WIN32) - set(GTEST_SHARED_LIB_DIR "${GTEST_ROOT}/bin") - set(GTEST_SHARED_LIBRARY_FILENAME - "${CMAKE_SHARED_LIBRARY_PREFIX}gtest${CMAKE_SHARED_LIBRARY_SUFFIX}") - set(GTEST_SHARED_LIBRARY_LIB - "${GTEST_SHARED_LIB_DIR}/${GTEST_SHARED_LIBRARY_FILENAME}") - - set(GTEST_MAIN_SHARED_LIB_DIR "${GTEST_ROOT}/bin") - set(GTEST_MAIN_SHARED_LIBRARY_FILENAME - "${CMAKE_SHARED_LIBRARY_PREFIX}gtest_main${CMAKE_SHARED_LIBRARY_SUFFIX}") - set(GTEST_MAIN_SHARED_LIBRARY_LIB - "${GTEST_MAIN_SHARED_LIB_DIR}/${GTEST_MAIN_SHARED_LIBRARY_FILENAME}") - - set_target_properties(GTest::gtest PROPERTIES IMPORTED_LOCATION - "${GTEST_SHARED_LIBRARY_LIB}") - - set_target_properties(GTest::gtest_main - PROPERTIES IMPORTED_LOCATION - "${GTEST_MAIN_SHARED_LIBRARY_LIB}") - endif() - - find_package(Arrow QUIET) - if(NOT Arrow_FOUND) - # Trigger an automatic build of the Arrow C++ libraries. - build_arrow() - endif() - endif() - -else() - find_package(Arrow QUIET) - if(NOT Arrow_FOUND) - build_arrow() - endif() +find_package(Arrow QUIET) +if(NOT Arrow_FOUND) + build_arrow() endif() # MATLAB is Required @@ -311,56 +182,6 @@ else() message(STATUS "ARROW_INCLUDE_DIR: ${ARROW_INCLUDE_DIR}") endif() -# ############################################################################## -# C++ Tests -# ############################################################################## -# Only build the C++ tests if MATLAB_BUILD_TESTS=ON. -if(MATLAB_BUILD_TESTS) - enable_testing() - - # Define a test executable target. TODO: Remove the placeholder test. This is - # just for testing GoogleTest integration. - add_executable(placeholder_test ${CMAKE_SOURCE_DIR}/src/placeholder_test.cc) - - # Declare a dependency on the GTest::gtest and GTest::gtest_main IMPORTED - # targets. - target_link_libraries(placeholder_test GTest::gtest GTest::gtest_main) - - # Ensure using GTest:gtest and GTest::gtest_main on macOS without - # specifying DYLD_LIBRARY_DIR. - set_target_properties(placeholder_test - PROPERTIES BUILD_RPATH - "$;$" - ) - - # Add test targets for C++ tests. - add_test(PlaceholderTestTarget placeholder_test) - - # On Windows: - # Add the directory of gtest.dll and gtest_main.dll to the %PATH% for running - # all tests. - # Add the directory of libmx.dll, libmex.dll, and libarrow.dll to the %PATH% for running - # CheckNumArgsTestTarget. - # Note: When appending to the path using set_test_properties' ENVIRONMENT property, - # make sure that we escape ';' to prevent CMake from interpreting the input as - # a list of strings. - if(WIN32) - get_target_property(GTEST_SHARED_LIB GTest::gtest IMPORTED_LOCATION) - get_filename_component(GTEST_SHARED_LIB_DIR ${GTEST_SHARED_LIB} DIRECTORY) - - get_target_property(GTEST_MAIN_SHARED_LIB GTest::gtest_main IMPORTED_LOCATION) - get_filename_component(GTEST_MAIN_SHARED_LIB_DIR ${GTEST_MAIN_SHARED_LIB} DIRECTORY) - - set_tests_properties(PlaceholderTestTarget - PROPERTIES ENVIRONMENT - "PATH=${GTEST_SHARED_LIB_DIR}\;${GTEST_MAIN_SHARED_LIB_DIR}\;$ENV{PATH}" - ) - - get_target_property(ARROW_SHARED_LIB arrow_shared IMPORTED_LOCATION) - get_filename_component(ARROW_SHARED_LIB_DIR ${ARROW_SHARED_LIB} DIRECTORY) - endif() -endif() - # ############################################################################## # Install # ############################################################################## diff --git a/matlab/README.md b/matlab/README.md index d6b08fbee1c15..0a2bdf01f465f 100644 --- a/matlab/README.md +++ b/matlab/README.md @@ -100,31 +100,12 @@ As part of the install step, the installation directory is added to the [MATLAB ## Test -There are two kinds of tests for the MATLAB Interface: MATLAB and C++. - -### MATLAB - To run the MATLAB tests, start MATLAB in the `arrow/matlab` directory and call the [`runtests`](https://mathworks.com/help/matlab/ref/runtests.html) command on the `test` directory with `IncludeSubFolders=true`: ``` matlab >> runtests("test", IncludeSubFolders=true); ``` -### C++ - -To enable the C++ tests, set the `MATLAB_BUILD_TESTS` flag to `ON` at build time: - -```console -$ cmake -S . -B build -D MATLAB_BUILD_TESTS=ON -$ cmake --build build --config Release -``` - -After building with the `MATLAB_BUILD_TESTS` flag enabled, the C++ tests can be run using [CTest](https://cmake.org/cmake/help/latest/manual/ctest.1.html): - -```console -$ ctest --test-dir build -``` - ## Usage Included below are some example code snippets that illustrate how to use the MATLAB interface. diff --git a/matlab/doc/matlab_interface_for_apache_arrow_design.md b/matlab/doc/matlab_interface_for_apache_arrow_design.md index 79b43fd02518b..17c7ba254c0ea 100644 --- a/matlab/doc/matlab_interface_for_apache_arrow_design.md +++ b/matlab/doc/matlab_interface_for_apache_arrow_design.md @@ -257,14 +257,13 @@ For large tables used in a multi-process "data processing pipeline", a user coul ## Testing To ensure code quality, we would like to include the following testing infrastructure, at a minimum: -1. C++ APIs - - GoogleTest C++ Unit Tests - - Integration with CI workflows -2. MATLAB APIs - - [MATLAB Class-Based Unit Tests] - - Integration with CI workflows + +1. [MATLAB Class-Based Unit Tests] +2. [MATLAB CI Workflows] 3. [Integration Testing] +**Note**: To test internal C++ code, we can use a [MEX function] to call the C++ code from a MATLAB Class-Based Unit Test. + ## Documentation To ensure usability, discoverability, and accessibility, we would like to include high quality documentation for the MATLAB Interface for Apache Arrow. @@ -318,3 +317,4 @@ The table below provides a high-level roadmap for the development of specific ca [`apache-arrow` package via the `npm` package manager]: https://www.npmjs.com/package/apache-arrow [Rust user]: https://github.com/apache/arrow-rs [`arrow` crate via the `cargo` package manager]: https://crates.io/crates/arrow +[MATLAB CI Workflows]: https://github.com/apache/arrow/actions/workflows/matlab.yml diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/array.cc b/matlab/src/cpp/arrow/matlab/array/proxy/array.cc index ed6152259891d..5fa533632f928 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/array.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/array.cc @@ -31,7 +31,6 @@ namespace arrow::matlab::array::proxy { // Register Proxy methods. REGISTER_METHOD(Array, toString); - REGISTER_METHOD(Array, toMATLAB); REGISTER_METHOD(Array, getLength); REGISTER_METHOD(Array, getValid); REGISTER_METHOD(Array, getType); diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/array.h b/matlab/src/cpp/arrow/matlab/array/proxy/array.h index 185e107f75391..46e1fa5a81380 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/array.h @@ -42,8 +42,6 @@ class Array : public libmexclass::proxy::Proxy { void getType(libmexclass::proxy::method::Context& context); - virtual void toMATLAB(libmexclass::proxy::method::Context& context) = 0; - void isEqual(libmexclass::proxy::method::Context& context); std::shared_ptr array; diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.cc b/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.cc index 5be0cfb5a3d13..6a6e478274823 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.cc @@ -25,7 +25,9 @@ namespace arrow::matlab::array::proxy { BooleanArray::BooleanArray(std::shared_ptr array) - : arrow::matlab::array::proxy::Array{std::move(array)} {} + : arrow::matlab::array::proxy::Array{std::move(array)} { + REGISTER_METHOD(BooleanArray, toMATLAB); + } libmexclass::proxy::MakeResult BooleanArray::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { ::matlab::data::StructArray opts = constructor_arguments[0]; diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.h b/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.h index 775673c29eada..edc00b178e42a 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/boolean_array.h @@ -31,7 +31,7 @@ namespace arrow::matlab::array::proxy { static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments); protected: - void toMATLAB(libmexclass::proxy::method::Context& context) override; + void toMATLAB(libmexclass::proxy::method::Context& context); }; } diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h b/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h index f9da38dbaa062..4b4ddb6588678 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h @@ -40,7 +40,9 @@ class NumericArray : public arrow::matlab::array::proxy::Array { public: NumericArray(const std::shared_ptr> numeric_array) - : arrow::matlab::array::proxy::Array{std::move(numeric_array)} {} + : arrow::matlab::array::proxy::Array{std::move(numeric_array)} { + REGISTER_METHOD(NumericArray, toMATLAB); + } static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { using MatlabBuffer = arrow::matlab::buffer::MatlabBuffer; @@ -67,7 +69,7 @@ class NumericArray : public arrow::matlab::array::proxy::Array { } protected: - void toMATLAB(libmexclass::proxy::method::Context& context) override { + void toMATLAB(libmexclass::proxy::method::Context& context) { using CType = typename arrow::TypeTraits::CType; using NumericArray = arrow::NumericArray; diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/string_array.cc b/matlab/src/cpp/arrow/matlab/array/proxy/string_array.cc index c583e8851a3ac..7160e88a3c8a0 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/string_array.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/string_array.cc @@ -28,7 +28,9 @@ namespace arrow::matlab::array::proxy { StringArray::StringArray(const std::shared_ptr string_array) - : arrow::matlab::array::proxy::Array(std::move(string_array)) {} + : arrow::matlab::array::proxy::Array(std::move(string_array)) { + REGISTER_METHOD(StringArray, toMATLAB); + } libmexclass::proxy::MakeResult StringArray::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { namespace mda = ::matlab::data; diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/string_array.h b/matlab/src/cpp/arrow/matlab/array/proxy/string_array.h index bdcfedd7cdda3..4cc01f0a02f8c 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/string_array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/string_array.h @@ -32,7 +32,7 @@ namespace arrow::matlab::array::proxy { static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments); protected: - void toMATLAB(libmexclass::proxy::method::Context& context) override; + void toMATLAB(libmexclass::proxy::method::Context& context); }; } diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/struct_array.cc b/matlab/src/cpp/arrow/matlab/array/proxy/struct_array.cc new file mode 100644 index 0000000000000..c6d9e47a9b0c4 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/array/proxy/struct_array.cc @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/matlab/array/proxy/struct_array.h" +#include "arrow/matlab/array/proxy/wrap.h" +#include "arrow/matlab/bit/pack.h" +#include "arrow/matlab/error/error.h" +#include "arrow/matlab/index/validate.h" + +#include "arrow/util/utf8.h" + +#include "libmexclass/proxy/ProxyManager.h" + +namespace arrow::matlab::array::proxy { + + StructArray::StructArray(std::shared_ptr struct_array) + : proxy::Array{std::move(struct_array)} { + REGISTER_METHOD(StructArray, getNumFields); + REGISTER_METHOD(StructArray, getFieldByIndex); + REGISTER_METHOD(StructArray, getFieldByName); + REGISTER_METHOD(StructArray, getFieldNames); + } + + libmexclass::proxy::MakeResult StructArray::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { + namespace mda = ::matlab::data; + using libmexclass::proxy::ProxyManager; + + mda::StructArray opts = constructor_arguments[0]; + const mda::TypedArray arrow_array_proxy_ids = opts[0]["ArrayProxyIDs"]; + const mda::StringArray field_names_mda = opts[0]["FieldNames"]; + const mda::TypedArray validity_bitmap_mda = opts[0]["Valid"]; + + std::vector> arrow_arrays; + arrow_arrays.reserve(arrow_array_proxy_ids.getNumberOfElements()); + + // Retrieve all of the Arrow Array Proxy instances from the libmexclass ProxyManager. + for (const auto& arrow_array_proxy_id : arrow_array_proxy_ids) { + auto proxy = ProxyManager::getProxy(arrow_array_proxy_id); + auto arrow_array_proxy = std::static_pointer_cast(proxy); + auto arrow_array = arrow_array_proxy->unwrap(); + arrow_arrays.push_back(arrow_array); + } + + // Convert the utf-16 encoded field names into utf-8 encoded strings + std::vector field_names; + field_names.reserve(field_names_mda.getNumberOfElements()); + for (const auto& field_name : field_names_mda) { + const auto field_name_utf16 = std::u16string(field_name); + MATLAB_ASSIGN_OR_ERROR(const auto field_name_utf8, + arrow::util::UTF16StringToUTF8(field_name_utf16), + error::UNICODE_CONVERSION_ERROR_ID); + field_names.push_back(field_name_utf8); + } + + // Pack the validity bitmap values. + MATLAB_ASSIGN_OR_ERROR(auto validity_bitmap_buffer, + bit::packValid(validity_bitmap_mda), + error::BITPACK_VALIDITY_BITMAP_ERROR_ID); + + // Create the StructArray + MATLAB_ASSIGN_OR_ERROR(auto array, + arrow::StructArray::Make(arrow_arrays, field_names, validity_bitmap_buffer), + error::STRUCT_ARRAY_MAKE_FAILED); + + // Construct the StructArray Proxy + auto struct_array = std::static_pointer_cast(array); + return std::make_shared(std::move(struct_array)); + } + + void StructArray::getNumFields(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + + mda::ArrayFactory factory; + const auto num_fields = array->type()->num_fields(); + context.outputs[0] = factory.createScalar(num_fields); + } + + void StructArray::getFieldByIndex(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + using namespace libmexclass::proxy; + + mda::StructArray args = context.inputs[0]; + const mda::TypedArray index_mda = args[0]["Index"]; + const auto matlab_index = int32_t(index_mda[0]); + + auto struct_array = std::static_pointer_cast(array); + + const auto num_fields = struct_array->type()->num_fields(); + + // Validate there is at least 1 field + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT( + index::validateNonEmptyContainer(num_fields), + context, error::INDEX_EMPTY_CONTAINER); + + // Validate the matlab index provided is within the range [1, num_fields] + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT( + index::validateInRange(matlab_index, num_fields), + context, error::INDEX_OUT_OF_RANGE); + + // Note: MATLAB uses 1-based indexing, so subtract 1. + const int32_t index = matlab_index - 1; + + auto field_array = struct_array->field(index); + + // Wrap the array within a proxy object if possible. + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto field_array_proxy, + proxy::wrap(field_array), + context, error::UNKNOWN_PROXY_FOR_ARRAY_TYPE); + const auto field_array_proxy_id = ProxyManager::manageProxy(field_array_proxy); + const auto type_id = field_array->type_id(); + + // Return a struct with two fields: ProxyID and TypeID. The MATLAB + // layer will use these values to construct the appropriate MATLAB + // arrow.array.Array subclass. + mda::ArrayFactory factory; + mda::StructArray output = factory.createStructArray({1, 1}, {"ProxyID", "TypeID"}); + output[0]["ProxyID"] = factory.createScalar(field_array_proxy_id); + output[0]["TypeID"] = factory.createScalar(static_cast(type_id)); + context.outputs[0] = output; + } + + void StructArray::getFieldByName(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + using libmexclass::proxy::ProxyManager; + + mda::StructArray args = context.inputs[0]; + + const mda::StringArray name_mda = args[0]["Name"]; + const auto name_utf16 = std::u16string(name_mda[0]); + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto name, + arrow::util::UTF16StringToUTF8(name_utf16), + context, error::UNICODE_CONVERSION_ERROR_ID); + + + auto struct_array = std::static_pointer_cast(array); + auto field_array = struct_array->GetFieldByName(name); + if (!field_array) { + // Return an error if we could not query the field by name. + const auto msg = "Could not find field named " + name + "."; + context.error = libmexclass::error::Error{ + error::ARROW_TABULAR_SCHEMA_AMBIGUOUS_FIELD_NAME, msg}; + return; + } + + // Wrap the array within a proxy object if possible. + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto field_array_proxy, + proxy::wrap(field_array), + context, error::UNKNOWN_PROXY_FOR_ARRAY_TYPE); + const auto field_array_proxy_id = ProxyManager::manageProxy(field_array_proxy); + const auto type_id = field_array->type_id(); + + // Return a struct with two fields: ProxyID and TypeID. The MATLAB + // layer will use these values to construct the appropriate MATLAB + // arrow.array.Array subclass. + mda::ArrayFactory factory; + mda::StructArray output = factory.createStructArray({1, 1}, {"ProxyID", "TypeID"}); + output[0]["ProxyID"] = factory.createScalar(field_array_proxy_id); + output[0]["TypeID"] = factory.createScalar(static_cast(type_id)); + context.outputs[0] = output; + } + + void StructArray::getFieldNames(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + + const auto& fields = array->type()->fields(); + const auto num_fields = fields.size(); + std::vector names; + names.reserve(num_fields); + + for (size_t i = 0; i < num_fields; ++i) { + auto str_utf8 = fields[i]->name(); + + // MATLAB strings are UTF-16 encoded. Must convert UTF-8 + // encoded field names before returning to MATLAB. + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto str_utf16, + arrow::util::UTF8StringToUTF16(str_utf8), + context, error::UNICODE_CONVERSION_ERROR_ID); + const mda::MATLABString matlab_string = mda::MATLABString(std::move(str_utf16)); + names.push_back(matlab_string); + } + + mda::ArrayFactory factory; + context.outputs[0] = factory.createArray({1, num_fields}, names.begin(), names.end()); + } +} diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/struct_array.h b/matlab/src/cpp/arrow/matlab/array/proxy/struct_array.h new file mode 100644 index 0000000000000..cfb548c4e50df --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/array/proxy/struct_array.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/matlab/array/proxy/array.h" + +namespace arrow::matlab::array::proxy { + +class StructArray : public arrow::matlab::array::proxy::Array { + public: + StructArray(std::shared_ptr struct_array); + + ~StructArray() {} + + static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments); + + protected: + + void getNumFields(libmexclass::proxy::method::Context& context); + + void getFieldByIndex(libmexclass::proxy::method::Context& context); + + void getFieldByName(libmexclass::proxy::method::Context& context); + + void getFieldNames(libmexclass::proxy::method::Context& context); + +}; + +} diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc b/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc index a8e3f239919cc..b14f4b18711cb 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc @@ -21,6 +21,7 @@ #include "arrow/matlab/array/proxy/boolean_array.h" #include "arrow/matlab/array/proxy/numeric_array.h" #include "arrow/matlab/array/proxy/string_array.h" +#include "arrow/matlab/array/proxy/struct_array.h" namespace arrow::matlab::array::proxy { @@ -61,6 +62,8 @@ namespace arrow::matlab::array::proxy { return std::make_shared>(std::static_pointer_cast(array)); case ID::STRING: return std::make_shared(std::static_pointer_cast(array)); + case ID::STRUCT: + return std::make_shared(std::static_pointer_cast(array)); default: return arrow::Status::NotImplemented("Unsupported DataType: " + array->type()->ToString()); } diff --git a/matlab/src/cpp/arrow/matlab/error/error.h b/matlab/src/cpp/arrow/matlab/error/error.h index 4ff77da8d8360..347bc25b5f3a6 100644 --- a/matlab/src/cpp/arrow/matlab/error/error.h +++ b/matlab/src/cpp/arrow/matlab/error/error.h @@ -182,6 +182,9 @@ namespace arrow::matlab::error { static const char* TABLE_INVALID_NUMERIC_COLUMN_INDEX = "arrow:tabular:table:InvalidNumericColumnIndex"; static const char* FAILED_TO_OPEN_FILE_FOR_WRITE = "arrow:io:FailedToOpenFileForWrite"; static const char* FAILED_TO_OPEN_FILE_FOR_READ = "arrow:io:FailedToOpenFileForRead"; + static const char* CSV_FAILED_TO_WRITE_TABLE = "arrow:io:csv:FailedToWriteTable"; + static const char* CSV_FAILED_TO_CREATE_TABLE_READER = "arrow:io:csv:FailedToCreateTableReader"; + static const char* CSV_FAILED_TO_READ_TABLE = "arrow:io:csv:FailedToReadTable"; static const char* FEATHER_FAILED_TO_WRITE_TABLE = "arrow:io:feather:FailedToWriteTable"; static const char* TABLE_FROM_RECORD_BATCH = "arrow:table:FromRecordBatch"; static const char* FEATHER_FAILED_TO_CREATE_READER = "arrow:io:feather:FailedToCreateReader"; @@ -192,7 +195,7 @@ namespace arrow::matlab::error { static const char* CHUNKED_ARRAY_MAKE_FAILED = "arrow:chunkedarray:MakeFailed"; static const char* CHUNKED_ARRAY_NUMERIC_INDEX_WITH_EMPTY_CHUNKED_ARRAY = "arrow:chunkedarray:NumericIndexWithEmptyChunkedArray"; static const char* CHUNKED_ARRAY_INVALID_NUMERIC_CHUNK_INDEX = "arrow:chunkedarray:InvalidNumericChunkIndex"; - + static const char* STRUCT_ARRAY_MAKE_FAILED = "arrow:array:StructArrayMakeFailed"; static const char* INDEX_EMPTY_CONTAINER = "arrow:index:EmptyContainer"; static const char* INDEX_OUT_OF_RANGE = "arrow:index:OutOfRange"; } diff --git a/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_reader.cc b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_reader.cc new file mode 100644 index 0000000000000..ab9935ce145a8 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_reader.cc @@ -0,0 +1,93 @@ +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "libmexclass/proxy/ProxyManager.h" + +#include "arrow/matlab/error/error.h" +#include "arrow/matlab/io/csv/proxy/table_reader.h" +#include "arrow/matlab/tabular/proxy/table.h" + +#include "arrow/util/utf8.h" + +#include "arrow/result.h" + +#include "arrow/io/file.h" +#include "arrow/io/interfaces.h" +#include "arrow/csv/reader.h" +#include "arrow/table.h" + +namespace arrow::matlab::io::csv::proxy { + + TableReader::TableReader(const std::string& filename) : filename{filename} { + REGISTER_METHOD(TableReader, read); + REGISTER_METHOD(TableReader, getFilename); + } + + libmexclass::proxy::MakeResult TableReader::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { + namespace mda = ::matlab::data; + using TableReaderProxy = arrow::matlab::io::csv::proxy::TableReader; + + mda::StructArray args = constructor_arguments[0]; + const mda::StringArray filename_utf16_mda = args[0]["Filename"]; + const auto filename_utf16 = std::u16string(filename_utf16_mda[0]); + MATLAB_ASSIGN_OR_ERROR(const auto filename, arrow::util::UTF16StringToUTF8(filename_utf16), error::UNICODE_CONVERSION_ERROR_ID); + + return std::make_shared(filename); + } + + void TableReader::read(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + using namespace libmexclass::proxy; + namespace csv = ::arrow::csv; + using TableProxy = arrow::matlab::tabular::proxy::Table; + + mda::ArrayFactory factory; + + // Create a file input stream. + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto source, arrow::io::ReadableFile::Open(filename, arrow::default_memory_pool()), context, error::FAILED_TO_OPEN_FILE_FOR_READ); + + const ::arrow::io::IOContext io_context; + const auto read_options = csv::ReadOptions::Defaults(); + const auto parse_options = csv::ParseOptions::Defaults(); + const auto convert_options = csv::ConvertOptions::Defaults(); + + // Create a TableReader from the file input stream. + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto table_reader, + csv::TableReader::Make(io_context, source, read_options, parse_options, convert_options), + context, + error::CSV_FAILED_TO_CREATE_TABLE_READER); + + // Read a Table from the file. + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto table, table_reader->Read(), context, error::CSV_FAILED_TO_READ_TABLE); + + auto table_proxy = std::make_shared(table); + const auto table_proxy_id = ProxyManager::manageProxy(table_proxy); + + const auto table_proxy_id_mda = factory.createScalar(table_proxy_id); + + context.outputs[0] = table_proxy_id_mda; + } + + void TableReader::getFilename(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + mda::ArrayFactory factory; + + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto filename_utf16, arrow::util::UTF8StringToUTF16(filename), context, error::UNICODE_CONVERSION_ERROR_ID); + auto filename_utf16_mda = factory.createScalar(filename_utf16); + context.outputs[0] = filename_utf16_mda; + } + +} diff --git a/matlab/src/placeholder_test.cc b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_reader.h similarity index 56% rename from matlab/src/placeholder_test.cc rename to matlab/src/cpp/arrow/matlab/io/csv/proxy/table_reader.h index eef37e178f623..d5dfce50e4096 100644 --- a/matlab/src/placeholder_test.cc +++ b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_reader.h @@ -15,13 +15,24 @@ // specific language governing permissions and limitations // under the License. -#include +#pragma once -namespace arrow { -namespace matlab { -namespace test { -// TODO: Remove this placeholder test. -TEST(PlaceholderTestSuite, PlaceholderTestCase) { ASSERT_TRUE(true); } -} // namespace test -} // namespace matlab -} // namespace arrow +#include "libmexclass/proxy/Proxy.h" + +namespace arrow::matlab::io::csv::proxy { + + class TableReader : public libmexclass::proxy::Proxy { + public: + TableReader(const std::string& filename); + ~TableReader() {} + static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments); + + protected: + void read(libmexclass::proxy::method::Context& context); + void getFilename(libmexclass::proxy::method::Context& context); + + private: + const std::string filename; + }; + +} diff --git a/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_writer.cc b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_writer.cc new file mode 100644 index 0000000000000..b24bd81b06681 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_writer.cc @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/matlab/io/csv/proxy/table_writer.h" +#include "arrow/matlab/tabular/proxy/table.h" +#include "arrow/matlab/error/error.h" + +#include "arrow/result.h" +#include "arrow/table.h" +#include "arrow/util/utf8.h" + +#include "arrow/io/file.h" +#include "arrow/csv/writer.h" +#include "arrow/csv/options.h" + +#include "libmexclass/proxy/ProxyManager.h" + +namespace arrow::matlab::io::csv::proxy { + + TableWriter::TableWriter(const std::string& filename) : filename{filename} { + REGISTER_METHOD(TableWriter, getFilename); + REGISTER_METHOD(TableWriter, write); + } + + libmexclass::proxy::MakeResult TableWriter::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { + namespace mda = ::matlab::data; + mda::StructArray opts = constructor_arguments[0]; + const mda::StringArray filename_mda = opts[0]["Filename"]; + using TableWriterProxy = ::arrow::matlab::io::csv::proxy::TableWriter; + + const auto filename_utf16 = std::u16string(filename_mda[0]); + MATLAB_ASSIGN_OR_ERROR(const auto filename_utf8, + arrow::util::UTF16StringToUTF8(filename_utf16), + error::UNICODE_CONVERSION_ERROR_ID); + + return std::make_shared(filename_utf8); + } + + void TableWriter::getFilename(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto utf16_filename, + arrow::util::UTF8StringToUTF16(filename), + context, + error::UNICODE_CONVERSION_ERROR_ID); + mda::ArrayFactory factory; + auto str_mda = factory.createScalar(utf16_filename); + context.outputs[0] = str_mda; + } + + void TableWriter::write(libmexclass::proxy::method::Context& context) { + namespace csv = ::arrow::csv; + namespace mda = ::matlab::data; + using TableProxy = ::arrow::matlab::tabular::proxy::Table; + + mda::StructArray opts = context.inputs[0]; + const mda::TypedArray table_proxy_id_mda = opts[0]["TableProxyID"]; + const uint64_t table_proxy_id = table_proxy_id_mda[0]; + + auto proxy = libmexclass::proxy::ProxyManager::getProxy(table_proxy_id); + auto table_proxy = std::static_pointer_cast(proxy); + auto table = table_proxy->unwrap(); + + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto output_stream, + arrow::io::FileOutputStream::Open(filename), + context, + error::FAILED_TO_OPEN_FILE_FOR_WRITE); + const auto options = csv::WriteOptions::Defaults(); + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT(csv::WriteCSV(*table, options, output_stream.get()), + context, + error::CSV_FAILED_TO_WRITE_TABLE); + } +} diff --git a/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_writer.h b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_writer.h new file mode 100644 index 0000000000000..b9916bd9bdc22 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/io/csv/proxy/table_writer.h @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "libmexclass/proxy/Proxy.h" + +namespace arrow::matlab::io::csv::proxy { + + class TableWriter : public libmexclass::proxy::Proxy { + public: + TableWriter(const std::string& filename); + ~TableWriter() {} + static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments); + + protected: + void getFilename(libmexclass::proxy::method::Context& context); + void write(libmexclass::proxy::method::Context& context); + + private: + const std::string filename; + }; + +} diff --git a/matlab/src/cpp/arrow/matlab/proxy/factory.cc b/matlab/src/cpp/arrow/matlab/proxy/factory.cc index ebeb020a9e7c7..62ed84fedcf6a 100644 --- a/matlab/src/cpp/arrow/matlab/proxy/factory.cc +++ b/matlab/src/cpp/arrow/matlab/proxy/factory.cc @@ -21,6 +21,7 @@ #include "arrow/matlab/array/proxy/timestamp_array.h" #include "arrow/matlab/array/proxy/time32_array.h" #include "arrow/matlab/array/proxy/time64_array.h" +#include "arrow/matlab/array/proxy/struct_array.h" #include "arrow/matlab/array/proxy/chunked_array.h" #include "arrow/matlab/tabular/proxy/record_batch.h" #include "arrow/matlab/tabular/proxy/table.h" @@ -37,6 +38,8 @@ #include "arrow/matlab/type/proxy/field.h" #include "arrow/matlab/io/feather/proxy/writer.h" #include "arrow/matlab/io/feather/proxy/reader.h" +#include "arrow/matlab/io/csv/proxy/table_writer.h" +#include "arrow/matlab/io/csv/proxy/table_reader.h" #include "factory.h" @@ -55,6 +58,7 @@ libmexclass::proxy::MakeResult Factory::make_proxy(const ClassName& class_name, REGISTER_PROXY(arrow.array.proxy.Int64Array , arrow::matlab::array::proxy::NumericArray); REGISTER_PROXY(arrow.array.proxy.BooleanArray , arrow::matlab::array::proxy::BooleanArray); REGISTER_PROXY(arrow.array.proxy.StringArray , arrow::matlab::array::proxy::StringArray); + REGISTER_PROXY(arrow.array.proxy.StructArray , arrow::matlab::array::proxy::StructArray); REGISTER_PROXY(arrow.array.proxy.TimestampArray, arrow::matlab::array::proxy::NumericArray); REGISTER_PROXY(arrow.array.proxy.Time32Array , arrow::matlab::array::proxy::NumericArray); REGISTER_PROXY(arrow.array.proxy.Time64Array , arrow::matlab::array::proxy::NumericArray); @@ -85,6 +89,8 @@ libmexclass::proxy::MakeResult Factory::make_proxy(const ClassName& class_name, REGISTER_PROXY(arrow.type.proxy.StructType , arrow::matlab::type::proxy::StructType); REGISTER_PROXY(arrow.io.feather.proxy.Writer , arrow::matlab::io::feather::proxy::Writer); REGISTER_PROXY(arrow.io.feather.proxy.Reader , arrow::matlab::io::feather::proxy::Reader); + REGISTER_PROXY(arrow.io.csv.proxy.TableWriter , arrow::matlab::io::csv::proxy::TableWriter); + REGISTER_PROXY(arrow.io.csv.proxy.TableReader , arrow::matlab::io::csv::proxy::TableReader); return libmexclass::error::Error{error::UNKNOWN_PROXY_ERROR_ID, "Did not find matching C++ proxy for " + class_name}; }; diff --git a/matlab/src/matlab/+arrow/+array/Array.m b/matlab/src/matlab/+arrow/+array/Array.m index 4505d4b006ad8..436d5b80aa6a8 100644 --- a/matlab/src/matlab/+arrow/+array/Array.m +++ b/matlab/src/matlab/+arrow/+array/Array.m @@ -21,12 +21,9 @@ Proxy end - properties (Dependent) + properties(Dependent, SetAccess=private, GetAccess=public) Length Valid % Validity bitmap - end - - properties(Dependent, SetAccess=private, GetAccess=public) Type(1, 1) arrow.type.Type end diff --git a/matlab/src/matlab/+arrow/+array/BooleanArray.m b/matlab/src/matlab/+arrow/+array/BooleanArray.m index b9ef36b5a70c9..dc38ef93e545c 100644 --- a/matlab/src/matlab/+arrow/+array/BooleanArray.m +++ b/matlab/src/matlab/+arrow/+array/BooleanArray.m @@ -16,8 +16,8 @@ classdef BooleanArray < arrow.array.Array % arrow.array.BooleanArray - properties (Hidden, SetAccess=private) - NullSubstitionValue = false; + properties (Hidden, GetAccess=public, SetAccess=private) + NullSubstitutionValue = false; end methods @@ -35,7 +35,7 @@ function matlabArray = toMATLAB(obj) matlabArray = obj.Proxy.toMATLAB(); - matlabArray(~obj.Valid) = obj.NullSubstitionValue; + matlabArray(~obj.Valid) = obj.NullSubstitutionValue; end end diff --git a/matlab/src/matlab/+arrow/+array/ChunkedArray.m b/matlab/src/matlab/+arrow/+array/ChunkedArray.m index 96d7bb57a4021..ede95323f4865 100644 --- a/matlab/src/matlab/+arrow/+array/ChunkedArray.m +++ b/matlab/src/matlab/+arrow/+array/ChunkedArray.m @@ -66,7 +66,8 @@ for ii = 1:obj.NumChunks chunk = obj.chunk(ii); endIndex = startIndex + chunk.Length - 1; - data(startIndex:endIndex) = toMATLAB(chunk); + % Use 2D indexing to support tabular MATLAB types. + data(startIndex:endIndex, :) = toMATLAB(chunk); startIndex = endIndex + 1; end end diff --git a/matlab/src/matlab/+arrow/+array/Date32Array.m b/matlab/src/matlab/+arrow/+array/Date32Array.m index a462bd4f85ac1..cfe56bc67fb94 100644 --- a/matlab/src/matlab/+arrow/+array/Date32Array.m +++ b/matlab/src/matlab/+arrow/+array/Date32Array.m @@ -17,7 +17,7 @@ classdef Date32Array < arrow.array.Array - properties(Access=private) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = NaT end diff --git a/matlab/src/matlab/+arrow/+array/Date64Array.m b/matlab/src/matlab/+arrow/+array/Date64Array.m index f5da26bbb5594..c67b82a5bbc47 100644 --- a/matlab/src/matlab/+arrow/+array/Date64Array.m +++ b/matlab/src/matlab/+arrow/+array/Date64Array.m @@ -17,7 +17,7 @@ classdef Date64Array < arrow.array.Array - properties(Access=private) + properties(Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = NaT end diff --git a/matlab/src/matlab/+arrow/+array/Float32Array.m b/matlab/src/matlab/+arrow/+array/Float32Array.m index fe90db335b5aa..d12e772c41428 100644 --- a/matlab/src/matlab/+arrow/+array/Float32Array.m +++ b/matlab/src/matlab/+arrow/+array/Float32Array.m @@ -16,7 +16,7 @@ classdef Float32Array < arrow.array.NumericArray % arrow.array.Float32Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = single(NaN); end diff --git a/matlab/src/matlab/+arrow/+array/Float64Array.m b/matlab/src/matlab/+arrow/+array/Float64Array.m index ecf91e28954b5..028331b4f99c0 100644 --- a/matlab/src/matlab/+arrow/+array/Float64Array.m +++ b/matlab/src/matlab/+arrow/+array/Float64Array.m @@ -16,7 +16,7 @@ classdef Float64Array < arrow.array.NumericArray % arrow.array.Float64Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = NaN; end diff --git a/matlab/src/matlab/+arrow/+array/Int16Array.m b/matlab/src/matlab/+arrow/+array/Int16Array.m index 53c96c6eeb85c..aee94b39c8969 100644 --- a/matlab/src/matlab/+arrow/+array/Int16Array.m +++ b/matlab/src/matlab/+arrow/+array/Int16Array.m @@ -16,7 +16,7 @@ classdef Int16Array < arrow.array.NumericArray % arrow.array.Int16Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = int16(0) end diff --git a/matlab/src/matlab/+arrow/+array/Int32Array.m b/matlab/src/matlab/+arrow/+array/Int32Array.m index d85bcaf627f7b..a0c0c76afa0e7 100644 --- a/matlab/src/matlab/+arrow/+array/Int32Array.m +++ b/matlab/src/matlab/+arrow/+array/Int32Array.m @@ -16,7 +16,7 @@ classdef Int32Array < arrow.array.NumericArray % arrow.array.Int32Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = int32(0) end diff --git a/matlab/src/matlab/+arrow/+array/Int64Array.m b/matlab/src/matlab/+arrow/+array/Int64Array.m index 72199df88ded1..1f8b1c793984a 100644 --- a/matlab/src/matlab/+arrow/+array/Int64Array.m +++ b/matlab/src/matlab/+arrow/+array/Int64Array.m @@ -16,7 +16,7 @@ classdef Int64Array < arrow.array.NumericArray % arrow.array.Int64Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = int64(0); end diff --git a/matlab/src/matlab/+arrow/+array/Int8Array.m b/matlab/src/matlab/+arrow/+array/Int8Array.m index 0e9d8eec0edf5..02e21178ffe49 100644 --- a/matlab/src/matlab/+arrow/+array/Int8Array.m +++ b/matlab/src/matlab/+arrow/+array/Int8Array.m @@ -16,7 +16,7 @@ classdef Int8Array < arrow.array.NumericArray % arrow.array.Int8Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = int8(0); end diff --git a/matlab/src/matlab/+arrow/+array/NumericArray.m b/matlab/src/matlab/+arrow/+array/NumericArray.m index 8f465ce425e23..088ccfd6aa53f 100644 --- a/matlab/src/matlab/+arrow/+array/NumericArray.m +++ b/matlab/src/matlab/+arrow/+array/NumericArray.m @@ -16,7 +16,7 @@ classdef NumericArray < arrow.array.Array % arrow.array.NumericArray - properties(Abstract, Access=protected) + properties(Abstract, Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue; end diff --git a/matlab/src/matlab/+arrow/+array/StringArray.m b/matlab/src/matlab/+arrow/+array/StringArray.m index 18fdec9ac70c3..e016aeb704a4d 100644 --- a/matlab/src/matlab/+arrow/+array/StringArray.m +++ b/matlab/src/matlab/+arrow/+array/StringArray.m @@ -16,8 +16,8 @@ classdef StringArray < arrow.array.Array % arrow.array.StringArray - properties (Hidden, SetAccess=private) - NullSubstitionValue = string(missing); + properties (Hidden, GetAccess=public, SetAccess=private) + NullSubstitutionValue = string(missing); end methods @@ -35,7 +35,7 @@ function matlabArray = toMATLAB(obj) matlabArray = obj.Proxy.toMATLAB(); - matlabArray(~obj.Valid) = obj.NullSubstitionValue; + matlabArray(~obj.Valid) = obj.NullSubstitutionValue; end end diff --git a/matlab/src/matlab/+arrow/+array/StructArray.m b/matlab/src/matlab/+arrow/+array/StructArray.m new file mode 100644 index 0000000000000..589e39fecd015 --- /dev/null +++ b/matlab/src/matlab/+arrow/+array/StructArray.m @@ -0,0 +1,146 @@ +% arrow.array.StructArray + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef StructArray < arrow.array.Array + + properties (Dependent, GetAccess=public, SetAccess=private) + NumFields + FieldNames + end + + properties (Hidden, Dependent, GetAccess=public, SetAccess=private) + NullSubstitutionValue + end + + methods + function obj = StructArray(proxy) + arguments + proxy(1, 1) libmexclass.proxy.Proxy {validate(proxy, "arrow.array.proxy.StructArray")} + end + import arrow.internal.proxy.validate + obj@arrow.array.Array(proxy); + end + + function numFields = get.NumFields(obj) + numFields = obj.Proxy.getNumFields(); + end + + function fieldNames = get.FieldNames(obj) + fieldNames = obj.Proxy.getFieldNames(); + end + + function F = field(obj, idx) + import arrow.internal.validate.* + + idx = index.numericOrString(idx, "int32", AllowNonScalar=false); + + if isnumeric(idx) + args = struct(Index=idx); + fieldStruct = obj.Proxy.getFieldByIndex(args); + else + args = struct(Name=idx); + fieldStruct = obj.Proxy.getFieldByName(args); + end + + traits = arrow.type.traits.traits(arrow.type.ID(fieldStruct.TypeID)); + proxy = libmexclass.proxy.Proxy(Name=traits.ArrayProxyClassName, ID=fieldStruct.ProxyID); + F = traits.ArrayConstructor(proxy); + end + + function T = toMATLAB(obj) + T = table(obj); + end + + function T = table(obj) + import arrow.tabular.internal.* + + numFields = obj.NumFields; + matlabArrays = cell(1, numFields); + + invalid = ~obj.Valid; + numInvalid = nnz(invalid); + + for ii = 1:numFields + arrowArray = obj.field(ii); + matlabArray = toMATLAB(arrowArray); + if numInvalid ~= 0 + % MATLAB tables do not support null values themselves. + % So, to encode the StructArray's null values, we + % iterate over each variable in the resulting MATLAB + % table, and for each variable, we set the value of all + % null elements to the "NullSubstitutionValue" that + % corresponds to the variable's type (e.g. NaN for + % double, NaT for datetime, etc.). + matlabArray(invalid, :) = repmat(arrowArray.NullSubstitutionValue, [numInvalid 1]); + end + matlabArrays{ii} = matlabArray; + end + + fieldNames = [obj.Type.Fields.Name]; + validVariableNames = makeValidVariableNames(fieldNames); + validDimensionNames = makeValidDimensionNames(validVariableNames); + + T = table(matlabArrays{:}, ... + VariableNames=validVariableNames, ... + DimensionNames=validDimensionNames); + end + + function nullSubVal = get.NullSubstitutionValue(obj) + % Return a cell array containing each field's type-specifc + % "null" value. For example, NaN is the type-specific null + % value for Float32Arrays and Float64Arrays + numFields = obj.NumFields; + nullSubVal = cell(1, numFields); + for ii = 1:obj.NumFields + nullSubVal{ii} = obj.field(ii).NullSubstitutionValue; + end + end + end + + methods (Static) + function array = fromArrays(arrowArrays, opts) + arguments(Repeating) + arrowArrays(1, 1) arrow.array.Array + end + arguments + opts.FieldNames(1, :) string {mustBeNonmissing} = compose("Field%d", 1:numel(arrowArrays)) + opts.Valid + end + + import arrow.tabular.internal.validateArrayLengths + import arrow.tabular.internal.validateColumnNames + import arrow.array.internal.getArrayProxyIDs + import arrow.internal.validate.parseValid + + if numel(arrowArrays) == 0 + error("arrow:struct:ZeroFields", ... + "Must supply at least one field array."); + end + + validateArrayLengths(arrowArrays); + validateColumnNames(opts.FieldNames, numel(arrowArrays)); + validElements = parseValid(opts, arrowArrays{1}.Length); + + arrayProxyIDs = getArrayProxyIDs(arrowArrays); + args = struct(ArrayProxyIDs=arrayProxyIDs, ... + FieldNames=opts.FieldNames, Valid=validElements); + proxyName = "arrow.array.proxy.StructArray"; + proxy = arrow.internal.proxy.create(proxyName, args); + array = arrow.array.StructArray(proxy); + end + end +end \ No newline at end of file diff --git a/matlab/src/matlab/+arrow/+array/Time32Array.m b/matlab/src/matlab/+arrow/+array/Time32Array.m index 85babd26a721a..ae40a3a0b740c 100644 --- a/matlab/src/matlab/+arrow/+array/Time32Array.m +++ b/matlab/src/matlab/+arrow/+array/Time32Array.m @@ -17,7 +17,7 @@ classdef Time32Array < arrow.array.Array - properties(Access=private) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = seconds(NaN); end diff --git a/matlab/src/matlab/+arrow/+array/Time64Array.m b/matlab/src/matlab/+arrow/+array/Time64Array.m index f85eeb1f8f0c9..cd4b948324272 100644 --- a/matlab/src/matlab/+arrow/+array/Time64Array.m +++ b/matlab/src/matlab/+arrow/+array/Time64Array.m @@ -17,7 +17,7 @@ classdef Time64Array < arrow.array.Array - properties(Access=private) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = seconds(NaN); end diff --git a/matlab/src/matlab/+arrow/+array/TimestampArray.m b/matlab/src/matlab/+arrow/+array/TimestampArray.m index 80198f965fe92..9289d0a099f7c 100644 --- a/matlab/src/matlab/+arrow/+array/TimestampArray.m +++ b/matlab/src/matlab/+arrow/+array/TimestampArray.m @@ -16,7 +16,7 @@ classdef TimestampArray < arrow.array.Array % arrow.array.TimestampArray - properties(Access=private) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = NaT; end diff --git a/matlab/src/matlab/+arrow/+array/UInt16Array.m b/matlab/src/matlab/+arrow/+array/UInt16Array.m index 9d3f33c279175..d5487ee130d93 100644 --- a/matlab/src/matlab/+arrow/+array/UInt16Array.m +++ b/matlab/src/matlab/+arrow/+array/UInt16Array.m @@ -16,7 +16,7 @@ classdef UInt16Array < arrow.array.NumericArray % arrow.array.UInt16Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = uint16(0) end diff --git a/matlab/src/matlab/+arrow/+array/UInt32Array.m b/matlab/src/matlab/+arrow/+array/UInt32Array.m index 5235d4fb15576..43c1caac3b791 100644 --- a/matlab/src/matlab/+arrow/+array/UInt32Array.m +++ b/matlab/src/matlab/+arrow/+array/UInt32Array.m @@ -16,7 +16,7 @@ classdef UInt32Array < arrow.array.NumericArray % arrow.array.UInt32Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = uint32(0) end diff --git a/matlab/src/matlab/+arrow/+array/UInt64Array.m b/matlab/src/matlab/+arrow/+array/UInt64Array.m index 2d69bd031ac31..047e7102dd5c5 100644 --- a/matlab/src/matlab/+arrow/+array/UInt64Array.m +++ b/matlab/src/matlab/+arrow/+array/UInt64Array.m @@ -16,7 +16,7 @@ classdef UInt64Array < arrow.array.NumericArray % arrow.array.UInt64Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = uint64(0) end diff --git a/matlab/src/matlab/+arrow/+array/UInt8Array.m b/matlab/src/matlab/+arrow/+array/UInt8Array.m index 3d007376bc89a..901a003161220 100644 --- a/matlab/src/matlab/+arrow/+array/UInt8Array.m +++ b/matlab/src/matlab/+arrow/+array/UInt8Array.m @@ -16,7 +16,7 @@ classdef UInt8Array < arrow.array.NumericArray % arrow.array.UInt8Array - properties (Access=protected) + properties (Hidden, GetAccess=public, SetAccess=private) NullSubstitutionValue = uint8(0) end diff --git a/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m b/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m index c0bedaf2faf39..d3a751ca46731 100644 --- a/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m +++ b/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m @@ -23,6 +23,10 @@ opts.NumRows(1, 1) {mustBeFinite, mustBeNonnegative} = 3; end + % Seed the random number generator to ensure + % reproducible results in tests. + rng(1); + import arrow.type.ID import arrow.array.* @@ -59,6 +63,13 @@ matlabData{ii} = randomDatetimes(opts.NumRows); cmd = compose("%s.fromMATLAB(matlabData{ii})", name); arrowArrays{ii} = eval(cmd); + elseif name == "arrow.array.StructArray" + dates = randomDatetimes(opts.NumRows); + strings = randomStrings(opts.NumRows); + timestampArray = arrow.array(dates); + stringArray = arrow.array(strings); + arrowArrays{ii} = StructArray.fromArrays(timestampArray, stringArray); + matlabData{ii} = table(dates, strings, VariableNames=["Field1", "Field2"]); else error("arrow:test:SupportedArrayCase", ... "Missing if-branch for array class " + name); diff --git a/matlab/src/matlab/+arrow/+internal/+validate/parseValid.m b/matlab/src/matlab/+arrow/+internal/+validate/parseValid.m new file mode 100644 index 0000000000000..3281e24ec1963 --- /dev/null +++ b/matlab/src/matlab/+arrow/+internal/+validate/parseValid.m @@ -0,0 +1,46 @@ +%PARSEVALID Utility function for parsing the Valid name-value pair. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +function validElements = parseValid(opts, numElements) + if ~isfield(opts, "Valid") + % If Valid is not a field in opts, return an empty logical array. + validElements = logical.empty(0, 1); + return; + end + + valid = opts.Valid; + if islogical(valid) + validElements = reshape(valid, [], 1); + if ~isscalar(validElements) + % Verify the logical vector has the correct number of elements + validateattributes(validElements, "logical", {'numel', numElements}); + elseif validElements == false + validElements = false(numElements, 1); + else % validElements == true + % Return an empty logical to represent all elements are valid. + validElements = logical.empty(0, 1); + end + else + % valid is a list of indices. Verify the indices are numeric, + % integers, and within the range [1, numElements] + validateattributes(valid, "numeric", {'integer', '>', 0, '<=', numElements}); + % Create a logical vector that contains true values at the indices + % specified by opts.Valid. + validElements = false([numElements 1]); + validElements(valid) = true; + end +end \ No newline at end of file diff --git a/matlab/src/matlab/+arrow/+internal/+validate/parseValidElements.m b/matlab/src/matlab/+arrow/+internal/+validate/parseValidElements.m index 4081f4092740b..8a43dbb4d78e1 100644 --- a/matlab/src/matlab/+arrow/+internal/+validate/parseValidElements.m +++ b/matlab/src/matlab/+arrow/+internal/+validate/parseValidElements.m @@ -21,7 +21,7 @@ % precedence over InferNulls. if isfield(opts, "Valid") - validElements = parseValid(numel(data), opts.Valid); + validElements = arrow.internal.validate.parseValid(opts, numel(data)); else validElements = parseInferNulls(data, opts.InferNulls); end @@ -33,29 +33,6 @@ end end -function validElements = parseValid(numElements, valid) - if islogical(valid) - validElements = reshape(valid, [], 1); - if ~isscalar(validElements) - % Verify the logical vector has the correct number of elements - validateattributes(validElements, "logical", {'numel', numElements}); - elseif validElements == false - validElements = false(numElements, 1); - else % validElements == true - % Return an empty logical to represent all elements are valid. - validElements = logical.empty(0, 1); - end - else - % valid is a list of indices. Verify the indices are numeric, - % integers, and within the range 1 < indices < numElements. - validateattributes(valid, "numeric", {'integer', '>', 0, '<=', numElements}); - % Create a logical vector that contains true values at the indices - % specified by opts.Valid. - validElements = false([numElements 1]); - validElements(valid) = true; - end -end - function validElements = parseInferNulls(data, inferNulls) if inferNulls && ~(isinteger(data) || islogical(data)) % Only call ismissing on data types that have a "missing" value, diff --git a/matlab/src/matlab/+arrow/+io/+csv/TableReader.m b/matlab/src/matlab/+arrow/+io/+csv/TableReader.m new file mode 100644 index 0000000000000..1e0308bb8d4fe --- /dev/null +++ b/matlab/src/matlab/+arrow/+io/+csv/TableReader.m @@ -0,0 +1,51 @@ +%TABLEREADER Reads tabular data from a CSV file into an arrow.tabular.Table. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef TableReader + + properties (GetAccess=public, SetAccess=private, Hidden) + Proxy + end + + properties (Dependent, SetAccess=private, GetAccess=public) + Filename + end + + methods + + function obj = TableReader(filename) + arguments + filename (1, 1) string {mustBeNonmissing, mustBeNonzeroLengthText} + end + + args = struct(Filename=filename); + obj.Proxy = arrow.internal.proxy.create("arrow.io.csv.proxy.TableReader", args); + end + + function table = read(obj) + tableProxyID = obj.Proxy.read(); + proxy = libmexclass.proxy.Proxy(Name="arrow.tabular.proxy.Table", ID=tableProxyID); + table = arrow.tabular.Table(proxy); + end + + function filename = get.Filename(obj) + filename = obj.Proxy.getFilename(); + end + + end + +end \ No newline at end of file diff --git a/matlab/src/matlab/+arrow/+io/+csv/TableWriter.m b/matlab/src/matlab/+arrow/+io/+csv/TableWriter.m new file mode 100644 index 0000000000000..eb1aafe08f545 --- /dev/null +++ b/matlab/src/matlab/+arrow/+io/+csv/TableWriter.m @@ -0,0 +1,51 @@ +%TABLEWRITER Writes tabular data in an arrow.tabular.Table to a CSV file. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. +classdef TableWriter < matlab.mixin.Scalar + + properties(Hidden, SetAccess=private, GetAccess=public) + Proxy + end + + properties(Dependent, SetAccess=private, GetAccess=public) + Filename + end + + methods + function obj = TableWriter(filename) + arguments + filename (1, 1) string {mustBeNonmissing, mustBeNonzeroLengthText} + end + + args = struct(Filename=filename); + proxyName = "arrow.io.csv.proxy.TableWriter"; + obj.Proxy = arrow.internal.proxy.create(proxyName, args); + end + + function write(obj, table) + arguments + obj (1, 1) arrow.io.csv.TableWriter + table (1, 1) arrow.tabular.Table + end + args = struct(TableProxyID=table.Proxy.ID); + obj.Proxy.write(args); + end + + function filename = get.Filename(obj) + filename = obj.Proxy.getFilename(); + end + end +end diff --git a/matlab/src/matlab/+arrow/+type/+traits/StructTraits.m b/matlab/src/matlab/+arrow/+type/+traits/StructTraits.m index a8ed98f8ae468..0f8b7b3a2a663 100644 --- a/matlab/src/matlab/+arrow/+type/+traits/StructTraits.m +++ b/matlab/src/matlab/+arrow/+type/+traits/StructTraits.m @@ -16,21 +16,18 @@ classdef StructTraits < arrow.type.traits.TypeTraits properties (Constant) - % TODO: When arrow.array.StructArray is implemented, set these - % properties appropriately - ArrayConstructor = missing - ArrayClassName = missing - ArrayProxyClassName = missing + ArrayConstructor = @arrow.array.StructArray + ArrayClassName = "arrow.array.StructArray" + ArrayProxyClassName = "arrow.array.proxy.StructArray" + + % TODO: Implement fromMATLAB ArrayStaticConstructor = missing TypeConstructor = @arrow.type.StructType TypeClassName = "arrow.type.StructType" TypeProxyClassName = "arrow.type.proxy.StructType" - - % TODO: When arrow.array.StructArray is implemented, set these - % properties appropriately - MatlabConstructor = missing - MatlabClassName = missing + MatlabConstructor = @table + MatlabClassName = "table" end end \ No newline at end of file diff --git a/matlab/src/matlab/+arrow/+type/StructType.m b/matlab/src/matlab/+arrow/+type/StructType.m index 6c1318f6376f3..331ac75a2ee16 100644 --- a/matlab/src/matlab/+arrow/+type/StructType.m +++ b/matlab/src/matlab/+arrow/+type/StructType.m @@ -33,14 +33,28 @@ end methods (Hidden) - % TODO: Consider using a mixin approach to add this behavior. For - % example, ChunkedArray's toMATLAB method could check if its - % Type inherits from a mixin called "Preallocateable" (or something - % more descriptive). If so, we can call preallocateMATLABArray - % in the toMATLAB method. - function preallocateMATLABArray(~) - error("arrow:type:UnsupportedFunction", ... - "preallocateMATLABArray is not supported for StructType"); - end + function data = preallocateMATLABArray(obj, numElements) + import arrow.tabular.internal.* + + fields = obj.Fields; + + % Construct the VariableNames and VariableDimensionNames + fieldNames = [fields.Name]; + validVariableNames = makeValidVariableNames(fieldNames); + validDimensionNames = makeValidDimensionNames(validVariableNames); + + % Recursively call preallocateMATLABArray to handle + % preallocation of nested types + variableData = cell(1, numel(fields)); + for ii = 1:numel(fields) + type = fields(ii).Type; + variableData{ii} = preallocateMATLABArray(type, numElements); + end + + % Return a table with the appropriate schema and dimensions + data = table(variableData{:}, ... + VariableNames=validVariableNames, ... + DimensionNames=validDimensionNames); + end end end \ No newline at end of file diff --git a/matlab/test/arrow/array/tStructArray.m b/matlab/test/arrow/array/tStructArray.m new file mode 100644 index 0000000000000..639df65befbf5 --- /dev/null +++ b/matlab/test/arrow/array/tStructArray.m @@ -0,0 +1,277 @@ +%TSTRUCTARRAY Unit tests for arrow.array.StructArray + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef tStructArray < matlab.unittest.TestCase + + properties + Float64Array = arrow.array([1 NaN 3 4 5]); + StringArray = arrow.array(["A" "B" "C" "D" missing]); + end + + methods (Test) + function Basic(tc) + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + tc.verifyInstanceOf(array, "arrow.array.StructArray"); + end + + function FieldNames(tc) + % Verify the FieldNames property is set to the expected value. + import arrow.array.StructArray + + % Default field names used + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + tc.verifyEqual(array.FieldNames, ["Field1", "Field2"]); + + % Field names provided + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["A", "B"]); + tc.verifyEqual(array.FieldNames, ["A", "B"]); + + % Duplicate field names provided + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["C", "C"]); + tc.verifyEqual(array.FieldNames, ["C", "C"]); + end + + function FieldNamesError(tc) + % Verify the FieldNames nv-pair errors when expected. + import arrow.array.StructArray + + % Wrong type provided + fcn = @() StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames={table table}); + tc.verifyError(fcn, "MATLAB:validation:UnableToConvert"); + + % Wrong number of field names provided + fcn = @() StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames="A"); + tc.verifyError(fcn, "arrow:tabular:WrongNumberColumnNames"); + + % Missing string provided + fcn = @() StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["A" missing]); + tc.verifyError(fcn, "MATLAB:validators:mustBeNonmissing"); + end + + function FieldNamesNoSetter(tc) + % Verify the FieldNames property is read-only. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + fcn = @() setfield(array, "FieldNames", ["A", "B"]); + tc.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + function NumFields(tc) + % Verify the NumFields property is set to the expected value. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + tc.verifyEqual(array.NumFields, int32(2)); + end + + function NumFieldsNoSetter(tc) + % Verify the NumFields property is read-only. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + fcn = @() setfield(array, "NumFields", 10); + tc.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + function Valid(tc) + % Verify the Valid property is set to the expected value. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + expectedValid = true([5 1]); + tc.verifyEqual(array.Valid, expectedValid); + + % Supply the Valid nv-pair + valid = [true true false true false]; + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, Valid=valid); + tc.verifyEqual(array.Valid, valid'); + end + + function ValidNVPairError(tc) + % Verify the Valid nv-pair errors when expected. + import arrow.array.StructArray + + % Provided an invalid index + fcn = @() StructArray.fromArrays(tc.Float64Array, tc.StringArray, Valid=10); + tc.verifyError(fcn, "MATLAB:notLessEqual"); + + % Provided a logical vector with more elements than the array + % length + fcn = @() StructArray.fromArrays(tc.Float64Array, tc.StringArray, Valid=false([7 1])); + tc.verifyError(fcn, "MATLAB:incorrectNumel"); + end + + function ValidNoSetter(tc) + % Verify the Valid property is read-only. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + fcn = @() setfield(array, "Valid", false); + tc.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + function Length(tc) + % Verify the Length property is set to the expected value. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + tc.verifyEqual(array.Length, int64(5)); + end + + function LengthNoSetter(tc) + % Verify the Length property is read-only. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + fcn = @() setfield(array, "Length", 1); + tc.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + function Type(tc) + % Verify the Type property is set to the expected value. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + field1 = arrow.field("X", arrow.float64()); + field2 = arrow.field("Y", arrow.string()); + expectedType = arrow.struct(field1, field2); + tc.verifyEqual(array.Type, expectedType); + end + + function TypeNoSetter(tc) + % Verify the Type property is read-only. + import arrow.array.StructArray + + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + fcn = @() setfield(array, "Type", tc.Float64Array.Type); + tc.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + function FieldByIndex(tc) + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + + % Extract 1st field + field1 = array.field(1); + tc.verifyEqual(field1, tc.Float64Array); + + % Extract 2nd field + field2 = array.field(2); + tc.verifyEqual(field2, tc.StringArray); + end + + function FieldByIndexError(tc) + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + + % Supply a nonscalar vector + fcn = @() array.field([1 2]); + tc.verifyError(fcn, "arrow:badsubscript:NonScalar"); + + % Supply a noninteger + fcn = @() array.field(1.1); + tc.verifyError(fcn, "arrow:badsubscript:NonInteger"); + end + + function FieldByName(tc) + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + + % Extract 1st field + field1 = array.field("Field1"); + tc.verifyEqual(field1, tc.Float64Array); + + % Extract 2nd field + field2 = array.field("Field2"); + tc.verifyEqual(field2, tc.StringArray); + end + + function FieldByNameError(tc) + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray); + + % Supply a nonscalar string array + fcn = @() array.field(["Field1" "Field2"]); + tc.verifyError(fcn, "arrow:badsubscript:NonScalar"); + + % Supply a nonexistent field name + fcn = @() array.field("B"); + tc.verifyError(fcn, "arrow:tabular:schema:AmbiguousFieldName"); + end + + function toMATLAB(tc) + % Verify toMATLAB returns the expected MATLAB table + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + expectedTable = table(toMATLAB(tc.Float64Array), toMATLAB(tc.StringArray), VariableNames=["X", "Y"]); + actualTable = toMATLAB(array); + tc.verifyEqual(actualTable, expectedTable); + + % Verify table elements that correspond to "null" values + % in the StructArray are set to the type-specific null values. + valid = [1 2 5]; + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"], Valid=valid); + float64NullValue = tc.Float64Array.NullSubstitutionValue; + stringNullValue = tc.StringArray.NullSubstitutionValue; + expectedTable([3 4], :) = repmat({float64NullValue stringNullValue}, [2 1]); + actualTable = toMATLAB(array); + tc.verifyEqual(actualTable, expectedTable); + end + + function table(tc) + % Verify toMATLAB returns the expected MATLAB table + import arrow.array.StructArray + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + expectedTable = table(toMATLAB(tc.Float64Array), toMATLAB(tc.StringArray), VariableNames=["X", "Y"]); + actualTable = table(array); + tc.verifyEqual(actualTable, expectedTable); + + % Verify table elements that correspond to "null" values + % in the StructArray are set to the type-specific null values. + valid = [1 2 5]; + array = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"], Valid=valid); + float64NullValue = tc.Float64Array.NullSubstitutionValue; + stringNullValue = tc.StringArray.NullSubstitutionValue; + expectedTable([3 4], :) = repmat({float64NullValue stringNullValue}, [2 1]); + actualTable = toMATLAB(array); + tc.verifyEqual(actualTable, expectedTable); + end + + function IsEqualTrue(tc) + % Verify isequal returns true when expected. + import arrow.array.StructArray + array1 = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + array2 = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + tc.verifyTrue(isequal(array1, array2)); + end + + function IsEqualFalse(tc) + % Verify isequal returns false when expected. + import arrow.array.StructArray + array1 = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["X", "Y"]); + array2 = StructArray.fromArrays(tc.StringArray, tc.Float64Array, FieldNames=["X", "Y"]); + array3 = StructArray.fromArrays(tc.Float64Array, tc.StringArray, FieldNames=["A", "B"]); + % StructArrays have the same FieldNames but the Fields have different types. + tc.verifyFalse(isequal(array1, array2)); + % Fields of the StructArrays have the same types but the StructArrays have different FieldNames. + tc.verifyFalse(isequal(array1, array3)); + end + + end +end \ No newline at end of file diff --git a/matlab/test/arrow/io/csv/CSVTest.m b/matlab/test/arrow/io/csv/CSVTest.m new file mode 100644 index 0000000000000..49f77eaaa7c63 --- /dev/null +++ b/matlab/test/arrow/io/csv/CSVTest.m @@ -0,0 +1,102 @@ +%CSVTEST Super class for CSV related tests. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. +classdef CSVTest < matlab.unittest.TestCase + + properties + Filename + end + + methods (TestClassSetup) + + function initializeProperties(~) + % Seed the random number generator. + rng(1); + end + + end + + methods (TestMethodSetup) + + function setupTestFilename(testCase) + import matlab.unittest.fixtures.TemporaryFolderFixture + fixture = testCase.applyFixture(TemporaryFolderFixture); + testCase.Filename = fullfile(fixture.Folder, "filename.csv"); + end + + end + + methods + + function verifyRoundTrip(testCase, arrowTable) + import arrow.io.csv.* + + writer = TableWriter(testCase.Filename); + reader = TableReader(testCase.Filename); + + writer.write(arrowTable); + arrowTableRead = reader.read(); + + testCase.verifyEqual(arrowTableRead, arrowTable); + end + + function arrowTable = makeArrowTable(testCase, opts) + arguments + testCase + opts.Type + opts.ColumnNames + opts.NumRows + opts.WithNulls (1, 1) logical = false + end + + if opts.Type == "numeric" + matlabTable = array2table(rand(opts.NumRows, numel(opts.ColumnNames))); + elseif opts.Type == "string" + matlabTable = array2table("A" + rand(opts.NumRows, numel(opts.ColumnNames)) + "B"); + end + + if opts.WithNulls + matlabTable = testCase.setNullValues(matlabTable, NullPercentage=0.2); + end + + arrays = cell(1, width(matlabTable)); + for ii = 1:width(matlabTable) + arrays{ii} = arrow.array(matlabTable.(ii)); + end + arrowTable = arrow.tabular.Table.fromArrays(arrays{:}, ColumnNames=opts.ColumnNames); + end + + function tWithNulls = setNullValues(testCase, t, opts) + arguments + testCase %#ok + t table + opts.NullPercentage (1, 1) double {mustBeGreaterThanOrEqual(opts.NullPercentage, 0)} = 0.5 + end + + tWithNulls = t; + for ii = 1:width(t) + temp = tWithNulls.(ii); + numValues = numel(temp); + numNulls = uint64(opts.NullPercentage * numValues); + nullIndices = randperm(numValues, numNulls); + temp(nullIndices) = missing; + tWithNulls.(ii) = temp; + end + end + + end + +end diff --git a/matlab/test/arrow/io/csv/tError.m b/matlab/test/arrow/io/csv/tError.m new file mode 100644 index 0000000000000..24c420e7ba2dd --- /dev/null +++ b/matlab/test/arrow/io/csv/tError.m @@ -0,0 +1,73 @@ +%TERROR Error tests for CSV. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. +classdef tError < CSVTest + + methods(Test) + + function EmptyFile(testCase) + import arrow.io.csv.* + + arrowTableWrite = arrow.table(); + + writer = TableWriter(testCase.Filename); + reader = TableReader(testCase.Filename); + + writer.write(arrowTableWrite); + fcn = @() reader.read(); + testCase.verifyError(fcn, "arrow:io:csv:FailedToReadTable"); + end + + function InvalidWriterFilenameType(testCase) + import arrow.io.csv.* + fcn = @() TableWriter(table); + testCase.verifyError(fcn, "MATLAB:validation:UnableToConvert"); + fcn = @() TableWriter(["a", "b"]); + testCase.verifyError(fcn, "MATLAB:validation:IncompatibleSize"); + end + + function InvalidReaderFilenameType(testCase) + import arrow.io.csv.* + fcn = @() TableReader(table); + testCase.verifyError(fcn, "MATLAB:validation:UnableToConvert"); + fcn = @() TableReader(["a", "b"]); + testCase.verifyError(fcn, "MATLAB:validation:IncompatibleSize"); + end + + function InvalidWriterWriteType(testCase) + import arrow.io.csv.* + writer = TableWriter(testCase.Filename); + fcn = @() writer.write("text"); + testCase.verifyError(fcn, "MATLAB:validation:UnableToConvert"); + end + + function WriterFilenameNoSetter(testCase) + import arrow.io.csv.* + writer = TableWriter(testCase.Filename); + fcn = @() setfield(writer, "Filename", "filename.csv"); + testCase.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + function ReaderFilenameNoSetter(testCase) + import arrow.io.csv.* + reader = TableReader(testCase.Filename); + fcn = @() setfield(reader, "Filename", "filename.csv"); + testCase.verifyError(fcn, "MATLAB:class:SetProhibited"); + end + + end + +end \ No newline at end of file diff --git a/matlab/test/arrow/io/csv/tRoundTrip.m b/matlab/test/arrow/io/csv/tRoundTrip.m new file mode 100644 index 0000000000000..cb35822580106 --- /dev/null +++ b/matlab/test/arrow/io/csv/tRoundTrip.m @@ -0,0 +1,62 @@ +%TROUNDTRIP Round trip tests for CSV. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. +classdef tRoundTrip < CSVTest + + properties (TestParameter) + NumRows = { ... + 2, ... + 10, ... + 100 ... + } + WithNulls = { ... + true, ... + false ... + } + ColumnNames = {... + ["A", "B", "C"], ... + ["😀", "🌲", "🥭", " ", "ABC"], ... + [" ", " ", " "] + } + end + + methods(Test) + + function Numeric(testCase, NumRows, WithNulls, ColumnNames) + arrowTable = testCase.makeArrowTable(... + Type="numeric", ... + NumRows=NumRows, ... + WithNulls=WithNulls, ... + ColumnNames=ColumnNames ... + ); + + testCase.verifyRoundTrip(arrowTable); + end + + function String(testCase, NumRows, ColumnNames) + arrowTable = testCase.makeArrowTable(... + Type="string", ... + NumRows=NumRows, ... + WithNulls=false, ... + ColumnNames=ColumnNames ... + ); + + testCase.verifyRoundTrip(arrowTable); + end + + end + +end \ No newline at end of file diff --git a/matlab/test/arrow/type/traits/tStructTraits.m b/matlab/test/arrow/type/traits/tStructTraits.m index 6a97b1e1852d6..07833aca162b5 100644 --- a/matlab/test/arrow/type/traits/tStructTraits.m +++ b/matlab/test/arrow/type/traits/tStructTraits.m @@ -17,15 +17,15 @@ properties TraitsConstructor = @arrow.type.traits.StructTraits - ArrayConstructor = missing - ArrayClassName = missing - ArrayProxyClassName = missing + ArrayConstructor = @arrow.array.StructArray + ArrayClassName = "arrow.array.StructArray" + ArrayProxyClassName = "arrow.array.proxy.StructArray" ArrayStaticConstructor = missing TypeConstructor = @arrow.type.StructType TypeClassName = "arrow.type.StructType" TypeProxyClassName = "arrow.type.proxy.StructType" - MatlabConstructor = missing - MatlabClassName = missing + MatlabConstructor = @table + MatlabClassName = "table" end end \ No newline at end of file diff --git a/matlab/tools/cmake/BuildMatlabArrowInterface.cmake b/matlab/tools/cmake/BuildMatlabArrowInterface.cmake index 40c6b5a51d4fe..149a688b27e15 100644 --- a/matlab/tools/cmake/BuildMatlabArrowInterface.cmake +++ b/matlab/tools/cmake/BuildMatlabArrowInterface.cmake @@ -47,6 +47,7 @@ set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_SOURCES "${CMAKE_SOURCE_DIR}/src/cpp/a "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/timestamp_array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/time32_array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/time64_array.cc" + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/struct_array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/chunked_array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/wrap.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc" @@ -70,10 +71,10 @@ set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_SOURCES "${CMAKE_SOURCE_DIR}/src/cpp/a "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/type/proxy/wrap.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/io/feather/proxy/writer.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/io/feather/proxy/reader.cc" + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/io/csv/proxy/table_writer.cc" + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/io/csv/proxy/table_reader.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/index/validate.cc") - - set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_FACTORY_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/proxy") set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_FACTORY_SOURCES "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/proxy/factory.cc") set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_LIBRARY_INCLUDE_DIRS ${MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_LIBRARY_ROOT_INCLUDE_DIR} diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 242ba8448f4a6..29f8d2da72f3a 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -168,37 +168,44 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${PYARROW_CXXFLAGS}") if(MSVC) # MSVC version of -Wno-return-type-c-linkage - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4190") + string(APPEND CMAKE_CXX_FLAGS " /wd4190") # Cython generates some bitshift expressions that MSVC does not like in # __Pyx_PyFloat_DivideObjC - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4293") + string(APPEND CMAKE_CXX_FLAGS " /wd4293") # Converting to/from C++ bool is pretty wonky in Cython. The C4800 warning # seem harmless, and probably not worth the effort of working around it - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4800") + string(APPEND CMAKE_CXX_FLAGS " /wd4800") # See https://github.com/cython/cython/issues/2731. Change introduced in # Cython 0.29.1 causes "unsafe use of type 'bool' in operation" - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4804") + string(APPEND CMAKE_CXX_FLAGS " /wd4804") + + # See https://github.com/cython/cython/issues/4445. + # + # Cython 3 emits "(void)__Pyx_PyObject_CallMethod0;" to suppress a + # "unused function" warning but the code emits another "function + # call missing argument list" warning. + string(APPEND CMAKE_CXX_FLAGS " /wd4551") else() # Enable perf and other tools to work properly - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + string(APPEND CMAKE_CXX_FLAGS " -fno-omit-frame-pointer") # Suppress Cython warnings - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable -Wno-maybe-uninitialized") + string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-variable -Wno-maybe-uninitialized") if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # Cython warnings in clang - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-parentheses-equality") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-constant-logical-operand") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-declarations") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sometimes-uninitialized") + string(APPEND CMAKE_CXX_FLAGS " -Wno-parentheses-equality") + string(APPEND CMAKE_CXX_FLAGS " -Wno-constant-logical-operand") + string(APPEND CMAKE_CXX_FLAGS " -Wno-missing-declarations") + string(APPEND CMAKE_CXX_FLAGS " -Wno-sometimes-uninitialized") # We have public Cython APIs which return C++ types, which are in an extern # "C" blog (no symbol mangling) and clang doesn't like this - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-type-c-linkage") + string(APPEND CMAKE_CXX_FLAGS " -Wno-return-type-c-linkage") endif() endif() diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 609307528d2ec..25f77d8160ea8 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2366,7 +2366,7 @@ cdef class Expression(_Weakrefable): 1, 2, 3 - ], skip_nulls=false})> + ], null_matching_behavior=MATCH})> """ def __init__(self): diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index d29fa125e2061..48ee676915311 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -1078,7 +1078,8 @@ cdef class FileSystemDataset(Dataset): @classmethod def from_paths(cls, paths, schema=None, format=None, filesystem=None, partitions=None, root_partition=None): - """A Dataset created from a list of paths on a particular filesystem. + """ + A Dataset created from a list of paths on a particular filesystem. Parameters ---------- diff --git a/python/pyarrow/_dataset_parquet.pyx b/python/pyarrow/_dataset_parquet.pyx index 79bd270ce54d2..cf5c44c1c964a 100644 --- a/python/pyarrow/_dataset_parquet.pyx +++ b/python/pyarrow/_dataset_parquet.pyx @@ -595,6 +595,10 @@ cdef class ParquetFileWriteOptions(FileWriteOptions): ), column_encoding=self._properties["column_encoding"], data_page_version=self._properties["data_page_version"], + encryption_properties=self._properties["encryption_properties"], + write_batch_size=self._properties["write_batch_size"], + dictionary_pagesize_limit=self._properties["dictionary_pagesize_limit"], + write_page_index=self._properties["write_page_index"], ) def _set_arrow_properties(self): @@ -631,6 +635,10 @@ cdef class ParquetFileWriteOptions(FileWriteOptions): coerce_timestamps=None, allow_truncated_timestamps=False, use_compliant_nested_type=True, + encryption_properties=None, + write_batch_size=None, + dictionary_pagesize_limit=None, + write_page_index=False, ) self._set_properties() self._set_arrow_properties() diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 42b221ed72a1b..79aa24e4ce8e3 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -988,8 +988,10 @@ cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): cdef shared_ptr[CMetadataRecordBatchReader] reader def __iter__(self): - while True: - yield self.read_chunk() + return self + + def __next__(self): + return self.read_chunk() @property def schema(self): @@ -1699,7 +1701,9 @@ cdef class FlightClient(_Weakrefable): def close(self): """Close the client and disconnect.""" - check_flight_status(self.client.get().Close()) + client = self.client.get() + if client != NULL: + check_flight_status(client.Close()) def __del__(self): # Not ideal, but close() wasn't originally present so diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 4efad2c4d1bc5..067cb5f91681b 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -27,9 +27,10 @@ from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * +# TODO GH-37235: Fix exception handling cdef CDeclaration _create_named_table_provider( dict named_args, const std_vector[c_string]& names, const CSchema& schema -): +) noexcept: cdef: c_string c_name shared_ptr[CTable] c_in_table diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 4bddd2d080f5f..c4cf5830c4128 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -118,16 +118,16 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: c_bool Equals(const CLocation& other) @staticmethod - CResult[CLocation] Parse(c_string& uri_string) + CResult[CLocation] Parse(const c_string& uri_string) @staticmethod - CResult[CLocation] ForGrpcTcp(c_string& host, int port) + CResult[CLocation] ForGrpcTcp(const c_string& host, int port) @staticmethod - CResult[CLocation] ForGrpcTls(c_string& host, int port) + CResult[CLocation] ForGrpcTls(const c_string& host, int port) @staticmethod - CResult[CLocation] ForGrpcUnix(c_string& path) + CResult[CLocation] ForGrpcUnix(const c_string& path) cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint": CFlightEndpoint() @@ -172,7 +172,9 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CResult[unique_ptr[CFlightInfo]] Next() cdef cppclass CSimpleFlightListing" arrow::flight::SimpleFlightListing": - CSimpleFlightListing(vector[CFlightInfo]&& info) + # This doesn't work with Cython >= 3 + # CSimpleFlightListing(vector[CFlightInfo]&& info) + CSimpleFlightListing(const vector[CFlightInfo]& info) cdef cppclass CFlightPayload" arrow::flight::FlightPayload": shared_ptr[CBuffer] descriptor @@ -310,7 +312,10 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CCallHeaders" arrow::flight::CallHeaders": cppclass const_iterator: pair[c_string, c_string] operator*() + # For Cython < 3 const_iterator operator++() + # For Cython >= 3 + const_iterator operator++(int) bint operator==(const_iterator) bint operator!=(const_iterator) const_iterator cbegin() diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index a8398597fe6cd..53e521fc11468 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -436,8 +436,10 @@ cdef class MessageReader(_Weakrefable): return result def __iter__(self): - while True: - yield self.read_next_message() + return self + + def __next__(self): + return self.read_next_message() def read_next_message(self): """ @@ -656,11 +658,10 @@ cdef class RecordBatchReader(_Weakrefable): # cdef block is in lib.pxd def __iter__(self): - while True: - try: - yield self.read_next_batch() - except StopIteration: - return + return self + + def __next__(self): + return self.read_next_batch() @property def schema(self): diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index e07949c675524..9a66dc81226d4 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -819,8 +819,8 @@ cdef class MapScalar(ListScalar): Iterate over this element's values. """ arr = self.values - if array is None: - raise StopIteration + if arr is None: + return for k, v in zip(arr.field(self.type.key_field.name), arr.field(self.type.item_field.name)): yield (k.as_py(), v.as_py()) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index b8a0c38089980..39c3c43daea37 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1615,9 +1615,13 @@ def test_fragments_repr(tempdir, dataset): # partitioned parquet dataset fragment = list(dataset.get_fragments())[0] assert ( + # Ordering of partition items is non-deterministic repr(fragment) == "" + "partition=[key=xxx, group=1]>" or + repr(fragment) == + "" ) # single-file parquet dataset (no partition information in repr) @@ -5291,6 +5295,38 @@ def test_write_dataset_preserve_field_metadata(tempdir): assert dataset.to_table().schema.equals(schema_metadata, check_metadata=True) +def test_write_dataset_write_page_index(tempdir): + for write_statistics in [True, False]: + for write_page_index in [True, False]: + schema = pa.schema([ + pa.field("x", pa.int64()), + pa.field("y", pa.int64())]) + + arrays = [[1, 2, 3], [None, 5, None]] + table = pa.Table.from_arrays(arrays, schema=schema) + + file_format = ds.ParquetFileFormat() + base_dir = tempdir / f"write_page_index_{write_page_index}" + ds.write_dataset( + table, + base_dir, + format="parquet", + file_options=file_format.make_write_options( + write_statistics=write_statistics, + write_page_index=write_page_index, + ), + existing_data_behavior='overwrite_or_ignore', + ) + ds1 = ds.dataset(base_dir, format="parquet") + + for file in ds1.files: + # Can retrieve sorting columns from metadata + metadata = pq.read_metadata(file) + cc = metadata.row_group(0).column(0) + assert cc.has_offset_index is write_page_index + assert cc.has_column_index is write_page_index & write_statistics + + @pytest.mark.parametrize('dstype', [ "fs", "mem" ]) diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 5f6c8c813f12a..8a1dcfb057f74 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -700,6 +700,10 @@ def test_map(pickle_module): for i, j in zip(s, v): assert i == j + # test iteration with missing values + for _ in pa.scalar(None, type=ty): + pass + assert s.as_py() == v assert s[1] == ( pa.scalar('b', type=pa.string()), diff --git a/python/pyproject.toml b/python/pyproject.toml index 7e61304585809..a1de6ac4f1c7e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,7 +19,7 @@ requires = [ "cython >= 0.29.31,<3", "oldest-supported-numpy>=0.14", - "setuptools_scm", + "setuptools_scm < 8.0.0", "setuptools >= 40.1.0", "wheel" ] diff --git a/python/requirements-build.txt b/python/requirements-build.txt index 6378d1b94e1bb..efd653ec470d5 100644 --- a/python/requirements-build.txt +++ b/python/requirements-build.txt @@ -1,4 +1,4 @@ cython>=0.29.31,<3 oldest-supported-numpy>=0.14 -setuptools_scm +setuptools_scm<8.0.0 setuptools>=38.6.0 diff --git a/python/requirements-wheel-build.txt b/python/requirements-wheel-build.txt index e4f5243fbc2fe..00504b0c731a1 100644 --- a/python/requirements-wheel-build.txt +++ b/python/requirements-wheel-build.txt @@ -1,5 +1,5 @@ cython>=0.29.31,<3 oldest-supported-numpy>=0.14 -setuptools_scm +setuptools_scm<8.0.0 setuptools>=58 wheel diff --git a/python/setup.py b/python/setup.py index abd9d03cfb17e..062aac307b1e4 100755 --- a/python/setup.py +++ b/python/setup.py @@ -492,7 +492,7 @@ def has_ext_modules(foo): 'pyarrow/_generated_version.py'), 'version_scheme': guess_next_dev_version }, - setup_requires=['setuptools_scm', 'cython >= 0.29.31,<3'] + setup_requires, + setup_requires=['setuptools_scm < 8.0.0', 'cython >= 0.29.31,<3'] + setup_requires, install_requires=install_requires, tests_require=['pytest', 'pandas', 'hypothesis'], python_requires='>=3.8', diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 8f44f8936bdd3..09183250ba3e0 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -183,6 +183,22 @@ configure_tzdb <- function() { # Just to be extra safe, let's wrap this in a try(); # we don't want a failed startup message to prevent the package from loading try({ + # On MacOS only, Check if we are running in under emulation, and warn this will not work + if (on_rosetta()) { + packageStartupMessage( + paste( + "Warning:", + " It appears that you are running R and Arrow in emulation (i.e. you're", + " running an Intel version of R on a non-Intel mac). This configuration is", + " not supported by arrow, you should install a native (arm64) build of R", + " and use arrow with that. See https://cran.r-project.org/bin/macosx/", + "", + sep = "\n" + ) + ) + } + + features <- arrow_info()$capabilities # That has all of the #ifdef features, plus the compression libs and the # string libraries (but not the memory allocators, they're added elsewhere) @@ -225,6 +241,11 @@ on_macos_10_13_or_lower <- function() { package_version(unname(Sys.info()["release"])) < "18.0.0" } +on_rosetta <- function() { + identical(tolower(Sys.info()[["sysname"]]), "darwin") && + identical(system("sysctl -n sysctl.proc_translated", intern = TRUE), "1") +} + option_use_threads <- function() { !is_false(getOption("arrow.use_threads")) } diff --git a/r/R/install-arrow.R b/r/R/install-arrow.R index 8380fa2af989c..7017d4f39b876 100644 --- a/r/R/install-arrow.R +++ b/r/R/install-arrow.R @@ -61,7 +61,6 @@ install_arrow <- function(nightly = FALSE, verbose = Sys.getenv("ARROW_R_DEV", FALSE), repos = getOption("repos"), ...) { - sysname <- tolower(Sys.info()[["sysname"]]) conda <- isTRUE(grepl("conda", R.Version()$platform)) if (conda) { @@ -80,8 +79,7 @@ install_arrow <- function(nightly = FALSE, # On the M1, we can't use the usual autobrew, which pulls Intel dependencies apple_m1 <- grepl("arm-apple|aarch64.*darwin", R.Version()$platform) # On Rosetta, we have to build without JEMALLOC, so we also can't autobrew - rosetta <- identical(sysname, "darwin") && identical(system("sysctl -n sysctl.proc_translated", intern = TRUE), "1") - if (rosetta) { + if (on_rosetta()) { Sys.setenv(ARROW_JEMALLOC = "OFF") } if (apple_m1 || rosetta) { diff --git a/r/README.md b/r/README.md index d343d6979c0a7..3c1e3570ffdd4 100644 --- a/r/README.md +++ b/r/README.md @@ -73,6 +73,8 @@ additional steps should be required. There are some special cases to note: +- On macOS, the R you use with Arrow should match the architecture of the machine you are using. If you're using an ARM (aka M1, M2, etc.) processor use R compiled for arm64. If you're using an Intel based mac, use R compiled for x86. Using R and Arrow compiled for Intel based macs on an ARM based mac will result in segfaults and crashes. + - On Linux the installation process can sometimes be more involved because CRAN does not host binaries for Linux. For more information please see the [installation guide](https://arrow.apache.org/docs/r/articles/install.html). diff --git a/r/inst/build_arrow_static.sh b/r/inst/build_arrow_static.sh index fe56b9fca9e59..52ac5b7d3245b 100755 --- a/r/inst/build_arrow_static.sh +++ b/r/inst/build_arrow_static.sh @@ -55,6 +55,13 @@ else ARROW_DEFAULT_PARAM="OFF" fi +# Disable mimalloc on IntelLLVM because the bundled version (2.0.x) does not support it +case "$CXX" in + *icpx*) + ARROW_MIMALLOC="OFF" + ;; +esac + mkdir -p "${BUILD_DIR}" pushd "${BUILD_DIR}" ${CMAKE} -DARROW_BOOST_USE_SHARED=OFF \