Skip to content

Commit

Permalink
Fix/too large initial stability for easy (#357)
Browse files Browse the repository at this point in the history
L-M-Sherlock authored Jul 16, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 884741b commit 487eee1
Showing 4 changed files with 238 additions and 232 deletions.
460 changes: 232 additions & 228 deletions fsrs4anki_optimizer.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions package/fsrs4anki_optimizer/__main__.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,7 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
print(analysis)

optimizer.define_model()
optimizer.pretrain()
optimizer.train()

optimizer.predict_memory_states()
7 changes: 4 additions & 3 deletions package/fsrs4anki_optimizer/fsrs4anki_optimizer.py
Original file line number Diff line number Diff line change
@@ -404,6 +404,7 @@ def cum_concat(x):

S0_dataset = df[df['i'] == 2]
self.S0_dataset_group = S0_dataset.groupby(by=['r_history', 'delta_t'], group_keys=False).agg({'y': ['mean', 'count']}).reset_index()
self.S0_dataset_group.to_csv('stability_for_pretrain.tsv', sep='\t', index=None)

df['retention'] = df.groupby(by=['r_history', 'delta_t'], group_keys=False)['y'].transform('mean')
df['total_cnt'] = df.groupby(by=['r_history', 'delta_t'], group_keys=False)['id'].transform('count')
@@ -488,7 +489,7 @@ def pretrain(self, verbose=True):
if total_count < 100:
tqdm.write(f'Not enough data for first rating {first_rating}. Expected at least 100, got {total_count}.')
continue
params, _ = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/np.sqrt(count), bounds=((0.1), (365)))
params, _ = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/np.sqrt(count), bounds=((0.1), (60 if total_count < 1000 else 365)))
stability = params[0]
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = total_count
@@ -537,7 +538,7 @@ def S0_rating_curve(rating, a, b, c):
plt.show()

for rating in (1, 2, 3, 4):
again_extrap = max(min(S0_rating_curve(1, *params), 365), 0.1)
again_extrap = max(min(S0_rating_curve(1, *params), 60), 0.1)
# if there isn't enough data to calculate the value for "Again" exactly
if 1 not in rating_stability:
# then check if there exists an exact value for "Hard"
@@ -554,7 +555,7 @@ def S0_rating_curve(rating, a, b, c):
else:
rating_stability[1] = again_extrap
elif rating not in rating_stability:
rating_stability[rating] = max(min(S0_rating_curve(rating, *params), 365), 0.1)
rating_stability[rating] = max(min(S0_rating_curve(rating, *params), 60), 0.1)

rating_stability = {k: round(v, 2) for k, v in sorted(rating_stability.items(), key=lambda item: item[0])}
for rating, stability in rating_stability.items():
2 changes: 1 addition & 1 deletion package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fsrs4anki_optimizer"
version = "4.0.3"
version = "4.0.4"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",

0 comments on commit 487eee1

Please sign in to comment.