Skip to content

Commit

Permalink
3/n migrate interpolate from param base to param origin
Browse files Browse the repository at this point in the history
Summary:
Fixing time series interpolation, where pandas `resample` has changed the API.

Keeping the old logic for now.
New logic uses `origin="start_day"` for backward compatibility.

In future, we should remove the old code using `base` param, and move the default of `origin="start"`.

Differential Revision: D64365553

fbshipit-source-id: 331efff1f20499301553452ebde4ff93c6e0caeb
  • Loading branch information
islijepcevic authored and facebook-github-bot committed Oct 17, 2024
1 parent f8d5e6a commit 40aab6d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
53 changes: 50 additions & 3 deletions kats/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import logging
from collections.abc import Iterable
from enum import auto, Enum, unique
from typing import Any, cast, Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Union

import dateutil
import matplotlib.pyplot as plt
Expand All @@ -38,6 +38,13 @@
from pandas.tseries.frequencies import to_offset

FigSize = Tuple[int, int]
INTERPOLATION_METHOD_TYPE = (
Literal["higher"]
| Literal["linear"]
| Literal["lower"]
| Literal["midpoint"]
| Literal["nearest"]
)


# Constants
Expand Down Expand Up @@ -946,6 +953,7 @@ def interpolate(
self,
freq: Optional[Union[str, pd.Timedelta]] = None,
base: int = 0,
origin: pd.Timestamp | str = "start_day",
method: str = "linear",
remove_duplicate_time: bool = False,
**kwargs: Any,
Expand All @@ -968,6 +976,12 @@ def interpolate(
See https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html
Note that base will be deprecated since version 1.1.0.
The new arguments that you should use are ‘offset’ or ‘origin’.
origin: base argument for resample().
See https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html
Origin "start_day" is chosen for backward compatibility with `base=0`.
When non-default `base` is detected, `origin` will be set to "start".
Future versions of Kats will deprecate `base` and use `origin` instead,
defaulting to "start".
method: A string representing the method to impute the missing time
and data. See the above options (default "linear").
remove_duplicate_index: A boolean to auto-remove any duplicate time
Expand All @@ -978,7 +992,6 @@ def interpolate(
Returns:
A new :class:`TimeSeriesData` object with interpolated data.
"""

if not freq:
freq = self.infer_freq_robust()

Expand All @@ -1002,6 +1015,10 @@ def interpolate(
if remove_duplicate_time:
df = df[~df.index.duplicated()]

if pd.__version__ >= "1.1":
origin = origin if base == 0 else "start"
return self._interpolate_new(df, freq, origin, method, **kwargs)

if method == "linear":
df = df.resample(rule=freq, base=base).interpolate(method="linear")

Expand All @@ -1011,7 +1028,37 @@ def interpolate(
elif method == "bfill":
df = df.resample(rule=freq, base=base).bfill()
else:
df = df.resample(rule=freq, base=base).interpolate(method=method, **kwargs)
df = df.resample(rule=freq, base=base).interpolate(
method=cast(INTERPOLATION_METHOD_TYPE, method), **kwargs
)

df = df.reset_index().rename(columns={"index": self.time_col_name})
return TimeSeriesData(df, time_col_name=self.time_col_name)

def _interpolate_new(
self,
df: pd.DataFrame,
freq: Optional[Union[str, pd.Timedelta]],
origin: pd.Timestamp | str,
method: str,
**kwargs: Any,
) -> TimeSeriesData:
if method == "linear":
# pyre-ignore
df = df.resample(rule=freq, origin=origin).interpolate(method="linear")

elif method == "ffill":
# pyre-ignore
df = df.resample(rule=freq, origin=origin).ffill()

elif method == "bfill":
# pyre-ignore
df = df.resample(rule=freq, origin=origin).bfill()
else:
# pyre-ignore
df = df.resample(rule=freq, origin=origin).interpolate(
method=cast(INTERPOLATION_METHOD_TYPE, method), **kwargs
)

df = df.reset_index().rename(columns={"index": self.time_col_name})
return TimeSeriesData(df, time_col_name=self.time_col_name)
Expand Down
6 changes: 4 additions & 2 deletions kats/tests/test_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,15 +865,17 @@ def test_interpolate_base(self) -> None:
# calculate frequency first
frequency = str(int(ts0.infer_freq_robust().total_seconds())) + "s"

# without base value, interpolate won't work, will return all NaN
# Without base value, interpolate won't work, will return all NaN
# this is because start time is not from "**:00:00" or "**:30:00" type.
# This is equivalent to origin="start_day"
self.assertEqual(
# pyre-fixme[16]: Optional type has no attribute `value`.
ts0.interpolate(freq=frequency).to_dataframe().fillna(0).value.sum(),
0,
)
# with base value, will start from "**:59:59" ("**:00:00" - 1 sec)
# With base value, will start from "**:59:59" ("**:00:00" - 1 sec)
# or "**:29:59" ("**:30:00" -1 sec).
# Here we default to origin="start" instead of origin="start_day", which works.
self.assertEqual(
ts0.interpolate(freq=frequency, base=-1).to_dataframe().isna().value.sum(),
0,
Expand Down

0 comments on commit 40aab6d

Please sign in to comment.