From d5e9e1329385497223fa0e213c65c2a4aba45466 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 25 Mar 2024 04:57:14 +0800 Subject: [PATCH 01/13] GH-40751: [C++] Fix protobuf package name setting for builds with substrait (#40753) ### Rationale for this change The problem #40751 seems to be introduced by #40399. Though I'm not entirely sure about the purpose of that, it seems to be missing an `OR ARROW_SUBSTRAIT` in the `if` branch in https://github.com/apache/arrow/commit/5baca0f16e924c42741729f041b31a02883548b9#diff-5cdc95f4e1b618f2f3ef10d370ce05a1ac05d9d401aecff3ccbb3d76bd366b6aR1815 Because other than `ARROW_ORC`, `ARROW_WITH_OPENTELEMETRY` and `ARROW_FLIGHT`, `ARROW_SUBSTRAIT` also implies `ARROW_WITH_PROTOBUF`: https://github.com/apache/arrow/blob/5baca0f16e924c42741729f041b31a02883548b9/cpp/cmake_modules/ThirdpartyToolchain.cmake#L421-L423 ### What changes are included in this PR? Add the possible missing condition of `ARROW_SUBSTRAIT` for the questioning `if` branch. ### Are these changes tested? Manually tested. ### Are there any user-facing changes? None. * GitHub Issue: #40751 Lead-authored-by: Ruoxi Sun Co-authored-by: Rossi Sun Co-authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index b8e765f08587a..ad7344b09dd4e 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1812,7 +1812,9 @@ if(ARROW_WITH_PROTOBUF) else() set(ARROW_PROTOBUF_REQUIRED_VERSION "2.6.1") endif() - if(ARROW_ORC OR ARROW_WITH_OPENTELEMETRY) + if(ARROW_ORC + OR ARROW_SUBSTRAIT + OR ARROW_WITH_OPENTELEMETRY) set(ARROW_PROTOBUF_ARROW_CMAKE_PACKAGE_NAME "Arrow") set(ARROW_PROTOBUF_ARROW_PC_PACKAGE_NAME "arrow") elseif(ARROW_FLIGHT) From f100eff39fd37538c5ab4572083029622fc0f5aa Mon Sep 17 00:00:00 2001 From: ZhangHuiGui <106943008+ZhangHuiGui@users.noreply.github.com> Date: Mon, 25 Mar 2024 13:02:24 +0800 Subject: [PATCH 02/13] GH-40308: [C++][Gandiva] Add support for compute module's decimal promotion rules (#40434) ### Rationale for this change Gandiva decimal divide rules are different with our compute module's rules. Some systems such as Redshift use the same rules as our compute module's rules. So it's useful that Gandiva support our compute module's rules too. ### What changes are included in this PR? Support an option argument in GetResultType for compatibilty with **compute module's decimal promotion rules**. ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #40308 Authored-by: ZhangHuiGui Signed-off-by: Sutou Kouhei --- cpp/src/gandiva/decimal_type_util.cc | 19 ++++++++-- cpp/src/gandiva/decimal_type_util.h | 14 ++++++- cpp/src/gandiva/tests/decimal_single_test.cc | 40 +++++++++++++++++--- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/cpp/src/gandiva/decimal_type_util.cc b/cpp/src/gandiva/decimal_type_util.cc index 2abc5a21eaa88..cce4292f3bf15 100644 --- a/cpp/src/gandiva/decimal_type_util.cc +++ b/cpp/src/gandiva/decimal_type_util.cc @@ -30,7 +30,8 @@ constexpr int32_t DecimalTypeUtil::kMinAdjustedScale; // Implementation of decimal rules. Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_types, - Decimal128TypePtr* out_type) { + Decimal128TypePtr* out_type, + bool use_compute_rules) { DCHECK_EQ(in_types.size(), 2); *out_type = nullptr; @@ -59,7 +60,9 @@ Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_type break; case kOpDivide: - result_scale = std::max(kMinAdjustedScale, s1 + p2 + 1); + result_scale = use_compute_rules + ? std::max(kMinComputeAdjustedScale, s1 + p2 - s2 + 1) + : std::max(kMinAdjustedScale, s1 + p2 + 1); result_precision = p1 - s1 + s2 + result_scale; break; @@ -68,7 +71,17 @@ Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_type result_precision = std::min(p1 - s1, p2 - s2) + result_scale; break; } - *out_type = MakeAdjustedType(result_precision, result_scale); + + if (use_compute_rules) { + if (result_precision < kMinPrecision || result_precision > kMaxPrecision) { + return Status::Invalid("Decimal precision out of range [", int32_t(kMinPrecision), + ", ", int32_t(kMaxPrecision), "]: ", result_precision); + } + *out_type = MakeType(result_precision, result_scale); + } else { + *out_type = MakeAdjustedType(result_precision, result_scale); + } + return Status::OK(); } diff --git a/cpp/src/gandiva/decimal_type_util.h b/cpp/src/gandiva/decimal_type_util.h index 2b496f6cbf5bd..16ce544717e46 100644 --- a/cpp/src/gandiva/decimal_type_util.h +++ b/cpp/src/gandiva/decimal_type_util.h @@ -45,6 +45,9 @@ class GANDIVA_EXPORT DecimalTypeUtil { /// The maximum precision representable by a 8-byte decimal static constexpr int32_t kMaxDecimal64Precision = 18; + /// The minimum precision representable by a 16-byte decimal + static constexpr int32_t kMinPrecision = 1; + /// The maximum precision representable by a 16-byte decimal static constexpr int32_t kMaxPrecision = 38; @@ -57,10 +60,19 @@ class GANDIVA_EXPORT DecimalTypeUtil { // * There is no strong reason for 6, but both SQLServer and Impala use 6 too. static constexpr int32_t kMinAdjustedScale = 6; + // The same function with kMinAdjustedScale, just for compatibility with + // compute module's decimal promotion rules. + static constexpr int32_t kMinComputeAdjustedScale = 4; + // For specified operation and input scale/precision, determine the output // scale/precision. + // + // The 'use_compute_rules' is for compatibility with compute module's + // decimal promotion rules: + // https://arrow.apache.org/docs/cpp/compute.html#arithmetic-functions static Status GetResultType(Op op, const Decimal128TypeVector& in_types, - Decimal128TypePtr* out_type); + Decimal128TypePtr* out_type, + bool use_compute_rules = false); static Decimal128TypePtr MakeType(int32_t precision, int32_t scale) { return std::dynamic_pointer_cast( diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc b/cpp/src/gandiva/tests/decimal_single_test.cc index 666ee4a68d5de..57c281a4551ef 100644 --- a/cpp/src/gandiva/tests/decimal_single_test.cc +++ b/cpp/src/gandiva/tests/decimal_single_test.cc @@ -49,7 +49,8 @@ class TestDecimalOps : public ::testing::Test { ArrayPtr MakeDecimalVector(const DecimalScalar128& in); void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x, - const DecimalScalar128& y, const DecimalScalar128& expected); + const DecimalScalar128& y, const DecimalScalar128& expected, + bool use_compute_rules = false, bool verify_failed = false); void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, const DecimalScalar128& expected) { @@ -67,8 +68,10 @@ class TestDecimalOps : public ::testing::Test { } void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected) { - Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected); + const DecimalScalar128& expected, bool use_compute_rules = false, + bool verify_failed = false) { + Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected, use_compute_rules, + verify_failed); } void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, @@ -91,7 +94,8 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) { void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function, const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected) { + const DecimalScalar128& expected, bool use_compute_rules, + bool verify_failed) { auto x_type = std::make_shared(x.precision(), x.scale()); auto y_type = std::make_shared(y.precision(), y.scale()); auto field_x = field("x", x_type); @@ -99,8 +103,14 @@ void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function, auto schema = arrow::schema({field_x, field_y}); Decimal128TypePtr output_type; - auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type); - ARROW_EXPECT_OK(status); + auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type, + use_compute_rules); + if (verify_failed) { + ASSERT_NOT_OK(status); + return; + } else { + ARROW_EXPECT_OK(status); + } // output fields auto res = field("res", output_type); @@ -283,13 +293,31 @@ TEST_F(TestDecimalOps, TestMultiply) { } TEST_F(TestDecimalOps, TestDivide) { + // fast-path + // + // origin Gandiva's rules DivideAndVerify(decimal_literal("201", 10, 3), // x decimal_literal("301", 10, 2), // y decimal_literal("6677740863787", 23, 14)); // expected + // compute module's rules + DivideAndVerify(decimal_literal("201", 10, 3), // x + decimal_literal("301", 10, 2), // y + decimal_literal("66777408638", 21, 12), // expected + /*use_compute_rules=*/true); + + // max precision beyond 38 + // + // normally under origin Gandiva rules DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x DecimalScalar128(std::string(35, '9'), 38, 20), // x DecimalScalar128("1000000000", 38, 6)); + + // invalid under compute module's rules + DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x + DecimalScalar128(std::string(35, '9'), 38, 20), // x + DecimalScalar128(std::string(35, '9'), 0, 0), // useless expected + /*use_compute_rules=*/true, /*verify_failed=*/true); } TEST_F(TestDecimalOps, TestMod) { From d236ceac4b8cad199aad88d8d57db1b984087409 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Mon, 25 Mar 2024 18:44:56 +0900 Subject: [PATCH 03/13] GH-40623: [Python][Docs] Add workaround for autosummary (#40739) ### Rationale for this change Sphinx < 7.3.0 has a problem that auto generated files may use wrong suffix. See https://github.com/sphinx-doc/sphinx/issues/12147 for details. ### What changes are included in this PR? Use `.rst` as the first `source_suffix` item as a workaround of this problem. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #40623 Authored-by: Sutou Kouhei Signed-off-by: AlenkaF --- docs/source/conf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7915e2c2c485a..ad8fa798d6aac 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -208,8 +208,14 @@ # source_suffix = { - '.md': 'markdown', + # We need to keep "'.rst': 'restructuredtext'" as the first item. + # This is a workaround of + # https://github.com/sphinx-doc/sphinx/issues/12147 . + # + # We can sort these items in alphabetical order with Sphinx 7.3.0 + # or later that will include the fix of this problem. '.rst': 'restructuredtext', + '.md': 'markdown', } autosummary_generate = True From 2b0427559d94fb25ce61672c1ed17fc245d0a546 Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Mon, 25 Mar 2024 07:03:49 -0400 Subject: [PATCH 04/13] GH-37720: [Format][Docs][FlightSQL] Document stateless prepared statements (#40243) documents changes for stateless management of FlightSQL prepared statement handles based on the design proposal described in apache/arrow#37720 * GitHub Issue: #37720 PRs for language implementations: * Rust: apache/arrow-rs#5433 * Go: apache/arrow#40311 Mailing list discussion: https://lists.apache.org/thread/3kb82ypx99q96g84qv555l6x8r0bppyq --------- Co-authored-by: David Li Co-authored-by: Andrew Lamb Co-authored-by: Sutou Kouhei --- docs/source/format/FlightSql.rst | 15 ++++++++++++++ .../CommandPreparedStatementQuery.mmd | 2 ++ .../CommandPreparedStatementQuery.mmd.svg | 2 +- format/FlightSql.proto | 20 +++++++++++++++++++ 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/source/format/FlightSql.rst b/docs/source/format/FlightSql.rst index 6bb917271366c..5573c0040761f 100644 --- a/docs/source/format/FlightSql.rst +++ b/docs/source/format/FlightSql.rst @@ -141,6 +141,21 @@ the ``type`` should be ``ClosePreparedStatement``). Execute a previously created prepared statement and get the results. When used with DoPut: binds parameter values to the prepared statement. + The server may optionally provide an updated handle in the response. + Updating the handle allows the client to supply all state required to + execute the query in an ActionPreparedStatementExecute message. + For example, stateless servers can encode the bound parameter values into + the new handle, and the client will send that new handle with parameters + back to the server. + + Note that a handle returned from a DoPut call with + CommandPreparedStatementQuery can itself be passed to a subsequent DoPut + call with CommandPreparedStatementQuery to bind a new set of parameters. + The subsequent call itself may return an updated handle which again should + be used for subsequent requests. + + The server is responsible for detecting the case where the client does not + use the updated handle and should return an error. When used with GetFlightInfo: execute the prepared statement. The prepared statement can be reused after fetching results. diff --git a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd index cb50522eb5a32..cbd1eb6014bca 100644 --- a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd +++ b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd @@ -28,6 +28,8 @@ Server->>Client: ActionCreatePreparedStatementResult{handle} loop for each invocation of the prepared statement Client->>Server: DoPut(CommandPreparedStatementQuery) Client->>Server: stream of FlightData +Server-->>Client: DoPutPreparedStatementResult{handle} +Note over Client,Server: optional response with updated handle Client->>Server: GetFlightInfo(CommandPreparedStatementQuery) Server->>Client: FlightInfo{endpoints: [FlightEndpoint{…}, …]} loop for each endpoint in FlightInfo.endpoints diff --git a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg index 96a5bc3688297..cbf6a78e9a5ce 100644 --- a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg +++ b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg @@ -1 +1 @@ -ClientServerDoAction(ActionCreatePreparedStatementRequest)1ActionCreatePreparedStatementResult{handle}2DoPut(CommandPreparedStatementQuery)3stream of FlightData4GetFlightInfo(CommandPreparedStatementQuery)5FlightInfo{endpoints: [FlightEndpoint{…}, …]}6DoGet(endpoint.ticket)7stream of FlightData8loop[for each endpoint in FlightInfo.endpoints]loop[for each invocation of the prepared statement]DoAction(ActionClosePreparedStatementRequest)9ActionClosePreparedStatementRequest{}10ClientServer \ No newline at end of file +ServerClientServerClientoptional response with updated handleloop[for each endpoint in FlightInfo.endpoints]loop[for each invocation of the prepared statement]DoAction(ActionCreatePreparedStatementRequest)1ActionCreatePreparedStatementResult{handle}2DoPut(CommandPreparedStatementQuery)3stream of FlightData4DoPutPreparedStatementResult{handle}5GetFlightInfo(CommandPreparedStatementQuery)6FlightInfo{endpoints: [FlightEndpoint{…}, …]}7DoGet(endpoint.ticket)8stream of FlightData9DoAction(ActionClosePreparedStatementRequest)10ActionClosePreparedStatementRequest{}11 \ No newline at end of file diff --git a/format/FlightSql.proto b/format/FlightSql.proto index 581cf1f76d57c..3282ee4f47304 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -1797,6 +1797,26 @@ message DoPutUpdateResult { int64 record_count = 1; } +/* An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. + * + * *Note on legacy behavior*: previous versions of the protocol did not return any result for + * this command, and that behavior should still be supported by clients. In that case, the client + * can continue as though the fields in this message were not provided or set to sensible default values. + */ +message DoPutPreparedStatementResult { + option (experimental) = true; + + // Represents a (potentially updated) opaque handle for the prepared statement on the server. + // Because the handle could potentially be updated, any previous handles for this prepared + // statement should be considered invalid, and all subsequent requests for this prepared + // statement must use this new handle. + // The updated handle allows implementing query parameters with stateless services. + // + // When an updated handle is not provided by the server, clients should contiue + // using the previous handle provided by `ActionCreatePreparedStatementResonse`. + optional bytes prepared_statement_handle = 1; +} + /* * Request message for the "CancelQuery" action. * From 8133a20c70cd50c0567f8fe1eff338257d144eaa Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 25 Mar 2024 22:37:34 +0800 Subject: [PATCH 05/13] GH-40652: [C++] Enlarge dest buffer according to dest offset for `CopyBitmap` benchmark (#40769) ### Rationale for this change `CopyBitmap` benchmark doesn't make dest buffer large enough wrt. the possible dest offset, causing memory issues like when accessing the trailing bytes. ### What changes are included in this PR? Make the dest buffer wrt. dest offset. ### Are these changes tested? The fix itself is test (benchmark). ### Are there any user-facing changes? None. * GitHub Issue: #40652 Authored-by: Ruoxi Sun Signed-off-by: Antoine Pitrou --- cpp/src/arrow/util/bit_util_benchmark.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/util/bit_util_benchmark.cc b/cpp/src/arrow/util/bit_util_benchmark.cc index 3bcb4ceea6303..0bf2c26f12486 100644 --- a/cpp/src/arrow/util/bit_util_benchmark.cc +++ b/cpp/src/arrow/util/bit_util_benchmark.cc @@ -449,7 +449,7 @@ static void CopyBitmap(benchmark::State& state) { // NOLINT non-const reference const uint8_t* src = buffer->data(); const int64_t length = bits_size - OffsetSrc; - auto copy = *AllocateEmptyBitmap(length); + auto copy = *AllocateEmptyBitmap(length + OffsetDest); for (auto _ : state) { internal::CopyBitmap(src, OffsetSrc, length, copy->mutable_data(), OffsetDest); From cc771a013362248269b75e054c2fed9c3d0f352a Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Mon, 25 Mar 2024 07:59:38 -0700 Subject: [PATCH 06/13] GH-40634: [C#] ArrowStreamReader should not be null (#40765) ### What changes are included in this PR? Small refactoring in the IPC reader implementation classes of how the schema is read in order to support getting the schema asynchronously through ArrowStreamReader and avoiding the case where ArrowStreamReader.Schema returns null because no record batches have yet been read. ### Are these changes tested? Yes. ### Are there any user-facing changes? A new method ArrowStreamReader.GetSchema has been added to allow the schema to be gotten asynchronously. Closes #40634 * GitHub Issue: #40634 Authored-by: Curt Hagenlocher Signed-off-by: Curt Hagenlocher --- .../FlightRecordBatchStreamReader.cs | 4 +-- .../RecordBatchReaderImplementation.cs | 27 ++++++++++++++----- .../Ipc/ArrowFileReaderImplementation.cs | 6 ++--- .../Ipc/ArrowMemoryReaderImplementation.cs | 11 ++++++-- .../Ipc/ArrowReaderImplementation.cs | 19 +++++++++++-- .../src/Apache.Arrow/Ipc/ArrowStreamReader.cs | 12 +++++++++ .../Ipc/ArrowStreamReaderImplementation.cs | 8 +++--- .../Apache.Arrow.Tests/ArrowReaderVerifier.cs | 3 +++ .../ArrowStreamReaderTests.cs | 2 ++ 9 files changed, 72 insertions(+), 20 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs index d21fb25f5c946..7400ec15e54d6 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs @@ -45,12 +45,12 @@ private protected FlightRecordBatchStreamReader(IAsyncStreamReader Schema => _arrowReaderImplementation.ReadSchema(); + public ValueTask Schema => _arrowReaderImplementation.GetSchemaAsync(); internal ValueTask GetFlightDescriptor() { return _arrowReaderImplementation.ReadFlightDescriptor(); - } + } /// /// Get the application metadata from the latest received record batch diff --git a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs index be844ea58e404..99876bf769dc7 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs @@ -48,19 +48,33 @@ public async ValueTask ReadFlightDescriptor() { if (!HasReadSchema) { - await ReadSchema().ConfigureAwait(false); + await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false); } return _flightDescriptor; } - public async ValueTask ReadSchema() + public async ValueTask GetSchemaAsync() + { + if (!HasReadSchema) + { + await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false); + } + return _schema; + } + + public override void ReadSchema() + { + ReadSchemaAsync(CancellationToken.None).AsTask().Wait(); + } + + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken) { if (HasReadSchema) { - return Schema; + return; } - var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); + var moveNextResult = await _flightDataStream.MoveNext(cancellationToken).ConfigureAwait(false); if (!moveNextResult) { @@ -87,12 +101,11 @@ public async ValueTask ReadSchema() switch (message.HeaderType) { case MessageHeader.Schema: - Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); + _schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); break; default: throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}"); } - return Schema; } public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) @@ -101,7 +114,7 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati if (!HasReadSchema) { - await ReadSchema().ConfigureAwait(false); + await ReadSchemaAsync(cancellationToken).ConfigureAwait(false); } var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); if (moveNextResult) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index 02f36b079349b..4b7c5f914c402 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -52,7 +52,7 @@ public async ValueTask RecordBatchCountAsync(CancellationToken cancellation return _footer.RecordBatchCount; } - protected override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -85,7 +85,7 @@ protected override async ValueTask ReadSchemaAsync(CancellationToken cancellatio } } - protected override void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -139,7 +139,7 @@ private void ReadSchema(Memory buffer) // Deserialize the footer from the footer flatbuffer _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref _dictionaryMemo); - Schema = _footer.Schema; + _schema = _footer.Schema; } public async ValueTask ReadRecordBatchAsync(int index, CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index 6e2336a591bf1..842c56823d07f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -33,6 +33,13 @@ public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer, ICompression _buffer = buffer; } + public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + ReadSchema(); + return default; + } + public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -93,7 +100,7 @@ public override RecordBatch ReadNextRecordBatch() return batch; } - private void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -117,7 +124,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index eb7349a570786..4e273dbde5690 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -30,13 +30,25 @@ namespace Apache.Arrow.Ipc { internal abstract class ArrowReaderImplementation : IDisposable { - public Schema Schema { get; protected set; } - protected bool HasReadSchema => Schema != null; + public Schema Schema + { + get + { + if (!HasReadSchema) + { + ReadSchema(); + } + return _schema; + } + } + + protected internal bool HasReadSchema => _schema != null; private protected DictionaryMemo _dictionaryMemo; private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); private protected readonly MemoryAllocator _allocator; private readonly ICompressionCodecFactory _compressionCodecFactory; + private protected Schema _schema; private protected ArrowReaderImplementation() : this(null, null) { } @@ -57,6 +69,9 @@ protected virtual void Dispose(bool disposing) { } + public abstract ValueTask ReadSchemaAsync(CancellationToken cancellationToken); + public abstract void ReadSchema(); + public abstract ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken); public abstract RecordBatch ReadNextRecordBatch(); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs index cdcfe7875da22..e129da399d59a 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs @@ -28,6 +28,9 @@ public class ArrowStreamReader : IArrowReader, IArrowArrayStream, IDisposable { private protected readonly ArrowReaderImplementation _implementation; + /// + /// May block if the schema hasn't yet been read. To avoid blocking, use GetSchemaAsync. + /// public Schema Schema => _implementation.Schema; public ArrowStreamReader(Stream stream) @@ -97,6 +100,15 @@ protected virtual void Dispose(bool disposing) } } + public async ValueTask GetSchema(CancellationToken cancellationToken = default) + { + if (!_implementation.HasReadSchema) + { + await _implementation.ReadSchemaAsync(cancellationToken); + } + return _implementation.Schema; + } + public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) { return _implementation.ReadNextRecordBatchAsync(cancellationToken); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 5428c88c27bbc..5583a58487bf5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -146,7 +146,7 @@ protected ReadResult ReadMessage() return new ReadResult(messageLength, result); } - protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -164,11 +164,11 @@ protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellation EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); } } - protected virtual void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -184,7 +184,7 @@ protected virtual void ReadSchema() EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 10315ff287c0b..2e7488092c2cf 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -38,6 +38,9 @@ public static void VerifyReader(ArrowStreamReader reader, RecordBatch originalBa public static async Task VerifyReaderAsync(ArrowStreamReader reader, RecordBatch originalBatch) { + Schema schema = await reader.GetSchema(); + Assert.NotNull(schema); + RecordBatch readBatch = await reader.ReadNextRecordBatchAsync(); CompareBatches(originalBatch, readBatch); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index ed030cc6ace11..b9e4664fdcd45 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -94,6 +94,8 @@ public async Task ReadRecordBatch_Memory(bool writeEnd) { await TestReaderFromMemory((reader, originalBatch) => { + Assert.NotNull(reader.Schema); + ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; }, writeEnd); From 1781b3248743812f737f32c244434b00e8d18579 Mon Sep 17 00:00:00 2001 From: Kemal <223029+disq@users.noreply.github.com> Date: Mon, 25 Mar 2024 15:12:31 +0000 Subject: [PATCH 07/13] GH-40693: [Go] Fix Decimal type precision loss on GetOneForMarshal (#40694) ### Rationale for this change Loss of precision when using `GetOneForMarshal` on `Decimal128` and `Decimal256` ### What changes are included in this PR? Fixes for precision loss with `DecimalType.GetOneForMarshal` * GitHub Issue: #40693 Lead-authored-by: Herman Schaaf Co-authored-by: Kemal Hadimli Signed-off-by: Matt Topol --- go/arrow/array/decimal128.go | 13 +++--- go/arrow/array/decimal128_test.go | 59 +++++++++++++++++++++++++- go/arrow/array/decimal256.go | 12 ++++-- go/arrow/array/decimal256_test.go | 70 ++++++++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 12 deletions(-) diff --git a/go/arrow/array/decimal128.go b/go/arrow/array/decimal128.go index 0dca320cda959..dc5f5d761618e 100644 --- a/go/arrow/array/decimal128.go +++ b/go/arrow/array/decimal128.go @@ -19,7 +19,6 @@ package array import ( "bytes" "fmt" - "math" "math/big" "reflect" "strings" @@ -86,15 +85,19 @@ func (a *Decimal128) setData(data *Data) { a.values = a.values[beg:end] } } - func (a *Decimal128) GetOneForMarshal(i int) interface{} { if a.IsNull(i) { return nil } - typ := a.DataType().(*arrow.Decimal128Type) - f := (&big.Float{}).SetInt(a.Value(i).BigInt()) - f.Quo(f, big.NewFloat(math.Pow10(int(typ.Scale)))) + n := a.Value(i) + scale := typ.Scale + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(-scale)).BigInt())) + } else { + f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(scale)).BigInt())) + } return f.Text('g', int(typ.Precision)) } diff --git a/go/arrow/array/decimal128_test.go b/go/arrow/array/decimal128_test.go index 836a6987df69f..31c6a6f8cadd6 100644 --- a/go/arrow/array/decimal128_test.go +++ b/go/arrow/array/decimal128_test.go @@ -204,7 +204,17 @@ func TestDecimal128StringRoundTrip(t *testing.T) { decimal128.FromI64(9), decimal128.FromI64(10), } - valid := []bool{true, true, true, false, true, true, false, true, true, true} + val1, err := decimal128.FromString("0.99", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + val2, err := decimal128.FromString("1234567890.12345", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + values = append(values, val1, val2) + + valid := []bool{true, true, true, false, true, true, false, true, true, true, true, true} b.AppendValues(values, valid) @@ -224,3 +234,50 @@ func TestDecimal128StringRoundTrip(t *testing.T) { assert.True(t, array.Equal(arr, arr1)) } + +func TestDecimal128GetOneForMarshal(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.Decimal128Type{Precision: 38, Scale: 20} + + b := array.NewDecimal128Builder(mem, dtype) + defer b.Release() + + cases := []struct { + give any + want any + }{ + {"1", "1"}, + {"1.25", "1.25"}, + {"0.99", "0.99"}, + {"1234567890.123456789", "1234567890.123456789"}, + {nil, nil}, + {"-0.99", "-0.99"}, + {"-1234567890.123456789", "-1234567890.123456789"}, + {"0.0000000000000000001", "1e-19"}, + } + for _, v := range cases { + if v.give == nil { + b.AppendNull() + continue + } + + dt, err := decimal128.FromString(v.give.(string), dtype.Precision, dtype.Scale) + if err != nil { + t.Fatal(err) + } + b.Append(dt) + } + + arr := b.NewDecimal128Array() + defer arr.Release() + + if got, want := arr.Len(), len(cases); got != want { + t.Fatalf("invalid array length: got=%d, want=%d", got, want) + } + + for i := range cases { + assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i) + } +} diff --git a/go/arrow/array/decimal256.go b/go/arrow/array/decimal256.go index 452ac96625bc8..f9c666300fa61 100644 --- a/go/arrow/array/decimal256.go +++ b/go/arrow/array/decimal256.go @@ -19,7 +19,6 @@ package array import ( "bytes" "fmt" - "math" "math/big" "reflect" "strings" @@ -91,10 +90,15 @@ func (a *Decimal256) GetOneForMarshal(i int) interface{} { if a.IsNull(i) { return nil } - typ := a.DataType().(*arrow.Decimal256Type) - f := (&big.Float{}).SetInt(a.Value(i).BigInt()) - f.Quo(f, big.NewFloat(math.Pow10(int(typ.Scale)))) + n := a.Value(i) + scale := typ.Scale + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(-scale)).BigInt())) + } else { + f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(scale)).BigInt())) + } return f.Text('g', int(typ.Precision)) } diff --git a/go/arrow/array/decimal256_test.go b/go/arrow/array/decimal256_test.go index 4f0c441210643..c78bd5243a66a 100644 --- a/go/arrow/array/decimal256_test.go +++ b/go/arrow/array/decimal256_test.go @@ -205,7 +205,17 @@ func TestDecimal256StringRoundTrip(t *testing.T) { decimal256.FromI64(9), decimal256.FromI64(10), } - valid := []bool{true, true, true, false, true, true, false, true, true, true} + val1, err := decimal256.FromString("0.99", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + val2, err := decimal256.FromString("1234567890.123456789", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + values = append(values, val1, val2) + + valid := []bool{true, true, true, false, true, true, false, true, true, true, true, true} b.AppendValues(values, valid) @@ -217,11 +227,67 @@ func TestDecimal256StringRoundTrip(t *testing.T) { defer b1.Release() for i := 0; i < arr.Len(); i++ { - assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + v := arr.ValueStr(i) + assert.NoError(t, b1.AppendValueFromString(v)) } arr1 := b1.NewArray().(*array.Decimal256) defer arr1.Release() + for i := 0; i < arr.Len(); i++ { + if arr.IsNull(i) && arr1.IsNull(i) { + continue + } + if arr.Value(i) != arr1.Value(i) { + t.Fatalf("unexpected value at index %d: got=%v, want=%v", i, arr1.Value(i), arr.Value(i)) + } + } assert.True(t, array.Equal(arr, arr1)) } + +func TestDecimal256GetOneForMarshal(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.Decimal256Type{Precision: 38, Scale: 20} + + b := array.NewDecimal256Builder(mem, dtype) + defer b.Release() + + cases := []struct { + give any + want any + }{ + {"1", "1"}, + {"1.25", "1.25"}, + {"0.99", "0.99"}, + {"1234567890.123456789", "1234567890.123456789"}, + {nil, nil}, + {"-0.99", "-0.99"}, + {"-1234567890.123456789", "-1234567890.123456789"}, + {"0.0000000000000000001", "1e-19"}, + } + for _, v := range cases { + if v.give == nil { + b.AppendNull() + continue + } + + dt, err := decimal256.FromString(v.give.(string), dtype.Precision, dtype.Scale) + if err != nil { + t.Fatal(err) + } + b.Append(dt) + } + + arr := b.NewDecimal256Array() + defer arr.Release() + + if got, want := arr.Len(), len(cases); got != want { + t.Fatalf("invalid array length: got=%d, want=%d", got, want) + } + + for i := range cases { + assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i) + } +} From 3095344d68af3e4353c9ce098d73fe6768bcb626 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 25 Mar 2024 16:49:05 +0100 Subject: [PATCH 08/13] GH-40279: [C++] Reduce S3Client initialization time (#40299) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Rationale for this change By default, S3Client instantiation is extremely slow (around 1ms for every instance). Investigation led to the conclusion that most of this time was spent inside the AWS SDK, parsing a hardcoded piece of JSON data when instantiating a AWS rule engine. Python benchmarks show this repeated initiatlization cost: ```python >>> from pyarrow.fs import S3FileSystem >>> %time s = S3FileSystem() CPU times: user 21.1 ms, sys: 0 ns, total: 21.1 ms Wall time: 20.9 ms >>> %time s = S3FileSystem() CPU times: user 2.37 ms, sys: 0 ns, total: 2.37 ms Wall time: 2.18 ms >>> %time s = S3FileSystem() CPU times: user 2.42 ms, sys: 0 ns, total: 2.42 ms Wall time: 2.23 ms >>> %timeit s = S3FileSystem() 1.28 ms ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) >>> %timeit s = S3FileSystem() 1.28 ms ± 2.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) >>> %timeit s = S3FileSystem(anonymous=True) 1.26 ms ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` ### What changes are included in this PR? Instead of letting the AWS SDK create a new S3EndpointProvider for each S3Client, arrange to only create a single S3EndpointProvider per set of endpoint configuration options. This lets the 1ms instantiation cost be paid only when a new set of endpoint configuration options is given. Python benchmarks show the initialization cost has become a one-time cost: ```python >>> from pyarrow.fs import S3FileSystem >>> %time s = S3FileSystem() CPU times: user 20 ms, sys: 0 ns, total: 20 ms Wall time: 19.8 ms >>> %time s = S3FileSystem() CPU times: user 404 µs, sys: 49 µs, total: 453 µs Wall time: 266 µs >>> %time s = S3FileSystem() CPU times: user 361 µs, sys: 42 µs, total: 403 µs Wall time: 249 µs >>> %timeit s = S3FileSystem() 50.4 µs ± 227 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) >>> %timeit s = S3FileSystem(anonymous=True) 33.5 µs ± 306 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) ``` ### Are these changes tested? By existing tests. ### Are there any user-facing changes? No. * GitHub Issue: #40279 Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/filesystem/s3fs.cc | 174 ++++++++++++++++++++++++++++--- 1 file changed, 161 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 2ba64ee22f54f..640888e1c4fa5 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -99,12 +99,21 @@ #define ARROW_S3_HAS_CRT #endif +#if ARROW_AWS_SDK_VERSION_CHECK(1, 10, 0) +#define ARROW_S3_HAS_S3CLIENT_CONFIGURATION +#endif + #ifdef ARROW_S3_HAS_CRT #include #include #include #endif +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION +#include +#include +#endif + #include "arrow/util/windows_fixup.h" #include "arrow/buffer.h" @@ -128,19 +137,17 @@ #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" -namespace arrow { - -using internal::TaskGroup; -using internal::ToChars; -using io::internal::SubmitIO; -using util::Uri; - -namespace fs { +namespace arrow::fs { using ::Aws::Client::AWSError; using ::Aws::S3::S3Errors; namespace S3Model = Aws::S3::Model; +using ::arrow::internal::TaskGroup; +using ::arrow::internal::ToChars; +using ::arrow::io::internal::SubmitIO; +using ::arrow::util::Uri; + using internal::ConnectRetryStrategy; using internal::DetectS3Backend; using internal::ErrorToStatus; @@ -913,6 +920,134 @@ Result> GetClientHolder( // ----------------------------------------------------------------------- // S3 client factory: build S3Client from S3Options +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + +// GH-40279: standard initialization of S3Client creates a new `S3EndpointProvider` +// every time. Its construction takes 1ms, which makes instantiating every S3Client +// very costly (see upstream bug report +// at https://github.com/aws/aws-sdk-cpp/issues/2880). +// To work around this, we build and cache `S3EndpointProvider` instances +// for each distinct endpoint configuration, and reuse them whenever possible. +// Since most applications tend to use a single endpoint configuration, this +// makes the 1ms setup cost a once-per-process overhead, making it much more +// bearable - if not ideal. + +struct EndpointConfigKey { + explicit EndpointConfigKey(const Aws::S3::S3ClientConfiguration& config) + : region(config.region), + scheme(config.scheme), + endpoint_override(config.endpointOverride), + use_virtual_addressing(config.useVirtualAddressing) {} + + Aws::String region; + Aws::Http::Scheme scheme; + Aws::String endpoint_override; + bool use_virtual_addressing; + + bool operator==(const EndpointConfigKey& other) const noexcept { + return region == other.region && scheme == other.scheme && + endpoint_override == other.endpoint_override && + use_virtual_addressing == other.use_virtual_addressing; + } +}; + +} // namespace +} // namespace arrow::fs + +template <> +struct std::hash { + std::size_t operator()(const arrow::fs::EndpointConfigKey& key) const noexcept { + // A crude hash is sufficient since we expect the cache to remain very small. + auto h = std::hash{}; + return h(key.region) ^ h(key.endpoint_override); + } +}; + +namespace arrow::fs { +namespace { + +// EndpointProvider configuration happens in a non-thread-safe way, even +// when the updates are idempotent. This is a problem when trying to reuse +// a single EndpointProvider from several clients. +// To work around this, this class ensures reconfiguration of an existing +// EndpointProvider is a no-op. +class InitOnceEndpointProvider : public Aws::S3::S3EndpointProviderBase { + public: + explicit InitOnceEndpointProvider( + std::shared_ptr wrapped) + : wrapped_(std::move(wrapped)) {} + + void InitBuiltInParameters(const Aws::S3::S3ClientConfiguration& config) override {} + + void OverrideEndpoint(const Aws::String& endpoint) override { + ARROW_LOG(ERROR) << "unexpected call to InitOnceEndpointProvider::OverrideEndpoint"; + } + Aws::S3::Endpoint::S3ClientContextParameters& AccessClientContextParameters() override { + ARROW_LOG(ERROR) + << "unexpected call to InitOnceEndpointProvider::AccessClientContextParameters"; + // Need to return a reference to something... + return wrapped_->AccessClientContextParameters(); + } + + const Aws::S3::Endpoint::S3ClientContextParameters& GetClientContextParameters() + const override { + return wrapped_->GetClientContextParameters(); + } + Aws::Endpoint::ResolveEndpointOutcome ResolveEndpoint( + const Aws::Endpoint::EndpointParameters& params) const override { + return wrapped_->ResolveEndpoint(params); + } + + protected: + std::shared_ptr wrapped_; +}; + +// A class that instantiates a single EndpointProvider per distinct endpoint +// configuration and initializes it in a thread-safe way. See earlier comments +// for rationale. +class EndpointProviderCache { + public: + std::shared_ptr Lookup( + const Aws::S3::S3ClientConfiguration& config) { + auto key = EndpointConfigKey(config); + CacheValue* value; + { + std::unique_lock lock(mutex_); + value = &cache_[std::move(key)]; + } + std::call_once(value->once, [&]() { + auto endpoint_provider = std::make_shared(); + endpoint_provider->InitBuiltInParameters(config); + value->endpoint_provider = + std::make_shared(std::move(endpoint_provider)); + }); + return value->endpoint_provider; + } + + void Reset() { + std::unique_lock lock(mutex_); + cache_.clear(); + } + + static EndpointProviderCache* Instance() { + static EndpointProviderCache instance; + return &instance; + } + + private: + EndpointProviderCache() = default; + + struct CacheValue { + std::once_flag once; + std::shared_ptr endpoint_provider; + }; + + std::mutex mutex_; + std::unordered_map cache_; +}; + +#endif // ARROW_S3_HAS_S3CLIENT_CONFIGURATION + class ClientBuilder { public: explicit ClientBuilder(S3Options options) : options_(std::move(options)) {} @@ -958,9 +1093,6 @@ class ClientBuilder { client_config_.caPath = ToAwsString(internal::global_options.tls_ca_dir_path); } - const bool use_virtual_addressing = - options_.endpoint_override.empty() || options_.force_virtual_addressing; - // Set proxy options if provided if (!options_.proxy_options.scheme.empty()) { if (options_.proxy_options.scheme == "http") { @@ -990,10 +1122,20 @@ class ClientBuilder { client_config_.maxConnections = std::max(io_context->executor()->GetCapacity(), 25); } + const bool use_virtual_addressing = + options_.endpoint_override.empty() || options_.force_virtual_addressing; + +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + client_config_.useVirtualAddressing = use_virtual_addressing; + auto endpoint_provider = EndpointProviderCache::Instance()->Lookup(client_config_); + auto client = std::make_shared(credentials_provider_, endpoint_provider, + client_config_); +#else auto client = std::make_shared( credentials_provider_, client_config_, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, use_virtual_addressing); +#endif client->s3_retry_strategy_ = options_.retry_strategy; return GetClientHolder(std::move(client)); } @@ -1002,7 +1144,11 @@ class ClientBuilder { protected: S3Options options_; +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + Aws::S3::S3ClientConfiguration client_config_; +#else Aws::Client::ClientConfiguration client_config_; +#endif std::shared_ptr credentials_provider_; }; @@ -2949,6 +3095,9 @@ struct AwsInstance { "This could lead to a segmentation fault at exit"; } GetClientFinalizer()->Finalize(); +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + EndpointProviderCache::Instance()->Reset(); +#endif Aws::ShutdownAPI(aws_options_); } } @@ -3090,5 +3239,4 @@ Result ResolveS3BucketRegion(const std::string& bucket) { return resolver->ResolveRegion(bucket); } -} // namespace fs -} // namespace arrow +} // namespace arrow::fs From 8efb8dce97693289e7a2dd780b6385714f07a5e3 Mon Sep 17 00:00:00 2001 From: Andrew Grosser Date: Mon, 25 Mar 2024 12:04:19 -0700 Subject: [PATCH 09/13] GH-40755: [JS] fix decimal conversions (#40729) ### Rationale for this change Fixes https://github.com/apache/arrow/issues/40755 Further work is required to complete: https://github.com/apache/arrow/issues/37920 Decimals are broken - need a correct way to convert decimals to numbers in js Also included an option to include a denominator (BigInt(1/scale)) as scale is part of the metadata ### What changes are included in this PR? Submitting a correct way to convert decimals to numbers ### Are these changes tested? Yes, includes a test ### Are there any user-facing changes? No **This PR contains a "Critical Fix".** * GitHub Issue: #37920 * GitHub Issue: #40755 --------- Co-authored-by: Dominik Moritz Co-authored-by: Paul Taylor <178183+trxcllnt@users.noreply.github.com> --- js/src/util/bn.ts | 32 ++++++++++++++++++++++---------- js/test/unit/bn-tests.ts | 15 +++++++++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/js/src/util/bn.ts b/js/src/util/bn.ts index af546be5436a2..b4db9cf2b4afe 100644 --- a/js/src/util/bn.ts +++ b/js/src/util/bn.ts @@ -36,7 +36,7 @@ function BigNum(this: any, x: any, ...xs: any) { BigNum.prototype[isArrowBigNumSymbol] = true; BigNum.prototype.toJSON = function >(this: T) { return `"${bigNumToString(this)}"`; }; -BigNum.prototype.valueOf = function >(this: T) { return bigNumToNumber(this); }; +BigNum.prototype.valueOf = function >(this: T, scale?: number) { return bigNumToNumber(this, scale); }; BigNum.prototype.toString = function >(this: T) { return bigNumToString(this); }; BigNum.prototype[Symbol.toPrimitive] = function >(this: T, hint: 'string' | 'number' | 'default' = 'default') { switch (hint) { @@ -68,24 +68,36 @@ Object.assign(SignedBigNum.prototype, BigNum.prototype, { 'constructor': SignedB Object.assign(UnsignedBigNum.prototype, BigNum.prototype, { 'constructor': UnsignedBigNum, 'signed': false, 'TypedArray': Uint32Array, 'BigIntArray': BigUint64Array }); Object.assign(DecimalBigNum.prototype, BigNum.prototype, { 'constructor': DecimalBigNum, 'signed': true, 'TypedArray': Uint32Array, 'BigIntArray': BigUint64Array }); +//FOR ES2020 COMPATIBILITY +const TWO_TO_THE_64 = BigInt(4294967296) * BigInt(4294967296); // 2^64 = 0x10000000000000000n +const TWO_TO_THE_64_MINUS_1 = TWO_TO_THE_64 - BigInt(1); // (2^32 * 2^32) - 1 = 0xFFFFFFFFFFFFFFFFn + /** @ignore */ -function bigNumToNumber>(bn: T) { - const { buffer, byteOffset, length, 'signed': signed } = bn; - const words = new BigUint64Array(buffer, byteOffset, length); +export function bigNumToNumber>(bn: T, scale?: number) { + const { buffer, byteOffset, byteLength, 'signed': signed } = bn; + const words = new BigUint64Array(buffer, byteOffset, byteLength / 8); const negative = signed && words.at(-1)! & (BigInt(1) << BigInt(63)); - let number = negative ? BigInt(1) : BigInt(0); - let i = BigInt(0); + let number = BigInt(0); + let i = 0; if (!negative) { for (const word of words) { - number += word * (BigInt(1) << (BigInt(32) * i++)); + number |= word * (BigInt(1) << BigInt(64 * i++)); } } else { for (const word of words) { - number += ~word * (BigInt(1) << (BigInt(32) * i++)); + number |= (word ^ TWO_TO_THE_64_MINUS_1) * (BigInt(1) << BigInt(64 * i++)); } number *= BigInt(-1); + number -= BigInt(1); + } + if (typeof scale === 'number') { + const denominator = BigInt(Math.pow(10, scale)); + const quotient = number / denominator; + const remainder = number % denominator; + const n = Number(quotient) + (Number(remainder) / Number(denominator)); + return n; } - return number; + return Number(number); } /** @ignore */ @@ -217,7 +229,7 @@ export interface BN extends TypedArrayLike { * arithmetic operators, like `+`. Easy (and unsafe) way to convert BN to * number via `+bn_inst` */ - valueOf(): number; + valueOf(scale?: number): number; /** * Return the JSON representation of the bytes. Must be wrapped in double-quotes, * so it's compatible with JSON.stringify(). diff --git a/js/test/unit/bn-tests.ts b/js/test/unit/bn-tests.ts index c9606baf85942..dbda02198ea2e 100644 --- a/js/test/unit/bn-tests.ts +++ b/js/test/unit/bn-tests.ts @@ -83,4 +83,19 @@ describe(`BN`, () => { const d4 = toDecimal(new Uint32Array([0x9D91E773, 0x4BB90CED, 0xAB2354CC, 0x54278E9B])); expect(d4.toString()).toBe('111860543658909349380118287427608635251'); }); + + test(`valueOf for decimal numbers`, () => { + const n1 = new BN(new Uint32Array([0x00000001, 0x00000000, 0x00000000, 0x00000000]), false); + expect(n1.valueOf()).toBe(1); + const n2 = new BN(new Uint32Array([0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]), true); + expect(n2.valueOf()).toBe(-2); + const n3 = new BN(new Uint32Array([0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]), true); + expect(n3.valueOf()).toBe(-1); + const n4 = new BN(new Uint32Array([0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]), true); + expect(n4.valueOf(1)).toBe(-0.1); + const n5 = new BN(new Uint32Array([0x00000000, 0x00000000, 0x00000000, 0x80000000]), false); + expect(n5.valueOf()).toBe(1.7014118346046923e+38); + // const n6 = new BN(new Uint32Array([0x00000000, 0x00000000, 0x00000000, 0x80000000]), false); + // expect(n6.valueOf(1)).toBe(1.7014118346046923e+37); + }); }); From ab6a5fd610cbd10cbded6ef0f235ad54b9273496 Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Mon, 25 Mar 2024 15:39:26 -0400 Subject: [PATCH 10/13] MINOR: [R] Update maintainer in package description (#40692) ### Rationale for this change Update R package maintainer to Jon ### What changes are included in this PR? Update maintainer field, swapping over Nic and Jon! ### Are these changes tested? Nope ### Are there any user-facing changes? Nah! Authored-by: Nic Crane Signed-off-by: Nic Crane --- r/DESCRIPTION | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 81b5d79258255..6062a8c4f4689 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -4,10 +4,10 @@ Version: 15.0.2.9000 Authors@R: c( person("Neal", "Richardson", email = "neal.p.richardson@gmail.com", role = c("aut")), person("Ian", "Cook", email = "ianmcook@gmail.com", role = c("aut")), - person("Nic", "Crane", email = "thisisnic@gmail.com", role = c("aut", "cre")), + person("Nic", "Crane", email = "thisisnic@gmail.com", role = c("aut")), person("Dewey", "Dunnington", role = c("aut"), email = "dewey@fishandwhistle.net", comment = c(ORCID = "0000-0002-9415-4582")), person("Romain", "Fran\u00e7ois", role = c("aut"), comment = c(ORCID = "0000-0002-2444-4226")), - person("Jonathan", "Keane", email = "jkeane@gmail.com", role = c("aut")), + person("Jonathan", "Keane", email = "jkeane@gmail.com", role = c("aut", "cre")), person("Drago\u0219", "Moldovan-Gr\u00fcnfeld", email = "dragos.mold@gmail.com", role = c("aut")), person("Jeroen", "Ooms", email = "jeroen@berkeley.edu", role = c("aut")), person("Jacob", "Wujciak-Jens", email = "jacob@wujciak.de", role = c("aut")), From 5fd6b44936a19761e45a8e43d7e76a0a23c5a222 Mon Sep 17 00:00:00 2001 From: Peter Newcomb Date: Mon, 25 Mar 2024 16:48:50 -0400 Subject: [PATCH 11/13] GH-40630: [Go][Parquet] Enable writing of Parquet footer without closing file (#40654) ### Rationale for this change See #40630 ### What changes are included in this PR? 1. Added `FlushWithFooter` method to *file.Writer 2. To support `FlushWithFooter`, refactored `Close` in a way that changes the order of operations in two ways: a. closure of open row group writers is now done after using `defer` to ensure closure of the sink, instead of before b. wiping out of encryption keys is now done by the same deferred function, ensuring that it happens even upon error ### Are these changes tested? `file_writer_test.go` has been extended to cover `FlushWithFooter` in a manner equivalent to the existing coverage. ### Are there any user-facing changes? Only the addition of a new public method as described above. No breaking changes to any existing public interfaces, unless the two minor order-of-operation changes described above are somehow a problem. I'm not sure it's a critical fix, but one of the minor changes described above may reduce the likelihood that an attack could inject an error (e.g., an I/O error) to prevent an encryption key from being wiped from memory. * GitHub Issue: #40630 Authored-by: Peter Newcomb Signed-off-by: Matt Topol --- go/parquet/file/file_writer.go | 62 ++++++++++++++++++----------- go/parquet/file/file_writer_test.go | 17 +++++++- go/parquet/metadata/file.go | 15 ++++++- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/go/parquet/file/file_writer.go b/go/parquet/file/file_writer.go index a2cf397cbc80b..57344b25cf05c 100644 --- a/go/parquet/file/file_writer.go +++ b/go/parquet/file/file_writer.go @@ -32,6 +32,7 @@ import ( type Writer struct { sink utils.WriteCloserTell open bool + footerFlushed bool props *parquet.WriterProperties rowGroups int nrows int @@ -125,6 +126,7 @@ func (fw *Writer) appendRowGroup(buffered bool) *rowGroupWriter { fw.rowGroupWriter.Close() } fw.rowGroups++ + fw.footerFlushed = false rgMeta := fw.metadata.AppendRowGroup() fw.rowGroupWriter = newRowGroupWriter(fw.sink, rgMeta, int16(fw.rowGroups)-1, fw.props, buffered, fw.fileEncryptor) return fw.rowGroupWriter @@ -172,12 +174,9 @@ func (fw *Writer) Close() (err error) { // if any functions here panic, we set open to be false so // that this doesn't get called again fw.open = false - if fw.rowGroupWriter != nil { - fw.nrows += fw.rowGroupWriter.nrows - fw.rowGroupWriter.Close() - } - fw.rowGroupWriter = nil + defer func() { + fw.closeEncryptor() ierr := fw.sink.Close() if err != nil { if ierr != nil { @@ -189,30 +188,48 @@ func (fw *Writer) Close() (err error) { err = ierr }() + err = fw.FlushWithFooter() + fw.metadata.Clear() + } + return nil +} + +// FlushWithFooter closes any open row group writer and writes the file footer, leaving +// the writer open for additional row groups. Additional footers written by later +// calls to FlushWithFooter or Close will be cumulative, so that only the last footer +// written need ever be read by a reader. +func (fw *Writer) FlushWithFooter() error { + if !fw.footerFlushed { + if fw.rowGroupWriter != nil { + fw.nrows += fw.rowGroupWriter.nrows + fw.rowGroupWriter.Close() + } + fw.rowGroupWriter = nil + + fileMetadata, err := fw.metadata.Snapshot() + if err != nil { + return err + } + fileEncryptProps := fw.props.FileEncryptionProperties() if fileEncryptProps == nil { // non encrypted file - fileMetadata, err := fw.metadata.Finish() - if err != nil { + if _, err = writeFileMetadata(fileMetadata, fw.sink); err != nil { + return err + } + } else { + if err := fw.flushEncryptedFile(fileMetadata, fileEncryptProps); err != nil { return err } - - _, err = writeFileMetadata(fileMetadata, fw.sink) - return err } - return fw.closeEncryptedFile(fileEncryptProps) + fw.footerFlushed = true } return nil } -func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) error { +func (fw *Writer) flushEncryptedFile(fileMetadata *metadata.FileMetaData, props *parquet.FileEncryptionProperties) error { // encrypted file with encrypted footer if props.EncryptedFooter() { - fileMetadata, err := fw.metadata.Finish() - if err != nil { - return err - } - footerLen := int64(0) cryptoMetadata := fw.metadata.GetFileCryptoMetaData() @@ -236,19 +253,18 @@ func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) er return err } } else { - fileMetadata, err := fw.metadata.Finish() - if err != nil { - return err - } footerSigningEncryptor := fw.fileEncryptor.GetFooterSigningEncryptor() - if _, err = writeEncryptedFileMetadata(fileMetadata, fw.sink, footerSigningEncryptor, false); err != nil { + if _, err := writeEncryptedFileMetadata(fileMetadata, fw.sink, footerSigningEncryptor, false); err != nil { return err } } + return nil +} + +func (fw *Writer) closeEncryptor() { if fw.fileEncryptor != nil { fw.fileEncryptor.WipeOutEncryptionKeys() } - return nil } func writeFileMetadata(fileMetadata *metadata.FileMetaData, w io.Writer) (n int64, err error) { diff --git a/go/parquet/file/file_writer_test.go b/go/parquet/file/file_writer_test.go index 434c9852c5823..3687fc8778202 100644 --- a/go/parquet/file/file_writer_test.go +++ b/go/parquet/file/file_writer_test.go @@ -64,6 +64,20 @@ func (t *SerializeTestSuite) fileSerializeTest(codec compress.Compression, expec writer := file.NewParquetWriter(sink, t.Schema.Root(), file.WithWriterProps(props)) t.GenerateData(int64(t.rowsPerRG)) + + t.serializeGeneratedData(writer) + writer.FlushWithFooter() + + t.validateSerializedData(writer, sink, expected) + + t.serializeGeneratedData(writer) + writer.Close() + + t.numRowGroups *= 2 + t.validateSerializedData(writer, sink, expected) +} + +func (t *SerializeTestSuite) serializeGeneratedData(writer *file.Writer) { for rg := 0; rg < t.numRowGroups/2; rg++ { rgw := writer.AppendRowGroup() for col := 0; col < t.numCols; col++ { @@ -94,8 +108,9 @@ func (t *SerializeTestSuite) fileSerializeTest(codec compress.Compression, expec } rgw.Close() } - writer.Close() +} +func (t *SerializeTestSuite) validateSerializedData(writer *file.Writer, sink *encoding.BufferWriter, expected compress.Compression) { nrows := t.numRowGroups * t.rowsPerRG t.EqualValues(nrows, writer.NumRows()) diff --git a/go/parquet/metadata/file.go b/go/parquet/metadata/file.go index f40081f172a75..fc376383165b1 100644 --- a/go/parquet/metadata/file.go +++ b/go/parquet/metadata/file.go @@ -104,6 +104,15 @@ func (f *FileMetaDataBuilder) AppendKeyValueMetadata(key string, value string) e // version etc. This will clear out this filemetadatabuilder so it can // be re-used func (f *FileMetaDataBuilder) Finish() (*FileMetaData, error) { + out, err := f.Snapshot() + f.Clear() + return out, err +} + +// Snapshot returns finalized metadata of the number of rows, row groups, version etc. +// The snapshot must be used (e.g., serialized) before any additional (meta)data is +// written, as it refers to builder datastructures that will continue to mutate. +func (f *FileMetaDataBuilder) Snapshot() (*FileMetaData, error) { totalRows := int64(0) for _, rg := range f.rowGroups { totalRows += rg.NumRows @@ -161,9 +170,13 @@ func (f *FileMetaDataBuilder) Finish() (*FileMetaData, error) { } out.initColumnOrders() + return out, nil +} + +// Clears out this filemetadatabuilder so it can be re-used +func (f *FileMetaDataBuilder) Clear() { f.metadata = format.NewFileMetaData() f.rowGroups = nil - return out, nil } // KeyValueMetadata is an alias for a slice of thrift keyvalue pairs. From e75bc99fa862a6703d83f44c027a52043851c530 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 26 Mar 2024 14:44:54 +1300 Subject: [PATCH 12/13] GH-40788: [C#] Override Accept in MapArray (#40789) ### Rationale for this change This allows users to implement `IArrowArrayVisitor` and have the `Visit(MapArray)` method called instead of `Visit(ListArray)` or `Visit(IArrowArray)`. ### What changes are included in this PR? Overrides the `Accept` method to check whether the visitor implements `Visit(MapArray)`, and if not, delegates to the base implementation to handle `IArrowArrayVisitor` or fall back to using an `IArrowArrayVisitor`. ### Are these changes tested? Yes, I've added unit tests. ### Are there any user-facing changes? Yes, this is a user-facing change. * GitHub Issue: #40788 Authored-by: Adam Reeve Signed-off-by: Curt Hagenlocher --- csharp/src/Apache.Arrow/Arrays/MapArray.cs | 13 +++ .../test/Apache.Arrow.Tests/MapArrayTests.cs | 110 ++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/csharp/src/Apache.Arrow/Arrays/MapArray.cs b/csharp/src/Apache.Arrow/Arrays/MapArray.cs index a6676b134e34a..dad50981ea54d 100644 --- a/csharp/src/Apache.Arrow/Arrays/MapArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/MapArray.cs @@ -135,6 +135,19 @@ private MapArray(ArrayData data, IArrowArray structs) : base(data, structs, Arro { } + public override void Accept(IArrowArrayVisitor visitor) + { + switch (visitor) + { + case IArrowArrayVisitor typedVisitor: + typedVisitor.Visit(this); + break; + default: + base.Accept(visitor); + break; + } + } + public IEnumerable> GetTuples(int index, Func getKey, Func getValue) where TKeyArray : Array where TValueArray : Array { diff --git a/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs b/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs index 7f35f104267dc..21decdacc0588 100644 --- a/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs +++ b/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs @@ -85,8 +85,118 @@ public void MapArray_Should_GetKeyValuePairs() Assert.Equal(new KeyValuePair[] { kv1, kv2 }, array.GetKeyValuePairs(2, GetKey, GetValue).ToArray()); } + [Fact] + public void MapArray_Should_AcceptMapVisitor() + { + var mapArray = BuildMapArray(); + var visitor = new MapOnlyVisitor(); + mapArray.Accept(visitor); + + Assert.True(visitor.MapVisited); + Assert.False(visitor.BaseVisited); + } + + [Fact] + public void MapArray_Should_AcceptListVisitor() + { + var mapArray = BuildMapArray(); + var visitor = new ListOnlyVisitor(); + mapArray.Accept(visitor); + + Assert.True(visitor.ListVisited); + Assert.False(visitor.BaseVisited); + } + + [Fact] + public void MapArray_Should_AcceptListAndMapVisitor() + { + var mapArray = BuildMapArray(); + var visitor = new MapAndListVisitor(); + mapArray.Accept(visitor); + + Assert.True(visitor.MapVisited); + Assert.False(visitor.ListVisited); + Assert.False(visitor.BaseVisited); + } + + private static MapArray BuildMapArray() + { + MapType type = new MapType(StringType.Default, Int64Type.Default); + MapArray.Builder builder = new MapArray.Builder(type); + var keyBuilder = builder.KeyBuilder as StringArray.Builder; + var valueBuilder = builder.ValueBuilder as Int64Array.Builder; + + builder.Append(); + keyBuilder.Append("test"); + valueBuilder.Append(1); + + builder.AppendNull(); + + builder.Append(); + keyBuilder.Append("other"); + valueBuilder.Append(123); + keyBuilder.Append("kv"); + valueBuilder.AppendNull(); + + return builder.Build(); + } + private static string GetKey(StringArray array, int index) => array.GetString(index); private static int? GetValue(Int32Array array, int index) => array.GetValue(index); private static long? GetValue(Int64Array array, int index) => array.GetValue(index); + + private sealed class MapOnlyVisitor : IArrowArrayVisitor + { + public bool MapVisited = false; + public bool BaseVisited = false; + + public void Visit(MapArray array) + { + MapVisited = true; + } + + public void Visit(IArrowArray array) + { + BaseVisited = true; + } + } + + private sealed class ListOnlyVisitor : IArrowArrayVisitor + { + public bool ListVisited = false; + public bool BaseVisited = false; + + public void Visit(ListArray array) + { + ListVisited = true; + } + + public void Visit(IArrowArray array) + { + BaseVisited = true; + } + } + + private sealed class MapAndListVisitor : IArrowArrayVisitor, IArrowArrayVisitor + { + public bool MapVisited = false; + public bool ListVisited = false; + public bool BaseVisited = false; + + public void Visit(MapArray array) + { + MapVisited = true; + } + + public void Visit(ListArray array) + { + ListVisited = true; + } + + public void Visit(IArrowArray array) + { + BaseVisited = true; + } + } } } From e3b0bd1feb63d59cd6fb553af976449397b8348e Mon Sep 17 00:00:00 2001 From: Felipe Oliveira Carvalho Date: Mon, 25 Mar 2024 23:14:27 -0300 Subject: [PATCH 13/13] GH-40783: [C++] Re-order loads and stores in MemoryPoolStats update (#40647) ### Rationale for this change Issue loads as soon as possible so the latency of waiting for memory is masked by doing other operations. ### What changes are included in this PR? - Make all the read-modify-write operations use `memory_order_acq_rel` - Make all the loads and stores use `memory_order_acquire`/`release` respectively - Statically specialize the implementation of `UpdateAllocatedBytes` so `bytes_allocated_` can be updated without waiting for the load of the old value ### Are these changes tested? By existing tests. * GitHub Issue: #40783 Authored-by: Felipe Oliveira Carvalho Signed-off-by: Felipe Oliveira Carvalho --- cpp/src/arrow/memory_pool.cc | 12 ++-- cpp/src/arrow/memory_pool.h | 82 +++++++++++++++++--------- cpp/src/arrow/memory_pool_benchmark.cc | 8 ++- cpp/src/arrow/memory_pool_test.cc | 11 +--- cpp/src/arrow/memory_pool_test.h | 9 +-- cpp/src/arrow/stl_allocator.h | 6 +- java/dataset/src/main/cpp/jni_util.cc | 6 +- 7 files changed, 80 insertions(+), 54 deletions(-) diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index d58c203d2ae27..2f8ce3a6fa8c7 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -472,7 +472,7 @@ class BaseMemoryPoolImpl : public MemoryPool { } #endif - stats_.UpdateAllocatedBytes(size); + stats_.DidAllocateBytes(size); return Status::OK(); } @@ -494,7 +494,7 @@ class BaseMemoryPoolImpl : public MemoryPool { } #endif - stats_.UpdateAllocatedBytes(new_size - old_size); + stats_.DidReallocateBytes(old_size, new_size); return Status::OK(); } @@ -509,7 +509,7 @@ class BaseMemoryPoolImpl : public MemoryPool { #endif Allocator::DeallocateAligned(buffer, size, alignment); - stats_.UpdateAllocatedBytes(-size, /*is_free*/ true); + stats_.DidFreeBytes(size); } void ReleaseUnused() override { Allocator::ReleaseUnused(); } @@ -761,20 +761,20 @@ class ProxyMemoryPool::ProxyMemoryPoolImpl { Status Allocate(int64_t size, int64_t alignment, uint8_t** out) { RETURN_NOT_OK(pool_->Allocate(size, alignment, out)); - stats_.UpdateAllocatedBytes(size); + stats_.DidAllocateBytes(size); return Status::OK(); } Status Reallocate(int64_t old_size, int64_t new_size, int64_t alignment, uint8_t** ptr) { RETURN_NOT_OK(pool_->Reallocate(old_size, new_size, alignment, ptr)); - stats_.UpdateAllocatedBytes(new_size - old_size); + stats_.DidReallocateBytes(old_size, new_size); return Status::OK(); } void Free(uint8_t* buffer, int64_t size, int64_t alignment) { pool_->Free(buffer, size, alignment); - stats_.UpdateAllocatedBytes(-size, /*is_free=*/true); + stats_.DidFreeBytes(size); } int64_t bytes_allocated() const { return stats_.bytes_allocated(); } diff --git a/cpp/src/arrow/memory_pool.h b/cpp/src/arrow/memory_pool.h index 712a828041c76..98c6dc3e211b8 100644 --- a/cpp/src/arrow/memory_pool.h +++ b/cpp/src/arrow/memory_pool.h @@ -35,44 +35,68 @@ namespace internal { /////////////////////////////////////////////////////////////////////// // Helper tracking memory statistics -class MemoryPoolStats { - public: - MemoryPoolStats() : bytes_allocated_(0), max_memory_(0) {} - - int64_t max_memory() const { return max_memory_.load(); } - - int64_t bytes_allocated() const { return bytes_allocated_.load(); } +/// \brief Memory pool statistics +/// +/// 64-byte aligned so that all atomic values are on the same cache line. +class alignas(64) MemoryPoolStats { + private: + // All atomics are updated according to Acquire-Release ordering. + // https://en.cppreference.com/w/cpp/atomic/memory_order#Release-Acquire_ordering + // + // max_memory_, total_allocated_bytes_, and num_allocs_ only go up (they are + // monotonically increasing) which can allow some optimizations. + std::atomic max_memory_{0}; + std::atomic bytes_allocated_{0}; + std::atomic total_allocated_bytes_{0}; + std::atomic num_allocs_{0}; - int64_t total_bytes_allocated() const { return total_allocated_bytes_.load(); } + public: + int64_t max_memory() const { return max_memory_.load(std::memory_order_acquire); } - int64_t num_allocations() const { return num_allocs_.load(); } + int64_t bytes_allocated() const { + return bytes_allocated_.load(std::memory_order_acquire); + } - inline void UpdateAllocatedBytes(int64_t diff, bool is_free = false) { - auto allocated = bytes_allocated_.fetch_add(diff) + diff; - // "maximum" allocated memory is ill-defined in multi-threaded code, - // so don't try to be too rigorous here - if (diff > 0 && allocated > max_memory_) { - max_memory_ = allocated; - } + int64_t total_bytes_allocated() const { + return total_allocated_bytes_.load(std::memory_order_acquire); + } - // Reallocations might just expand/contract the allocation in place or might - // copy to a new location. We can't really know, so we just represent the - // optimistic case. - if (diff > 0) { - total_allocated_bytes_ += diff; + int64_t num_allocations() const { return num_allocs_.load(std::memory_order_acquire); } + + inline void DidAllocateBytes(int64_t size) { + // Issue the load before everything else. max_memory_ is monotonically increasing, + // so we can use a relaxed load before the read-modify-write. + auto max_memory = max_memory_.load(std::memory_order_relaxed); + const auto old_bytes_allocated = + bytes_allocated_.fetch_add(size, std::memory_order_acq_rel); + // Issue store operations on values that we don't depend on to proceed + // with execution. When done, max_memory and old_bytes_allocated have + // a higher chance of being available on CPU registers. This also has the + // nice side-effect of putting 3 atomic stores close to each other in the + // instruction stream. + total_allocated_bytes_.fetch_add(size, std::memory_order_acq_rel); + num_allocs_.fetch_add(1, std::memory_order_acq_rel); + + // If other threads are updating max_memory_ concurrently we leave the loop without + // updating knowing that it already reached a value even higher than ours. + const auto allocated = old_bytes_allocated + size; + while (max_memory < allocated && !max_memory_.compare_exchange_weak( + /*expected=*/max_memory, /*desired=*/allocated, + std::memory_order_acq_rel)) { } + } - // We count any reallocation as a allocation. - if (!is_free) { - num_allocs_ += 1; + inline void DidReallocateBytes(int64_t old_size, int64_t new_size) { + if (new_size > old_size) { + DidAllocateBytes(new_size - old_size); + } else { + DidFreeBytes(old_size - new_size); } } - protected: - std::atomic bytes_allocated_ = 0; - std::atomic max_memory_ = 0; - std::atomic total_allocated_bytes_ = 0; - std::atomic num_allocs_ = 0; + inline void DidFreeBytes(int64_t size) { + bytes_allocated_.fetch_sub(size, std::memory_order_acq_rel); + } }; } // namespace internal diff --git a/cpp/src/arrow/memory_pool_benchmark.cc b/cpp/src/arrow/memory_pool_benchmark.cc index fe7a3dd2f8ee0..c2e55314b56f9 100644 --- a/cpp/src/arrow/memory_pool_benchmark.cc +++ b/cpp/src/arrow/memory_pool_benchmark.cc @@ -114,8 +114,12 @@ static void AllocateTouchDeallocate( state.SetBytesProcessed(state.iterations() * nbytes); } -#define BENCHMARK_ALLOCATE_ARGS \ - ->RangeMultiplier(16)->Range(4096, 16 * 1024 * 1024)->ArgName("size")->UseRealTime() +#define BENCHMARK_ALLOCATE_ARGS \ + ->RangeMultiplier(16) \ + ->Range(4096, 16 * 1024 * 1024) \ + ->ArgName("size") \ + ->UseRealTime() \ + ->ThreadRange(1, 32) #define BENCHMARK_ALLOCATE(benchmark_func, template_param) \ BENCHMARK_TEMPLATE(benchmark_func, template_param) BENCHMARK_ALLOCATE_ARGS diff --git a/cpp/src/arrow/memory_pool_test.cc b/cpp/src/arrow/memory_pool_test.cc index 81d9d69ba346d..3f0a852876718 100644 --- a/cpp/src/arrow/memory_pool_test.cc +++ b/cpp/src/arrow/memory_pool_test.cc @@ -106,11 +106,6 @@ TEST(DefaultMemoryPool, Identity) { specific_pools.end()); } -// Death tests and valgrind are known to not play well 100% of the time. See -// googletest documentation -#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)) - -// TODO: is this still a death test? TEST(DefaultMemoryPoolDeathTest, Statistics) { MemoryPool* pool = default_memory_pool(); uint8_t* data1; @@ -137,18 +132,16 @@ TEST(DefaultMemoryPoolDeathTest, Statistics) { ASSERT_EQ(150, pool->max_memory()); ASSERT_EQ(200, pool->total_bytes_allocated()); ASSERT_EQ(50, pool->bytes_allocated()); - ASSERT_EQ(4, pool->num_allocations()); + ASSERT_EQ(3, pool->num_allocations()); pool->Free(data1, 50); ASSERT_EQ(150, pool->max_memory()); ASSERT_EQ(200, pool->total_bytes_allocated()); ASSERT_EQ(0, pool->bytes_allocated()); - ASSERT_EQ(4, pool->num_allocations()); + ASSERT_EQ(3, pool->num_allocations()); } -#endif // ARROW_VALGRIND - TEST(LoggingMemoryPool, Logging) { auto pool = MemoryPool::CreateDefault(); diff --git a/cpp/src/arrow/memory_pool_test.h b/cpp/src/arrow/memory_pool_test.h index e4a07099f830f..32f1cc5d1d310 100644 --- a/cpp/src/arrow/memory_pool_test.h +++ b/cpp/src/arrow/memory_pool_test.h @@ -38,19 +38,20 @@ class TestMemoryPoolBase : public ::testing::Test { auto pool = memory_pool(); uint8_t* data; + const auto old_bytes_allocated = pool->bytes_allocated(); ASSERT_OK(pool->Allocate(100, &data)); EXPECT_EQ(static_cast(0), reinterpret_cast(data) % 64); - ASSERT_EQ(100, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated + 100, pool->bytes_allocated()); uint8_t* data2; ASSERT_OK(pool->Allocate(27, &data2)); EXPECT_EQ(static_cast(0), reinterpret_cast(data2) % 64); - ASSERT_EQ(127, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated + 127, pool->bytes_allocated()); pool->Free(data, 100); - ASSERT_EQ(27, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated + 27, pool->bytes_allocated()); pool->Free(data2, 27); - ASSERT_EQ(0, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated, pool->bytes_allocated()); } void TestOOM() { diff --git a/cpp/src/arrow/stl_allocator.h b/cpp/src/arrow/stl_allocator.h index a1f4ae9feb82b..82e6aaa8772b9 100644 --- a/cpp/src/arrow/stl_allocator.h +++ b/cpp/src/arrow/stl_allocator.h @@ -110,7 +110,7 @@ class STLMemoryPool : public MemoryPool { } catch (std::bad_alloc& e) { return Status::OutOfMemory(e.what()); } - stats_.UpdateAllocatedBytes(size); + stats_.DidAllocateBytes(size); return Status::OK(); } @@ -124,13 +124,13 @@ class STLMemoryPool : public MemoryPool { } memcpy(*ptr, old_ptr, std::min(old_size, new_size)); alloc_.deallocate(old_ptr, old_size); - stats_.UpdateAllocatedBytes(new_size - old_size); + stats_.DidReallocateBytes(old_size, new_size); return Status::OK(); } void Free(uint8_t* buffer, int64_t size, int64_t /*alignment*/) override { alloc_.deallocate(buffer, size); - stats_.UpdateAllocatedBytes(-size, /*is_free=*/true); + stats_.DidFreeBytes(size); } int64_t bytes_allocated() const override { return stats_.bytes_allocated(); } diff --git a/java/dataset/src/main/cpp/jni_util.cc b/java/dataset/src/main/cpp/jni_util.cc index f1b5a7f7c650e..8e899527f6a99 100644 --- a/java/dataset/src/main/cpp/jni_util.cc +++ b/java/dataset/src/main/cpp/jni_util.cc @@ -97,7 +97,11 @@ class ReservationListenableMemoryPool::Impl { int64_t Reserve(int64_t diff) { std::lock_guard lock(mutex_); - stats_.UpdateAllocatedBytes(diff); + if (diff > 0) { + stats_.DidAllocateBytes(diff); + } else if (diff < 0) { + stats_.DidFreeBytes(-diff); + } int64_t new_block_count; int64_t bytes_reserved = stats_.bytes_allocated(); if (bytes_reserved == 0) {