Skip to content

Commit

Permalink
end this madness
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 22, 2024
1 parent 4ff3fd6 commit 43fdbeb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rankers/train/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def step(self):
self.d_weight = self.d_weight * (self.t / self.T) ** 2

@abstractmethod
def reg(reps, weight=0):
def reg(self, reps, weight=0):
raise NotImplementedError

def forward(self, query_hidden_states, text_hidden_states, **kwargs):
Expand All @@ -198,15 +198,15 @@ def __init__(
) -> None:
super(FLOPSLoss, self).__init__(q_weight, d_weight, t, T, reduction)

def reg(reps, weight=0):
def reg(self, reps, weight=0):
return (torch.abs(reps).mean(dim=0) ** 2).sum() * weight


class L1Loss(BaseLoss):
def __init__(self, reduction: str = "mean") -> None:
super(L1Loss, self).__init__(reduction)

def reg(reps, weight=0):
def reg(self, reps, weight=0):
return torch.abs(reps).sum(dim=1).mean() * weight


Expand Down

0 comments on commit 43fdbeb

Please sign in to comment.