@@ -19,47 +19,11 @@ def gpus():
1919
2020@pytest .fixture (scope = "session" )
2121def data_with_covariates ():
22- return _data_with_covariates ()
23-
24-
25- def _data_with_covariates ():
26- data = get_stallion_data ()
27- data ["month" ] = data .date .dt .month .astype (str )
28- data ["log_volume" ] = np .log1p (data .volume )
29- data ["weight" ] = 1 + np .sqrt (data .volume )
30-
31- data ["time_idx" ] = data ["date" ].dt .year * 12 + data ["date" ].dt .month
32- data ["time_idx" ] -= data ["time_idx" ].min ()
33-
34- # convert special days into strings
35- special_days = [
36- "easter_day" ,
37- "good_friday" ,
38- "new_year" ,
39- "christmas" ,
40- "labor_day" ,
41- "independence_day" ,
42- "revolution_day_memorial" ,
43- "regional_games" ,
44- "fifa_u_17_world_cup" ,
45- "football_gold_cup" ,
46- "beer_capital" ,
47- "music_fest" ,
48- ]
49- data [special_days ] = (
50- data [special_days ].apply (lambda x : x .map ({0 : "" , 1 : x .name })).astype ("category" )
22+ from pytorch_forecasting .tests ._data_scenarios import (
23+ data_with_covariates as _data_with_covariates ,
5124 )
52- data = data .astype (dict (industry_volume = float ))
53-
54- # select data subset
55- data = data [lambda x : x .sku .isin (data .sku .unique ()[:2 ])][
56- lambda x : x .agency .isin (data .agency .unique ()[:2 ])
57- ]
5825
59- # default target
60- data ["target" ] = data ["volume" ].clip (1e-3 , 1.0 )
61-
62- return data
26+ return _data_with_covariates ()
6327
6428
6529def make_dataloaders (data_with_covariates , ** kwargs ):
@@ -161,47 +125,12 @@ def multiple_dataloaders_with_covariates(data_with_covariates, request):
161125
162126@pytest .fixture (scope = "session" )
163127def dataloaders_with_different_encoder_decoder_length (data_with_covariates ):
164- return make_dataloaders (
165- data_with_covariates .copy (),
166- target = "target" ,
167- time_varying_known_categoricals = ["special_days" , "month" ],
168- variable_groups = dict (
169- special_days = [
170- "easter_day" ,
171- "good_friday" ,
172- "new_year" ,
173- "christmas" ,
174- "labor_day" ,
175- "independence_day" ,
176- "revolution_day_memorial" ,
177- "regional_games" ,
178- "fifa_u_17_world_cup" ,
179- "football_gold_cup" ,
180- "beer_capital" ,
181- "music_fest" ,
182- ]
183- ),
184- time_varying_known_reals = [
185- "time_idx" ,
186- "price_regular" ,
187- "price_actual" ,
188- "discount" ,
189- "discount_in_percent" ,
190- ],
191- time_varying_unknown_categoricals = [],
192- time_varying_unknown_reals = [
193- "target" ,
194- "volume" ,
195- "log_volume" ,
196- "industry_volume" ,
197- "soda_volume" ,
198- "avg_max_temp" ,
199- ],
200- static_categoricals = ["agency" ],
201- add_relative_time_idx = False ,
202- target_normalizer = GroupNormalizer (groups = ["agency" , "sku" ], center = False ),
128+ from pytorch_forecasting .tests ._data_scenarios import (
129+ dataloaders_with_different_encoder_decoder_length as _dataloader ,
203130 )
204131
132+ return _dataloader ()
133+
205134
206135@pytest .fixture (scope = "session" )
207136def dataloaders_with_covariates (data_with_covariates ):
@@ -228,43 +157,8 @@ def dataloaders_multi_target(data_with_covariates):
228157
229158@pytest .fixture (scope = "session" )
230159def dataloaders_fixed_window_without_covariates ():
231- return _dataloaders_fixed_window_without_covariates ()
232-
233-
234- def _dataloaders_fixed_window_without_covariates ():
235- data = generate_ar_data (seasonality = 10.0 , timesteps = 50 , n_series = 2 )
236- validation = data .series .iloc [:2 ]
237-
238- max_encoder_length = 30
239- max_prediction_length = 10
240-
241- training = TimeSeriesDataSet (
242- data [lambda x : ~ x .series .isin (validation )],
243- time_idx = "time_idx" ,
244- target = "value" ,
245- categorical_encoders = {"series" : NaNLabelEncoder ().fit (data .series )},
246- group_ids = ["series" ],
247- static_categoricals = [],
248- max_encoder_length = max_encoder_length ,
249- max_prediction_length = max_prediction_length ,
250- time_varying_unknown_reals = ["value" ],
251- target_normalizer = EncoderNormalizer (),
160+ from pytorch_forecasting .tests ._data_scenarios import (
161+ dataloaders_fixed_window_without_covariates as _dataloader ,
252162 )
253163
254- validation = TimeSeriesDataSet .from_dataset (
255- training ,
256- data [lambda x : x .series .isin (validation )],
257- stop_randomization = True ,
258- )
259- batch_size = 2
260- train_dataloader = training .to_dataloader (
261- train = True , batch_size = batch_size , num_workers = 0
262- )
263- val_dataloader = validation .to_dataloader (
264- train = False , batch_size = batch_size , num_workers = 0
265- )
266- test_dataloader = validation .to_dataloader (
267- train = False , batch_size = batch_size , num_workers = 0
268- )
269-
270- return dict (train = train_dataloader , val = val_dataloader , test = test_dataloader )
164+ return _dataloader ()
0 commit comments