Skip to content

Commit

Permalink
rmse score fix (#372)
Browse files Browse the repository at this point in the history
* rmse score fix

* ..

* bump python version

---------

Co-authored-by: Moshe Raboh [email protected] <[email protected]>
  • Loading branch information
mosheraboh and Moshe Raboh [email protected] authored Oct 6, 2024
1 parent e0c7b73 commit 1626ae1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ Note - in general, we find it helpful to follow the same directory structure sho

# Installation

FuseMedML is tested on Python >= 3.7 and PyTorch >= 1.5
FuseMedML is tested on Python >= 3.9 and PyTorch >= 2.0

## We recommend using a Conda environment

Expand Down
16 changes: 9 additions & 7 deletions fuse/eval/metrics/regression/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from fuse.eval.metrics.libs.stat import Stat
from fuse.eval.metrics.metrics_common import MetricDefault
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
r2_score,
root_mean_squared_error,
)


class MetricPearsonCorrelation(MetricDefault):
Expand Down Expand Up @@ -107,26 +112,23 @@ def __init__(
super().__init__(
pred=pred,
target=target,
metric_func=self.mse,
metric_func=self.rmse,
**kwargs,
)

def mse(
def rmse(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
**kwargs: dict,
) -> float:

pred = np.array(pred).flatten()
target = np.array(target).flatten()

assert len(pred) == len(
target
), f"Expected pred and target to have the dimensions but found: {len(pred)} elements in pred and {len(target)} in target"

squared_diff = (pred - target) ** 2
return squared_diff.mean()
return root_mean_squared_error(y_true=target, y_pred=pred)


class MetricR2(MetricDefault):
Expand Down
2 changes: 1 addition & 1 deletion fuse/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pandas>=1.2
tqdm>=4.52.0
scipy>=1.5.4
matplotlib>=3.3.3
scikit-learn>=0.23.2
scikit-learn>=1.4
termcolor>=1.1.0
pycocotools>=2.0.1
pytorch_lightning>=1.6
Expand Down
2 changes: 1 addition & 1 deletion run_all_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ create_env() {
fi

# Python version
PYTHON_VER=3.8
PYTHON_VER=3.9
ENV_NAME="fuse_$PYTHON_VER-CUDA-$force_cuda_version-$(echo -n $requirements | sha256sum | awk '{print $1;}')"
echo $ENV_NAME

Expand Down

0 comments on commit 1626ae1

Please sign in to comment.