Skip to content

Commit eaa7e53

Browse files
authored
Merge pull request #62 from jdb78/feature/time-synchronized-sampling
Add time-synchronized sampling to dataset
2 parents 4fba73d + 4dbc322 commit eaa7e53

File tree

5 files changed

+180
-7
lines changed

5 files changed

+180
-7
lines changed

pytorch_forecasting/data/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
to abstracts the necessary work.
66
"""
77
from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TorchNormalizer
8-
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
8+
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet, TimeSynchronizedBatchSampler
99

10-
__all__ = ["TimeSeriesDataSet", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", "EncoderNormalizer"]
10+
__all__ = [
11+
"TimeSeriesDataSet",
12+
"NaNLabelEncoder",
13+
"GroupNormalizer",
14+
"TorchNormalizer",
15+
"EncoderNormalizer",
16+
"TimeSynchronizedBatchSampler",
17+
]

pytorch_forecasting/data/timeseries.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
import pandas as pd
1515
from sklearn.exceptions import NotFittedError
1616
from sklearn.preprocessing import StandardScaler
17+
from sklearn.utils import shuffle
1718
from sklearn.utils.validation import check_is_fitted
1819
import torch
1920
from torch.distributions import Beta
2021
from torch.nn.utils import rnn
2122
from torch.utils.data import DataLoader, Dataset
23+
from torch.utils.data.sampler import Sampler
2224

2325
from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TorchNormalizer
2426

@@ -655,6 +657,9 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
655657
# prediction must be for after minimal prediction index + length of prediction
656658
(x["sequence_length"] + x["time"] - 1 >= self.min_prediction_idx - 1 + self.min_prediction_length)
657659
]
660+
# todo: add duplicates for
661+
# (x.sequence length > self.min_prediction_length + self.min_encoder_length) &
662+
# (x.time - x.time_start < self.max_prediction_length + self.max_encoder_length)
658663

659664
if predict_mode: # keep longest element per series (i.e. the first element that spans to the end of the series)
660665
# filter all elements that are longer than the allowed maximum sequence length
@@ -766,6 +771,7 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
766771
# fill in missing values (if not all time indices are specified
767772
sequence_length = len(time)
768773
if sequence_length < index.sequence_length:
774+
assert self.allow_missings, "allow_missings should be True if sequences have gaps"
769775
repetitions = torch.cat([time[1:] - time[:-1], torch.ones(1, dtype=time.dtype)])
770776
indices = torch.repeat_interleave(torch.arange(len(time)), repetitions)
771777
repetition_indices = torch.cat([torch.tensor([False], dtype=torch.bool), indices[1:] == indices[:-1]])
@@ -970,14 +976,21 @@ def _collate_fn(
970976
target,
971977
)
972978

973-
def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> DataLoader:
979+
def to_dataloader(
980+
self, train: bool = True, batch_size: int = 64, batch_sampler: Union[Sampler, str] = None, **kwargs
981+
) -> DataLoader:
974982
"""
975983
Get dataloader from dataset.
976984
977985
Args:
978986
train (bool, optional): if dataloader is used for training or prediction
979987
Will shuffle and drop last batch if True. Defaults to True.
980988
batch_size (int): batch size for training model. Defaults to 64.
989+
batch_sampler (Union[Sampler, str]): batch sampler or string. One of
990+
991+
* "synchronized": ensure that samples in decoder are aligned in time. Does not support missing
992+
values in dataset.
993+
981994
**kwargs: additional arguments to ``DataLoader()``
982995
983996
@@ -1015,12 +1028,26 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
10151028
drop_last=train and len(self) > batch_size,
10161029
collate_fn=self._collate_fn,
10171030
batch_size=batch_size,
1031+
batch_sampler=batch_sampler,
10181032
)
1019-
10201033
default_kwargs.update(kwargs)
1034+
kwargs = default_kwargs
1035+
if kwargs["batch_sampler"] is not None:
1036+
sampler = kwargs["batch_sampler"]
1037+
if isinstance(sampler, str):
1038+
if sampler == "synchronized":
1039+
kwargs["batch_sampler"] = TimeSynchronizedBatchSampler(
1040+
self, batch_size=kwargs["batch_size"], shuffle=kwargs["shuffle"], drop_last=kwargs["drop_last"]
1041+
)
1042+
else:
1043+
raise ValueError(f"batch_sampler {sampler} unknown - see docstring for valid batch_sampler")
1044+
del kwargs["batch_size"]
1045+
del kwargs["shuffle"]
1046+
del kwargs["drop_last"]
1047+
10211048
return DataLoader(
10221049
self,
1023-
**default_kwargs,
1050+
**kwargs,
10241051
)
10251052

10261053
def get_index(self) -> pd.DataFrame:
@@ -1045,3 +1072,106 @@ def get_index(self) -> pd.DataFrame:
10451072
index_data[id] = self.transform_values(id, index_data[id], inverse=True)
10461073
index = pd.DataFrame(index_data, index=self.index.index)
10471074
return index
1075+
1076+
1077+
class TimeSynchronizedBatchSampler(Sampler):
1078+
"""
1079+
Samples mini-batches randomly but in a time-synchronised manner.
1080+
1081+
Time-synchornisation means that the time index of the first decoder samples are aligned across the batch.
1082+
This sampler does not support missing values in the dataset.
1083+
"""
1084+
1085+
def __init__(
1086+
self,
1087+
data_source: TimeSeriesDataSet,
1088+
batch_size: int = 64,
1089+
shuffle: bool = False,
1090+
drop_last: bool = False,
1091+
):
1092+
"""
1093+
Initialize TimeSynchronizedBatchSampler.
1094+
1095+
Args:
1096+
data_source (TimeSeriesDataSet): timeseries dataset.
1097+
drop_last (bool): if to drop last mini-batch from a group if it is smaller than batch_size.
1098+
Defaults to False.
1099+
shuffle (bool): if to shuffle dataset. Defaults to False.
1100+
batch_size (int, optional): Number of samples in a mini-batch. This is rather the maximum number
1101+
of samples. Because mini-batches are grouped by prediction time, chances are that there
1102+
are multiple where batch size will be smaller than the maximum. Defaults to 64.
1103+
"""
1104+
# Since collections.abc.Iterable does not check for `__getitem__`, which
1105+
# is one way for an object to be an iterable, we don't do an `isinstance`
1106+
# check here.
1107+
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0:
1108+
raise ValueError(
1109+
"batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)
1110+
)
1111+
if not isinstance(drop_last, bool):
1112+
raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last))
1113+
self.data_source = data_source
1114+
self.batch_size = batch_size
1115+
self.drop_last = drop_last
1116+
self.shuffle = shuffle
1117+
assert not self.data_source.allow_missings, "allow_missings should be False for time-synchronized mini-batches"
1118+
1119+
# construct index from which can be sampled
1120+
self.construct_batch_groups()
1121+
1122+
def construct_batch_groups(self):
1123+
"""
1124+
Construct index of batches from which can be sampled
1125+
"""
1126+
index = self.data_source.index
1127+
# get groups, i.e. group all samples by first predict time
1128+
decoder_lengths = np.min(
1129+
[
1130+
index.time_last - (self.data_source.min_prediction_idx - 1),
1131+
index.sequence_length - self.data_source.min_encoder_length,
1132+
],
1133+
axis=0,
1134+
).clip(max=self.data_source.max_prediction_length)
1135+
first_prediction_time = index.time + index.sequence_length - decoder_lengths + 1
1136+
self._groups = pd.RangeIndex(0, len(index.index)).groupby(first_prediction_time)
1137+
1138+
# calculate sizes of groups
1139+
self._group_sizes = {}
1140+
warns = []
1141+
for name, group in self._groups.items(): # iterate over groups
1142+
if self.drop_last:
1143+
self._group_sizes[name] = len(group) // self.batch_size
1144+
else:
1145+
self._group_sizes[name] = (len(group) + self.batch_size - 1) // self.batch_size
1146+
if self._group_sizes[name] == 0:
1147+
self._group_sizes[name] = 1
1148+
warns.append(name)
1149+
if len(warns) > 0:
1150+
warnings.warn(
1151+
f"Less than {self.batch_size} samples available for {len(warns)} prediction times. "
1152+
f"Use batch size smaller than {self.batch_size}. "
1153+
f"First 10 prediction times with small batch sizes: {warns[:10]}"
1154+
)
1155+
# create index from which can be sampled: index is equal to number of batches
1156+
# associate index with prediction time
1157+
self._group_index = np.repeat(list(self._group_sizes.keys()), list(self._group_sizes.values()))
1158+
# associate index with batch within prediction time group
1159+
self._sub_group_index = np.concatenate([np.arange(size) for size in self._group_sizes.values()])
1160+
1161+
def __iter__(self):
1162+
if self.shuffle: # shuffle samples
1163+
groups = {name: shuffle(group) for name, group in self._groups.items()}
1164+
else:
1165+
groups = self._groups
1166+
1167+
batch_samples = np.random.permutation(len(self))
1168+
for idx in batch_samples:
1169+
name = self._group_index[idx]
1170+
sub_group = self._sub_group_index[idx]
1171+
sub_group_start = sub_group * self.batch_size
1172+
sub_group_end = sub_group_start + self.batch_size
1173+
batch = groups[name][sub_group_start:sub_group_end]
1174+
yield batch
1175+
1176+
def __len__(self):
1177+
return len(self._group_index)

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __init__(
9595
time_varying_categoricals_decoder: integer of positions of categorical variables for decoder
9696
time_varying_reals_encoder: integer of positions of continuous variables for encoder
9797
time_varying_reals_decoder: integer of positions of continuous variables for decoder
98+
categorical_groups: dictionary where values
99+
are list of categorical variables that are forming together a new categorical
100+
variable which is the key in the dictionary
98101
x_reals: order of continuous variables in tensor passed to forward function
99102
x_categoricals: order of categorical variables in tensor passed to forward function
100103
hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_dataset(test_data):
4747
test_data,
4848
time_idx="time_idx",
4949
target="volume",
50-
time_varying_known_reals=["price_regular"],
50+
time_varying_known_reals=["price_regular", "time_idx"],
5151
group_ids=["agency", "sku"],
5252
static_categoricals=["agency"],
5353
max_encoder_length=5,

tests/test_data.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from sklearn.preprocessing import StandardScaler
99
import torch
1010

11-
from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TimeSeriesDataSet
11+
from pytorch_forecasting.data import (
12+
EncoderNormalizer,
13+
GroupNormalizer,
14+
NaNLabelEncoder,
15+
TimeSeriesDataSet,
16+
TimeSynchronizedBatchSampler,
17+
)
1218
from pytorch_forecasting.data.examples import get_stallion_data
1319

1420
torch.manual_seed(23)
@@ -263,3 +269,30 @@ def test_overwrite_values(test_dataset, value, variable, target):
263269
changed = torch.isclose(outputs[0][name], control_outputs[0][name]).all()
264270
assert changed, f"Output {name} should be reset"
265271
assert torch.isclose(outputs[1], control_outputs[1]).all(), "Target should be reset"
272+
273+
274+
@pytest.mark.parametrize(
275+
"drop_last,shuffle,as_string,batch_size",
276+
[
277+
(True, True, True, 64),
278+
(False, False, False, 64),
279+
(True, False, False, 1000),
280+
],
281+
)
282+
def test_TimeSynchronizedBatchSampler(test_dataset, shuffle, drop_last, as_string, batch_size):
283+
if as_string:
284+
dataloader = test_dataset.to_dataloader(
285+
batch_sampler="synchronized", shuffle=shuffle, drop_last=drop_last, batch_size=batch_size
286+
)
287+
else:
288+
sampler = TimeSynchronizedBatchSampler(
289+
data_source=test_dataset, shuffle=shuffle, drop_last=drop_last, batch_size=batch_size
290+
)
291+
dataloader = test_dataset.to_dataloader(batch_sampler=sampler)
292+
293+
time_idx_pos = test_dataset.reals.index("time_idx")
294+
for x, _ in iter(dataloader): # check all samples
295+
time_idx_of_first_prediction = x["decoder_cont"][:, 0, time_idx_pos]
296+
assert torch.isclose(
297+
time_idx_of_first_prediction, time_idx_of_first_prediction[0]
298+
).all(), "Time index should be the same for the first prediction"

0 commit comments

Comments
 (0)