Skip to content

Commit

Permalink
Fix/cosmetic-changes (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jul 14, 2023
1 parent ced61d2 commit 544d773
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 675 deletions.
874 changes: 264 additions & 610 deletions fsrs4anki_optimizer.ipynb

Large diffs are not rendered by default.

127 changes: 63 additions & 64 deletions package/fsrs4anki_optimizer/fsrs4anki_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def build_dataset(self, train_set: pd.DataFrame, test_set: pd.DataFrame):
self.test_set = RevlogDataset(test_set)
sampler = RevlogSampler(self.test_set, batch_size=self.batch_size)
self.test_data_loader = DataLoader(self.test_set, batch_sampler=sampler, collate_fn=collate_fn)
print("dataset built")
tqdm.write("dataset built")

def train(self, verbose: bool=True):
best_loss = np.inf
Expand Down Expand Up @@ -262,7 +262,6 @@ def eval(self):
retentions = power_forgetting_curve(delta_ts, stabilities)
tran_loss = self.loss_fn(retentions, labels).mean()
self.avg_train_losses.append(tran_loss)
tqdm.write(f"Loss in trainset: {tran_loss:.4f}")

sequences, delta_ts, labels, seq_lens = self.test_set.x_train, self.test_set.t_train, self.test_set.y_train, self.test_set.seq_len
real_batch_size = seq_lens.shape[0]
Expand All @@ -271,7 +270,6 @@ def eval(self):
retentions = power_forgetting_curve(delta_ts, stabilities)
test_loss = self.loss_fn(retentions, labels).mean()
self.avg_eval_losses.append(test_loss)
tqdm.write(f"Loss in testset: {test_loss:.4f}")

w = list(map(lambda x: round(float(x), 4), dict(self.model.named_parameters())['w'].data))

Expand Down Expand Up @@ -318,7 +316,7 @@ def anki_extract(filename: str):
# Extract the collection file or deck file to get the .anki21 database.
with zipfile.ZipFile(f'{filename}', 'r') as zip_ref:
zip_ref.extractall('./')
print("Deck file extracted successfully!")
tqdm.write("Deck file extracted successfully!")

def create_time_series(self, timezone: str, revlog_start_date: str, next_day_starts_at: int, filter_out_suspended_cards: bool = False):
"""Step 2"""
Expand Down Expand Up @@ -374,7 +372,7 @@ def create_time_series(self, timezone: str, revlog_start_date: str, next_day_sta
self.type_sequence = np.array(df['type'])
self.time_sequence = np.array(df['time'])
df.to_csv("revlog.csv", index=False)
print("revlog.csv saved.")
tqdm.write("revlog.csv saved.")

df = df[(df['type'] != 3) | (df['factor'] != 0)].copy()
df['real_days'] = df['review_date'] - timedelta(hours=int(next_day_starts_at))
Expand Down Expand Up @@ -402,14 +400,14 @@ def cum_concat(x):
df = df.groupby('cid').filter(lambda group: group['id'].min() > time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) * 1000)
df['y'] = df['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x])
df.to_csv('revlog_history.tsv', sep="\t", index=False)
print("Trainset saved.")
tqdm.write("Trainset saved.")

S0_dataset = df[df['i'] == 2]
self.S0_dataset_group = S0_dataset.groupby(by=['r_history', 'delta_t'], group_keys=False).agg({'y': ['mean', 'count']}).reset_index()

df['retention'] = df.groupby(by=['r_history', 'delta_t'], group_keys=False)['y'].transform('mean')
df['total_cnt'] = df.groupby(by=['r_history', 'delta_t'], group_keys=False)['id'].transform('count')
print("Retention calculated.")
tqdm.write("Retention calculated.")

df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_ivl', 'factor', 'time', 'type', 'create_date', 'review_date', 'real_days', 'r', 't_history', 'y'])
df.drop_duplicates(inplace=True)
Expand All @@ -434,7 +432,7 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
return group

df = df.groupby(by=['r_history'], group_keys=False).progress_apply(cal_stability)
print("Stability calculated.")
tqdm.write("Stability calculated.")
df.reset_index(drop = True, inplace = True)
df.drop_duplicates(inplace=True)
df.sort_values(by=['r_history'], inplace=True, ignore_index=True)
Expand All @@ -449,7 +447,7 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
df['last_recall'] = df['r_history'].map(lambda x: x[-1])
df = df[df.groupby(['i', 'r_history'], group_keys=False)['group_cnt'].transform(max) == df['group_cnt']]
df.to_csv('./stability_for_analysis.tsv', sep='\t', index=None)
print("Analysis saved!")
tqdm.write("Analysis saved!")
caption = "1:again, 2:hard, 3:good, 4:easy\n"
analysis = df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False)
return caption + analysis
Expand All @@ -475,7 +473,7 @@ def define_model(self):
https://github.com/open-spaced-repetition/fsrs4anki/wiki/Free-Spaced-Repetition-Scheduler
'''

def pretrain(self):
def pretrain(self, verbose=True):
rating_stability = {}
rating_count = {}

Expand All @@ -486,53 +484,55 @@ def pretrain(self):
delta_t = group['delta_t']
recall = group['y']['mean']
count = group['y']['count']
if sum(count) < 100:
print(f'Not enough data for first rating {first_rating}. Expected at least 100, got {sum(count)}.')
total_count = sum(count)
if total_count < 100:
tqdm.write(f'Not enough data for first rating {first_rating}. Expected at least 100, got {total_count}.')
continue
params, covs = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/np.sqrt(count), bounds=((0.1), (3650)))
rating_stability[int(first_rating)] = params[0]
rating_count[int(first_rating)] = sum(count)
print('Weighted fit parameters:', params)
print('Number of reviews:', sum(count))
params, _ = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/np.sqrt(count), bounds=((0.1), (365)))
stability = params[0]
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = total_count
predict_recall = power_forgetting_curve(delta_t, *params)
print(f'RMSE: {mean_squared_error(recall, predict_recall, sample_weight=count, squared=False):.4f}')
rmse = mean_squared_error(recall, predict_recall, sample_weight=count, squared=False)

if verbose:
plt.plot(delta_t, recall, label='Exact')
plt.plot(np.linspace(0, 30), power_forgetting_curve(np.linspace(0, 30), *params), label=f'Weighted fit (RMSE: {rmse:.4f})')
count_percent = np.array([x/total_count for x in count])
plt.scatter(delta_t, recall, s=count_percent * 1000, alpha=0.5)
plt.legend(loc='upper right', fancybox=True, shadow=False)
plt.grid(True)
plt.ylim(0, 1)
plt.xlim(0, 30)
plt.xlabel('Interval')
plt.ylabel('Recall')
plt.title(f'Forgetting curve for first rating {first_rating} (n={total_count}, s={stability:.2f})')
plt.show()
tqdm.write(str(rating_stability))

plt.plot(delta_t, recall, label='Exact')
plt.plot(np.linspace(0, 30), power_forgetting_curve(np.linspace(0, 30), *params), label='Weighted fit')
count_percent = np.array([x/sum(count) for x in count])
plt.scatter(delta_t, recall, s=count_percent * 1000, alpha=0.5)
plt.legend(loc='upper right', fancybox=True, shadow=False)
plt.grid(True)
plt.ylim(0, 1)
plt.xlim(0, 30)
plt.xlabel('Interval')
plt.ylabel('Recall')
plt.title('Forgetting curve for first rating ' + first_rating)
plt.show()

print(rating_stability)
if len(rating_stability) < 2:
raise Exception("Not enough data for pretraining!")

def S0_rating_curve(rating, a, b, c):
return np.exp(a + b * rating) + c

params, covs = curve_fit(S0_rating_curve, list(rating_stability.keys()), list(rating_stability.values()), sigma=1/np.sqrt(list(rating_count.values())), method='dogbox', bounds=((-15, 0.03, -5), (15, 7, 30)))
print('Weighted fit parameters:', params)
predict_stability = S0_rating_curve(np.array(list(rating_stability.keys())), *params)
print("Fit stability:", predict_stability)
print(f'RMSE: {mean_squared_error(list(rating_stability.values()), predict_stability, sample_weight=list(rating_count.values()), squared=False):.4f}')
plt.plot(list(rating_stability.keys()), list(rating_stability.values()), label='Exact')
plt.plot(list(rating_stability.keys()), predict_stability, label='Weighted fit')
plt.legend(loc='upper right', fancybox=True, shadow=False)
plt.grid(True)
plt.xlabel('First rating')
plt.ylabel('Stability')
plt.title('Stability for first rating')
plt.show()
if verbose:
tqdm.write(f'Weighted fit parameters: {params}')
predict_stability = S0_rating_curve(np.array(list(rating_stability.keys())), *params)
tqdm.write(f"Fit stability: {predict_stability}")
tqdm.write(f'RMSE: {mean_squared_error(list(rating_stability.values()), predict_stability, sample_weight=list(rating_count.values()), squared=False):.4f}')
plt.plot(list(rating_stability.keys()), list(rating_stability.values()), label='Exact')
plt.plot(list(rating_stability.keys()), predict_stability, label='Weighted fit')
plt.legend(loc='upper right', fancybox=True, shadow=False)
plt.grid(True)
plt.xlabel('First rating')
plt.ylabel('Stability')
plt.title('Stability for first rating')
plt.show()

for rating in (1, 2, 3, 4):
again_extrap = max(min(S0_rating_curve(1, *params), 3650), 0.1)
again_extrap = max(min(S0_rating_curve(1, *params), 365), 0.1)
# if there isn't enough data to calculate the value for "Again" exactly
if 1 not in rating_stability:
# then check if there exists an exact value for "Hard"
Expand All @@ -549,12 +549,12 @@ def S0_rating_curve(rating, a, b, c):
else:
rating_stability[1] = again_extrap
elif rating not in rating_stability:
rating_stability[rating] = max(min(S0_rating_curve(rating, *params), 3650), 0.1)
rating_stability[rating] = max(min(S0_rating_curve(rating, *params), 365), 0.1)

rating_stability = {k: round(v, 2) for k, v in sorted(rating_stability.items(), key=lambda item: item[0])}
for rating, stability in rating_stability.items():
self.init_w[rating-1] = stability

tqdm.write(f"Pretrain finished!")

def train(self, lr: float = 4e-2, n_epoch: int = 5, n_splits: int = 5, batch_size: int = 512, verbose: bool = True):
"""Step 4"""
Expand All @@ -564,14 +564,14 @@ def train(self, lr: float = 4e-2, n_epoch: int = 5, n_splits: int = 5, batch_siz
raise ValueError('Training data is inadequate.')
self.dataset['tensor'] = self.dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)
self.dataset['group'] = self.dataset['r_history'] + self.dataset['t_history']
print("Tensorized!")
tqdm.write("Tensorized!")

w = []
plots = []
if n_splits > 1:
sgkf = StratifiedGroupKFold(n_splits=n_splits)
for train_index, test_index in sgkf.split(self.dataset, self.dataset['i'], self.dataset['group']):
print("TRAIN:", len(train_index), "TEST:", len(test_index))
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
train_set = self.dataset.iloc[train_index].copy()
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(train_set, test_set, self.init_w, n_epoch=n_epoch, lr=lr, batch_size=batch_size)
Expand All @@ -586,7 +586,7 @@ def train(self, lr: float = 4e-2, n_epoch: int = 5, n_splits: int = 5, batch_siz
avg_w = np.round(np.mean(w, axis=0), 4)
self.w = avg_w.tolist()

print("\nTraining finished!")
tqdm.write("\nTraining finished!")
return plots

def preview(self, requestRetention: float):
Expand Down Expand Up @@ -620,7 +620,6 @@ def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
for i in range(len(test_rating_sequence.split(','))):
r_history = test_rating_sequence[:2*i+1]
states = my_collection.predict(t_history, r_history)
print(states)
next_t = next_interval(states[0], requestRetention)
t_history += f',{int(next_t)}'
difficulty = round(float(states[1]), 1)
Expand All @@ -643,7 +642,7 @@ def predict_memory_states(self):
prediction.sort_values(by=['r_history'], inplace=True)
prediction.rename(columns={"id": "count"}, inplace=True)
prediction.to_csv("./prediction.tsv", sep='\t', index=None)
print("prediction.tsv saved.")
tqdm.write("prediction.tsv saved.")
prediction['difficulty'] = prediction['difficulty'].map(lambda x: int(round(x)))
self.difficulty_distribution = prediction.groupby(by=['difficulty'])['count'].sum() / prediction['count'].sum()
self.difficulty_distribution_padding = np.zeros(10)
Expand Down Expand Up @@ -683,8 +682,8 @@ def find_optimal_retention(self):
if 2 in type_count and 2 in type_block:
f_time = round(type_time[2]/type_block[2]/1000 + r_time, 1)

print(f"average time for failed cards: {f_time}s")
print(f"average time for recalled cards: {r_time}s")
tqdm.write(f"average time for failed cards: {f_time}s")
tqdm.write(f"average time for recalled cards: {r_time}s")

def stability2index(stability):
return (np.log(stability) / np.log(base)).round().astype(int) + index_offset
Expand All @@ -700,7 +699,7 @@ def cal_next_recall_stability(s, r, d, response):
terminal_stability = cal_next_recall_stability(terminal_stability, 0.96, d_range, 1)
index_len = stability2index(terminal_stability)
stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])
print(f"terminal stability: {stability_list.max(): .2f}")
tqdm.write(f"terminal stability: {stability_list.max(): .2f}")
df = pd.DataFrame(columns=["retention", "difficulty", "time"])

for percentage in tqdm(range(96, 66, -2), desc="find optimal retention"):
Expand Down Expand Up @@ -735,7 +734,7 @@ def cal_next_recall_stability(s, r, d, response):

df.sort_values(by=["difficulty", "retention"], inplace=True)
df.to_csv("./expected_time.csv", index=False)
print("expected_time.csv saved.")
tqdm.write("expected_time.csv saved.")

optimal_retention_list = np.zeros(10)
fig = plt.figure()
Expand All @@ -749,7 +748,7 @@ def cal_next_recall_stability(s, r, d, response):

self.optimal_retention = np.inner(self.difficulty_distribution_padding, optimal_retention_list)

print(f"\n-----suggested retention (experimental): {self.optimal_retention:.2f}-----")
tqdm.write(f"\n-----suggested retention (experimental): {self.optimal_retention:.2f}-----")

ax.set_ylabel("expected time (second)")
ax.set_xlabel("retention")
Expand Down Expand Up @@ -790,7 +789,7 @@ def calibration_graph(self):
plot_brier(self.dataset['p'], self.dataset['y'], bins=40, ax=fig1.add_subplot(111))
fig2 = plt.figure(figsize=(16, 12))
for last_rating in ("1","2","3","4"):
print(f"\nLast rating: {last_rating}")
tqdm.write(f"\nLast rating: {last_rating}")
plot_brier(self.dataset[self.dataset['r_history'].str.endswith(last_rating)]['p'], self.dataset[self.dataset['r_history'].str.endswith(last_rating)]['y'], bins=40, ax=fig2.add_subplot(2, 2, int(last_rating)), title=f"Last rating: {last_rating}")

def to_percent(temp, position):
Expand Down Expand Up @@ -875,7 +874,7 @@ def compare_with_sm2(self):
self.dataset['sm2_ivl'] = self.dataset['tensor'].map(sm2)
self.dataset['sm2_p'] = np.exp(np.log(0.9) * self.dataset['delta_t'] / self.dataset['sm2_ivl'])
self.dataset['log_loss'] = self.dataset.apply(lambda row: - np.log(row['sm2_p']) if row['y'] == 1 else - np.log(1 - row['sm2_p']), axis=1)
print(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}")
tqdm.write(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}")
cross_comparison = self.dataset[['sm2_p', 'p', 'y']].copy()
fig1 = plt.figure()
plot_brier(cross_comparison['sm2_p'], cross_comparison['y'], bins=40, ax=fig1.add_subplot(111))
Expand All @@ -891,13 +890,13 @@ def compare_with_sm2(self):
ax.axhline(y = 0.0, color = 'black', linestyle = '-')

cross_comparison_group = cross_comparison.groupby(by='SM2_bin').agg({'y': ['mean'], 'FSRS_B-W': ['mean'], 'p': ['mean', 'count']})
print(f"Universal Metric of FSRS: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['p', 'mean'], sample_weight=cross_comparison_group['p', 'count'], squared=False):.4f}")
tqdm.write(f"Universal Metric of FSRS: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['p', 'mean'], sample_weight=cross_comparison_group['p', 'count'], squared=False):.4f}")
cross_comparison_group['p', 'percent'] = cross_comparison_group['p', 'count'] / cross_comparison_group['p', 'count'].sum()
ax.scatter(cross_comparison_group.index, cross_comparison_group['FSRS_B-W', 'mean'], s=cross_comparison_group['p', 'percent'] * 1024, alpha=0.5)
ax.plot(cross_comparison_group['FSRS_B-W', 'mean'], label='FSRS by SM2')

cross_comparison_group = cross_comparison.groupby(by='FSRS_bin').agg({'y': ['mean'], 'SM2_B-W': ['mean'], 'sm2_p': ['mean', 'count']})
print(f"Universal Metric of SM2: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['sm2_p', 'mean'], sample_weight=cross_comparison_group['sm2_p', 'count'], squared=False):.4f}")
tqdm.write(f"Universal Metric of SM2: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['sm2_p', 'mean'], sample_weight=cross_comparison_group['sm2_p', 'count'], squared=False):.4f}")
cross_comparison_group['sm2_p', 'percent'] = cross_comparison_group['sm2_p', 'count'] / cross_comparison_group['sm2_p', 'count'].sum()
ax.scatter(cross_comparison_group.index, cross_comparison_group['SM2_B-W', 'mean'], s=cross_comparison_group['sm2_p', 'percent'] * 1024, alpha=0.5)
ax.plot(cross_comparison_group['SM2_B-W', 'mean'], label='SM2 by FSRS')
Expand Down Expand Up @@ -947,13 +946,13 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
bin_counts = brier['detail']['bin_counts']
r2 = r2_score(bin_correct_means, bin_prediction_means, sample_weight=bin_counts)
rmse = np.sqrt(mean_squared_error(bin_correct_means, bin_prediction_means, sample_weight=bin_counts))
print(f"R-squared: {r2:.4f}")
print(f"RMSE: {rmse:.4f}")
tqdm.write(f"R-squared: {r2:.4f}")
tqdm.write(f"RMSE: {rmse:.4f}")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(True)
fit_wls = sm.WLS(bin_correct_means, sm.add_constant(bin_prediction_means), weights=bin_counts).fit()
print(fit_wls.params)
tqdm.write(str(fit_wls.params))
y_regression = [fit_wls.params[0] + fit_wls.params[1]*x for x in bin_prediction_means]
ax.plot(bin_prediction_means, y_regression, label='Weighted Least Squares Regression', color="green")
ax.plot(bin_prediction_means, bin_correct_means, label='Actual Calibration', color="#1f77b4")
Expand Down
2 changes: 1 addition & 1 deletion package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fsrs4anki_optimizer"
version = "4.0.1"
version = "4.0.2"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down

0 comments on commit 544d773

Please sign in to comment.