diff --git a/ci/scripts/java_build.sh b/ci/scripts/java_build.sh index 0fa1edab429c0..212ec6eb11476 100755 --- a/ci/scripts/java_build.sh +++ b/ci/scripts/java_build.sh @@ -72,9 +72,6 @@ if [ $ARROW_JAVA_SKIP_GIT_PLUGIN ]; then mvn="${mvn} -Dmaven.gitcommitid.skip=true" fi -# Use `2 * ncores` threads -mvn="${mvn} -T 2C" - # https://github.com/apache/arrow/issues/41429 # TODO: We want to out-of-source build. This is a workaround. We copy # all needed files to the build directory from the source directory @@ -98,10 +95,12 @@ if [ "${ARROW_JAVA_JNI}" = "ON" ]; then mvn="${mvn} -Darrow.cpp.build.dir=${java_jni_dist_dir} -Parrow-jni" fi -${mvn} clean install +# Use `2 * ncores` threads +${mvn} -T 2C clean install if [ "${BUILD_DOCS_JAVA}" == "ON" ]; then # HTTP pooling is turned of to avoid download issues https://issues.apache.org/jira/browse/ARROW-11633 + # GH-43378: Maven site plugins not compatible with multithreading mkdir -p ${build_dir}/docs/java/reference ${mvn} -Dcheckstyle.skip=true -Dhttp.keepAlive=false -Dmaven.wagon.http.pool=false clean install site rsync -a target/site/apidocs/ ${build_dir}/docs/java/reference diff --git a/cpp/cmake_modules/UseCython.cmake b/cpp/cmake_modules/UseCython.cmake index e15ac59490c6e..7d88daa4fade9 100644 --- a/cpp/cmake_modules/UseCython.cmake +++ b/cpp/cmake_modules/UseCython.cmake @@ -184,4 +184,9 @@ function(cython_add_module _name pyx_target_name generated_files) add_dependencies(${_name} ${pyx_target_name}) endfunction() +execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from Cython.Compiler.Version import version; print(version)" + OUTPUT_VARIABLE CYTHON_VERSION_OUTPUT + OUTPUT_STRIP_TRAILING_WHITESPACE) +set(CYTHON_VERSION "${CYTHON_VERSION_OUTPUT}") + include(CMakeParseArguments) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 9c66a58c54261..67d2c19f98a2d 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -907,6 +907,7 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON extension/fixed_shape_tensor.cc + extension/opaque.cc json/options.cc json/chunked_builder.cc json/chunker.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 3df86e7d6936c..bd9be3e8a9532 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -865,6 +865,25 @@ std::shared_ptr GetCastToHalfFloat() { return func; } +struct NullExtensionTypeMatcher : public TypeMatcher { + ~NullExtensionTypeMatcher() override = default; + + bool Matches(const DataType& type) const override { + return type.id() == Type::EXTENSION && + checked_cast(type).storage_id() == Type::NA; + } + + std::string ToString() const override { return "extension"; } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + return casted != nullptr; + } +}; + } // namespace std::vector> GetNumericCasts() { @@ -875,6 +894,10 @@ std::vector> GetNumericCasts() { auto cast_null = std::make_shared("cast_null", Type::NA); DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), OutputAllNull)); + // Explicitly allow casting extension type with null backing array to null + DCHECK_OK(cast_null->AddKernel( + Type::EXTENSION, {InputType(std::make_shared())}, null(), + OutputAllNull)); functions.push_back(cast_null); functions.push_back(GetCastToInteger("cast_int8")); diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index c15c42874d4de..6741ab602f50b 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -21,4 +21,10 @@ add_arrow_test(test PREFIX "arrow-fixed-shape-tensor") +add_arrow_test(test + SOURCES + opaque_test.cc + PREFIX + "arrow-extension-opaque") + arrow_install_all_headers("arrow/extension") diff --git a/cpp/src/arrow/extension/opaque.cc b/cpp/src/arrow/extension/opaque.cc new file mode 100644 index 0000000000000..c430bb5d2eaab --- /dev/null +++ b/cpp/src/arrow/extension/opaque.cc @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/opaque.h" + +#include + +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/util/logging.h" + +#include +#include +#include + +namespace arrow::extension { + +std::string OpaqueType::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() + << "[storage_type=" << storage_type_->ToString(show_metadata) + << ", type_name=" << type_name_ << ", vendor_name=" << vendor_name_ << "]>"; + return ss.str(); +} + +bool OpaqueType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& opaque = internal::checked_cast(other); + return storage_type()->Equals(*opaque.storage_type()) && + type_name() == opaque.type_name() && vendor_name() == opaque.vendor_name(); +} + +std::string OpaqueType::Serialize() const { + rapidjson::Document document; + document.SetObject(); + rapidjson::Document::AllocatorType& allocator = document.GetAllocator(); + + rapidjson::Value type_name(rapidjson::StringRef(type_name_)); + document.AddMember(rapidjson::Value("type_name", allocator), type_name, allocator); + rapidjson::Value vendor_name(rapidjson::StringRef(vendor_name_)); + document.AddMember(rapidjson::Value("vendor_name", allocator), vendor_name, allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result> OpaqueType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + rapidjson::Document document; + const auto& parsed = document.Parse(serialized_data.data(), serialized_data.length()); + if (parsed.HasParseError()) { + return Status::Invalid("Invalid serialized JSON data for OpaqueType: ", + rapidjson::GetParseError_En(parsed.GetParseError()), ": ", + serialized_data); + } else if (!document.IsObject()) { + return Status::Invalid("Invalid serialized JSON data for OpaqueType: not an object"); + } + if (!document.HasMember("type_name")) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: missing type_name"); + } else if (!document.HasMember("vendor_name")) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: missing vendor_name"); + } + + const auto& type_name = document["type_name"]; + const auto& vendor_name = document["vendor_name"]; + if (!type_name.IsString()) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: type_name is not a string"); + } else if (!vendor_name.IsString()) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: vendor_name is not a string"); + } + + return opaque(std::move(storage_type), type_name.GetString(), vendor_name.GetString()); +} + +std::shared_ptr OpaqueType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.opaque", + internal::checked_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +std::shared_ptr opaque(std::shared_ptr storage_type, + std::string type_name, std::string vendor_name) { + return std::make_shared(std::move(storage_type), std::move(type_name), + std::move(vendor_name)); +} + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/opaque.h b/cpp/src/arrow/extension/opaque.h new file mode 100644 index 0000000000000..9814b391cbad6 --- /dev/null +++ b/cpp/src/arrow/extension/opaque.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension_type.h" +#include "arrow/type.h" + +namespace arrow::extension { + +/// \brief Opaque is a placeholder for a type from an external (usually +/// non-Arrow) system that could not be interpreted. +class ARROW_EXPORT OpaqueType : public ExtensionType { + public: + /// \brief Construct an OpaqueType. + /// + /// \param[in] storage_type The underlying storage type. Should be + /// arrow::null if there is no data. + /// \param[in] type_name The name of the type in the external system. + /// \param[in] vendor_name The name of the external system. + explicit OpaqueType(std::shared_ptr storage_type, std::string type_name, + std::string vendor_name) + : ExtensionType(std::move(storage_type)), + type_name_(std::move(type_name)), + vendor_name_(std::move(vendor_name)) {} + + std::string extension_name() const override { return "arrow.opaque"; } + std::string ToString(bool show_metadata) const override; + bool ExtensionEquals(const ExtensionType& other) const override; + std::string Serialize() const override; + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + /// Create an OpaqueArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + std::string_view type_name() const { return type_name_; } + std::string_view vendor_name() const { return vendor_name_; } + + private: + std::string type_name_; + std::string vendor_name_; +}; + +/// \brief Opaque is a wrapper for (usually binary) data from an external +/// (often non-Arrow) system that could not be interpreted. +class ARROW_EXPORT OpaqueArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Return an OpaqueType instance. +ARROW_EXPORT std::shared_ptr opaque(std::shared_ptr storage_type, + std::string type_name, + std::string vendor_name); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/opaque_test.cc b/cpp/src/arrow/extension/opaque_test.cc new file mode 100644 index 0000000000000..1629cdb39651c --- /dev/null +++ b/cpp/src/arrow/extension/opaque_test.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/extension/fixed_shape_tensor.h" +#include "arrow/extension/opaque.h" +#include "arrow/extension_type.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +TEST(OpaqueType, Basics) { + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type2 = internal::checked_pointer_cast( + extension::opaque(null(), "type2", "vendor")); + ASSERT_EQ("arrow.opaque", type->extension_name()); + ASSERT_EQ(*type, *type); + ASSERT_NE(*arrow::null(), *type); + ASSERT_NE(*type, *type2); + ASSERT_EQ(*arrow::null(), *type->storage_type()); + ASSERT_THAT(type->Serialize(), ::testing::Not(::testing::IsEmpty())); + ASSERT_EQ(R"({"type_name":"type","vendor_name":"vendor"})", type->Serialize()); + ASSERT_EQ("type", type->type_name()); + ASSERT_EQ("vendor", type->vendor_name()); + ASSERT_EQ( + "extension", + type->ToString(false)); +} + +TEST(OpaqueType, Equals) { + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type2 = internal::checked_pointer_cast( + extension::opaque(null(), "type2", "vendor")); + auto type3 = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor2")); + auto type4 = internal::checked_pointer_cast( + extension::opaque(int64(), "type", "vendor")); + auto type5 = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type6 = internal::checked_pointer_cast( + extension::fixed_shape_tensor(float64(), {1})); + + ASSERT_EQ(*type, *type); + ASSERT_EQ(*type2, *type2); + ASSERT_EQ(*type3, *type3); + ASSERT_EQ(*type4, *type4); + ASSERT_EQ(*type5, *type5); + + ASSERT_EQ(*type, *type5); + + ASSERT_NE(*type, *type2); + ASSERT_NE(*type, *type3); + ASSERT_NE(*type, *type4); + ASSERT_NE(*type, *type6); + + ASSERT_NE(*type2, *type); + ASSERT_NE(*type2, *type3); + ASSERT_NE(*type2, *type4); + ASSERT_NE(*type2, *type6); + + ASSERT_NE(*type3, *type); + ASSERT_NE(*type3, *type2); + ASSERT_NE(*type3, *type4); + ASSERT_NE(*type3, *type6); + + ASSERT_NE(*type4, *type); + ASSERT_NE(*type4, *type2); + ASSERT_NE(*type4, *type3); + ASSERT_NE(*type4, *type6); + ASSERT_NE(*type6, *type4); +} + +TEST(OpaqueType, CreateFromArray) { + auto type = internal::checked_pointer_cast( + extension::opaque(binary(), "geometry", "adbc.postgresql")); + auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); + auto array = ExtensionType::WrapArray(type, storage); + ASSERT_EQ(2, array->length()); + ASSERT_EQ(1, array->null_count()); +} + +void CheckDeserialize(const std::string& serialized, + const std::shared_ptr& expected) { + auto type = internal::checked_pointer_cast(expected); + ASSERT_OK_AND_ASSIGN(auto deserialized, + type->Deserialize(type->storage_type(), serialized)); + ASSERT_EQ(*expected, *deserialized); +} + +TEST(OpaqueType, Deserialize) { + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "type", "vendor_name": "vendor"})", + extension::opaque(null(), "type", "vendor"))); + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "long name", "vendor_name": "long name"})", + extension::opaque(null(), "long name", "long name"))); + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "名前", "vendor_name": "名字"})", + extension::opaque(null(), "名前", "名字"))); + ASSERT_NO_FATAL_FAILURE(CheckDeserialize( + R"({"type_name": "type", "vendor_name": "vendor", "extra_field": 2})", + extension::opaque(null(), "type", "vendor"))); + + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("The document is empty"), + type->Deserialize(null(), R"()")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + testing::HasSubstr("Missing a name for object member"), + type->Deserialize(null(), R"({)")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("not an object"), + type->Deserialize(null(), R"([])")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("missing type_name"), + type->Deserialize(null(), R"({})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("type_name is not a string"), + type->Deserialize(null(), R"({"type_name": 2, "vendor_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("type_name is not a string"), + type->Deserialize(null(), R"({"type_name": null, "vendor_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("vendor_name is not a string"), + type->Deserialize(null(), R"({"vendor_name": 2, "type_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("vendor_name is not a string"), + type->Deserialize(null(), R"({"vendor_name": null, "type_name": ""})")); +} + +TEST(OpaqueType, MetadataRoundTrip) { + for (const auto& type : { + extension::opaque(null(), "foo", "bar"), + extension::opaque(binary(), "geometry", "postgis"), + extension::opaque(fixed_size_list(int64(), 4), "foo", "bar"), + extension::opaque(utf8(), "foo", "bar"), + }) { + auto opaque = internal::checked_pointer_cast(type); + std::string serialized = opaque->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + opaque->Deserialize(opaque->storage_type(), serialized)); + ASSERT_EQ(*type, *deserialized); + } +} + +TEST(OpaqueType, BatchRoundTrip) { + auto type = internal::checked_pointer_cast( + extension::opaque(binary(), "geometry", "adbc.postgresql")); + ExtensionTypeGuard guard(type); + + auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); + auto array = ExtensionType::WrapArray(type, storage); + auto batch = + RecordBatch::Make(schema({field("field", type)}), array->length(), {array}); + + std::shared_ptr written; + { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(&written)); + } + + ASSERT_EQ(*batch->schema(), *written->schema()); + ASSERT_BATCHES_EQUAL(*batch, *written); +} + +} // namespace arrow diff --git a/cpp/src/arrow/filesystem/azurefs.cc b/cpp/src/arrow/filesystem/azurefs.cc index a3aa2c8e837d9..9b3c0c0c1d703 100644 --- a/cpp/src/arrow/filesystem/azurefs.cc +++ b/cpp/src/arrow/filesystem/azurefs.cc @@ -3199,4 +3199,31 @@ Result> AzureFileSystem::OpenAppendStream( return impl_->OpenAppendStream(location, metadata, false, this); } +Result AzureFileSystem::PathFromUri(const std::string& uri_string) const { + /// We can not use `internal::PathFromUriHelper` here because for Azure we have to + /// support different URI schemes where the authority is handled differently. + /// Example (both should yield the same path `container/some/path`): + /// - (1) abfss://storageacc.blob.core.windows.net/container/some/path + /// - (2) abfss://acc:pw@container/some/path + /// The authority handling is different with these two URIs. (1) requires no prepending + /// of the authority to the path, while (2) requires to preprend the authority to the + /// path. + std::string path; + Uri uri; + RETURN_NOT_OK(uri.Parse(uri_string)); + RETURN_NOT_OK(AzureOptions::FromUri(uri, &path)); + + std::vector supported_schemes = {"abfs", "abfss"}; + const auto scheme = uri.scheme(); + if (std::find(supported_schemes.begin(), supported_schemes.end(), scheme) == + supported_schemes.end()) { + std::string expected_schemes = + ::arrow::internal::JoinStrings(supported_schemes, ", "); + return Status::Invalid("The filesystem expected a URI with one of the schemes (", + expected_schemes, ") but received ", uri_string); + } + + return path; +} + } // namespace arrow::fs diff --git a/cpp/src/arrow/filesystem/azurefs.h b/cpp/src/arrow/filesystem/azurefs.h index 93d6ec2f945b4..072b061eeb2a9 100644 --- a/cpp/src/arrow/filesystem/azurefs.h +++ b/cpp/src/arrow/filesystem/azurefs.h @@ -367,6 +367,8 @@ class ARROW_EXPORT AzureFileSystem : public FileSystem { Result> OpenAppendStream( const std::string& path, const std::shared_ptr& metadata) override; + + Result PathFromUri(const std::string& uri_string) const override; }; } // namespace arrow::fs diff --git a/cpp/src/arrow/filesystem/azurefs_test.cc b/cpp/src/arrow/filesystem/azurefs_test.cc index 9a11a6f24995a..36646f417cbe1 100644 --- a/cpp/src/arrow/filesystem/azurefs_test.cc +++ b/cpp/src/arrow/filesystem/azurefs_test.cc @@ -2958,5 +2958,14 @@ TEST_F(TestAzuriteFileSystem, OpenInputFileClosed) { ASSERT_RAISES(Invalid, stream->ReadAt(1, 1)); ASSERT_RAISES(Invalid, stream->Seek(2)); } + +TEST_F(TestAzuriteFileSystem, PathFromUri) { + ASSERT_EQ( + "container/some/path", + fs()->PathFromUri("abfss://storageacc.blob.core.windows.net/container/some/path")); + ASSERT_EQ("container/some/path", + fs()->PathFromUri("abfss://acc:pw@container/some/path")); + ASSERT_RAISES(Invalid, fs()->PathFromUri("http://acc:pw@container/some/path")); +} } // namespace fs } // namespace arrow diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 43ac48b87678e..98f93705f6f56 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -262,7 +262,9 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES + test_auth_handlers.cc test_definitions.cc + test_flight_server.cc test_util.cc DEPENDENCIES flight_grpc_gen diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 101bb06b21288..3d52bc3f5ae06 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -52,7 +52,9 @@ // Include before test_util.h (boost), contains Windows fixes #include "arrow/flight/platform.h" #include "arrow/flight/serialization_internal.h" +#include "arrow/flight/test_auth_handlers.h" #include "arrow/flight/test_definitions.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" // OTel includes must come after any gRPC includes, and // client_header_internal.h includes gRPC. See: @@ -247,7 +249,7 @@ TEST(TestFlight, ConnectUriUnix) { // CI environments don't have an IPv6 interface configured TEST(TestFlight, DISABLED_IpV6Port) { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0)); FlightServerOptions options(location); @@ -261,7 +263,7 @@ TEST(TestFlight, DISABLED_IpV6Port) { } TEST(TestFlight, ServerCallContextIncomingHeaders) { - auto server = ExampleTestServer(); + auto server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -290,7 +292,7 @@ TEST(TestFlight, ServerCallContextIncomingHeaders) { class TestFlightClient : public ::testing::Test { public: void SetUp() { - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 665c1f1ba036a..da6fcf81eb737 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -36,6 +36,7 @@ #include "arrow/flight/sql/server.h" #include "arrow/flight/sql/server_session_middleware.h" #include "arrow/flight/sql/types.h" +#include "arrow/flight/test_auth_handlers.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" diff --git a/cpp/src/arrow/flight/test_auth_handlers.cc b/cpp/src/arrow/flight/test_auth_handlers.cc new file mode 100644 index 0000000000000..856ccf0f2b271 --- /dev/null +++ b/cpp/src/arrow/flight/test_auth_handlers.cc @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_auth.h" +#include "arrow/flight/test_auth_handlers.h" +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow::flight { + +// TestServerAuthHandler + +TestServerAuthHandler::TestServerAuthHandler(const std::string& username, + const std::string& password) + : username_(username), password_(password) {} + +TestServerAuthHandler::~TestServerAuthHandler() {} + +Status TestServerAuthHandler::Authenticate(const ServerCallContext& context, + ServerAuthSender* outgoing, + ServerAuthReader* incoming) { + std::string token; + RETURN_NOT_OK(incoming->Read(&token)); + if (token != password_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + RETURN_NOT_OK(outgoing->Write(username_)); + return Status::OK(); +} + +Status TestServerAuthHandler::IsValid(const ServerCallContext& context, + const std::string& token, + std::string* peer_identity) { + if (token != password_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + *peer_identity = username_; + return Status::OK(); +} + +// TestServerBasicAuthHandler + +TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username, + const std::string& password) { + basic_auth_.username = username; + basic_auth_.password = password; +} + +TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {} + +Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context, + ServerAuthSender* outgoing, + ServerAuthReader* incoming) { + std::string token; + RETURN_NOT_OK(incoming->Read(&token)); + ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token)); + if (incoming_auth.username != basic_auth_.username || + incoming_auth.password != basic_auth_.password) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + RETURN_NOT_OK(outgoing->Write(basic_auth_.username)); + return Status::OK(); +} + +Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context, + const std::string& token, + std::string* peer_identity) { + if (token != basic_auth_.username) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + *peer_identity = basic_auth_.username; + return Status::OK(); +} + +// TestClientAuthHandler + +TestClientAuthHandler::TestClientAuthHandler(const std::string& username, + const std::string& password) + : username_(username), password_(password) {} + +TestClientAuthHandler::~TestClientAuthHandler() {} + +Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing, + ClientAuthReader* incoming) { + RETURN_NOT_OK(outgoing->Write(password_)); + std::string username; + RETURN_NOT_OK(incoming->Read(&username)); + if (username != username_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + return Status::OK(); +} + +Status TestClientAuthHandler::GetToken(std::string* token) { + *token = password_; + return Status::OK(); +} + +// TestClientBasicAuthHandler + +TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username, + const std::string& password) { + basic_auth_.username = username; + basic_auth_.password = password; +} + +TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {} + +Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing, + ClientAuthReader* incoming) { + ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString()); + RETURN_NOT_OK(outgoing->Write(pb_result)); + RETURN_NOT_OK(incoming->Read(&token_)); + return Status::OK(); +} + +Status TestClientBasicAuthHandler::GetToken(std::string* token) { + *token = token_; + return Status::OK(); +} + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_auth_handlers.h b/cpp/src/arrow/flight/test_auth_handlers.h new file mode 100644 index 0000000000000..74f48798f3b02 --- /dev/null +++ b/cpp/src/arrow/flight/test_auth_handlers.h @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_auth.h" +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +// A pair of authentication handlers that check for a predefined password +// and set the peer identity to a predefined username. + +namespace arrow::flight { + +class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler { + public: + explicit TestServerAuthHandler(const std::string& username, + const std::string& password); + ~TestServerAuthHandler() override; + Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, + ServerAuthReader* incoming) override; + Status IsValid(const ServerCallContext& context, const std::string& token, + std::string* peer_identity) override; + + private: + std::string username_; + std::string password_; +}; + +class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler { + public: + explicit TestServerBasicAuthHandler(const std::string& username, + const std::string& password); + ~TestServerBasicAuthHandler() override; + Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, + ServerAuthReader* incoming) override; + Status IsValid(const ServerCallContext& context, const std::string& token, + std::string* peer_identity) override; + + private: + BasicAuth basic_auth_; +}; + +class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { + public: + explicit TestClientAuthHandler(const std::string& username, + const std::string& password); + ~TestClientAuthHandler() override; + Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; + Status GetToken(std::string* token) override; + + private: + std::string username_; + std::string password_; +}; + +class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler { + public: + explicit TestClientBasicAuthHandler(const std::string& username, + const std::string& password); + ~TestClientBasicAuthHandler() override; + Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; + Status GetToken(std::string* token) override; + + private: + BasicAuth basic_auth_; + std::string token_; +}; + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index c43b693d84a47..273d394c288d9 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -27,6 +27,7 @@ #include "arrow/array/util.h" #include "arrow/flight/api.h" #include "arrow/flight/client_middleware.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/flight/types_async.h" @@ -53,7 +54,7 @@ using arrow::internal::checked_cast; // Tests of initialization/shutdown void ConnectivityTest::TestGetPort() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -61,7 +62,7 @@ void ConnectivityTest::TestGetPort() { ASSERT_GT(server->port(), 0); } void ConnectivityTest::TestBuilderHook() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -80,7 +81,7 @@ void ConnectivityTest::TestShutdown() { constexpr int kIterations = 10; ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); for (int i = 0; i < kIterations; i++) { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -92,7 +93,7 @@ void ConnectivityTest::TestShutdown() { } } void ConnectivityTest::TestShutdownWithDeadline() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -105,7 +106,7 @@ void ConnectivityTest::TestShutdownWithDeadline() { ASSERT_OK(server->Wait()); } void ConnectivityTest::TestBrokenConnection() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -151,7 +152,7 @@ class GetFlightInfoListener : public AsyncListener { } // namespace void DataTest::SetUpTest() { - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -1822,7 +1823,7 @@ void AsyncClientTest::SetUpTest() { ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); FlightServerOptions server_options(location); ASSERT_OK(server_->Init(server_options)); diff --git a/cpp/src/arrow/flight/test_flight_server.cc b/cpp/src/arrow/flight/test_flight_server.cc new file mode 100644 index 0000000000000..0ea95ebd15b07 --- /dev/null +++ b/cpp/src/arrow/flight/test_flight_server.cc @@ -0,0 +1,417 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/test_flight_server.h" + +#include "arrow/array/array_base.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/flight/server.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/status.h" + +namespace arrow::flight { +namespace { + +class ErrorRecordBatchReader : public RecordBatchReader { + public: + ErrorRecordBatchReader() : schema_(arrow::schema({})) {} + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + *out = nullptr; + return Status::OK(); + } + + Status Close() override { + // This should be propagated over DoGet to the client + return Status::IOError("Expected error"); + } + + private: + std::shared_ptr schema_; +}; + +Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { + if (ticket.ticket == "ticket-ints-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-floats-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleFloatBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-dicts-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleDictBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-large-batch-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleLargeBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else { + return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); + } +} + +} // namespace + +std::unique_ptr TestFlightServer::Make() { + return std::make_unique(); +} + +Status TestFlightServer::ListFlights(const ServerCallContext& context, + const Criteria* criteria, + std::unique_ptr* listings) { + std::vector flights = ExampleFlightInfo(); + if (criteria && criteria->expression != "") { + // For test purposes, if we get criteria, return no results + flights.clear(); + } + *listings = std::make_unique(flights); + return Status::OK(); +} + +Status TestFlightServer::GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* out) { + // Test that Arrow-C++ status codes make it through the transport + if (request.type == FlightDescriptor::DescriptorType::CMD && + request.cmd == "status-outofmemory") { + return Status::OutOfMemory("Sentinel"); + } + + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *out = std::make_unique(info); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); +} + +Status TestFlightServer::DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) { + // Test for ARROW-5095 + if (request.ticket == "ARROW-5095-fail") { + return Status::UnknownError("Server-side error"); + } + if (request.ticket == "ARROW-5095-success") { + return Status::OK(); + } + if (request.ticket == "ARROW-13253-DoGet-Batch") { + // Make batch > 2GiB in size + ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); + *data_stream = std::make_unique(std::move(reader)); + return Status::OK(); + } + if (request.ticket == "ticket-stream-error") { + auto reader = std::make_shared(); + *data_stream = std::make_unique(std::move(reader)); + return Status::OK(); + } + + std::shared_ptr batch_reader; + RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); + + *data_stream = std::make_unique(batch_reader); + return Status::OK(); +} + +Status TestFlightServer::DoPut(const ServerCallContext&, + std::unique_ptr reader, + std::unique_ptr writer) { + return reader->ToRecordBatches().status(); +} + +Status TestFlightServer::DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) { + // Test various scenarios for a DoExchange + if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) { + return Status::Invalid("Must provide a command descriptor"); + } + + const std::string& cmd = reader->descriptor().cmd; + if (cmd == "error") { + // Immediately return an error to the client. + return Status::NotImplemented("Expected error"); + } else if (cmd == "get") { + return RunExchangeGet(std::move(reader), std::move(writer)); + } else if (cmd == "put") { + return RunExchangePut(std::move(reader), std::move(writer)); + } else if (cmd == "counter") { + return RunExchangeCounter(std::move(reader), std::move(writer)); + } else if (cmd == "total") { + return RunExchangeTotal(std::move(reader), std::move(writer)); + } else if (cmd == "echo") { + return RunExchangeEcho(std::move(reader), std::move(writer)); + } else if (cmd == "large_batch") { + return RunExchangeLargeBatch(std::move(reader), std::move(writer)); + } else if (cmd == "TestUndrained") { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + return Status::OK(); + } else { + return Status::NotImplemented("Scenario not implemented: ", cmd); + } +} + +// A simple example - act like DoGet. +Status TestFlightServer::RunExchangeGet(std::unique_ptr reader, + std::unique_ptr writer) { + RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + for (const auto& batch : batches) { + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + } + return Status::OK(); +} + +// A simple example - act like DoPut +Status TestFlightServer::RunExchangePut(std::unique_ptr reader, + std::unique_ptr writer) { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + if (!schema->Equals(ExampleIntSchema(), false)) { + return Status::Invalid("Schema is not as expected"); + } + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + FlightStreamChunk chunk; + for (const auto& batch : batches) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data) { + return Status::Invalid("Expected another batch"); + } + if (!batch->Equals(*chunk.data)) { + return Status::Invalid("Batch does not match"); + } + } + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (chunk.data || chunk.app_metadata) { + return Status::Invalid("Too many batches"); + } + + RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done"))); + return Status::OK(); +} + +// Read some number of record batches from the client, send a +// metadata message back with the count, then echo the batches back. +Status TestFlightServer::RunExchangeCounter(std::unique_ptr reader, + std::unique_ptr writer) { + std::vector> batches; + FlightStreamChunk chunk; + int chunks = 0; + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (chunk.data) { + batches.push_back(chunk.data); + chunks++; + } + } + + // Echo back the number of record batches read. + std::shared_ptr buf = Buffer::FromString(std::to_string(chunks)); + RETURN_NOT_OK(writer->WriteMetadata(buf)); + // Echo the record batches themselves. + if (chunks > 0) { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + RETURN_NOT_OK(writer->Begin(schema)); + + for (const auto& batch : batches) { + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + } + } + + return Status::OK(); +} + +// Read int64 batches from the client, each time sending back a +// batch with a running sum of columns. +Status TestFlightServer::RunExchangeTotal(std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk{}; + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + // Ensure the schema contains only int64 columns + for (const auto& field : schema->fields()) { + if (field->type()->id() != Type::type::INT64) { + return Status::Invalid("Field is not INT64: ", field->name()); + } + } + std::vector sums(schema->num_fields()); + std::vector> columns(schema->num_fields()); + RETURN_NOT_OK(writer->Begin(schema)); + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (chunk.data) { + if (!chunk.data->schema()->Equals(schema, false)) { + // A compliant client implementation would make this impossible + return Status::Invalid("Schemas are incompatible"); + } + + // Update the running totals + auto builder = std::make_shared(); + int col_index = 0; + for (const auto& column : chunk.data->columns()) { + auto arr = std::dynamic_pointer_cast(column); + if (!arr) { + return MakeFlightError(FlightStatusCode::Internal, "Could not cast array"); + } + for (int row = 0; row < column->length(); row++) { + if (!arr->IsNull(row)) { + sums[col_index] += arr->Value(row); + } + } + + builder->Reset(); + RETURN_NOT_OK(builder->Append(sums[col_index])); + RETURN_NOT_OK(builder->Finish(&columns[col_index])); + + col_index++; + } + + // Echo the totals to the client + auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns); + RETURN_NOT_OK(writer->WriteRecordBatch(*response)); + } + } + return Status::OK(); +} + +// Echo the client's messages back. +Status TestFlightServer::RunExchangeEcho(std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk; + bool begun = false; + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (!begun && chunk.data) { + begun = true; + RETURN_NOT_OK(writer->Begin(chunk.data->schema())); + } + if (chunk.data && chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); + } else if (chunk.data) { + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); + } else if (chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); + } + } + return Status::OK(); +} + +// Regression test for ARROW-13253 +Status TestFlightServer::RunExchangeLargeBatch( + std::unique_ptr, std::unique_ptr writer) { + ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); + RETURN_NOT_OK(writer->Begin(batch->schema())); + return writer->WriteRecordBatch(*batch); +} + +Status TestFlightServer::RunAction1(const Action& action, + std::unique_ptr* out) { + std::vector results; + for (int i = 0; i < 3; ++i) { + Result result; + std::string value = action.body->ToString() + "-part" + std::to_string(i); + result.body = Buffer::FromString(std::move(value)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); +} + +Status TestFlightServer::RunAction2(std::unique_ptr* out) { + // Empty + *out = std::make_unique(std::vector{}); + return Status::OK(); +} + +Status TestFlightServer::ListIncomingHeaders(const ServerCallContext& context, + const Action& action, + std::unique_ptr* out) { + std::vector results; + std::string_view prefix(*action.body); + for (const auto& header : context.incoming_headers()) { + if (header.first.substr(0, prefix.size()) != prefix) { + continue; + } + Result result; + result.body = + Buffer::FromString(std::string(header.first) + ": " + std::string(header.second)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); +} + +Status TestFlightServer::DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) { + if (action.type == "action1") { + return RunAction1(action, out); + } else if (action.type == "action2") { + return RunAction2(out); + } else if (action.type == "list-incoming-headers") { + return ListIncomingHeaders(context, action, out); + } else { + return Status::NotImplemented(action.type); + } +} + +Status TestFlightServer::ListActions(const ServerCallContext& context, + std::vector* out) { + std::vector actions = ExampleActionTypes(); + *out = std::move(actions); + return Status::OK(); +} + +Status TestFlightServer::GetSchema(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* schema) { + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *schema = std::make_unique(info.serialized_schema()); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); +} + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_flight_server.h b/cpp/src/arrow/flight/test_flight_server.h new file mode 100644 index 0000000000000..794dd834c014b --- /dev/null +++ b/cpp/src/arrow/flight/test_flight_server.h @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow::flight { + +class ARROW_FLIGHT_EXPORT TestFlightServer : public FlightServerBase { + public: + static std::unique_ptr Make(); + + Status ListFlights(const ServerCallContext& context, const Criteria* criteria, + std::unique_ptr* listings) override; + + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* out) override; + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override; + + Status DoPut(const ServerCallContext&, std::unique_ptr reader, + std::unique_ptr writer) override; + + Status DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override; + + // A simple example - act like DoGet. + Status RunExchangeGet(std::unique_ptr reader, + std::unique_ptr writer); + + // A simple example - act like DoPut + Status RunExchangePut(std::unique_ptr reader, + std::unique_ptr writer); + + // Read some number of record batches from the client, send a + // metadata message back with the count, then echo the batches back. + Status RunExchangeCounter(std::unique_ptr reader, + std::unique_ptr writer); + + // Read int64 batches from the client, each time sending back a + // batch with a running sum of columns. + Status RunExchangeTotal(std::unique_ptr reader, + std::unique_ptr writer); + + // Echo the client's messages back. + Status RunExchangeEcho(std::unique_ptr reader, + std::unique_ptr writer); + + // Regression test for ARROW-13253 + Status RunExchangeLargeBatch(std::unique_ptr, + std::unique_ptr writer); + + Status RunAction1(const Action& action, std::unique_ptr* out); + + Status RunAction2(std::unique_ptr* out); + + Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, + std::unique_ptr* out); + + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) override; + + Status ListActions(const ServerCallContext& context, + std::vector* out) override; + + Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* schema) override; +}; + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_server.cc b/cpp/src/arrow/flight/test_server.cc index 18bf2b4135990..ba84b8f532e03 100644 --- a/cpp/src/arrow/flight/test_server.cc +++ b/cpp/src/arrow/flight/test_server.cc @@ -26,6 +26,7 @@ #include #include "arrow/flight/server.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/util/logging.h" @@ -38,7 +39,7 @@ std::unique_ptr g_server; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - g_server = arrow::flight::ExampleTestServer(); + g_server = arrow::flight::TestFlightServer::Make(); arrow::flight::Location location; if (FLAGS_unix.empty()) { diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 8b4245e74e843..127827ff38cdd 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -49,8 +49,7 @@ #include "arrow/flight/api.h" #include "arrow/flight/serialization_internal.h" -namespace arrow { -namespace flight { +namespace arrow::flight { namespace bp = boost::process; namespace fs = boost::filesystem; @@ -90,25 +89,6 @@ Status ResolveCurrentExecutable(fs::path* out) { } } -class ErrorRecordBatchReader : public RecordBatchReader { - public: - ErrorRecordBatchReader() : schema_(arrow::schema({})) {} - - std::shared_ptr schema() const override { return schema_; } - - Status ReadNext(std::shared_ptr* out) override { - *out = nullptr; - return Status::OK(); - } - - Status Close() override { - // This should be propagated over DoGet to the client - return Status::IOError("Expected error"); - } - - private: - std::shared_ptr schema_; -}; } // namespace void TestServer::Start(const std::vector& extra_args) { @@ -171,364 +151,6 @@ int TestServer::port() const { return port_; } const std::string& TestServer::unix_sock() const { return unix_sock_; } -Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { - if (ticket.ticket == "ticket-ints-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-floats-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleFloatBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-dicts-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleDictBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-large-batch-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleLargeBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else { - return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); - } -} - -class FlightTestServer : public FlightServerBase { - Status ListFlights(const ServerCallContext& context, const Criteria* criteria, - std::unique_ptr* listings) override { - std::vector flights = ExampleFlightInfo(); - if (criteria && criteria->expression != "") { - // For test purposes, if we get criteria, return no results - flights.clear(); - } - *listings = std::make_unique(flights); - return Status::OK(); - } - - Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* out) override { - // Test that Arrow-C++ status codes make it through the transport - if (request.type == FlightDescriptor::DescriptorType::CMD && - request.cmd == "status-outofmemory") { - return Status::OutOfMemory("Sentinel"); - } - - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *out = std::make_unique(info); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } - - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - // Test for ARROW-5095 - if (request.ticket == "ARROW-5095-fail") { - return Status::UnknownError("Server-side error"); - } - if (request.ticket == "ARROW-5095-success") { - return Status::OK(); - } - if (request.ticket == "ARROW-13253-DoGet-Batch") { - // Make batch > 2GiB in size - ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - *data_stream = std::make_unique(std::move(reader)); - return Status::OK(); - } - if (request.ticket == "ticket-stream-error") { - auto reader = std::make_shared(); - *data_stream = std::make_unique(std::move(reader)); - return Status::OK(); - } - - std::shared_ptr batch_reader; - RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); - - *data_stream = std::make_unique(batch_reader); - return Status::OK(); - } - - Status DoPut(const ServerCallContext&, std::unique_ptr reader, - std::unique_ptr writer) override { - return reader->ToRecordBatches().status(); - } - - Status DoExchange(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - // Test various scenarios for a DoExchange - if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) { - return Status::Invalid("Must provide a command descriptor"); - } - - const std::string& cmd = reader->descriptor().cmd; - if (cmd == "error") { - // Immediately return an error to the client. - return Status::NotImplemented("Expected error"); - } else if (cmd == "get") { - return RunExchangeGet(std::move(reader), std::move(writer)); - } else if (cmd == "put") { - return RunExchangePut(std::move(reader), std::move(writer)); - } else if (cmd == "counter") { - return RunExchangeCounter(std::move(reader), std::move(writer)); - } else if (cmd == "total") { - return RunExchangeTotal(std::move(reader), std::move(writer)); - } else if (cmd == "echo") { - return RunExchangeEcho(std::move(reader), std::move(writer)); - } else if (cmd == "large_batch") { - return RunExchangeLargeBatch(std::move(reader), std::move(writer)); - } else if (cmd == "TestUndrained") { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - return Status::OK(); - } else { - return Status::NotImplemented("Scenario not implemented: ", cmd); - } - } - - // A simple example - act like DoGet. - Status RunExchangeGet(std::unique_ptr reader, - std::unique_ptr writer) { - RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - for (const auto& batch : batches) { - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - } - return Status::OK(); - } - - // A simple example - act like DoPut - Status RunExchangePut(std::unique_ptr reader, - std::unique_ptr writer) { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - if (!schema->Equals(ExampleIntSchema(), false)) { - return Status::Invalid("Schema is not as expected"); - } - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - FlightStreamChunk chunk; - for (const auto& batch : batches) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data) { - return Status::Invalid("Expected another batch"); - } - if (!batch->Equals(*chunk.data)) { - return Status::Invalid("Batch does not match"); - } - } - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (chunk.data || chunk.app_metadata) { - return Status::Invalid("Too many batches"); - } - - RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done"))); - return Status::OK(); - } - - // Read some number of record batches from the client, send a - // metadata message back with the count, then echo the batches back. - Status RunExchangeCounter(std::unique_ptr reader, - std::unique_ptr writer) { - std::vector> batches; - FlightStreamChunk chunk; - int chunks = 0; - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (chunk.data) { - batches.push_back(chunk.data); - chunks++; - } - } - - // Echo back the number of record batches read. - std::shared_ptr buf = Buffer::FromString(std::to_string(chunks)); - RETURN_NOT_OK(writer->WriteMetadata(buf)); - // Echo the record batches themselves. - if (chunks > 0) { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - RETURN_NOT_OK(writer->Begin(schema)); - - for (const auto& batch : batches) { - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - } - } - - return Status::OK(); - } - - // Read int64 batches from the client, each time sending back a - // batch with a running sum of columns. - Status RunExchangeTotal(std::unique_ptr reader, - std::unique_ptr writer) { - FlightStreamChunk chunk{}; - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - // Ensure the schema contains only int64 columns - for (const auto& field : schema->fields()) { - if (field->type()->id() != Type::type::INT64) { - return Status::Invalid("Field is not INT64: ", field->name()); - } - } - std::vector sums(schema->num_fields()); - std::vector> columns(schema->num_fields()); - RETURN_NOT_OK(writer->Begin(schema)); - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (chunk.data) { - if (!chunk.data->schema()->Equals(schema, false)) { - // A compliant client implementation would make this impossible - return Status::Invalid("Schemas are incompatible"); - } - - // Update the running totals - auto builder = std::make_shared(); - int col_index = 0; - for (const auto& column : chunk.data->columns()) { - auto arr = std::dynamic_pointer_cast(column); - if (!arr) { - return MakeFlightError(FlightStatusCode::Internal, "Could not cast array"); - } - for (int row = 0; row < column->length(); row++) { - if (!arr->IsNull(row)) { - sums[col_index] += arr->Value(row); - } - } - - builder->Reset(); - RETURN_NOT_OK(builder->Append(sums[col_index])); - RETURN_NOT_OK(builder->Finish(&columns[col_index])); - - col_index++; - } - - // Echo the totals to the client - auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns); - RETURN_NOT_OK(writer->WriteRecordBatch(*response)); - } - } - return Status::OK(); - } - - // Echo the client's messages back. - Status RunExchangeEcho(std::unique_ptr reader, - std::unique_ptr writer) { - FlightStreamChunk chunk; - bool begun = false; - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (!begun && chunk.data) { - begun = true; - RETURN_NOT_OK(writer->Begin(chunk.data->schema())); - } - if (chunk.data && chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); - } else if (chunk.data) { - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); - } else if (chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); - } - } - return Status::OK(); - } - - // Regression test for ARROW-13253 - Status RunExchangeLargeBatch(std::unique_ptr, - std::unique_ptr writer) { - ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); - RETURN_NOT_OK(writer->Begin(batch->schema())); - return writer->WriteRecordBatch(*batch); - } - - Status RunAction1(const Action& action, std::unique_ptr* out) { - std::vector results; - for (int i = 0; i < 3; ++i) { - Result result; - std::string value = action.body->ToString() + "-part" + std::to_string(i); - result.body = Buffer::FromString(std::move(value)); - results.push_back(result); - } - *out = std::make_unique(std::move(results)); - return Status::OK(); - } - - Status RunAction2(std::unique_ptr* out) { - // Empty - *out = std::make_unique(std::vector{}); - return Status::OK(); - } - - Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) { - std::vector results; - std::string_view prefix(*action.body); - for (const auto& header : context.incoming_headers()) { - if (header.first.substr(0, prefix.size()) != prefix) { - continue; - } - Result result; - result.body = Buffer::FromString(std::string(header.first) + ": " + - std::string(header.second)); - results.push_back(result); - } - *out = std::make_unique(std::move(results)); - return Status::OK(); - } - - Status DoAction(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) override { - if (action.type == "action1") { - return RunAction1(action, out); - } else if (action.type == "action2") { - return RunAction2(out); - } else if (action.type == "list-incoming-headers") { - return ListIncomingHeaders(context, action, out); - } else { - return Status::NotImplemented(action.type); - } - } - - Status ListActions(const ServerCallContext& context, - std::vector* out) override { - std::vector actions = ExampleActionTypes(); - *out = std::move(actions); - return Status::OK(); - } - - Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* schema) override { - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *schema = std::make_unique(info.serialized_schema()); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } -}; - -std::unique_ptr ExampleTestServer() { - return std::make_unique(); -} - FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, const std::vector& endpoints, int64_t total_records, int64_t total_bytes, bool ordered, @@ -701,109 +323,6 @@ std::vector ExampleActionTypes() { return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}}; } -TestServerAuthHandler::TestServerAuthHandler(const std::string& username, - const std::string& password) - : username_(username), password_(password) {} - -TestServerAuthHandler::~TestServerAuthHandler() {} - -Status TestServerAuthHandler::Authenticate(const ServerCallContext& context, - ServerAuthSender* outgoing, - ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - if (token != password_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(username_)); - return Status::OK(); -} - -Status TestServerAuthHandler::IsValid(const ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token != password_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = username_; - return Status::OK(); -} - -TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username, - const std::string& password) { - basic_auth_.username = username; - basic_auth_.password = password; -} - -TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {} - -Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context, - ServerAuthSender* outgoing, - ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token)); - if (incoming_auth.username != basic_auth_.username || - incoming_auth.password != basic_auth_.password) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(basic_auth_.username)); - return Status::OK(); -} - -Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token != basic_auth_.username) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = basic_auth_.username; - return Status::OK(); -} - -TestClientAuthHandler::TestClientAuthHandler(const std::string& username, - const std::string& password) - : username_(username), password_(password) {} - -TestClientAuthHandler::~TestClientAuthHandler() {} - -Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing, - ClientAuthReader* incoming) { - RETURN_NOT_OK(outgoing->Write(password_)); - std::string username; - RETURN_NOT_OK(incoming->Read(&username)); - if (username != username_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - return Status::OK(); -} - -Status TestClientAuthHandler::GetToken(std::string* token) { - *token = password_; - return Status::OK(); -} - -TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username, - const std::string& password) { - basic_auth_.username = username; - basic_auth_.password = password; -} - -TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {} - -Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing, - ClientAuthReader* incoming) { - ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString()); - RETURN_NOT_OK(outgoing->Write(pb_result)); - RETURN_NOT_OK(incoming->Read(&token_)); - return Status::OK(); -} - -Status TestClientBasicAuthHandler::GetToken(std::string* token) { - *token = token_; - return Status::OK(); -} - Status ExampleTlsCertificates(std::vector* out) { std::string root; RETURN_NOT_OK(GetTestResourceRoot(&root)); @@ -860,5 +379,4 @@ Status ExampleTlsCertificateRoot(CertKeyPair* out) { } } -} // namespace flight -} // namespace arrow +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index c0b42d9b90c5a..15ba6145ecd2b 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -32,9 +32,7 @@ #include "arrow/testing/util.h" #include "arrow/flight/client.h" -#include "arrow/flight/client_auth.h" #include "arrow/flight/server.h" -#include "arrow/flight/server_auth.h" #include "arrow/flight/types.h" #include "arrow/flight/visibility.h" @@ -95,10 +93,6 @@ class ARROW_FLIGHT_EXPORT TestServer { std::shared_ptr<::boost::process::child> server_process_; }; -/// \brief Create a simple Flight server for testing -ARROW_FLIGHT_EXPORT -std::unique_ptr ExampleTestServer(); - // Helper to initialize a server and matching client with callbacks to // populate options. template @@ -195,65 +189,6 @@ FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descript int64_t total_records, int64_t total_bytes, bool ordered, std::string app_metadata); -// ---------------------------------------------------------------------- -// A pair of authentication handlers that check for a predefined password -// and set the peer identity to a predefined username. - -class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler { - public: - explicit TestServerAuthHandler(const std::string& username, - const std::string& password); - ~TestServerAuthHandler() override; - Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, - ServerAuthReader* incoming) override; - Status IsValid(const ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; - - private: - std::string username_; - std::string password_; -}; - -class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler { - public: - explicit TestServerBasicAuthHandler(const std::string& username, - const std::string& password); - ~TestServerBasicAuthHandler() override; - Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, - ServerAuthReader* incoming) override; - Status IsValid(const ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; - - private: - BasicAuth basic_auth_; -}; - -class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { - public: - explicit TestClientAuthHandler(const std::string& username, - const std::string& password); - ~TestClientAuthHandler() override; - Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; - Status GetToken(std::string* token) override; - - private: - std::string username_; - std::string password_; -}; - -class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler { - public: - explicit TestClientBasicAuthHandler(const std::string& username, - const std::string& password); - ~TestClientBasicAuthHandler() override; - Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; - Status GetToken(std::string* token) override; - - private: - BasicAuth basic_auth_; - std::string token_; -}; - ARROW_FLIGHT_EXPORT Status ExampleTlsCertificates(std::vector* out); diff --git a/docs/source/python/api/arrays.rst b/docs/source/python/api/arrays.rst index aefed00b3d2e0..4ad35b190cdd0 100644 --- a/docs/source/python/api/arrays.rst +++ b/docs/source/python/api/arrays.rst @@ -85,6 +85,7 @@ may expose data type-specific methods or properties. UnionArray ExtensionArray FixedShapeTensorArray + OpaqueArray .. _api.scalar: @@ -143,3 +144,5 @@ classes may expose data type-specific methods or properties. StructScalar UnionScalar ExtensionScalar + FixedShapeTensorScalar + OpaqueScalar diff --git a/docs/source/python/api/datatypes.rst b/docs/source/python/api/datatypes.rst index 7edb4e161541d..a43c5299eae51 100644 --- a/docs/source/python/api/datatypes.rst +++ b/docs/source/python/api/datatypes.rst @@ -67,6 +67,8 @@ These should be used to create Arrow data types and schemas. struct dictionary run_end_encoded + fixed_shape_tensor + opaque field schema from_numpy_dtype @@ -117,6 +119,14 @@ Specific classes and functions for extension types. register_extension_type unregister_extension_type +:doc:`Canonical extension types <../../format/CanonicalExtensions>` +implemented by PyArrow. + +.. autosummary:: + :toctree: ../generated/ + + FixedShapeTensorType + OpaqueType .. _api.types.checking: .. currentmodule:: pyarrow.types diff --git a/go/go.mod b/go/go.mod index 09869b7a3836f..9f4222a541bb6 100644 --- a/go/go.mod +++ b/go/go.mod @@ -49,7 +49,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.24.1 github.com/huandu/xstrings v1.4.0 - github.com/substrait-io/substrait-go v0.5.0 + github.com/substrait-io/substrait-go v0.6.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/go/go.sum b/go/go.sum index 2e89a769024ac..c7eb3a66deeec 100644 --- a/go/go.sum +++ b/go/go.sum @@ -24,8 +24,8 @@ github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8c github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= +github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.11.0 h1:n7Z+zx8S9f9KgzG6KtQKf+kwqXZlLNR2F6018Dgau54= @@ -99,8 +99,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/substrait-io/substrait-go v0.5.0 h1:8sYsoqcrzoNpThPyot1CQpwF6OokxvplLUQJTGlKws4= -github.com/substrait-io/substrait-go v0.5.0/go.mod h1:Co7ko6iIjdqCGcN3LfkKWPVlxONkNZem9omWAGIaOrQ= +github.com/substrait-io/substrait-go v0.6.0 h1:n2G/SGmrn7U5Q39VA8WeM2UfVL5Y/6HX8WAP9uJLNk4= +github.com/substrait-io/substrait-go v0.6.0/go.mod h1:cl8Wsc7aBPDfcHp9+OrUqGpjkgrYlhcDsH/lMP6KUZA= github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java index feb7edfec9495..2921e43cb6410 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java @@ -382,6 +382,17 @@ public VectorWithOrdinal getChildVectorWithOrdinal(String name) { return new VectorWithOrdinal(vector, ordinal); } + /** + * Return the underlying buffers associated with this vector. Note that this doesn't impact the + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it + * (unless they change it). + * + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. + * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. + */ @Override public ArrowBuf[] getBuffers(boolean clear) { final List buffers = new ArrayList<>(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java index 26079cbee951a..f643306cfdcff 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java @@ -102,7 +102,7 @@ private void allocateBuffers() { sizeBuffer = allocateBuffers(sizeAllocationSizeInBytes); } - private ArrowBuf allocateBuffers(final long size) { + protected ArrowBuf allocateBuffers(final long size) { final int curSize = (int) size; ArrowBuf buffer = allocator.buffer(curSize); buffer.readerIndex(0); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java index 1cdb87eba0376..fbe83bad52cf1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java @@ -271,6 +271,17 @@ public void reset() { valueCount = 0; } + /** + * Return the underlying buffers associated with this vector. Note that this doesn't impact the + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it + * (unless they change it). + * + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. + * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. + */ @Override public ArrowBuf[] getBuffers(boolean clear) { final ArrowBuf[] buffers; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java index cb4550848088c..c762eb51725ca 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java @@ -360,6 +360,17 @@ public void reset() { valueCount = 0; } + /** + * Return the underlying buffers associated with this vector. Note that this doesn't impact the + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it + * (unless they change it). + * + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. + * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. + */ @Override public ArrowBuf[] getBuffers(boolean clear) { setReaderAndWriterIndex(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java index b5b32c8032dfe..ed075352c931c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java @@ -882,12 +882,13 @@ public void reset() { /** * Return the underlying buffers associated with this vector. Note that this doesn't impact the - * reference counts for this buffer so it only should be used for in-context access. Also note - * that this buffer changes regularly thus external classes shouldn't hold a reference to it + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it * (unless they change it). * - * @param clear Whether to clear vector before returning; the buffers will still be refcounted but - * the returned array will be the only reference to them + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 17ccdbf0eae39..2c61f799a4cf9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -39,6 +39,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueIterableVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.UnionLargeListViewReader; import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; @@ -361,20 +362,17 @@ public TransferPair getTransferPair(Field field, BufferAllocator allocator) { @Override public TransferPair getTransferPair(String ref, BufferAllocator allocator, CallBack callBack) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support getTransferPair(String, BufferAllocator, CallBack) yet"); + return new TransferImpl(ref, allocator, callBack); } @Override public TransferPair getTransferPair(Field field, BufferAllocator allocator, CallBack callBack) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support getTransferPair(Field, BufferAllocator, CallBack) yet"); + return new TransferImpl(field, allocator, callBack); } @Override public TransferPair makeTransferPair(ValueVector target) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support makeTransferPair(ValueVector) yet"); + return new TransferImpl((LargeListViewVector) target); } @Override @@ -452,6 +450,159 @@ public OUT accept(VectorVisitor visitor, IN value) { return visitor.visit(this, value); } + private class TransferImpl implements TransferPair { + + LargeListViewVector to; + TransferPair dataTransferPair; + + public TransferImpl(String name, BufferAllocator allocator, CallBack callBack) { + this(new LargeListViewVector(name, allocator, field.getFieldType(), callBack)); + } + + public TransferImpl(Field field, BufferAllocator allocator, CallBack callBack) { + this(new LargeListViewVector(field, allocator, callBack)); + } + + public TransferImpl(LargeListViewVector to) { + this.to = to; + to.addOrGetVector(vector.getField().getFieldType()); + if (to.getDataVector() instanceof ZeroVector) { + to.addOrGetVector(vector.getField().getFieldType()); + } + dataTransferPair = getDataVector().makeTransferPair(to.getDataVector()); + } + + @Override + public void transfer() { + to.clear(); + dataTransferPair.transfer(); + to.validityBuffer = transferBuffer(validityBuffer, to.allocator); + to.offsetBuffer = transferBuffer(offsetBuffer, to.allocator); + to.sizeBuffer = transferBuffer(sizeBuffer, to.allocator); + if (valueCount > 0) { + to.setValueCount(valueCount); + } + clear(); + } + + @Override + public void splitAndTransfer(int startIndex, int length) { + Preconditions.checkArgument( + startIndex >= 0 && length >= 0 && startIndex + length <= valueCount, + "Invalid parameters startIndex: %s, length: %s for valueCount: %s", + startIndex, + length, + valueCount); + to.clear(); + if (length > 0) { + // we have to scan by index since there are out-of-order offsets + to.offsetBuffer = to.allocateBuffers((long) length * OFFSET_WIDTH); + to.sizeBuffer = to.allocateBuffers((long) length * SIZE_WIDTH); + + /* splitAndTransfer the size buffer */ + int maxOffsetAndSizeSum = Integer.MIN_VALUE; + int minOffsetValue = Integer.MAX_VALUE; + for (int i = 0; i < length; i++) { + final int offsetValue = offsetBuffer.getInt((long) (startIndex + i) * OFFSET_WIDTH); + final int sizeValue = sizeBuffer.getInt((long) (startIndex + i) * SIZE_WIDTH); + to.sizeBuffer.setInt((long) i * SIZE_WIDTH, sizeValue); + maxOffsetAndSizeSum = Math.max(maxOffsetAndSizeSum, offsetValue + sizeValue); + minOffsetValue = Math.min(minOffsetValue, offsetValue); + } + + /* splitAndTransfer the offset buffer */ + for (int i = 0; i < length; i++) { + final int offsetValue = offsetBuffer.getInt((long) (startIndex + i) * OFFSET_WIDTH); + final int relativeOffset = offsetValue - minOffsetValue; + to.offsetBuffer.setInt((long) i * OFFSET_WIDTH, relativeOffset); + } + + /* splitAndTransfer the validity buffer */ + splitAndTransferValidityBuffer(startIndex, length, to); + + /* splitAndTransfer the data buffer */ + final int childSliceLength = maxOffsetAndSizeSum - minOffsetValue; + dataTransferPair.splitAndTransfer(minOffsetValue, childSliceLength); + to.setValueCount(length); + } + } + + /* + * transfer the validity. + */ + private void splitAndTransferValidityBuffer( + int startIndex, int length, LargeListViewVector target) { + int firstByteSource = BitVectorHelper.byteIndex(startIndex); + int lastByteSource = BitVectorHelper.byteIndex(valueCount - 1); + int byteSizeTarget = getValidityBufferSizeFromCount(length); + int offset = startIndex % 8; + + if (length > 0) { + if (offset == 0) { + // slice + if (target.validityBuffer != null) { + target.validityBuffer.getReferenceManager().release(); + } + target.validityBuffer = validityBuffer.slice(firstByteSource, byteSizeTarget); + target.validityBuffer.getReferenceManager().retain(1); + } else { + /* Copy data + * When the first bit starts from the middle of a byte (offset != 0), + * copy data from src BitVector. + * Each byte in the target is composed by a part in i-th byte, + * another part in (i+1)-th byte. + */ + target.allocateValidityBuffer(byteSizeTarget); + + for (int i = 0; i < byteSizeTarget - 1; i++) { + byte b1 = + BitVectorHelper.getBitsFromCurrentByte(validityBuffer, firstByteSource + i, offset); + byte b2 = + BitVectorHelper.getBitsFromNextByte( + validityBuffer, firstByteSource + i + 1, offset); + + target.validityBuffer.setByte(i, (b1 + b2)); + } + + /* Copying the last piece is done in the following manner: + * if the source vector has 1 or more bytes remaining, we copy + * the last piece as a byte formed by shifting data + * from the current byte and the next byte. + * + * if the source vector has no more bytes remaining + * (we are at the last byte), we copy the last piece as a byte + * by shifting data from the current byte. + */ + if ((firstByteSource + byteSizeTarget - 1) < lastByteSource) { + byte b1 = + BitVectorHelper.getBitsFromCurrentByte( + validityBuffer, firstByteSource + byteSizeTarget - 1, offset); + byte b2 = + BitVectorHelper.getBitsFromNextByte( + validityBuffer, firstByteSource + byteSizeTarget, offset); + + target.validityBuffer.setByte(byteSizeTarget - 1, b1 + b2); + } else { + byte b1 = + BitVectorHelper.getBitsFromCurrentByte( + validityBuffer, firstByteSource + byteSizeTarget - 1, offset); + target.validityBuffer.setByte(byteSizeTarget - 1, b1); + } + } + } + } + + @Override + public ValueVector getTo() { + return to; + } + + @Override + public void copyValueSafe(int from, int to) { + this.to.copyFrom(from, to, LargeListViewVector.this); + } + } + @Override protected FieldReader getReaderImpl() { throw new UnsupportedOperationException( @@ -546,7 +697,8 @@ public void reset() { * (unless they change it). * * @param clear Whether to clear vector before returning, the buffers will still be refcounted but - * the returned array will be the only reference to them + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override @@ -561,7 +713,7 @@ public ArrowBuf[] getBuffers(boolean clear) { list.add(validityBuffer); list.add(offsetBuffer); list.add(sizeBuffer); - list.addAll(Arrays.asList(vector.getBuffers(clear))); + list.addAll(Arrays.asList(vector.getBuffers(false))); buffers = list.toArray(new ArrowBuf[list.size()]); } if (clear) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index a1e18210fc686..76682c28fe65d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -726,12 +726,13 @@ public void reset() { /** * Return the underlying buffers associated with this vector. Note that this doesn't impact the - * reference counts for this buffer so it only should be used for in-context access. Also note - * that this buffer changes regularly thus external classes shouldn't hold a reference to it + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it * (unless they change it). * - * @param clear Whether to clear vector before returning; the buffers will still be refcounted but - * the returned array will be the only reference to them + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 6ced66d81ec21..7f6d92f3be9c8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -704,7 +704,8 @@ public void reset() { * (unless they change it). * * @param clear Whether to clear vector before returning, the buffers will still be refcounted but - * the returned array will be the only reference to them + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java index dda9b6547f758..ca5f572034cee 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java @@ -396,12 +396,13 @@ public int getValueCapacity() { /** * Return the underlying buffers associated with this vector. Note that this doesn't impact the - * reference counts for this buffer so it only should be used for in-context access. Also note - * that this buffer changes regularly thus external classes shouldn't hold a reference to it + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it * (unless they change it). * - * @param clear Whether to clear vector before returning; the buffers will still be refcounted but - * the returned array will be the only reference to them + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 626619a9483de..5668325a87eeb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -73,6 +73,7 @@ import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.LargeListView; import org.apache.arrow.vector.types.pojo.ArrowType.ListView; import org.apache.arrow.vector.types.pojo.ArrowType.Union; import org.apache.arrow.vector.types.pojo.Field; @@ -729,7 +730,8 @@ private List readIntoBuffer( } else if (bufferType.equals(OFFSET) || bufferType.equals(SIZE)) { if (type == MinorType.LARGELIST || type == MinorType.LARGEVARCHAR - || type == MinorType.LARGEVARBINARY) { + || type == MinorType.LARGEVARBINARY + || type == MinorType.LARGELISTVIEW) { reader = helper.INT8; } else { reader = helper.INT4; @@ -890,7 +892,10 @@ private void readFromJsonIntoVector(Field field, FieldVector vector) throws IOEx BufferType bufferType = vectorTypes.get(v); nextFieldIs(bufferType.getName()); int innerBufferValueCount = valueCount; - if (bufferType.equals(OFFSET) && !(type instanceof Union) && !(type instanceof ListView)) { + if (bufferType.equals(OFFSET) + && !(type instanceof Union) + && !(type instanceof ListView) + && !(type instanceof LargeListView)) { /* offset buffer has 1 additional value capacity except for dense unions and ListView */ innerBufferValueCount = valueCount + 1; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index 929c8c97c0551..68700fe6afd25 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -73,6 +73,7 @@ import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.BaseLargeRepeatedValueViewVector; import org.apache.arrow.vector.complex.BaseRepeatedValueViewVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -232,7 +233,8 @@ private void writeFromVectorIntoJson(Field field, FieldVector vector) throws IOE final int bufferValueCount = (bufferType.equals(OFFSET) && vector.getMinorType() != MinorType.DENSEUNION - && vector.getMinorType() != MinorType.LISTVIEW) + && vector.getMinorType() != MinorType.LISTVIEW + && vector.getMinorType() != MinorType.LARGELISTVIEW) ? valueCount + 1 : valueCount; for (int i = 0; i < bufferValueCount; i++) { @@ -274,6 +276,7 @@ private void writeFromVectorIntoJson(Field field, FieldVector vector) throws IOE } else if (bufferType.equals(OFFSET) && vector.getValueCount() == 0 && (vector.getMinorType() == MinorType.LARGELIST + || vector.getMinorType() == MinorType.LARGELISTVIEW || vector.getMinorType() == MinorType.LARGEVARBINARY || vector.getMinorType() == MinorType.LARGEVARCHAR)) { // Empty vectors may not have allocated an offsets buffer @@ -427,6 +430,10 @@ private void writeValueToGenerator( generator.writeNumber( buffer.getInt((long) index * BaseRepeatedValueViewVector.OFFSET_WIDTH)); break; + case LARGELISTVIEW: + generator.writeNumber( + buffer.getInt((long) index * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + break; case LARGELIST: case LARGEVARBINARY: case LARGEVARCHAR: @@ -582,7 +589,12 @@ private void writeValueToGenerator( throw new UnsupportedOperationException("minor type: " + vector.getMinorType()); } } else if (bufferType.equals(SIZE)) { - generator.writeNumber(buffer.getInt((long) index * BaseRepeatedValueViewVector.SIZE_WIDTH)); + if (vector.getMinorType() == MinorType.LISTVIEW) { + generator.writeNumber(buffer.getInt((long) index * BaseRepeatedValueViewVector.SIZE_WIDTH)); + } else { + generator.writeNumber( + buffer.getInt((long) index * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java index 563ac811c4fdb..2ed8d4d7005ea 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java @@ -18,6 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.ArrayList; @@ -32,6 +33,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.TransferPair; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -1639,6 +1641,460 @@ public void testOutOfOrderOffset1() { } } + private int validateSizeBufferAndCalculateMinOffset( + int start, + int splitLength, + ArrowBuf fromOffsetBuffer, + ArrowBuf fromSizeBuffer, + ArrowBuf toSizeBuffer) { + int minOffset = fromOffsetBuffer.getInt((long) start * LargeListViewVector.OFFSET_WIDTH); + int fromDataLength; + int toDataLength; + + for (int i = 0; i < splitLength; i++) { + fromDataLength = fromSizeBuffer.getInt((long) (start + i) * LargeListViewVector.SIZE_WIDTH); + toDataLength = toSizeBuffer.getInt((long) (i) * LargeListViewVector.SIZE_WIDTH); + + /* validate size */ + assertEquals( + fromDataLength, + toDataLength, + "Different data lengths at index: " + i + " and start: " + start); + + /* calculate minimum offset */ + int currentOffset = + fromOffsetBuffer.getInt((long) (start + i) * LargeListViewVector.OFFSET_WIDTH); + if (currentOffset < minOffset) { + minOffset = currentOffset; + } + } + + return minOffset; + } + + private void validateOffsetBuffer( + int start, + int splitLength, + ArrowBuf fromOffsetBuffer, + ArrowBuf toOffsetBuffer, + int minOffset) { + int offset1; + int offset2; + + for (int i = 0; i < splitLength; i++) { + offset1 = fromOffsetBuffer.getInt((long) (start + i) * LargeListViewVector.OFFSET_WIDTH); + offset2 = toOffsetBuffer.getInt((long) (i) * LargeListViewVector.OFFSET_WIDTH); + assertEquals( + offset1 - minOffset, + offset2, + "Different offset values at index: " + i + " and start: " + start); + } + } + + private void validateDataBuffer( + int start, + int splitLength, + ArrowBuf fromOffsetBuffer, + ArrowBuf fromSizeBuffer, + BigIntVector fromDataVector, + ArrowBuf toOffsetBuffer, + BigIntVector toDataVector) { + int dataLength; + Long fromValue; + for (int i = 0; i < splitLength; i++) { + dataLength = fromSizeBuffer.getInt((long) (start + i) * LargeListViewVector.SIZE_WIDTH); + for (int j = 0; j < dataLength; j++) { + fromValue = + fromDataVector.getObject( + (fromOffsetBuffer.getInt((long) (start + i) * LargeListViewVector.OFFSET_WIDTH) + + j)); + Long toValue = + toDataVector.getObject( + (toOffsetBuffer.getInt((long) i * LargeListViewVector.OFFSET_WIDTH) + j)); + assertEquals( + fromValue, toValue, "Different data values at index: " + i + " and start: " + start); + } + } + } + + /** + * Validate split and transfer of data from fromVector to toVector. Note that this method assumes + * that the child vector is BigIntVector. + * + * @param start start index + * @param splitLength length of data to split and transfer + * @param fromVector fromVector + * @param toVector toVector + */ + private void validateSplitAndTransfer( + TransferPair transferPair, + int start, + int splitLength, + LargeListViewVector fromVector, + LargeListViewVector toVector) { + + transferPair.splitAndTransfer(start, splitLength); + + /* get offsetBuffer of toVector */ + final ArrowBuf toOffsetBuffer = toVector.getOffsetBuffer(); + + /* get sizeBuffer of toVector */ + final ArrowBuf toSizeBuffer = toVector.getSizeBuffer(); + + /* get dataVector of toVector */ + BigIntVector toDataVector = (BigIntVector) toVector.getDataVector(); + + /* get offsetBuffer of toVector */ + final ArrowBuf fromOffsetBuffer = fromVector.getOffsetBuffer(); + + /* get sizeBuffer of toVector */ + final ArrowBuf fromSizeBuffer = fromVector.getSizeBuffer(); + + /* get dataVector of toVector */ + BigIntVector fromDataVector = (BigIntVector) fromVector.getDataVector(); + + /* validate size buffers */ + int minOffset = + validateSizeBufferAndCalculateMinOffset( + start, splitLength, fromOffsetBuffer, fromSizeBuffer, toSizeBuffer); + /* validate offset buffers */ + validateOffsetBuffer(start, splitLength, fromOffsetBuffer, toOffsetBuffer, minOffset); + /* validate data */ + validateDataBuffer( + start, + splitLength, + fromOffsetBuffer, + fromSizeBuffer, + fromDataVector, + toOffsetBuffer, + toDataVector); + } + + @Test + public void testSplitAndTransfer() throws Exception { + try (LargeListViewVector fromVector = LargeListViewVector.empty("sourceVector", allocator)) { + + /* Explicitly add the dataVector */ + MinorType type = MinorType.BIGINT; + fromVector.addOrGetVector(FieldType.nullable(type.getType())); + + UnionLargeListViewWriter listViewWriter = fromVector.getWriter(); + + /* allocate memory */ + listViewWriter.allocate(); + + /* populate data */ + listViewWriter.setPosition(0); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(10); + listViewWriter.bigInt().writeBigInt(11); + listViewWriter.bigInt().writeBigInt(12); + listViewWriter.endListView(); + + listViewWriter.setPosition(1); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(13); + listViewWriter.bigInt().writeBigInt(14); + listViewWriter.endListView(); + + listViewWriter.setPosition(2); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(15); + listViewWriter.bigInt().writeBigInt(16); + listViewWriter.bigInt().writeBigInt(17); + listViewWriter.bigInt().writeBigInt(18); + listViewWriter.endListView(); + + listViewWriter.setPosition(3); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(19); + listViewWriter.endListView(); + + listViewWriter.setPosition(4); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(20); + listViewWriter.bigInt().writeBigInt(21); + listViewWriter.bigInt().writeBigInt(22); + listViewWriter.bigInt().writeBigInt(23); + listViewWriter.endListView(); + + fromVector.setValueCount(5); + + /* get offset buffer */ + final ArrowBuf offsetBuffer = fromVector.getOffsetBuffer(); + + /* get size buffer */ + final ArrowBuf sizeBuffer = fromVector.getSizeBuffer(); + + /* get dataVector */ + BigIntVector dataVector = (BigIntVector) fromVector.getDataVector(); + + /* check the vector output */ + + int index = 0; + int offset; + int size = 0; + Long actual; + + /* index 0 */ + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(0), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(10), actual); + offset++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(11), actual); + offset++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(12), actual); + assertEquals( + Integer.toString(3), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 1 */ + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(3), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(13), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(14), actual); + size++; + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 2 */ + size = 0; + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(5), Integer.toString(offset)); + size++; + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(15), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(16), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(17), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(18), actual); + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 3 */ + size = 0; + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(9), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(19), actual); + size++; + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 4 */ + size = 0; + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(10), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(20), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(21), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(22), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(23), actual); + size++; + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* do split and transfer */ + try (LargeListViewVector toVector = LargeListViewVector.empty("toVector", allocator)) { + int[][] transferLengths = {{0, 2}, {3, 1}, {4, 1}}; + TransferPair transferPair = fromVector.makeTransferPair(toVector); + + for (final int[] transferLength : transferLengths) { + int start = transferLength[0]; + int splitLength = transferLength[1]; + validateSplitAndTransfer(transferPair, start, splitLength, fromVector, toVector); + } + } + } + } + + @Test + public void testGetTransferPairWithField() throws Exception { + try (final LargeListViewVector fromVector = LargeListViewVector.empty("listview", allocator)) { + + UnionLargeListViewWriter writer = fromVector.getWriter(); + writer.allocate(); + + // set some values + writer.startListView(); + writer.integer().writeInt(1); + writer.integer().writeInt(2); + writer.endListView(); + fromVector.setValueCount(2); + + final TransferPair transferPair = + fromVector.getTransferPair(fromVector.getField(), allocator); + final LargeListViewVector toVector = (LargeListViewVector) transferPair.getTo(); + // Field inside a new vector created by reusing a field should be the same in memory as the + // original field. + assertSame(toVector.getField(), fromVector.getField()); + } + } + + @Test + public void testOutOfOrderOffsetSplitAndTransfer() { + // [[12, -7, 25], null, [0, -127, 127, 50], [], [50, 12]] + try (LargeListViewVector fromVector = LargeListViewVector.empty("fromVector", allocator)) { + // Allocate buffers in LargeListViewVector by calling `allocateNew` method. + fromVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(64, true), null, null); + Field field = new Field("child-vector", fieldType, null); + fromVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = fromVector.getDataVector(); + fieldVector.clear(); + + BigIntVector childVector = (BigIntVector) fieldVector; + + childVector.allocateNew(7); + + childVector.set(0, 0); + childVector.set(1, -127); + childVector.set(2, 127); + childVector.set(3, 50); + childVector.set(4, 12); + childVector.set(5, -7); + childVector.set(6, 25); + + childVector.setValueCount(7); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + fromVector.setValidity(0, 1); + fromVector.setValidity(1, 0); + fromVector.setValidity(2, 1); + fromVector.setValidity(3, 1); + fromVector.setValidity(4, 1); + + fromVector.setOffset(0, 4); + fromVector.setOffset(1, 7); + fromVector.setOffset(2, 0); + fromVector.setOffset(3, 0); + fromVector.setOffset(4, 3); + + fromVector.setSize(0, 3); + fromVector.setSize(1, 0); + fromVector.setSize(2, 4); + fromVector.setSize(3, 0); + fromVector.setSize(4, 2); + + // Set value count using `setValueCount` method. + fromVector.setValueCount(5); + + final ArrowBuf offSetBuffer = fromVector.getOffsetBuffer(); + final ArrowBuf sizeBuffer = fromVector.getSizeBuffer(); + + // check offset buffer + assertEquals(4, offSetBuffer.getInt(0 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(7, offSetBuffer.getInt(1 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(0, offSetBuffer.getInt(2 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(0, offSetBuffer.getInt(3 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(3, offSetBuffer.getInt(4 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + + // check size buffer + assertEquals(3, sizeBuffer.getInt(0 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(0, sizeBuffer.getInt(1 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(4, sizeBuffer.getInt(2 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(0, sizeBuffer.getInt(3 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(2, sizeBuffer.getInt(4 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + + // check child vector + assertEquals(0, ((BigIntVector) fromVector.getDataVector()).get(0)); + assertEquals(-127, ((BigIntVector) fromVector.getDataVector()).get(1)); + assertEquals(127, ((BigIntVector) fromVector.getDataVector()).get(2)); + assertEquals(50, ((BigIntVector) fromVector.getDataVector()).get(3)); + assertEquals(12, ((BigIntVector) fromVector.getDataVector()).get(4)); + assertEquals(-7, ((BigIntVector) fromVector.getDataVector()).get(5)); + assertEquals(25, ((BigIntVector) fromVector.getDataVector()).get(6)); + + // check values + Object result = fromVector.getObject(0); + ArrayList resultSet = (ArrayList) result; + assertEquals(3, resultSet.size()); + assertEquals(Long.valueOf(12), resultSet.get(0)); + assertEquals(Long.valueOf(-7), resultSet.get(1)); + assertEquals(Long.valueOf(25), resultSet.get(2)); + + assertTrue(fromVector.isNull(1)); + + result = fromVector.getObject(2); + resultSet = (ArrayList) result; + assertEquals(4, resultSet.size()); + assertEquals(Long.valueOf(0), resultSet.get(0)); + assertEquals(Long.valueOf(-127), resultSet.get(1)); + assertEquals(Long.valueOf(127), resultSet.get(2)); + assertEquals(Long.valueOf(50), resultSet.get(3)); + + assertTrue(fromVector.isEmpty(3)); + + result = fromVector.getObject(4); + resultSet = (ArrayList) result; + assertEquals(2, resultSet.size()); + assertEquals(Long.valueOf(50), resultSet.get(0)); + assertEquals(Long.valueOf(12), resultSet.get(1)); + + fromVector.validate(); + + /* do split and transfer */ + try (LargeListViewVector toVector = LargeListViewVector.empty("toVector", allocator)) { + int[][] transferLengths = {{2, 3}, {0, 1}, {0, 3}}; + TransferPair transferPair = fromVector.makeTransferPair(toVector); + + for (final int[] transferLength : transferLengths) { + int start = transferLength[0]; + int splitLength = transferLength[1]; + validateSplitAndTransfer(transferPair, start, splitLength, fromVector, toVector); + } + } + } + } + private void writeIntValues(UnionLargeListViewWriter writer, int[] values) { writer.startListView(); for (int v : values) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java b/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java index d20dc3348b1c9..a3f25bc5207b6 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java @@ -29,6 +29,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; @@ -852,6 +853,25 @@ public void testListVectorZeroStartIndexAndLength() { } } + @Test + public void testLargeListViewVectorZeroStartIndexAndLength() { + try (final LargeListViewVector listVector = + LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector newListVector = LargeListViewVector.empty("newList", allocator)) { + + listVector.allocateNew(); + final int valueCount = 0; + listVector.setValueCount(valueCount); + + final TransferPair tp = listVector.makeTransferPair(newListVector); + + tp.splitAndTransfer(0, 0); + assertEquals(valueCount, newListVector.getValueCount()); + + newListVector.clear(); + } + } + @Test public void testStructVectorZeroStartIndexAndLength() { Map metadata = new HashMap<>(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java index 48cf78a4c2e4a..28d73a8fdfff9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java @@ -25,6 +25,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.NonNullableStructVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; @@ -122,7 +123,10 @@ public void testListTypeReset() { "VarList", allocator, FieldType.nullable(MinorType.INT.getType()), null); final FixedSizeListVector fixedList = new FixedSizeListVector( - "FixedList", allocator, FieldType.nullable(new FixedSizeList(2)), null)) { + "FixedList", allocator, FieldType.nullable(new FixedSizeList(2)), null); + final ListViewVector variableViewList = + new ListViewVector( + "VarListView", allocator, FieldType.nullable(MinorType.INT.getType()), null)) { // ListVector variableList.allocateNewSafe(); variableList.startNewValue(0); @@ -136,6 +140,13 @@ public void testListTypeReset() { fixedList.setNull(0); fixedList.setValueCount(1); resetVectorAndVerify(fixedList, fixedList.getBuffers(false)); + + // ListViewVector + variableViewList.allocateNewSafe(); + variableViewList.startNewValue(0); + variableViewList.endValue(0, 0); + variableViewList.setValueCount(1); + resetVectorAndVerify(variableViewList, variableViewList.getBuffers(false)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java index c69a3bfbc1ee2..8037212aaea21 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java @@ -437,10 +437,18 @@ public void testRoundtripEmptyVector() throws Exception { "list", FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), + new Field( + "listview", + FieldType.nullable(ArrowType.ListView.INSTANCE), + Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), new Field( "largelist", FieldType.nullable(ArrowType.LargeList.INSTANCE), Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), + new Field( + "largelistview", + FieldType.nullable(ArrowType.LargeListView.INSTANCE), + Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), new Field( "map", FieldType.nullable(new ArrowType.Map(/*keyssorted*/ false)), diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index a90dee70584b1..5d5eeaf8157b4 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -260,6 +260,7 @@ message(STATUS "Found NumPy version: ${Python3_NumPy_VERSION}") message(STATUS "NumPy include dir: ${NUMPY_INCLUDE_DIRS}") include(UseCython) +message(STATUS "Found Cython version: ${CYTHON_VERSION}") # Arrow C++ and set default PyArrow build options include(GNUInstallDirs) @@ -855,6 +856,10 @@ set(CYTHON_FLAGS "${CYTHON_FLAGS}" "--warning-errors") # undocumented Cython feature. set(CYTHON_FLAGS "${CYTHON_FLAGS}" "--no-c-in-traceback") +if(CYTHON_VERSION VERSION_GREATER_EQUAL "3.1.0a0") + list(APPEND CYTHON_FLAGS "-Xfreethreading_compatible=True") +endif() + foreach(module ${CYTHON_EXTENSIONS}) string(REPLACE "." ";" directories ${module}) list(GET directories -1 module_name) diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index e52e0d242bee5..aa7bab9f97e05 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -173,6 +173,7 @@ def print_entry(label, value): dictionary, run_end_encoded, fixed_shape_tensor, + opaque, field, type_for_alias, DataType, DictionaryType, StructType, @@ -182,7 +183,7 @@ def print_entry(label, value): TimestampType, Time32Type, Time64Type, DurationType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, - RunEndEncodedType, FixedShapeTensorType, + RunEndEncodedType, FixedShapeTensorType, OpaqueType, PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, @@ -216,7 +217,7 @@ def print_entry(label, value): Time32Array, Time64Array, DurationArray, MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, - RunEndEncodedArray, FixedShapeTensorArray, + RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, @@ -233,7 +234,8 @@ def print_entry(label, value): StringScalar, LargeStringScalar, StringViewScalar, FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, - RunEndEncodedScalar, ExtensionScalar) + RunEndEncodedScalar, ExtensionScalar, + FixedShapeTensorScalar, OpaqueScalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 997f208a5dec4..6c40a21db96ca 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -4448,6 +4448,34 @@ cdef class FixedShapeTensorArray(ExtensionArray): ) +cdef class OpaqueArray(ExtensionArray): + """ + Concrete class for opaque extension arrays. + + Examples + -------- + Define the extension type for an opaque array + + >>> import pyarrow as pa + >>> opaque_type = pa.opaque( + ... pa.binary(), + ... type_name="geometry", + ... vendor_name="postgis", + ... ) + + Create an extension array + + >>> arr = [None, b"data"] + >>> storage = pa.array(arr, pa.binary()) + >>> pa.ExtensionArray.from_storage(opaque_type, storage) + + [ + null, + 64617461 + ] + """ + + cdef dict _array_classes = { _Type_NA: NullArray, _Type_BOOL: BooleanArray, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0d871f411b11b..9b008d150f1f1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2882,6 +2882,19 @@ cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extens " arrow::extension::FixedShapeTensorArray"(CExtensionArray): const CResult[shared_ptr[CTensor]] ToTensor() const + +cdef extern from "arrow/extension/opaque.h" namespace "arrow::extension" nogil: + cdef cppclass COpaqueType \ + " arrow::extension::OpaqueType"(CExtensionType): + + c_string type_name() + c_string vendor_name() + + cdef cppclass COpaqueArray \ + " arrow::extension::OpaqueArray"(CExtensionArray): + pass + + cdef extern from "arrow/util/compression.h" namespace "arrow" nogil: cdef enum CCompressionType" arrow::Compression::type": CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED" diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 082d8470cdbb0..2cb302d20a8ac 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -215,6 +215,11 @@ cdef class FixedShapeTensorType(BaseExtensionType): const CFixedShapeTensorType* tensor_ext_type +cdef class OpaqueType(BaseExtensionType): + cdef: + const COpaqueType* opaque_ext_type + + cdef class PyExtensionType(ExtensionType): pass diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 966273b4bea84..2f9fc1c554209 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -124,6 +124,8 @@ cdef api object pyarrow_wrap_data_type( return cpy_ext_type.GetInstance() elif ext_type.extension_name() == b"arrow.fixed_shape_tensor": out = FixedShapeTensorType.__new__(FixedShapeTensorType) + elif ext_type.extension_name() == b"arrow.opaque": + out = OpaqueType.__new__(OpaqueType) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 41bfde39adb6f..12a99c2aece63 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1085,6 +1085,12 @@ cdef class FixedShapeTensorScalar(ExtensionScalar): return pyarrow_wrap_tensor(ctensor) +cdef class OpaqueScalar(ExtensionScalar): + """ + Concrete class for opaque extension scalar. + """ + + cdef dict _scalar_classes = { _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 1c4d0175a2d97..58c54189f223e 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1661,3 +1661,49 @@ def test_legacy_int_type(): batch = ipc_read_batch(buf) assert isinstance(batch.column(0).type, LegacyIntType) assert batch.column(0) == ext_arr + + +@pytest.mark.parametrize("storage_type,storage", [ + (pa.null(), [None] * 4), + (pa.int64(), [1, 2, None, 4]), + (pa.binary(), [None, b"foobar"]), + (pa.list_(pa.int64()), [[], [1, 2], None, [3, None]]), +]) +def test_opaque_type(pickle_module, storage_type, storage): + opaque_type = pa.opaque(storage_type, "type", "vendor") + assert opaque_type.extension_name == "arrow.opaque" + assert opaque_type.storage_type == storage_type + assert opaque_type.type_name == "type" + assert opaque_type.vendor_name == "vendor" + assert "arrow.opaque" in str(opaque_type) + + assert opaque_type == opaque_type + assert opaque_type != storage_type + assert opaque_type != pa.opaque(storage_type, "type2", "vendor") + assert opaque_type != pa.opaque(storage_type, "type", "vendor2") + assert opaque_type != pa.opaque(pa.decimal128(12, 3), "type", "vendor") + + # Pickle roundtrip + result = pickle_module.loads(pickle_module.dumps(opaque_type)) + assert result == opaque_type + + # IPC roundtrip + opaque_arr_class = opaque_type.__arrow_ext_class__() + storage = pa.array(storage, storage_type) + arr = pa.ExtensionArray.from_storage(opaque_type, storage) + assert isinstance(arr, opaque_arr_class) + + with registered_extension_type(opaque_type): + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) + + assert batch.column(0).type.extension_name == "arrow.opaque" + assert isinstance(batch.column(0), opaque_arr_class) + + # cast storage -> extension type + result = storage.cast(opaque_type) + assert result == arr + + # cast extension type -> storage type + inner = arr.cast(storage_type) + assert inner == storage diff --git a/python/pyarrow/tests/test_misc.py b/python/pyarrow/tests/test_misc.py index c42e4fbdfc2e8..9a55a38177fc8 100644 --- a/python/pyarrow/tests/test_misc.py +++ b/python/pyarrow/tests/test_misc.py @@ -247,6 +247,9 @@ def test_set_timezone_db_path_non_windows(): pa.ProxyMemoryPool, pa.Device, pa.MemoryManager, + pa.OpaqueArray, + pa.OpaqueScalar, + pa.OpaqueType, ]) def test_extension_type_constructor_errors(klass): # ARROW-2638: prevent calling extension class constructors directly diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 039870accddcb..93d68fb847890 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1837,6 +1837,50 @@ cdef class FixedShapeTensorType(BaseExtensionType): return FixedShapeTensorScalar +cdef class OpaqueType(BaseExtensionType): + """ + Concrete class for opaque extension type. + + Opaque is a placeholder for a type from an external (often non-Arrow) + system that could not be interpreted. + + Examples + -------- + Create an instance of opaque extension type: + + >>> import pyarrow as pa + >>> pa.opaque(pa.int32(), "geometry", "postgis") + OpaqueType(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.opaque_ext_type = type.get() + + @property + def type_name(self): + """ + The name of the type in the external system. + """ + return frombytes(c_string(self.opaque_ext_type.type_name())) + + @property + def vendor_name(self): + """ + The name of the external system. + """ + return frombytes(c_string(self.opaque_ext_type.vendor_name())) + + def __arrow_ext_class__(self): + return OpaqueArray + + def __reduce__(self): + return opaque, (self.storage_type, self.type_name, self.vendor_name) + + def __arrow_ext_scalar_class__(self): + return OpaqueScalar + + _py_extension_type_auto_load = False @@ -5234,6 +5278,63 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N return out +def opaque(DataType storage_type, str type_name not None, str vendor_name not None): + """ + Create instance of opaque extension type. + + Parameters + ---------- + storage_type : DataType + The underlying data type. + type_name : str + The name of the type in the external system. + vendor_name : str + The name of the external system. + + Examples + -------- + Create an instance of an opaque extension type: + + >>> import pyarrow as pa + >>> type = pa.opaque(pa.binary(), "other", "jdbc") + >>> type + OpaqueType(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(binary) + >>> type.type_name + 'other' + >>> type.vendor_name + 'jdbc' + + Create a table with an opaque array: + + >>> arr = [None, b"foobar"] + >>> storage = pa.array(arr, pa.binary()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[null,666F6F626172]] + + Returns + ------- + type : OpaqueType + """ + + cdef: + c_string c_type_name = tobytes(type_name) + c_string c_vendor_name = tobytes(vendor_name) + shared_ptr[CDataType] c_type = make_shared[COpaqueType]( + storage_type.sp_type, c_type_name, c_vendor_name) + OpaqueType out = OpaqueType.__new__(OpaqueType) + out.init(c_type) + return out + + cdef dict _type_aliases = { 'null': null, 'bool': bool_,