|
13 | 13 | GroupNormalizer, |
14 | 14 | NaNLabelEncoder, |
15 | 15 | TimeSeriesDataSet, |
| 16 | + TimeSynchronizedBatchSampler, |
16 | 17 | ) |
17 | 18 | from pytorch_forecasting.data.examples import get_stallion_data |
18 | 19 |
|
@@ -270,8 +271,24 @@ def test_overwrite_values(test_dataset, value, variable, target): |
270 | 271 | assert torch.isclose(outputs[1], control_outputs[1]).all(), "Target should be reset" |
271 | 272 |
|
272 | 273 |
|
273 | | -def test_TimeSynchronizedBatchSampler(test_dataset): |
274 | | - dataloader = test_dataset.to_dataloader(batch_sampler="synchronized") |
| 274 | +@pytest.mark.parametrize( |
| 275 | + "drop_last,shuffle,as_string,batch_size", |
| 276 | + [ |
| 277 | + (True, True, True, 64), |
| 278 | + (False, False, False, 64), |
| 279 | + (True, False, False, 1000), |
| 280 | + ], |
| 281 | +) |
| 282 | +def test_TimeSynchronizedBatchSampler(test_dataset, shuffle, drop_last, as_string, batch_size): |
| 283 | + if as_string: |
| 284 | + dataloader = test_dataset.to_dataloader( |
| 285 | + batch_sampler="synchronized", shuffle=shuffle, drop_last=drop_last, batch_size=batch_size |
| 286 | + ) |
| 287 | + else: |
| 288 | + sampler = TimeSynchronizedBatchSampler( |
| 289 | + data_source=test_dataset, shuffle=shuffle, drop_last=drop_last, batch_size=batch_size |
| 290 | + ) |
| 291 | + dataloader = test_dataset.to_dataloader(batch_sampler=sampler) |
275 | 292 |
|
276 | 293 | time_idx_pos = test_dataset.reals.index("time_idx") |
277 | 294 | for x, _ in iter(dataloader): # check all samples |
|
0 commit comments