Skip to content

Commit 653f9b8

Browse files
author
Jan Beitner
committed
Fix tests - ensure mase is calculated in correct places
1 parent 4ee69b6 commit 653f9b8

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

pytorch_forecasting/data/encoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def transform(self, y: Iterable) -> Union[torch.Tensor, np.ndarray]:
9494
if self.warn:
9595
cond = ~np.isin(y, self.classes_)
9696
if cond.any():
97-
warnings.warn(f"Found {y[cond].nunique()} unknown classes which were set to NaN", UserWarning)
97+
warnings.warn(
98+
f"Found {np.unique(np.asarray(y)[cond]).size} unknown classes which were set to NaN",
99+
UserWarning,
100+
)
98101

99102
encoded = [self.classes_.get(v, 0) for v in y]
100103

pytorch_forecasting/models/base_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ def step(self, x: Dict[str, torch.Tensor], y: torch.Tensor, batch_idx: int, labe
183183
# for smoothness of loss function
184184
monotinicity_loss = 10 * torch.pow(monotinicity_loss, 2)
185185
if isinstance(self.loss, MASE):
186-
self.loss(prediction, y)
187-
else:
188186
loss = self.loss(
189187
prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]
190188
)
189+
else:
190+
loss = self.loss(prediction, y)
191191

192192
loss = loss * (1 + monotinicity_loss)
193193
else:

pytorch_forecasting/models/nbeats/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def step(self, x, y, batch_idx, label) -> Dict[str, torch.Tensor]:
218218
)
219219
backcast_weight = backcast_weight / (backcast_weight + 1) # normalize
220220
forecast_weight = 1 - backcast_weight
221-
backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight
221+
if isinstance(self.loss, MASE):
222+
backcast_loss = self.loss(backcast, x["encoder_target"], x["decoder_target"]) * backcast_weight
223+
else:
224+
backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight
222225
if label == "train":
223226
log["loss"] = log["loss"] * forecast_weight + backcast_loss
224227
log["log"]["train_loss"] = log["log"]["train_loss"] * forecast_weight + backcast_loss

pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,15 @@ def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
351351
variable_embedding = self.prescalers[name](variable_embedding)
352352
weight_inputs.append(variable_embedding)
353353
var_outputs.append(self.single_variable_grns[name](variable_embedding))
354-
var_outputs = torch.stack(var_outputs, axis=-1)
354+
var_outputs = torch.stack(var_outputs, dim=-1)
355355

356356
# calculate variable weights
357357
flat_embedding = torch.cat(weight_inputs, dim=-1)
358358
sparse_weights = self.flattened_grn(flat_embedding, context)
359359
sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)
360360

361361
outputs = var_outputs * sparse_weights
362-
outputs = outputs.sum(axis=-1)
362+
outputs = outputs.sum(dim=-1)
363363
else: # for one input, do not perform variable selection but just encoding
364364
name = next(iter(self.single_variable_grns.keys()))
365365
variable_embedding = x[name]

0 commit comments

Comments
 (0)