1414import pandas as pd
1515from sklearn .exceptions import NotFittedError
1616from sklearn .preprocessing import StandardScaler
17+ from sklearn .utils import shuffle
1718from sklearn .utils .validation import check_is_fitted
1819import torch
1920from torch .distributions import Beta
2021from torch .nn .utils import rnn
2122from torch .utils .data import DataLoader , Dataset
23+ from torch .utils .data .sampler import Sampler
2224
2325from pytorch_forecasting .data .encoders import EncoderNormalizer , GroupNormalizer , NaNLabelEncoder , TorchNormalizer
2426
@@ -655,6 +657,9 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
655657 # prediction must be for after minimal prediction index + length of prediction
656658 (x ["sequence_length" ] + x ["time" ] - 1 >= self .min_prediction_idx - 1 + self .min_prediction_length )
657659 ]
660+ # todo: add duplicates for
661+ # (x.sequence length > self.min_prediction_length + self.min_encoder_length) &
662+ # (x.time - x.time_start < self.max_prediction_length + self.max_encoder_length)
658663
659664 if predict_mode : # keep longest element per series (i.e. the first element that spans to the end of the series)
660665 # filter all elements that are longer than the allowed maximum sequence length
@@ -766,6 +771,7 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
766771 # fill in missing values (if not all time indices are specified
767772 sequence_length = len (time )
768773 if sequence_length < index .sequence_length :
774+ assert self .allow_missings , "allow_missings should be True if sequences have gaps"
769775 repetitions = torch .cat ([time [1 :] - time [:- 1 ], torch .ones (1 , dtype = time .dtype )])
770776 indices = torch .repeat_interleave (torch .arange (len (time )), repetitions )
771777 repetition_indices = torch .cat ([torch .tensor ([False ], dtype = torch .bool ), indices [1 :] == indices [:- 1 ]])
@@ -970,14 +976,21 @@ def _collate_fn(
970976 target ,
971977 )
972978
973- def to_dataloader (self , train : bool = True , batch_size : int = 64 , ** kwargs ) -> DataLoader :
979+ def to_dataloader (
980+ self , train : bool = True , batch_size : int = 64 , batch_sampler : Union [Sampler , str ] = None , ** kwargs
981+ ) -> DataLoader :
974982 """
975983 Get dataloader from dataset.
976984
977985 Args:
978986 train (bool, optional): if dataloader is used for training or prediction
979987 Will shuffle and drop last batch if True. Defaults to True.
980988 batch_size (int): batch size for training model. Defaults to 64.
989+ batch_sampler (Union[Sampler, str]): batch sampler or string. One of
990+
991+ * "synchronized": ensure that samples in decoder are aligned in time. Does not support missing
992+ values in dataset.
993+
981994 **kwargs: additional arguments to ``DataLoader()``
982995
983996
@@ -1015,12 +1028,26 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
10151028 drop_last = train and len (self ) > batch_size ,
10161029 collate_fn = self ._collate_fn ,
10171030 batch_size = batch_size ,
1031+ batch_sampler = batch_sampler ,
10181032 )
1019-
10201033 default_kwargs .update (kwargs )
1034+ kwargs = default_kwargs
1035+ if kwargs ["batch_sampler" ] is not None :
1036+ sampler = kwargs ["batch_sampler" ]
1037+ if isinstance (sampler , str ):
1038+ if sampler == "synchronized" :
1039+ kwargs ["batch_sampler" ] = TimeSynchronizedBatchSampler (
1040+ self , batch_size = kwargs ["batch_size" ], shuffle = kwargs ["shuffle" ], drop_last = kwargs ["drop_last" ]
1041+ )
1042+ else :
1043+ raise ValueError (f"batch_sampler { sampler } unknown - see docstring for valid batch_sampler" )
1044+ del kwargs ["batch_size" ]
1045+ del kwargs ["shuffle" ]
1046+ del kwargs ["drop_last" ]
1047+
10211048 return DataLoader (
10221049 self ,
1023- ** default_kwargs ,
1050+ ** kwargs ,
10241051 )
10251052
10261053 def get_index (self ) -> pd .DataFrame :
@@ -1045,3 +1072,106 @@ def get_index(self) -> pd.DataFrame:
10451072 index_data [id ] = self .transform_values (id , index_data [id ], inverse = True )
10461073 index = pd .DataFrame (index_data , index = self .index .index )
10471074 return index
1075+
1076+
1077+ class TimeSynchronizedBatchSampler (Sampler ):
1078+ """
1079+ Samples mini-batches randomly but in a time-synchronised manner.
1080+
1081+ Time-synchornisation means that the time index of the first decoder samples are aligned across the batch.
1082+ This sampler does not support missing values in the dataset.
1083+ """
1084+
1085+ def __init__ (
1086+ self ,
1087+ data_source : TimeSeriesDataSet ,
1088+ batch_size : int = 64 ,
1089+ shuffle : bool = False ,
1090+ drop_last : bool = False ,
1091+ ):
1092+ """
1093+ Initialize TimeSynchronizedBatchSampler.
1094+
1095+ Args:
1096+ data_source (TimeSeriesDataSet): timeseries dataset.
1097+ drop_last (bool): if to drop last mini-batch from a group if it is smaller than batch_size.
1098+ Defaults to False.
1099+ shuffle (bool): if to shuffle dataset. Defaults to False.
1100+ batch_size (int, optional): Number of samples in a mini-batch. This is rather the maximum number
1101+ of samples. Because mini-batches are grouped by prediction time, chances are that there
1102+ are multiple where batch size will be smaller than the maximum. Defaults to 64.
1103+ """
1104+ # Since collections.abc.Iterable does not check for `__getitem__`, which
1105+ # is one way for an object to be an iterable, we don't do an `isinstance`
1106+ # check here.
1107+ if not isinstance (batch_size , int ) or isinstance (batch_size , bool ) or batch_size <= 0 :
1108+ raise ValueError (
1109+ "batch_size should be a positive integer value, " "but got batch_size={}" .format (batch_size )
1110+ )
1111+ if not isinstance (drop_last , bool ):
1112+ raise ValueError ("drop_last should be a boolean value, but got " "drop_last={}" .format (drop_last ))
1113+ self .data_source = data_source
1114+ self .batch_size = batch_size
1115+ self .drop_last = drop_last
1116+ self .shuffle = shuffle
1117+ assert not self .data_source .allow_missings , "allow_missings should be False for time-synchronized mini-batches"
1118+
1119+ # construct index from which can be sampled
1120+ self .construct_batch_groups ()
1121+
1122+ def construct_batch_groups (self ):
1123+ """
1124+ Construct index of batches from which can be sampled
1125+ """
1126+ index = self .data_source .index
1127+ # get groups, i.e. group all samples by first predict time
1128+ decoder_lengths = np .min (
1129+ [
1130+ index .time_last - (self .data_source .min_prediction_idx - 1 ),
1131+ index .sequence_length - self .data_source .min_encoder_length ,
1132+ ],
1133+ axis = 0 ,
1134+ ).clip (max = self .data_source .max_prediction_length )
1135+ first_prediction_time = index .time + index .sequence_length - decoder_lengths + 1
1136+ self ._groups = pd .RangeIndex (0 , len (index .index )).groupby (first_prediction_time )
1137+
1138+ # calculate sizes of groups
1139+ self ._group_sizes = {}
1140+ warns = []
1141+ for name , group in self ._groups .items (): # iterate over groups
1142+ if self .drop_last :
1143+ self ._group_sizes [name ] = len (group ) // self .batch_size
1144+ else :
1145+ self ._group_sizes [name ] = (len (group ) + self .batch_size - 1 ) // self .batch_size
1146+ if self ._group_sizes [name ] == 0 :
1147+ self ._group_sizes [name ] = 1
1148+ warns .append (name )
1149+ if len (warns ) > 0 :
1150+ warnings .warn (
1151+ f"Less than { self .batch_size } samples available for { len (warns )} prediction times. "
1152+ f"Use batch size smaller than { self .batch_size } . "
1153+ f"First 10 prediction times with small batch sizes: { warns [:10 ]} "
1154+ )
1155+ # create index from which can be sampled: index is equal to number of batches
1156+ # associate index with prediction time
1157+ self ._group_index = np .repeat (list (self ._group_sizes .keys ()), list (self ._group_sizes .values ()))
1158+ # associate index with batch within prediction time group
1159+ self ._sub_group_index = np .concatenate ([np .arange (size ) for size in self ._group_sizes .values ()])
1160+
1161+ def __iter__ (self ):
1162+ if self .shuffle : # shuffle samples
1163+ groups = {name : shuffle (group ) for name , group in self ._groups .items ()}
1164+ else :
1165+ groups = self ._groups
1166+
1167+ batch_samples = np .random .permutation (len (self ))
1168+ for idx in batch_samples :
1169+ name = self ._group_index [idx ]
1170+ sub_group = self ._sub_group_index [idx ]
1171+ sub_group_start = sub_group * self .batch_size
1172+ sub_group_end = sub_group_start + self .batch_size
1173+ batch = groups [name ][sub_group_start :sub_group_end ]
1174+ yield batch
1175+
1176+ def __len__ (self ):
1177+ return len (self ._group_index )
0 commit comments