Skip to content

Commit

Permalink
Rename nlist to partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan committed Oct 15, 2024
1 parent 9f1a046 commit a642cf1
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 92 deletions.
2 changes: 1 addition & 1 deletion apis/python/src/tiledb/vector_search/ivf_pq_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create(
id_type=np.dtype(np.uint64).name,
partitioning_index_type=np.dtype(np.uint64).name,
dimensions=dimensions,
n_list=partitions if (partitions is not None and partitions != -1) else 0,
partitions=partitions if (partitions is not None and partitions != -1) else 0,
num_subspaces=num_subspaces,
distance_metric=int(distance_metric),
)
Expand Down
6 changes: 4 additions & 2 deletions apis/python/src/tiledb/vector_search/type_erased_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,11 @@ void init_type_erased_module(py::module_& m) {
"train",
[](IndexIVFPQ& index,
const FeatureVectorArray& vectors,
std::optional<size_t> nlist) { index.train(vectors, nlist); },
std::optional<size_t> partitions) {
index.train(vectors, partitions);
},
py::arg("vectors"),
py::arg("nlist") = std::nullopt)
py::arg("partitions") = std::nullopt)
.def(
"add",
[](IndexIVFPQ& index, const FeatureVectorArray& vectors) {
Expand Down
40 changes: 20 additions & 20 deletions src/include/api/ivf_flat_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class IndexIVFFlat {
for (auto&& c : *config) {
auto key = c.first;
auto value = c.second;
if (key == "nlist") {
nlist_ = std::stol(value);
if (key == "partitions") {
partitions_ = std::stol(value);
} else if (key == "dimensions") {
dimensions_ = std::stol(value);
} else if (key == "max_iter") {
Expand Down Expand Up @@ -255,12 +255,12 @@ class IndexIVFFlat {
" != " + std::to_string(index_->dimensions()));
}
dimensions_ = index_->dimensions();
if (nlist_ != 0 && nlist_ != index_->num_partitions()) {
if (partitions_ != 0 && partitions_ != index_->num_partitions()) {
throw std::runtime_error(
"nlist mismatch: " + std::to_string(nlist_) +
"partitions mismatch: " + std::to_string(partitions_) +
" != " + std::to_string(index_->num_partitions()));
}
nlist_ = index_->num_partitions();
partitions_ = index_->num_partitions();
}

/**
Expand Down Expand Up @@ -292,49 +292,49 @@ class IndexIVFFlat {
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint32_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT32 &&
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint32_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_UINT8 && id_datatype_ == TILEDB_UINT32 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint32_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT32 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint32_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_UINT8 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint64_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint64_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_UINT8 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint64_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint64_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
}

index_->train(training_set, init);
Expand All @@ -346,12 +346,12 @@ class IndexIVFFlat {
}
dimensions_ = index_->dimensions();

if (nlist_ != 0 && nlist_ != index_->num_partitions()) {
if (partitions_ != 0 && partitions_ != index_->num_partitions()) {
throw std::runtime_error(
"nlist mismatch: " + std::to_string(nlist_) +
"partitions mismatch: " + std::to_string(partitions_) +
" != " + std::to_string(index_->num_partitions()));
}
nlist_ = index_->num_partitions();
partitions_ = index_->num_partitions();
}

/**
Expand Down Expand Up @@ -440,7 +440,7 @@ class IndexIVFFlat {
}

constexpr auto num_partitions() const {
return nlist_;
return partitions_;
}

constexpr tiledb_datatype_t feature_type() const {
Expand Down Expand Up @@ -523,11 +523,11 @@ class IndexIVFFlat {
}

index_impl(
size_t nlist,
size_t partitions,
size_t max_iter,
float tolerance,
std::optional<size_t> num_threads)
: impl_index_(nlist, max_iter, tolerance) {
: impl_index_(partitions, max_iter, tolerance) {
}

index_impl(
Expand Down Expand Up @@ -717,7 +717,7 @@ class IndexIVFFlat {
};

uint64_t dimensions_ = 0;
size_t nlist_ = 0;
size_t partitions_ = 0;
uint32_t max_iterations_ = 2;
float tolerance_ = 1e-4;
std::optional<size_t> num_threads_ = std::nullopt;
Expand Down
Loading

0 comments on commit a642cf1

Please sign in to comment.