Skip to content

Commit

Permalink
perf: implement XTR for retrieving multivector (#3437)
Browse files Browse the repository at this point in the history
  • Loading branch information
BubbleCal authored Mar 4, 2025
1 parent eb16635 commit 87f055f
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 51 deletions.
5 changes: 3 additions & 2 deletions rust/lance-linalg/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ pub fn multivec_distance(
})
.unwrap_or(f32::NAN)
})
.map(|sim| 1.0 - sim)
.collect();
Ok(dists)
}
Expand All @@ -197,8 +198,8 @@ where
.as_primitive::<T>()
.values()
.chunks_exact(dim)
.map(|v| distance_type.func()(q, v))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.map(|v| 1.0 - distance_type.func()(q, v))
.max_by(|a, b| a.total_cmp(b))
.unwrap()
})
.sum()
Expand Down
81 changes: 36 additions & 45 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use crate::index::scalar::detect_scalar_index_type;
use crate::index::vector::utils::{get_vector_dim, get_vector_type};
use crate::index::DatasetIndexInternalExt;
use crate::io::exec::fts::{FlatFtsExec, FtsExec};
use crate::io::exec::knn::MultivectorScoringExec;
use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec};
use crate::io::exec::{get_physical_optimizer, LanceScanConfig};
use crate::io::exec::{
Expand All @@ -90,6 +91,9 @@ pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4;
lazy_static::lazy_static! {
pub static ref DEFAULT_FRAGMENT_READAHEAD: Option<usize> = std::env::var("LANCE_DEFAULT_FRAGMENT_READAHEAD")
.map(|val| Some(val.parse().unwrap())).unwrap_or(None);

pub static ref DEFAULT_XTR_OVERFETCH: u32 = std::env::var("LANCE_XTR_OVERFETCH")
.map(|val| val.parse().unwrap()).unwrap_or(10);
}

// We want to support ~256 concurrent reads to maximize throughput on cloud storage systems
Expand Down Expand Up @@ -1692,13 +1696,13 @@ impl Scanner {

// Find all deltas with the same index name.
let deltas = self.dataset.load_indices_by_name(&index.name).await?;
let (ann_node, is_multivec) = match vector_type {
DataType::FixedSizeList(_, _) => (self.ann(q, &deltas, filter_plan).await?, false),
DataType::List(_) => (self.multivec_ann(q, &deltas, filter_plan).await?, true),
let ann_node = match vector_type {
DataType::FixedSizeList(_, _) => self.ann(q, &deltas, filter_plan).await?,
DataType::List(_) => self.multivec_ann(q, &deltas, filter_plan).await?,
_ => unreachable!(),
};

let mut knn_node = if q.refine_factor.is_some() || is_multivec {
let mut knn_node = if q.refine_factor.is_some() {
let vector_projection = self
.dataset
.empty_projection()
Expand Down Expand Up @@ -2200,69 +2204,56 @@ impl Scanner {
index: &[Index],
filter_plan: &FilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
// we split the query procedure into two steps:
// 1. collect the candidates by vector searching on each query vector
// 2. scoring the candidates

let over_fetch_factor = *DEFAULT_XTR_OVERFETCH;

let prefilter_source = self.prefilter_source(filter_plan).await?;
let dim = get_vector_dim(self.dataset.schema(), &q.column)?;
// split the query multivectors

let num_queries = q.key.len() / dim;
let new_queries = (0..num_queries)
.map(|i| q.key.slice(i * dim, dim))
.map(|query_vec| {
let mut new_query = q.clone();
new_query.key = query_vec;
// with XTR, we don't need to refine the result with original vectors,
// but here we really need to over-fetch the candidates to reach good enough recall.
// TODO: improve the recall with WARP, expose this parameter to the users.
new_query.refine_factor = Some(over_fetch_factor);
new_query
});
let mut ann_nodes = Vec::with_capacity(new_queries.len());
let prefilter_source = self.prefilter_source(filter_plan).await?;
for query in new_queries {
// this produces `nprobes * k * over_fetch_factor * num_indices` candidates
let ann_node = new_knn_exec(
self.dataset.clone(),
index,
&query,
prefilter_source.clone(),
)?;
ann_nodes.push(ann_node);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let ann_node = Arc::new(
SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node)
.with_fetch(Some(q.k * over_fetch_factor as usize)),
);
ann_nodes.push(ann_node as Arc<dyn ExecutionPlan>);
}
let ann_node = Arc::new(UnionExec::new(ann_nodes));
let ann_node = Arc::new(RepartitionExec::try_new(
ann_node,
datafusion::physical_plan::Partitioning::RoundRobinBatch(1),
)?);
let schema = ann_node.schema();
// unique by row ids, and get the min distance although it is not used.
let group_expr = vec![(
expressions::col(ROW_ID, schema.as_ref())?,
ROW_ID.to_string(),
)];
// for now multivector is always with cosine distance so here convert the distance to `1 - distance`
// and calculate the sum across all rows with the same row id.
let sum_expr = AggregateExprBuilder::new(
functions_aggregate::sum::sum_udaf(),
vec![expressions::binary(
expressions::lit(1.0),
datafusion_expr::Operator::Minus,
expressions::cast(
expressions::col(DIST_COL, &schema)?,
&schema,
DataType::Float64,
)?,
&schema,
)?],
)
.schema(schema.clone())
.alias(DIST_COL)
.build()?;
let ann_node: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(group_expr),
vec![Arc::new(sum_expr)],
vec![None],
ann_node,
schema,
)?);

let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?);

let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: true,
descending: false,
nulls_first: false,
},
};
Expand Down
6 changes: 4 additions & 2 deletions rust/lance/src/index/vector/ivf/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ mod tests {
.into_iter()
.enumerate()
.map(|(i, dist)| (dist, i as u64))
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.sorted_by(|a, b| a.0.total_cmp(&b.0))
.take(k)
.collect()
}
Expand Down Expand Up @@ -1046,6 +1046,8 @@ mod tests {
}

async fn test_index_multivec(params: VectorIndexParams, nlist: usize, recall_requirement: f32) {
// we introduce XTR for performance, which would reduce the recall a little bit
let recall_requirement = recall_requirement * 0.9;
match params.metric_type {
DistanceType::Hamming => {
test_index_multivec_impl::<UInt8Type>(params, nlist, recall_requirement, 0..2)
Expand Down Expand Up @@ -1116,7 +1118,7 @@ mod tests {
let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type);
let gt_set = gt.iter().map(|r| r.1).collect::<HashSet<_>>();

let recall = row_ids.intersection(&gt_set).count() as f32 / 10.0;
let recall = row_ids.intersection(&gt_set).count() as f32 / 100.0;
assert!(
recall >= recall_requirement,
"recall: {}\n results: {:?}\n\ngt: {:?}",
Expand Down
Loading

0 comments on commit 87f055f

Please sign in to comment.