Skip to content

Commit

Permalink
[fix] Fix SoftmaxLoss by initializing the optimizer over the loss(e…
Browse files Browse the repository at this point in the history
…s) rather than the model (UKPLab#2881)
  • Loading branch information
tomaarsen authored Aug 30, 2024
1 parent 52bf210 commit 0a32ec8
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
9 changes: 9 additions & 0 deletions sentence_transformers/losses/SoftmaxLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Callable, Iterable

import torch
import transformers
from packaging import version
from torch import Tensor, nn

from sentence_transformers.SentenceTransformer import SentenceTransformer
Expand Down Expand Up @@ -103,6 +105,13 @@ def __init__(
)
self.loss_fct = loss_fct

if version.parse(transformers.__version__) < version.parse("4.43.0"):
logger.warning(
"SoftmaxLoss requires transformers >= 4.43.0 to work correctly. "
"Otherwise, the classifier layer that maps embeddings to the labels cannot be updated. "
"Consider updating transformers with `pip install transformers>=4.43.0`."
)

def forward(
self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
) -> Tensor | tuple[Tensor, Tensor]:
Expand Down
39 changes: 39 additions & 0 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import warnings
from collections import OrderedDict
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -761,3 +762,41 @@ def create_model_card(
self.model.model_card_data.add_tags(tags)

self.model._create_model_card(self.args.output_dir, model_name=model_name)

def get_optimizer_cls_and_kwargs(
self, args: SentenceTransformerTrainingArguments, model: SentenceTransformer | None = None
) -> tuple[Any, Any]:
"""
We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model`
itself, but the SentenceTransformer losses can have weights that should be updated as well, e.g.
SoftmaxLoss (see #2872).
This method requires `transformers` >= 4.43.0.
"""

if isinstance(self.loss, dict):
loss_model = nn.Sequential(OrderedDict(self.loss))
else:
loss_model = self.loss
optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model)

# If the kwargs were not overridden by the super() call, then we should override them here so that the potential
# weights in the loss(es) can also be updated.
if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()):
decay_parameters = self.get_decay_parameter_names(loss_model)
optimizer_kwargs["optimizer_dict"] = [
{
"params": [
p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]

return optimizer_cls, optimizer_kwargs

0 comments on commit 0a32ec8

Please sign in to comment.