From 40aab6d2344af7fa195472e124f61bc1d6f6db32 Mon Sep 17 00:00:00 2001 From: Ivan Slijepcevic Date: Thu, 17 Oct 2024 08:59:40 -0700 Subject: [PATCH] 3/n migrate interpolate from param base to param origin Summary: Fixing time series interpolation, where pandas `resample` has changed the API. Keeping the old logic for now. New logic uses `origin="start_day"` for backward compatibility. In future, we should remove the old code using `base` param, and move the default of `origin="start"`. Differential Revision: D64365553 fbshipit-source-id: 331efff1f20499301553452ebde4ff93c6e0caeb --- kats/consts.py | 53 ++++++++++++++++++++++++++++++++++++--- kats/tests/test_consts.py | 6 +++-- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/kats/consts.py b/kats/consts.py index f53844d6..7b09d715 100644 --- a/kats/consts.py +++ b/kats/consts.py @@ -26,7 +26,7 @@ import logging from collections.abc import Iterable from enum import auto, Enum, unique -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Union import dateutil import matplotlib.pyplot as plt @@ -38,6 +38,13 @@ from pandas.tseries.frequencies import to_offset FigSize = Tuple[int, int] +INTERPOLATION_METHOD_TYPE = ( + Literal["higher"] + | Literal["linear"] + | Literal["lower"] + | Literal["midpoint"] + | Literal["nearest"] +) # Constants @@ -946,6 +953,7 @@ def interpolate( self, freq: Optional[Union[str, pd.Timedelta]] = None, base: int = 0, + origin: pd.Timestamp | str = "start_day", method: str = "linear", remove_duplicate_time: bool = False, **kwargs: Any, @@ -968,6 +976,12 @@ def interpolate( See https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html Note that base will be deprecated since version 1.1.0. The new arguments that you should use are ‘offset’ or ‘origin’. + origin: base argument for resample(). + See https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html + Origin "start_day" is chosen for backward compatibility with `base=0`. + When non-default `base` is detected, `origin` will be set to "start". + Future versions of Kats will deprecate `base` and use `origin` instead, + defaulting to "start". method: A string representing the method to impute the missing time and data. See the above options (default "linear"). remove_duplicate_index: A boolean to auto-remove any duplicate time @@ -978,7 +992,6 @@ def interpolate( Returns: A new :class:`TimeSeriesData` object with interpolated data. """ - if not freq: freq = self.infer_freq_robust() @@ -1002,6 +1015,10 @@ def interpolate( if remove_duplicate_time: df = df[~df.index.duplicated()] + if pd.__version__ >= "1.1": + origin = origin if base == 0 else "start" + return self._interpolate_new(df, freq, origin, method, **kwargs) + if method == "linear": df = df.resample(rule=freq, base=base).interpolate(method="linear") @@ -1011,7 +1028,37 @@ def interpolate( elif method == "bfill": df = df.resample(rule=freq, base=base).bfill() else: - df = df.resample(rule=freq, base=base).interpolate(method=method, **kwargs) + df = df.resample(rule=freq, base=base).interpolate( + method=cast(INTERPOLATION_METHOD_TYPE, method), **kwargs + ) + + df = df.reset_index().rename(columns={"index": self.time_col_name}) + return TimeSeriesData(df, time_col_name=self.time_col_name) + + def _interpolate_new( + self, + df: pd.DataFrame, + freq: Optional[Union[str, pd.Timedelta]], + origin: pd.Timestamp | str, + method: str, + **kwargs: Any, + ) -> TimeSeriesData: + if method == "linear": + # pyre-ignore + df = df.resample(rule=freq, origin=origin).interpolate(method="linear") + + elif method == "ffill": + # pyre-ignore + df = df.resample(rule=freq, origin=origin).ffill() + + elif method == "bfill": + # pyre-ignore + df = df.resample(rule=freq, origin=origin).bfill() + else: + # pyre-ignore + df = df.resample(rule=freq, origin=origin).interpolate( + method=cast(INTERPOLATION_METHOD_TYPE, method), **kwargs + ) df = df.reset_index().rename(columns={"index": self.time_col_name}) return TimeSeriesData(df, time_col_name=self.time_col_name) diff --git a/kats/tests/test_consts.py b/kats/tests/test_consts.py index a599d7e8..0a524dbf 100644 --- a/kats/tests/test_consts.py +++ b/kats/tests/test_consts.py @@ -865,15 +865,17 @@ def test_interpolate_base(self) -> None: # calculate frequency first frequency = str(int(ts0.infer_freq_robust().total_seconds())) + "s" - # without base value, interpolate won't work, will return all NaN + # Without base value, interpolate won't work, will return all NaN # this is because start time is not from "**:00:00" or "**:30:00" type. + # This is equivalent to origin="start_day" self.assertEqual( # pyre-fixme[16]: Optional type has no attribute `value`. ts0.interpolate(freq=frequency).to_dataframe().fillna(0).value.sum(), 0, ) - # with base value, will start from "**:59:59" ("**:00:00" - 1 sec) + # With base value, will start from "**:59:59" ("**:00:00" - 1 sec) # or "**:29:59" ("**:30:00" -1 sec). + # Here we default to origin="start" instead of origin="start_day", which works. self.assertEqual( ts0.interpolate(freq=frequency, base=-1).to_dataframe().isna().value.sum(), 0,