Skip to content

Commit

Permalink
fixing all losses
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Dec 17, 2024
1 parent 5e3139b commit b910c53
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions rankers/train/loss/torch/listwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, reduction="batchmean", temperature=1.0):
self.temperature = temperature
self.kl_div = torch.nn.KLDivLoss(reduction=self.reduction)

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
return self.kl_div(
F.log_softmax(pred / self.temperature, dim=1),
F.softmax(labels / self.temperature, dim=1),
Expand All @@ -35,7 +35,7 @@ def __init__(self, reduction="mean", temperature=1.0):
self.temperature = temperature
self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)

def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor:
_, g = pred.shape
i1, i2 = torch.triu_indices(g, g, offset=1)
pred_diff = pred[:, i1] - pred[:, i2]
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
self.base_margin = base_margin
self.increment_margin = increment_margin

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
_, g = pred.shape
i1, i2 = torch.triu_indices(g, g, offset=1)

Expand All @@ -96,7 +96,7 @@ def __init__(self, reduction="mean", temperature=1.0, epsilon=1e-8):
self.temperature = temperature
self.epsilon = epsilon

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
if not torch.all((labels >= 0) & (labels <= 1)):
labels = F.softmax(labels / self.temperature, dim=1)
return self._reduce(
Expand All @@ -119,7 +119,7 @@ def __init__(self, reduction="mean", epsilon: float = 1.0, temperature=1.0):
self.temperature = temperature
self.ce = torch.nn.CrossEntropyLoss(reduction="none")

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
labels_for_softmax = torch.divide(labels, labels.sum(dim=1))
expansion = (
labels_for_softmax * F.softmax(pred / self.temperature, dim=1)
Expand Down
10 changes: 5 additions & 5 deletions rankers/train/loss/torch/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class MarginMSELoss(BaseLoss):

name = "MarginMSE"

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
residual_pred = residual(pred)
residual_label = residual(labels)
return F.mse_loss(residual_pred, residual_label, reduction=self.reduction)
Expand All @@ -31,7 +31,7 @@ def __init__(self, margin=1, reduction="mean"):
super().__init__(reduction)
self.margin = margin

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
pred_residuals = F.relu(residual(F.sigmoid(pred)))
label_residuals = torch.sign(residual(F.sigmoid(labels)))
return self._reduce(F.relu(self.margin - (label_residuals * pred_residuals)))
Expand All @@ -47,7 +47,7 @@ def __init__(self, margin=1, reduction="mean"):
super().__init__(reduction)
self.margin = margin

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
margin_b = self.margin - residual(labels)
return self._reduce(F.relu(margin_b - residual(pred)))

Expand All @@ -58,7 +58,7 @@ class LCELoss(BaseLoss):

name = "LCE"

def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor:
if labels is not None:
labels = labels.argmax(dim=1)
else:
Expand All @@ -76,7 +76,7 @@ def __init__(self, reduction="mean", temperature=1.0):
super().__init__(reduction)
self.temperature = temperature

def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor:
softmax_scores = F.log_softmax(pred / self.temperature, dim=1)
labels = (
labels.argmax(dim=1)
Expand Down
2 changes: 1 addition & 1 deletion rankers/train/loss/torch/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PointwiseMSELoss(BaseLoss):

name = "PointwiseMSE"

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
flattened_pred = pred.view(-1)
flattened_labels = labels.view(-1)
return F.mse_loss(flattened_pred, flattened_labels, reduction=self.reduction)
Expand Down

0 comments on commit b910c53

Please sign in to comment.