Skip to content

Commit

Permalink
Add a new read_vector() helper which reads slices
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan committed Oct 14, 2024
1 parent a187208 commit 7c32ce8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/include/detail/linalg/tdb_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,58 @@ std::vector<T> read_vector(
return read_vector_helper<T>(ctx, uri, 0, 0, temporal_policy, true);
}

/**
* Read the contents of a TileDB array into a std::vector.
* @tparam T Type of data element stored.
* @param ctx The TileDB context.
* @param uri The URI of the TileDB array.
* @param slices The slices to read. Each slice is a pair of start and end.
* @param temporal_policy The temporal policy for the read.
* @return The vector of data.
*/
template <class T, typename Slice>
std::vector<T> read_vector(
const tiledb::Context& ctx,
const std::string& uri,
const std::vector<std::pair<Slice, Slice>>& slices,
size_t total_slices_size,
TemporalPolicy temporal_policy = {}) {
if (total_slices_size == 0) {
return {};
}
scoped_timer _{tdb_func__ + " " + std::string{uri}};

auto array_ = tiledb_helpers::open_array(
tdb_func__, ctx, uri, TILEDB_READ, temporal_policy);
auto schema_ = array_->schema();

const size_t idx = 0;
auto attr = schema_.attribute(idx);

std::string attr_name = attr.name();

// Create a subarray that reads the array up to the specified subset.
tiledb::Subarray subarray(ctx, *array_);
for (const auto& slice : slices) {
subarray.add_range(
0, static_cast<int>(slice.first), static_cast<int>(slice.second));
}

// @todo: use something non-initializing
std::vector<T> data_(total_slices_size);

tiledb::Query query(ctx, *array_);
query.set_subarray(subarray).set_data_buffer(
attr_name, data_.data(), total_slices_size);
tiledb_helpers::submit_query(tdb_func__, uri, query);
_memory_data.insert_entry(tdb_func__, total_slices_size * sizeof(T));

array_->close();
assert(tiledb::Query::Status::COMPLETE == query.query_status());

return data_;
}

template <class T>
auto sizes_to_indices(const std::vector<T>& sizes) {
std::vector<T> indices(size(sizes) + 1);
Expand Down
30 changes: 30 additions & 0 deletions src/include/test/unit_tdb_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,33 @@ TEST_CASE("create group", "[tdb_io]") {

read_group.close();
}

TEST_CASE("read vector slices", "[tdb_io]") {
tiledb::Context ctx;
std::string uri =
(std::filesystem::temp_directory_path() / "tmp_vector").string();

tiledb::VFS vfs(ctx);
if (vfs.is_dir(uri)) {
vfs.remove_dir(uri);
}

size_t n = 100;
std::vector<int> vector(n);
std::iota(begin(vector), end(vector), 0);
write_vector(ctx, vector, uri);

auto result = read_vector<int>(ctx, uri);
CHECK(vector == result);

std::vector<std::pair<uint64_t, uint64_t>> slices;
slices.push_back({0, 1}); // 2 elements.
slices.push_back({3, 3}); // 1 element.
slices.push_back({50, 60}); // 11 elements
size_t total_slices_size = 14;

auto result_slice = read_vector<int>(ctx, uri, slices, total_slices_size);
auto expected =
std::vector<int>{0, 1, 3, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60};
CHECK(result_slice == expected);
}

0 comments on commit 7c32ce8

Please sign in to comment.