From 5a789893f05c094735e6128bc6a4413160f19689 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 12 Nov 2024 14:27:22 +0800 Subject: [PATCH 01/15] perf: traspose the pq code to make calculating efficient Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/distance.rs | 65 ++++++++++++++-------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index bd45e45a44..beb676fae6 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -76,31 +76,48 @@ pub(super) fn compute_l2_distance( num_sub_vectors: usize, code: &[u8], ) -> Vec { + // here `code` has been transposed, + // so code[i][j] is the code of i-th sub-vector of the j-th vector, + // and `code` is a flatten array of [num_sub_vectors, num_vectors] u8, + // so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector. + let num_vectors = code.len() / num_sub_vectors; + let mut distances = vec![0.0_f32; num_vectors]; let num_centroids = num_centroids(num_bits); + for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { + let dist_table = &distance_table[sub_vec_idx * num_centroids..]; + vec_indices + .iter() + .zip(distances.iter_mut()) + .for_each(|(¢roid_idx, sum)| { + *sum += dist_table[centroid_idx as usize]; + }); + } - let iter = code.chunks_exact(num_sub_vectors * V); - let distances = iter.clone().flat_map(|c| { - let mut sums = [0.0_f32; V]; - for i in (0..num_sub_vectors).step_by(C) { - for (vec_idx, sum) in sums.iter_mut().enumerate() { - let vec_start = vec_idx * num_sub_vectors; - let s = c[vec_start + i..] - .iter() - .take(min(C, num_sub_vectors - i)) - .enumerate() - .map(|(k, c)| distance_table[(i + k) * num_centroids + *c as usize]) - .sum::(); - *sum += s; - } - } - sums.into_iter() - }); + distances + + // let iter = code.chunks_exact(num_sub_vectors * V); + // let distances = iter.clone().flat_map(|c| { + // let mut sums = [0.0_f32; V]; + // for i in (0..num_sub_vectors).step_by(C) { + // for (vec_idx, sum) in sums.iter_mut().enumerate() { + // let vec_start = vec_idx * num_sub_vectors; + // let s = c[vec_start + i..] + // .iter() + // .take(min(C, num_sub_vectors - i)) + // .enumerate() + // .map(|(k, c)| distance_table[(i + k) * num_centroids + *c as usize]) + // .sum::(); + // *sum += s; + // } + // } + // sums.into_iter() + // }); // Remainder - let remainder = iter.remainder().chunks(num_sub_vectors).map(|c| { - c.iter() - .enumerate() - .map(|(sub_vec_idx, code)| distance_table[sub_vec_idx * num_centroids + *code as usize]) - .sum::() - }); - distances.chain(remainder).collect() + // let remainder = iter.remainder().chunks(num_sub_vectors).map(|c| { + // c.iter() + // .enumerate() + // .map(|(sub_vec_idx, code)| distance_table[sub_vec_idx * num_centroids + *code as usize]) + // .sum::() + // }); + // distances.chain(remainder).collect() } From ed4522461bf672bc78a5e4b9ca1c4a6a428f6ba9 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 12 Nov 2024 17:15:23 +0800 Subject: [PATCH 02/15] transpose pq codes when writing Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq.rs | 1 + rust/lance-index/src/vector/pq/distance.rs | 4 +- rust/lance-index/src/vector/pq/storage.rs | 81 +++++++++++----------- rust/lance/src/index/vector/pq.rs | 1 + 4 files changed, 46 insertions(+), 41 deletions(-) diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index f6195ec0fb..352a65a926 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -398,6 +398,7 @@ impl Quantization for ProductQuantizer { dimension: self.dimension, codebook: None, codebook_tensor: tensor.encode_to_vec(), + transposed: false, })?) } diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index beb676fae6..f98aabd846 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::cmp::min; - use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2}; use super::{num_centroids, utils::get_sub_vector_centroids}; @@ -80,6 +78,8 @@ pub(super) fn compute_l2_distance( // so code[i][j] is the code of i-th sub-vector of the j-th vector, // and `code` is a flatten array of [num_sub_vectors, num_vectors] u8, // so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector. + + // `distance_table` is a flatten array of [num_sub_vectors, num_centroids] f32, let num_vectors = code.len() / num_sub_vectors; let mut distances = vec![0.0_f32; num_vectors]; let num_centroids = num_centroids(num_bits); diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index bab860fa48..71f470e880 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -62,6 +62,7 @@ pub struct ProductQuantizationMetadata { // empty for old format pub codebook_tensor: Vec, + pub transposed: bool, } impl DeepSizeOf for ProductQuantizationMetadata { @@ -151,6 +152,7 @@ impl ProductQuantizationStorage { num_sub_vectors: usize, dimension: usize, distance_type: DistanceType, + transposed: bool, ) -> Result { let Some(row_ids) = batch.column_by_name(ROW_ID) else { return Err(Error::Index { @@ -180,18 +182,35 @@ impl ProductQuantizationStorage { ), location: location!(), })?; - let pq_code: Arc = pq_code_fsl - .values() - .as_primitive_opt::() - .ok_or(Error::Index { - message: format!( - "{PQ_CODE_COLUMN} column is not of type UInt8: {}", - pq_col.data_type() - ), - location: location!(), - })? - .clone() - .into(); + + let pq_code: Arc = if transposed { + pq_code_fsl + .values() + .as_primitive_opt::() + .ok_or(Error::Index { + message: format!( + "{PQ_CODE_COLUMN} column is not of type UInt8: {}", + pq_col.data_type() + ), + location: location!(), + })? + .clone() + .into() + } else { + let mut transposed_code = vec![0; pq_code_fsl.values().len()]; + for (vec_idx, codes) in pq_code_fsl + .values() + .as_primitive::() + .values() + .chunks_exact(pq_code_fsl.len()) + .enumerate() + { + for (cluster_idx, code) in codes.iter().enumerate() { + transposed_code[cluster_idx * pq_code_fsl.len() + vec_idx] = *code; + } + } + Arc::new(UInt8Array::from(transposed_code)) + }; Ok(Self { codebook, @@ -231,7 +250,6 @@ impl ProductQuantizationStorage { let metric_type = quantizer.distance_type; let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN); let batch = transform.transform(batch)?; - Self::new( codebook, batch, @@ -239,6 +257,7 @@ impl ProductQuantizationStorage { num_sub_vectors, dimension, metric_type, + false, ) } @@ -327,6 +346,7 @@ impl ProductQuantizationStorage { dimension: self.dimension, codebook: None, codebook_tensor: Vec::new(), + transposed: true, }; let index_metadata = IndexMetadata { @@ -373,6 +393,7 @@ impl QuantizerStorage for ProductQuantizationStorage { .values() .as_primitive::() .clone(); + let codebook = FixedSizeListArray::try_new_from_values(codebook, metadata.dimension as i32)?; @@ -386,6 +407,7 @@ impl QuantizerStorage for ProductQuantizationStorage { metadata.num_sub_vectors, metadata.dimension, distance_type, + metadata.transposed, ) } } @@ -411,35 +433,15 @@ impl VectorStore for ProductQuantizationStorage { let codebook_tensor = pb::Tensor::decode(metadata.codebook_tensor.as_slice())?; let codebook = FixedSizeListArray::try_from(&codebook_tensor)?; - let row_ids = batch - .column_by_name(ROW_ID) - .ok_or(Error::Index { - message: "Row IDs column not found in batch".to_string(), - location: location!(), - })? - .as_primitive::() - .clone(); - let pq_code = batch - .column_by_name(PQ_CODE_COLUMN) - .ok_or(Error::Index { - message: "PQ code column not found in batch".to_string(), - location: location!(), - })? - .as_fixed_size_list() - .values() - .as_primitive::() - .clone(); - - Ok(Self { + Self::new( codebook, batch, - num_bits: metadata.num_bits, - num_sub_vectors: metadata.num_sub_vectors, - dimension: metadata.dimension, + metadata.num_bits, + metadata.num_sub_vectors, + metadata.dimension, distance_type, - pq_code: Arc::new(pq_code), - row_ids: Arc::new(row_ids), - }) + metadata.transposed, + ) } fn to_batches(&self) -> Result> { @@ -451,6 +453,7 @@ impl VectorStore for ProductQuantizationStorage { dimension: self.dimension, codebook: None, codebook_tensor: codebook, + transposed: true, // we always transpose the pq codes for efficiency }; let metadata_json = serde_json::to_string(&metadata)?; diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 9aac8424c1..20137b56f6 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -448,6 +448,7 @@ pub(crate) fn build_pq_storage( pq.code_dim(), pq.dimension, distance_type, + false, )?; Ok(pq_store) From 9650a02c5067dcc014d00b03f614a998bfd1703d Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 12 Nov 2024 17:41:12 +0800 Subject: [PATCH 03/15] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 71f470e880..ef609d2cf8 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -202,11 +202,11 @@ impl ProductQuantizationStorage { .values() .as_primitive::() .values() - .chunks_exact(pq_code_fsl.len()) + .chunks_exact(pq_code_fsl.value_length() as usize) .enumerate() { - for (cluster_idx, code) in codes.iter().enumerate() { - transposed_code[cluster_idx * pq_code_fsl.len() + vec_idx] = *code; + for (sub_vec_idx, code) in codes.iter().enumerate() { + transposed_code[sub_vec_idx * pq_code_fsl.len() + vec_idx] = *code; } } Arc::new(UInt8Array::from(transposed_code)) From 5a950022fae1612ee54f5828055df61fe70c094f Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 12 Nov 2024 20:43:28 +0800 Subject: [PATCH 04/15] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq.rs | 24 ++++++------ rust/lance-index/src/vector/pq/storage.rs | 48 ++++++++--------------- rust/lance/src/index/vector/ivf/v2.rs | 4 +- rust/lance/src/index/vector/pq.rs | 12 ++++-- 4 files changed, 40 insertions(+), 48 deletions(-) diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 352a65a926..d80a96641d 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -219,17 +219,19 @@ impl ProductQuantizer { distance_table.extend(distances); }); - // Compute distance from the pre-compute table. - Ok(Float32Array::from_iter_values( - code.values().chunks_exact(self.num_sub_vectors).map(|c| { - c.iter() - .enumerate() - .map(|(sub_vec_idx, centroid)| { - distance_table[sub_vec_idx * 256 + *centroid as usize] - }) - .sum::() - }), - )) + let num_vectors = code.len() / self.num_sub_vectors; + let mut distances = vec![0.0; num_vectors]; + let num_centroids = num_centroids(self.num_bits); + for (sub_vec_idx, vec_indices) in code.values().chunks_exact(num_vectors).enumerate() { + let dist_table = &distance_table[sub_vec_idx * num_centroids..]; + vec_indices + .iter() + .zip(distances.iter_mut()) + .for_each(|(¢roid_idx, sum)| { + *sum += dist_table[centroid_idx as usize]; + }); + } + Ok(distances.into()) } fn build_l2_distance_table(&self, key: &dyn Array) -> Result> { diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index ef609d2cf8..97e3cfcd68 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -175,41 +175,15 @@ impl ProductQuantizationStorage { location: location!(), }); }; - let pq_code_fsl = pq_col.as_fixed_size_list_opt().ok_or(Error::Index { - message: format!( - "{PQ_CODE_COLUMN} column is not of type UInt8: {}", - pq_col.data_type() - ), - location: location!(), - })?; + let pq_codes = pq_col + .as_fixed_size_list() + .values() + .as_primitive::(); let pq_code: Arc = if transposed { - pq_code_fsl - .values() - .as_primitive_opt::() - .ok_or(Error::Index { - message: format!( - "{PQ_CODE_COLUMN} column is not of type UInt8: {}", - pq_col.data_type() - ), - location: location!(), - })? - .clone() - .into() + pq_codes.clone().into() } else { - let mut transposed_code = vec![0; pq_code_fsl.values().len()]; - for (vec_idx, codes) in pq_code_fsl - .values() - .as_primitive::() - .values() - .chunks_exact(pq_code_fsl.value_length() as usize) - .enumerate() - { - for (sub_vec_idx, code) in codes.iter().enumerate() { - transposed_code[sub_vec_idx * pq_code_fsl.len() + vec_idx] = *code; - } - } - Arc::new(UInt8Array::from(transposed_code)) + transpose(pq_codes, num_sub_vectors, row_ids.len()).into() }; Ok(Self { @@ -368,6 +342,16 @@ impl ProductQuantizationStorage { } } +pub fn transpose(pq_codes: &UInt8Array, num_sub_vectors: usize, num_vectors: usize) -> UInt8Array { + let mut transposed_codes = vec![0; pq_codes.len()]; + for (vec_idx, codes) in pq_codes.values().chunks_exact(num_sub_vectors).enumerate() { + for (sub_vec_idx, code) in codes.iter().enumerate() { + transposed_codes[sub_vec_idx * num_vectors + vec_idx] = *code; + } + } + transposed_codes.into() +} + #[async_trait] impl QuantizerStorage for ProductQuantizationStorage { type Metadata = ProductQuantizationMetadata; diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index d4807cac06..9d67667038 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -659,8 +659,8 @@ mod tests { #[rstest] #[case(4, DistanceType::L2, 0.9)] - #[case(4, DistanceType::Cosine, 0.6)] - #[case(4, DistanceType::Dot, 0.2)] + #[case(4, DistanceType::Cosine, 0.9)] + #[case(4, DistanceType::Dot, 0.9)] #[tokio::test] async fn test_build_ivf_pq( #[case] nlist: usize, diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 20137b56f6..1770226bbb 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -19,7 +19,7 @@ use lance_core::utils::tokio::spawn_cpu; use lance_core::ROW_ID; use lance_core::{utils::address::RowAddress, ROW_ID_FIELD}; use lance_index::vector::ivf::storage::IvfModel; -use lance_index::vector::pq::storage::ProductQuantizationStorage; +use lance_index::vector::pq::storage::{transpose, ProductQuantizationStorage}; use lance_index::vector::quantizer::{Quantization, QuantizationType, Quantizer}; use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ @@ -245,7 +245,7 @@ impl VectorIndex for PQIndex { length: usize, ) -> Result> { let pq_code_length = self.pq.code_dim() * length; - let pq_code = read_fixed_stride_array( + let pq_codes = read_fixed_stride_array( reader.as_ref(), &DataType::UInt8, offset, @@ -264,8 +264,14 @@ impl VectorIndex for PQIndex { ) .await?; + let pq_codes = transpose( + pq_codes.as_primitive(), + self.pq.num_sub_vectors, + row_ids.len(), + ); + Ok(Box::new(Self { - code: Some(Arc::new(pq_code.as_primitive().clone())), + code: Some(Arc::new(pq_codes)), row_ids: Some(Arc::new(row_ids.as_primitive().clone())), pq: self.pq.clone(), metric_type: self.metric_type, From 71432db742d0fdd626badad7f8b28fba4bc1c725 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 12 Nov 2024 21:02:10 +0800 Subject: [PATCH 05/15] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 97e3cfcd68..374b580bf0 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -147,7 +147,7 @@ impl PartialEq for ProductQuantizationStorage { impl ProductQuantizationStorage { pub fn new( codebook: FixedSizeListArray, - batch: RecordBatch, + mut batch: RecordBatch, num_bits: u32, num_sub_vectors: usize, dimension: usize, @@ -175,15 +175,17 @@ impl ProductQuantizationStorage { location: location!(), }); }; - let pq_codes = pq_col + let pq_code = pq_col .as_fixed_size_list() .values() .as_primitive::(); let pq_code: Arc = if transposed { - pq_codes.clone().into() + pq_code.clone().into() } else { - transpose(pq_codes, num_sub_vectors, row_ids.len()).into() + let pq_code: Arc<_> = transpose(pq_code, num_sub_vectors, row_ids.len()).into(); + batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code.clone())?; + pq_code }; Ok(Self { From 0937f4a08297abdd370227a02c5a905e3ae5514d Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 12 Nov 2024 22:01:53 +0800 Subject: [PATCH 06/15] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq.rs | 4 +-- rust/lance-index/src/vector/pq/storage.rs | 36 ++++++++++++----------- rust/lance-index/src/vector/quantizer.rs | 1 + rust/lance/src/index/vector/builder.rs | 12 ++++++-- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index d80a96641d..1362143a2b 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -384,7 +384,7 @@ impl Quantization for ProductQuantizer { } fn metadata(&self, args: Option) -> Result { - let codebook_position = match args { + let codebook_position = match &args { Some(args) => args.codebook_position, None => Some(0), }; @@ -400,7 +400,7 @@ impl Quantization for ProductQuantizer { dimension: self.dimension, codebook: None, codebook_tensor: tensor.encode_to_vec(), - transposed: false, + transposed: args.map(|args| args.transposed).unwrap_or_default(), })?) } diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 374b580bf0..6fd185ea13 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -10,7 +10,7 @@ use std::{cmp::min, collections::HashMap, sync::Arc}; use arrow::datatypes::{self}; use arrow_array::{ cast::AsArray, - types::{Float32Type, UInt64Type, UInt8Type}, + types::{Float32Type, UInt64Type}, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, }; use arrow_array::{Array, ArrayRef}; @@ -169,24 +169,26 @@ impl ProductQuantizationStorage { .clone() .into(); - let Some(pq_col) = batch.column_by_name(PQ_CODE_COLUMN) else { - return Err(Error::Index { - message: format!("{PQ_CODE_COLUMN} column not found from PQ storage"), - location: location!(), - }); - }; - let pq_code = pq_col + if !transposed { + let pq_col = batch[PQ_CODE_COLUMN].as_fixed_size_list(); + let transposed_code = transpose( + pq_col.values().as_primitive(), + num_sub_vectors, + row_ids.len(), + ); + let pq_code_fsl = Arc::new(FixedSizeListArray::try_new_from_values( + transposed_code, + num_sub_vectors as i32, + )?); + batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code_fsl)?; + } + + let pq_code = batch[PQ_CODE_COLUMN] .as_fixed_size_list() .values() - .as_primitive::(); - - let pq_code: Arc = if transposed { - pq_code.clone().into() - } else { - let pq_code: Arc<_> = transpose(pq_code, num_sub_vectors, row_ids.len()).into(); - batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code.clone())?; - pq_code - }; + .as_primitive() + .clone() + .into(); Ok(Self { codebook, diff --git a/rust/lance-index/src/vector/quantizer.rs b/rust/lance-index/src/vector/quantizer.rs index 860c000227..1290a0f07b 100644 --- a/rust/lance-index/src/vector/quantizer.rs +++ b/rust/lance-index/src/vector/quantizer.rs @@ -161,6 +161,7 @@ pub struct QuantizationMetadata { // For PQ pub codebook_position: Option, pub codebook: Option, + pub transposed: bool, } #[async_trait] diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 0531cfd24e..21db5ea559 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -16,7 +16,9 @@ use lance_file::v2::reader::FileReaderOptions; use lance_file::v2::{reader::FileReader, writer::FileWriter}; use lance_index::vector::flat::storage::FlatStorage; use lance_index::vector::ivf::storage::IvfModel; -use lance_index::vector::quantizer::{QuantizationType, QuantizerBuildParams}; +use lance_index::vector::quantizer::{ + QuantizationMetadata, QuantizationType, QuantizerBuildParams, +}; use lance_index::vector::storage::STORAGE_METADATA_KEY; use lance_index::vector::v3::shuffler::IvfShufflerReader; use lance_index::vector::v3::subindex::SubIndexType; @@ -573,7 +575,13 @@ impl IvfIndexBuilde storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); // For now, each partition's metadata is just the quantizer, // it's all the same for now, so we just take the first one - let storage_partition_metadata = vec![quantizer.metadata(None)?.to_string()]; + let storage_partition_metadata = vec![quantizer + .metadata(Some(QuantizationMetadata { + codebook_position: Some(0), + codebook: None, + transposed: true, + }))? + .to_string()]; storage_writer.add_schema_metadata( STORAGE_METADATA_KEY, serde_json::to_string(&storage_partition_metadata)?, From 8018846191942f99a0b04c804d92fdc5c7ebe0f3 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 14:34:14 +0800 Subject: [PATCH 07/15] fix hnsw Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 6fd185ea13..4511d8eed2 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -559,10 +559,15 @@ impl PQDistCalculator { } } - fn get_pq_code(&self, id: u32) -> &[u8] { - let start = id as usize * self.num_sub_vectors; - let end = start + self.num_sub_vectors; - &self.pq_code.values()[start..end] + fn get_pq_code(&self, id: u32) -> Vec { + let num_vectors = self.pq_code.len() / self.num_sub_vectors; + self.pq_code + .values() + .iter() + .skip(id as usize) + .step_by(num_vectors) + .map(|&c| c as usize) + .collect() } } From 9fd8429581de833d3531b16639449f7c4eec0d76 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 14:42:06 +0800 Subject: [PATCH 08/15] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 4511d8eed2..71cb644639 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -575,9 +575,9 @@ impl DistCalculator for PQDistCalculator { fn distance(&self, id: u32) -> f32 { let pq_code = self.get_pq_code(id); pq_code - .iter() + .into_iter() .enumerate() - .map(|(i, &c)| self.distance_table[i * self.num_centroids + c as usize]) + .map(|(i, c)| self.distance_table[i * self.num_centroids + c]) .sum() } } From 388bcc76ce5572b1fc82a0a78d0da5c78449d8bb Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 16:28:57 +0800 Subject: [PATCH 09/15] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 24 ++++++---- rust/lance/src/index/vector/ivf.rs | 8 +++- rust/lance/src/index/vector/ivf/io.rs | 8 +++- rust/lance/src/index/vector/pq.rs | 58 +++++++++++++++-------- 4 files changed, 69 insertions(+), 29 deletions(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 71cb644639..1678f68e05 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -7,13 +7,13 @@ use std::{cmp::min, collections::HashMap, sync::Arc}; -use arrow::datatypes::{self}; +use arrow::datatypes::{self, UInt8Type}; use arrow_array::{ cast::AsArray, types::{Float32Type, UInt64Type}, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, }; -use arrow_array::{Array, ArrayRef}; +use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use deepsize::DeepSizeOf; @@ -172,9 +172,9 @@ impl ProductQuantizationStorage { if !transposed { let pq_col = batch[PQ_CODE_COLUMN].as_fixed_size_list(); let transposed_code = transpose( - pq_col.values().as_primitive(), - num_sub_vectors, + pq_col.values().as_primitive::(), row_ids.len(), + num_sub_vectors, ); let pq_code_fsl = Arc::new(FixedSizeListArray::try_new_from_values( transposed_code, @@ -346,13 +346,21 @@ impl ProductQuantizationStorage { } } -pub fn transpose(pq_codes: &UInt8Array, num_sub_vectors: usize, num_vectors: usize) -> UInt8Array { - let mut transposed_codes = vec![0; pq_codes.len()]; - for (vec_idx, codes) in pq_codes.values().chunks_exact(num_sub_vectors).enumerate() { +pub fn transpose( + original: &PrimitiveArray, + num_rows: usize, + num_columns: usize, +) -> PrimitiveArray +where + PrimitiveArray: From>, +{ + let mut transposed_codes = vec![T::default_value(); original.len()]; + for (vec_idx, codes) in original.values().chunks_exact(num_columns).enumerate() { for (sub_vec_idx, code) in codes.iter().enumerate() { - transposed_codes[sub_vec_idx * num_vectors + vec_idx] = *code; + transposed_codes[sub_vec_idx * num_rows + vec_idx] = *code; } } + transposed_codes.into() } diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 05bf683ac4..678ef7dfaa 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -36,6 +36,7 @@ use lance_file::{ }; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::pq::storage::transpose; use lance_index::vector::quantizer::QuantizationType; use lance_index::vector::v3::shuffler::IvfShuffler; use lance_index::vector::v3::subindex::{IvfSubIndex, SubIndexType}; @@ -1358,7 +1359,12 @@ impl RemapPageTask { ivf.offsets.push(writer.tell().await?); ivf.lengths .push(page.row_ids.as_ref().unwrap().len() as u32); - PlainEncoder::write(writer, &[page.code.as_ref().unwrap().as_ref()]).await?; + let original_pq = transpose( + page.code.as_ref().unwrap(), + page.pq.code_dim(), + page.row_ids.as_ref().unwrap().len(), + ); + PlainEncoder::write(writer, &[&original_pq]).await?; PlainEncoder::write(writer, &[page.row_ids.as_ref().unwrap().as_ref()]).await?; Ok(()) } diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index d9820bee24..090895e9b3 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -24,6 +24,7 @@ use lance_index::scalar::IndexWriter; use lance_index::vector::hnsw::HNSW; use lance_index::vector::hnsw::{builder::HnswBuildParams, HnswMetadata}; use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::pq::storage::transpose; use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::quantizer::{Quantization, Quantizer}; use lance_index::vector::v3::subindex::IvfSubIndex; @@ -199,9 +200,14 @@ pub(super) async fn write_pq_partitions( location: location!(), })?; if let Some(pq_code) = pq_index.code.as_ref() { + let original_pq_codes = transpose( + &pq_code, + pq_index.pq.num_sub_vectors, + pq_code.len() / pq_index.pq.code_dim(), + ); let fsl = Arc::new( FixedSizeListArray::try_new_from_values( - pq_code.as_ref().clone(), + original_pq_codes, pq_index.pq.code_dim() as i32, ) .unwrap(), diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 1770226bbb..f116de51e4 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -105,22 +105,35 @@ impl PQIndex { pre_filter: &dyn PreFilter, code: Arc, row_ids: Arc, - num_sub_vectors: i32, + _num_sub_vectors: i32, ) -> Result<(Arc, Arc)> { + let num_vectors = row_ids.len(); let indices_to_keep = pre_filter.filter_row_ids(Box::new(row_ids.values().iter())); let indices_to_keep = UInt64Array::from(indices_to_keep); let row_ids = take(row_ids.as_ref(), &indices_to_keep, None)?; let row_ids = Arc::new(as_primitive_array(&row_ids).clone()); - let code = FixedSizeListArray::try_new_from_values(code.as_ref().clone(), num_sub_vectors) - .unwrap(); - let code = take(&code, &indices_to_keep, None)?; - let code = as_fixed_size_list_array(&code).values().clone(); - let code = Arc::new(as_primitive_array(&code).clone()); + let code = Arc::new( + indices_to_keep + .values() + .iter() + .flat_map(|&idx| Self::get_pq_codes(&code, idx as usize, num_vectors)) + .collect(), + ); Ok((code, row_ids)) } + + fn get_pq_codes(transposed_codes: &UInt8Array, vec_idx: usize, num_vectors: usize) -> Vec { + transposed_codes + .values() + .iter() + .skip(vec_idx) + .step_by(num_vectors) + .cloned() + .collect() + } } #[async_trait] @@ -266,8 +279,8 @@ impl VectorIndex for PQIndex { let pq_codes = transpose( pq_codes.as_primitive(), - self.pq.num_sub_vectors, row_ids.len(), + self.pq.num_sub_vectors, ); Ok(Box::new(Self { @@ -287,29 +300,36 @@ impl VectorIndex for PQIndex { } fn remap(&mut self, mapping: &HashMap>) -> Result<()> { - let code = self - .code - .as_ref() - .unwrap() - .values() - .chunks_exact(self.pq.code_dim()); + let num_vectors = self.row_ids.as_ref().unwrap().len(); let row_ids = self.row_ids.as_ref().unwrap().values().iter(); + let transposed_codes = self.code.as_ref().unwrap(); let remapped = row_ids - .zip(code) - .filter_map(|(old_row_id, code)| { + .enumerate() + .filter_map(|(vec_idx, old_row_id)| { let new_row_id = mapping.get(old_row_id).cloned(); // If the row id is not in the mapping then this row is not remapped and we keep as is let new_row_id = new_row_id.unwrap_or(Some(*old_row_id)); - new_row_id.map(|new_row_id| (new_row_id, code)) + new_row_id.map(|new_row_id| { + ( + new_row_id, + Self::get_pq_codes(transposed_codes, vec_idx, num_vectors), + ) + }) }) .collect::>(); self.row_ids = Some(Arc::new(UInt64Array::from_iter_values( remapped.iter().map(|(row_id, _)| *row_id), ))); - self.code = Some(Arc::new(UInt8Array::from_iter_values( - remapped.into_iter().flat_map(|(_, code)| code).copied(), - ))); + + let pq_codes = + UInt8Array::from_iter_values(remapped.into_iter().flat_map(|(_, code)| code)); + let transposed_codes = transpose( + &pq_codes, + self.row_ids.as_ref().unwrap().len(), + self.pq.num_sub_vectors, + ); + self.code = Some(Arc::new(transposed_codes)); Ok(()) } From 048d9484cb933b0c2ee8a3f1e4341c3c50124ed5 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 16:55:52 +0800 Subject: [PATCH 10/15] fix with filters Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/distance.rs | 26 ---------------------- rust/lance/src/index/vector/pq.rs | 20 ++++++++++------- 2 files changed, 12 insertions(+), 34 deletions(-) diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index f98aabd846..9558d2ab0d 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -94,30 +94,4 @@ pub(super) fn compute_l2_distance( } distances - - // let iter = code.chunks_exact(num_sub_vectors * V); - // let distances = iter.clone().flat_map(|c| { - // let mut sums = [0.0_f32; V]; - // for i in (0..num_sub_vectors).step_by(C) { - // for (vec_idx, sum) in sums.iter_mut().enumerate() { - // let vec_start = vec_idx * num_sub_vectors; - // let s = c[vec_start + i..] - // .iter() - // .take(min(C, num_sub_vectors - i)) - // .enumerate() - // .map(|(k, c)| distance_table[(i + k) * num_centroids + *c as usize]) - // .sum::(); - // *sum += s; - // } - // } - // sums.into_iter() - // }); - // Remainder - // let remainder = iter.remainder().chunks(num_sub_vectors).map(|c| { - // c.iter() - // .enumerate() - // .map(|(sub_vec_idx, code)| distance_table[sub_vec_idx * num_centroids + *code as usize]) - // .sum::() - // }); - // distances.chain(remainder).collect() } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index f116de51e4..e8f1550c08 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -114,15 +114,19 @@ impl PQIndex { let row_ids = take(row_ids.as_ref(), &indices_to_keep, None)?; let row_ids = Arc::new(as_primitive_array(&row_ids).clone()); - let code = Arc::new( - indices_to_keep - .values() - .iter() - .flat_map(|&idx| Self::get_pq_codes(&code, idx as usize, num_vectors)) - .collect(), - ); + let code = code + .values() + .chunks_exact(num_vectors) + .flat_map(|c| { + let mut filtered = Vec::with_capacity(indices_to_keep.len()); + for idx in indices_to_keep.values() { + filtered.push(c[*idx as usize]); + } + filtered + }) + .collect(); - Ok((code, row_ids)) + Ok((Arc::new(code), row_ids)) } fn get_pq_codes(transposed_codes: &UInt8Array, vec_idx: usize, num_vectors: usize) -> Vec { From 43221200788db8d4f11ffe5b48830a67fa67d957 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 16:59:07 +0800 Subject: [PATCH 11/15] comments Signed-off-by: BubbleCal --- rust/lance/src/index/vector/ivf/io.rs | 2 +- rust/lance/src/index/vector/pq.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index 090895e9b3..8290f88ab2 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -201,7 +201,7 @@ pub(super) async fn write_pq_partitions( })?; if let Some(pq_code) = pq_index.code.as_ref() { let original_pq_codes = transpose( - &pq_code, + pq_code, pq_index.pq.num_sub_vectors, pq_code.len() / pq_index.pq.code_dim(), ); diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index e8f1550c08..e497517c24 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -52,6 +52,8 @@ pub struct PQIndex { pub pq: ProductQuantizer, /// PQ code + /// the PQ codes are stored in a transposed way, + /// call `Self::get_pq_codes` to get the PQ code for a specific vector. pub code: Option>, /// ROW Id used to refer to the actual row in dataset. From d9768235733e86ba44f5214b2f17abfc0d3ed99f Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 17:17:26 +0800 Subject: [PATCH 12/15] fix empty Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 1362143a2b..9368832467 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -142,6 +142,10 @@ impl ProductQuantizer { } pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result { + if code.is_empty() { + return Ok(Float32Array::from(Vec::::new())); + } + match self.distance_type { DistanceType::L2 => self.l2_distances(query, code), DistanceType::Cosine => { From 3e5e23a203b82c958b01c2d6e5a0c1a6fbf7462e Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 17:57:30 +0800 Subject: [PATCH 13/15] fix transpose Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 1678f68e05..92000418fa 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -354,6 +354,10 @@ pub fn transpose( where PrimitiveArray: From>, { + if original.is_empty() { + return original.clone(); + } + let mut transposed_codes = vec![T::default_value(); original.len()]; for (vec_idx, codes) in original.values().chunks_exact(num_columns).enumerate() { for (sub_vec_idx, code) in codes.iter().enumerate() { From 8c31845c0db033b0de1ec777cf7eb44c98cd4d91 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 18:20:37 +0800 Subject: [PATCH 14/15] fix ut Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 9368832467..15083fe1ca 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -141,6 +141,7 @@ impl ProductQuantizer { )?)) } + // the code must be transposed pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result { if code.is_empty() { return Ok(Float32Array::from(Vec::::new())); @@ -468,6 +469,7 @@ mod tests { use lance_linalg::kernels::argmin; use lance_testing::datagen::generate_random_array; use num_traits::Zero; + use storage::transpose; #[test] fn test_f16_pq_to_protobuf() { @@ -509,7 +511,8 @@ mod tests { let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8)); let query = generate_random_array(DIM); - let dists = pq.compute_distances(&query, &pq_code).unwrap(); + let transposed_pq_codes = transpose(&pq_code, TOTAL, 16); + let dists = pq.compute_distances(&query, &transposed_pq_codes).unwrap(); let sub_vec_len = DIM / 16; let expected = pq_code From b5302f733344b667814125d99a3496ff4a999969 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 13 Nov 2024 21:45:37 +0800 Subject: [PATCH 15/15] add test Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq.rs | 12 +-- rust/lance-index/src/vector/pq/distance.rs | 91 ++++++++++++++++++++-- 2 files changed, 88 insertions(+), 15 deletions(-) diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 15083fe1ca..411791b7e8 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -172,11 +172,11 @@ impl ProductQuantizer { #[cfg(target_feature = "avx512f")] { - Ok(self.compute_l2_distance::<16, 64>(&distance_table, code.values())) + Ok(self.compute_l2_distance(&distance_table, code.values())) } #[cfg(not(target_feature = "avx512f"))] { - Ok(self.compute_l2_distance::<8, 64>(&distance_table, code.values())) + Ok(self.compute_l2_distance(&distance_table, code.values())) } } @@ -289,12 +289,8 @@ impl ProductQuantizer { /// ------- /// The squared L2 distance. #[inline] - fn compute_l2_distance( - &self, - distance_table: &[f32], - code: &[u8], - ) -> Float32Array { - Float32Array::from(compute_l2_distance::( + fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array { + Float32Array::from(compute_l2_distance( distance_table, self.num_bits, self.num_sub_vectors, diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index 9558d2ab0d..416b859514 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::cmp::min; + use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2}; use super::{num_centroids, utils::get_sub_vector_centroids}; @@ -50,25 +52,20 @@ pub(super) fn build_distance_table_dot( /// Compute L2 distance from the query to all code. /// -/// Type parameters -/// --------------- -/// - C: the tile size of code-book to run at once. -/// - V: the tile size of PQ code to run at once. -/// /// Parameters /// ---------- /// - distance_table: the pre-computed L2 distance table. /// It is a flatten array of [num_sub_vectors, num_centroids] f32. /// - num_bits: the number of bits used for PQ. /// - num_sub_vectors: the number of sub-vectors. -/// - code: the PQ code to be used to compute the distances. +/// - code: the transposed PQ code to be used to compute the distances. /// /// Returns /// ------- /// The squared L2 distance. /// #[inline] -pub(super) fn compute_l2_distance( +pub(super) fn compute_l2_distance( distance_table: &[f32], num_bits: u32, num_sub_vectors: usize, @@ -95,3 +92,83 @@ pub(super) fn compute_l2_distance( distances } + +/// Compute L2 distance from the query to all code without transposing the code. +/// for testing only +/// +/// Type parameters +/// --------------- +/// - C: the tile size of code-book to run at once. +/// - V: the tile size of PQ code to run at once. +/// +#[allow(dead_code)] +fn compute_l2_distance_without_transposing( + distance_table: &[f32], + num_bits: u32, + num_sub_vectors: usize, + code: &[u8], +) -> Vec { + let num_centroids = num_centroids(num_bits); + let iter = code.chunks_exact(num_sub_vectors * V); + let distances = iter.clone().flat_map(|c| { + let mut sums = [0.0_f32; V]; + for i in (0..num_sub_vectors).step_by(C) { + for (vec_idx, sum) in sums.iter_mut().enumerate() { + let vec_start = vec_idx * num_sub_vectors; + let s = c[vec_start + i..] + .iter() + .take(min(C, num_sub_vectors - i)) + .enumerate() + .map(|(k, c)| distance_table[(i + k) * num_centroids + *c as usize]) + .sum::(); + *sum += s; + } + } + sums.into_iter() + }); + // Remainder + let remainder = iter.remainder().chunks(num_sub_vectors).map(|c| { + c.iter() + .enumerate() + .map(|(sub_vec_idx, code)| distance_table[sub_vec_idx * num_centroids + *code as usize]) + .sum::() + }); + distances.chain(remainder).collect() +} + +#[cfg(test)] +mod tests { + use crate::vector::pq::storage::transpose; + + use super::*; + use arrow_array::UInt8Array; + + #[test] + fn test_compute_on_transposed_codes() { + let num_vectors = 100; + let num_sub_vectors = 4; + let num_bits = 8; + let dimension = 16; + let codebook = + Vec::from_iter((0..num_sub_vectors * num_vectors * dimension).map(|v| v as f32)); + let query = Vec::from_iter((0..dimension).map(|v| v as f32)); + let distance_table = build_distance_table_l2(&codebook, num_bits, num_sub_vectors, &query); + + let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8)); + let pq_codes = UInt8Array::from_iter_values(pq_codes); + let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors); + let distances = compute_l2_distance( + &distance_table, + num_bits, + num_sub_vectors, + transposed_codes.values(), + ); + let expected = compute_l2_distance_without_transposing::<4, 1>( + &distance_table, + num_bits, + num_sub_vectors, + pq_codes.values(), + ); + assert_eq!(distances, expected); + } +}