Skip to content

Commit

Permalink
flexible loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 22, 2024
1 parent 20b3ddb commit c40c315
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions rankers/train/loss/listwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, reduction="batchmean", temperature=1.0):
self.temperature = temperature
self.kl_div = torch.nn.KLDivLoss(reduction=self.reduction)

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
return self.kl_div(
F.log_softmax(pred / self.temperature, dim=1),
F.softmax(labels / self.temperature, dim=1),
Expand All @@ -34,7 +34,7 @@ def __init__(self, reduction="mean", temperature=1.0):
self.temperature = temperature
self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)

def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor:
_, g = pred.shape
i1, i2 = torch.triu_indices(g, g, offset=1)
pred_diff = pred[:, i1] - pred[:, i2]
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
self.base_margin = base_margin
self.increment_margin = increment_margin

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
_, g = pred.shape
i1, i2 = torch.triu_indices(g, g, offset=1)

Expand All @@ -95,7 +95,7 @@ def __init__(self, reduction="mean", temperature=1.0, epsilon=1e-8):
self.temperature = temperature
self.epsilon = epsilon

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
if not torch.all((labels >= 0) & (labels <= 1)):
labels = F.softmax(labels / self.temperature, dim=1)
return self._reduce(
Expand All @@ -118,7 +118,7 @@ def __init__(self, reduction="mean", epsilon: float = 1.0, temperature=1.0):
self.temperature = temperature
self.ce = torch.nn.CrossEntropyLoss(reduction="none")

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
labels_for_softmax = torch.divide(labels, labels.sum(dim=1))
expansion = (
labels_for_softmax * F.softmax(pred / self.temperature, dim=1)
Expand Down
10 changes: 5 additions & 5 deletions rankers/train/loss/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MarginMSELoss(BaseLoss):

name = "MarginMSE"

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
residual_pred = pred[:, 0].unsqueeze(1) - pred[:, 1:]
residual_label = labels[:, 0].unsqueeze(1) - labels[:, 1:]
return F.mse_loss(residual_pred, residual_label, reduction=self.reduction)
Expand All @@ -28,7 +28,7 @@ def __init__(self, margin=1, reduction="mean"):
super().__init__(reduction)
self.margin = margin

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
pred_residuals = F.relu(residual(F.sigmoid(pred)))
label_residuals = torch.sign(residual(F.sigmoid(labels)))
return self._reduce(F.relu(self.margin - (label_residuals * pred_residuals)))
Expand All @@ -44,7 +44,7 @@ def __init__(self, margin=1, reduction="mean"):
super().__init__(reduction)
self.margin = margin

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
margin_b = self.margin - residual(labels)
return self._reduce(F.relu(margin_b - residual(pred)))

Expand All @@ -55,7 +55,7 @@ class LCELoss(BaseLoss):

name = "LCE"

def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor:
if labels is not None:
labels = labels.argmax(dim=1)
else:
Expand All @@ -73,7 +73,7 @@ def __init__(self, reduction="mean", temperature=1.0):
super().__init__(reduction)
self.temperature = temperature

def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor:
softmax_scores = F.log_softmax(pred / self.temperature, dim=1)
labels = (
labels.argmax(dim=1)
Expand Down
2 changes: 1 addition & 1 deletion rankers/train/loss/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PointwiseMSELoss(BaseLoss):

name = "PointwiseMSE"

def forward(self, pred: Tensor, labels: Tensor) -> Tensor:
def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor:
flattened_pred = pred.view(-1)
flattened_labels = labels.view(-1)
return F.mse_loss(flattened_pred, flattened_labels, reduction=self.reduction)
Expand Down

0 comments on commit c40c315

Please sign in to comment.