diff --git a/book/book.toml b/book/book.toml index 27c2d8a..1a65a73 100644 --- a/book/book.toml +++ b/book/book.toml @@ -4,3 +4,6 @@ language = "en" multilingual = false src = "src" title = "augurs - a time series toolkit" + +[rust] +edition = "2021" diff --git a/book/src/tutorials/clustering.md b/book/src/tutorials/clustering.md index 6d01e5e..17a6816 100644 --- a/book/src/tutorials/clustering.md +++ b/book/src/tutorials/clustering.md @@ -14,7 +14,10 @@ Let's start with a simple example using DBSCAN clustering: ```rust # extern crate augurs; -use augurs::{clustering::DbscanClusterer, dtw::Dtw}; +use augurs::{ + clustering::{DbscanCluster, DbscanClusterer}, + dtw::Dtw, +}; // Sample time series data const SERIES: &[&[f64]] = &[ @@ -39,11 +42,20 @@ fn main() { let min_cluster_size = 2; // Perform clustering - let clusters: Vec = DbscanClusterer::new(epsilon, min_cluster_size) + let clusters = DbscanClusterer::new(epsilon, min_cluster_size) .fit(&distance_matrix); // Clusters are labeled: -1 for noise, 0+ for cluster membership - assert_eq!(clusters, vec![0, 0, 1, 1, -1]); + assert_eq!( + clusters, + vec![ + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Noise, + ] + ); } ``` diff --git a/crates/augurs-clustering/README.md b/crates/augurs-clustering/README.md index d6658e2..f7b47dd 100644 --- a/crates/augurs-clustering/README.md +++ b/crates/augurs-clustering/README.md @@ -8,7 +8,7 @@ A crate such as [`augurs-dtw`] must be used to calculate the distance matrix for ## Usage ```rust -use augurs::clustering::{DbscanClusterer, DistanceMatrix}; +use augurs::clustering::{DbscanCluster, DbscanClusterer, DistanceMatrix}; # fn main() -> Result<(), Box> { // Start with a distance matrix. @@ -32,7 +32,16 @@ let min_cluster_size = 2; // Use DBSCAN to detect clusters of series. // Note that we don't need to specify the number of clusters in advance. let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix); -assert_eq!(clusters, vec![0, 0, 0, 1, 1]); +assert_eq!( + clusters, + vec![ + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + ], +); # Ok(()) # } ``` diff --git a/crates/augurs-clustering/src/lib.rs b/crates/augurs-clustering/src/lib.rs index b52b905..8a02a15 100644 --- a/crates/augurs-clustering/src/lib.rs +++ b/crates/augurs-clustering/src/lib.rs @@ -1,9 +1,64 @@ #![doc = include_str!("../README.md")] -use std::collections::VecDeque; +use std::{collections::VecDeque, num::NonZeroU32}; pub use augurs_core::DistanceMatrix; +/// A cluster identified by the DBSCAN algorithm. +/// +/// This is either a noise cluster, or a cluster with a specific ID. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DbscanCluster { + /// A noise cluster. + Noise, + /// A cluster with the given ID. + /// + /// The ID is not guaranteed to remain the same between runs of the algorithm. + /// + /// We use a `NonZeroU32` here to ensure that the ID is never zero. This is mostly + /// just a size optimization. + Cluster(NonZeroU32), +} + +impl DbscanCluster { + /// Returns true if this cluster is a noise cluster. + pub fn is_noise(&self) -> bool { + matches!(self, Self::Noise) + } + + /// Returns true if this cluster is a cluster with the given ID. + pub fn is_cluster(&self) -> bool { + matches!(self, Self::Cluster(_)) + } + + /// Returns the ID of the cluster, if it is a cluster, or `-1` if it is a noise cluster. + pub fn as_i32(&self) -> i32 { + match self { + Self::Noise => -1, + Self::Cluster(id) => id.get() as i32, + } + } + + fn increment(&mut self) { + match self { + Self::Noise => unreachable!(), + Self::Cluster(id) => *id = id.checked_add(1).expect("cluster ID overflow"), + } + } +} + +// Simplify tests by allowing comparisons with i32. +#[cfg(test)] +impl PartialEq for DbscanCluster { + fn eq(&self, other: &i32) -> bool { + if self.is_noise() { + *other == -1 + } else { + self.as_i32() == *other + } + } +} + /// DBSCAN clustering algorithm. #[derive(Debug)] pub struct DbscanClusterer { @@ -40,11 +95,11 @@ impl DbscanClusterer { /// Run the DBSCAN clustering algorithm. /// - /// The return value is a vector of cluster assignments, with `-1` indicating noise. - pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec { + /// The return value is a vector of cluster assignments, with `DbscanCluster::Noise` indicating noise. + pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec { let n = distance_matrix.shape().0; - let mut clusters = vec![-1; n]; - let mut cluster = 0; + let mut clusters = vec![DbscanCluster::Noise; n]; + let mut cluster = DbscanCluster::Cluster(NonZeroU32::new(1).unwrap()); let mut visited = vec![false; n]; let mut to_visit = VecDeque::with_capacity(n); @@ -53,7 +108,7 @@ impl DbscanClusterer { for (i, d) in distance_matrix.iter().enumerate() { // Skip if already assigned to a cluster. - if clusters[i] != -1 { + if clusters[i].is_cluster() { continue; } self.find_neighbours(i, d, &mut neighbours); @@ -67,7 +122,7 @@ impl DbscanClusterer { clusters[i] = cluster; // Mark all noise neighbours as visited and add them to the queue. for neighbour in neighbours.drain(..) { - if clusters[neighbour] == -1 { + if clusters[neighbour].is_noise() { visited[neighbour] = true; to_visit.push_back(neighbour); } @@ -87,7 +142,7 @@ impl DbscanClusterer { } } } - cluster += 1; + cluster.increment(); } clusters } @@ -123,19 +178,19 @@ mod test { assert_eq!(clusters, vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, -1, -1]); + assert_eq!(clusters, vec![1, 1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix); assert_eq!(clusters, vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(clusters, vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(clusters, vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, 0]); + assert_eq!(clusters, vec![1, 1, 1, 1]); } #[test] @@ -151,36 +206,36 @@ mod test { let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap(); let clusters = DbscanClusterer::new(10.0, 3).fit(&distance_matrix); let expected = vec![ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 2, -1, 2, -1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 3, -1, 3, -1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]; assert_eq!(clusters, expected); } diff --git a/crates/augurs/tests/integration.rs b/crates/augurs/tests/integration.rs index 068132a..a8d992c 100644 --- a/crates/augurs/tests/integration.rs +++ b/crates/augurs/tests/integration.rs @@ -18,6 +18,9 @@ fn test_changepoint() { #[cfg(feature = "clustering")] #[test] fn test_clustering() { + fn convert_clusters(clusters: Vec) -> Vec { + clusters.into_iter().map(|c| c.as_i32()).collect() + } use augurs::{clustering::DbscanClusterer, DistanceMatrix}; let distance_matrix = vec![ vec![0.0, 1.0, 2.0, 3.0], @@ -27,22 +30,22 @@ fn test_clustering() { ]; let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap(); let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![-1, -1, -1, -1]); + assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, -1, -1]); + assert_eq!(convert_clusters(clusters), vec![1, 1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![-1, -1, -1, -1]); + assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, 0]); + assert_eq!(convert_clusters(clusters), vec![1, 1, 1, 1]); } #[cfg(feature = "dtw")] diff --git a/crates/pyaugurs/src/clustering.rs b/crates/pyaugurs/src/clustering.rs index adad947..b96086a 100644 --- a/crates/pyaugurs/src/clustering.rs +++ b/crates/pyaugurs/src/clustering.rs @@ -82,8 +82,15 @@ impl Dbscan { &self, py: Python<'_>, distance_matrix: InputDistanceMatrix<'_>, - ) -> PyResult>> { + ) -> PyResult>> { let distance_matrix = distance_matrix.try_into()?; - Ok(self.inner.fit(&distance_matrix).into_pyarray(py).into()) + Ok(self + .inner + .fit(&distance_matrix) + .into_iter() + .map(|x| x.as_i32()) + .collect::>() + .into_pyarray(py) + .into()) } } diff --git a/examples/clustering/examples/dbscan_clustering.rs b/examples/clustering/examples/dbscan_clustering.rs index 08fab21..fb93b1e 100644 --- a/examples/clustering/examples/dbscan_clustering.rs +++ b/examples/clustering/examples/dbscan_clustering.rs @@ -7,7 +7,10 @@ //! The resulting clusters are assigned a label of -1 for noise, 0 for the first cluster, and 1 for //! the second cluster. -use augurs::{clustering::DbscanClusterer, dtw::Dtw}; +use augurs::{ + clustering::{DbscanCluster, DbscanClusterer}, + dtw::Dtw, +}; // This is a very trivial example dataset containing 5 time series which // form two obvious clusters, plus a noise cluster. @@ -49,7 +52,16 @@ fn main() { let min_cluster_size = 2; // Run DBSCAN clustering on the distance matrix. - let clusters: Vec = - DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 1, 1, -1]); + let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix); + println!("Clusters: {:?}", clusters); + assert_eq!( + clusters, + vec![ + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Noise, + ] + ); } diff --git a/js/augurs-clustering-js/src/lib.rs b/js/augurs-clustering-js/src/lib.rs index a3c8a69..350af65 100644 --- a/js/augurs-clustering-js/src/lib.rs +++ b/js/augurs-clustering-js/src/lib.rs @@ -44,7 +44,12 @@ impl DbscanClusterer { /// The return value is an `Int32Array` of cluster IDs, with `-1` indicating noise. #[wasm_bindgen] #[allow(non_snake_case)] - pub fn fit(&self, distanceMatrix: VecVecF64) -> Result, JsError> { - Ok(self.inner.fit(&DistanceMatrix::new(distanceMatrix)?.into())) + pub fn fit(&self, distanceMatrix: VecVecF64) -> Result, JsError> { + Ok(self + .inner + .fit(&DistanceMatrix::new(distanceMatrix)?.into()) + .into_iter() + .map(|x| x.as_i32()) + .collect::>()) } } diff --git a/js/testpkg/clustering.test.ts b/js/testpkg/clustering.test.ts index 66ed7d3..8ed45c2 100644 --- a/js/testpkg/clustering.test.ts +++ b/js/testpkg/clustering.test.ts @@ -22,7 +22,7 @@ describe('clustering', () => { [2, 3, 0, 4], [3, 3, 4, 0], ]); - expect(labels).toEqual(new Int32Array([0, 0, -1, -1])); + expect(labels).toEqual(new Int32Array([1, 1, -1, -1])); }); it('can be fit with a raw distance matrix of typed arrays', () => { @@ -33,7 +33,7 @@ describe('clustering', () => { new Float64Array([2, 3, 0, 4]), new Float64Array([3, 3, 4, 0]), ]); - expect(labels).toEqual(new Int32Array([0, 0, -1, -1])); + expect(labels).toEqual(new Int32Array([1, 1, -1, -1])); }); it('can be fit with a distance matrix from augurs', () => { @@ -46,6 +46,6 @@ describe('clustering', () => { ]); const clusterer = new DbscanClusterer({ epsilon: 0.5, minClusterSize: 2 }); const labels = clusterer.fit(distanceMatrix); - expect(labels).toEqual(new Int32Array([0, 0, 0, -1])); + expect(labels).toEqual(new Int32Array([1, 1, 1, -1])); }) });