Skip to content

Commit

Permalink
clean up; unify sketch loading for pairwise/multisearch
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Feb 1, 2024
1 parent 893e0a7 commit dbdff4a
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 283 deletions.
1 change: 0 additions & 1 deletion src/fastgather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use sourmash::selection::Selection;
// use camino;

use sourmash::prelude::Select;
use sourmash::signature::SigsTrait;

use crate::utils::{
consume_query_by_gather, load_collection, load_sketches_above_threshold, write_prefetch,
Expand Down
5 changes: 1 addition & 4 deletions src/fastmultigather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
use anyhow::Result;
use rayon::prelude::*;

use serde::Serialize;
use sourmash::prelude::Select;
use sourmash::selection::Selection;
use sourmash::sketch::Sketch;
use sourmash::storage::SigStore;
use sourmash::{selection, signature::Signature};

use std::sync::atomic;
use std::sync::atomic::AtomicUsize;

use std::collections::BinaryHeap;

use camino::{Utf8Path, Utf8PathBuf};
use camino::Utf8Path;

use crate::utils::{
consume_query_by_gather, load_collection, write_prefetch, PrefetchResult, ReportType,
Expand Down
1 change: 0 additions & 1 deletion src/mastiff_manygather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use anyhow::Result;
use rayon::prelude::*;

use sourmash::signature::Signature;
use sourmash::sketch::Sketch;
use std::path::Path;

Expand Down
108 changes: 37 additions & 71 deletions src/multisearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@ use rayon::prelude::*;

use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;

use std::sync::atomic;
use std::sync::atomic::AtomicUsize;

use sourmash::prelude::Select;
use sourmash::selection::Selection;
use sourmash::signature::SigsTrait;
use sourmash::sketch::Sketch;
use sourmash::storage::SigStore;

use crate::utils::{load_collection, ReportType};
use crate::utils::{load_collection, load_mh_with_name_and_md5, ReportType};

/// Search many queries against a list of signatures.
///
Expand All @@ -31,37 +27,14 @@ pub fn multisearch(
) -> Result<(), Box<dyn std::error::Error>> {
// Load all queries into memory at once.

// let queries = load_sketches_from_zip_or_pathlist(&querylist, &template, ReportType::Query)?;
let query_collection = load_collection(query_filepath, selection, ReportType::Query)?;
let mut queries: Vec<SigStore> = vec![];
for (idx, record) in query_collection.iter() {
if let Ok(sig) = query_collection.sig_from_record(record)
// .unwrap()
// .select(&selection) // if we select here, we downsample and the md5sum changes!
// ...which means we would lose the original md5sum that is used in the standard gather results.
{
queries.push(sig);
} else {
eprintln!("Failed to load 'against' record: {}", record.name());
}
}
let queries =
load_mh_with_name_and_md5(query_collection, &selection, ReportType::Query).unwrap();

// Load all against sketches into memory at once.
// let against = load_sketches_from_zip_or_pathlist(&againstlist, &template, ReportType::Against)?;
let against_collection = load_collection(against_filepath, selection, ReportType::Against)?;
let mut against: Vec<SigStore> = vec![];

for (idx, record) in against_collection.iter() {
if let Ok(sig) = against_collection.sig_from_record(record)
// .unwrap()
// .select(&selection) // if we select here, we downsample and the md5sum changes!
// ...which means we would lose the original md5sum that is used in the standard gather results.
{
against.push(sig);
} else {
eprintln!("Failed to load 'against' record: {}", record.name());
}
}
let against =
load_mh_with_name_and_md5(against_collection, &selection, ReportType::Against).unwrap();

// set up a multi-producer, single-consumer channel.
let (send, recv) = std::sync::mpsc::sync_channel(rayon::current_num_threads());
Expand Down Expand Up @@ -94,49 +67,42 @@ pub fn multisearch(

let send = against
.par_iter()
.filter_map(|target| {
.filter_map(|(against_mh, against_name, against_md5)| {
let mut results = vec![];

let ds_against_sig = target.clone().select(&selection).unwrap();
if let Some(against_mh) = ds_against_sig.minhash() {
// search for matches & save containment.
for query_sig in queries.iter() {
let i = processed_cmp.fetch_add(1, atomic::Ordering::SeqCst);
if i % 100000 == 0 {
eprintln!("Processed {} comparisons", i);
}
let ds_q = query_sig.clone().select(&selection).unwrap();
let query_mh = ds_q.minhash()?;
let overlap = query_mh.count_common(&against_mh, false).unwrap() as f64;
// use downsampled sizes
let query_size = query_mh.size() as f64;
let target_size = against_mh.size() as f64;

let containment_query_in_target = overlap / query_size;
let containment_in_target = overlap / target_size;
let max_containment = containment_query_in_target.max(containment_in_target);
let jaccard = overlap / (target_size + query_size - overlap);

if containment_query_in_target > threshold {
results.push((
query_sig.name(),
query_sig.md5sum(),
target.name(),
target.md5sum(),
containment_query_in_target,
max_containment,
jaccard,
overlap,
))
}
// search for matches & save containment.
for (query_mh, query_name, query_md5) in queries.iter() {
let i = processed_cmp.fetch_add(1, atomic::Ordering::SeqCst);
if i % 100000 == 0 {
eprintln!("Processed {} comparisons", i);
}
if results.is_empty() {
None
} else {
Some(results)

let overlap = query_mh.count_common(&against_mh, false).unwrap() as f64;
// use downsampled sizes
let query_size = query_mh.size() as f64;
let target_size = against_mh.size() as f64;

let containment_query_in_target = overlap / query_size;
let containment_in_target = overlap / target_size;
let max_containment = containment_query_in_target.max(containment_in_target);
let jaccard = overlap / (target_size + query_size - overlap);

if containment_query_in_target > threshold {
results.push((
query_name.clone(),
query_md5.clone(),
against_name.clone(),
against_md5.clone(),
containment_query_in_target,
max_containment,
jaccard,
overlap,
))
}
} else {
}
if results.is_empty() {
None
} else {
Some(results)
}
})
.flatten()
Expand Down
20 changes: 3 additions & 17 deletions src/pairwise.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use anyhow::Result;
/// pairwise: massively parallel in-memory pairwise comparisons.
use rayon::prelude::*;
use sourmash::sketch::minhash::KmerMinHash;

use std::fs::File;
use std::io::{BufWriter, Write};
Expand All @@ -11,12 +10,9 @@ use std::sync::atomic;
use std::sync::atomic::AtomicUsize;

use sourmash::signature::SigsTrait;
use sourmash::sketch::Sketch;

use crate::utils::{load_collection, ReportType};
use sourmash::prelude::Select;
use crate::utils::{load_collection, load_mh_with_name_and_md5, ReportType};
use sourmash::selection::Selection;
use sourmash::storage::SigStore;

/// Perform pairwise comparisons of all signatures in a list.
///
Expand All @@ -29,25 +25,15 @@ pub fn pairwise<P: AsRef<Path>>(
output: Option<P>,
) -> Result<(), Box<dyn std::error::Error>> {
// Load all sigs into memory at once.
let collection = load_collection(sigpath, selection, ReportType::Query)?;
let collection = load_collection(sigpath, selection, ReportType::Pairwise)?;

if collection.len() <= 1 {
bail!(
"Pairwise requires two or more sketches. Check input: '{:?}'",
&sigpath
)
}

let mut sketches: Vec<(KmerMinHash, String, String)> = Vec::new();
for (_idx, record) in collection.iter() {
if let Ok(sig) = collection.sig_from_record(record) {
if let Some(ds_mh) = sig.clone().select(&selection)?.minhash().cloned() {
sketches.push((ds_mh, record.name().to_string(), record.md5().to_string()));
}
} else {
eprintln!("Failed to load record: {}", record.name());
}
}
let sketches = load_mh_with_name_and_md5(collection, &selection, ReportType::Pairwise).unwrap();

// set up a multi-producer, single-consumer channel.
let (send, recv) = std::sync::mpsc::sync_channel(rayon::current_num_threads());
Expand Down
4 changes: 2 additions & 2 deletions src/python/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_bad_query(runtmp, capfd):
print(captured.err)

assert "WARNING: could not load sketches from path 'no-exist'" in captured.err
assert "WARNING: 1 query paths failed to load. See error messages above." in captured.err
assert "WARNING: 1 signature paths failed to load. See error messages above." in captured.err


def test_bad_query_2(runtmp, capfd):
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_nomatch_query(runtmp, capfd, zip_query):
captured = capfd.readouterr()
print(captured.err)

assert 'WARNING: skipped 1 query paths - no compatible signatures' in captured.err
assert 'WARNING: skipped 1 signature paths - no compatible signatures' in captured.err


@pytest.mark.parametrize("zip_db", [False, True])
Expand Down
Loading

0 comments on commit dbdff4a

Please sign in to comment.