Skip to content

Commit

Permalink
Merge pull request #237 from grafana/js-expose-minmaxscaler
Browse files Browse the repository at this point in the history
feat(js): expose `MinMaxScaler` transform
  • Loading branch information
yoziru authored Jan 14, 2025
2 parents 22be0e3 + 1357c31 commit 6291626
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 47 deletions.
5 changes: 4 additions & 1 deletion js/augurs-transforms-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tsify_next::Tsify;
use wasm_bindgen::prelude::*;

use augurs_core_js::VecF64;
use augurs_forecaster::transforms::{StandardScaler, Transformer, YeoJohnson};
use augurs_forecaster::transforms::{MinMaxScaler, StandardScaler, Transformer, YeoJohnson};

/// A transformation to be applied to the data.
///
Expand All @@ -14,6 +14,8 @@ use augurs_forecaster::transforms::{StandardScaler, Transformer, YeoJohnson};
#[serde(rename_all = "camelCase", tag = "type")]
#[tsify(from_wasm_abi)]
pub enum Transform {
/// Scale the data to the range [0, 1].
MinMaxScaler,
/// Standardize the data such that it has zero mean and unit variance.
StandardScaler {
/// Whether to ignore NaNs.
Expand All @@ -31,6 +33,7 @@ pub enum Transform {
impl Transform {
fn into_transformer(self) -> Box<dyn Transformer> {
match self {
Transform::MinMaxScaler => Box::new(MinMaxScaler::new()),
Transform::StandardScaler { ignore_nans } => {
Box::new(StandardScaler::new().ignore_nans(ignore_nans))
}
Expand Down
89 changes: 43 additions & 46 deletions js/testpkg/transforms.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,55 +40,52 @@ describe('transforms', () => {
}
});

describe('pipeline', () => {
it('works with arrays', () => {
const pt = new Pipeline([{ type: "standardScaler" }, { type: "yeoJohnson" }]);
const transformed = pt.fitTransform(y);
expect(transformed).toBeInstanceOf(Float64Array);
expect(transformed).toHaveLength(y.length);
const inverse = pt.inverseTransform(transformed);
expect(inverse).toBeInstanceOf(Float64Array);
expect(inverse).toHaveLength(y.length);
//@ts-ignore
expect(Array.from(inverse)).toAllBeCloseTo(y);
});

it('handles empty pipeline', () => {
const pt = new Pipeline([]);
expect(() => pt.fitTransform(y)).not.toThrow();
});
const testScalerPipeline = (
description: string,
inputData: number[],
scalerType: "standardScaler" | "minMaxScaler",
ignoreNaNs = false
) => {
describe(description, () => {
it(`works with arrays ${scalerType}`, () => {
const pt = new Pipeline([
scalerType === "standardScaler"
? { type: scalerType, ignoreNaNs }
: { type: scalerType },
{ type: "yeoJohnson", ignoreNaNs }
]);
const transformed = pt.fitTransform(inputData);
expect(transformed).toBeInstanceOf(Float64Array);
expect(transformed).toHaveLength(inputData.length);

it('handles invalid transforms', () => {
// @ts-ignore
expect(() => new Pipeline(["invalidTransform"])).toThrow();
});
});
const inverse = pt.inverseTransform(transformed);
expect(inverse).toBeInstanceOf(Float64Array);
expect(inverse).toHaveLength(inputData.length);
//@ts-ignore
expect(Array.from(inverse)).toAllBeCloseTo(inputData);
});

describe('pipeline with nans', () => {
const yWithNaNs = [...y];
yWithNaNs[10] = NaN;
yWithNaNs[20] = NaN;
it('handles empty pipeline', () => {
const pt = new Pipeline([]);
expect(() => pt.fitTransform(inputData)).not.toThrow();
});

it('works with arrays', () => {
const pt = new Pipeline([{ type: "standardScaler", ignoreNaNs: true }, { type: "yeoJohnson", ignoreNaNs: true }]);
const transformed = pt.fitTransform(yWithNaNs);
expect(transformed).toBeInstanceOf(Float64Array);
expect(transformed).toHaveLength(yWithNaNs.length);
const inverse = pt.inverseTransform(transformed);
expect(inverse).toBeInstanceOf(Float64Array);
expect(inverse).toHaveLength(yWithNaNs.length);
//@ts-ignore
expect(Array.from(inverse)).toAllBeCloseTo(yWithNaNs);
it('handles invalid transforms', () => {
// @ts-ignore
expect(() => new Pipeline(["invalidTransform"])).toThrow();
});
});
};

it('handles empty pipeline', () => {
const pt = new Pipeline([]);
expect(() => pt.fitTransform(yWithNaNs)).not.toThrow();
});
// Test regular pipeline
testScalerPipeline('pipeline with standard scaler', y, 'standardScaler');
testScalerPipeline('pipeline with minmax scaler', y, 'minMaxScaler');

it('handles invalid transforms', () => {
// @ts-ignore
expect(() => new Pipeline(["invalidTransform"])).toThrow();
});
});
})
// Test pipeline with NaNs
const yWithNaNs = [...y];
yWithNaNs[10] = NaN;
yWithNaNs[20] = NaN;

testScalerPipeline('pipeline with NaNs - standard scaler', yWithNaNs, 'standardScaler', true);
testScalerPipeline('pipeline with NaNs - minmax scaler', yWithNaNs, 'minMaxScaler', true);
});

0 comments on commit 6291626

Please sign in to comment.