Skip to content

Commit

Permalink
Update tdbMatrixMultiRange to support both slices and individual indi…
Browse files Browse the repository at this point in the history
…ces (#550)
  • Loading branch information
jparismorgan authored Oct 15, 2024
1 parent 9f1a046 commit 849b6e3
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 35 deletions.
139 changes: 118 additions & 21 deletions src/include/detail/linalg/tdb_matrix_multi_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,18 @@ class tdbBlockedMatrixMultiRange : public Matrix<T, LayoutPolicy, I> {
std::unique_ptr<tiledb::Array> array_;
tiledb::ArraySchema schema_;

// The indices of all the columns to load. The size of this is the total
// number of columns.
// This class supports two ways of multi-range queries. One of these two will
// be used.
enum class QueryType { ColumnIndices, ColumnSlices };
QueryType query_type_{QueryType::ColumnIndices};

// 1. By passing the indices of the columns to load.
std::vector<I> column_indices_;
// 2. By passing slices of the columns to load.
std::vector<std::pair<I, I>> column_slices_;

// The total number of columns in either column_indices_ or column_slices_.
size_t total_num_columns_{0};

// The max number of columns that can fit in allocated memory
size_t column_capacity_{0};
Expand All @@ -79,20 +88,13 @@ class tdbBlockedMatrixMultiRange : public Matrix<T, LayoutPolicy, I> {
// The final index numbers of the resident columns
size_t last_resident_col_{0};

size_t get_elements_to_load() const {
// Note that here we try to load column_indices_.size() vectors. If we are
[[nodiscard]] size_t get_elements_to_load() const {
// Note that here we try to load the max number of vectors. If we are
// time travelling, these vectors may not exist in the array, but we still
// need to load them to know that they don't exist.
return std::min(
column_capacity_, column_indices_.size() - last_resident_col_);
return std::min(column_capacity_, total_num_columns_ - last_resident_col_);
}

public:
tdbBlockedMatrixMultiRange(tdbBlockedMatrixMultiRange&& rhs) = default;

/** Default destructor. array will be closed when it goes out of scope */
virtual ~tdbBlockedMatrixMultiRange() = default;

/**
* @brief Construct a new tdbBlockedMatrixMultiRange object, limited to
* `upper_bound` vectors. In this case, the `Matrix` is column-major, so the
Expand All @@ -102,30 +104,43 @@ class tdbBlockedMatrixMultiRange : public Matrix<T, LayoutPolicy, I> {
* @param uri URI of the TileDB array to read.
* @param indices The indices of the columns to read.
* @param dimensions The number of dimensions in each vector.
* @param upper_bound The maximum number of vectors to read.
* @param query_type The type of query to perform.
* @param column_indices The indices of the columns to read. Should only be
* passed with QueryType::ColumnIndices.
* @param column_slices The slices of the columns to read. Should only be
* passed with QueryType::ColumnSlices.
* @param total_num_columns The total number of columns in the array.
* @param upper_bound The maximum number of vectors to read in at once.
* @param temporal_policy The TemporalPolicy to use for reading the array
* data.
*/
tdbBlockedMatrixMultiRange(
const tiledb::Context& ctx,
const std::string& uri,
const std::vector<I>& column_indices,
size_type dimensions,
QueryType query_type,
const std::vector<I>& column_indices,
const std::vector<std::pair<I, I>>& column_slices,
size_t total_num_columns,
size_t upper_bound,
TemporalPolicy temporal_policy = TemporalPolicy{})
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
: Base(dimensions, column_indices.size())
: Base(dimensions, total_num_columns)
, ctx_{ctx}
, dimensions_{dimensions}
, uri_{uri}
, array_(std::make_unique<tiledb::Array>(
ctx, uri, TILEDB_READ, temporal_policy.to_tiledb_temporal_policy()))
, schema_{array_->schema()}
, column_indices_{column_indices} {
, query_type_{query_type}
, column_indices_{column_indices}
, column_slices_{column_slices}
, total_num_columns_{total_num_columns} {
constructor_timer.stop();

// The default is to load all the vectors.
if (upper_bound == 0 || upper_bound > column_indices_.size()) {
column_capacity_ = column_indices_.size();
if (upper_bound == 0 || upper_bound > total_num_columns_) {
column_capacity_ = total_num_columns_;
} else {
column_capacity_ = upper_bound;
}
Expand All @@ -150,6 +165,80 @@ class tdbBlockedMatrixMultiRange : public Matrix<T, LayoutPolicy, I> {
Base::operator=(Base{std::move(data), dimensions, column_capacity_});
}

public:
tdbBlockedMatrixMultiRange(tdbBlockedMatrixMultiRange&& rhs) = default;

/** Default destructor. array will be closed when it goes out of scope */
virtual ~tdbBlockedMatrixMultiRange() = default;

/**
* @brief Construct a new tdbBlockedMatrixMultiRange object, limited to
* `upper_bound` vectors. In this case, the `Matrix` is column-major, so the
* number of vectors is the number of columns.
*
* @param ctx The TileDB context to use.
* @param uri URI of the TileDB array to read.
* @param dimensions The number of dimensions in each vector.
* @param column_indices The indices of the columns to read.
* @param upper_bound The maximum number of vectors to read in at once.
* @param temporal_policy The TemporalPolicy to use for reading the array
* data.
*/
tdbBlockedMatrixMultiRange(
const tiledb::Context& ctx,
const std::string& uri,
size_type dimensions,
const std::vector<I>& column_indices,
size_t upper_bound,
TemporalPolicy temporal_policy = TemporalPolicy{})
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
: tdbBlockedMatrixMultiRange(
ctx,
uri,
dimensions,
QueryType::ColumnIndices,
column_indices,
{},
column_indices.size(),
upper_bound,
temporal_policy) {
}

/**
* @brief Construct a new tdbBlockedMatrixMultiRange object, limited to
* `upper_bound` vectors. In this case, the `Matrix` is column-major, so the
* number of vectors is the number of columns.
*
* @param ctx The TileDB context to use.
* @param uri URI of the TileDB array to read.
* @param dimensions The number of dimensions in each vector.
* @param column_slices The slices of the columns to read.
* @param total_slices_size The total number of columns in the slices.
* @param upper_bound The maximum number of vectors to read in at once.
* @param temporal_policy The TemporalPolicy to use for reading the array
* data.
*/
tdbBlockedMatrixMultiRange(
const tiledb::Context& ctx,
const std::string& uri,
size_type dimensions,
const std::vector<std::pair<I, I>>& column_slices,
size_t total_slices_size,
size_t upper_bound,
TemporalPolicy temporal_policy = TemporalPolicy{})
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
: tdbBlockedMatrixMultiRange(
ctx,
uri,
dimensions,
QueryType::ColumnSlices,
{},
column_slices,
total_slices_size,
upper_bound,
temporal_policy) {
}

bool load() {
scoped_timer _{"tdb_matrix_multi_range@load"};

Expand Down Expand Up @@ -180,9 +269,17 @@ class tdbBlockedMatrixMultiRange : public Matrix<T, LayoutPolicy, I> {
subarray.add_range(0, 0, static_cast<int>(dimensions_) - 1);

// Setup the query ranges.
for (size_t i = first_resident_col; i < last_resident_col_; ++i) {
const auto index = static_cast<int>(column_indices_[i]);
subarray.add_range(1, index, index);
if (query_type_ == QueryType::ColumnIndices) {
for (size_t i = 0; i < column_indices_.size(); ++i) {
const auto index = static_cast<int>(column_indices_[i]);
subarray.add_range(1, index, index);
}
} else {
for (size_t i = 0; i < column_slices_.size(); ++i) {
const auto start = static_cast<int>(column_slices_[i].first);
const auto end = static_cast<int>(column_slices_[i].second);
subarray.add_range(1, start, end);
}
}

// Execute the query.
Expand Down
2 changes: 1 addition & 1 deletion src/include/index/ivf_pq_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -1273,8 +1273,8 @@ class ivf_pq_index {
tdbColMajorMatrixMultiRange<feature_type, uint64_t>(
group_->cached_ctx(),
group_->feature_vectors_uri(),
vector_indices,
dimensions_,
vector_indices,
0,
temporal_policy_);
feature_vectors.load();
Expand Down
45 changes: 32 additions & 13 deletions src/include/test/unit_tdb_matrix_multi_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ TEMPLATE_TEST_CASE(
std::iota(column_indices.begin(), column_indices.end(), 0);

auto Y = tdbColMajorMatrixMultiRange<TestType>(
ctx, tmp_matrix_uri, column_indices, dimensions, num_vectors);
ctx, tmp_matrix_uri, dimensions, column_indices, num_vectors);
CHECK(Y.load() == true);
for (int i = 0; i < 5; ++i) {
CHECK(Y.load() == false);
Expand Down Expand Up @@ -115,7 +115,7 @@ TEMPLATE_TEST_CASE(
std::vector<size_t> column_indices(num_vectors);
std::iota(column_indices.begin(), column_indices.end(), 0);
auto Y = tdbColMajorMatrixMultiRange<TestType>(
ctx, tmp_matrix_uri, column_indices, dimensions, num_vectors);
ctx, tmp_matrix_uri, dimensions, column_indices, num_vectors);
Y.load();
B = std::move(Y);
}
Expand All @@ -125,13 +125,13 @@ TEMPLATE_TEST_CASE(
std::iota(column_indices.begin(), column_indices.end(), 0);
auto Y = tdbColMajorMatrixMultiRange<TestType>(
tdbColMajorMatrixMultiRange<TestType>(
ctx, tmp_matrix_uri, column_indices, dimensions, num_vectors));
ctx, tmp_matrix_uri, dimensions, column_indices, num_vectors));
}

std::vector<size_t> column_indices(num_vectors);
std::iota(column_indices.begin(), column_indices.end(), 0);
auto Y = tdbColMajorMatrixMultiRange<TestType>(
ctx, tmp_matrix_uri, column_indices, dimensions, num_vectors);
ctx, tmp_matrix_uri, dimensions, column_indices, num_vectors);
Y.load();

CHECK(::num_vectors(Y) == ::num_vectors(X));
Expand Down Expand Up @@ -206,7 +206,7 @@ TEST_CASE("limit column_indices", "[tdb_matrix_multi_range]") {
std::vector<size_t> column_indices = {
0, 1, 2, 3, 10, 100, 15, 299, 309, 4, 100};
auto Y = tdbBlockedMatrixMultiRange<T, LayoutPolicy, I>(
ctx, tmp_matrix_uri, column_indices, dimensions, 0);
ctx, tmp_matrix_uri, dimensions, column_indices, 0);
Y.load();
CHECK(::num_vectors(Y) == column_indices.size());
CHECK(::dimensions(Y) == ::dimensions(X));
Expand Down Expand Up @@ -242,7 +242,12 @@ TEST_CASE("empty matrix", "[tdb_matrix_multi_range]") {
{
// No dimensions and no num_vectors.
auto X = tdbColMajorMatrixMultiRange<float>(
ctx, tmp_matrix_uri, {}, 0, 0, TemporalPolicy{TimeTravel, 50});
ctx,
tmp_matrix_uri,
0,
std::vector<size_t>{},
0,
TemporalPolicy{TimeTravel, 50});
X.load();
CHECK(X.num_cols() == 0);
CHECK(::num_vectors(X) == 0);
Expand All @@ -253,7 +258,12 @@ TEST_CASE("empty matrix", "[tdb_matrix_multi_range]") {
{
// All dimensions and no num_vectors.
auto X = tdbColMajorMatrixMultiRange<float>(
ctx, tmp_matrix_uri, {}, 0, 0, TemporalPolicy{TimeTravel, 50});
ctx,
tmp_matrix_uri,
0,
std::vector<size_t>{},
0,
TemporalPolicy{TimeTravel, 50});
X.load();
CHECK(X.num_cols() == 0);
CHECK(::num_vectors(X) == 0);
Expand All @@ -264,7 +274,12 @@ TEST_CASE("empty matrix", "[tdb_matrix_multi_range]") {
{
// No dimensions and all num_vectors.
auto X = tdbColMajorMatrixMultiRange<float>(
ctx, tmp_matrix_uri, {}, 0, 0, TemporalPolicy{TimeTravel, 50});
ctx,
tmp_matrix_uri,
0,
std::vector<size_t>{},
0,
TemporalPolicy{TimeTravel, 50});
X.load();
CHECK(X.num_cols() == 0);
CHECK(::num_vectors(X) == 0);
Expand All @@ -275,7 +290,12 @@ TEST_CASE("empty matrix", "[tdb_matrix_multi_range]") {
{
// No constraints.
auto X = tdbColMajorMatrixMultiRange<float>(
ctx, tmp_matrix_uri, {}, 0, 0, TemporalPolicy{TimeTravel, 50});
ctx,
tmp_matrix_uri,
0,
std::vector<size_t>{},
0,
TemporalPolicy{TimeTravel, 50});
X.load();
CHECK(X.num_cols() == 0);
CHECK(::num_vectors(X) == 0);
Expand Down Expand Up @@ -304,12 +324,11 @@ TEST_CASE("time travel", "[tdb_matrix_multi_range]") {

std::vector<size_t> column_indices(num_vectors);
std::iota(column_indices.begin(), column_indices.end(), 0);
debug_matrix(column_indices, "column_indices", 100);

{
// We can load the matrix at the creation timestamp.
auto Y = tdbColMajorMatrixMultiRange<int>(
ctx, tmp_matrix_uri, column_indices, dimensions, 0);
ctx, tmp_matrix_uri, dimensions, column_indices, 0);
CHECK(Y.load());
CHECK(::num_vectors(Y) == ::num_vectors(X));
CHECK(::dimensions(Y) == ::dimensions(X));
Expand All @@ -327,8 +346,8 @@ TEST_CASE("time travel", "[tdb_matrix_multi_range]") {
auto Y = tdbColMajorMatrixMultiRange<int>(
ctx,
tmp_matrix_uri,
column_indices,
dimensions,
column_indices,
num_vectors,
TemporalPolicy{TimeTravel, 100});
CHECK(Y.load());
Expand All @@ -348,8 +367,8 @@ TEST_CASE("time travel", "[tdb_matrix_multi_range]") {
auto Y = tdbColMajorMatrixMultiRange<int>(
ctx,
tmp_matrix_uri,
column_indices,
dimensions,
column_indices,
num_vectors,
TemporalPolicy{TimeTravel, 5});
CHECK(Y.load());
Expand Down

0 comments on commit 849b6e3

Please sign in to comment.