Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In Batch Negatives #3072

Open
riyajatar37003 opened this issue Nov 19, 2024 · 24 comments
Open

In Batch Negatives #3072

riyajatar37003 opened this issue Nov 19, 2024 · 24 comments
Labels
bug Something isn't working

Comments

@riyajatar37003
Copy link

riyajatar37003 commented Nov 19, 2024

Hi,
is there any way to disable in-batch negatives during training in Sentence Transformers?

Thanks
@tomaarsen

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 19, 2024

how can i create dataset where for each query , i have one positive and k-negatives .

    dataset_final[t] = Dataset.from_dict({
        "query": query,
        "positive": positive,
        "negative":[negative1,negative2,....,negative_k]
    })

can i do this way?

@tomaarsen
Copy link
Collaborator

tomaarsen commented Nov 19, 2024

Hello!

Yes, this is possible. If you have k negatives, then you'll have to use a custom loss function as there's no non-IBN loss that takes more than triplets. That should be fine, though. Here is an example:

from __future__ import annotations

from collections.abc import Iterable
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)

from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator

class KTupleLoss(nn.Module):
    def __init__(
        self, model: SentenceTransformer, scale: float = 20.0
    ) -> None:
        super().__init__()
        self.model = model
        self.scale = scale
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
        embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]

        # Collect the anchor, positive, and negative embeddings
        anchor_embeddings = embeddings[0]  # [batch_size, embedding_dim]
        positive_embeddings = embeddings[1]  # [batch_size, embedding_dim]
        negative_embeddings = torch.stack(embeddings[2:], dim=1)  # [batch_size, num_negatives, embedding_dim]

        # Normalize them
        anchor_embeddings = torch.nn.functional.normalize(anchor_embeddings, p=2, dim=-1)
        positive_embeddings = torch.nn.functional.normalize(positive_embeddings, p=2, dim=-1)
        negative_embeddings = torch.nn.functional.normalize(negative_embeddings, p=2, dim=-1)

        # Compute the similarity scores, i.e. 1) pairwise cosine similarity between anchor and positive,
        # and 2) pairwise cosine similarity between anchor and negatives
        pos_similarity = (anchor_embeddings * positive_embeddings).sum(1, keepdim=True) # [batch_size, 1]
        anchor_embeddings_3d = anchor_embeddings.unsqueeze(1) # [batch_size, 1, embedding_dim]
        neg_similarity = torch.matmul(anchor_embeddings_3d, negative_embeddings.transpose(1, 2)).squeeze(1) # [batch_size, num_negatives]

        # Concatenate the positive and negative similarity scores so we have 1 + num_negatives similarity scores per anchor
        scores = torch.cat((pos_similarity, neg_similarity), dim=1) * self.scale # [batch_size, 1 + num_negatives]

        # Set the labels as 0, i.e. the positive sample is always the first one in the scores tensor
        labels = torch.zeros(scores.size(0), dtype=torch.long, device=scores.device)

        return self.cross_entropy_loss(scores, labels)

    def get_config_dict(self) -> dict[str, Any]:
        return {"scale": self.scale}

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on AllNLI triplets",
    )
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"].select(range(1_000))
test_dataset = dataset["test"]

# This is a simple way to turn this into a dataset with k negative samples
def to_k_tuple(sample, k: int = 5):
    return {
        "anchor": sample["anchor"],
        "positive": sample["positive"],
        "negative": sample["negative"],
        **{
            f"negative_{i}": sample["negative"] for i in range(k - 1)
        }
    }

train_dataset = train_dataset.map(to_k_tuple, fn_kwargs={"k": 5})
eval_dataset = eval_dataset.map(to_k_tuple, fn_kwargs={"k": 5})
test_dataset = test_dataset.map(to_k_tuple, fn_kwargs={"k": 5})

# 4. Define a loss function
loss = KTupleLoss(model)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-all-nli-ktuple"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-nli-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

The KTupleLoss from this calculates the similarity score for anchor vs positive, as well as anchor vs k negatives. These are then concatenated such that for every single anchor, you get k+1 similarity scores. The first similarity score corresponds to the positive one, so we can set "labels" as a list of zeros (representing the index with the true positive score).

The loss can be optimized I believe (I think you can just take anchor_embeddings = embeddings[0] and then other_embeddings = torch.stack(embeddings[1:], dim=1), no need to separate the positive and the negative), but this should work.

This is akin to the MultipleNegativesRankingLoss, except no in-batch negatives.


You can create a k-negative dataset like so:

    dataset_final[t] = Dataset.from_dict({
        "query": query,
        "positive": positive,
        "negative_1": negative_1,
        "negative_2": negative_2,
        ...,
        "negative_k": negative_k,
    })

Here are my first logs: as you can see the model indeed learns (even though this test script makes k negatives in a bit of a hacky way by just repeating the actual 1 negative)

{'loss': 1.8269, 'grad_norm': 14.627801895141602, 'learning_rate': 3.04e-06, 'epoch': 0.02}
{'eval_loss': 1.4666913747787476, 'eval_all-nli-dev_cosine_accuracy': 0.719, 'eval_runtime': 20.5417, 'eval_samples_per_second': 48.682, 'eval_steps_per_second': 3.067, 'epoch': 0.02}                                                                                                                                               
{'loss': 1.0463, 'grad_norm': 60.64706802368164, 'learning_rate': 6.176000000000001e-06, 'epoch': 0.03}                                                            
{'eval_loss': 0.7802979946136475, 'eval_all-nli-dev_cosine_accuracy': 0.859, 'eval_runtime': 20.9557, 'eval_samples_per_second': 47.72, 'eval_steps_per_second': 3.006, 'epoch': 0.03}                                                                                                                                                
{'loss': 0.6681, 'grad_norm': 35.55289840698242, 'learning_rate': 9.376000000000001e-06, 'epoch': 0.05}                                                            
{'eval_loss': 0.540037989616394, 'eval_all-nli-dev_cosine_accuracy': 0.902, 'eval_runtime': 20.9779, 'eval_samples_per_second': 47.669, 'eval_steps_per_second': 3.003, 'epoch': 0.05}                                                                                                                                                
{'loss': 0.428, 'grad_norm': 22.288776397705078, 'learning_rate': 1.2576000000000001e-05, 'epoch': 0.06}                                                           
{'eval_loss': 0.49275851249694824, 'eval_all-nli-dev_cosine_accuracy': 0.913, 'eval_runtime': 20.7686, 'eval_samples_per_second': 48.15, 'eval_steps_per_second': 3.033, 'epoch': 0.06}                                                                                                                                               
  6%|███████▊                                                                                                                | 405/6250 [05:50<3:57:49,  2.44s/it]

I do want to say that in-batch negatives often help.

  • Tom Aarsen

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 19, 2024 via email

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 20, 2024

Hi ,
I am trying to train this https://huggingface.co/nvidia/NV-Embed-v2/tree/main model with lora using sentencne-transformer,
after few iteration checkpoint has been saved but during loading it using SentenceTransfomrer i m getting this error traceback
`/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00, 1.31it/s]
TypeError Traceback (most recent call last)
Cell In[1], line 43
41 # get the embeddings
42 batch_size = 2
---> 43 query_embeddings = model.encode(add_eos(queries), batch_size=batch_size, prompt=query_prefix, normalize_embeddings=True)
44 passage_embeddings = model.encode(add_eos(passages), batch_size=batch_size, normalize_embeddings=True)
46 scores = (query_embeddings @ passage_embeddings.T) * 100

File ~/.local/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:623, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
620 features.update(extra_features)
622 with torch.no_grad():
--> 623 out_features = self.forward(features, **kwargs)
624 if self.device.type == "hpu":
625 out_features = copy.deepcopy(out_features)

File ~/.local/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:690, in SentenceTransformer.forward(self, input, **kwargs)
688 module_kwarg_keys = self.module_kwargs.get(module_name, [])
689 module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 690 input = module(input, **module_kwargs)
691 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File ~/.local/lib/python3.10/site-packages/sentence_transformers/models/Transformer.py:393, in Transformer.forward(self, features, **kwargs)
390 if "token_type_ids" in features:
391 trans_features["token_type_ids"] = features["token_type_ids"]
--> 393 output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
394 output_tokens = output_states[0]
396 # If the AutoModel is wrapped with a PeftModelForFeatureExtraction, then it may have added virtual tokens
397 # We need to extend the attention mask to include these virtual tokens, or the pooling will fail

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1849, in PeftModelForFeatureExtraction.forward(self, input_ids, attention_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, **kwargs)
1847 peft_config = self.active_peft_config
1848 if not peft_config.is_prompt_learning:
-> 1849 return self.base_model(
1850 input_ids=input_ids,
1851 attention_mask=attention_mask,
1852 inputs_embeds=inputs_embeds,
1853 output_attentions=output_attentions,
1854 output_hidden_states=output_hidden_states,
1855 return_dict=return_dict,
1856 **kwargs,
1857 )
1859 batch_size = _get_batch_size(input_ids, inputs_embeds)
1860 if attention_mask is not None:
1861 # concat prompt attention mask

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:103, in BaseTuner.forward(self, *args, **kwargs)
102 def forward(self, *args: Any, **kwargs: Any):
--> 103 return self.model.forward(*args, **kwargs)

TypeError: NVEmbedModel.forward() got an unexpected keyword argument 'inputs_embeds'`

@tomaarsen
Copy link
Collaborator

I think this is because PEFT expects models to have the "standard" signature, e.g. https://github.com/huggingface/transformers/blob/f297af55dfc27485189f352cd36b4683de12e0b3/src/transformers/models/qwen2/modeling_qwen2.py#L808-L820
with inputs_embeds as a valid input.

But NV-Embed-v2 does not seem to have this parameter: https://huggingface.co/nvidia/NV-Embed-v2/blob/main/modeling_nvembed.py#L397

I think the only solution is to fix it in this modeling_nvembed.py file, I assume.

cc @BenjaminBossan as this is related to PEFT with a custom architecture - feel free to correct me if my above hypothesis is wrong.

  • Tom Aarsen

@riyajatar37003
Copy link
Author

doing this way it resolve that issue but don't know is this correct way or not

model = SentenceTransformer('NV-Embed-v2',trust_remote_code=True)
model.max_seq_length = 4096
model.tokenizer.padding_side="right"
model.load_adapter('nv-embed-v2-ft/checkpoint-150')

@riyajatar37003
Copy link
Author

@tomaarsen

@tomaarsen
Copy link
Collaborator

If that path consists an adapter (i.e. adapter_config.json and adapter_model.safetensors) then this is indeed the correct way to load an Adapter.
I believe you can also use load the adapter directly:

model = SentenceTransformer('nv-embed-v2-ft/checkpoint-150', trust_remote_code=True)
model.max_seq_length = 4096
model.tokenizer.padding_side="right"

Which should be equivalent, but I'm not 100% sure.

  • Tom Aarsen

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 20, 2024

model = SentenceTransformer('nv-embed-v2-ft/checkpoint-150', trust_remote_code=True)
model.max_seq_length = 4096
model.tokenizer.padding_side="right"
I tried this way but it throw the same error which i posted above.

@tomaarsen tomaarsen added the bug Something isn't working label Nov 20, 2024
@tomaarsen
Copy link
Collaborator

Okay, that may be a bug in ST, will look into it shortly.

@riyajatar37003
Copy link
Author

thanks

@BenjaminBossan
Copy link

I think this is because PEFT expects models to have the "standard" signature, e.g. https://github.com/huggingface/transformers/blob/f297af55dfc27485189f352cd36b4683de12e0b3/src/transformers/models/qwen2/modeling_qwen2.py#L808-L820 with inputs_embeds as a valid input.

But NV-Embed-v2 does not seem to have this parameter: https://huggingface.co/nvidia/NV-Embed-v2/blob/main/modeling_nvembed.py#L397

I think the only solution is to fix it in this modeling_nvembed.py file, I assume.

This depends. If you use some PEFT methods like prefix-tuning or p-tuning (all "prompt learning" methods), yes, we need to make some assumptions about the underlying model, like input_embeds existing. However, @riyajatar37003 mentions that they use LoRA, which should be much more agnostic towards the base model. Could you please share the code so that I can try to reproduce?

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 20, 2024

# See https://huggingface.co/collections/tomaarsen/training-with-prompts-672ce423c85b4d39aed52853 for some already trained models

import logging
import random
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
import numpy
import torch
from datasets import Dataset, load_dataset

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
random.seed(12)
torch.manual_seed(12)
numpy.random.seed(12)




# Each query needs to be accompanied by an corresponding instruction describing the task.
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}

query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "

# Feel free to adjust these variables:
use_prompts = True
include_prompts_in_pooling = True

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
   'nvidia/NV-Embed-v2',trust_remote_code=True,
)
model.set_pooling_include_prompt(include_prompts_in_pooling)
model.max_seq_length = 4096 #32768
model.tokenizer.padding_side="right"

# 2. Create a LoRA adapter for the model & add it
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,    
    target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']

)
model.add_adapter(peft_config,adapter_name='adaptor_1')

# 2. (Optional) Define prompts
if use_prompts:
    query_prompt = query_prefix
    corpus_prompt = ""
    prompts = {
        "query": query_prompt,
        "answer": corpus_prompt,
    }
from datasets import load_from_disk

# Load the saved dataset back into a Dataset object
# replace this with any of msmarco triplet from sentencen transoformer
train_dataset =  #load_from_disk("train_triplet_ours")
eval_dataset = #load_from_disk("test_triplet_ours")


# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=4)

# 5. (Optional) Specify training arguments
run_name = "nv-embed-v2-nq"
if use_prompts:
    run_name += "-prompts"
if not include_prompts_in_pooling:
    run_name += "-exclude-pooling-prompts"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=2e-5,
    warmup_steps=500,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=5,
    logging_steps=1,
    logging_first_step=True,dataloader_drop_last=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=12,
    prompts=prompts if use_prompts else None,
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = NanoBEIREvaluator(
    query_prompts=query_prompt if use_prompts else None,
    corpus_prompts=corpus_prompt if use_prompts else None,
)
# dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
if __name__=="__main__":
    
    trainer.train()
    # 8. Save the trained model
    model.save_pretrained(f"models/{run_name}/final")
    # (Optional) Evaluate the trained model on the evaluator after training
    dev_evaluator(model)
    
    
    
    
    # 9. (Optional) Push it to the Hugging Face Hub
    # model.push_to_hub(run_name)

@BenjaminBossan
Copy link

@riyajatar37003 I took your code with some small changes (bfloat16, smaller batch size to fit in memory, using this dataset) and it passed for me locally. Where exactly do you get the error? Do you have additional code where you load the model for inference and that's where it fails?

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 20, 2024 via email

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 20, 2024 via email

@BenjaminBossan
Copy link

I could successfully run:

model = SentenceTransformer(<path>, trust_remote_code=True, model_kwargs=dict(torch_dtype=torch.bfloat16))
model.max_seq_length = 4096
model.tokenizer.padding_side="right"

The versions I use:

  • PEFT installed from source
  • transformers installed from source
  • sentence-transformers 3.3.1
  • torch 2.5.1

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 20, 2024 via email

@pranavjadhav001
Copy link

Hello!

Yes, this is possible. If you have k negatives, then you'll have to use a custom loss function as there's no non-IBN loss that takes more than triplets. That should be fine, though. Here is an example:

from __future__ import annotations

from collections.abc import Iterable
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)

from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator

class KTupleLoss(nn.Module):
    def __init__(
        self, model: SentenceTransformer, scale: float = 20.0
    ) -> None:
        super().__init__()
        self.model = model
        self.scale = scale
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
        embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]

        # Collect the anchor, positive, and negative embeddings
        anchor_embeddings = embeddings[0]  # [batch_size, embedding_dim]
        positive_embeddings = embeddings[1]  # [batch_size, embedding_dim]
        negative_embeddings = torch.stack(embeddings[2:], dim=1)  # [batch_size, num_negatives, embedding_dim]

        # Normalize them
        anchor_embeddings = torch.nn.functional.normalize(anchor_embeddings, p=2, dim=-1)
        positive_embeddings = torch.nn.functional.normalize(positive_embeddings, p=2, dim=-1)
        negative_embeddings = torch.nn.functional.normalize(negative_embeddings, p=2, dim=-1)

        # Compute the similarity scores, i.e. 1) pairwise cosine similarity between anchor and positive,
        # and 2) pairwise cosine similarity between anchor and negatives
        pos_similarity = (anchor_embeddings * positive_embeddings).sum(1, keepdim=True) # [batch_size, 1]
        anchor_embeddings_3d = anchor_embeddings.unsqueeze(1) # [batch_size, 1, embedding_dim]
        neg_similarity = torch.matmul(anchor_embeddings_3d, negative_embeddings.transpose(1, 2)).squeeze(1) # [batch_size, num_negatives]

        # Concatenate the positive and negative similarity scores so we have 1 + num_negatives similarity scores per anchor
        scores = torch.cat((pos_similarity, neg_similarity), dim=1) * self.scale # [batch_size, 1 + num_negatives]

        # Set the labels as 0, i.e. the positive sample is always the first one in the scores tensor
        labels = torch.zeros(scores.size(0), dtype=torch.long, device=scores.device)

        return self.cross_entropy_loss(scores, labels)

    def get_config_dict(self) -> dict[str, Any]:
        return {"scale": self.scale}

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on AllNLI triplets",
    )
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"].select(range(1_000))
test_dataset = dataset["test"]

# This is a simple way to turn this into a dataset with k negative samples
def to_k_tuple(sample, k: int = 5):
    return {
        "anchor": sample["anchor"],
        "positive": sample["positive"],
        "negative": sample["negative"],
        **{
            f"negative_{i}": sample["negative"] for i in range(k - 1)
        }
    }

train_dataset = train_dataset.map(to_k_tuple, fn_kwargs={"k": 5})
eval_dataset = eval_dataset.map(to_k_tuple, fn_kwargs={"k": 5})
test_dataset = test_dataset.map(to_k_tuple, fn_kwargs={"k": 5})

# 4. Define a loss function
loss = KTupleLoss(model)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-all-nli-ktuple"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-nli-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

The KTupleLoss from this calculates the similarity score for anchor vs positive, as well as anchor vs k negatives. These are then concatenated such that for every single anchor, you get k+1 similarity scores. The first similarity score corresponds to the positive one, so we can set "labels" as a list of zeros (representing the index with the true positive score).

The loss can be optimized I believe (I think you can just take anchor_embeddings = embeddings[0] and then other_embeddings = torch.stack(embeddings[1:], dim=1), no need to separate the positive and the negative), but this should work.

This is akin to the MultipleNegativesRankingLoss, except no in-batch negatives.

You can create a k-negative dataset like so:

    dataset_final[t] = Dataset.from_dict({
        "query": query,
        "positive": positive,
        "negative_1": negative_1,
        "negative_2": negative_2,
        ...,
        "negative_k": negative_k,
    })

Here are my first logs: as you can see the model indeed learns (even though this test script makes k negatives in a bit of a hacky way by just repeating the actual 1 negative)

{'loss': 1.8269, 'grad_norm': 14.627801895141602, 'learning_rate': 3.04e-06, 'epoch': 0.02}
{'eval_loss': 1.4666913747787476, 'eval_all-nli-dev_cosine_accuracy': 0.719, 'eval_runtime': 20.5417, 'eval_samples_per_second': 48.682, 'eval_steps_per_second': 3.067, 'epoch': 0.02}                                                                                                                                               
{'loss': 1.0463, 'grad_norm': 60.64706802368164, 'learning_rate': 6.176000000000001e-06, 'epoch': 0.03}                                                            
{'eval_loss': 0.7802979946136475, 'eval_all-nli-dev_cosine_accuracy': 0.859, 'eval_runtime': 20.9557, 'eval_samples_per_second': 47.72, 'eval_steps_per_second': 3.006, 'epoch': 0.03}                                                                                                                                                
{'loss': 0.6681, 'grad_norm': 35.55289840698242, 'learning_rate': 9.376000000000001e-06, 'epoch': 0.05}                                                            
{'eval_loss': 0.540037989616394, 'eval_all-nli-dev_cosine_accuracy': 0.902, 'eval_runtime': 20.9779, 'eval_samples_per_second': 47.669, 'eval_steps_per_second': 3.003, 'epoch': 0.05}                                                                                                                                                
{'loss': 0.428, 'grad_norm': 22.288776397705078, 'learning_rate': 1.2576000000000001e-05, 'epoch': 0.06}                                                           
{'eval_loss': 0.49275851249694824, 'eval_all-nli-dev_cosine_accuracy': 0.913, 'eval_runtime': 20.7686, 'eval_samples_per_second': 48.15, 'eval_steps_per_second': 3.033, 'epoch': 0.06}                                                                                                                                               
  6%|███████▊                                                                                                                | 405/6250 [05:50<3:57:49,  2.44s/it]

I do want to say that in-batch negatives often help.

  • Tom Aarsen

Ran into similar problem recently where i wanted to avoid to in batch negatives, I created a solution which was almost identical to one you have shared. The issue is MNR loss creates [batch_size,embedding_size] tensors but for the above case its [batch_size,n_negatives+2,embedding_size] which quickly blows out of proportion, I currently have to train with 1/4th of MNR batch size, @tomaarsen any tips for optimization?

@tomaarsen
Copy link
Collaborator

Hmm, I believe MNRL should also use [batch_size, embedding_size] tensors for each column, so e.g. equivalent to [batch_size, n_negatives+2, embedding_size]. Perhaps there's some optimization still possible with reusing tensors. I can't really say - I'm only vaguely familiar with the best practices for memory optimization in torch.

  • Tom Aarsen

@riyajatar37003
Copy link
Author

hi @tomaarsen
Can you explain what happens when we keep this flag true or false
include_prompts_in_pooling = False
model.set_pooling_include_prompt(include_prompts_in_pooling)

so during encoding which token's representation will be considered as embedding ?
thanks

@tomaarsen
Copy link
Collaborator

Sentence Transformer models consist of a few steps:
text -> tokens -> token embeddings -> text embeddings

In the last transition, i.e. token embeddings -> text embeddings, we do pooling. For example mean pooling (text embedding is the average of all token embedding), or CLS embedding (text embedding is the first token embedding).
Some researchers add a prompt or instruction text in front of their text, like query: or Represent this sentence for searching relevant passages: , and some of those want to exclude the token embeddings from the prompt/instruction from the eventual pooling. If you call model.set_pooling_include_prompt(False), then the prompt will not be included in the pooling.

In my tests (see https://sbert.net/examples/training/prompts/README.html#training-script - Experiments with bert-base-uncased), I got the best performance when keeping the include_prompt as the default True.

Details: https://sbert.net/examples/training/prompts/README.html

  • Tom Aarsen

@riyajatar37003
Copy link
Author

riyajatar37003 commented Nov 26, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants