Skip to content

Commit 48f3595

Browse files
authored
Merge pull request #172 from jdb78/fix/allow-new-series-ids-in-test-data
Re-encode group ids by dataset to identify series
2 parents c7574bc + 0cb2248 commit 48f3595

File tree

3 files changed

+104
-11
lines changed

3 files changed

+104
-11
lines changed

pytorch_forecasting/data/timeseries.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,24 @@ def _set_target_normalizer(self, data: pd.DataFrame):
353353
self.target_normalizer, (TorchNormalizer, NaNLabelEncoder)
354354
), f"target_normalizer has to be either None or of class TorchNormalizer but found {self.target_normalizer}"
355355

356+
@property
357+
def _group_ids_mapping(self) -> Dict[str, str]:
358+
"""
359+
Mapping of group id names to group ids used to identify series in dataset -
360+
group ids can also be used for target normalizer.
361+
The former can change from training to validation and test dataset while the later must not.
362+
"""
363+
return {name: f"__group_id__{name}" for name in self.group_ids}
364+
365+
@property
366+
def _group_ids(self) -> List[str]:
367+
"""
368+
Group ids used to identify series in dataset.
369+
370+
See :py:meth:`~TimeSeriesDataSet._group_ids_mapping` for details.
371+
"""
372+
return list(self._group_ids_mapping.values())
373+
356374
def _validate_data(self, data: pd.DataFrame):
357375
"""
358376
Validate that data will not cause hick-ups later on.
@@ -403,9 +421,19 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
403421
Returns:
404422
pd.DataFrame: pre-processed dataframe
405423
"""
424+
# encode group ids - this encoding
425+
for name, group_name in self._group_ids_mapping.items():
426+
self.categorical_encoders[group_name] = NaNLabelEncoder().fit(data[name].to_numpy().reshape(-1))
427+
data[group_name] = self.transform_values(name, data[name], inverse=False, group_id=True)
406428

407429
# encode categoricals
408-
for name in set(self.categoricals + self.group_ids):
430+
if isinstance(
431+
self.target_normalizer, GroupNormalizer
432+
): # if we use a group normalizer, group_ids must be encoded as well
433+
group_ids_to_encode = self.group_ids
434+
else:
435+
group_ids_to_encode = []
436+
for name in set(group_ids_to_encode + self.categoricals):
409437
allow_nans = name in self.dropout_categoricals
410438
if name in self.variable_groups: # fit groups
411439
columns = self.variable_groups[name]
@@ -430,7 +458,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
430458
self.categorical_encoders[name] = self.categorical_encoders[name].fit(data[name])
431459

432460
# encode them
433-
for name in set(self.flat_categoricals + self.group_ids):
461+
for name in set(group_ids_to_encode + self.flat_categoricals):
434462
data[name] = self.transform_values(name, data[name], inverse=False)
435463

436464
# save special variables
@@ -472,6 +500,10 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
472500
data[self.target], scales = self.target_normalizer.transform(data[self.target], data, return_norm=True)
473501
elif isinstance(self.target_normalizer, NaNLabelEncoder):
474502
data[self.target] = self.target_normalizer.transform(data[self.target])
503+
data["__target__"] = data[
504+
self.target
505+
] # overwrite target because it requires encoding (continuous targets should not be normalized)
506+
scales = "no target scales available for categorical target"
475507
else:
476508
data[self.target], scales = self.target_normalizer.transform(data[self.target], return_norm=True)
477509

@@ -488,6 +520,8 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
488520

489521
if self.target in self.reals:
490522
self.scalers[self.target] = self.target_normalizer
523+
else:
524+
self.categorical_encoders[self.target] = self.target_normalizer
491525

492526
# rescale continuous variables apart from target
493527
for name in self.reals:
@@ -515,7 +549,12 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
515549
return data
516550

517551
def transform_values(
518-
self, name: str, values: Union[pd.Series, torch.Tensor, np.ndarray], data: pd.DataFrame = None, inverse=False
552+
self,
553+
name: str,
554+
values: Union[pd.Series, torch.Tensor, np.ndarray],
555+
data: pd.DataFrame = None,
556+
inverse=False,
557+
group_id: bool = False,
519558
) -> np.ndarray:
520559
"""
521560
Scale and encode values.
@@ -526,12 +565,16 @@ def transform_values(
526565
data (pd.DataFrame, optional): extra data used for scaling (e.g. dataframe with groups columns).
527566
Defaults to None.
528567
inverse (bool, optional): if to conduct inverse transformation. Defaults to False.
568+
group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
569+
Defaults to False.
529570
530571
Returns:
531572
np.ndarray: (de/en)coded/(de)scaled values
532573
"""
574+
if group_id:
575+
name = self._group_ids_mapping[name]
533576
# remaining categories
534-
if name in set(self.flat_categoricals + self.group_ids):
577+
if name in set(self.flat_categoricals + self.group_ids + self._group_ids):
535578
name = self.variable_to_group_mapping.get(name, name) # map name to encoder
536579
encoder = self.categorical_encoders[name]
537580
if encoder is None:
@@ -575,7 +618,7 @@ def _data_to_tensors(self, data: pd.DataFrame) -> Dict[str, torch.Tensor]:
575618
time index
576619
"""
577620

578-
index = torch.tensor(data[self.group_ids].to_numpy(np.long), dtype=torch.long)
621+
index = torch.tensor(data[self._group_ids].to_numpy(np.long), dtype=torch.long)
579622
time = torch.tensor(data["__time_idx__"].to_numpy(np.long), dtype=torch.long)
580623

581624
categorical = torch.tensor(data[self.flat_categoricals].to_numpy(np.long), dtype=torch.long)
@@ -735,7 +778,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
735778
Returns:
736779
pd.DataFrame: index dataframe
737780
"""
738-
g = data.groupby(self.group_ids, observed=True)
781+
g = data.groupby(self._group_ids, observed=True)
739782

740783
df_index_first = g["__time_idx__"].transform("nth", 0).to_frame("time_first")
741784
df_index_last = g["__time_idx__"].transform("nth", -1).to_frame("time_last")
@@ -797,10 +840,10 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
797840

798841
# check that all groups/series have at least one entry in the index
799842
if not group_ids.isin(df_index.group_id).all():
800-
missing_groups = data.loc[~group_ids.isin(df_index.group_id), self.group_ids].drop_duplicates()
843+
missing_groups = data.loc[~group_ids.isin(df_index.group_id), self._group_ids].drop_duplicates()
801844
# decode values
802-
for name in missing_groups.columns:
803-
missing_groups[name] = self.transform_values(name, missing_groups[name], inverse=True)
845+
for name, id in self._group_ids_mapping.items():
846+
missing_groups[id] = self.transform_values(name, missing_groups[id], inverse=True, group_id=True)
804847
warnings.warn(
805848
"Min encoder length and/or min_prediction_idx and/or min prediction length is too large for "
806849
f"{len(missing_groups)} series/groups which therefore are not present in the dataset index. "
@@ -1210,7 +1253,7 @@ def x_to_index(self, x: Dict[str, torch.Tensor]) -> pd.DataFrame:
12101253
for id in self.group_ids:
12111254
index_data[id] = x["groups"][:, self.group_ids.index(id)].cpu()
12121255
# decode if possible
1213-
index_data[id] = self.transform_values(id, index_data[id], inverse=True)
1256+
index_data[id] = self.transform_values(id, index_data[id], inverse=True, group_id=True)
12141257
index = pd.DataFrame(index_data)
12151258
return index
12161259

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111
from pytorch_forecasting.data.examples import get_stallion_data # isort:skip
1212

1313

14+
# for vscode debugging: https://stackoverflow.com/a/62563106/14121677
15+
if os.getenv("_PYTEST_RAISE", "0") != "0":
16+
17+
@pytest.hookimpl(tryfirst=True)
18+
def pytest_exception_interact(call):
19+
raise call.excinfo.value
20+
21+
@pytest.hookimpl(tryfirst=True)
22+
def pytest_internalerror(excinfo):
23+
raise excinfo.value
24+
25+
1426
@pytest.fixture
1527
def test_data():
1628
data = get_stallion_data()

tests/test_data.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,48 @@ def test_categorical_target(test_data):
373373
min_encoder_length=1,
374374
)
375375

376-
x, y = next(iter(dataset.to_dataloader()))
376+
_, y = next(iter(dataset.to_dataloader()))
377377
assert y.dtype is torch.long, "target must be of type long"
378378

379379

380380
def test_pickle(test_dataset):
381381
pickle.dumps(test_dataset)
382382
pickle.dumps(test_dataset.to_dataloader())
383+
384+
385+
@pytest.mark.parametrize(
386+
"kwargs",
387+
[
388+
{},
389+
dict(
390+
target_normalizer=GroupNormalizer(
391+
groups=["agency", "sku"], log_scale=True, scale_by_group=True, log_zero_value=1.0
392+
),
393+
),
394+
],
395+
)
396+
def test_new_group_ids(test_data, kwargs):
397+
"""Test for new group ids in dataset"""
398+
train_agency = test_data["agency"].iloc[0]
399+
train_dataset = TimeSeriesDataSet(
400+
test_data[lambda x: x.agency == train_agency],
401+
time_idx="time_idx",
402+
target="volume",
403+
group_ids=["agency", "sku"],
404+
max_encoder_length=5,
405+
max_prediction_length=2,
406+
min_prediction_length=1,
407+
min_encoder_length=1,
408+
categorical_encoders=dict(agency=NaNLabelEncoder(add_nan=True), sku=NaNLabelEncoder(add_nan=True)),
409+
**kwargs,
410+
)
411+
412+
# test sampling from training dataset
413+
next(iter(train_dataset.to_dataloader()))
414+
415+
# create test dataset with group ids that have not been observed before
416+
test_dataset = TimeSeriesDataSet.from_dataset(train_dataset, test_data)
417+
418+
# check that we can iterate through dataset without error
419+
for _ in iter(test_dataset.to_dataloader()):
420+
pass

0 commit comments

Comments
 (0)