Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework TSDataset.train_test_split to pass all features to train and test parts #545

Merged
merged 10 commits into from
Jan 13, 2025
18 changes: 14 additions & 4 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,9 @@ def train_test_split(

In case of inconsistencies between ``test_size`` and (``test_start``, ``test_end``), ``test_size`` is ignored

During splitting all the features are kept in train and test parts including target, regressors,
target components, prediction intervals.

Parameters
----------
train_start:
Expand Down Expand Up @@ -1210,29 +1213,36 @@ def train_test_split(
if train_start_defined < self.df.index.min():
warnings.warn(f"Min timestamp in df is {self.df.index.min()}.")

train_df = self.df.loc[train_start_defined:train_end_defined][self.raw_df.columns] # type: ignore
# TODO: why do we use self.raw_df.columns instead of self.df.columns? Need to be discussed
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
train_df_init = self.df.loc[train_start_defined:train_end_defined][self.raw_df.columns] # type: ignore
train_df = self.df.loc[train_start_defined:train_end_defined][self.df.columns] # type: ignore
train_raw_df = self.raw_df.loc[train_start_defined:train_end_defined] # type: ignore
train = TSDataset(
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
df=train_df,
df=train_df_init,
df_exog=self.df_exog,
freq=self.freq,
known_future=self.known_future,
hierarchical_structure=self.hierarchical_structure,
)
train.df = train_df
train.raw_df = train_raw_df
train._regressors = deepcopy(self.regressors)
train._target_components_names = deepcopy(self.target_components_names)
train._prediction_intervals_names = deepcopy(self._prediction_intervals_names)

test_df = self.df.loc[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore
# TODO: why do we use self.raw_df.columns instead of self.df.columns? Need to be discussed
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
test_df_init = self.df.loc[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore
test_df = self.df.loc[test_start_defined:test_end_defined][self.df.columns] # type: ignore
# TODO: why do we start from train_start_defined here? Need to be discussed
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
test_raw_df = self.raw_df.loc[train_start_defined:test_end_defined] # type: ignore
test = TSDataset(
df=test_df,
df=test_df_init,
df_exog=self.df_exog,
freq=self.freq,
known_future=self.known_future,
hierarchical_structure=self.hierarchical_structure,
)
test.df = test_df
test.raw_df = test_raw_df
test._regressors = deepcopy(self.regressors)
test._target_components_names = deepcopy(self.target_components_names)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from etna.datasets.utils import make_timestamp_df_from_alignment
from etna.transforms import AddConstTransform
from etna.transforms import DifferencingTransform
from etna.transforms import LagTransform
from etna.transforms import TimeSeriesImputerTransform


Expand Down Expand Up @@ -956,18 +957,35 @@ def test_train_test_split_pass_regressors_to_output(df_and_regressors):
df, df_exog, known_future = df_and_regressors
ts = TSDataset(df=df, df_exog=df_exog, freq="D", known_future=known_future)
train, test = ts.train_test_split(test_size=5)
assert set(train.regressors).issubset(set(train.features))
assert set(test.regressors).issubset(set(test.features))
assert train.regressors == ts.regressors
assert test.regressors == ts.regressors


def test_train_test_split_pass_transform_regressors_to_output(df_and_regressors):
df, df_exog, known_future = df_and_regressors
ts = TSDataset(df=df, df_exog=df_exog, freq="D", known_future=known_future)
ts.fit_transform(transforms=[LagTransform(in_column="target", lags=[1, 2, 3])])
train, test = ts.train_test_split(test_size=5)
assert set(train.regressors).issubset(set(train.features))
assert set(test.regressors).issubset(set(test.features))
assert train.regressors == ts.regressors
assert test.regressors == ts.regressors


def test_train_test_split_pass_target_components_to_output(ts_with_target_components):
train, test = ts_with_target_components.train_test_split(test_size=5)
assert set(train.target_components_names).issubset(set(train.features))
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
assert set(test.target_components_names).issubset(set(test.features))
assert sorted(train.target_components_names) == sorted(ts_with_target_components.target_components_names)
assert sorted(test.target_components_names) == sorted(ts_with_target_components.target_components_names)


def test_train_test_split_pass_prediction_intervals_to_output(ts_with_prediction_intervals):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add a similar test for hierarchy, testing that structure is copied and current levels are preserved. Seems like we lack this test in general.

train, test = ts_with_prediction_intervals.train_test_split(test_size=5)
assert set(train.prediction_intervals_names).issubset(set(train.features))
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
assert set(test.prediction_intervals_names).issubset(set(test.features))
assert sorted(train.prediction_intervals_names) == sorted(ts_with_prediction_intervals.prediction_intervals_names)
assert sorted(test.prediction_intervals_names) == sorted(ts_with_prediction_intervals.prediction_intervals_names)

Expand Down
Loading