Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXP: avoid cloning in gather, use iterators directly #3394

Open
wants to merge 4 commits into
base: latest
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions src/core/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
use crate::prelude::*;
use crate::selection::Selection;
use crate::signature::SigsTrait;
use crate::sketch::minhash::KmerMinHash;
use crate::sketch::minhash::KmerMinHashBTree;
use crate::storage::SigStore;
use crate::Error::CannotUpsampleScaled;
use crate::Result;
Expand Down Expand Up @@ -208,8 +208,8 @@

#[allow(clippy::too_many_arguments)]
pub fn calculate_gather_stats(
orig_query: &KmerMinHash,
remaining_query: KmerMinHash,
orig_query: &KmerMinHashBTree,
remaining_query: KmerMinHashBTree,
match_sig: SigStore,
match_size: usize,
gather_result_rank: u32,
Expand All @@ -219,6 +219,8 @@
calc_ani_ci: bool,
confidence: Option<f64>,
) -> Result<(GatherResult, (Vec<u64>, u64))> {
use crate::sketch::minhash::Intersection;

// get match_mh
let match_mh = match_sig.minhash().expect("cannot retrieve sketch");

Expand All @@ -234,10 +236,18 @@
.expect("cannot downsample match");

// calculate intersection
let isect = match_mh
.intersection(&remaining_query)
.expect("could not do intersection");
let isect_size = isect.0.len();
// Using Intersection directly here has a pretty big requirement:
// the sketches MUST BE COMPATIBLE
// (as in: same ksize, max_hash, hash_function, seed)
// this should be covered by the call to downsample_scaled above,
// but important to keep in mind in the future if code changes
let isect_values: Vec<_> = Intersection::new(match_mh.iter_mins(), remaining_query.iter_mins())
.copied()
.collect();

let isect_size = isect_values.len();
let isect = (isect_values, isect_size as u64);

trace!("isect_size: {}", isect_size);
trace!("query.size: {}", remaining_query.size());

Expand All @@ -246,7 +256,14 @@
(remaining_query.size() - isect_size) as u64 * remaining_query.scaled() as u64;

// stats for this match vs original query
let (intersect_orig, _) = match_mh.intersection_size(orig_query).unwrap();
// Using Intersection directly here has a pretty big requirement:
// the sketches MUST BE COMPATIBLE
// (as in: same ksize, max_hash, hash_function, seed)
// this should be covered by the call to downsample_scaled above,
// but important to keep in mind in the future if code changes
let intersect_orig =
Intersection::new(match_mh.iter_mins(), orig_query.iter_mins()).count() as u64;

let intersect_bp = match_mh.scaled() as u64 * intersect_orig;
let f_orig_query = intersect_orig as f64 / orig_query.size() as f64;
let f_match_orig = intersect_orig as f64 / match_mh.size() as f64;
Expand Down Expand Up @@ -303,12 +320,8 @@
// If abundance, calculate abund-related metrics (vs current query)
if calc_abund_stats {
// take abunds from subtracted query
let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&remaining_query) {
Ok((abunds, unique_weighted_found)) => (abunds, unique_weighted_found),
Err(e) => {
return Err(e);
}
};
let (abunds, unique_weighted_found) = match_mh
.inflated_abundances(remaining_query.iter_mins(), remaining_query.iter_abunds())?;

Check warning on line 324 in src/core/src/index/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/index/mod.rs#L324

Added line #L324 was not covered by tests

n_unique_weighted_found = unique_weighted_found;
sum_total_weighted_found = sum_weighted_found + n_unique_weighted_found;
Expand Down Expand Up @@ -399,6 +412,7 @@
orig_query.add_hash_with_abundance(8, 1);
orig_query.add_hash_with_abundance(10, 1); // Non-matching hash

let orig_query: KmerMinHashBTree = orig_query.into();
let query = orig_query.clone();
let total_weighted_hashes = orig_query.sum_abunds();

Expand Down
11 changes: 6 additions & 5 deletions src/core/src/index/revindex/disk_revindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,13 @@ impl RevIndexOps for RevIndex {
query_colors: QueryColors,
hash_to_color: HashToColor,
threshold: usize,
orig_query: &KmerMinHash,
orig_query: KmerMinHash,
selection: Option<Selection>,
) -> Result<Vec<GatherResult>> {
let mut match_size = usize::MAX;
let mut matches = vec![];
let mut query = KmerMinHashBTree::from(orig_query.clone());
let orig_query: KmerMinHashBTree = orig_query.into();
let mut query = orig_query.clone();
let mut sum_weighted_found = 0;
let _selection = selection.unwrap_or_else(|| self.collection.selection());
let total_weighted_hashes = orig_query.sum_abunds();
Expand Down Expand Up @@ -405,18 +406,18 @@ impl RevIndexOps for RevIndex {

// repeatedly downsample query, then extract to KmerMinHash
// => calculate_gather_stats
query = query
let query_mh = query
.clone()
.downsample_scaled(max_scaled)
.expect("cannot downsample query");
let query_mh = KmerMinHash::from(query.clone());

// just calculate essentials here
let gather_result_rank = matches.len() as u32;

// grab the specific intersection:
// Calculate stats
let (gather_result, isect) = calculate_gather_stats(
orig_query,
&orig_query,
query_mh,
match_sig,
match_size,
Expand Down
14 changes: 7 additions & 7 deletions src/core/src/index/revindex/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub trait RevIndexOps {
query_colors: QueryColors,
hash_to_color: HashToColor,
threshold: usize,
query: &KmerMinHash,
query: KmerMinHash,
selection: Option<Selection>,
) -> Result<Vec<GatherResult>>;

Expand Down Expand Up @@ -553,7 +553,7 @@ mod test {
query_colors,
hash_to_color,
0,
&query,
query,
Some(selection),
)?;

Expand Down Expand Up @@ -620,7 +620,7 @@ mod test {
query_colors,
hash_to_color,
5, // 50kb threshold
&query,
query,
Some(selection),
)?;

Expand Down Expand Up @@ -770,7 +770,7 @@ mod test {
query_colors,
hash_to_color,
0,
&query,
query,
Some(selection),
)?;

Expand Down Expand Up @@ -909,7 +909,7 @@ mod test {
query_colors,
hash_to_color,
0,
&query,
query.clone(),
Some(selection.clone()),
)
.expect("failed to gather!");
Expand All @@ -927,7 +927,7 @@ mod test {
query_colors,
hash_to_color,
0,
&query,
query.clone(),
Some(selection.clone()),
)?;
assert_eq!(matches_external, matches_internal);
Expand All @@ -944,7 +944,7 @@ mod test {
query_colors,
hash_to_color,
0,
&query,
query,
Some(selection.clone()),
)?;
assert_eq!(matches_external, matches_moved);
Expand Down
102 changes: 76 additions & 26 deletions src/core/src/sketch/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,10 @@
self.mins.iter()
}

pub fn iter_abunds(&self) -> Option<impl Iterator<Item = &u64>> {
self.abunds.as_ref().map(|abunds| abunds.iter())
}

pub fn abunds(&self) -> Option<Vec<u64>> {
self.abunds.clone()
}
Expand Down Expand Up @@ -828,33 +832,21 @@
Ok(())
}

pub fn inflated_abundances(&self, abunds_from: &KmerMinHash) -> Result<(Vec<u64>, u64), Error> {
self.check_compatible(abunds_from)?;
pub fn inflated_abundances<'a, 'b, M: Iterator<Item = &'a u64>, A: Iterator<Item = &'a u64>>(
&'b self,
mins_from: M,
abunds_from: Option<A>,
) -> Result<(Vec<u64>, u64), Error> {
//self.check_compatible(abunds_from)?;

// check that abunds_from has abundances
if abunds_from.abunds.is_none() {
if abunds_from.is_none() {
return Err(Error::NeedsAbundanceTracking);
}

let self_iter = self.mins.iter();
let abunds_iter = abunds_from.abunds.as_ref().unwrap().iter();
let abunds_from_iter = abunds_from.mins.iter().zip(abunds_iter);

let (abundances, total_abundance): (Vec<u64>, u64) = self_iter
.merge_join_by(abunds_from_iter, |&self_val, &(other_val, _)| {
self_val.cmp(other_val)
})
.filter_map(|either| match either {
itertools::EitherOrBoth::Both(_self_val, (_other_val, other_abund)) => {
Some(*other_abund)
}
_ => None,
})
.fold((Vec::new(), 0u64), |(mut acc_vec, acc_sum), abund| {
acc_vec.push(abund);
(acc_vec, acc_sum + abund)
});

Ok((abundances, total_abundance))
inflated_abundances(self_iter, mins_from, abunds_from)
}
}

Expand Down Expand Up @@ -912,21 +904,21 @@
}
}

struct Intersection<T, I: Iterator<Item = T>> {
pub(crate) struct Intersection<T, I: Iterator<Item = T>, J: Iterator<Item = T>> {
iter: Peekable<I>,
other: Peekable<I>,
other: Peekable<J>,
}

impl<T, I: Iterator<Item = T>> Intersection<T, I> {
pub fn new(left: I, right: I) -> Self {
impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>> Intersection<T, I, J> {
pub fn new(left: I, right: J) -> Self {

Check warning on line 913 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L913

Added line #L913 was not covered by tests
Intersection {
iter: left.peekable(),
other: right.peekable(),
}
}
}

impl<T: Ord, I: Iterator<Item = T>> Iterator for Intersection<T, I> {
impl<T: Ord, I: Iterator<Item = T>, J: Iterator<Item = T>> Iterator for Intersection<T, I, J> {
type Item = T;

fn next(&mut self) -> Option<T> {
Expand Down Expand Up @@ -1534,6 +1526,10 @@
self.mins.iter()
}

pub fn iter_abunds(&self) -> Option<impl Iterator<Item = &u64>> {
self.abunds.as_ref().map(|abunds| abunds.values())

Check warning on line 1530 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1529-L1530

Added lines #L1529 - L1530 were not covered by tests
}

pub fn abunds(&self) -> Option<Vec<u64>> {
self.abunds
.as_ref()
Expand All @@ -1551,6 +1547,13 @@
}
}

// Approximate total number of kmers
// this could be improved by generating an HLL estimate while sketching instead
// (for scaled minhashes)
pub fn n_unique_kmers(&self) -> u64 {
self.size() as u64 * self.scaled() as u64 // + (self.ksize - 1) for bp estimation

Check warning on line 1554 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1554

Added line #L1554 was not covered by tests
}

// create a downsampled copy of self
pub fn downsample_scaled(self, scaled: ScaledType) -> Result<KmerMinHashBTree, Error> {
if self.scaled() == scaled || self.scaled() == 0 {
Expand Down Expand Up @@ -1594,6 +1597,53 @@
self.size() as u64
}
}

pub fn inflated_abundances<'a, 'b, M: Iterator<Item = &'a u64>, A: Iterator<Item = &'a u64>>(
&'b self,
mins_from: M,
abunds_from: Option<A>,
) -> Result<(Vec<u64>, u64), Error> {
// check that abunds_from has abundances
if abunds_from.is_none() {
return Err(Error::NeedsAbundanceTracking);

Check warning on line 1608 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1607-L1608

Added lines #L1607 - L1608 were not covered by tests
}

let self_iter = self.mins.iter();

Check warning on line 1611 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1611

Added line #L1611 was not covered by tests

inflated_abundances(self_iter, mins_from, abunds_from)

Check warning on line 1613 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1613

Added line #L1613 was not covered by tests
}
}

fn inflated_abundances<
'a,
'b,
M: Iterator<Item = &'b u64>,
N: Iterator<Item = &'a u64>,
A: Iterator<Item = &'a u64>,
>(
self_iter: M,
mins_from: N,
abunds_from: Option<A>,
) -> Result<(Vec<u64>, u64), Error> {
let abunds_iter = abunds_from.unwrap();

Check warning on line 1628 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1628

Added line #L1628 was not covered by tests
let abunds_from_iter = mins_from.zip(abunds_iter);

let (abundances, total_abundance): (Vec<u64>, u64) = self_iter
.merge_join_by(abunds_from_iter, |&self_val, &(other_val, _)| {
self_val.cmp(other_val)

Check warning on line 1633 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1632-L1633

Added lines #L1632 - L1633 were not covered by tests
})
.filter_map(|either| match either {
itertools::EitherOrBoth::Both(_self_val, (_other_val, other_abund)) => {

Check warning on line 1636 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1636

Added line #L1636 was not covered by tests
Some(*other_abund)
}
_ => None,

Check warning on line 1639 in src/core/src/sketch/minhash.rs

View check run for this annotation

Codecov / codecov/patch

src/core/src/sketch/minhash.rs#L1639

Added line #L1639 was not covered by tests
})
.fold((Vec::new(), 0u64), |(mut acc_vec, acc_sum), abund| {
acc_vec.push(abund);
(acc_vec, acc_sum + abund)
});

Ok((abundances, total_abundance))
}

impl SigsTrait for KmerMinHashBTree {
Expand Down
4 changes: 2 additions & 2 deletions src/core/tests/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ fn test_inflated_abundances() {
// Attempt to inflate minhash_a using minhash_b's abundances
assert!(a.inflate(&b).is_ok());

let (abunds, total_abund) = a.inflated_abundances(&b).unwrap();
let (abunds, total_abund) = a.inflated_abundances(b.iter_mins(), b.iter_abunds()).unwrap();
assert_eq!(abunds, vec![2, 4]);
assert_eq!(total_abund, 6);
}
Expand All @@ -858,7 +858,7 @@ fn test_inflated_abunds_noabund() {
a.add_hash(10);
a.add_hash(20);
a.add_hash(30);
let result = a.inflated_abundances(&a);
let result = a.inflated_abundances(a.iter_mins(), a.iter_abunds());
assert!(matches!(
result,
Err(sourmash::Error::NeedsAbundanceTracking)
Expand Down
Loading