Skip to content

Commit 4dbc322

Browse files
author
Jan Beitner
committed
Increase timesynchronizedsampler test coverage
1 parent 51c9372 commit 4dbc322

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

tests/test_data.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
GroupNormalizer,
1414
NaNLabelEncoder,
1515
TimeSeriesDataSet,
16+
TimeSynchronizedBatchSampler,
1617
)
1718
from pytorch_forecasting.data.examples import get_stallion_data
1819

@@ -270,8 +271,24 @@ def test_overwrite_values(test_dataset, value, variable, target):
270271
assert torch.isclose(outputs[1], control_outputs[1]).all(), "Target should be reset"
271272

272273

273-
def test_TimeSynchronizedBatchSampler(test_dataset):
274-
dataloader = test_dataset.to_dataloader(batch_sampler="synchronized")
274+
@pytest.mark.parametrize(
275+
"drop_last,shuffle,as_string,batch_size",
276+
[
277+
(True, True, True, 64),
278+
(False, False, False, 64),
279+
(True, False, False, 1000),
280+
],
281+
)
282+
def test_TimeSynchronizedBatchSampler(test_dataset, shuffle, drop_last, as_string, batch_size):
283+
if as_string:
284+
dataloader = test_dataset.to_dataloader(
285+
batch_sampler="synchronized", shuffle=shuffle, drop_last=drop_last, batch_size=batch_size
286+
)
287+
else:
288+
sampler = TimeSynchronizedBatchSampler(
289+
data_source=test_dataset, shuffle=shuffle, drop_last=drop_last, batch_size=batch_size
290+
)
291+
dataloader = test_dataset.to_dataloader(batch_sampler=sampler)
275292

276293
time_idx_pos = test_dataset.reals.index("time_idx")
277294
for x, _ in iter(dataloader): # check all samples

0 commit comments

Comments
 (0)