From d26c6d6c77373fd2dc956185fd55b734d1cf8047 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 16 Dec 2024 16:37:16 +0800 Subject: [PATCH] Feat/option enable_short_term in training --- src/fsrs_optimizer/fsrs_optimizer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 1095682..b18b9af 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -299,7 +299,11 @@ def __init__( batch_size: int = 256, max_seq_len: int = 64, float_delta_t: bool = False, + enable_short_term: bool = True, ) -> None: + if not enable_short_term: + init_w[17] = 0 + init_w[18] = 0 self.model = FSRS(init_w, float_delta_t) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.clipper = ParameterClipper() @@ -315,6 +319,7 @@ def __init__( self.avg_eval_losses = [] self.loss_fn = nn.BCELoss(reduction="none") self.float_delta_t = float_delta_t + self.enable_short_term = enable_short_term def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]): self.train_set = BatchDataset( @@ -353,9 +358,12 @@ def train(self, verbose: bool = True): retentions = power_forgetting_curve(delta_ts, stabilities) loss = self.loss_fn(retentions, labels).sum() loss.backward() - if self.float_delta_t: + if self.float_delta_t or not self.enable_short_term: for param in self.model.parameters(): param.grad[:4] = torch.zeros(4) + if not self.enable_short_term: + for param in self.model.parameters(): + param.grad[17:19] = torch.zeros(2) self.optimizer.step() self.scheduler.step() self.model.apply(self.clipper) @@ -504,10 +512,14 @@ def loss(stability): class Optimizer: float_delta_t: bool = False + enable_short_term: bool = True - def __init__(self, float_delta_t: bool = False) -> None: + def __init__( + self, float_delta_t: bool = False, enable_short_term: bool = True + ) -> None: tqdm.pandas() self.float_delta_t = float_delta_t + self.enable_short_term = enable_short_term global S_MIN S_MIN = 1e-6 if float_delta_t else 0.01 @@ -1197,6 +1209,7 @@ def train( lr=lr, batch_size=batch_size, float_delta_t=self.float_delta_t, + enable_short_term=self.enable_short_term, ) w.append(trainer.train(verbose=verbose)) self.w = w[-1] @@ -1217,6 +1230,7 @@ def train( lr=lr, batch_size=batch_size, float_delta_t=self.float_delta_t, + enable_short_term=self.enable_short_term, ) w.append(trainer.train(verbose=verbose)) if verbose: