Skip to content

Commit 72d67e3

Browse files
committed
add TimeSeriesScaleMeanMaxVariance
1 parent 87495c1 commit 72d67e3

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

tslearn/preprocessing/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from .preprocessing import (
77
TimeSeriesScalerMeanVariance,
88
TimeSeriesScalerMinMax,
9-
TimeSeriesResampler
9+
TimeSeriesResampler,
10+
TimeSeriesScaleMeanMaxVariance
1011
)
1112

1213
__all__ = [
1314
"TimeSeriesResampler",
1415
"TimeSeriesScalerMinMax",
15-
"TimeSeriesScalerMeanVariance"
16+
"TimeSeriesScalerMeanVariance",
17+
"TimeSeriesScaleMeanMaxVariance",
1618
]

tslearn/preprocessing/preprocessing.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,91 @@ def transform(self, X, y=None, **kwargs):
296296

297297
def _more_tags(self):
298298
return {'allow_nan': True}
299+
300+
301+
class TimeSeriesScaleMeanMaxVariance(TransformerMixin, TimeSeriesBaseEstimator):
302+
"""Scaler for time series. Scales time series so that their mean (resp.
303+
standard deviation) in the signal with the max amplitue is
304+
mu (resp. std). The scaling relationships between each signal are preserved
305+
This is supplement to the TimeSeriesScalerMeanVariance method
306+
307+
Parameters
308+
----------
309+
mu : float (default: 0.)
310+
Mean of the output time series.
311+
std : float (default: 1.)
312+
Standard deviation of the output time series.
313+
314+
Notes
315+
-----
316+
This method requires a dataset of equal-sized time series.
317+
318+
NaNs within a time series are ignored when calculating mu and std.
319+
"""
320+
321+
def __init__(self, mu=0., std=1.):
322+
self.mu = mu
323+
self.std = std
324+
325+
def fit(self, X, y=None, **kwargs):
326+
"""A dummy method such that it complies to the sklearn requirements.
327+
Since this method is completely stateless, it just returns itself.
328+
329+
Parameters
330+
----------
331+
X
332+
Ignored
333+
334+
Returns
335+
-------
336+
self
337+
"""
338+
X = check_array(X, allow_nd=True, force_all_finite=False)
339+
X = to_time_series_dataset(X)
340+
self._X_fit_dims = X.shape
341+
return self
342+
343+
def fit_transform(self, X, y=None, **kwargs):
344+
"""Fit to data, then transform it.
345+
346+
Parameters
347+
----------
348+
X : array-like of shape (n_ts, sz, d)
349+
Time series dataset to be rescaled.
350+
351+
Returns
352+
-------
353+
numpy.ndarray
354+
Resampled time series dataset.
355+
"""
356+
return self.fit(X).transform(X)
357+
358+
def transform(self, X, y=None, **kwargs):
359+
"""Fit to data, then transform it.
360+
361+
Parameters
362+
----------
363+
X : array-like of shape (n_ts, sz, d)
364+
Time series dataset to be rescaled
365+
366+
Returns
367+
-------
368+
numpy.ndarray
369+
Rescaled time series dataset
370+
"""
371+
check_is_fitted(self, '_X_fit_dims')
372+
X = check_array(X, allow_nd=True, force_all_finite=False)
373+
X_ = to_time_series_dataset(X)
374+
X_ = check_dims(X_, X_fit_dims=self._X_fit_dims, extend=False)
375+
mean_t = numpy.nanmean(X_, axis=1, keepdims=True)
376+
std_t = numpy.nanstd(X_, axis=1, keepdims=True)
377+
# retain the scaling relation cross the signals,
378+
# the max std_t is set to self.std
379+
max_std = max(std_t)
380+
if max_std ==0.: max_std = 1
381+
X_ = (X_ - mean_t) * self.std / max_std + self.mu
382+
383+
return X_
384+
385+
def _more_tags(self):
386+
return {'allow_nan': True}

0 commit comments

Comments
 (0)