Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify GISTEmbedLoss and CachedGISTEmbedLoss to automatically remove duplicate positives from being considered negatives #2756

Open
tomaarsen opened this issue Jun 17, 2024 · 10 comments · May be fixed by #3074
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@tomaarsen
Copy link
Collaborator

Hello!

The (Cached)GISTEmbedLoss classes mask away certain in-batch negatives as they might actually be positives right here:

# Find which samples cannot be used as negatives because they are
# more similar to the query than the assigned positive as deemed by the guide model.
# For these samples, we mask them with -inf to basically ignore their contribution to
# the loss.
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf

and here:

# Find which samples cannot be used as negatives because they are
# more similar to the query than the assigned positive as deemed by the guide model.
# For these samples, we mask them with -inf to basically ignore their contribution to
# the loss.
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf

However, consider a scenario with (anchor, positive) pairs, where the same positive text occurs multiple times in the batch. This is quite bad, as this sample is now used both as a positive and as an in-batch negative. However, (Cached)GISTEmbedLoss should be able to detect this, as the guided_pp_sim and the guided_sim will be identical here. So, I think we can safely replace pp_sim[guided_pp_sim > guided_sim] = -torch.inf with pp_sim[guided_pp_sim >= guided_sim] = -torch.inf to automatically prevent duplicate positives from being labeled as in-batch negatives.

I haven't made this PR myself yet, because it will require some experimentation/testing to see if this doesn't accidentally hurt performance. However, conceptually it should only improve models.

cc @avsolatorio

  • Tom Aarsen
@tomaarsen tomaarsen added the enhancement New feature or request label Jun 17, 2024
@daegonYu
Copy link
Contributor

There may be some things I don't know well, but shouldn't I delete the duplicate between Anchors and also delete the duplicate between Anchors and Positives?

@tomaarsen
Copy link
Collaborator Author

Removing duplicate anchors might indeed be smart: if they're included in the same batch, then the positive from the other (but identical) anchor will be used as a negative right now.

Duplicates between anchors and positives shouldn't matter too much I think: the loss for that sample is 0, so it won't learn from it.

  • Tom Aarsen

@tomaarsen tomaarsen added the good first issue Good for newcomers label Jun 20, 2024
@JINO-ROHIT
Copy link
Contributor

@tomaarsen is this still open? Can i take this up?

@tomaarsen
Copy link
Collaborator Author

Definitely! Feel free to work on it.

  • Tom Aarsen

@JINO-ROHIT
Copy link
Contributor

is there some script or ds i should benchmark the changes to make sure the perf doesnt drop?

@tomaarsen
Copy link
Collaborator Author

tomaarsen commented Nov 13, 2024

I would take a script like this one: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/prompts/training_nq_prompts.py

I modified that script here, with the prompts removed, used the GISTEmbedLoss & added a updated_loss variable. I also changed the training dataset to 10k samples (feel free to increase/decrease) which is then repeated 5 times. The idea is that we want duplicates in a batch, because that's exactly the case that we want to fix.

You can run this first with the original GISTEmbedLoss & CachedGISTEmbedLoss, and then also when you're updating it. If you have wandb or tensorboard, then you can compare results during training easily. Ideally, when you're done, you'll have 4 models, and the 2 new models should be either roughly equivalent or slightly better. If you upload them with the final line of this script, then you can easily make them public later and share that it works (plus, they might look nice on your profile).

Here it is, feel free to modify it to your likings (e.g. different base model, datasets, anything):

import logging
import random

import numpy
import torch
from datasets import Dataset, load_dataset, concatenate_datasets

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
random.seed(12)
torch.manual_seed(12)
numpy.random.seed(12)

# Set this to True when you've updated the loss function(s)
updated_loss = False

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on Natural Questions pairs",
    ),
)

# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
train_dataset: Dataset = dataset_dict["train"].select(range(10_000)) # Select 10k training samples, feel free to increase
# We then duplicate the 10k samples 5 times so that there's a decent chance that some batches have duplicate samples. 
# After all, that's the case we want to test/update
train_dataset: Dataset = concatenate_datasets([train_dataset] * 5)
eval_dataset: Dataset = dataset_dict["test"]

# 3. Define a loss function
guide = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Experiment with both GISTEmbedLoss and CachedGISTEmbedLoss
loss = GISTEmbedLoss(model, guide)

# 4. (Optional) Specify training arguments
run_name = "mpnet-base-nq"
if isinstance(loss, GISTEmbedLoss):
    run_name += "-gist"
elif isinstance(loss, CachedGISTEmbedLoss):
    run_name += "-cgist"
if updated_loss:
    run_name += "-new"
else:
    run_name += "-old"

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    # batch_sampler=BatchSamplers.NO_DUPLICATES,  # Although the loss benefits from having no duplicate samples in a batch
	# we want to specifically test with duplicate samples as those should start being ignored.
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    logging_steps=5,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=12,
)

# 5. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = NanoBEIREvaluator()
dev_evaluator(model)

# 6. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# 7. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 8. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)

P.s. I just modified this script in GitHub itself, so there might be a bug/typo somewhere.

  • Tom Aarsen

@JINO-ROHIT
Copy link
Contributor

gotcha, thanks so much!

@JINO-ROHIT
Copy link
Contributor

Hi @tomaarsen I was able to benchmark the changes and logged them here - https://wandb.ai/jinooo/sentence-transformers

Basically I did 2 sets of eval.

Set 1
Smaller batch size - 16
Epoch - 1
The eval scores for the old and new updated losses were similar and I felt I had to try larger batch sizes.

Set 2
Batch size - 256
Epoch - 1
The eval scores with the updated losses are slightly better in some benchmarks and on par exactly for the others. I think this should prolly help over longer epochs. This was a good find!

Lmk if you need me to raise a PR for this

@tomaarsen
Copy link
Collaborator Author

Okay, fair enough. Are your changes similar to #3063?

I'm definitely interested in a PR.

  • Tom Aarsen

@JINO-ROHIT
Copy link
Contributor

yeap, similar changes , ill raise a PR for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants