Skip to content

Commit

Permalink
Fix/change RMSE to MAPE in the stability charts (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Aug 9, 2024
1 parent 4b56fd9 commit ef723f8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.0.5"
version = "5.0.6"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
15 changes: 11 additions & 4 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score
from sklearn.metrics import (
root_mean_squared_error,
mean_absolute_error,
mean_absolute_percentage_error,
r2_score,
)
from scipy.optimize import minimize
from itertools import accumulate
from tqdm.auto import tqdm
Expand Down Expand Up @@ -1272,7 +1277,9 @@ def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
(
f"{ivl}d"
if ivl < 30
else f"{ivl / 30:.1f}m" if ivl < 365 else f"{ivl / 365:.1f}y"
else f"{ivl / 30:.1f}m"
if ivl < 365
else f"{ivl / 365:.1f}y"
)
for ivl in map(int, t_history.split(","))
]
Expand Down Expand Up @@ -1632,14 +1639,14 @@ def cal_stability(tmp):
analysis_group.dropna(inplace=True)
analysis_group.drop_duplicates(subset=[group_key], inplace=True)
analysis_group.sort_values(by=[group_key], inplace=True)
rmse = root_mean_squared_error(
mape = mean_absolute_percentage_error(
analysis_group["true_s"],
analysis_group["predicted_s"],
sample_weight=analysis_group["total_count"],
)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.set_title(f"RMSE={rmse:.2f}, last rating={last_rating}")
ax1.set_title(f"MAPE={mape:.2f}, last rating={last_rating}")
ax1.scatter(
analysis_group[group_key],
analysis_group["true_s"],
Expand Down

0 comments on commit ef723f8

Please sign in to comment.