Skip to content

Commit c7574bc

Browse files
authored
Merge pull request #170 from jdb78/feature/improved-encoder-keyerror
Better KeyError for NaNLabelEncoder
2 parents 9cd0dab + 684a450 commit c7574bc

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytorch_forecasting/data/encoders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def transform(self, y: Iterable) -> Union[torch.Tensor, np.ndarray]:
102102
encoded = [self.classes_.get(v, 0) for v in y]
103103

104104
else:
105-
encoded = [self.classes_[v] for v in y]
105+
try:
106+
encoded = [self.classes_[v] for v in y]
107+
except KeyError as e:
108+
raise KeyError(
109+
f"Unknown category '{e.args[0]}' encountered. Set `add_nan=True` to allow unknown categories"
110+
)
106111

107112
if isinstance(y, torch.Tensor):
108113
encoded = torch.tensor(encoded, dtype=torch.long, device=y.device)

0 commit comments

Comments
 (0)