From 7c32ce8b4ed529dc9d90fe46ae5647829e5b514a Mon Sep 17 00:00:00 2001 From: Paris Morgan Date: Mon, 14 Oct 2024 16:20:10 -0700 Subject: [PATCH] Add a new read_vector() helper which reads slices --- src/include/detail/linalg/tdb_io.h | 52 ++++++++++++++++++++++++++++++ src/include/test/unit_tdb_io.cc | 30 +++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/src/include/detail/linalg/tdb_io.h b/src/include/detail/linalg/tdb_io.h index d1807be66..9a7e76437 100644 --- a/src/include/detail/linalg/tdb_io.h +++ b/src/include/detail/linalg/tdb_io.h @@ -373,6 +373,58 @@ std::vector read_vector( return read_vector_helper(ctx, uri, 0, 0, temporal_policy, true); } +/** + * Read the contents of a TileDB array into a std::vector. + * @tparam T Type of data element stored. + * @param ctx The TileDB context. + * @param uri The URI of the TileDB array. + * @param slices The slices to read. Each slice is a pair of start and end. + * @param temporal_policy The temporal policy for the read. + * @return The vector of data. + */ +template +std::vector read_vector( + const tiledb::Context& ctx, + const std::string& uri, + const std::vector>& slices, + size_t total_slices_size, + TemporalPolicy temporal_policy = {}) { + if (total_slices_size == 0) { + return {}; + } + scoped_timer _{tdb_func__ + " " + std::string{uri}}; + + auto array_ = tiledb_helpers::open_array( + tdb_func__, ctx, uri, TILEDB_READ, temporal_policy); + auto schema_ = array_->schema(); + + const size_t idx = 0; + auto attr = schema_.attribute(idx); + + std::string attr_name = attr.name(); + + // Create a subarray that reads the array up to the specified subset. + tiledb::Subarray subarray(ctx, *array_); + for (const auto& slice : slices) { + subarray.add_range( + 0, static_cast(slice.first), static_cast(slice.second)); + } + + // @todo: use something non-initializing + std::vector data_(total_slices_size); + + tiledb::Query query(ctx, *array_); + query.set_subarray(subarray).set_data_buffer( + attr_name, data_.data(), total_slices_size); + tiledb_helpers::submit_query(tdb_func__, uri, query); + _memory_data.insert_entry(tdb_func__, total_slices_size * sizeof(T)); + + array_->close(); + assert(tiledb::Query::Status::COMPLETE == query.query_status()); + + return data_; +} + template auto sizes_to_indices(const std::vector& sizes) { std::vector indices(size(sizes) + 1); diff --git a/src/include/test/unit_tdb_io.cc b/src/include/test/unit_tdb_io.cc index 4b285e65d..193eef5be 100644 --- a/src/include/test/unit_tdb_io.cc +++ b/src/include/test/unit_tdb_io.cc @@ -275,3 +275,33 @@ TEST_CASE("create group", "[tdb_io]") { read_group.close(); } + +TEST_CASE("read vector slices", "[tdb_io]") { + tiledb::Context ctx; + std::string uri = + (std::filesystem::temp_directory_path() / "tmp_vector").string(); + + tiledb::VFS vfs(ctx); + if (vfs.is_dir(uri)) { + vfs.remove_dir(uri); + } + + size_t n = 100; + std::vector vector(n); + std::iota(begin(vector), end(vector), 0); + write_vector(ctx, vector, uri); + + auto result = read_vector(ctx, uri); + CHECK(vector == result); + + std::vector> slices; + slices.push_back({0, 1}); // 2 elements. + slices.push_back({3, 3}); // 1 element. + slices.push_back({50, 60}); // 11 elements + size_t total_slices_size = 14; + + auto result_slice = read_vector(ctx, uri, slices, total_slices_size); + auto expected = + std::vector{0, 1, 3, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}; + CHECK(result_slice == expected); +}