diff --git a/sentence_transformers/losses/SoftmaxLoss.py b/sentence_transformers/losses/SoftmaxLoss.py index 887356e11..48a30c452 100644 --- a/sentence_transformers/losses/SoftmaxLoss.py +++ b/sentence_transformers/losses/SoftmaxLoss.py @@ -4,6 +4,8 @@ from typing import Callable, Iterable import torch +import transformers +from packaging import version from torch import Tensor, nn from sentence_transformers.SentenceTransformer import SentenceTransformer @@ -103,6 +105,13 @@ def __init__( ) self.loss_fct = loss_fct + if version.parse(transformers.__version__) < version.parse("4.43.0"): + logger.warning( + "SoftmaxLoss requires transformers >= 4.43.0 to work correctly. " + "Otherwise, the classifier layer that maps embeddings to the labels cannot be updated. " + "Consider updating transformers with `pip install transformers>=4.43.0`." + ) + def forward( self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor ) -> Tensor | tuple[Tensor, Tensor]: diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 719958360..e4b65298f 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -3,6 +3,7 @@ import logging import os import warnings +from collections import OrderedDict from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable @@ -761,3 +762,41 @@ def create_model_card( self.model.model_card_data.add_tags(tags) self.model._create_model_card(self.args.output_dir, model_name=model_name) + + def get_optimizer_cls_and_kwargs( + self, args: SentenceTransformerTrainingArguments, model: SentenceTransformer | None = None + ) -> tuple[Any, Any]: + """ + We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model` + itself, but the SentenceTransformer losses can have weights that should be updated as well, e.g. + SoftmaxLoss (see #2872). + + This method requires `transformers` >= 4.43.0. + """ + + if isinstance(self.loss, dict): + loss_model = nn.Sequential(OrderedDict(self.loss)) + else: + loss_model = self.loss + optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model) + + # If the kwargs were not overridden by the super() call, then we should override them here so that the potential + # weights in the loss(es) can also be updated. + if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()): + decay_parameters = self.get_decay_parameter_names(loss_model) + optimizer_kwargs["optimizer_dict"] = [ + { + "params": [ + p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + return optimizer_cls, optimizer_kwargs