-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify GISTEmbedLoss and CachedGISTEmbedLoss to automatically remove duplicate positives from being considered negatives #2756
Comments
There may be some things I don't know well, but shouldn't I delete the duplicate between Anchors and also delete the duplicate between Anchors and Positives? |
Removing duplicate anchors might indeed be smart: if they're included in the same batch, then the positive from the other (but identical) anchor will be used as a negative right now. Duplicates between anchors and positives shouldn't matter too much I think: the loss for that sample is 0, so it won't learn from it.
|
@tomaarsen is this still open? Can i take this up? |
Definitely! Feel free to work on it.
|
is there some script or ds i should benchmark the changes to make sure the perf doesnt drop? |
I would take a script like this one: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/prompts/training_nq_prompts.py I modified that script here, with the prompts removed, used the GISTEmbedLoss & added a You can run this first with the original GISTEmbedLoss & CachedGISTEmbedLoss, and then also when you're updating it. If you have Here it is, feel free to modify it to your likings (e.g. different base model, datasets, anything): import logging
import random
import numpy
import torch
from datasets import Dataset, load_dataset, concatenate_datasets
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerModelCardData,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
random.seed(12)
torch.manual_seed(12)
numpy.random.seed(12)
# Set this to True when you've updated the loss function(s)
updated_loss = False
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
"microsoft/mpnet-base",
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on Natural Questions pairs",
),
)
# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
train_dataset: Dataset = dataset_dict["train"].select(range(10_000)) # Select 10k training samples, feel free to increase
# We then duplicate the 10k samples 5 times so that there's a decent chance that some batches have duplicate samples.
# After all, that's the case we want to test/update
train_dataset: Dataset = concatenate_datasets([train_dataset] * 5)
eval_dataset: Dataset = dataset_dict["test"]
# 3. Define a loss function
guide = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Experiment with both GISTEmbedLoss and CachedGISTEmbedLoss
loss = GISTEmbedLoss(model, guide)
# 4. (Optional) Specify training arguments
run_name = "mpnet-base-nq"
if isinstance(loss, GISTEmbedLoss):
run_name += "-gist"
elif isinstance(loss, CachedGISTEmbedLoss):
run_name += "-cgist"
if updated_loss:
run_name += "-new"
else:
run_name += "-old"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=256,
per_device_eval_batch_size=256,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
# batch_sampler=BatchSamplers.NO_DUPLICATES, # Although the loss benefits from having no duplicate samples in a batch
# we want to specifically test with duplicate samples as those should start being ignored.
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=2,
logging_steps=5,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 5. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = NanoBEIREvaluator()
dev_evaluator(model)
# 6. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# 7. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 8. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name) P.s. I just modified this script in GitHub itself, so there might be a bug/typo somewhere.
|
gotcha, thanks so much! |
Hi @tomaarsen I was able to benchmark the changes and logged them here - https://wandb.ai/jinooo/sentence-transformers Basically I did 2 sets of eval. Set 1 Set 2 Lmk if you need me to raise a PR for this |
Okay, fair enough. Are your changes similar to #3063? I'm definitely interested in a PR.
|
yeap, similar changes , ill raise a PR for this |
Hello!
The (Cached)GISTEmbedLoss classes mask away certain in-batch negatives as they might actually be positives right here:
sentence-transformers/sentence_transformers/losses/CachedGISTEmbedLoss.py
Lines 263 to 269 in f012ab3
and here:
sentence-transformers/sentence_transformers/losses/GISTEmbedLoss.py
Lines 132 to 138 in f012ab3
However, consider a scenario with (anchor, positive) pairs, where the same
positive
text occurs multiple times in the batch. This is quite bad, as this sample is now used both as a positive and as an in-batch negative. However, (Cached)GISTEmbedLoss should be able to detect this, as theguided_pp_sim
and theguided_sim
will be identical here. So, I think we can safely replacepp_sim[guided_pp_sim > guided_sim] = -torch.inf
withpp_sim[guided_pp_sim >= guided_sim] = -torch.inf
to automatically prevent duplicate positives from being labeled as in-batch negatives.I haven't made this PR myself yet, because it will require some experimentation/testing to see if this doesn't accidentally hurt performance. However, conceptually it should only improve models.
cc @avsolatorio
The text was updated successfully, but these errors were encountered: