Skip to content

Commit

Permalink
fix: really fix filtering of NaNs this time (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Jan 7, 2025
1 parent af2a921 commit 570ec3f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
2 changes: 1 addition & 1 deletion crates/augurs-prophet/src/prophet/prep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ pub(super) struct Features {

impl<O> Prophet<O> {
pub(super) fn preprocess(&mut self, data: TrainingData) -> Result<Preprocessed, Error> {
let data = data.filter_nans();
let n = data.ds.len();
if n != data.y.len() {
return Err(Error::MismatchedLengths {
Expand All @@ -207,7 +208,6 @@ impl<O> Prophet<O> {
if n < 2 {
return Err(Error::NotEnoughData);
}
let data = data.filter_nans();

let mut history_dates = data.ds.clone();
history_dates.sort_unstable();
Expand Down
30 changes: 29 additions & 1 deletion crates/augurs-prophet/tests/wasmstan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use augurs_prophet::{Prophet, ProphetOptions};

#[test]
fn wasmstan() {
tracing_subscriber::fmt::init();
tracing_subscriber::fmt::try_init().ok();
let opts = ProphetOptions::default();
let opt = WasmstanOptimizer::new();
let mut prophet = Prophet::new(opts, opt);
Expand All @@ -31,6 +31,34 @@ fn wasmstan() {
assert_all_close(&predictions.yhat.point, EXPECTED);
}

#[test]
fn wasmstan_nans() {
tracing_subscriber::fmt::try_init().ok();
let opts = ProphetOptions::default();
let opt = WasmstanOptimizer::new();
let mut prophet = Prophet::new(opts, opt);
let mut y = TRAINING_Y.to_vec();
y[100] = f64::NAN;
y[200] = f64::NAN;
y[300] = f64::NAN;
let training_data = TrainingData::new(TRAINING_DS.to_vec(), y).unwrap();
tracing::info!("fitting");
prophet
.fit(
training_data,
OptimizeOpts {
seed: Some(100),
..Default::default()
},
)
.unwrap();
let prediction_data = PredictionData::new(PREDICTION_DS.to_vec());
tracing::info!("predicting");
let predictions = prophet.predict(Some(prediction_data)).unwrap();
tracing::info!("done");
assert_all_close(&predictions.yhat.point, EXPECTED);
}

static TRAINING_DS: &[TimestampSeconds] = &[
1727168400, 1727169600, 1727170800, 1727172000, 1727173200, 1727174400, 1727175600, 1727176800,
1727178000, 1727179200, 1727180400, 1727181600, 1727182800, 1727184000, 1727185200, 1727186400,
Expand Down

0 comments on commit 570ec3f

Please sign in to comment.