Skip to content

Commit edd718d

Browse files
authored
[ENH] refactor test data scenario generation to tests._data_scenarios (#1877)
This PR moves all test data scenario generation to `tests._data_scenarios`. It also deduplicates current data scenario code.
1 parent 01f97c4 commit edd718d

File tree

5 files changed

+32
-316
lines changed

5 files changed

+32
-316
lines changed

pytorch_forecasting/models/mlp/_decodermlp_metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def _get_test_dataloaders_from(cls, params):
8585
"""
8686
data_loader_kwargs = params.get("data_loader_kwargs", {})
8787

88-
from pytorch_forecasting.tests._conftest import (
89-
_data_with_covariates,
88+
from pytorch_forecasting.tests._data_scenarios import (
89+
data_with_covariates,
9090
make_dataloaders,
9191
)
9292

93-
dwc = _data_with_covariates()
93+
dwc = data_with_covariates()
9494
dwc.assign(target=lambda x: x.volume)
9595
dl_default_kwargs = dict(
9696
target="target",

pytorch_forecasting/models/nbeats/_nbeats_metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def _get_test_dataloaders_from(cls, params):
5454
Dict of dataloaders created from the parameters.
5555
Train, validation, and test dataloaders, in this order.
5656
"""
57-
from pytorch_forecasting.tests._conftest import (
58-
_dataloaders_fixed_window_without_covariates,
57+
from pytorch_forecasting.tests._data_scenarios import (
58+
dataloaders_fixed_window_without_covariates,
5959
)
6060

61-
return _dataloaders_fixed_window_without_covariates()
61+
return dataloaders_fixed_window_without_covariates()

pytorch_forecasting/tests/_conftest.py

Lines changed: 10 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -19,47 +19,11 @@ def gpus():
1919

2020
@pytest.fixture(scope="session")
2121
def data_with_covariates():
22-
return _data_with_covariates()
23-
24-
25-
def _data_with_covariates():
26-
data = get_stallion_data()
27-
data["month"] = data.date.dt.month.astype(str)
28-
data["log_volume"] = np.log1p(data.volume)
29-
data["weight"] = 1 + np.sqrt(data.volume)
30-
31-
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
32-
data["time_idx"] -= data["time_idx"].min()
33-
34-
# convert special days into strings
35-
special_days = [
36-
"easter_day",
37-
"good_friday",
38-
"new_year",
39-
"christmas",
40-
"labor_day",
41-
"independence_day",
42-
"revolution_day_memorial",
43-
"regional_games",
44-
"fifa_u_17_world_cup",
45-
"football_gold_cup",
46-
"beer_capital",
47-
"music_fest",
48-
]
49-
data[special_days] = (
50-
data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
22+
from pytorch_forecasting.tests._data_scenarios import (
23+
data_with_covariates as _data_with_covariates,
5124
)
52-
data = data.astype(dict(industry_volume=float))
53-
54-
# select data subset
55-
data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][
56-
lambda x: x.agency.isin(data.agency.unique()[:2])
57-
]
5825

59-
# default target
60-
data["target"] = data["volume"].clip(1e-3, 1.0)
61-
62-
return data
26+
return _data_with_covariates()
6327

6428

6529
def make_dataloaders(data_with_covariates, **kwargs):
@@ -161,47 +125,12 @@ def multiple_dataloaders_with_covariates(data_with_covariates, request):
161125

162126
@pytest.fixture(scope="session")
163127
def dataloaders_with_different_encoder_decoder_length(data_with_covariates):
164-
return make_dataloaders(
165-
data_with_covariates.copy(),
166-
target="target",
167-
time_varying_known_categoricals=["special_days", "month"],
168-
variable_groups=dict(
169-
special_days=[
170-
"easter_day",
171-
"good_friday",
172-
"new_year",
173-
"christmas",
174-
"labor_day",
175-
"independence_day",
176-
"revolution_day_memorial",
177-
"regional_games",
178-
"fifa_u_17_world_cup",
179-
"football_gold_cup",
180-
"beer_capital",
181-
"music_fest",
182-
]
183-
),
184-
time_varying_known_reals=[
185-
"time_idx",
186-
"price_regular",
187-
"price_actual",
188-
"discount",
189-
"discount_in_percent",
190-
],
191-
time_varying_unknown_categoricals=[],
192-
time_varying_unknown_reals=[
193-
"target",
194-
"volume",
195-
"log_volume",
196-
"industry_volume",
197-
"soda_volume",
198-
"avg_max_temp",
199-
],
200-
static_categoricals=["agency"],
201-
add_relative_time_idx=False,
202-
target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False),
128+
from pytorch_forecasting.tests._data_scenarios import (
129+
dataloaders_with_different_encoder_decoder_length as _dataloader,
203130
)
204131

132+
return _dataloader()
133+
205134

206135
@pytest.fixture(scope="session")
207136
def dataloaders_with_covariates(data_with_covariates):
@@ -228,43 +157,8 @@ def dataloaders_multi_target(data_with_covariates):
228157

229158
@pytest.fixture(scope="session")
230159
def dataloaders_fixed_window_without_covariates():
231-
return _dataloaders_fixed_window_without_covariates()
232-
233-
234-
def _dataloaders_fixed_window_without_covariates():
235-
data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2)
236-
validation = data.series.iloc[:2]
237-
238-
max_encoder_length = 30
239-
max_prediction_length = 10
240-
241-
training = TimeSeriesDataSet(
242-
data[lambda x: ~x.series.isin(validation)],
243-
time_idx="time_idx",
244-
target="value",
245-
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
246-
group_ids=["series"],
247-
static_categoricals=[],
248-
max_encoder_length=max_encoder_length,
249-
max_prediction_length=max_prediction_length,
250-
time_varying_unknown_reals=["value"],
251-
target_normalizer=EncoderNormalizer(),
160+
from pytorch_forecasting.tests._data_scenarios import (
161+
dataloaders_fixed_window_without_covariates as _dataloader,
252162
)
253163

254-
validation = TimeSeriesDataSet.from_dataset(
255-
training,
256-
data[lambda x: x.series.isin(validation)],
257-
stop_randomization=True,
258-
)
259-
batch_size = 2
260-
train_dataloader = training.to_dataloader(
261-
train=True, batch_size=batch_size, num_workers=0
262-
)
263-
val_dataloader = validation.to_dataloader(
264-
train=False, batch_size=batch_size, num_workers=0
265-
)
266-
test_dataloader = validation.to_dataloader(
267-
train=False, batch_size=batch_size, num_workers=0
268-
)
269-
270-
return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader)
164+
return _dataloader()

pytorch_forecasting/tests/_data_scenarios.py

Lines changed: 6 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import pytest
32
import torch
43

54
from pytorch_forecasting import TimeSeriesDataSet
@@ -9,14 +8,6 @@
98
torch.manual_seed(23)
109

1110

12-
@pytest.fixture(scope="session")
13-
def gpus():
14-
if torch.cuda.is_available():
15-
return [0]
16-
else:
17-
return 0
18-
19-
2011
def data_with_covariates():
2112
data = get_stallion_data()
2213
data["month"] = data.date.dt.month.astype(str)
@@ -87,77 +78,9 @@ def make_dataloaders(data_with_covariates, **kwargs):
8778
return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader)
8879

8980

90-
@pytest.fixture(
91-
params=[
92-
dict(),
93-
dict(
94-
static_categoricals=["agency", "sku"],
95-
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
96-
time_varying_known_categoricals=["special_days", "month"],
97-
variable_groups=dict(
98-
special_days=[
99-
"easter_day",
100-
"good_friday",
101-
"new_year",
102-
"christmas",
103-
"labor_day",
104-
"independence_day",
105-
"revolution_day_memorial",
106-
"regional_games",
107-
"fifa_u_17_world_cup",
108-
"football_gold_cup",
109-
"beer_capital",
110-
"music_fest",
111-
]
112-
),
113-
time_varying_known_reals=[
114-
"time_idx",
115-
"price_regular",
116-
"price_actual",
117-
"discount",
118-
"discount_in_percent",
119-
],
120-
time_varying_unknown_categoricals=[],
121-
time_varying_unknown_reals=[
122-
"volume",
123-
"log_volume",
124-
"industry_volume",
125-
"soda_volume",
126-
"avg_max_temp",
127-
],
128-
constant_fill_strategy={"volume": 0},
129-
categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)},
130-
),
131-
dict(static_categoricals=["agency", "sku"]),
132-
dict(randomize_length=True, min_encoder_length=2),
133-
dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2),
134-
dict(target_normalizer=GroupNormalizer(transformation="log1p")),
135-
dict(
136-
target_normalizer=GroupNormalizer(
137-
groups=["agency", "sku"], transformation="softplus", center=False
138-
)
139-
),
140-
dict(target="agency"),
141-
# test multiple targets
142-
dict(target=["industry_volume", "volume"]),
143-
dict(target=["agency", "volume"]),
144-
dict(
145-
target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1
146-
),
147-
dict(target=["agency", "volume"], weight="volume"),
148-
# test weights
149-
dict(target="volume", weight="volume"),
150-
],
151-
scope="session",
152-
)
153-
def multiple_dataloaders_with_covariates(data_with_covariates, request):
154-
return make_dataloaders(data_with_covariates, **request.param)
155-
156-
157-
@pytest.fixture(scope="session")
158-
def dataloaders_with_different_encoder_decoder_length(data_with_covariates):
81+
def dataloaders_with_different_encoder_decoder_length():
15982
return make_dataloaders(
160-
data_with_covariates.copy(),
83+
data_with_covariates(),
16184
target="target",
16285
time_varying_known_categoricals=["special_days", "month"],
16386
variable_groups=dict(
@@ -198,10 +121,9 @@ def dataloaders_with_different_encoder_decoder_length(data_with_covariates):
198121
)
199122

200123

201-
@pytest.fixture(scope="session")
202-
def dataloaders_with_covariates(data_with_covariates):
124+
def dataloaders_with_covariates():
203125
return make_dataloaders(
204-
data_with_covariates.copy(),
126+
data_with_covariates(),
205127
target="target",
206128
time_varying_known_reals=["discount"],
207129
time_varying_unknown_reals=["target"],
@@ -211,17 +133,15 @@ def dataloaders_with_covariates(data_with_covariates):
211133
)
212134

213135

214-
@pytest.fixture(scope="session")
215-
def dataloaders_multi_target(data_with_covariates):
136+
def dataloaders_multi_target():
216137
return make_dataloaders(
217-
data_with_covariates.copy(),
138+
data_with_covariates(),
218139
time_varying_unknown_reals=["target", "discount"],
219140
target=["target", "discount"],
220141
add_relative_time_idx=False,
221142
)
222143

223144

224-
@pytest.fixture(scope="session")
225145
def dataloaders_fixed_window_without_covariates():
226146
data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2)
227147
validation = data.series.iloc[:2]

0 commit comments

Comments
 (0)