Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename nlist to partitions #551

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apis/python/src/tiledb/vector_search/ivf_pq_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create(
id_type=np.dtype(np.uint64).name,
partitioning_index_type=np.dtype(np.uint64).name,
dimensions=dimensions,
n_list=partitions if (partitions is not None and partitions != -1) else 0,
partitions=partitions if (partitions is not None and partitions != -1) else 0,
num_subspaces=num_subspaces,
distance_metric=int(distance_metric),
)
Expand Down
6 changes: 4 additions & 2 deletions apis/python/src/tiledb/vector_search/type_erased_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,11 @@ void init_type_erased_module(py::module_& m) {
"train",
[](IndexIVFPQ& index,
const FeatureVectorArray& vectors,
std::optional<size_t> nlist) { index.train(vectors, nlist); },
std::optional<size_t> partitions) {
index.train(vectors, partitions);
},
py::arg("vectors"),
py::arg("nlist") = std::nullopt)
py::arg("partitions") = std::nullopt)
.def(
"add",
[](IndexIVFPQ& index, const FeatureVectorArray& vectors) {
Expand Down
40 changes: 20 additions & 20 deletions src/include/api/ivf_flat_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class IndexIVFFlat {
for (auto&& c : *config) {
auto key = c.first;
auto value = c.second;
if (key == "nlist") {
nlist_ = std::stol(value);
if (key == "partitions") {
partitions_ = std::stol(value);
} else if (key == "dimensions") {
dimensions_ = std::stol(value);
} else if (key == "max_iter") {
Expand Down Expand Up @@ -255,12 +255,12 @@ class IndexIVFFlat {
" != " + std::to_string(index_->dimensions()));
}
dimensions_ = index_->dimensions();
if (nlist_ != 0 && nlist_ != index_->num_partitions()) {
if (partitions_ != 0 && partitions_ != index_->num_partitions()) {
throw std::runtime_error(
"nlist mismatch: " + std::to_string(nlist_) +
"partitions mismatch: " + std::to_string(partitions_) +
" != " + std::to_string(index_->num_partitions()));
}
nlist_ = index_->num_partitions();
partitions_ = index_->num_partitions();
}

/**
Expand Down Expand Up @@ -292,49 +292,49 @@ class IndexIVFFlat {
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint32_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT32 &&
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint32_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_UINT8 && id_datatype_ == TILEDB_UINT32 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint32_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT32 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint32_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_UINT8 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint64_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT32) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint64_t, uint32_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_UINT8 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<uint8_t, uint64_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
} else if (
feature_datatype_ == TILEDB_FLOAT32 && id_datatype_ == TILEDB_UINT64 &&
px_datatype_ == TILEDB_UINT64) {
index_ = std::make_unique<
index_impl<ivf_flat_index<float, uint64_t, uint64_t>>>(
nlist_, max_iterations_, tolerance_, num_threads_);
partitions_, max_iterations_, tolerance_, num_threads_);
}

index_->train(training_set, init);
Expand All @@ -346,12 +346,12 @@ class IndexIVFFlat {
}
dimensions_ = index_->dimensions();

if (nlist_ != 0 && nlist_ != index_->num_partitions()) {
if (partitions_ != 0 && partitions_ != index_->num_partitions()) {
throw std::runtime_error(
"nlist mismatch: " + std::to_string(nlist_) +
"partitions mismatch: " + std::to_string(partitions_) +
" != " + std::to_string(index_->num_partitions()));
}
nlist_ = index_->num_partitions();
partitions_ = index_->num_partitions();
}

/**
Expand Down Expand Up @@ -440,7 +440,7 @@ class IndexIVFFlat {
}

constexpr auto num_partitions() const {
return nlist_;
return partitions_;
}

constexpr tiledb_datatype_t feature_type() const {
Expand Down Expand Up @@ -523,11 +523,11 @@ class IndexIVFFlat {
}

index_impl(
size_t nlist,
size_t partitions,
size_t max_iter,
float tolerance,
std::optional<size_t> num_threads)
: impl_index_(nlist, max_iter, tolerance) {
: impl_index_(partitions, max_iter, tolerance) {
}

index_impl(
Expand Down Expand Up @@ -717,7 +717,7 @@ class IndexIVFFlat {
};

uint64_t dimensions_ = 0;
size_t nlist_ = 0;
size_t partitions_ = 0;
uint32_t max_iterations_ = 2;
float tolerance_ = 1e-4;
std::optional<size_t> num_threads_ = std::nullopt;
Expand Down
Loading
Loading