diff --git a/rankers/train/loss/listwise.py b/rankers/train/loss/listwise.py index e4bb638..58c5b8e 100644 --- a/rankers/train/loss/listwise.py +++ b/rankers/train/loss/listwise.py @@ -15,7 +15,7 @@ def __init__(self, reduction="batchmean", temperature=1.0): self.temperature = temperature self.kl_div = torch.nn.KLDivLoss(reduction=self.reduction) - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: return self.kl_div( F.log_softmax(pred / self.temperature, dim=1), F.softmax(labels / self.temperature, dim=1), @@ -34,7 +34,7 @@ def __init__(self, reduction="mean", temperature=1.0): self.temperature = temperature self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction) - def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor: _, g = pred.shape i1, i2 = torch.triu_indices(g, g, offset=1) pred_diff = pred[:, i1] - pred[:, i2] @@ -69,7 +69,7 @@ def __init__( self.base_margin = base_margin self.increment_margin = increment_margin - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: _, g = pred.shape i1, i2 = torch.triu_indices(g, g, offset=1) @@ -95,7 +95,7 @@ def __init__(self, reduction="mean", temperature=1.0, epsilon=1e-8): self.temperature = temperature self.epsilon = epsilon - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: if not torch.all((labels >= 0) & (labels <= 1)): labels = F.softmax(labels / self.temperature, dim=1) return self._reduce( @@ -118,7 +118,7 @@ def __init__(self, reduction="mean", epsilon: float = 1.0, temperature=1.0): self.temperature = temperature self.ce = torch.nn.CrossEntropyLoss(reduction="none") - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: labels_for_softmax = torch.divide(labels, labels.sum(dim=1)) expansion = ( labels_for_softmax * F.softmax(pred / self.temperature, dim=1) diff --git a/rankers/train/loss/pairwise.py b/rankers/train/loss/pairwise.py index d1c27cd..9893aad 100644 --- a/rankers/train/loss/pairwise.py +++ b/rankers/train/loss/pairwise.py @@ -12,7 +12,7 @@ class MarginMSELoss(BaseLoss): name = "MarginMSE" - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: residual_pred = pred[:, 0].unsqueeze(1) - pred[:, 1:] residual_label = labels[:, 0].unsqueeze(1) - labels[:, 1:] return F.mse_loss(residual_pred, residual_label, reduction=self.reduction) @@ -28,7 +28,7 @@ def __init__(self, margin=1, reduction="mean"): super().__init__(reduction) self.margin = margin - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: pred_residuals = F.relu(residual(F.sigmoid(pred))) label_residuals = torch.sign(residual(F.sigmoid(labels))) return self._reduce(F.relu(self.margin - (label_residuals * pred_residuals))) @@ -44,7 +44,7 @@ def __init__(self, margin=1, reduction="mean"): super().__init__(reduction) self.margin = margin - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: margin_b = self.margin - residual(labels) return self._reduce(F.relu(margin_b - residual(pred))) @@ -55,7 +55,7 @@ class LCELoss(BaseLoss): name = "LCE" - def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor: if labels is not None: labels = labels.argmax(dim=1) else: @@ -73,7 +73,7 @@ def __init__(self, reduction="mean", temperature=1.0): super().__init__(reduction) self.temperature = temperature - def forward(self, pred: Tensor, labels: Tensor = None) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor = None, **kwargs) -> Tensor: softmax_scores = F.log_softmax(pred / self.temperature, dim=1) labels = ( labels.argmax(dim=1) diff --git a/rankers/train/loss/pointwise.py b/rankers/train/loss/pointwise.py index a5c82bc..02cfe47 100644 --- a/rankers/train/loss/pointwise.py +++ b/rankers/train/loss/pointwise.py @@ -10,7 +10,7 @@ class PointwiseMSELoss(BaseLoss): name = "PointwiseMSE" - def forward(self, pred: Tensor, labels: Tensor) -> Tensor: + def forward(self, pred: Tensor, labels: Tensor, **kwargs) -> Tensor: flattened_pred = pred.view(-1) flattened_labels = labels.view(-1) return F.mse_loss(flattened_pred, flattened_labels, reduction=self.reduction)