diff --git a/src/streaming/CMakeLists.txt b/src/streaming/CMakeLists.txt index 33d68e2a..598fffe5 100644 --- a/src/streaming/CMakeLists.txt +++ b/src/streaming/CMakeLists.txt @@ -9,6 +9,12 @@ add_library(${tgt} zarr.stream.cpp zarr.common.hh zarr.common.cpp + blosc.compression.params.hh + blosc.compression.params.cpp + thread.pool.hh + thread.pool.cpp + s3.connection.hh + s3.connection.cpp ) target_include_directories(${tgt} diff --git a/src/streaming/blosc.compression.params.cpp b/src/streaming/blosc.compression.params.cpp new file mode 100644 index 00000000..19af8e8c --- /dev/null +++ b/src/streaming/blosc.compression.params.cpp @@ -0,0 +1,23 @@ +#include "blosc.compression.params.hh" + +const char* +zarr::blosc_codec_to_string(ZarrCompressionCodec codec) +{ + switch (codec) { + case ZarrCompressionCodec_BloscZstd: + return "zstd"; + case ZarrCompressionCodec_BloscLZ4: + return "lz4"; + default: + return "unrecognized codec"; + } +} + +zarr::BloscCompressionParams::BloscCompressionParams(std::string_view codec_id, + uint8_t clevel, + uint8_t shuffle) + : codec_id{ codec_id } + , clevel{ clevel } + , shuffle{ shuffle } +{ +} diff --git a/src/streaming/blosc.compression.params.hh b/src/streaming/blosc.compression.params.hh new file mode 100644 index 00000000..dcd9cd42 --- /dev/null +++ b/src/streaming/blosc.compression.params.hh @@ -0,0 +1,25 @@ +#pragma once + +#include "acquire.zarr.h" + +#include + +#include +#include + +namespace zarr { +const char* +blosc_codec_to_string(ZarrCompressionCodec codec); + +struct BloscCompressionParams +{ + std::string codec_id; + uint8_t clevel{ 1 }; + uint8_t shuffle{ 1 }; + + BloscCompressionParams() = default; + BloscCompressionParams(std::string_view codec_id, + uint8_t clevel, + uint8_t shuffle); +}; +} // namespace zarr diff --git a/src/streaming/s3.connection.cpp b/src/streaming/s3.connection.cpp new file mode 100644 index 00000000..c9ffbf88 --- /dev/null +++ b/src/streaming/s3.connection.cpp @@ -0,0 +1,262 @@ +#include "macros.hh" +#include "s3.connection.hh" + +#include + +#include +#include +#include + +zarr::S3Connection::S3Connection(const std::string& endpoint, + const std::string& access_key_id, + const std::string& secret_access_key) +{ + minio::s3::BaseUrl url(endpoint); + url.https = endpoint.starts_with("https"); + + provider_ = std::make_unique( + access_key_id, secret_access_key); + client_ = std::make_unique(url, provider_.get()); + + CHECK(client_); +} + +bool +zarr::S3Connection::is_connection_valid() +{ + return static_cast(client_->ListBuckets()); +} + +bool +zarr::S3Connection::bucket_exists(std::string_view bucket_name) +{ + minio::s3::BucketExistsArgs args; + args.bucket = bucket_name; + + auto response = client_->BucketExists(args); + return response.exist; +} + +bool +zarr::S3Connection::object_exists(std::string_view bucket_name, + std::string_view object_name) +{ + minio::s3::StatObjectArgs args; + args.bucket = bucket_name; + args.object = object_name; + + auto response = client_->StatObject(args); + // casts to true if response code in 200 range and error message is empty + return static_cast(response); +} + +std::string +zarr::S3Connection::put_object(std::string_view bucket_name, + std::string_view object_name, + std::span data) +{ + EXPECT(!bucket_name.empty(), "Bucket name must not be empty."); + EXPECT(!object_name.empty(), "Object name must not be empty."); + EXPECT(!data.empty(), "Data must not be empty."); + + minio::utils::CharBuffer buffer(reinterpret_cast(data.data()), + data.size()); + std::basic_istream stream(&buffer); + + LOG_DEBUG( + "Putting object %s in bucket %s", object_name.data(), bucket_name.data()); + minio::s3::PutObjectArgs args(stream, static_cast(data.size()), 0); + args.bucket = bucket_name; + args.object = object_name; + + auto response = client_->PutObject(args); + if (!response) { + LOG_ERROR("Failed to put object %s in bucket %s: %s", + object_name.data(), + bucket_name.data(), + response.Error().String().c_str()); + return {}; + } + + return response.etag; +} + +bool +zarr::S3Connection::delete_object(std::string_view bucket_name, + std::string_view object_name) +{ + EXPECT(!bucket_name.empty(), "Bucket name must not be empty."); + EXPECT(!object_name.empty(), "Object name must not be empty."); + + LOG_DEBUG("Deleting object %s from bucket %s", + object_name.data(), + bucket_name.data()); + minio::s3::RemoveObjectArgs args; + args.bucket = bucket_name; + args.object = object_name; + + auto response = client_->RemoveObject(args); + if (!response) { + LOG_ERROR("Failed to delete object %s from bucket %s: %s", + object_name.data(), + bucket_name.data(), + response.Error().String().c_str()); + return false; + } + + return true; +} + +std::string +zarr::S3Connection::create_multipart_object(std::string_view bucket_name, + std::string_view object_name) +{ + EXPECT(!bucket_name.empty(), "Bucket name must not be empty."); + EXPECT(!object_name.empty(), "Object name must not be empty."); + + LOG_DEBUG( + "Creating multipart object ", object_name, " in bucket ", bucket_name); + minio::s3::CreateMultipartUploadArgs args; + args.bucket = bucket_name; + args.object = object_name; + + auto response = client_->CreateMultipartUpload(args); + if (!response) { + LOG_ERROR("Failed to create multipart object ", + object_name, + " in bucket ", + bucket_name, + ": ", + response.Error().String()); + } + EXPECT(!response.upload_id.empty(), "Upload id returned empty."); + + return response.upload_id; +} + +std::string +zarr::S3Connection::upload_multipart_object_part(std::string_view bucket_name, + std::string_view object_name, + std::string_view upload_id, + std::span data, + unsigned int part_number) +{ + EXPECT(!bucket_name.empty(), "Bucket name must not be empty."); + EXPECT(!object_name.empty(), "Object name must not be empty."); + EXPECT(!data.empty(), "Number of bytes must be positive."); + EXPECT(part_number, "Part number must be positive."); + + LOG_DEBUG("Uploading multipart object part ", + part_number, + " for object ", + object_name, + " in bucket ", + bucket_name); + + std::string_view data_buffer(reinterpret_cast(data.data()), + data.size()); + + minio::s3::UploadPartArgs args; + args.bucket = bucket_name; + args.object = object_name; + args.part_number = part_number; + args.upload_id = upload_id; + args.data = data_buffer; + + auto response = client_->UploadPart(args); + if (!response) { + LOG_ERROR("Failed to upload part ", + part_number, + " for object ", + object_name, + " in bucket ", + bucket_name, + ": ", + response.Error().String()); + return {}; + } + + return response.etag; +} + +bool +zarr::S3Connection::complete_multipart_object( + std::string_view bucket_name, + std::string_view object_name, + std::string_view upload_id, + const std::list& parts) +{ + EXPECT(!bucket_name.empty(), "Bucket name must not be empty."); + EXPECT(!object_name.empty(), "Object name must not be empty."); + EXPECT(!upload_id.empty(), "Upload id must not be empty."); + EXPECT(!parts.empty(), "Parts list must not be empty."); + + LOG_DEBUG("Completing multipart object %s in bucket %s", + object_name.data(), + bucket_name.data()); + minio::s3::CompleteMultipartUploadArgs args; + args.bucket = bucket_name; + args.object = object_name; + args.upload_id = upload_id; + args.parts = parts; + + auto response = client_->CompleteMultipartUpload(args); + if (!response) { + LOG_ERROR("Failed to complete multipart object %s in bucket %s: %s", + object_name.data(), + bucket_name.data(), + response.Error().String().c_str()); + return false; + } + + return true; +} + +zarr::S3ConnectionPool::S3ConnectionPool(size_t n_connections, + const std::string& endpoint, + const std::string& access_key_id, + const std::string& secret_access_key) +{ + for (auto i = 0; i < n_connections; ++i) { + auto connection = std::make_unique( + endpoint, access_key_id, secret_access_key); + + if (connection->is_connection_valid()) { + connections_.push_back(std::make_unique( + endpoint, access_key_id, secret_access_key)); + } + } + + CHECK(!connections_.empty()); +} + +zarr::S3ConnectionPool::~S3ConnectionPool() +{ + is_accepting_connections_ = false; + cv_.notify_all(); +} + +std::unique_ptr +zarr::S3ConnectionPool::get_connection() +{ + std::unique_lock lock(connections_mutex_); + cv_.wait(lock, [this] { + return !is_accepting_connections_ || !connections_.empty(); + }); + + if (!is_accepting_connections_ || connections_.empty()) { + return nullptr; + } + + auto conn = std::move(connections_.back()); + connections_.pop_back(); + return conn; +} + +void +zarr::S3ConnectionPool::return_connection(std::unique_ptr&& conn) +{ + std::unique_lock lock(connections_mutex_); + connections_.push_back(std::move(conn)); + cv_.notify_one(); +} diff --git a/src/streaming/s3.connection.hh b/src/streaming/s3.connection.hh new file mode 100644 index 00000000..91e71f27 --- /dev/null +++ b/src/streaming/s3.connection.hh @@ -0,0 +1,138 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace zarr { +class S3Connection +{ + public: + S3Connection(const std::string& endpoint, + const std::string& access_key_id, + const std::string& secret_access_key); + + /** + * @brief Test a connection by listing all buckets at this connection's + * endpoint. + * @returns True if the connection is valid, otherwise false. + */ + bool is_connection_valid(); + + /* Bucket operations */ + + /** + * @brief Check whether a bucket exists. + * @param bucket_name The name of the bucket. + * @returns True if the bucket exists, otherwise false. + */ + bool bucket_exists(std::string_view bucket_name); + + /* Object operations */ + + /** + * @brief Check whether an object exists. + * @param bucket_name The name of the bucket containing the object. + * @param object_name The name of the object. + * @returns True if the object exists, otherwise false. + */ + bool object_exists(std::string_view bucket_name, + std::string_view object_name); + + /** + * @brief Put an object. + * @param bucket_name The name of the bucket to put the object in. + * @param object_name The name of the object. + * @param data The data to put in the object. + * @returns The etag of the object. + * @throws std::runtime_error if the bucket name is empty, the object name + * is empty, or @p data is empty. + */ + [[nodiscard]] std::string put_object(std::string_view bucket_name, + std::string_view object_name, + std::span data); + + /** + * @brief Delete an object. + * @param bucket_name The name of the bucket containing the object. + * @param object_name The name of the object. + * @returns True if the object was successfully deleted, otherwise false. + * @throws std::runtime_error if the bucket name is empty or the object + * name is empty. + */ + [[nodiscard]] bool delete_object(std::string_view bucket_name, + std::string_view object_name); + + /* Multipart object operations */ + + /// @brief Create a multipart object. + /// @param bucket_name The name of the bucket containing the object. + /// @param object_name The name of the object. + /// @returns The upload id of the multipart object. Nonempty if and only if + /// the operation succeeds. + /// @throws std::runtime_error if the bucket name is empty or the object + /// name is empty. + [[nodiscard]] std::string create_multipart_object( + std::string_view bucket_name, + std::string_view object_name); + + /// @brief Upload a part of a multipart object. + /// @param bucket_name The name of the bucket containing the object. + /// @param object_name The name of the object. + /// @param upload_id The upload id of the multipart object. + /// @param data The data to upload. + /// @param part_number The part number of the object. + /// @returns The etag of the uploaded part. Nonempty if and only if the + /// operation is successful. + /// @throws std::runtime_error if the bucket name is empty, the object name + /// is empty, @p data is empty, or @p part_number is 0. + [[nodiscard]] std::string upload_multipart_object_part( + std::string_view bucket_name, + std::string_view object_name, + std::string_view upload_id, + std::span data, + unsigned int part_number); + + /// @brief Complete a multipart object. + /// @param bucket_name The name of the bucket containing the object. + /// @param object_name The name of the object. + /// @param upload_id The upload id of the multipart object. + /// @param parts List of the parts making up the object. + /// @returns True if the object was successfully completed, otherwise false. + [[nodiscard]] bool complete_multipart_object( + std::string_view bucket_name, + std::string_view object_name, + std::string_view upload_id, + const std::list& parts); + + private: + std::unique_ptr client_; + std::unique_ptr provider_; +}; + +class S3ConnectionPool +{ + public: + S3ConnectionPool(size_t n_connections, + const std::string& endpoint, + const std::string& access_key_id, + const std::string& secret_access_key); + ~S3ConnectionPool(); + + std::unique_ptr get_connection(); + void return_connection(std::unique_ptr&& conn); + + private: + std::vector> connections_; + std::mutex connections_mutex_; + std::condition_variable cv_; + + std::atomic is_accepting_connections_{true}; +}; +} // namespace zarr diff --git a/src/streaming/thread.pool.cpp b/src/streaming/thread.pool.cpp new file mode 100644 index 00000000..2550dd8a --- /dev/null +++ b/src/streaming/thread.pool.cpp @@ -0,0 +1,94 @@ +#include "thread.pool.hh" + +zarr::ThreadPool::ThreadPool(unsigned int n_threads, ErrorCallback&& err) + : error_handler_{ std::move(err) } +{ + const auto max_threads = std::max(std::thread::hardware_concurrency(), 1u); + n_threads = std::clamp(n_threads, 1u, max_threads); + + for (auto i = 0; i < n_threads; ++i) { + threads_.emplace_back([this] { process_tasks_(); }); + } +} + +zarr::ThreadPool::~ThreadPool() noexcept +{ + { + std::unique_lock lock(jobs_mutex_); + while (!jobs_.empty()) { + jobs_.pop(); + } + } + + await_stop(); +} + +bool +zarr::ThreadPool::push_job(Task&& job) +{ + std::unique_lock lock(jobs_mutex_); + if (!is_accepting_jobs_) { + return false; + } + + jobs_.push(std::move(job)); + cv_.notify_one(); + + return true; +} + +void +zarr::ThreadPool::await_stop() noexcept +{ + { + std::scoped_lock lock(jobs_mutex_); + is_accepting_jobs_ = false; + + cv_.notify_all(); + } + + // spin down threads + for (auto& thread : threads_) { + if (thread.joinable()) { + thread.join(); + } + } +} + +std::optional +zarr::ThreadPool::pop_from_job_queue_() noexcept +{ + if (jobs_.empty()) { + return std::nullopt; + } + + auto job = std::move(jobs_.front()); + jobs_.pop(); + return job; +} + +bool +zarr::ThreadPool::should_stop_() const noexcept +{ + return !is_accepting_jobs_ && jobs_.empty(); +} + +void +zarr::ThreadPool::process_tasks_() +{ + while (true) { + std::unique_lock lock(jobs_mutex_); + cv_.wait(lock, [&] { return should_stop_() || !jobs_.empty(); }); + + if (should_stop_()) { + break; + } + + if (auto job = pop_from_job_queue_(); job.has_value()) { + lock.unlock(); + if (std::string err_msg; !job.value()(err_msg)) { + error_handler_(err_msg); + } + } + } +} \ No newline at end of file diff --git a/src/streaming/thread.pool.hh b/src/streaming/thread.pool.hh new file mode 100644 index 00000000..c04c966b --- /dev/null +++ b/src/streaming/thread.pool.hh @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace zarr { +class ThreadPool +{ + public: + using Task = std::function; + using ErrorCallback = std::function; + + // The error handler `err` is called when a job returns false. This + // can happen when the job encounters an error, or otherwise fails. The + // std::string& argument to the error handler is a diagnostic message from + // the failing job and is logged to the error stream by the Zarr driver when + // the next call to `append()` is made. + ThreadPool(unsigned int n_threads, ErrorCallback&& err); + ~ThreadPool() noexcept; + + /** + * @brief Push a job onto the job queue. + * + * @param job The job to push onto the queue. + * @return true if the job was successfully pushed onto the queue, false + * otherwise. + */ + [[nodiscard]] bool push_job(Task&& job); + + /** + * @brief Block until all jobs on the queue have processed, then spin down + * the threads. + * @note After calling this function, the job queue no longer accepts jobs. + */ + void await_stop() noexcept; + + private: + ErrorCallback error_handler_; + + std::vector threads_; + std::mutex jobs_mutex_; + std::condition_variable cv_; + std::queue jobs_; + + std::atomic is_accepting_jobs_{ true }; + + std::optional pop_from_job_queue_() noexcept; + [[nodiscard]] bool should_stop_() const noexcept; + void process_tasks_(); +}; +} // zarr diff --git a/tests/unit-tests/CMakeLists.txt b/tests/unit-tests/CMakeLists.txt index 4533ee71..f5381a00 100644 --- a/tests/unit-tests/CMakeLists.txt +++ b/tests/unit-tests/CMakeLists.txt @@ -7,6 +7,11 @@ set(tests array-dimensions-chunk-internal-offset array-dimensions-shard-index-for-chunk array-dimensions-shard-internal-index + thread-pool-push-to-job-queue + s3-connection-bucket-exists + s3-connection-object-exists-check-false-positives + s3-connection-put-object + s3-connection-upload-multipart-object ) foreach (name ${tests}) diff --git a/tests/unit-tests/s3-connection-bucket-exists.cpp b/tests/unit-tests/s3-connection-bucket-exists.cpp new file mode 100644 index 00000000..b62c25a1 --- /dev/null +++ b/tests/unit-tests/s3-connection-bucket-exists.cpp @@ -0,0 +1,79 @@ +#include "s3.connection.hh" +#include "unit.test.macros.hh" + +#include +#include + +namespace { +bool +get_credentials(std::string& endpoint, + std::string& bucket_name, + std::string& access_key_id, + std::string& secret_access_key) +{ + char* env = nullptr; + if (!(env = std::getenv("ZARR_S3_ENDPOINT"))) { + LOG_ERROR("ZARR_S3_ENDPOINT not set."); + return false; + } + endpoint = env; + + if (!(env = std::getenv("ZARR_S3_BUCKET_NAME"))) { + LOG_ERROR("ZARR_S3_BUCKET_NAME not set."); + return false; + } + bucket_name = env; + + if (!(env = std::getenv("ZARR_S3_ACCESS_KEY_ID"))) { + LOG_ERROR("ZARR_S3_ACCESS_KEY_ID not set."); + return false; + } + access_key_id = env; + + if (!(env = std::getenv("ZARR_S3_SECRET_ACCESS_KEY"))) { + LOG_ERROR("ZARR_S3_SECRET_ACCESS_KEY not set."); + return false; + } + secret_access_key = env; + + return true; +} +} // namespace + +int +main() +{ + std::string s3_endpoint, bucket_name, s3_access_key_id, + s3_secret_access_key; + if (!get_credentials( + s3_endpoint, bucket_name, s3_access_key_id, s3_secret_access_key)) { + LOG_WARNING("Failed to get credentials. Skipping test."); + return 0; + } + + int retval = 1; + + try { + zarr::S3Connection conn{ s3_endpoint, + s3_access_key_id, + s3_secret_access_key }; + + if (!conn.is_connection_valid()) { + LOG_ERROR("Failed to connect to S3."); + return 1; + } + + if (conn.bucket_exists("")) { + LOG_ERROR("False positive response for empty bucket name."); + return 1; + } + + CHECK(conn.bucket_exists(bucket_name)); + + retval = 0; + } catch (const std::exception& e) { + LOG_ERROR("Failed: %s", e.what()); + } + + return retval; +} \ No newline at end of file diff --git a/tests/unit-tests/s3-connection-object-exists-check-false-positives.cpp b/tests/unit-tests/s3-connection-object-exists-check-false-positives.cpp new file mode 100644 index 00000000..944ae327 --- /dev/null +++ b/tests/unit-tests/s3-connection-object-exists-check-false-positives.cpp @@ -0,0 +1,85 @@ +#include "s3.connection.hh" +#include "unit.test.macros.hh" + +#include +#include + +namespace { +bool +get_credentials(std::string& endpoint, + std::string& bucket_name, + std::string& access_key_id, + std::string& secret_access_key) +{ + char* env = nullptr; + if (!(env = std::getenv("ZARR_S3_ENDPOINT"))) { + LOG_ERROR("ZARR_S3_ENDPOINT not set."); + return false; + } + endpoint = env; + + if (!(env = std::getenv("ZARR_S3_BUCKET_NAME"))) { + LOG_ERROR("ZARR_S3_BUCKET_NAME not set."); + return false; + } + bucket_name = env; + + if (!(env = std::getenv("ZARR_S3_ACCESS_KEY_ID"))) { + LOG_ERROR("ZARR_S3_ACCESS_KEY_ID not set."); + return false; + } + access_key_id = env; + + if (!(env = std::getenv("ZARR_S3_SECRET_ACCESS_KEY"))) { + LOG_ERROR("ZARR_S3_SECRET_ACCESS_KEY not set."); + return false; + } + secret_access_key = env; + + return true; +} +} // namespace + +int +main() +{ + std::string s3_endpoint, bucket_name, s3_access_key_id, + s3_secret_access_key; + if (!get_credentials( + s3_endpoint, bucket_name, s3_access_key_id, s3_secret_access_key)) { + LOG_WARNING("Failed to get credentials. Skipping test."); + return 0; + } + + int retval = 1; + const std::string object_name = "test-object"; + + try { + zarr::S3Connection conn{ s3_endpoint, + s3_access_key_id, + s3_secret_access_key }; + + if (!conn.is_connection_valid()) { + LOG_ERROR("Failed to connect to S3."); + return 1; + } + + CHECK(conn.bucket_exists(bucket_name)); + + if (conn.object_exists("", object_name)) { + LOG_ERROR("False positive for empty bucket name."); + return 1; + } + + if (conn.object_exists(bucket_name, "")) { + LOG_ERROR("False positive for empty object name."); + return 1; + } + + retval = 0; + } catch (const std::exception& e) { + LOG_ERROR("Failed: %s", e.what()); + } + + return retval; +} \ No newline at end of file diff --git a/tests/unit-tests/s3-connection-put-object.cpp b/tests/unit-tests/s3-connection-put-object.cpp new file mode 100644 index 00000000..beef999b --- /dev/null +++ b/tests/unit-tests/s3-connection-put-object.cpp @@ -0,0 +1,88 @@ +#include "s3.connection.hh" +#include "unit.test.macros.hh" + +#include + +namespace { +bool +get_credentials(std::string& endpoint, + std::string& bucket_name, + std::string& access_key_id, + std::string& secret_access_key) +{ + char* env = nullptr; + if (!(env = std::getenv("ZARR_S3_ENDPOINT"))) { + LOG_ERROR("ZARR_S3_ENDPOINT not set."); + return false; + } + endpoint = env; + + if (!(env = std::getenv("ZARR_S3_BUCKET_NAME"))) { + LOG_ERROR("ZARR_S3_BUCKET_NAME not set."); + return false; + } + bucket_name = env; + + if (!(env = std::getenv("ZARR_S3_ACCESS_KEY_ID"))) { + LOG_ERROR("ZARR_S3_ACCESS_KEY_ID not set."); + return false; + } + access_key_id = env; + + if (!(env = std::getenv("ZARR_S3_SECRET_ACCESS_KEY"))) { + LOG_ERROR("ZARR_S3_SECRET_ACCESS_KEY not set."); + return false; + } + secret_access_key = env; + + return true; +} +} // namespace + +int +main() +{ + std::string s3_endpoint, bucket_name, s3_access_key_id, + s3_secret_access_key; + if (!get_credentials( + s3_endpoint, bucket_name, s3_access_key_id, s3_secret_access_key)) { + LOG_WARNING("Failed to get credentials. Skipping test."); + return 0; + } + + int retval = 1; + const std::string object_name = "test-object"; + + try { + zarr::S3Connection conn{ s3_endpoint, + s3_access_key_id, + s3_secret_access_key }; + + if (!conn.is_connection_valid()) { + LOG_ERROR("Failed to connect to S3."); + return 1; + } + CHECK(conn.bucket_exists(bucket_name)); + CHECK(conn.delete_object(bucket_name, object_name)); + CHECK(!conn.object_exists(bucket_name, object_name)); + + std::vector data(1024, 0); + + std::string etag = + conn.put_object(bucket_name, + object_name, + std::span(data.data(), data.size())); + CHECK(!etag.empty()); + + CHECK(conn.object_exists(bucket_name, object_name)); + + // cleanup + CHECK(conn.delete_object(bucket_name, object_name)); + + retval = 0; + } catch (const std::exception& e) { + LOG_ERROR("Failed: %s", e.what()); + } + + return retval; +} \ No newline at end of file diff --git a/tests/unit-tests/s3-connection-upload-multipart-object.cpp b/tests/unit-tests/s3-connection-upload-multipart-object.cpp new file mode 100644 index 00000000..d3d6b316 --- /dev/null +++ b/tests/unit-tests/s3-connection-upload-multipart-object.cpp @@ -0,0 +1,127 @@ +#include "s3.connection.hh" +#include "unit.test.macros.hh" + +#include + +namespace { +bool +get_credentials(std::string& endpoint, + std::string& bucket_name, + std::string& access_key_id, + std::string& secret_access_key) +{ + char* env = nullptr; + if (!(env = std::getenv("ZARR_S3_ENDPOINT"))) { + LOG_ERROR("ZARR_S3_ENDPOINT not set."); + return false; + } + endpoint = env; + + if (!(env = std::getenv("ZARR_S3_BUCKET_NAME"))) { + LOG_ERROR("ZARR_S3_BUCKET_NAME not set."); + return false; + } + bucket_name = env; + + if (!(env = std::getenv("ZARR_S3_ACCESS_KEY_ID"))) { + LOG_ERROR("ZARR_S3_ACCESS_KEY_ID not set."); + return false; + } + access_key_id = env; + + if (!(env = std::getenv("ZARR_S3_SECRET_ACCESS_KEY"))) { + LOG_ERROR("ZARR_S3_SECRET_ACCESS_KEY not set."); + return false; + } + secret_access_key = env; + + return true; +} +} // namespace + +int +main() +{ + std::string s3_endpoint, bucket_name, s3_access_key_id, + s3_secret_access_key; + if (!get_credentials( + s3_endpoint, bucket_name, s3_access_key_id, s3_secret_access_key)) { + LOG_WARNING("Failed to get credentials. Skipping test."); + return 0; + } + + int retval = 1; + const std::string object_name = "test-object"; + + try { + zarr::S3Connection conn( + s3_endpoint, s3_access_key_id, s3_secret_access_key); + + if (!conn.is_connection_valid()) { + LOG_ERROR("Failed to connect to S3."); + return 1; + } + CHECK(conn.bucket_exists(bucket_name)); + CHECK(conn.delete_object(bucket_name, object_name)); + CHECK(!conn.object_exists(bucket_name, object_name)); + + std::string upload_id = + conn.create_multipart_object(bucket_name, object_name); + CHECK(!upload_id.empty()); + + std::list parts; + + // parts need to be at least 5MiB, except the last part + std::vector data(5 << 20, 0); + for (auto i = 0; i < 4; ++i) { + std::string etag = conn.upload_multipart_object_part( + bucket_name, + object_name, + upload_id, + std::span(data.data(), data.size()), + i + 1); + CHECK(!etag.empty()); + + minio::s3::Part part; + part.number = i + 1; + part.etag = etag; + part.size = data.size(); + + parts.push_back(part); + } + + // last part is 1MiB + { + const unsigned int part_number = parts.size() + 1; + const size_t part_size = 1 << 20; // 1MiB + std::string etag = conn.upload_multipart_object_part( + bucket_name, + object_name, + upload_id, + std::span(data.data(), data.size()), + part_number); + CHECK(!etag.empty()); + + minio::s3::Part part; + part.number = part_number; + part.etag = etag; + part.size = part_size; + + parts.push_back(part); + } + + CHECK(conn.complete_multipart_object( + bucket_name, object_name, upload_id, parts)); + + CHECK(conn.object_exists(bucket_name, object_name)); + + // cleanup + CHECK(conn.delete_object(bucket_name, object_name)); + + retval = 0; + } catch (const std::exception& e) { + LOG_ERROR("Failed: %s", e.what()); + } + + return retval; +} \ No newline at end of file diff --git a/tests/unit-tests/thread-pool-push-to-job-queue.cpp b/tests/unit-tests/thread-pool-push-to-job-queue.cpp new file mode 100644 index 00000000..12c26e08 --- /dev/null +++ b/tests/unit-tests/thread-pool-push-to-job-queue.cpp @@ -0,0 +1,60 @@ +#include "thread.pool.hh" +#include "unit.test.macros.hh" + +#include +#include +#include +#include + +namespace fs = std::filesystem; + +int +main() +{ + int retval = 0; + + fs::path tmp_path = fs::temp_directory_path() / TEST; + CHECK(!fs::exists(tmp_path)); + + zarr::ThreadPool pool{ 1, [](const std::string&) {} }; + + CHECK(pool.push_job([&tmp_path](std::string&) { + std::ofstream ofs(tmp_path); + ofs << "Hello, Acquire!"; + ofs.close(); + return true; + })); + pool.await_stop(); + + CHECK(fs::exists(tmp_path)); + + std::ifstream ifs(tmp_path); + CHECK(ifs.is_open()); + + std::string contents; + while (!ifs.eof()) { + std::getline(ifs, contents); + } + ifs.close(); + + if (contents != "Hello, Acquire!") { + fprintf(stderr, + "Expected 'Hello, Acquire!' but got '%s'\n", + contents.c_str()); + retval = 1; + } + + goto Cleanup; + +Finalize: + return retval; + +Cleanup: + std::error_code ec; + if (!fs::remove(tmp_path, ec)) { + fprintf(stderr, "Failed to remove file: %s\n", ec.message().c_str()); + retval = 1; + } + + goto Finalize; +} \ No newline at end of file