Skip to content

Commit 0e57336

Browse files
authored
masked loss PDU (#147)
1 parent 72c8af7 commit 0e57336

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/trainer/unlearn/pdu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
117117
maxLogits = logits.max(dim=-1)[0]
118118
averageLogits = logits.mean(dim=-1)
119119

120-
forget_loss = ((maxLogits - averageLogits) ** 2).mean()
120+
forget_loss = (maxLogits - averageLogits) ** 2
121+
mask = (forget_inputs["labels"] != -100).reshape(-1)
122+
forget_loss = (forget_loss * mask).sum() / mask.sum()
121123

122124
retain_inputs = inputs["retain"]
123125
retain_inputs = {

0 commit comments

Comments
 (0)