Skip to content

Commit

Permalink
feat: 🎸 update metric_forward and fix bugs in l1 and l2 loss
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 13, 2023
1 parent a23494e commit fa83069
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 4 additions & 2 deletions basicts/losses/losses.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F


def l1_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def l1_loss(prediction: torch.Tensor, target: torch._tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean") -> torch.Tensor:
"""unmasked mae."""

return F.l1_loss(prediction, target)


def l2_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def l2_loss(prediction: torch.Tensor, target: torch.Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean") -> torch.Tensor:
"""unmasked mse"""

return F.mse_loss(prediction, target)
Expand Down
9 changes: 7 additions & 2 deletions basicts/runners/base_tsf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,18 @@ def metric_forward(self, metric_func, args) -> torch.Tensor:
# support partial function
# users can define their partial function in the config file
# e.g., functools.partial(masked_mase, freq="4", null_val=np.nan)
if "null_val" in covariate_names and "null_val" not in metric_func.keywords: # if null_val is required but not provided
if "null_val" in metric_func.keywords: # null_val is provided
# assert self.null_val is None, "Null_val is provided in metric function. The CFG.NULL_VAL should not be set."
pass # error when using multiple metrics, some of which require null_val and some do not
elif "null_val" in covariate_names: # null_val is required but not provided
args["null_val"] = self.null_val
metric_item = metric_func(**args)
elif callable(metric_func):
# is a function
# filter out keys that are not in function arguments
metric_item = metric_func(**args, null_val=self.null_val)
if "null_val" in covariate_names: # null_val is required
args["null_val"] = self.null_val
metric_item = metric_func(**args)
else:
raise TypeError("Unknown metric type: {0}".format(type(metric_func)))
return metric_item
Expand Down

0 comments on commit fa83069

Please sign in to comment.