2525from pytorch_forecasting .data .encoders import EncoderNormalizer , GroupNormalizer , NaNLabelEncoder , TorchNormalizer
2626
2727
28+ def _find_end_indices (diffs : np .ndarray , max_lengths : np .ndarray , min_length : int ) -> Tuple [np .ndarray , np .ndarray ]:
29+ """
30+ Identify end indices in series even if some values are missing.
31+
32+ Args:
33+ diffs (np.ndarray): array of differences to next time step. nans should be filled up with ones
34+ max_lengths (np.ndarray): maximum length of sequence by position.
35+ min_length (int): minimum length of sequence.
36+
37+ Returns:
38+ Tuple[np.ndarray, np.ndarray]: tuple of arrays where first is end indices and second is list of start
39+ and end indices that are currently missing.
40+ """
41+ missing_start_ends = []
42+ end_indices = []
43+ length = 1
44+ start_idx = 0
45+ max_idx = len (diffs ) - 1
46+ max_length = max_lengths [start_idx ]
47+
48+ for idx , diff in enumerate (diffs ):
49+ if length >= max_length :
50+ while length >= max_length :
51+ if length == max_length :
52+ end_indices .append (idx )
53+ else :
54+ end_indices .append (idx - 1 )
55+ length -= diffs [start_idx ]
56+ if start_idx < max_idx :
57+ start_idx += 1
58+ max_length = max_lengths [start_idx ]
59+ elif length >= min_length :
60+ missing_start_ends .append ([start_idx , idx ])
61+ length += diff
62+ if len (missing_start_ends ) > 0 : # required for numba compliance
63+ return np .asarray (end_indices ), np .asarray (missing_start_ends )
64+ else :
65+ return np .asarray (end_indices ), np .empty ((0 , 2 ), dtype = np .int64 )
66+
67+
68+ try :
69+ import numba
70+
71+ _find_end_indices = numba .jit (nopython = True )(_find_end_indices )
72+ except ImportError :
73+ pass
74+
75+
2876class TimeSeriesDataSet (Dataset ):
2977 """
3078 PyTorch Dataset for fitting timeseries models.
@@ -125,12 +173,16 @@ def __init__(
125173 """
126174 super ().__init__ ()
127175 self .max_encoder_length = max_encoder_length
128- self .min_encoder_length = min_encoder_length or max_encoder_length
176+ if min_encoder_length is None :
177+ min_encoder_length = max_encoder_length
178+ self .min_encoder_length = min_encoder_length
129179 assert (
130180 self .min_encoder_length <= self .max_encoder_length
131181 ), "max encoder length has to be larger equals min encoder length"
132182 self .max_prediction_length = max_prediction_length
133- self .min_prediction_length = min_prediction_length or max_prediction_length
183+ if min_prediction_length is None :
184+ min_prediction_length = max_prediction_length
185+ self .min_prediction_length = min_prediction_length
134186 assert (
135187 self .min_prediction_length <= self .max_prediction_length
136188 ), "max prediction length has to be larger equals min prediction length"
@@ -155,7 +207,9 @@ def __init__(
155207 else :
156208 randomize_length = (0.2 , 0.05 )
157209 self .randomize_length = randomize_length
158- self .min_prediction_idx = min_prediction_idx or data [self .time_idx ].min ()
210+ if min_prediction_idx is None :
211+ min_prediction_idx = data [self .time_idx ].min ()
212+ self .min_prediction_idx = min_prediction_idx
159213 self .constant_fill_strategy = {} if len (constant_fill_strategy ) == 0 else constant_fill_strategy
160214 self .predict_mode = predict_mode
161215 self .allow_missings = allow_missings
@@ -623,52 +677,54 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
623677 df_index ["count" ] = (df_index ["time_last" ] - df_index ["time_first" ]).astype (int ) + 1
624678 df_index ["group_id" ] = g .ngroup ()
625679
680+ min_sequence_length = self .min_prediction_length + self .min_encoder_length
681+ max_sequence_length = self .max_prediction_length + self .max_encoder_length
682+
626683 # calculate maximum index to include from current index_start
627- max_time = (df_index ["time" ] + self .max_encoder_length + self .max_prediction_length ).clip (
628- upper = df_index ["count" ] + df_index .time_first
629- )
684+ max_time = (df_index ["time" ] + max_sequence_length - 1 ).clip (upper = df_index ["count" ] + df_index .time_first - 1 )
630685
631686 # if there are missing timesteps, we cannot say directly what is the last timestep to include
632687 # therefore we iterate until it is found
633688 if (df_index ["time_diff_to_next" ] != 1 ).any ():
634689 assert (
635690 self .allow_missings
636691 ), "Time difference between steps has been idenfied as larger than 1 - set allow_missings=True"
637- df_index ["index_end" ] = df_index ["index_start" ]
638- for _ in range (df_index ["count" ].max ()):
639- new_end_time = (
640- df_index [["time" , "time_diff_to_next" ]].iloc [df_index ["index_end" ]].sum (axis = 1 ).to_numpy ()
641- )
642- df_index ["index_end" ] = df_index ["index_end" ].where (
643- new_end_time + 1 > max_time , df_index ["index_end" ] + 1
644- )
645- else :
646- # direct calculation of end index if there are no missing timesteps in the data
647- df_index ["index_end" ] = df_index ["index_start" ] + (max_time - df_index ["time" ] - 1 )
692+
693+ df_index ["index_end" ], missing_sequences = _find_end_indices (
694+ diffs = df_index .time_diff_to_next .to_numpy (),
695+ max_lengths = (max_time - df_index .time ).to_numpy () + 1 ,
696+ min_length = min_sequence_length ,
697+ )
698+ # add duplicates but mostly with shorter sequence length for start of timeseries
699+ # while the previous steps have ensured that we start a sequence on every time step, the missing_sequences
700+ # ensure that there is a sequence that finishes on every timestep
701+ if len (missing_sequences ) > 0 :
702+ shortened_sequences = df_index .iloc [missing_sequences [:, 0 ]].assign (index_end = missing_sequences [:, 1 ])
703+
704+ # concatenate shortened sequences
705+ df_index = pd .concat ([df_index , shortened_sequences ], axis = 0 , ignore_index = True )
648706
649707 # filter out where encode and decode length are not satisfied
650708 df_index ["sequence_length" ] = df_index ["time" ].iloc [df_index ["index_end" ]].to_numpy () - df_index ["time" ] + 1
651709
652710 # filter too short sequences
653711 df_index = df_index [
654712 # sequence must be at least of minimal prediction length
655- lambda x : (x .sequence_length >= self . min_prediction_length + self . min_encoder_length )
713+ lambda x : (x .sequence_length >= min_sequence_length )
656714 &
657715 # prediction must be for after minimal prediction index + length of prediction
658- (x ["sequence_length" ] + x ["time" ] - 1 >= self .min_prediction_idx - 1 + self .min_prediction_length )
716+ (x ["sequence_length" ] + x ["time" ] >= self .min_prediction_idx + self .min_prediction_length )
659717 ]
660- # todo: add duplicates for
661- # (x.sequence length > self.min_prediction_length + self.min_encoder_length) &
662- # (x.time - x.time_start < self.max_prediction_length + self.max_encoder_length)
663718
664719 if predict_mode : # keep longest element per series (i.e. the first element that spans to the end of the series)
665720 # filter all elements that are longer than the allowed maximum sequence length
666721 df_index = df_index [
667- lambda x : (x ["time_last" ] - x ["time" ] + 1 <= self . max_prediction_length + self . max_encoder_length )
668- & (x ["sequence_length" ] >= self . min_prediction_length + self . min_encoder_length )
722+ lambda x : (x ["time_last" ] - x ["time" ] + 1 <= max_sequence_length )
723+ & (x ["sequence_length" ] >= min_sequence_length )
669724 ]
670725 # choose longest sequence
671726 df_index = df_index .loc [df_index .groupby ("group_id" ).sequence_length .idxmax ()]
727+
672728 assert len (df_index ) > 0 , "filters should not remove entries"
673729
674730 return df_index
@@ -690,8 +746,10 @@ def plot_randomization(
690746 """
691747 if betas is None :
692748 betas = self .randomize_length
693- length = length or self .max_encoder_length
694- min_length = min_length or self .min_encoder_length
749+ if length is None :
750+ length = self .max_encoder_length
751+ if min_length is None :
752+ min_length = self .min_encoder_length
695753 probabilities = Beta (betas [0 ], betas [1 ]).sample ((1000 ,))
696754
697755 lengths = ((length - min_length ) * probabilities ).round () + min_length
@@ -1050,27 +1108,19 @@ def to_dataloader(
10501108 ** kwargs ,
10511109 )
10521110
1053- def get_index (self ) -> pd .DataFrame :
1111+ def x_to_index (self , x ) -> pd .DataFrame :
10541112 """
1055- Data index / order in which items are returned in train=False mode by dataloader .
1113+ Decode dataframe index from x .
10561114
10571115 Returns:
10581116 dataframe with time index column for first prediction and group ids
10591117 """
1060- decoder_length = pd .DataFrame (
1061- dict (
1062- prediction_idx = self .data ["time" ][self .index .index_end .to_numpy ()] - (self .min_prediction_idx - 1 ),
1063- sequence_length = self .index .sequence_length ,
1064- max_prediction_length = self .max_prediction_length ,
1065- )
1066- ).min (axis = 1 )
1067- encoder_lengths = self .index .sequence_length - decoder_length
1068- index_data = {self .time_idx : self .index .time + encoder_lengths }
1118+ index_data = {self .time_idx : x ["decoder_time_idx" ][:, 0 ]}
10691119 for id in self .group_ids :
1070- index_data [id ] = self . data ["groups" ][:, self .group_ids .index (id )][ self . index . index_start . to_numpy ( )]
1120+ index_data [id ] = x ["groups" ][:, self .group_ids .index (id )]
10711121 # decode if possible
10721122 index_data [id ] = self .transform_values (id , index_data [id ], inverse = True )
1073- index = pd .DataFrame (index_data , index = self . index . index )
1123+ index = pd .DataFrame (index_data )
10741124 return index
10751125
10761126
0 commit comments