From bdc19d585301474eaf71bcaff90a39cd731dc562 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Sat, 31 Oct 2020 23:07:23 +0000 Subject: [PATCH] Add HyperLogLog implementation (#1223) Implement a HyperLogLog sketch based on the `khmer` implementation but using the estimator from ["New cardinality estimation algorithms for HyperLogLog sketches"](http://oertl.github.io/hyperloglog-sketch-estimation-paper/paper/paper.pdf) (also implemented in `dashing`). This PR also moves `add_sequence` and `add_protein` to `SigsTrait`, closing #1057. The encoding data and methods (`hp`, `dayhoff`, `aa` and `HashFunctions`) was in the MinHash source file, and since it is more general-purpose it was moved to a new module `encodings`, which is then used by `SigsTrait`. (these changes are both spun off #1201) --- Makefile | 1 + include/sourmash.h | 37 ++ sourmash/hll.py | 112 ++++ src/core/Cargo.toml | 11 +- src/core/src/cmd.rs | 3 +- src/core/src/encodings.rs | 358 ++++++++++ src/core/src/errors.rs | 6 + src/core/src/ffi/hyperloglog.rs | 217 ++++++ src/core/src/ffi/minhash.rs | 5 +- src/core/src/ffi/mod.rs | 1 + src/core/src/ffi/signature.rs | 2 +- src/core/src/ffi/utils.rs | 4 +- src/core/src/from.rs | 6 +- src/core/src/lib.rs | 2 + src/core/src/signature.rs | 200 +++++- src/core/src/sketch/hyperloglog/estimators.rs | 178 +++++ src/core/src/sketch/hyperloglog/mod.rs | 373 +++++++++++ src/core/src/sketch/minhash.rs | 627 ++---------------- src/core/src/sketch/mod.rs | 7 +- src/core/src/sketch/nodegraph.rs | 2 +- src/core/src/sketch/ukhs.rs | 39 -- src/core/src/wasm.rs | 5 +- src/core/tests/minhash.rs | 5 +- tests/test_hll.py | 126 ++++ 24 files changed, 1642 insertions(+), 685 deletions(-) create mode 100644 sourmash/hll.py create mode 100644 src/core/src/encodings.rs create mode 100644 src/core/src/ffi/hyperloglog.rs create mode 100644 src/core/src/sketch/hyperloglog/estimators.rs create mode 100644 src/core/src/sketch/hyperloglog/mod.rs delete mode 100644 src/core/src/sketch/ukhs.rs create mode 100644 tests/test_hll.py diff --git a/Makefile b/Makefile index e8a3fa5d69..4644c6345f 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,7 @@ doc: build .PHONY cd doc && make html include/sourmash.h: src/core/src/lib.rs \ + src/core/src/ffi/hyperloglog.rs \ src/core/src/ffi/minhash.rs \ src/core/src/ffi/signature.rs \ src/core/src/ffi/nodegraph.rs \ diff --git a/include/sourmash.h b/include/sourmash.h index 401c6b3407..b5e111b662 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -35,6 +35,7 @@ enum SourmashErrorCode { SOURMASH_ERROR_CODE_INVALID_HASH_FUNCTION = 1104, SOURMASH_ERROR_CODE_READ_DATA = 1201, SOURMASH_ERROR_CODE_STORAGE = 1202, + SOURMASH_ERROR_CODE_HLL_PRECISION_BOUNDS = 1301, SOURMASH_ERROR_CODE_IO = 100001, SOURMASH_ERROR_CODE_UTF8_ERROR = 100002, SOURMASH_ERROR_CODE_PARSE_INT = 100003, @@ -45,6 +46,8 @@ typedef uint32_t SourmashErrorCode; typedef struct SourmashComputeParameters SourmashComputeParameters; +typedef struct SourmashHyperLogLog SourmashHyperLogLog; + typedef struct SourmashKmerMinHash SourmashKmerMinHash; typedef struct SourmashNodegraph SourmashNodegraph; @@ -115,6 +118,40 @@ bool computeparams_track_abundance(const SourmashComputeParameters *ptr); uint64_t hash_murmur(const char *kmer, uint64_t seed); +void hll_add_hash(SourmashHyperLogLog *ptr, uint64_t hash); + +void hll_add_sequence(SourmashHyperLogLog *ptr, const char *sequence, uintptr_t insize, bool force); + +uintptr_t hll_cardinality(const SourmashHyperLogLog *ptr); + +double hll_containment(const SourmashHyperLogLog *ptr, const SourmashHyperLogLog *optr); + +void hll_free(SourmashHyperLogLog *ptr); + +SourmashHyperLogLog *hll_from_buffer(const char *ptr, uintptr_t insize); + +SourmashHyperLogLog *hll_from_path(const char *filename); + +uintptr_t hll_intersection_size(const SourmashHyperLogLog *ptr, const SourmashHyperLogLog *optr); + +uintptr_t hll_ksize(const SourmashHyperLogLog *ptr); + +uintptr_t hll_matches(const SourmashHyperLogLog *ptr, const SourmashKmerMinHash *mh_ptr); + +void hll_merge(SourmashHyperLogLog *ptr, const SourmashHyperLogLog *optr); + +SourmashHyperLogLog *hll_new(void); + +void hll_save(const SourmashHyperLogLog *ptr, const char *filename); + +double hll_similarity(const SourmashHyperLogLog *ptr, const SourmashHyperLogLog *optr); + +const uint8_t *hll_to_buffer(const SourmashHyperLogLog *ptr, uintptr_t *size); + +void hll_update_mh(SourmashHyperLogLog *ptr, const SourmashKmerMinHash *optr); + +SourmashHyperLogLog *hll_with_error_rate(double error_rate, uintptr_t ksize); + void kmerminhash_add_from(SourmashKmerMinHash *ptr, const SourmashKmerMinHash *other); void kmerminhash_add_hash(SourmashKmerMinHash *ptr, uint64_t h); diff --git a/sourmash/hll.py b/sourmash/hll.py new file mode 100644 index 0000000000..c98ded5e8b --- /dev/null +++ b/sourmash/hll.py @@ -0,0 +1,112 @@ +# -*- coding: UTF-8 -*- + +import sys +from tempfile import NamedTemporaryFile + +from ._lowlevel import ffi, lib +from .utils import RustObject, rustcall, decode_str +from .exceptions import SourmashError +from .minhash import to_bytes, MinHash + + +class HLL(RustObject): + __dealloc_func__ = lib.hll_free + + def __init__(self, error_rate, ksize): + self._objptr = lib.hll_with_error_rate(error_rate, ksize) + + def __len__(self): + return self.cardinality() + + def cardinality(self): + return self._methodcall(lib.hll_cardinality) + + @property + def ksize(self): + return self._methodcall(lib.hll_ksize) + + def add_sequence(self, sequence, force=False): + "Add a sequence into the sketch." + self._methodcall(lib.hll_add_sequence, to_bytes(sequence), len(sequence), force) + + def add_kmer(self, kmer): + "Add a kmer into the sketch." + if len(kmer) != self.ksize: + raise ValueError("kmer to add is not {} in length".format(self.ksize)) + self.add_sequence(kmer) + + def add(self, h): + if isinstance(h, str): + return self.add_kmer(h) + return self._methodcall(lib.hll_add_hash, h) + + def update(self, other): + if isinstance(other, HLL): + return self._methodcall(lib.hll_merge, other._objptr) + elif isinstance(other, MinHash): + return self._methodcall(lib.hll_update_mh, other._objptr) + else: + # FIXME: we could take sets here too (or anything that can be + # converted to a list of ints...) + raise TypeError("Must be a HyperLogLog or MinHash") + + def similarity(self, other): + if isinstance(other, HLL): + return self._methodcall(lib.hll_similarity, other._objptr) + else: + # FIXME: we could take sets here too (or anything that can be + # converted to a list of ints...) + raise TypeError("other must be a HyperLogLog") + + def containment(self, other): + if isinstance(other, HLL): + return self._methodcall(lib.hll_containment, other._objptr) + else: + # FIXME: we could take sets here too (or anything that can be + # converted to a list of ints...) + raise TypeError("other must be a HyperLogLog") + + def intersection(self, other): + if isinstance(other, HLL): + return self._methodcall(lib.hll_intersection_size, other._objptr) + else: + # FIXME: we could take sets here too (or anything that can be + # converted to a list of ints...) + raise TypeError("other must be a HyperLogLog") + + @staticmethod + def load(filename): + hll_ptr = rustcall(lib.hll_from_path, to_bytes(filename)) + return HLL._from_objptr(hll_ptr) + + @staticmethod + def from_buffer(buf): + hll_ptr = rustcall(lib.hll_from_buffer, buf, len(buf)) + return HLL._from_objptr(hll_ptr) + + def save(self, filename): + self._methodcall(lib.hll_save, to_bytes(filename)) + + def to_bytes(self, compression=1): + size = ffi.new("uintptr_t *") + rawbuf = self._methodcall(lib.hll_to_buffer, size) + size = size[0] + + rawbuf = ffi.gc(rawbuf, lambda o: lib.nodegraph_buffer_free(o, size), size) + buf = ffi.buffer(rawbuf, size) + + return buf + + def count(self, h): + self.add(h) + + def get(self, h): + raise NotImplementedError("HLL doesn't support membership query") + + def matches(self, mh): + if not isinstance(mh, MinHash): + # FIXME: we could take sets here too (or anything that can be + # converted to a list of ints...) + raise ValueError("mh must be a MinHash") + + return self._methodcall(lib.hll_matches, mh._objptr) diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 36087c5f07..cb41e221b9 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -22,17 +22,20 @@ from-finch = ["finch"] parallel = ["rayon"] [dependencies] +az = "1.0.0" backtrace = "=0.3.46" # later versions require rust 1.40 +bytecount = "0.6.0" byteorder = "1.3.4" cfg-if = "1.0" -failure = "0.1.8" # can remove after .backtrace() is available in std::error::Error finch = { version = "0.3.0", optional = true } fixedbitset = "0.3.0" +getset = "0.1.1" log = "0.4.8" md5 = "0.7.0" murmurhash3 = "0.0.5" niffler = { version = "2.2.0", default-features = false, features = [ "gz" ] } nohash-hasher = "0.2.0" +num-iter = "0.1.41" once_cell = "1.3.1" rayon = { version = "1.3.0", optional = true } serde = { version = "1.0.110", features = ["derive"] } @@ -40,7 +43,6 @@ serde_json = "1.0.53" primal-check = "0.2.3" thiserror = "1.0" typed-builder = "0.7.0" -getset = "0.1.1" [target.'cfg(all(target_arch = "wasm32", target_vendor="unknown"))'.dependencies.wasm-bindgen] version = "0.2.62" @@ -58,15 +60,12 @@ wasm-opt = false # https://github.com/rustwasm/wasm-pack/issues/886 [dev-dependencies] assert_matches = "1.3.0" criterion = "0.3.2" +needletail = { version = "0.4.0", default-features = false } predicates = "1.0.4" proptest = { version = "0.9.6", default-features = false, features = ["std"]} # Upgrade to 0.10 requires rust 1.39 rand = "0.7.3" tempfile = "3.1.0" -[dev-dependencies.needletail] -version = "0.4.0" -default-features = false - [[bench]] name = "index" harness = false diff --git a/src/core/src/cmd.rs b/src/core/src/cmd.rs index 7a7f5c33e2..9df25ef931 100644 --- a/src/core/src/cmd.rs +++ b/src/core/src/cmd.rs @@ -4,9 +4,10 @@ use wasm_bindgen::prelude::*; use getset::{CopyGetters, Getters, Setters}; use typed_builder::TypedBuilder; +use crate::encodings::HashFunctions; use crate::index::MHBT; use crate::signature::Signature; -use crate::sketch::minhash::{max_hash_for_scaled, HashFunctions, KmerMinHashBTree}; +use crate::sketch::minhash::{max_hash_for_scaled, KmerMinHashBTree}; use crate::sketch::Sketch; use crate::Error; diff --git a/src/core/src/encodings.rs b/src/core/src/encodings.rs new file mode 100644 index 0000000000..0ceddc2cbb --- /dev/null +++ b/src/core/src/encodings.rs @@ -0,0 +1,358 @@ +use std::collections::HashMap; +use std::convert::TryFrom; +use std::iter::Iterator; +use std::str; + +use once_cell::sync::Lazy; +#[cfg(all(target_arch = "wasm32", target_vendor = "unknown"))] +use wasm_bindgen::prelude::*; + +use crate::Error; + +#[cfg_attr(all(target_arch = "wasm32", target_vendor = "unknown"), wasm_bindgen)] +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u32)] +pub enum HashFunctions { + murmur64_DNA = 1, + murmur64_protein = 2, + murmur64_dayhoff = 3, + murmur64_hp = 4, +} + +impl HashFunctions { + pub fn dna(&self) -> bool { + *self == HashFunctions::murmur64_DNA + } + + pub fn protein(&self) -> bool { + *self == HashFunctions::murmur64_protein + } + + pub fn dayhoff(&self) -> bool { + *self == HashFunctions::murmur64_dayhoff + } + + pub fn hp(&self) -> bool { + *self == HashFunctions::murmur64_hp + } +} + +impl std::fmt::Display for HashFunctions { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}", + match self { + HashFunctions::murmur64_DNA => "dna", + HashFunctions::murmur64_protein => "protein", + HashFunctions::murmur64_dayhoff => "dayhoff", + HashFunctions::murmur64_hp => "hp", + } + ) + } +} + +impl TryFrom<&str> for HashFunctions { + type Error = Error; + + fn try_from(moltype: &str) -> Result { + match moltype.to_lowercase().as_ref() { + "dna" => Ok(HashFunctions::murmur64_DNA), + "dayhoff" => Ok(HashFunctions::murmur64_dayhoff), + "hp" => Ok(HashFunctions::murmur64_hp), + "protein" => Ok(HashFunctions::murmur64_protein), + _ => unimplemented!(), + } + } +} + +const COMPLEMENT: [u8; 256] = { + let mut lookup = [0; 256]; + lookup[b'A' as usize] = b'T'; + lookup[b'C' as usize] = b'G'; + lookup[b'G' as usize] = b'C'; + lookup[b'T' as usize] = b'A'; + lookup[b'N' as usize] = b'N'; + lookup +}; + +#[inline] +pub fn revcomp(seq: &[u8]) -> Vec { + seq.iter() + .rev() + .map(|nt| COMPLEMENT[*nt as usize]) + .collect() +} + +static CODONTABLE: Lazy> = Lazy::new(|| { + [ + // F + ("TTT", b'F'), + ("TTC", b'F'), + // L + ("TTA", b'L'), + ("TTG", b'L'), + // S + ("TCT", b'S'), + ("TCC", b'S'), + ("TCA", b'S'), + ("TCG", b'S'), + ("TCN", b'S'), + // Y + ("TAT", b'Y'), + ("TAC", b'Y'), + // * + ("TAA", b'*'), + ("TAG", b'*'), + // * + ("TGA", b'*'), + // C + ("TGT", b'C'), + ("TGC", b'C'), + // W + ("TGG", b'W'), + // L + ("CTT", b'L'), + ("CTC", b'L'), + ("CTA", b'L'), + ("CTG", b'L'), + ("CTN", b'L'), + // P + ("CCT", b'P'), + ("CCC", b'P'), + ("CCA", b'P'), + ("CCG", b'P'), + ("CCN", b'P'), + // H + ("CAT", b'H'), + ("CAC", b'H'), + // Q + ("CAA", b'Q'), + ("CAG", b'Q'), + // R + ("CGT", b'R'), + ("CGC", b'R'), + ("CGA", b'R'), + ("CGG", b'R'), + ("CGN", b'R'), + // I + ("ATT", b'I'), + ("ATC", b'I'), + ("ATA", b'I'), + // M + ("ATG", b'M'), + // T + ("ACT", b'T'), + ("ACC", b'T'), + ("ACA", b'T'), + ("ACG", b'T'), + ("ACN", b'T'), + // N + ("AAT", b'N'), + ("AAC", b'N'), + // K + ("AAA", b'K'), + ("AAG", b'K'), + // S + ("AGT", b'S'), + ("AGC", b'S'), + // R + ("AGA", b'R'), + ("AGG", b'R'), + // V + ("GTT", b'V'), + ("GTC", b'V'), + ("GTA", b'V'), + ("GTG", b'V'), + ("GTN", b'V'), + // A + ("GCT", b'A'), + ("GCC", b'A'), + ("GCA", b'A'), + ("GCG", b'A'), + ("GCN", b'A'), + // D + ("GAT", b'D'), + ("GAC", b'D'), + // E + ("GAA", b'E'), + ("GAG", b'E'), + // G + ("GGT", b'G'), + ("GGC", b'G'), + ("GGA", b'G'), + ("GGG", b'G'), + ("GGN", b'G'), + ] + .iter() + .cloned() + .collect() +}); + +// Dayhoff table from +// Peris, P., López, D., & Campos, M. (2008). +// IgTM: An algorithm to predict transmembrane domains and topology in +// proteins. BMC Bioinformatics, 9(1), 1029–11. +// http://doi.org/10.1186/1471-2105-9-367 +// +// Original source: +// Dayhoff M. O., Schwartz R. M., Orcutt B. C. (1978). +// A model of evolutionary change in proteins, +// in Atlas of Protein Sequence and Structure, +// ed Dayhoff M. O., editor. +// (Washington, DC: National Biomedical Research Foundation; ), 345–352. +// +// | Amino acid | Property | Dayhoff | +// |---------------|-----------------------|---------| +// | C | Sulfur polymerization | a | +// | A, G, P, S, T | Small | b | +// | D, E, N, Q | Acid and amide | c | +// | H, K, R | Basic | d | +// | I, L, M, V | Hydrophobic | e | +// | F, W, Y | Aromatic | f | +static DAYHOFFTABLE: Lazy> = Lazy::new(|| { + [ + // a + (b'C', b'a'), + // b + (b'A', b'b'), + (b'G', b'b'), + (b'P', b'b'), + (b'S', b'b'), + (b'T', b'b'), + // c + (b'D', b'c'), + (b'E', b'c'), + (b'N', b'c'), + (b'Q', b'c'), + // d + (b'H', b'd'), + (b'K', b'd'), + (b'R', b'd'), + // e + (b'I', b'e'), + (b'L', b'e'), + (b'M', b'e'), + (b'V', b'e'), + // e + (b'F', b'f'), + (b'W', b'f'), + (b'Y', b'f'), + ] + .iter() + .cloned() + .collect() +}); + +// HP Hydrophobic/hydrophilic mapping +// From: Phillips, R., Kondev, J., Theriot, J. (2008). +// Physical Biology of the Cell. New York: Garland Science, Taylor & Francis Group. ISBN: 978-0815341635 + +// +// | Amino acid | HP +// |---------------------------------------|---------| +// | A, F, G, I, L, M, P, V, W, Y | h | +// | N, C, S, T, D, E, R, H, K, Q | p | +static HPTABLE: Lazy> = Lazy::new(|| { + [ + // h + (b'A', b'h'), + (b'F', b'h'), + (b'G', b'h'), + (b'I', b'h'), + (b'L', b'h'), + (b'M', b'h'), + (b'P', b'h'), + (b'V', b'h'), + (b'W', b'h'), + (b'Y', b'h'), + // p + (b'N', b'p'), + (b'C', b'p'), + (b'S', b'p'), + (b'T', b'p'), + (b'D', b'p'), + (b'E', b'p'), + (b'R', b'p'), + (b'H', b'p'), + (b'K', b'p'), + (b'Q', b'p'), + ] + .iter() + .cloned() + .collect() +}); + +#[inline] +pub fn translate_codon(codon: &[u8]) -> Result { + if codon.len() == 1 { + return Ok(b'X'); + } + + if codon.len() == 2 { + let mut v = codon.to_vec(); + v.push(b'N'); + match CODONTABLE.get(str::from_utf8(v.as_slice()).unwrap()) { + Some(aa) => return Ok(*aa), + None => return Ok(b'X'), + } + } + + if codon.len() == 3 { + match CODONTABLE.get(str::from_utf8(codon).unwrap()) { + Some(aa) => return Ok(*aa), + None => return Ok(b'X'), + } + } + + Err(Error::InvalidCodonLength { + message: format!("{}", codon.len()), + }) +} + +#[inline] +pub fn aa_to_dayhoff(aa: u8) -> u8 { + match DAYHOFFTABLE.get(&aa) { + Some(letter) => *letter, + None => b'X', + } +} + +pub fn aa_to_hp(aa: u8) -> u8 { + match HPTABLE.get(&aa) { + Some(letter) => *letter, + None => b'X', + } +} + +#[inline] +pub fn to_aa(seq: &[u8], dayhoff: bool, hp: bool) -> Result, Error> { + let mut converted: Vec = Vec::with_capacity(seq.len() / 3); + + for chunk in seq.chunks(3) { + if chunk.len() < 3 { + break; + } + + let residue = translate_codon(chunk)?; + if dayhoff { + converted.push(aa_to_dayhoff(residue) as u8); + } else if hp { + converted.push(aa_to_hp(residue) as u8); + } else { + converted.push(residue); + } + } + + Ok(converted) +} + +pub const VALID: [bool; 256] = { + let mut lookup = [false; 256]; + lookup[b'A' as usize] = true; + lookup[b'C' as usize] = true; + lookup[b'G' as usize] = true; + lookup[b'T' as usize] = true; + lookup +}; diff --git a/src/core/src/errors.rs b/src/core/src/errors.rs index 8c86319f07..0747a3748b 100644 --- a/src/core/src/errors.rs +++ b/src/core/src/errors.rs @@ -39,6 +39,9 @@ pub enum SourmashError { #[error("Codon is invalid length: {message}")] InvalidCodonLength { message: String }, + #[error("Set error rate to a value smaller than 0.367696 and larger than 0.00203125")] + HLLPrecisionBounds, + #[error(transparent)] ReadDataError(#[from] crate::index::storage::ReadDataError), @@ -87,6 +90,8 @@ pub enum SourmashErrorCode { // index-related errors ReadData = 12_01, Storage = 12_02, + // HLL errors + HLLPrecisionBounds = 13_01, // external errors Io = 100_001, Utf8Error = 100_002, @@ -114,6 +119,7 @@ impl SourmashErrorCode { SourmashError::InvalidHashFunction { .. } => SourmashErrorCode::InvalidHashFunction, SourmashError::ReadDataError { .. } => SourmashErrorCode::ReadData, SourmashError::StorageError { .. } => SourmashErrorCode::Storage, + SourmashError::HLLPrecisionBounds { .. } => SourmashErrorCode::HLLPrecisionBounds, SourmashError::SerdeError { .. } => SourmashErrorCode::SerdeError, SourmashError::IOError { .. } => SourmashErrorCode::Io, SourmashError::NifflerError { .. } => SourmashErrorCode::NifflerError, diff --git a/src/core/src/ffi/hyperloglog.rs b/src/core/src/ffi/hyperloglog.rs new file mode 100644 index 0000000000..7d8e39ea82 --- /dev/null +++ b/src/core/src/ffi/hyperloglog.rs @@ -0,0 +1,217 @@ +use std::ffi::CStr; +use std::os::raw::c_char; +use std::slice; + +use crate::index::sbt::Update; +use crate::signature::SigsTrait; +use crate::sketch::hyperloglog::HyperLogLog; + +use crate::ffi::minhash::SourmashKmerMinHash; +use crate::ffi::utils::ForeignObject; + +pub struct SourmashHyperLogLog; + +impl ForeignObject for SourmashHyperLogLog { + type RustObject = HyperLogLog; +} + +#[no_mangle] +pub unsafe extern "C" fn hll_new() -> *mut SourmashHyperLogLog { + SourmashHyperLogLog::from_rust(HyperLogLog::default()) +} + +#[no_mangle] +pub unsafe extern "C" fn hll_free(ptr: *mut SourmashHyperLogLog) { + SourmashHyperLogLog::drop(ptr); +} + +ffi_fn! { +unsafe fn hll_with_error_rate( + error_rate: f64, + ksize: usize, +) -> Result<*mut SourmashHyperLogLog> { + let hll = HyperLogLog::with_error_rate(error_rate, ksize)?; + Ok(SourmashHyperLogLog::from_rust(hll)) +} +} + +#[no_mangle] +pub unsafe extern "C" fn hll_ksize(ptr: *const SourmashHyperLogLog) -> usize { + SourmashHyperLogLog::as_rust(ptr).ksize() +} + +#[no_mangle] +pub unsafe extern "C" fn hll_cardinality(ptr: *const SourmashHyperLogLog) -> usize { + SourmashHyperLogLog::as_rust(ptr).cardinality() +} + +#[no_mangle] +pub unsafe extern "C" fn hll_similarity( + ptr: *const SourmashHyperLogLog, + optr: *const SourmashHyperLogLog, +) -> f64 { + SourmashHyperLogLog::as_rust(ptr).similarity(SourmashHyperLogLog::as_rust(optr)) +} + +#[no_mangle] +pub unsafe extern "C" fn hll_containment( + ptr: *const SourmashHyperLogLog, + optr: *const SourmashHyperLogLog, +) -> f64 { + SourmashHyperLogLog::as_rust(ptr).containment(SourmashHyperLogLog::as_rust(optr)) +} + +#[no_mangle] +pub unsafe extern "C" fn hll_intersection_size( + ptr: *const SourmashHyperLogLog, + optr: *const SourmashHyperLogLog, +) -> usize { + SourmashHyperLogLog::as_rust(ptr).intersection(SourmashHyperLogLog::as_rust(optr)) +} + +ffi_fn! { +unsafe fn hll_add_sequence( + ptr: *mut SourmashHyperLogLog, + sequence: *const c_char, + insize: usize, + force: bool +) -> Result<()> { + + let hll = SourmashHyperLogLog::as_rust_mut(ptr); + + let buf = { + assert!(!ptr.is_null()); + slice::from_raw_parts(sequence as *mut u8, insize) + }; + + hll.add_sequence(buf, force) +} +} + +#[no_mangle] +pub unsafe extern "C" fn hll_add_hash(ptr: *mut SourmashHyperLogLog, hash: u64) { + let hll = SourmashHyperLogLog::as_rust_mut(ptr); + hll.add_hash(hash); +} + +ffi_fn! { +unsafe fn hll_merge( + ptr: *mut SourmashHyperLogLog, + optr: *const SourmashHyperLogLog, +) { + let hll = SourmashHyperLogLog::as_rust_mut(ptr); + let ohll = SourmashHyperLogLog::as_rust(optr); + + // FIXME raise an exception properly + hll.merge(ohll)?; +} +} + +ffi_fn! { +unsafe fn hll_update_mh( + ptr: *mut SourmashHyperLogLog, + optr: *const SourmashKmerMinHash, +) { + let hll = SourmashHyperLogLog::as_rust_mut(ptr); + let mh = SourmashKmerMinHash::as_rust(optr); + + mh.update(hll)? +} +} + +#[no_mangle] +pub unsafe extern "C" fn hll_matches( + ptr: *const SourmashHyperLogLog, + mh_ptr: *const SourmashKmerMinHash, +) -> usize { + let hll = SourmashHyperLogLog::as_rust(ptr); + let mh_hll = SourmashKmerMinHash::as_rust(mh_ptr).as_hll(); + + hll.intersection(&mh_hll) +} + +ffi_fn! { +unsafe fn hll_from_path(filename: *const c_char) -> Result<*mut SourmashHyperLogLog> { + // FIXME use buffer + len instead of c_str + let c_str = { + assert!(!filename.is_null()); + + CStr::from_ptr(filename) + }; + + let (mut input, _) = niffler::from_path(c_str.to_str()?)?; + let hll = HyperLogLog::from_reader(&mut input)?; + + Ok(SourmashHyperLogLog::from_rust(hll)) +} +} + +ffi_fn! { +unsafe fn hll_from_buffer(ptr: *const c_char, insize: usize) -> Result<*mut SourmashHyperLogLog> { + // FIXME use SourmashSlice_u8? + let buf = { + assert!(!ptr.is_null()); + slice::from_raw_parts(ptr as *mut u8, insize) + }; + + let hll = HyperLogLog::from_reader(&mut &buf[..])?; + + Ok(SourmashHyperLogLog::from_rust(hll)) +} +} + +ffi_fn! { +unsafe fn hll_save(ptr: *const SourmashHyperLogLog, filename: *const c_char) -> Result<()> { + let hll = SourmashHyperLogLog::as_rust(ptr); + + // FIXME use buffer + len instead of c_str + let c_str = { + assert!(!filename.is_null()); + + CStr::from_ptr(filename) + }; + + hll.save(c_str.to_str()?)?; + + Ok(()) +} +} + +ffi_fn! { +unsafe fn hll_to_buffer(ptr: *const SourmashHyperLogLog, size: *mut usize) -> Result<*const u8> { + let hll = SourmashHyperLogLog::as_rust(ptr); + + // TODO: remove this + let compression = 1; + + let mut buffer = vec![]; + { + let mut writer = if compression > 0 { + let level = match compression { + 1 => niffler::compression::Level::One, + 2 => niffler::compression::Level::Two, + 3 => niffler::compression::Level::Three, + 4 => niffler::compression::Level::Four, + 5 => niffler::compression::Level::Five, + 6 => niffler::compression::Level::Six, + 7 => niffler::compression::Level::Seven, + 8 => niffler::compression::Level::Eight, + _ => niffler::compression::Level::Nine, + }; + + niffler::get_writer(Box::new(&mut buffer), + niffler::compression::Format::Gzip, + level)? + } else { + Box::new(&mut buffer) + }; + hll.save_to_writer(&mut writer)?; + } + + let b = buffer.into_boxed_slice(); + *size = b.len(); + + // FIXME use SourmashSlice_u8? + Ok(Box::into_raw(b) as *const u8) +} +} diff --git a/src/core/src/ffi/minhash.rs b/src/core/src/ffi/minhash.rs index 980cb7b258..c0812cacda 100644 --- a/src/core/src/ffi/minhash.rs +++ b/src/core/src/ffi/minhash.rs @@ -2,11 +2,10 @@ use std::ffi::CStr; use std::os::raw::c_char; use std::slice; +use crate::encodings::{aa_to_dayhoff, aa_to_hp, translate_codon, HashFunctions}; use crate::ffi::utils::{ForeignObject, SourmashStr}; use crate::signature::SigsTrait; -use crate::sketch::minhash::{ - aa_to_dayhoff, aa_to_hp, translate_codon, HashFunctions, KmerMinHash, -}; +use crate::sketch::minhash::KmerMinHash; pub struct SourmashKmerMinHash; diff --git a/src/core/src/ffi/mod.rs b/src/core/src/ffi/mod.rs index 5fd05a047d..bfd9b46bd7 100644 --- a/src/core/src/ffi/mod.rs +++ b/src/core/src/ffi/mod.rs @@ -7,6 +7,7 @@ pub mod utils; pub mod cmd; +pub mod hyperloglog; pub mod minhash; pub mod nodegraph; pub mod signature; diff --git a/src/core/src/ffi/signature.rs b/src/core/src/ffi/signature.rs index 8a30737462..59859372aa 100644 --- a/src/core/src/ffi/signature.rs +++ b/src/core/src/ffi/signature.rs @@ -4,8 +4,8 @@ use std::io; use std::os::raw::c_char; use std::slice; +use crate::encodings::HashFunctions; use crate::signature::Signature; -use crate::sketch::minhash::HashFunctions; use crate::sketch::Sketch; use crate::ffi::cmd::compute::SourmashComputeParameters; diff --git a/src/core/src/ffi/utils.rs b/src/core/src/ffi/utils.rs index 140058e8c9..2d1ebd3ccd 100644 --- a/src/core/src/ffi/utils.rs +++ b/src/core/src/ffi/utils.rs @@ -9,7 +9,6 @@ use std::slice; use std::str; use std::thread; -use failure::Fail; // can remove after .backtrace() is available in error... use thiserror::Error; use crate::errors::SourmashErrorCode; @@ -111,6 +110,7 @@ pub unsafe extern "C" fn sourmash_err_get_last_message() -> SourmashStr { /// Returns the panic information as string. #[no_mangle] pub unsafe extern "C" fn sourmash_err_get_backtrace() -> SourmashStr { + /* TODO: bring back when backtrace is available in std::error LAST_ERROR.with(|e| { if let Some(ref error) = *e.borrow() { if let Some(backtrace) = error.backtrace() { @@ -125,6 +125,8 @@ pub unsafe extern "C" fn sourmash_err_get_backtrace() -> SourmashStr { Default::default() } }) + */ + SourmashStr::default() } /// Clears the last error. diff --git a/src/core/src/from.rs b/src/core/src/from.rs index e2405f3594..b0bb7a0974 100644 --- a/src/core/src/from.rs +++ b/src/core/src/from.rs @@ -1,7 +1,8 @@ use finch::sketch_schemes::mash::MashSketcher; use finch::sketch_schemes::SketchScheme; -use crate::sketch::minhash::{HashFunctions, KmerMinHash}; +use crate::encodings::HashFunctions; +use crate::sketch::minhash::KmerMinHash; /* TODO: @@ -41,8 +42,9 @@ mod test { use std::collections::HashSet; use std::iter::FromIterator; + use crate::encodings::HashFunctions; use crate::signature::SigsTrait; - use crate::sketch::minhash::{HashFunctions, KmerMinHash}; + use crate::sketch::minhash::KmerMinHash; use finch::sketch_schemes::mash::MashSketcher; use needletail::kmer::CanonicalKmers; diff --git a/src/core/src/lib.rs b/src/core/src/lib.rs index 7f9a7a0efc..07292ddae1 100644 --- a/src/core/src/lib.rs +++ b/src/core/src/lib.rs @@ -26,6 +26,8 @@ pub mod index; pub mod signature; pub mod sketch; +pub mod encodings; + #[cfg(feature = "from-finch")] pub mod from; diff --git a/src/core/src/signature.rs b/src/core/src/signature.rs index dbf21ebacf..ac0f32cab9 100644 --- a/src/core/src/signature.rs +++ b/src/core/src/signature.rs @@ -17,51 +17,199 @@ use typed_builder::TypedBuilder; #[cfg(all(target_arch = "wasm32", target_vendor = "unknown"))] use wasm_bindgen::prelude::*; +use crate::encodings::{aa_to_dayhoff, aa_to_hp, revcomp, to_aa, HashFunctions, VALID}; use crate::index::storage::ToWriter; -use crate::sketch::minhash::HashFunctions; use crate::sketch::Sketch; use crate::Error; +use crate::HashIntoType; pub trait SigsTrait { fn size(&self) -> usize; fn to_vec(&self) -> Vec; - fn check_compatible(&self, other: &Self) -> Result<(), Error>; - fn add_sequence(&mut self, seq: &[u8], _force: bool) -> Result<(), Error>; - fn add_protein(&mut self, seq: &[u8]) -> Result<(), Error>; fn ksize(&self) -> usize; + fn check_compatible(&self, other: &Self) -> Result<(), Error>; + fn seed(&self) -> u64; + + fn hash_function(&self) -> HashFunctions; + + fn add_hash(&mut self, hash: HashIntoType); + fn add_sequence(&mut self, seq: &[u8], force: bool) -> Result<(), Error> { + let ksize = self.ksize() as usize; + let len = seq.len(); + let hash_function = self.hash_function(); + + if len < ksize { + return Ok(()); + }; + + // Here we convert the sequence to upper case and + // pre-calculate the reverse complement for the full sequence... + let sequence = seq.to_ascii_uppercase(); + let rc = revcomp(&sequence); + + if hash_function.dna() { + let mut last_position_check = 0; + + let mut is_valid_kmer = |i| { + for j in std::cmp::max(i, last_position_check)..i + ksize { + if !VALID[sequence[j] as usize] { + return false; + } + last_position_check += 1; + } + true + }; + + for i in 0..=len - ksize { + // ... and then while moving the k-mer window forward for the sequence + // we move another window backwards for the RC. + // For a ksize = 3, and a sequence AGTCGT (len = 6): + // +-+---------+---------------+-------+ + // seq RC |i|i + ksize|len - ksize - i|len - i| + // AGTCGT ACGACT +-+---------+---------------+-------+ + // +-> +-> |0| 2 | 3 | 6 | + // +-> +-> |1| 3 | 2 | 5 | + // +-> +-> |2| 4 | 1 | 4 | + // +-> +-> |3| 5 | 0 | 3 | + // +-+---------+---------------+-------+ + // (leaving this table here because I had to draw to + // get the indices correctly) + + let kmer = &sequence[i..i + ksize]; + + if !is_valid_kmer(i) { + if !force { + // throw error if DNA is not valid + return Err(Error::InvalidDNA { + message: String::from_utf8(kmer.to_vec()).unwrap(), + }); + } + + continue; // skip invalid k-mer + } + + let krc = &rc[len - ksize - i..len - i]; + let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed()); + self.add_hash(hash); + } + } else { + // protein + let aa_ksize = self.ksize() / 3; + + for i in 0..3 { + let substr: Vec = sequence + .iter() + .cloned() + .skip(i) + .take(sequence.len() - i) + .collect(); + let aa = to_aa(&substr, hash_function.dayhoff(), hash_function.hp()).unwrap(); + + aa.windows(aa_ksize as usize).for_each(|n| { + let hash = crate::_hash_murmur(n, self.seed()); + self.add_hash(hash); + }); + + let rc_substr: Vec = rc.iter().cloned().skip(i).take(rc.len() - i).collect(); + let aa_rc = to_aa(&rc_substr, hash_function.dayhoff(), hash_function.hp()).unwrap(); + + aa_rc.windows(aa_ksize as usize).for_each(|n| { + let hash = crate::_hash_murmur(n, self.seed()); + self.add_hash(hash); + }); + } + } + + Ok(()) + } + + fn add_protein(&mut self, seq: &[u8]) -> Result<(), Error> { + let ksize = (self.ksize() / 3) as usize; + let len = seq.len(); + let hash_function = self.hash_function(); + + if len < ksize { + return Ok(()); + } + + if hash_function.protein() { + for aa_kmer in seq.windows(ksize) { + let hash = crate::_hash_murmur(&aa_kmer, self.seed()); + self.add_hash(hash); + } + return Ok(()); + } + + let aa_seq: Vec<_> = match hash_function { + HashFunctions::murmur64_dayhoff => seq.iter().cloned().map(aa_to_dayhoff).collect(), + HashFunctions::murmur64_hp => seq.iter().cloned().map(aa_to_hp).collect(), + invalid => { + return Err(Error::InvalidHashFunction { + function: format!("{}", invalid), + }) + } + }; + + for aa_kmer in aa_seq.windows(ksize) { + let hash = crate::_hash_murmur(&aa_kmer, self.seed()); + self.add_hash(hash); + } + + Ok(()) + } } impl SigsTrait for Sketch { fn size(&self) -> usize { match *self { - Sketch::UKHS(ref ukhs) => ukhs.size(), Sketch::MinHash(ref mh) => mh.size(), Sketch::LargeMinHash(ref mh) => mh.size(), + Sketch::HyperLogLog(ref hll) => hll.size(), } } fn to_vec(&self) -> Vec { match *self { - Sketch::UKHS(ref ukhs) => ukhs.to_vec(), Sketch::MinHash(ref mh) => mh.to_vec(), Sketch::LargeMinHash(ref mh) => mh.to_vec(), + Sketch::HyperLogLog(ref hll) => hll.to_vec(), } } fn ksize(&self) -> usize { match *self { - Sketch::UKHS(ref ukhs) => ukhs.ksize(), Sketch::MinHash(ref mh) => mh.ksize(), Sketch::LargeMinHash(ref mh) => mh.ksize(), + Sketch::HyperLogLog(ref hll) => hll.ksize(), + } + } + + fn seed(&self) -> u64 { + match *self { + Sketch::MinHash(ref mh) => mh.seed(), + Sketch::LargeMinHash(ref mh) => mh.seed(), + Sketch::HyperLogLog(ref hll) => hll.seed(), + } + } + + fn hash_function(&self) -> HashFunctions { + match *self { + Sketch::MinHash(ref mh) => mh.hash_function(), + Sketch::LargeMinHash(ref mh) => mh.hash_function(), + Sketch::HyperLogLog(ref hll) => hll.hash_function(), + } + } + + fn add_hash(&mut self, hash: HashIntoType) { + match *self { + Sketch::MinHash(ref mut mh) => mh.add_hash(hash), + Sketch::LargeMinHash(ref mut mh) => mh.add_hash(hash), + Sketch::HyperLogLog(ref mut hll) => hll.add_hash(hash), } } fn check_compatible(&self, other: &Self) -> Result<(), Error> { match *self { - Sketch::UKHS(ref ukhs) => match other { - Sketch::UKHS(ref ot) => ukhs.check_compatible(ot), - _ => Err(Error::MismatchSignatureType), - }, Sketch::MinHash(ref mh) => match other { Sketch::MinHash(ref ot) => mh.check_compatible(ot), _ => Err(Error::MismatchSignatureType), @@ -70,6 +218,10 @@ impl SigsTrait for Sketch { Sketch::LargeMinHash(ref ot) => mh.check_compatible(ot), _ => Err(Error::MismatchSignatureType), }, + Sketch::HyperLogLog(ref hll) => match other { + Sketch::HyperLogLog(ref ot) => hll.check_compatible(ot), + _ => Err(Error::MismatchSignatureType), + }, } } @@ -77,7 +229,7 @@ impl SigsTrait for Sketch { match *self { Sketch::MinHash(ref mut mh) => mh.add_sequence(seq, force), Sketch::LargeMinHash(ref mut mh) => mh.add_sequence(seq, force), - Sketch::UKHS(_) => unimplemented!(), + Sketch::HyperLogLog(_) => unimplemented!(), } } @@ -85,7 +237,7 @@ impl SigsTrait for Sketch { match *self { Sketch::MinHash(ref mut mh) => mh.add_protein(seq), Sketch::LargeMinHash(ref mut mh) => mh.add_protein(seq), - Sketch::UKHS(_) => unimplemented!(), + Sketch::HyperLogLog(_) => unimplemented!(), } } } @@ -197,7 +349,7 @@ impl Signature { match &self.signatures[0] { Sketch::MinHash(mh) => mh.md5sum(), Sketch::LargeMinHash(mh) => mh.md5sum(), - Sketch::UKHS(hs) => hs.md5sum(), + Sketch::HyperLogLog(_) => unimplemented!(), } } else { // TODO: select the correct signature @@ -297,25 +449,7 @@ impl Signature { None => return true, // TODO: match previous behavior }; } - Sketch::UKHS(hs) => { - if let Some(k) = ksize { - if k != hs.ksize() as usize { - return false; - } - }; - - match moltype { - Some(x) => { - if x == HashFunctions::murmur64_DNA { - return true; - } else { - // TODO: draff only supports dna for now - unimplemented!() - } - } - None => unimplemented!(), - }; - } + Sketch::HyperLogLog(_) => unimplemented!(), }; false }) diff --git a/src/core/src/sketch/hyperloglog/estimators.rs b/src/core/src/sketch/hyperloglog/estimators.rs new file mode 100644 index 0000000000..4c2fbe02cc --- /dev/null +++ b/src/core/src/sketch/hyperloglog/estimators.rs @@ -0,0 +1,178 @@ +use std::cmp; + +pub type CounterType = u8; + +pub fn counts(registers: &[CounterType], q: usize) -> Vec { + let mut counts = vec![0; q + 2]; + + for k in registers { + counts[*k as usize] += 1; + } + + counts +} + +#[allow(clippy::many_single_char_names)] +pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { + let m = 1 << p; + if counts[q + 1] == m { + return std::f64::INFINITY; + } + + let (k_min, _) = counts.iter().enumerate().find(|(_, v)| **v != 0).unwrap(); + let k_min_prime = cmp::max(1, k_min); + + let (k_max, _) = counts + .iter() + .enumerate() + .rev() + .find(|(_, v)| **v != 0) + .unwrap(); + let k_max_prime = cmp::min(q, k_max as usize); + + let mut z = 0.; + for i in num_iter::range_step_inclusive(k_max_prime as i32, k_min_prime as i32, -1) { + z = 0.5 * z + counts[i as usize] as f64; + } + + // ldexp(x, i) = x * (2 ** i) + z *= 2f64.powi(-(k_min_prime as i32)); + + let mut c_prime = counts[q + 1]; + if q >= 1 { + c_prime += counts[k_max_prime]; + } + + let mut g_prev = 0.; + let a = z + (counts[0] as f64); + let b = z + (counts[q + 1] as f64) * 2f64.powi(-(q as i32)); + let m_prime = (m - counts[0]) as f64; + + let mut x = if b <= 1.5 * a { + // weak lower bound (47) + m_prime / (0.5 * b + a) + } else { + // strong lower bound (46) + m_prime / (b * (1. + b / a).ln()) + }; + + let mut delta_x = x; + let del = relerr / (m as f64).sqrt(); + while delta_x > x * del { + // secant method iteration + + let kappa: usize = az::saturating_cast(2. + x.log2().floor()); + + // x_prime in [0, 0.25] + let mut x_prime = x * 2f64.powi(-(cmp::max(k_max_prime, kappa) as i32) - 1); + let x_pp = x_prime * x_prime; + + // Taylor approximation (58) + let mut h = x_prime - (x_pp / 3.) + (x_pp * x_pp) * (1. / 45. - x_pp / 472.5); + + // Calculate h(x/2^k), see (56), at this point x_prime = x / (2^(k+2)) + for _k in num_iter::range_step_inclusive(kappa as i32 - 1, k_max_prime as i32, -1) { + let h_prime = 1. - h; + h = (x_prime + h * h_prime) / (x_prime + h_prime); + x_prime += x_prime; + } + + // compare (53) + let mut g = c_prime as f64 * h; + + for k in num_iter::range_step_inclusive(k_max_prime as i32 - 1, k_min_prime as i32, -1) { + let h_prime = 1. - h; + // Calculate h(x/2^k), see (56), at this point x_prime = x / (2^(k+2)) + h = (x_prime + h * h_prime) / (x_prime + h_prime); + g += counts[k as usize] as f64 * h; + x_prime += x_prime; + } + + g += x * a; + delta_x = if (g > g_prev) | (m_prime >= g) { + // see (54) + delta_x * (m_prime - g) / (g - g_prev) + } else { + 0. + }; + + x += delta_x; + g_prev = g + } + + m as f64 * x +} + +/// Calculate the joint maximum likelihood of A and B. +/// +/// Returns a tuple (only in A, only in B, intersection) +pub fn joint_mle( + k1: &[CounterType], + k2: &[CounterType], + p: usize, + q: usize, +) -> (usize, usize, usize) { + let mut c1 = vec![0; q + 2]; + let mut c2 = vec![0; q + 2]; + let mut cu = vec![0; q + 2]; + let mut cg1 = vec![0; q + 2]; + let mut cg2 = vec![0; q + 2]; + let mut ceq = vec![0; q + 2]; + + for (k1_, k2_) in k1.iter().zip(k2.iter()) { + match k1_.cmp(&k2_) { + cmp::Ordering::Less => { + c1[*k1_ as usize] += 1; + cg2[*k2_ as usize] += 1; + } + cmp::Ordering::Greater => { + cg1[*k1_ as usize] += 1; + c2[*k2_ as usize] += 1; + } + cmp::Ordering::Equal => { + ceq[*k1_ as usize] += 1; + } + } + cu[*cmp::max(k1_, k2_) as usize] += 1; + } + + for (i, (v, u)) in cg1.iter().zip(ceq.iter()).enumerate() { + c1[i] += v + u; + } + + for (i, (v, u)) in cg2.iter().zip(ceq.iter()).enumerate() { + c2[i] += v + u; + } + + let c_ax = mle(&c1, p, q, 0.01); + let c_bx = mle(&c2, p, q, 0.01); + let c_abx = mle(&cu, p, q, 0.01); + + let mut counts_axb_half = vec![0u16; q + 2]; + let mut counts_bxa_half = vec![0u16; q + 2]; + + counts_axb_half[q] = k1.len() as u16; + counts_bxa_half[q] = k2.len() as u16; + + for _q in 0..q { + counts_axb_half[_q] = cg1[_q] + ceq[_q] + cg2[_q + 1]; + debug_assert!(counts_axb_half[q] >= counts_axb_half[_q]); + counts_axb_half[q] -= counts_axb_half[_q]; + + counts_bxa_half[_q] = cg2[_q] + ceq[_q] + cg1[_q + 1]; + debug_assert!(counts_bxa_half[q] >= counts_bxa_half[_q]); + counts_bxa_half[q] -= counts_bxa_half[_q]; + } + + let c_axb_half = mle(&counts_axb_half, p, q - 1, 0.01); + let c_bxa_half = mle(&counts_bxa_half, p, q - 1, 0.01); + + let cx1 = 1.5 * c_bx + 1.5 * c_ax - c_bxa_half - c_axb_half; + let cx2 = 2. * (c_bxa_half + c_axb_half) - 3. * c_abx; + + ( + (c_abx - c_bx) as usize, + (c_abx - c_ax) as usize, + cmp::max(0, (0.5 * (cx1 + cx2)) as usize), + ) +} diff --git a/src/core/src/sketch/hyperloglog/mod.rs b/src/core/src/sketch/hyperloglog/mod.rs new file mode 100644 index 0000000000..2d2429fae8 --- /dev/null +++ b/src/core/src/sketch/hyperloglog/mod.rs @@ -0,0 +1,373 @@ +/* +Based on the HyperLogLog implementations in khmer + https://github.com/dib-lab/khmer/blob/fb65d21eaedf0d397d49ae3debc578897f9d6eb4/src/oxli/hllcounter.cc +using the maximum likelihood estimators from + https://oertl.github.io/hyperloglog-sketch-estimation-paper/paper/paper.pdf +first implemented for genomics in dashing + https://genomebiology.biomedcentral.com/articles/10.1186/s13059-019-1875-0 +*/ + +use std::cmp; +use std::fs::File; +use std::io; +use std::path::Path; + +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use serde::{Deserialize, Serialize}; + +use crate::encodings::HashFunctions; +use crate::index::sbt::Update; +use crate::signature::SigsTrait; +use crate::sketch::KmerMinHash; +use crate::Error; +use crate::HashIntoType; + +pub mod estimators; +use estimators::CounterType; + +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] +pub struct HyperLogLog { + registers: Vec, + p: usize, + q: usize, + ksize: usize, +} + +impl HyperLogLog { + pub fn with_error_rate(error_rate: f64, ksize: usize) -> Result { + let p = f64::ceil(f64::log2(f64::powi(1.04 / error_rate, 2))); + HyperLogLog::new(p as usize, ksize) + } + + pub fn new(p: usize, ksize: usize) -> Result { + if p < 4 || p > 18 { + return Err(Error::HLLPrecisionBounds); + } + + let size = (1 as usize) << p; + let registers = vec![0; size]; + + Ok(HyperLogLog { + registers, + ksize, + p, + q: 64 - p, // FIXME: allow setting q explicitly + }) + } + + pub fn merge(&mut self, other: &HyperLogLog) -> Result<(), Error> { + self.check_compatible(other)?; + self.registers + .iter_mut() + .zip(other.registers.iter()) + .for_each(|(a, b)| *a = cmp::max(*a, *b)); + Ok(()) + } + + pub fn add_word(&mut self, word: &[u8]) { + let hash = crate::_hash_murmur(word, 42); // TODO: decide on seed + self.add_hash(hash); + } + + pub fn add_many(&mut self, hashes: &[HashIntoType]) -> Result<(), Error> { + for min in hashes { + self.add_hash(*min); + } + Ok(()) + } + + pub fn cardinality(&self) -> usize { + let counts = estimators::counts(&self.registers, self.q); + + estimators::mle(&counts, self.p, self.q, 0.01) as usize + } + + pub fn similarity(&self, other: &HyperLogLog) -> f64 { + let (only_a, only_b, intersection) = + estimators::joint_mle(&self.registers, &other.registers, self.p, self.q); + + intersection as f64 / (only_a + only_b + intersection) as f64 + } + + pub fn containment(&self, other: &HyperLogLog) -> f64 { + let (only_a, _, intersection) = + estimators::joint_mle(&self.registers, &other.registers, self.p, self.q); + + intersection as f64 / (only_a + intersection) as f64 + } + + pub fn intersection(&self, other: &HyperLogLog) -> usize { + let (_, _, intersection) = + estimators::joint_mle(&self.registers, &other.registers, self.p, self.q); + + intersection + } + + // save + pub fn save>(&self, path: P) -> Result<(), Error> { + // TODO: if it ends with gz, open a compressed file + // might use get_output here? + self.save_to_writer(&mut File::create(path)?)?; + Ok(()) + } + + pub fn save_to_writer(&self, wtr: &mut W) -> Result<(), Error> + where + W: io::Write, + { + wtr.write_all(b"HLL")?; + wtr.write_u8(1)?; // version + wtr.write_u8(self.p as u8)?; // number of bits used for indexing + wtr.write_u8(self.q as u8)?; // number of bits used for counting leading zeroes + wtr.write_u8(self.ksize as u8)?; // ksize + wtr.write_all(&self.registers.as_slice())?; + + Ok(()) + } + + pub fn from_reader(rdr: R) -> Result + where + R: io::Read, + { + let (mut rdr, _format) = niffler::get_reader(Box::new(rdr))?; + + let signature = rdr.read_u24::()?; + assert_eq!(signature, 0x484c4c); + + let version = rdr.read_u8()?; + assert_eq!(version, 1); + + let p = rdr.read_u8()? as usize; + let q = rdr.read_u8()? as usize; + + let ksize = rdr.read_u8()? as usize; + let n_registers = 1 << p; + + let mut registers = vec![0u8; n_registers]; + rdr.read_exact(&mut registers)?; + + Ok(HyperLogLog { + p, + q, + ksize, + registers, + }) + } + + pub fn from_path>(path: P) -> Result { + let mut reader = io::BufReader::new(File::open(path)?); + Ok(HyperLogLog::from_reader(&mut reader)?) + } +} + +impl SigsTrait for HyperLogLog { + fn size(&self) -> usize { + self.registers.len() + } + + fn to_vec(&self) -> Vec { + self.registers.iter().map(|x| *x as u64).collect() + } + + fn ksize(&self) -> usize { + self.ksize as usize + } + + fn seed(&self) -> u64 { + // TODO: support other seeds + 42 + } + + fn hash_function(&self) -> HashFunctions { + //TODO support other hash functions + HashFunctions::murmur64_DNA + } + + fn add_hash(&mut self, hash: HashIntoType) { + let value = hash >> self.p; + let index = (hash - (value << self.p)) as usize; + + let leftmost = value.leading_zeros() + 1 - (self.p as u32); + + let old_value = self.registers[index]; + self.registers[index] = cmp::max(old_value, leftmost as CounterType); + } + + fn check_compatible(&self, other: &HyperLogLog) -> Result<(), Error> { + if self.ksize() != other.ksize() { + Err(Error::MismatchKSizes) + } else if self.size() != other.size() { + // TODO: create new error + Err(Error::MismatchNum { + n1: self.size() as u32, + n2: other.size() as u32, + }) + } else { + Ok(()) + } + } +} + +impl Update for KmerMinHash { + fn update(&self, other: &mut HyperLogLog) -> Result<(), Error> { + for h in self.mins() { + other.add_hash(h); + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + use std::io::{BufReader, BufWriter, Read}; + use std::path::PathBuf; + + use crate::signature::SigsTrait; + use needletail::{parse_fastx_file, parse_fastx_reader, Sequence}; + + use super::HyperLogLog; + + // TODO: pull more tests from khmer HLL + + #[test] + fn hll_add() { + const ERR_RATE: f64 = 0.01; + const N_UNIQUE: usize = 3356; + const KSIZE: u8 = 21; + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/ecoli.genes.fna"); + + let mut hll = HyperLogLog::with_error_rate(ERR_RATE, KSIZE as usize).unwrap(); + let mut counter: HashSet> = HashSet::new(); + + let mut parser = parse_fastx_file(filename).unwrap(); + while let Some(record) = parser.next() { + let record = record.unwrap(); + let norm_seq = record.normalize(false); + let rc = norm_seq.reverse_complement(); + + hll.add_sequence(&norm_seq, false).unwrap(); + for (_, kmer, _) in norm_seq.canonical_kmers(KSIZE, &rc) { + counter.insert(kmer.into()); + } + } + + assert_eq!(counter.len(), N_UNIQUE); + + let abs_error = (1. - (hll.cardinality() as f64 / N_UNIQUE as f64)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + } + + #[test] + fn hll_joint_mle() { + const ERR_RATE: f64 = 0.01; + const KSIZE: u8 = 21; + + const N_UNIQUE_H1: usize = 500741; + const N_UNIQUE_H2: usize = 995845; + const N_UNIQUE_U: usize = 995845; + + const SIMILARITY: f64 = 0.502783; + const CONTAINMENT_H1: f64 = 1.; + const CONTAINMENT_H2: f64 = 0.502783; + + const INTERSECTION: usize = 500838; + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/genome-s10.fa.gz"); + + let mut hll1 = HyperLogLog::with_error_rate(ERR_RATE, KSIZE as usize).unwrap(); + let mut hll2 = HyperLogLog::with_error_rate(ERR_RATE, KSIZE as usize).unwrap(); + let mut hllu = HyperLogLog::with_error_rate(ERR_RATE, KSIZE as usize).unwrap(); + + let mut buf = vec![]; + let (mut reader, _) = niffler::from_path(filename).unwrap(); + reader.read_to_end(&mut buf).unwrap(); + + let mut parser = parse_fastx_reader(&buf[..]).unwrap(); + while let Some(record) = parser.next() { + let record = record.unwrap(); + let norm_seq = record.normalize(false); + + hll1.add_sequence(&norm_seq, false).unwrap(); + hllu.add_sequence(&norm_seq, false).unwrap(); + } + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/genome-s10+s11.fa.gz"); + + let mut buf = vec![]; + let (mut reader, _) = niffler::from_path(filename).unwrap(); + reader.read_to_end(&mut buf).unwrap(); + + let mut parser = parse_fastx_reader(&buf[..]).unwrap(); + while let Some(record) = parser.next() { + let record = record.unwrap(); + let norm_seq = record.normalize(false); + + hll2.add_sequence(&norm_seq, false).unwrap(); + hllu.add_sequence(&norm_seq, false).unwrap(); + } + + let abs_error = (1. - (hll1.cardinality() as f64 / N_UNIQUE_H1 as f64)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + + let abs_error = (1. - (hll2.cardinality() as f64 / N_UNIQUE_H2 as f64)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + + let similarity = hll1.similarity(&hll2); + let abs_error = (1. - (similarity / SIMILARITY as f64)).abs(); + assert!(abs_error < ERR_RATE, "{} {}", similarity, SIMILARITY); + + let containment = hll1.containment(&hll2); + let abs_error = (1. - (containment / CONTAINMENT_H1 as f64)).abs(); + assert!(abs_error < ERR_RATE, "{} {}", containment, CONTAINMENT_H1); + + let containment = hll2.containment(&hll1); + let abs_error = (1. - (containment / CONTAINMENT_H2 as f64)).abs(); + assert!(abs_error < ERR_RATE, "{} {}", containment, CONTAINMENT_H2); + + let intersection = hll1.intersection(&hll2) as f64; + let abs_error = (1. - (intersection / INTERSECTION as f64)).abs(); + assert!(abs_error < ERR_RATE, "{} {}", intersection, INTERSECTION); + + hll1.merge(&hll2).unwrap(); + + let abs_error = (1. - (hllu.similarity(&hll1) as f64 / 1.)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + + let abs_error = (1. - (hllu.containment(&hll1) as f64 / 1.)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + + let abs_error = (1. - (hll1.containment(&hllu) as f64 / 1.)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + + let intersection = hll1.intersection(&hllu) as f64; + let abs_error = (1. - (intersection / N_UNIQUE_U as f64)).abs(); + assert!(abs_error < ERR_RATE, "{} {}", intersection, N_UNIQUE_U); + } + + #[test] + fn save_load_hll() { + let mut hll = HyperLogLog::with_error_rate(0.01, 1).expect("error building HLL"); + for i in 1..5000 { + hll.add_hash(i) + } + + let mut buf = Vec::new(); + { + let mut writer = BufWriter::new(&mut buf); + hll.save_to_writer(&mut writer).unwrap(); + } + + let mut reader = BufReader::new(&buf[..]); + let hll_new: HyperLogLog = HyperLogLog::from_reader(&mut reader).expect("Loading error"); + + assert_eq!(hll_new.p, hll.p); + assert_eq!(hll_new.q, hll.q); + assert_eq!(hll_new.registers, hll.registers); + assert_eq!(hll_new.ksize, hll.ksize); + } +} diff --git a/src/core/src/sketch/minhash.rs b/src/core/src/sketch/minhash.rs index 04626210b2..0aea151025 100644 --- a/src/core/src/sketch/minhash.rs +++ b/src/core/src/sketch/minhash.rs @@ -1,65 +1,25 @@ use std::cmp::Ordering; -use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::convert::TryFrom; +use std::collections::{BTreeMap, BTreeSet}; use std::f64::consts::PI; use std::fmt::Write; use std::iter::{Iterator, Peekable}; use std::str; use std::sync::Mutex; -use once_cell::sync::Lazy; use serde::de::Deserializer; use serde::ser::{SerializeStruct, Serializer}; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use crate::_hash_murmur; +use crate::encodings::HashFunctions; use crate::signature::SigsTrait; +use crate::sketch::hyperloglog::HyperLogLog; use crate::Error; #[cfg(all(target_arch = "wasm32", target_vendor = "unknown"))] use wasm_bindgen::prelude::*; -#[cfg_attr(all(target_arch = "wasm32", target_vendor = "unknown"), wasm_bindgen)] -#[allow(non_camel_case_types)] -#[derive(Debug, Clone, Copy, PartialEq)] -#[repr(u32)] -pub enum HashFunctions { - murmur64_DNA = 1, - murmur64_protein = 2, - murmur64_dayhoff = 3, - murmur64_hp = 4, -} - -impl std::fmt::Display for HashFunctions { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "{}", - match self { - HashFunctions::murmur64_DNA => "dna", - HashFunctions::murmur64_protein => "protein", - HashFunctions::murmur64_dayhoff => "dayhoff", - HashFunctions::murmur64_hp => "hp", - } - ) - } -} - -impl TryFrom<&str> for HashFunctions { - type Error = Error; - - fn try_from(moltype: &str) -> Result { - match moltype.to_lowercase().as_ref() { - "dna" => Ok(HashFunctions::murmur64_DNA), - "dayhoff" => Ok(HashFunctions::murmur64_dayhoff), - "hp" => Ok(HashFunctions::murmur64_hp), - "protein" => Ok(HashFunctions::murmur64_protein), - _ => unimplemented!(), - } - } -} - pub fn max_hash_for_scaled(scaled: u64) -> u64 { match scaled { 0 => 0, @@ -266,14 +226,6 @@ impl KmerMinHash { self.hash_function == HashFunctions::murmur64_protein } - fn is_dna(&self) -> bool { - self.hash_function == HashFunctions::murmur64_DNA - } - - pub fn seed(&self) -> u64 { - self.seed - } - pub fn max_hash(&self) -> u64 { self.max_hash } @@ -650,6 +602,7 @@ impl KmerMinHash { Ok((common, combined_mh.mins.len() as u64)) } + // FIXME: intersection_size and count_common should be the same? pub fn intersection_size(&self, other: &KmerMinHash) -> Result<(u64, u64), Error> { self.check_compatible(other)?; @@ -764,10 +717,6 @@ impl KmerMinHash { self.hash_function == HashFunctions::murmur64_hp } - pub fn hash_function(&self) -> HashFunctions { - self.hash_function - } - pub fn mins(&self) -> Vec { self.mins.clone() } @@ -815,6 +764,16 @@ impl KmerMinHash { .collect() } } + + pub fn as_hll(&self) -> HyperLogLog { + let mut hll = HyperLogLog::with_error_rate(0.01, self.ksize()).unwrap(); + + for h in &self.mins { + hll.add_hash(*h) + } + + hll + } } impl SigsTrait for KmerMinHash { @@ -830,6 +789,18 @@ impl SigsTrait for KmerMinHash { self.ksize as usize } + fn seed(&self) -> u64 { + self.seed + } + + fn hash_function(&self) -> HashFunctions { + self.hash_function + } + + fn add_hash(&mut self, hash: u64) { + self.add_hash_with_abundance(hash, 1); + } + fn check_compatible(&self, other: &KmerMinHash) -> Result<(), Error> { /* if self.num != other.num { @@ -855,122 +826,6 @@ impl SigsTrait for KmerMinHash { } Ok(()) } - - fn add_sequence(&mut self, seq: &[u8], force: bool) -> Result<(), Error> { - let ksize = self.ksize as usize; - let len = seq.len(); - - if len < ksize { - return Ok(()); - }; - - // Here we convert the sequence to upper case and - // pre-calculate the reverse complement for the full sequence... - let sequence = seq.to_ascii_uppercase(); - let rc = revcomp(&sequence); - - if self.is_dna() { - let mut last_position_check = 0; - - let mut is_valid_kmer = |i| { - for j in std::cmp::max(i, last_position_check)..i + ksize { - if !VALID[sequence[j] as usize] { - return false; - } - last_position_check += 1; - } - true - }; - - for i in 0..=len - ksize { - // ... and then while moving the k-mer window forward for the sequence - // we move another window backwards for the RC. - // For a ksize = 3, and a sequence AGTCGT (len = 6): - // +-+---------+---------------+-------+ - // seq RC |i|i + ksize|len - ksize - i|len - i| - // AGTCGT ACGACT +-+---------+---------------+-------+ - // +-> +-> |0| 2 | 3 | 6 | - // +-> +-> |1| 3 | 2 | 5 | - // +-> +-> |2| 4 | 1 | 4 | - // +-> +-> |3| 5 | 0 | 3 | - // +-+---------+---------------+-------+ - // (leaving this table here because I had to draw to - // get the indices correctly) - - let kmer = &sequence[i..i + ksize]; - - if !is_valid_kmer(i) { - if !force { - // throw error if DNA is not valid - return Err(Error::InvalidDNA { - message: String::from_utf8(kmer.to_vec()).unwrap(), - }); - } - - continue; // skip invalid k-mer - } - - let krc = &rc[len - ksize - i..len - i]; - self.add_word(std::cmp::min(kmer, krc)); - } - } else { - // protein - let aa_ksize = self.ksize / 3; - - for i in 0..3 { - let substr: Vec = sequence - .iter() - .cloned() - .skip(i) - .take(sequence.len() - i) - .collect(); - let aa = to_aa(&substr, self.dayhoff(), self.hp()).unwrap(); - - aa.windows(aa_ksize as usize).for_each(|n| self.add_word(n)); - - let rc_substr: Vec = rc.iter().cloned().skip(i).take(rc.len() - i).collect(); - let aa_rc = to_aa(&rc_substr, self.dayhoff(), self.hp()).unwrap(); - - aa_rc - .windows(aa_ksize as usize) - .for_each(|n| self.add_word(n)); - } - } - - Ok(()) - } - - fn add_protein(&mut self, seq: &[u8]) -> Result<(), Error> { - let ksize = (self.ksize / 3) as usize; - let len = seq.len(); - - if len < ksize { - return Ok(()); - } - - if let HashFunctions::murmur64_protein = self.hash_function { - for aa_kmer in seq.windows(ksize) { - self.add_word(&aa_kmer); - } - return Ok(()); - } - - let aa_seq: Vec<_> = match self.hash_function { - HashFunctions::murmur64_dayhoff => seq.iter().cloned().map(aa_to_dayhoff).collect(), - HashFunctions::murmur64_hp => seq.iter().cloned().map(aa_to_hp).collect(), - invalid => { - return Err(Error::InvalidHashFunction { - function: format!("{}", invalid), - }) - } - }; - - for aa_kmer in aa_seq.windows(ksize) { - self.add_word(aa_kmer); - } - - Ok(()) - } } struct Intersection> { @@ -1013,296 +868,6 @@ impl> Iterator for Intersection { } } -const COMPLEMENT: [u8; 256] = { - let mut lookup = [0; 256]; - lookup[b'A' as usize] = b'T'; - lookup[b'C' as usize] = b'G'; - lookup[b'G' as usize] = b'C'; - lookup[b'T' as usize] = b'A'; - lookup[b'N' as usize] = b'N'; - lookup -}; - -#[inline] -fn revcomp(seq: &[u8]) -> Vec { - seq.iter() - .rev() - .map(|nt| COMPLEMENT[*nt as usize]) - .collect() -} - -static CODONTABLE: Lazy> = Lazy::new(|| { - [ - // F - ("TTT", b'F'), - ("TTC", b'F'), - // L - ("TTA", b'L'), - ("TTG", b'L'), - // S - ("TCT", b'S'), - ("TCC", b'S'), - ("TCA", b'S'), - ("TCG", b'S'), - ("TCN", b'S'), - // Y - ("TAT", b'Y'), - ("TAC", b'Y'), - // * - ("TAA", b'*'), - ("TAG", b'*'), - // * - ("TGA", b'*'), - // C - ("TGT", b'C'), - ("TGC", b'C'), - // W - ("TGG", b'W'), - // L - ("CTT", b'L'), - ("CTC", b'L'), - ("CTA", b'L'), - ("CTG", b'L'), - ("CTN", b'L'), - // P - ("CCT", b'P'), - ("CCC", b'P'), - ("CCA", b'P'), - ("CCG", b'P'), - ("CCN", b'P'), - // H - ("CAT", b'H'), - ("CAC", b'H'), - // Q - ("CAA", b'Q'), - ("CAG", b'Q'), - // R - ("CGT", b'R'), - ("CGC", b'R'), - ("CGA", b'R'), - ("CGG", b'R'), - ("CGN", b'R'), - // I - ("ATT", b'I'), - ("ATC", b'I'), - ("ATA", b'I'), - // M - ("ATG", b'M'), - // T - ("ACT", b'T'), - ("ACC", b'T'), - ("ACA", b'T'), - ("ACG", b'T'), - ("ACN", b'T'), - // N - ("AAT", b'N'), - ("AAC", b'N'), - // K - ("AAA", b'K'), - ("AAG", b'K'), - // S - ("AGT", b'S'), - ("AGC", b'S'), - // R - ("AGA", b'R'), - ("AGG", b'R'), - // V - ("GTT", b'V'), - ("GTC", b'V'), - ("GTA", b'V'), - ("GTG", b'V'), - ("GTN", b'V'), - // A - ("GCT", b'A'), - ("GCC", b'A'), - ("GCA", b'A'), - ("GCG", b'A'), - ("GCN", b'A'), - // D - ("GAT", b'D'), - ("GAC", b'D'), - // E - ("GAA", b'E'), - ("GAG", b'E'), - // G - ("GGT", b'G'), - ("GGC", b'G'), - ("GGA", b'G'), - ("GGG", b'G'), - ("GGN", b'G'), - ] - .iter() - .cloned() - .collect() -}); - -// Dayhoff table from -// Peris, P., López, D., & Campos, M. (2008). -// IgTM: An algorithm to predict transmembrane domains and topology in -// proteins. BMC Bioinformatics, 9(1), 1029–11. -// http://doi.org/10.1186/1471-2105-9-367 -// -// Original source: -// Dayhoff M. O., Schwartz R. M., Orcutt B. C. (1978). -// A model of evolutionary change in proteins, -// in Atlas of Protein Sequence and Structure, -// ed Dayhoff M. O., editor. -// (Washington, DC: National Biomedical Research Foundation; ), 345–352. -// -// | Amino acid | Property | Dayhoff | -// |---------------|-----------------------|---------| -// | C | Sulfur polymerization | a | -// | A, G, P, S, T | Small | b | -// | D, E, N, Q | Acid and amide | c | -// | H, K, R | Basic | d | -// | I, L, M, V | Hydrophobic | e | -// | F, W, Y | Aromatic | f | -static DAYHOFFTABLE: Lazy> = Lazy::new(|| { - [ - // a - (b'C', b'a'), - // b - (b'A', b'b'), - (b'G', b'b'), - (b'P', b'b'), - (b'S', b'b'), - (b'T', b'b'), - // c - (b'D', b'c'), - (b'E', b'c'), - (b'N', b'c'), - (b'Q', b'c'), - // d - (b'H', b'd'), - (b'K', b'd'), - (b'R', b'd'), - // e - (b'I', b'e'), - (b'L', b'e'), - (b'M', b'e'), - (b'V', b'e'), - // e - (b'F', b'f'), - (b'W', b'f'), - (b'Y', b'f'), - ] - .iter() - .cloned() - .collect() -}); - -// HP Hydrophobic/hydrophilic mapping -// From: Phillips, R., Kondev, J., Theriot, J. (2008). -// Physical Biology of the Cell. New York: Garland Science, Taylor & Francis Group. ISBN: 978-0815341635 - -// -// | Amino acid | HP -// |---------------------------------------|---------| -// | A, F, G, I, L, M, P, V, W, Y | h | -// | N, C, S, T, D, E, R, H, K, Q | p | -static HPTABLE: Lazy> = Lazy::new(|| { - [ - // h - (b'A', b'h'), - (b'F', b'h'), - (b'G', b'h'), - (b'I', b'h'), - (b'L', b'h'), - (b'M', b'h'), - (b'P', b'h'), - (b'V', b'h'), - (b'W', b'h'), - (b'Y', b'h'), - // p - (b'N', b'p'), - (b'C', b'p'), - (b'S', b'p'), - (b'T', b'p'), - (b'D', b'p'), - (b'E', b'p'), - (b'R', b'p'), - (b'H', b'p'), - (b'K', b'p'), - (b'Q', b'p'), - ] - .iter() - .cloned() - .collect() -}); - -#[inline] -pub(crate) fn translate_codon(codon: &[u8]) -> Result { - if codon.len() == 1 { - return Ok(b'X'); - } - - if codon.len() == 2 { - let mut v = codon.to_vec(); - v.push(b'N'); - match CODONTABLE.get(str::from_utf8(v.as_slice()).unwrap()) { - Some(aa) => return Ok(*aa), - None => return Ok(b'X'), - } - } - - if codon.len() == 3 { - match CODONTABLE.get(str::from_utf8(codon).unwrap()) { - Some(aa) => return Ok(*aa), - None => return Ok(b'X'), - } - } - - Err(Error::InvalidCodonLength { - message: format!("{}", codon.len()), - }) -} - -#[inline] -pub(crate) fn aa_to_dayhoff(aa: u8) -> u8 { - match DAYHOFFTABLE.get(&aa) { - Some(letter) => *letter, - None => b'X', - } -} - -pub(crate) fn aa_to_hp(aa: u8) -> u8 { - match HPTABLE.get(&aa) { - Some(letter) => *letter, - None => b'X', - } -} - -#[inline] -fn to_aa(seq: &[u8], dayhoff: bool, hp: bool) -> Result, Error> { - let mut converted: Vec = Vec::with_capacity(seq.len() / 3); - - for chunk in seq.chunks(3) { - if chunk.len() < 3 { - break; - } - - let residue = translate_codon(chunk)?; - if dayhoff { - converted.push(aa_to_dayhoff(residue) as u8); - } else if hp { - converted.push(aa_to_hp(residue) as u8); - } else { - converted.push(residue); - } - } - - Ok(converted) -} - -const VALID: [bool; 256] = { - let mut lookup = [false; 256]; - lookup[b'A' as usize] = true; - lookup[b'C' as usize] = true; - lookup[b'G' as usize] = true; - lookup[b'T' as usize] = true; - lookup -}; - //############# // A MinHash implementation for low scaled or large cardinalities @@ -1499,14 +1064,6 @@ impl KmerMinHashBTree { self.hash_function == HashFunctions::murmur64_protein } - fn is_dna(&self) -> bool { - self.hash_function == HashFunctions::murmur64_DNA - } - - pub fn seed(&self) -> u64 { - self.seed - } - pub fn max_hash(&self) -> u64 { self.max_hash } @@ -1588,10 +1145,6 @@ impl KmerMinHashBTree { data.clone().unwrap() } - pub fn add_hash(&mut self, hash: u64) { - self.add_hash_with_abundance(hash, 1); - } - pub fn add_hash_with_abundance(&mut self, hash: u64, abundance: u64) { if hash > self.max_hash && self.max_hash != 0 { // This is a scaled minhash, and we don't need to add the new hash @@ -1936,6 +1489,18 @@ impl SigsTrait for KmerMinHashBTree { self.ksize as usize } + fn seed(&self) -> u64 { + self.seed + } + + fn hash_function(&self) -> HashFunctions { + self.hash_function + } + + fn add_hash(&mut self, hash: u64) { + self.add_hash_with_abundance(hash, 1); + } + fn check_compatible(&self, other: &KmerMinHashBTree) -> Result<(), Error> { /* if self.num != other.num { @@ -1961,122 +1526,6 @@ impl SigsTrait for KmerMinHashBTree { } Ok(()) } - - fn add_sequence(&mut self, seq: &[u8], force: bool) -> Result<(), Error> { - let ksize = self.ksize as usize; - let len = seq.len(); - - if len < ksize { - return Ok(()); - }; - - // Here we convert the sequence to upper case and - // pre-calculate the reverse complement for the full sequence... - let sequence = seq.to_ascii_uppercase(); - let rc = revcomp(&sequence); - - if self.is_dna() { - let mut last_position_check = 0; - - let mut is_valid_kmer = |i| { - for j in std::cmp::max(i, last_position_check)..i + ksize { - if !VALID[sequence[j] as usize] { - return false; - } - last_position_check += 1; - } - true - }; - - for i in 0..=len - ksize { - // ... and then while moving the k-mer window forward for the sequence - // we move another window backwards for the RC. - // For a ksize = 3, and a sequence AGTCGT (len = 6): - // +-+---------+---------------+-------+ - // seq RC |i|i + ksize|len - ksize - i|len - i| - // AGTCGT ACGACT +-+---------+---------------+-------+ - // +-> +-> |0| 2 | 3 | 6 | - // +-> +-> |1| 3 | 2 | 5 | - // +-> +-> |2| 4 | 1 | 4 | - // +-> +-> |3| 5 | 0 | 3 | - // +-+---------+---------------+-------+ - // (leaving this table here because I had to draw to - // get the indices correctly) - - let kmer = &sequence[i..i + ksize]; - - if !is_valid_kmer(i) { - if !force { - // throw error if DNA is not valid - return Err(Error::InvalidDNA { - message: String::from_utf8(kmer.to_vec()).unwrap(), - }); - } - - continue; // skip invalid k-mer - } - - let krc = &rc[len - ksize - i..len - i]; - self.add_word(std::cmp::min(kmer, krc)); - } - } else { - // protein - let aa_ksize = self.ksize / 3; - - for i in 0..3 { - let substr: Vec = sequence - .iter() - .cloned() - .skip(i) - .take(sequence.len() - i) - .collect(); - let aa = to_aa(&substr, self.dayhoff(), self.hp()).unwrap(); - - aa.windows(aa_ksize as usize).for_each(|n| self.add_word(n)); - - let rc_substr: Vec = rc.iter().cloned().skip(i).take(rc.len() - i).collect(); - let aa_rc = to_aa(&rc_substr, self.dayhoff(), self.hp()).unwrap(); - - aa_rc - .windows(aa_ksize as usize) - .for_each(|n| self.add_word(n)); - } - } - - Ok(()) - } - - fn add_protein(&mut self, seq: &[u8]) -> Result<(), Error> { - let ksize = (self.ksize / 3) as usize; - let len = seq.len(); - - if len < ksize { - return Ok(()); - } - - if let HashFunctions::murmur64_protein = self.hash_function { - for aa_kmer in seq.windows(ksize) { - self.add_word(&aa_kmer); - } - return Ok(()); - } - - let aa_seq: Vec<_> = match self.hash_function { - HashFunctions::murmur64_dayhoff => seq.iter().cloned().map(aa_to_dayhoff).collect(), - HashFunctions::murmur64_hp => seq.iter().cloned().map(aa_to_hp).collect(), - invalid => { - return Err(Error::InvalidHashFunction { - function: format!("{}", invalid), - }) - } - }; - - for aa_kmer in aa_seq.windows(ksize) { - self.add_word(aa_kmer); - } - - Ok(()) - } } impl From for KmerMinHash { diff --git a/src/core/src/sketch/mod.rs b/src/core/src/sketch/mod.rs index 3afe6aeb39..98d9efdd2e 100644 --- a/src/core/src/sketch/mod.rs +++ b/src/core/src/sketch/mod.rs @@ -1,17 +1,16 @@ +pub mod hyperloglog; pub mod minhash; pub mod nodegraph; -pub mod ukhs; - use serde::{Deserialize, Serialize}; +use crate::sketch::hyperloglog::HyperLogLog; use crate::sketch::minhash::{KmerMinHash, KmerMinHashBTree}; -use crate::sketch::ukhs::FlatUKHS; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum Sketch { MinHash(KmerMinHash), LargeMinHash(KmerMinHashBTree), - UKHS(FlatUKHS), // FIXME + HyperLogLog(HyperLogLog), } diff --git a/src/core/src/sketch/nodegraph.rs b/src/core/src/sketch/nodegraph.rs index 34a34a8ca3..1e2fac1eb2 100644 --- a/src/core/src/sketch/nodegraph.rs +++ b/src/core/src/sketch/nodegraph.rs @@ -76,7 +76,7 @@ impl Nodegraph { pub fn with_tables(tablesize: usize, n_tables: usize, ksize: usize) -> Nodegraph { let mut tablesizes = Vec::with_capacity(n_tables); - let mut i = (tablesize - 1) as u64; + let mut i = u64::max((tablesize - 1) as u64, 2); if i % 2 == 0 { i -= 1 } diff --git a/src/core/src/sketch/ukhs.rs b/src/core/src/sketch/ukhs.rs deleted file mode 100644 index d8b4e25df9..0000000000 --- a/src/core/src/sketch/ukhs.rs +++ /dev/null @@ -1,39 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::signature::SigsTrait; -use crate::Error; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FlatUKHS {} - -impl FlatUKHS { - pub fn md5sum(&self) -> String { - unimplemented!() - } -} - -impl SigsTrait for FlatUKHS { - fn size(&self) -> usize { - unimplemented!() - } - - fn to_vec(&self) -> Vec { - unimplemented!() - } - - fn ksize(&self) -> usize { - unimplemented!() - } - - fn check_compatible(&self, _other: &Self) -> Result<(), Error> { - unimplemented!() - } - - fn add_sequence(&mut self, _seq: &[u8], _force: bool) -> Result<(), Error> { - unimplemented!() - } - - fn add_protein(&mut self, _seq: &[u8]) -> Result<(), Error> { - unimplemented!() - } -} diff --git a/src/core/src/wasm.rs b/src/core/src/wasm.rs index bfd5e86065..3ada5e7f86 100644 --- a/src/core/src/wasm.rs +++ b/src/core/src/wasm.rs @@ -3,8 +3,9 @@ use wasm_bindgen::prelude::*; use serde_json; use crate::cmd::ComputeParameters; +use crate::encodings::HashFunctions; use crate::signature::{Signature, SigsTrait}; -use crate::sketch::minhash::{HashFunctions, KmerMinHash}; +use crate::sketch::minhash::KmerMinHash; #[wasm_bindgen] impl KmerMinHash { @@ -37,7 +38,7 @@ impl KmerMinHash { hash_function, seed as u64, track_abundance, - num + num, ) } diff --git a/src/core/tests/minhash.rs b/src/core/tests/minhash.rs index a6a1cc4d9e..b7c61e275b 100644 --- a/src/core/tests/minhash.rs +++ b/src/core/tests/minhash.rs @@ -2,10 +2,9 @@ use std::fs::File; use std::io::BufReader; use std::path::PathBuf; +use sourmash::encodings::HashFunctions; use sourmash::signature::{Signature, SigsTrait}; -use sourmash::sketch::minhash::{ - max_hash_for_scaled, HashFunctions, KmerMinHash, KmerMinHashBTree, -}; +use sourmash::sketch::minhash::{max_hash_for_scaled, KmerMinHash, KmerMinHashBTree}; use sourmash::sketch::Sketch; use proptest::collection::vec; diff --git a/tests/test_hll.py b/tests/test_hll.py new file mode 100644 index 0000000000..19491ceec2 --- /dev/null +++ b/tests/test_hll.py @@ -0,0 +1,126 @@ +import gzip +from tempfile import NamedTemporaryFile + +from screed.fasta import fasta_iter +import pytest + +from sourmash.hll import HLL + +from . import sourmash_tst_utils as utils + +K = 21 # size of kmer +ERR_RATE = 0.01 +N_UNIQUE = 3356 +TRANSLATE = {'A': 'T', 'C': 'G', 'T': 'A', 'G': 'C'} + + +def test_hll_add_python(): + # test python code to count unique kmers using HyperLogLog. + # use the lower level add() method, which accepts anything, + # and compare to an exact count using collections.Counter + + filename = utils.get_test_data('ecoli.genes.fna') + hll = HLL(ERR_RATE, K) + counter = set() + + for n, record in enumerate(fasta_iter(open(filename))): + sequence = record['sequence'] + seq_len = len(sequence) + for n in range(0, seq_len + 1 - K): + kmer = sequence[n:n + K] + rc = "".join(TRANSLATE[c] for c in kmer[::-1]) + + hll.add(kmer) + + if rc in counter: + kmer = rc + counter.update([kmer]) + + n_unique = len(counter) + + assert n_unique == N_UNIQUE + assert abs(1 - float(hll.cardinality()) / N_UNIQUE) < ERR_RATE + + +def test_hll_consume_string(): + # test rust code to count unique kmers using HyperLogLog, + # using screed to feed each read to the counter. + + filename = utils.get_test_data('ecoli.genes.fna') + hll = HLL(ERR_RATE, K) + n_consumed = n = 0 + for n, record in enumerate(fasta_iter(open(filename)), 1): + hll.add_sequence(record['sequence']) + + assert abs(1 - float(len(hll)) / N_UNIQUE) < ERR_RATE + + +def test_hll_similarity_containment(): + N_UNIQUE_H1 = 500741 + N_UNIQUE_H2 = 995845 + N_UNIQUE_U = 995845 + + SIMILARITY = 0.502783 + CONTAINMENT_H1 = 1. + CONTAINMENT_H2 = 0.502783 + + INTERSECTION = 500838 + + hll1 = HLL(ERR_RATE, K) + hll2 = HLL(ERR_RATE, K) + hllu = HLL(ERR_RATE, K) + + filename = utils.get_test_data('genome-s10.fa.gz') + for n, record in enumerate(fasta_iter(gzip.GzipFile(filename))): + sequence = record['sequence'] + seq_len = len(sequence) + for n in range(0, seq_len + 1 - K): + kmer = sequence[n:n + K] + hll1.add(kmer) + hllu.add(kmer) + + filename = utils.get_test_data('genome-s10+s11.fa.gz') + for n, record in enumerate(fasta_iter(gzip.GzipFile(filename))): + sequence = record['sequence'] + seq_len = len(sequence) + for n in range(0, seq_len + 1 - K): + kmer = sequence[n:n + K] + hll2.add(kmer) + hllu.add(kmer) + + assert abs(1 - float(hll1.cardinality()) / N_UNIQUE_H1) < ERR_RATE + assert abs(1 - float(hll2.cardinality()) / N_UNIQUE_H2) < ERR_RATE + + assert abs(1 - float(hll1.similarity(hll2)) / SIMILARITY) < ERR_RATE + + assert abs(1 - float(hll1.containment(hll2)) / CONTAINMENT_H1) < ERR_RATE + assert abs(1 - float(hll2.containment(hll1)) / CONTAINMENT_H2) < ERR_RATE + + assert abs(1 - float(hll1.intersection(hll2)) / INTERSECTION) < ERR_RATE + + """ + hll1.merge(hll2) + + assert abs(1 - float(hllu.similarity(hll1))) < ERR_RATE + + assert abs(1 - float(hllu.containment(hll1))) < ERR_RATE + assert abs(1 - float(hllu.containment(hll2))) < ERR_RATE + + assert abs(1 - float(hll1.intersection(hllu)) / N_UNIQUE_U) < ERR_RATE + """ + +def test_hll_save_load(): + filename = utils.get_test_data('ecoli.genes.fna') + hll = HLL(ERR_RATE, K) + n_consumed = n = 0 + for n, record in enumerate(fasta_iter(open(filename)), 1): + hll.add_sequence(record['sequence']) + + assert abs(1 - float(len(hll)) / N_UNIQUE) < ERR_RATE + + with NamedTemporaryFile() as f: + hll.save(f.name) + + new_hll = HLL.load(f.name) + + assert len(hll) == len(new_hll)