diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 652a780..1095682 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -90,13 +90,15 @@ def stability_after_success( return new_s def stability_after_failure(self, state: Tensor, r: Tensor) -> Tensor: + old_s = state[:, 0] new_s = ( self.w[11] * torch.pow(state[:, 1], -self.w[12]) - * (torch.pow(state[:, 0] + 1, self.w[13]) - 1) + * (torch.pow(old_s + 1, self.w[13]) - 1) * torch.exp((1 - r) * self.w[14]) ) - return torch.minimum(new_s, state[:, 0]) + new_minimum_s = old_s / torch.exp(self.w[17] * self.w[18]) + return torch.minimum(new_s, new_minimum_s) def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor: new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18]))