@@ -353,6 +353,24 @@ def _set_target_normalizer(self, data: pd.DataFrame):
353353 self .target_normalizer , (TorchNormalizer , NaNLabelEncoder )
354354 ), f"target_normalizer has to be either None or of class TorchNormalizer but found { self .target_normalizer } "
355355
356+ @property
357+ def _group_ids_mapping (self ) -> Dict [str , str ]:
358+ """
359+ Mapping of group id names to group ids used to identify series in dataset -
360+ group ids can also be used for target normalizer.
361+ The former can change from training to validation and test dataset while the later must not.
362+ """
363+ return {name : f"__group_id__{ name } " for name in self .group_ids }
364+
365+ @property
366+ def _group_ids (self ) -> List [str ]:
367+ """
368+ Group ids used to identify series in dataset.
369+
370+ See :py:meth:`~TimeSeriesDataSet._group_ids_mapping` for details.
371+ """
372+ return list (self ._group_ids_mapping .values ())
373+
356374 def _validate_data (self , data : pd .DataFrame ):
357375 """
358376 Validate that data will not cause hick-ups later on.
@@ -403,9 +421,19 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
403421 Returns:
404422 pd.DataFrame: pre-processed dataframe
405423 """
424+ # encode group ids - this encoding
425+ for name , group_name in self ._group_ids_mapping .items ():
426+ self .categorical_encoders [group_name ] = NaNLabelEncoder ().fit (data [name ].to_numpy ().reshape (- 1 ))
427+ data [group_name ] = self .transform_values (name , data [name ], inverse = False , group_id = True )
406428
407429 # encode categoricals
408- for name in set (self .categoricals + self .group_ids ):
430+ if isinstance (
431+ self .target_normalizer , GroupNormalizer
432+ ): # if we use a group normalizer, group_ids must be encoded as well
433+ group_ids_to_encode = self .group_ids
434+ else :
435+ group_ids_to_encode = []
436+ for name in set (group_ids_to_encode + self .categoricals ):
409437 allow_nans = name in self .dropout_categoricals
410438 if name in self .variable_groups : # fit groups
411439 columns = self .variable_groups [name ]
@@ -430,7 +458,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
430458 self .categorical_encoders [name ] = self .categorical_encoders [name ].fit (data [name ])
431459
432460 # encode them
433- for name in set (self . flat_categoricals + self .group_ids ):
461+ for name in set (group_ids_to_encode + self .flat_categoricals ):
434462 data [name ] = self .transform_values (name , data [name ], inverse = False )
435463
436464 # save special variables
@@ -472,6 +500,10 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
472500 data [self .target ], scales = self .target_normalizer .transform (data [self .target ], data , return_norm = True )
473501 elif isinstance (self .target_normalizer , NaNLabelEncoder ):
474502 data [self .target ] = self .target_normalizer .transform (data [self .target ])
503+ data ["__target__" ] = data [
504+ self .target
505+ ] # overwrite target because it requires encoding (continuous targets should not be normalized)
506+ scales = "no target scales available for categorical target"
475507 else :
476508 data [self .target ], scales = self .target_normalizer .transform (data [self .target ], return_norm = True )
477509
@@ -488,6 +520,8 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
488520
489521 if self .target in self .reals :
490522 self .scalers [self .target ] = self .target_normalizer
523+ else :
524+ self .categorical_encoders [self .target ] = self .target_normalizer
491525
492526 # rescale continuous variables apart from target
493527 for name in self .reals :
@@ -515,7 +549,12 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
515549 return data
516550
517551 def transform_values (
518- self , name : str , values : Union [pd .Series , torch .Tensor , np .ndarray ], data : pd .DataFrame = None , inverse = False
552+ self ,
553+ name : str ,
554+ values : Union [pd .Series , torch .Tensor , np .ndarray ],
555+ data : pd .DataFrame = None ,
556+ inverse = False ,
557+ group_id : bool = False ,
519558 ) -> np .ndarray :
520559 """
521560 Scale and encode values.
@@ -526,12 +565,16 @@ def transform_values(
526565 data (pd.DataFrame, optional): extra data used for scaling (e.g. dataframe with groups columns).
527566 Defaults to None.
528567 inverse (bool, optional): if to conduct inverse transformation. Defaults to False.
568+ group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
569+ Defaults to False.
529570
530571 Returns:
531572 np.ndarray: (de/en)coded/(de)scaled values
532573 """
574+ if group_id :
575+ name = self ._group_ids_mapping [name ]
533576 # remaining categories
534- if name in set (self .flat_categoricals + self .group_ids ):
577+ if name in set (self .flat_categoricals + self .group_ids + self . _group_ids ):
535578 name = self .variable_to_group_mapping .get (name , name ) # map name to encoder
536579 encoder = self .categorical_encoders [name ]
537580 if encoder is None :
@@ -575,7 +618,7 @@ def _data_to_tensors(self, data: pd.DataFrame) -> Dict[str, torch.Tensor]:
575618 time index
576619 """
577620
578- index = torch .tensor (data [self .group_ids ].to_numpy (np .long ), dtype = torch .long )
621+ index = torch .tensor (data [self ._group_ids ].to_numpy (np .long ), dtype = torch .long )
579622 time = torch .tensor (data ["__time_idx__" ].to_numpy (np .long ), dtype = torch .long )
580623
581624 categorical = torch .tensor (data [self .flat_categoricals ].to_numpy (np .long ), dtype = torch .long )
@@ -735,7 +778,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
735778 Returns:
736779 pd.DataFrame: index dataframe
737780 """
738- g = data .groupby (self .group_ids , observed = True )
781+ g = data .groupby (self ._group_ids , observed = True )
739782
740783 df_index_first = g ["__time_idx__" ].transform ("nth" , 0 ).to_frame ("time_first" )
741784 df_index_last = g ["__time_idx__" ].transform ("nth" , - 1 ).to_frame ("time_last" )
@@ -797,10 +840,10 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
797840
798841 # check that all groups/series have at least one entry in the index
799842 if not group_ids .isin (df_index .group_id ).all ():
800- missing_groups = data .loc [~ group_ids .isin (df_index .group_id ), self .group_ids ].drop_duplicates ()
843+ missing_groups = data .loc [~ group_ids .isin (df_index .group_id ), self ._group_ids ].drop_duplicates ()
801844 # decode values
802- for name in missing_groups . columns :
803- missing_groups [name ] = self .transform_values (name , missing_groups [name ], inverse = True )
845+ for name , id in self . _group_ids_mapping . items () :
846+ missing_groups [id ] = self .transform_values (name , missing_groups [id ], inverse = True , group_id = True )
804847 warnings .warn (
805848 "Min encoder length and/or min_prediction_idx and/or min prediction length is too large for "
806849 f"{ len (missing_groups )} series/groups which therefore are not present in the dataset index. "
@@ -1210,7 +1253,7 @@ def x_to_index(self, x: Dict[str, torch.Tensor]) -> pd.DataFrame:
12101253 for id in self .group_ids :
12111254 index_data [id ] = x ["groups" ][:, self .group_ids .index (id )].cpu ()
12121255 # decode if possible
1213- index_data [id ] = self .transform_values (id , index_data [id ], inverse = True )
1256+ index_data [id ] = self .transform_values (id , index_data [id ], inverse = True , group_id = True )
12141257 index = pd .DataFrame (index_data )
12151258 return index
12161259
0 commit comments