diff --git a/src/streaming/s3.sink.cpp b/src/streaming/s3.sink.cpp index 04ea3c7..81150ab 100644 --- a/src/streaming/s3.sink.cpp +++ b/src/streaming/s3.sink.cpp @@ -128,17 +128,13 @@ zarr::S3Sink::is_multipart_upload_() const void zarr::S3Sink::create_multipart_upload_() { - if (!is_multipart_upload_()) { - multipart_upload_ = {}; - } - - if (!multipart_upload_->upload_id.empty()) { - return; - } + multipart_upload_ = MultiPartUpload{}; + auto connection = connection_pool_->get_connection(); multipart_upload_->upload_id = - connection_pool_->get_connection()->create_multipart_object(bucket_name_, - object_key_); + connection->create_multipart_object(bucket_name_, object_key_); + + connection_pool_->return_connection(std::move(connection)); } bool @@ -148,9 +144,11 @@ zarr::S3Sink::flush_part_() return false; } - auto connection = connection_pool_->get_connection(); + if (!is_multipart_upload_()) { + create_multipart_upload_(); + } - create_multipart_upload_(); + auto connection = connection_pool_->get_connection(); bool retval = false; try { diff --git a/tests/unit-tests/CMakeLists.txt b/tests/unit-tests/CMakeLists.txt index aefbc34..6069b73 100644 --- a/tests/unit-tests/CMakeLists.txt +++ b/tests/unit-tests/CMakeLists.txt @@ -14,6 +14,7 @@ set(tests s3-connection-upload-multipart-object file-sink-write s3-sink-write + s3-sink-write-multipart sink-creator-make-metadata-sinks sink-creator-make-data-sinks array-writer-downsample-writer-config diff --git a/tests/unit-tests/s3-sink-write-multipart.cpp b/tests/unit-tests/s3-sink-write-multipart.cpp new file mode 100644 index 0000000..a335281 --- /dev/null +++ b/tests/unit-tests/s3-sink-write-multipart.cpp @@ -0,0 +1,112 @@ +#include "s3.sink.hh" +#include "unit.test.macros.hh" + +#include +#include + +namespace { +bool +get_credentials(std::string& endpoint, + std::string& bucket_name, + std::string& access_key_id, + std::string& secret_access_key) +{ + char* env = nullptr; + if (!(env = std::getenv("ZARR_S3_ENDPOINT"))) { + LOG_ERROR("ZARR_S3_ENDPOINT not set."); + return false; + } + endpoint = env; + + if (!(env = std::getenv("ZARR_S3_BUCKET_NAME"))) { + LOG_ERROR("ZARR_S3_BUCKET_NAME not set."); + return false; + } + bucket_name = env; + + if (!(env = std::getenv("ZARR_S3_ACCESS_KEY_ID"))) { + LOG_ERROR("ZARR_S3_ACCESS_KEY_ID not set."); + return false; + } + access_key_id = env; + + if (!(env = std::getenv("ZARR_S3_SECRET_ACCESS_KEY"))) { + LOG_ERROR("ZARR_S3_SECRET_ACCESS_KEY not set."); + return false; + } + secret_access_key = env; + + return true; +} +} // namespace + +int +main() +{ + std::string s3_endpoint, bucket_name, s3_access_key_id, + s3_secret_access_key; + + if (!get_credentials( + s3_endpoint, bucket_name, s3_access_key_id, s3_secret_access_key)) { + LOG_WARNING("Failed to get credentials. Skipping test."); + return 0; + } + + int retval = 1; + const std::string object_name = "test-object"; + + try { + auto pool = std::make_shared( + 1, s3_endpoint, s3_access_key_id, s3_secret_access_key); + + auto conn = pool->get_connection(); + if (!conn->is_connection_valid()) { + LOG_ERROR("Failed to connect to S3."); + return 1; + } + CHECK(conn->bucket_exists(bucket_name)); + CHECK(conn->delete_object(bucket_name, object_name)); + CHECK(!conn->object_exists(bucket_name, object_name)); + + pool->return_connection(std::move(conn)); + + std::vector data((5 << 20) + 1, std::byte{ 0 }); + { + auto sink = + std::make_unique(bucket_name, object_name, pool); + CHECK(sink->write(0, data)); + CHECK(zarr::finalize_sink(std::move(sink))); + } + + conn = pool->get_connection(); + CHECK(conn->object_exists(bucket_name, object_name)); + pool->return_connection(std::move(conn)); + + // Verify the object size. + { + minio::s3::BaseUrl url(s3_endpoint); + url.https = s3_endpoint.starts_with("https://"); + + minio::creds::StaticProvider provider(s3_access_key_id, + s3_secret_access_key); + + minio::s3::Client client(url, &provider); + minio::s3::StatObjectArgs args; + args.bucket = bucket_name; + args.object = object_name; + + minio::s3::StatObjectResponse resp = client.StatObject(args); + EXPECT_EQ(int, data.size(), resp.size); + } + + // cleanup + conn = pool->get_connection(); + CHECK(conn->delete_object(bucket_name, object_name)); + + retval = 0; + } catch (const std::exception& e) { + LOG_ERROR("Exception: ", e.what()); + } + + return retval; +} \ No newline at end of file