Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: transpose the PQ codes to improve search performance #3120

Merged
merged 15 commits into from
Nov 13, 2024
48 changes: 27 additions & 21 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ impl ProductQuantizer {
)?))
}

// the code must be transposed
pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
if code.is_empty() {
return Ok(Float32Array::from(Vec::<f32>::new()));
}

match self.distance_type {
DistanceType::L2 => self.l2_distances(query, code),
DistanceType::Cosine => {
Expand All @@ -167,11 +172,11 @@ impl ProductQuantizer {

#[cfg(target_feature = "avx512f")]
{
Ok(self.compute_l2_distance::<16, 64>(&distance_table, code.values()))
Ok(self.compute_l2_distance(&distance_table, code.values()))
}
#[cfg(not(target_feature = "avx512f"))]
{
Ok(self.compute_l2_distance::<8, 64>(&distance_table, code.values()))
Ok(self.compute_l2_distance(&distance_table, code.values()))
}
}

Expand Down Expand Up @@ -219,17 +224,19 @@ impl ProductQuantizer {
distance_table.extend(distances);
});

// Compute distance from the pre-compute table.
Ok(Float32Array::from_iter_values(
code.values().chunks_exact(self.num_sub_vectors).map(|c| {
c.iter()
.enumerate()
.map(|(sub_vec_idx, centroid)| {
distance_table[sub_vec_idx * 256 + *centroid as usize]
})
.sum::<f32>()
}),
))
let num_vectors = code.len() / self.num_sub_vectors;
let mut distances = vec![0.0; num_vectors];
let num_centroids = num_centroids(self.num_bits);
for (sub_vec_idx, vec_indices) in code.values().chunks_exact(num_vectors).enumerate() {
let dist_table = &distance_table[sub_vec_idx * num_centroids..];
vec_indices
.iter()
.zip(distances.iter_mut())
.for_each(|(&centroid_idx, sum)| {
*sum += dist_table[centroid_idx as usize];
});
}
Ok(distances.into())
}

fn build_l2_distance_table(&self, key: &dyn Array) -> Result<Vec<f32>> {
Expand Down Expand Up @@ -282,12 +289,8 @@ impl ProductQuantizer {
/// -------
/// The squared L2 distance.
#[inline]
fn compute_l2_distance<const C: usize, const V: usize>(
&self,
distance_table: &[f32],
code: &[u8],
) -> Float32Array {
Float32Array::from(compute_l2_distance::<C, V>(
fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array {
Float32Array::from(compute_l2_distance(
distance_table,
self.num_bits,
self.num_sub_vectors,
Expand Down Expand Up @@ -382,7 +385,7 @@ impl Quantization for ProductQuantizer {
}

fn metadata(&self, args: Option<QuantizationMetadata>) -> Result<serde_json::Value> {
let codebook_position = match args {
let codebook_position = match &args {
Some(args) => args.codebook_position,
None => Some(0),
};
Expand All @@ -398,6 +401,7 @@ impl Quantization for ProductQuantizer {
dimension: self.dimension,
codebook: None,
codebook_tensor: tensor.encode_to_vec(),
transposed: args.map(|args| args.transposed).unwrap_or_default(),
})?)
}

Expand Down Expand Up @@ -461,6 +465,7 @@ mod tests {
use lance_linalg::kernels::argmin;
use lance_testing::datagen::generate_random_array;
use num_traits::Zero;
use storage::transpose;

#[test]
fn test_f16_pq_to_protobuf() {
Expand Down Expand Up @@ -502,7 +507,8 @@ mod tests {
let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8));
let query = generate_random_array(DIM);

let dists = pq.compute_distances(&query, &pq_code).unwrap();
let transposed_pq_codes = transpose(&pq_code, TOTAL, 16);
let dists = pq.compute_distances(&query, &transposed_pq_codes).unwrap();

let sub_vec_len = DIM / 16;
let expected = pq_code
Expand Down
82 changes: 75 additions & 7 deletions rust/lance-index/src/vector/pq/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,63 @@ pub(super) fn build_distance_table_dot<T: Dot>(

/// Compute L2 distance from the query to all code.
///
/// Type parameters
/// ---------------
/// - C: the tile size of code-book to run at once.
/// - V: the tile size of PQ code to run at once.
///
/// Parameters
/// ----------
/// - distance_table: the pre-computed L2 distance table.
/// It is a flatten array of [num_sub_vectors, num_centroids] f32.
/// - num_bits: the number of bits used for PQ.
/// - num_sub_vectors: the number of sub-vectors.
/// - code: the PQ code to be used to compute the distances.
/// - code: the transposed PQ code to be used to compute the distances.
///
/// Returns
/// -------
/// The squared L2 distance.
///
#[inline]
pub(super) fn compute_l2_distance<const C: usize, const V: usize>(
pub(super) fn compute_l2_distance(
distance_table: &[f32],
num_bits: u32,
num_sub_vectors: usize,
code: &[u8],
) -> Vec<f32> {
// here `code` has been transposed,
// so code[i][j] is the code of i-th sub-vector of the j-th vector,
// and `code` is a flatten array of [num_sub_vectors, num_vectors] u8,
// so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector.

// `distance_table` is a flatten array of [num_sub_vectors, num_centroids] f32,
let num_vectors = code.len() / num_sub_vectors;
let mut distances = vec![0.0_f32; num_vectors];
let num_centroids = num_centroids(num_bits);
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let dist_table = &distance_table[sub_vec_idx * num_centroids..];
vec_indices
.iter()
.zip(distances.iter_mut())
.for_each(|(&centroid_idx, sum)| {
*sum += dist_table[centroid_idx as usize];
});
}

distances
}

/// Compute L2 distance from the query to all code without transposing the code.
/// for testing only
///
/// Type parameters
/// ---------------
/// - C: the tile size of code-book to run at once.
/// - V: the tile size of PQ code to run at once.
///
#[allow(dead_code)]
fn compute_l2_distance_without_transposing<const C: usize, const V: usize>(
distance_table: &[f32],
num_bits: u32,
num_sub_vectors: usize,
code: &[u8],
) -> Vec<f32> {
let num_centroids = num_centroids(num_bits);
let iter = code.chunks_exact(num_sub_vectors * V);
let distances = iter.clone().flat_map(|c| {
let mut sums = [0.0_f32; V];
Expand All @@ -104,3 +135,40 @@ pub(super) fn compute_l2_distance<const C: usize, const V: usize>(
});
distances.chain(remainder).collect()
}

#[cfg(test)]
mod tests {
use crate::vector::pq::storage::transpose;

use super::*;
use arrow_array::UInt8Array;

#[test]
fn test_compute_on_transposed_codes() {
let num_vectors = 100;
let num_sub_vectors = 4;
let num_bits = 8;
let dimension = 16;
let codebook =
Vec::from_iter((0..num_sub_vectors * num_vectors * dimension).map(|v| v as f32));
let query = Vec::from_iter((0..dimension).map(|v| v as f32));
let distance_table = build_distance_table_l2(&codebook, num_bits, num_sub_vectors, &query);

let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8));
let pq_codes = UInt8Array::from_iter_values(pq_codes);
let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors);
let distances = compute_l2_distance(
&distance_table,
num_bits,
num_sub_vectors,
transposed_codes.values(),
);
let expected = compute_l2_distance_without_transposing::<4, 1>(
&distance_table,
num_bits,
num_sub_vectors,
pq_codes.values(),
);
assert_eq!(distances, expected);
}
}
Loading
Loading