From b7409e0e991b3955bec048a7f108b9ba90cf35ea Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 7 Jan 2025 10:06:37 +0800 Subject: [PATCH] apply L2 regularization based on init_w --- src/fsrs_optimizer/fsrs_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index a0f9cac..851ef03 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -67,7 +67,6 @@ S_MIN = 0.01 -DEFAULT_PARAMS_TENSOR = torch.tensor(DEFAULT_PARAMETER, dtype=torch.float) DEFAULT_PARAMS_STDDEV_TENSOR = torch.tensor( [ 6.61, @@ -338,6 +337,7 @@ def __init__( init_w[17] = 0 init_w[18] = 0 self.model = FSRS(init_w, float_delta_t) + self.init_w_tensor = torch.tensor(init_w, dtype=torch.float) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.clipper = ParameterClipper() self.gamma = gamma @@ -392,7 +392,7 @@ def train(self, verbose: bool = True): retentions = power_forgetting_curve(delta_ts, stabilities) loss = (self.loss_fn(retentions, labels) * weights).sum() penalty = torch.sum( - torch.square(self.model.w - DEFAULT_PARAMS_TENSOR) + torch.square(self.model.w - self.init_w_tensor) / torch.square(DEFAULT_PARAMS_STDDEV_TENSOR) ) loss += penalty * self.gamma * real_batch_size / epoch_len @@ -447,7 +447,7 @@ def eval(self): retentions = power_forgetting_curve(delta_ts, stabilities) loss = (self.loss_fn(retentions, labels) * weights).mean() penalty = torch.sum( - torch.square(self.model.w - DEFAULT_PARAMS_TENSOR) + torch.square(self.model.w - self.init_w_tensor) / torch.square(DEFAULT_PARAMS_STDDEV_TENSOR) ) loss += penalty * self.gamma / len(self.train_set.y_train)