diff --git a/basicts/metrics/r_square.py b/basicts/metrics/r_square.py index 0ad7996..c60d8e8 100644 --- a/basicts/metrics/r_square.py +++ b/basicts/metrics/r_square.py @@ -39,3 +39,4 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float = # 计算 R^2 loss = 1 - (ss_res / (ss_tot + eps)) + return torch.mean(loss)