From d2bcb54c06f10c5718d5005f6459683df757f682 Mon Sep 17 00:00:00 2001 From: Gorka Kobeaga Date: Fri, 14 Oct 2022 00:52:19 +0200 Subject: [PATCH] Add sample naming --- .../linfa-preprocessing/src/linear_scaling.rs | 13 ++++ .../linfa-preprocessing/src/norm_scaling.rs | 10 +++ .../linfa-preprocessing/src/whitening.rs | 13 ++++ algorithms/linfa-reduction/src/pca.rs | 5 +- algorithms/linfa-tsne/src/lib.rs | 8 +- datasets/src/dataset.rs | 22 +++++- src/dataset/impl_dataset.rs | 76 ++++++++++++++++++- src/dataset/impl_targets.rs | 1 + src/dataset/iter.rs | 2 + src/dataset/mod.rs | 4 +- 10 files changed, 146 insertions(+), 8 deletions(-) diff --git a/algorithms/linfa-preprocessing/src/linear_scaling.rs b/algorithms/linfa-preprocessing/src/linear_scaling.rs index 29e3198d6..f2538ffa7 100644 --- a/algorithms/linfa-preprocessing/src/linear_scaling.rs +++ b/algorithms/linfa-preprocessing/src/linear_scaling.rs @@ -291,10 +291,12 @@ impl, T: AsTargets> /// Panics if the shape of the records is not compatible with the shape of the dataset used for fitting. fn transform(&self, x: DatasetBase, T>) -> DatasetBase, T> { let feature_names = x.feature_names(); + let sample_names = x.sample_names(); let (records, targets, weights) = (x.records, x.targets, x.weights); let records = self.transform(records.to_owned()); DatasetBase::new(records, targets) .with_weights(weights) + .with_sample_names(sample_names) .with_feature_names(feature_names) } } @@ -566,6 +568,17 @@ mod tests { assert_eq!(original_feature_names, transformed.feature_names()) } + #[test] + fn test_retain_sample_names() { + let dataset = linfa_datasets::diabetes(); + let original_sample_names = dataset.sample_names(); + let transformed = LinearScaler::standard() + .fit(&dataset) + .unwrap() + .transform(dataset); + assert_eq!(original_sample_names, transformed.sample_names()) + } + #[test] #[should_panic] fn test_transform_wrong_size_array_standard() { diff --git a/algorithms/linfa-preprocessing/src/norm_scaling.rs b/algorithms/linfa-preprocessing/src/norm_scaling.rs index d0c80aaa7..762bc5ae0 100644 --- a/algorithms/linfa-preprocessing/src/norm_scaling.rs +++ b/algorithms/linfa-preprocessing/src/norm_scaling.rs @@ -82,11 +82,13 @@ impl, T: AsTargets> /// Substitutes the records of the dataset with their scaled versions with unit norm. fn transform(&self, x: DatasetBase, T>) -> DatasetBase, T> { let feature_names = x.feature_names(); + let sample_names = x.sample_names(); let (records, targets, weights) = (x.records, x.targets, x.weights); let records = self.transform(records.to_owned()); DatasetBase::new(records, targets) .with_weights(weights) .with_feature_names(feature_names) + .with_sample_names(sample_names) } } @@ -151,4 +153,12 @@ mod tests { let transformed = NormScaler::l2().transform(dataset); assert_eq!(original_feature_names, transformed.feature_names()) } + + #[test] + fn test_retain_sample_names() { + let dataset = linfa_datasets::diabetes(); + let original_sample_names = dataset.sample_names(); + let transformed = NormScaler::l2().transform(dataset); + assert_eq!(original_sample_names, transformed.sample_names()) + } } diff --git a/algorithms/linfa-preprocessing/src/whitening.rs b/algorithms/linfa-preprocessing/src/whitening.rs index 6d031ad2b..89450901f 100644 --- a/algorithms/linfa-preprocessing/src/whitening.rs +++ b/algorithms/linfa-preprocessing/src/whitening.rs @@ -192,10 +192,12 @@ impl, T: AsTargets> { fn transform(&self, x: DatasetBase, T>) -> DatasetBase, T> { let feature_names = x.feature_names(); + let sample_names = x.sample_names(); let (records, targets, weights) = (x.records, x.targets, x.weights); let records = self.transform(records.to_owned()); DatasetBase::new(records, targets) .with_weights(weights) + .with_sample_names(sample_names) .with_feature_names(feature_names) } } @@ -324,6 +326,17 @@ mod tests { assert_eq!(original_feature_names, transformed.feature_names()) } + #[test] + fn test_retain_sample_names() { + let dataset = linfa_datasets::diabetes(); + let original_sample_names = dataset.sample_names(); + let transformed = Whitener::cholesky() + .fit(&dataset) + .unwrap() + .transform(dataset); + assert_eq!(original_sample_names, transformed.sample_names()) + } + #[test] #[should_panic] fn test_pca_fail_on_empty_input() { diff --git a/algorithms/linfa-reduction/src/pca.rs b/algorithms/linfa-reduction/src/pca.rs index c2e31213b..6d6ab0869 100644 --- a/algorithms/linfa-reduction/src/pca.rs +++ b/algorithms/linfa-reduction/src/pca.rs @@ -201,6 +201,7 @@ impl, T> Transformer, T>, DatasetBase, T>> for Pca { fn transform(&self, ds: DatasetBase, T>) -> DatasetBase, T> { + let sample_names = ds.sample_names(); let DatasetBase { records, targets, @@ -211,7 +212,9 @@ impl, T> let mut new_records = self.default_target(&records); self.predict_inplace(&records, &mut new_records); - DatasetBase::new(new_records, targets).with_weights(weights) + DatasetBase::new(new_records, targets) + .with_weights(weights) + .with_sample_names(sample_names) } } #[cfg(test)] diff --git a/algorithms/linfa-tsne/src/lib.rs b/algorithms/linfa-tsne/src/lib.rs index ea0617734..6c8faaceb 100644 --- a/algorithms/linfa-tsne/src/lib.rs +++ b/algorithms/linfa-tsne/src/lib.rs @@ -70,6 +70,7 @@ impl for TSneValidParams { fn transform(&self, ds: DatasetBase, T>) -> Result, T>> { + let sample_names = ds.sample_names(); let DatasetBase { records, targets, @@ -77,8 +78,11 @@ impl .. } = ds; - self.transform(records) - .map(|new_records| DatasetBase::new(new_records, targets).with_weights(weights)) + self.transform(records).map(|new_records| { + DatasetBase::new(new_records, targets) + .with_weights(weights) + .with_sample_names(sample_names) + }) } } diff --git a/datasets/src/dataset.rs b/datasets/src/dataset.rs index c15f29536..b213089dc 100644 --- a/datasets/src/dataset.rs +++ b/datasets/src/dataset.rs @@ -30,10 +30,14 @@ pub fn iris() -> Dataset { ); let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"]; + let sample_names = (0..data.nrows()) + .map(|idx| format!("sample-{idx}")) + .collect(); Dataset::new(data, targets) .map_targets(|x| *x as usize) .with_feature_names(feature_names) + .with_sample_names(sample_names) } #[cfg(feature = "diabetes")] @@ -57,8 +61,13 @@ pub fn diabetes() -> Dataset { "lamotrigine", "blood sugar level", ]; + let sample_names = (0..data.nrows()) + .map(|idx| format!("sample-{idx}")) + .collect(); - Dataset::new(data, targets).with_feature_names(feature_names) + Dataset::new(data, targets) + .with_feature_names(feature_names) + .with_sample_names(sample_names) } #[cfg(feature = "winequality")] @@ -85,10 +94,14 @@ pub fn winequality() -> Dataset { "sulphates", "alcohol", ]; + let sample_names = (0..data.nrows()) + .map(|idx| format!("sample-{idx}")) + .collect(); Dataset::new(data, targets) .map_targets(|x| *x as usize) .with_feature_names(feature_names) + .with_sample_names(sample_names) } #[cfg(feature = "linnerud")] @@ -112,8 +125,13 @@ pub fn linnerud() -> Dataset { let output_array = array_from_buf(&output_data[..]); let feature_names = vec!["Chins", "Situps", "Jumps"]; + let sample_names = (0..input_array.nrows()) + .map(|idx| format!("sample-{idx}")) + .collect(); - Dataset::new(input_array, output_array).with_feature_names(feature_names) + Dataset::new(input_array, output_array) + .with_feature_names(feature_names) + .with_sample_names(sample_names) } #[cfg(test)] diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index c835f8328..f9c27af7a 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -30,6 +30,7 @@ impl DatasetBase { targets, weights: Array1::zeros(0), feature_names: Vec::new(), + sample_names: Vec::new(), } } @@ -70,6 +71,19 @@ impl DatasetBase { } } + /// Returns sample names + /// + /// A row name gives a human-readable string describing the sample. + pub fn sample_names(&self) -> Vec { + if !self.sample_names.is_empty() { + self.sample_names.clone() + } else { + (0..self.records.nsamples()) + .map(|idx| format!("sample-{}", idx)) + .collect() + } + } + /// Return records of a dataset /// /// The records are data points from which predictions are made. This functions returns a @@ -88,6 +102,7 @@ impl DatasetBase { targets: self.targets, weights: Array1::zeros(0), feature_names: Vec::new(), + sample_names: Vec::new(), } } @@ -100,6 +115,7 @@ impl DatasetBase { targets, weights: self.weights, feature_names: self.feature_names, + sample_names: self.sample_names, } } @@ -118,6 +134,29 @@ impl DatasetBase { self } + + /// Updates the row names of a dataset + /// + /// ## Panics + /// + /// This method will panic for any of the following three reasons: + /// + /// - If the names vector length is different to nsamples + pub fn with_sample_names>(mut self, names: Vec) -> DatasetBase { + if names.len() == self.records().nsamples() { + let sample_names = names.into_iter().map(|x| x.into()).collect(); + + self.sample_names = sample_names; + } else if !names.is_empty() { + panic!( + "Sample names vector length, {}, is different to nsamples, {}.", + names.len(), + self.records().nsamples() + ); + } + + self + } } impl> DatasetBase { @@ -143,6 +182,7 @@ impl> DatasetBase { targets, weights, feature_names, + sample_names, .. } = self; @@ -153,6 +193,7 @@ impl> DatasetBase { targets: targets.map(fnc), weights, feature_names, + sample_names, } } @@ -215,6 +256,7 @@ where DatasetBase::new(records, targets) .with_feature_names(self.feature_names.clone()) + .with_sample_names(self.sample_names.clone()) .with_weights(self.weights.clone()) } @@ -287,13 +329,26 @@ where } else { (Array1::zeros(0), Array1::zeros(0)) }; + + let (first_sample_names, second_sample_names) = + if self.sample_names.len() == self.nsamples() { + ( + self.sample_names.iter().take(n).collect(), + self.sample_names.iter().skip(n).collect(), + ) + } else { + (Vec::new(), Vec::new()) + }; + let dataset1 = DatasetBase::new(records_first, targets_first) .with_weights(first_weights) - .with_feature_names(self.feature_names.clone()); + .with_feature_names(self.feature_names.clone()) + .with_sample_names(first_sample_names); let dataset2 = DatasetBase::new(records_second, targets_second) .with_weights(second_weights) - .with_feature_names(self.feature_names.clone()); + .with_feature_names(self.feature_names.clone()) + .with_sample_names(second_sample_names); (dataset1, dataset2) } @@ -339,6 +394,7 @@ where label, DatasetBase::new(self.records().view(), targets) .with_feature_names(self.feature_names.clone()) + .with_sample_names(self.sample_names.clone()) .with_weights(self.weights.clone()), ) }) @@ -395,6 +451,7 @@ impl, I: Dimension> From> targets: empty_targets, weights: Array1::zeros(0), feature_names: Vec::new(), + sample_names: Vec::new(), } } } @@ -411,6 +468,7 @@ where targets: rec_tar.1, weights: Array1::zeros(0), feature_names: Vec::new(), + sample_names: Vec::new(), } } } @@ -977,12 +1035,26 @@ impl Dataset { Array1::zeros(0) }; + // split sample_names into two disjoint Vec + let second_sample_names = if self.sample_names.len() == n1 + n2 { + let mut sample_names = self.sample_names; + + let sample_names2 = sample_names.split_off(n1); + self.sample_names = sample_names; + + sample_names2 + } else { + Vec::new() + }; + // create new datasets with attached weights let dataset1 = Dataset::new(first, first_targets) .with_weights(self.weights) + .with_sample_names(self.sample_names) .with_feature_names(feature_names.clone()); let dataset2 = Dataset::new(second, second_targets) .with_weights(second_weights) + .with_sample_names(second_sample_names) .with_feature_names(feature_names); (dataset1, dataset2) diff --git a/src/dataset/impl_targets.rs b/src/dataset/impl_targets.rs index 36692d5c5..918cfc63c 100644 --- a/src/dataset/impl_targets.rs +++ b/src/dataset/impl_targets.rs @@ -231,6 +231,7 @@ where weights: Array1::from(weights), targets, feature_names: self.feature_names.clone(), + sample_names: self.sample_names.clone(), } } } diff --git a/src/dataset/iter.rs b/src/dataset/iter.rs index d1608f9ab..27a70c459 100644 --- a/src/dataset/iter.rs +++ b/src/dataset/iter.rs @@ -82,6 +82,7 @@ where let mut targets = self.dataset.targets.as_targets(); let feature_names; let weights = self.dataset.weights.clone(); + let sample_names = self.dataset.sample_names.clone(); if !self.target_or_feature { // This branch should only run for 2D targets @@ -103,6 +104,7 @@ where targets, weights, feature_names, + sample_names, }; Some(dataset_view) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index b04e48109..d45a06468 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -154,7 +154,7 @@ impl Deref for Pr { /// This is the fundamental structure of a dataset. It contains a number of records about the data /// and may contain targets, weights and feature names. In order to keep the type complexity low /// the dataset base is only generic over the records and targets and introduces a trait bound on -/// the records. `weights` and `feature_names`, on the other hand, are always assumed to be owned +/// the records. `weights`, `feature_names` and `sample_names`, on the other hand, are always assumed to be owned /// and copied when views are created. /// /// # Fields @@ -164,6 +164,7 @@ impl Deref for Pr { /// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets) /// * `weights`: optional weights for each sample with dimensionality (nsamples) /// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures) +/// * `sample_names`: optional descriptive sample names with dimensionality (nsamples) /// /// # Trait bounds /// @@ -180,6 +181,7 @@ where pub weights: Array1, feature_names: Vec, + sample_names: Vec, } /// Targets with precomputed, counted labels