From d4c865872c82f695eedd043ecbe32fe450d8d894 Mon Sep 17 00:00:00 2001 From: James Knight Date: Thu, 7 Jul 2022 13:52:55 +0100 Subject: [PATCH 1/4] -Added bootstrap aggregation for general linfa classifiers. -Added an example using bootstrap aggregation to carry out Random Forest classification. -Factored the linfa trait FromTargetArray to create an additional trait FromTargetArrayOwned. --- algorithms/linfa-ensemble/Cargo.toml | 36 +++++ algorithms/linfa-ensemble/README.md | 21 +++ .../examples/randomforest_iris.rs | 38 +++++ algorithms/linfa-ensemble/src/ensemble.rs | 153 ++++++++++++++++++ algorithms/linfa-ensemble/src/lib.rs | 3 + src/dataset/impl_dataset.rs | 10 +- src/dataset/impl_targets.rs | 26 ++- src/dataset/mod.rs | 9 +- 8 files changed, 281 insertions(+), 15 deletions(-) create mode 100644 algorithms/linfa-ensemble/Cargo.toml create mode 100644 algorithms/linfa-ensemble/README.md create mode 100644 algorithms/linfa-ensemble/examples/randomforest_iris.rs create mode 100644 algorithms/linfa-ensemble/src/ensemble.rs create mode 100644 algorithms/linfa-ensemble/src/lib.rs diff --git a/algorithms/linfa-ensemble/Cargo.toml b/algorithms/linfa-ensemble/Cargo.toml new file mode 100644 index 000000000..c737b4080 --- /dev/null +++ b/algorithms/linfa-ensemble/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "linfa-ensemble" +version = "0.6.0" +edition = "2018" +authors = ["James Knight ", "James Kay "] +description = "A general method for creating ensemble classifiers" +license = "MIT/Apache-2.0" + +repository = "https://github.com/rust-ml/linfa" +readme = "README.md" + +keywords = ["machine-learning", "linfa", "ensemble"] +categories = ["algorithms", "mathematics", "science"] + +[features] +default = [] +serde = ["serde_crate", "ndarray/serde"] + +[dependencies.serde_crate] +package = "serde" +optional = true +version = "1.0" +default-features = false +features = ["std", "derive"] + +[dependencies] +ndarray = { version = "0.15" , features = ["rayon", "approx"]} +ndarray-rand = "0.14" +rand = "0.8.5" + +linfa = { version = "0.6.0", path = "../.." } +linfa-trees = { version = "0.6.0", path = "../linfa-trees"} + +[dev-dependencies] +linfa-datasets = { version = "0.6.0", path = "../../datasets/", features = ["iris"] } + diff --git a/algorithms/linfa-ensemble/README.md b/algorithms/linfa-ensemble/README.md new file mode 100644 index 000000000..fba055aa7 --- /dev/null +++ b/algorithms/linfa-ensemble/README.md @@ -0,0 +1,21 @@ +# Enseble Learning + +`linfa-ensemble` provides pure Rust implementations of Ensemble Learning algorithms for the Linfa toolkit. + +## The Big Picture + +`linfa-ensemble` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. + +## Current state + +`linfa-ensemble` currently provides an implementation of bootstrap aggregation (bagging) for other classifers provided in linfa. + +## Examples + +You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use: + +```bash +$ cargo run --example randomforest_iris --release +``` + + diff --git a/algorithms/linfa-ensemble/examples/randomforest_iris.rs b/algorithms/linfa-ensemble/examples/randomforest_iris.rs new file mode 100644 index 000000000..862b69894 --- /dev/null +++ b/algorithms/linfa-ensemble/examples/randomforest_iris.rs @@ -0,0 +1,38 @@ +use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; +use linfa_ensemble::{EnsembleLearnerParams}; +use linfa_trees::DecisionTree; +use ndarray_rand::rand::SeedableRng; +use rand::rngs::SmallRng; + +fn main() { + //Number of models in the ensemble + let ensemble_size = 100; + //Proportion of training data given to each model + let bootstrap_proportion = 0.7; + + //Create ensemble learner + let mut learner = EnsembleLearnerParams::new(); + learner + .ensemble_size(ensemble_size) + .bootstrap_proportion(bootstrap_proportion) + .model_params(DecisionTree::params()); + + //Load dataset + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + //Train ensemble learner model + let model = learner.fit(&train).unwrap(); + + //Return highest ranking predictions + let final_predictions_ensemble = model.predict(&test); + println!("Final Predictions: \n{:?}", final_predictions_ensemble); + + let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap(); + + println!("{:?}", cm); + println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {},\n Bootstrap Proportion: {}", + 100.0 * cm.accuracy(), ensemble_size, bootstrap_proportion); +} diff --git a/algorithms/linfa-ensemble/src/ensemble.rs b/algorithms/linfa-ensemble/src/ensemble.rs new file mode 100644 index 000000000..f5a4620c1 --- /dev/null +++ b/algorithms/linfa-ensemble/src/ensemble.rs @@ -0,0 +1,153 @@ +use linfa::{ + dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records}, + error::{Error}, + traits::*, + DatasetBase, +}; +use ndarray::{ + Array2, Axis, Array, Dimension +}; +use std::{ + cmp::Eq, + collections::HashMap, + hash::Hash, +}; + +pub struct EnsembleLearner { + pub models: Vec, +} + +impl EnsembleLearner { + + // Generates prediction iterator returning predictions from each model + pub fn generate_predictions<'b, R: Records, T>(&'b self, x: &'b R) -> impl Iterator + 'b + where M: Predict<&'b R, T> + PredictInplace { + self.models.iter().map(move |m| m.predict(x)) + } + + // Consumes prediction iterator to return all predictions + // Orders predictions by total number of models giving that prediciton + pub fn aggregate_predictions(&self, ys: Ys) + -> impl Iterator::Elem, <::Ix as Dimension>::Smaller >, usize)>> + where + Ys::Item: AsTargets, + ::Elem: Copy + Eq + Hash, + { + let mut prediction_maps = Vec::new(); + + for y in ys { + let targets = y.as_targets(); + let no_targets = targets.shape()[0]; + + for i in 0..no_targets { + if prediction_maps.len() == i { + prediction_maps.push(HashMap::new()); + } + *prediction_maps[i].entry(y.as_targets().index_axis(Axis(0), i).to_owned()).or_insert(0) += 1; + } + } + + prediction_maps.into_iter().map(|xs| { + let mut xs: Vec<_> = xs.into_iter().collect(); + xs.sort_by(|(_, x), (_, y)| y.cmp(x)); + xs + }) + } +} + +impl +PredictInplace, T> for EnsembleLearner +where + M: PredictInplace, T>, + ::Elem: Copy + Eq + Hash, + T: AsTargets + AsTargetsMut::Elem>, +{ + fn predict_inplace(&self, x: &Array2, y: &mut T) { + let mut y_array = y.as_targets_mut(); + assert_eq!( + x.nrows(), + y_array.len(), + "The number of data points must match the number of output targets." + ); + + let mut predictions = self.generate_predictions(x); + let aggregated_predictions = self.aggregate_predictions(&mut predictions); + + for (target, output) in y_array.axis_iter_mut(Axis(0)).zip(aggregated_predictions.into_iter()) { + for (t, o) in target.into_iter().zip(output[0].0.iter()) { + *t = *o; + } + } + } + + fn default_target(&self, x: &Array2) -> T { + self.models[0].default_target(x) + } +} + +pub struct EnsembleLearnerParams

{ + pub ensemble_size: usize, + pub bootstrap_proportion: f64, + pub model_params: Option

, +} + +impl

EnsembleLearnerParams

{ + pub fn new() -> EnsembleLearnerParams

{ + EnsembleLearnerParams { + ensemble_size: 1, + bootstrap_proportion: 1.0, + model_params: None, + } + } + + pub fn ensemble_size(&mut self, size: usize) -> &mut EnsembleLearnerParams

{ + self.ensemble_size = size; + self + } + + pub fn bootstrap_proportion(&mut self, proportion: f64) -> &mut EnsembleLearnerParams

{ + self.bootstrap_proportion = proportion; + self + } + + pub fn model_params(&mut self, params: P) -> &mut EnsembleLearnerParams

{ + self.model_params = Some(params); + self + } +} + +impl, T::Owned, Error>> + Fit, T, Error> for EnsembleLearnerParams

+where + D: Clone, + T: FromTargetArrayOwned, + T::Elem: Copy + Eq + Hash, + T::Owned: AsTargets, +{ + type Object = EnsembleLearner; + + fn fit(&self, dataset: &DatasetBase, T>) -> Result { + assert!( + self.model_params.is_some(), + "Must define an underlying model for ensemble learner", + ); + + let mut models = Vec::new(); + let rng = &mut rand::thread_rng(); + + let dataset_size = ((dataset.records.shape()[0] as f64) * self.bootstrap_proportion) as usize; + + let iter = dataset.bootstrap_samples(dataset_size, rng); + + for train in iter { + let model = self.model_params.as_ref().unwrap().fit(&train).unwrap(); + models.push(model); + + if models.len() == self.ensemble_size { + break + } + } + + Ok(EnsembleLearner { models }) + } +} diff --git a/algorithms/linfa-ensemble/src/lib.rs b/algorithms/linfa-ensemble/src/lib.rs new file mode 100644 index 000000000..8d17edeb9 --- /dev/null +++ b/algorithms/linfa-ensemble/src/lib.rs @@ -0,0 +1,3 @@ +mod ensemble; + +pub use ensemble::*; diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index c835f8328..e96090720 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -2,7 +2,7 @@ use super::{ super::traits::{Predict, PredictInplace}, iter::{ChunksIter, DatasetIter, Iter}, AsSingleTargets, AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView, - Float, FromTargetArray, Label, Labels, Records, Result, TargetDim, + Float, FromTargetArray, FromTargetArrayOwned, Label, Labels, Records, Result, TargetDim, }; use crate::traits::Fit; use ndarray::{concatenate, prelude::*, Data, DataMut, Dimension}; @@ -418,7 +418,7 @@ where impl<'b, F: Clone, E: Copy + 'b, D, T> DatasetBase, T> where D: Data, - T: FromTargetArray<'b, Elem = E>, + T: FromTargetArrayOwned, T::Owned: AsTargets, { /// Apply bootstrapping for samples and features @@ -441,7 +441,7 @@ where &'b self, sample_feature_size: (usize, usize), rng: &'b mut R, - ) -> impl Iterator, >::Owned>> + 'b { + ) -> impl Iterator, T::Owned>> + 'b { std::iter::repeat(()).map(move |_| { // sample with replacement let indices = (0..sample_feature_size.0) @@ -481,7 +481,7 @@ where &'b self, num_samples: usize, rng: &'b mut R, - ) -> impl Iterator, >::Owned>> + 'b { + ) -> impl Iterator, T::Owned>> + 'b { std::iter::repeat(()).map(move |_| { // sample with replacement let indices = (0..num_samples) @@ -515,7 +515,7 @@ where &'b self, num_features: usize, rng: &'b mut R, - ) -> impl Iterator, >::Owned>> + 'b { + ) -> impl Iterator, T::Owned>> + 'b { std::iter::repeat(()).map(move |_| { let targets = T::new_targets(self.as_targets().to_owned()); diff --git a/src/dataset/impl_targets.rs b/src/dataset/impl_targets.rs index 080a0a5c5..61e32531a 100644 --- a/src/dataset/impl_targets.rs +++ b/src/dataset/impl_targets.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::{ AsMultiTargets, AsMultiTargetsMut, AsProbabilities, AsSingleTargets, AsSingleTargetsMut, - AsTargets, AsTargetsMut, CountedTargets, DatasetBase, FromTargetArray, Label, Labels, Pr, + AsTargets, AsTargetsMut, CountedTargets, DatasetBase, FromTargetArray, FromTargetArrayOwned, Label, Labels, Pr, TargetDim, }; use ndarray::{ @@ -25,14 +25,17 @@ impl<'a, L, S: Data, I: TargetDim> AsTargets for ArrayBase { impl> AsSingleTargets for T {} impl> AsMultiTargets for T {} -impl<'a, L: Clone + 'a, S: Data, I: TargetDim> FromTargetArray<'a> for ArrayBase { +impl<'a, L: Clone + 'a, S: Data, I: TargetDim> FromTargetArrayOwned for ArrayBase { type Owned = ArrayBase, I>; - type View = ArrayBase, I>; /// Returns an owned representation of the target array fn new_targets(targets: Array) -> Self::Owned { targets } +} + +impl<'a, L: Clone + 'a, S: Data, I: TargetDim> FromTargetArray<'a> for ArrayBase { + type View = ArrayBase, I>; /// Returns a reference to the target array fn new_targets_view(targets: ArrayView<'a, L, I>) -> Self::View { @@ -40,6 +43,7 @@ impl<'a, L: Clone + 'a, S: Data, I: TargetDim> FromTargetArray<'a> for } } + impl, I: TargetDim> AsTargetsMut for ArrayBase { type Elem = L; type Ix = I; @@ -79,23 +83,29 @@ impl> AsTargetsMut for CountedTargets } } -impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets +impl FromTargetArrayOwned for CountedTargets where - T: FromTargetArray<'a, Elem = L>, + T: FromTargetArrayOwned, T::Owned: Labels, - T::View: Labels, { type Owned = CountedTargets; - type View = CountedTargets; fn new_targets(targets: Array) -> Self::Owned { let targets = T::new_targets(targets); - CountedTargets { labels: targets.label_count(), targets, } } +} + + +impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets +where + T: FromTargetArray<'a, Elem = L>, + T::View: Labels, +{ + type View = CountedTargets; fn new_targets_view(targets: ArrayView<'a, L, T::Ix>) -> Self::View { let targets = T::new_targets_view(targets); diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index b04e48109..073c8e6b5 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -259,17 +259,22 @@ pub trait AsMultiTargets: AsTargets { } } +pub trait FromTargetArrayOwned: AsTargets { + type Owned; + + /// Create self object from new target array + fn new_targets(targets: Array) -> Self::Owned; +} + /// Helper trait to construct counted labels /// /// This is implemented for objects which can act as targets and created from a target matrix. For /// targets represented as `ndarray` matrix this is identity, for counted labels, i.e. /// `TargetsWithLabels`, it creates the corresponding wrapper struct. pub trait FromTargetArray<'a>: AsTargets { - type Owned; type View; /// Create self object from new target array - fn new_targets(targets: Array) -> Self::Owned; fn new_targets_view(targets: ArrayView<'a, Self::Elem, Self::Ix>) -> Self::View; } From 4bf178cc65562b28c97de3adc0750afdb71d8f1c Mon Sep 17 00:00:00 2001 From: James Knight Date: Thu, 25 Aug 2022 09:45:06 +0100 Subject: [PATCH 2/4] All fixes from PR review other than changes to predict_inplace --- algorithms/linfa-ensemble/src/ensemble.rs | 54 ++++++++++++----------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/algorithms/linfa-ensemble/src/ensemble.rs b/algorithms/linfa-ensemble/src/ensemble.rs index f5a4620c1..601734b5d 100644 --- a/algorithms/linfa-ensemble/src/ensemble.rs +++ b/algorithms/linfa-ensemble/src/ensemble.rs @@ -12,6 +12,8 @@ use std::{ collections::HashMap, hash::Hash, }; +use rand::Rng; +use rand::rngs::ThreadRng; pub struct EnsembleLearner { pub models: Vec, @@ -21,12 +23,11 @@ impl EnsembleLearner { // Generates prediction iterator returning predictions from each model pub fn generate_predictions<'b, R: Records, T>(&'b self, x: &'b R) -> impl Iterator + 'b - where M: Predict<&'b R, T> + PredictInplace { + where M: Predict<&'b R, T> { self.models.iter().map(move |m| m.predict(x)) } // Consumes prediction iterator to return all predictions - // Orders predictions by total number of models giving that prediciton pub fn aggregate_predictions(&self, ys: Ys) -> impl Iterator::Elem, <::Ix as Dimension>::Smaller >, usize)>> where @@ -53,6 +54,7 @@ impl EnsembleLearner { xs }) } + } impl @@ -66,8 +68,8 @@ where let mut y_array = y.as_targets_mut(); assert_eq!( x.nrows(), - y_array.len(), - "The number of data points must match the number of output targets." + y_array.len_of(Axis(0)), + "The number of data points must match the number of outputs." ); let mut predictions = self.generate_predictions(x); @@ -85,39 +87,45 @@ where } } -pub struct EnsembleLearnerParams

{ +pub struct EnsembleLearnerParams { pub ensemble_size: usize, pub bootstrap_proportion: f64, - pub model_params: Option

, + pub model_params: P, + pub rng: R +} + +impl

EnsembleLearnerParams { + pub fn new(model_params: P) -> EnsembleLearnerParams { + return Self::new_fixed_rng(model_params, rand::thread_rng()) + } } -impl

EnsembleLearnerParams

{ - pub fn new() -> EnsembleLearnerParams

{ +impl EnsembleLearnerParams { + pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams { EnsembleLearnerParams { ensemble_size: 1, bootstrap_proportion: 1.0, - model_params: None, + model_params: model_params, + rng: rng } } - pub fn ensemble_size(&mut self, size: usize) -> &mut EnsembleLearnerParams

{ + pub fn ensemble_size(&mut self, size: usize) -> &mut EnsembleLearnerParams { + assert!(size > 0, "ensemble_size cannot be less than 1. Ensembles must consist of at least one model."); self.ensemble_size = size; self } - pub fn bootstrap_proportion(&mut self, proportion: f64) -> &mut EnsembleLearnerParams

{ + pub fn bootstrap_proportion(&mut self, proportion: f64) -> &mut EnsembleLearnerParams { + assert!(proportion > 0.0, "bootstrap_proportion must be greater than 0. Must provide some data to each model."); self.bootstrap_proportion = proportion; self } - pub fn model_params(&mut self, params: P) -> &mut EnsembleLearnerParams

{ - self.model_params = Some(params); - self - } } -impl, T::Owned, Error>> - Fit, T, Error> for EnsembleLearnerParams

+impl, T::Owned, Error>, R: Rng + Clone> + Fit, T, Error> for EnsembleLearnerParams where D: Clone, T: FromTargetArrayOwned, @@ -127,20 +135,16 @@ where type Object = EnsembleLearner; fn fit(&self, dataset: &DatasetBase, T>) -> Result { - assert!( - self.model_params.is_some(), - "Must define an underlying model for ensemble learner", - ); let mut models = Vec::new(); - let rng = &mut rand::thread_rng(); + let mut rng = self.rng.clone(); - let dataset_size = ((dataset.records.shape()[0] as f64) * self.bootstrap_proportion) as usize; + let dataset_size = ((dataset.records.shape()[0] as f64) * self.bootstrap_proportion).ceil() as usize; - let iter = dataset.bootstrap_samples(dataset_size, rng); + let iter = dataset.bootstrap_samples(dataset_size, &mut rng); for train in iter { - let model = self.model_params.as_ref().unwrap().fit(&train).unwrap(); + let model = self.model_params.fit(&train).unwrap(); models.push(model); if models.len() == self.ensemble_size { From 647083a53837f0069d96157e8fb69dab858d5042 Mon Sep 17 00:00:00 2001 From: James Knight Date: Mon, 5 Sep 2022 10:43:05 +0100 Subject: [PATCH 3/4] updated example --- algorithms/linfa-ensemble/examples/randomforest_iris.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/algorithms/linfa-ensemble/examples/randomforest_iris.rs b/algorithms/linfa-ensemble/examples/randomforest_iris.rs index 862b69894..d3a00fa2e 100644 --- a/algorithms/linfa-ensemble/examples/randomforest_iris.rs +++ b/algorithms/linfa-ensemble/examples/randomforest_iris.rs @@ -11,11 +11,10 @@ fn main() { let bootstrap_proportion = 0.7; //Create ensemble learner - let mut learner = EnsembleLearnerParams::new(); + let mut learner = EnsembleLearnerParams::new(DecisionTree::params()); learner .ensemble_size(ensemble_size) - .bootstrap_proportion(bootstrap_proportion) - .model_params(DecisionTree::params()); + .bootstrap_proportion(bootstrap_proportion); //Load dataset let mut rng = SmallRng::seed_from_u64(42); From 46722ec5b0d4adf7073099daeb9dc4dc45d84f88 Mon Sep 17 00:00:00 2001 From: James Knight Date: Mon, 26 Sep 2022 16:03:50 +0100 Subject: [PATCH 4/4] Added ParamGuard with consuming builder to EnsembleLearnerParams --- .../examples/randomforest_iris.rs | 14 +-- algorithms/linfa-ensemble/src/ensemble.rs | 119 ++++++++++++------ 2 files changed, 86 insertions(+), 47 deletions(-) diff --git a/algorithms/linfa-ensemble/examples/randomforest_iris.rs b/algorithms/linfa-ensemble/examples/randomforest_iris.rs index d3a00fa2e..ce54d50c3 100644 --- a/algorithms/linfa-ensemble/examples/randomforest_iris.rs +++ b/algorithms/linfa-ensemble/examples/randomforest_iris.rs @@ -1,5 +1,5 @@ use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; -use linfa_ensemble::{EnsembleLearnerParams}; +use linfa_ensemble::EnsembleLearnerParams; use linfa_trees::DecisionTree; use ndarray_rand::rand::SeedableRng; use rand::rngs::SmallRng; @@ -10,12 +10,6 @@ fn main() { //Proportion of training data given to each model let bootstrap_proportion = 0.7; - //Create ensemble learner - let mut learner = EnsembleLearnerParams::new(DecisionTree::params()); - learner - .ensemble_size(ensemble_size) - .bootstrap_proportion(bootstrap_proportion); - //Load dataset let mut rng = SmallRng::seed_from_u64(42); let (train, test) = linfa_datasets::iris() @@ -23,7 +17,11 @@ fn main() { .split_with_ratio(0.8); //Train ensemble learner model - let model = learner.fit(&train).unwrap(); + let model = EnsembleLearnerParams::new(DecisionTree::params()) + .ensemble_size(ensemble_size) + .bootstrap_proportion(bootstrap_proportion) + .fit(&train) + .unwrap(); //Return highest ranking predictions let final_predictions_ensemble = model.predict(&test); diff --git a/algorithms/linfa-ensemble/src/ensemble.rs b/algorithms/linfa-ensemble/src/ensemble.rs index 601734b5d..54b06cc7f 100644 --- a/algorithms/linfa-ensemble/src/ensemble.rs +++ b/algorithms/linfa-ensemble/src/ensemble.rs @@ -1,35 +1,43 @@ use linfa::{ dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records}, - error::{Error}, + error::{Error, Result}, traits::*, - DatasetBase, + DatasetBase, ParamGuard, }; -use ndarray::{ - Array2, Axis, Array, Dimension -}; -use std::{ - cmp::Eq, - collections::HashMap, - hash::Hash, -}; -use rand::Rng; +use ndarray::{Array, Array2, Axis, Dimension}; use rand::rngs::ThreadRng; +use rand::Rng; +use std::{cmp::Eq, collections::HashMap, hash::Hash}; pub struct EnsembleLearner { pub models: Vec, } impl EnsembleLearner { - // Generates prediction iterator returning predictions from each model - pub fn generate_predictions<'b, R: Records, T>(&'b self, x: &'b R) -> impl Iterator + 'b - where M: Predict<&'b R, T> { + pub fn generate_predictions<'b, R: Records, T>( + &'b self, + x: &'b R, + ) -> impl Iterator + 'b + where + M: Predict<&'b R, T>, + { self.models.iter().map(move |m| m.predict(x)) } // Consumes prediction iterator to return all predictions - pub fn aggregate_predictions(&self, ys: Ys) - -> impl Iterator::Elem, <::Ix as Dimension>::Smaller >, usize)>> + pub fn aggregate_predictions( + &self, + ys: Ys, + ) -> impl Iterator< + Item = Vec<( + Array< + ::Elem, + <::Ix as Dimension>::Smaller, + >, + usize, + )>, + > where Ys::Item: AsTargets, ::Elem: Copy + Eq + Hash, @@ -44,7 +52,9 @@ impl EnsembleLearner { if prediction_maps.len() == i { prediction_maps.push(HashMap::new()); } - *prediction_maps[i].entry(y.as_targets().index_axis(Axis(0), i).to_owned()).or_insert(0) += 1; + *prediction_maps[i] + .entry(y.as_targets().index_axis(Axis(0), i).to_owned()) + .or_insert(0) += 1; } } @@ -54,11 +64,9 @@ impl EnsembleLearner { xs }) } - } -impl -PredictInplace, T> for EnsembleLearner +impl PredictInplace, T> for EnsembleLearner where M: PredictInplace, T>, ::Elem: Copy + Eq + Hash, @@ -75,7 +83,10 @@ where let mut predictions = self.generate_predictions(x); let aggregated_predictions = self.aggregate_predictions(&mut predictions); - for (target, output) in y_array.axis_iter_mut(Axis(0)).zip(aggregated_predictions.into_iter()) { + for (target, output) in y_array + .axis_iter_mut(Axis(0)) + .zip(aggregated_predictions.into_iter()) + { for (t, o) in target.into_iter().zip(output[0].0.iter()) { *t = *o; } @@ -87,45 +98,72 @@ where } } -pub struct EnsembleLearnerParams { +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct EnsembleLearnerValidParams { pub ensemble_size: usize, pub bootstrap_proportion: f64, pub model_params: P, - pub rng: R + pub rng: R, } +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct EnsembleLearnerParams(EnsembleLearnerValidParams); + impl

EnsembleLearnerParams { pub fn new(model_params: P) -> EnsembleLearnerParams { - return Self::new_fixed_rng(model_params, rand::thread_rng()) + return Self::new_fixed_rng(model_params, rand::thread_rng()); } } impl EnsembleLearnerParams { pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams { - EnsembleLearnerParams { + Self(EnsembleLearnerValidParams { ensemble_size: 1, bootstrap_proportion: 1.0, model_params: model_params, - rng: rng - } + rng: rng, + }) } - pub fn ensemble_size(&mut self, size: usize) -> &mut EnsembleLearnerParams { - assert!(size > 0, "ensemble_size cannot be less than 1. Ensembles must consist of at least one model."); - self.ensemble_size = size; + pub fn ensemble_size(mut self, size: usize) -> Self { + self.0.ensemble_size = size; self } - pub fn bootstrap_proportion(&mut self, proportion: f64) -> &mut EnsembleLearnerParams { - assert!(proportion > 0.0, "bootstrap_proportion must be greater than 0. Must provide some data to each model."); - self.bootstrap_proportion = proportion; + pub fn bootstrap_proportion(mut self, proportion: f64) -> Self { + self.0.bootstrap_proportion = proportion; self } +} + +impl ParamGuard for EnsembleLearnerParams { + type Checked = EnsembleLearnerValidParams; + type Error = Error; + + fn check_ref(&self) -> Result<&Self::Checked> { + if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 { + Err(Error::Parameters(format!( + "Bootstrap proportion should be greater than zero and less than or equal to one, but was {}", + self.0.bootstrap_proportion + ))) + } else if self.0.ensemble_size < 1 { + Err(Error::Parameters(format!( + "Ensemble size should be less than one, but was {}", + self.0.ensemble_size + ))) + } else { + Ok(&self.0) + } + } + fn check(self) -> Result { + self.check_ref()?; + Ok(self.0) + } } -impl, T::Owned, Error>, R: Rng + Clone> - Fit, T, Error> for EnsembleLearnerParams +impl, T::Owned, Error>, R: Rng + Clone> Fit, T, Error> + for EnsembleLearnerValidParams where D: Clone, T: FromTargetArrayOwned, @@ -134,12 +172,15 @@ where { type Object = EnsembleLearner; - fn fit(&self, dataset: &DatasetBase, T>) -> Result { - + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> core::result::Result { let mut models = Vec::new(); let mut rng = self.rng.clone(); - let dataset_size = ((dataset.records.shape()[0] as f64) * self.bootstrap_proportion).ceil() as usize; + let dataset_size = + ((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize; let iter = dataset.bootstrap_samples(dataset_size, &mut rng); @@ -148,7 +189,7 @@ where models.push(model); if models.len() == self.ensemble_size { - break + break; } }