Skip to content

Commit

Permalink
apply L2 regularization based on init_w
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 7, 2025
1 parent af9b79f commit b7409e0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
S_MIN = 0.01


DEFAULT_PARAMS_TENSOR = torch.tensor(DEFAULT_PARAMETER, dtype=torch.float)
DEFAULT_PARAMS_STDDEV_TENSOR = torch.tensor(
[
6.61,
Expand Down Expand Up @@ -338,6 +337,7 @@ def __init__(
init_w[17] = 0
init_w[18] = 0
self.model = FSRS(init_w, float_delta_t)
self.init_w_tensor = torch.tensor(init_w, dtype=torch.float)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.clipper = ParameterClipper()
self.gamma = gamma
Expand Down Expand Up @@ -392,7 +392,7 @@ def train(self, verbose: bool = True):
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).sum()
penalty = torch.sum(
torch.square(self.model.w - DEFAULT_PARAMS_TENSOR)
torch.square(self.model.w - self.init_w_tensor)
/ torch.square(DEFAULT_PARAMS_STDDEV_TENSOR)
)
loss += penalty * self.gamma * real_batch_size / epoch_len
Expand Down Expand Up @@ -447,7 +447,7 @@ def eval(self):
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).mean()
penalty = torch.sum(
torch.square(self.model.w - DEFAULT_PARAMS_TENSOR)
torch.square(self.model.w - self.init_w_tensor)
/ torch.square(DEFAULT_PARAMS_STDDEV_TENSOR)
)
loss += penalty * self.gamma / len(self.train_set.y_train)
Expand Down

0 comments on commit b7409e0

Please sign in to comment.