-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
'scale' hyperparameter in MultipleNegativesRankingLoss #3054
Comments
Hello! I'm not actually super sure on the origin of this parameter, Nils Reimers added it before I took over. My understanding is that the
Here's an example script of manually going through the loss: from torch import nn
import torch
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
# Let's take 1 sample and send it through our loss, except now "manually"
anchor = "is toprol xl the same as metoprolol?"
positive = "Metoprolol succinate is also known by the brand name Toprol XL. It is the extended-release form of metoprolol. Metoprolol succinate is approved to treat high blood pressure, chronic chest pain, and congestive heart failure."
negative_1 = "The Are You Experienced album was apparently mastered from the original stereo UK master tapes (according to Steve Hoffman - one of the very few who has heard both the master tapes and the CDs produced over the years). ... The CD booklets were a little sparse, but at least they stayed true to the album's original design."
negative_2 = "Matryoshka dolls are made of wood from lime, balsa, alder, aspen, and birch trees; lime is probably the most common wood type. ... After cutting, the trees are stripped of most of their bark, although a few inner rings of bark are left to bind the wood and keep it from splitting."
negative_3 = "The eyes are always the same size from birth to death. Baby eyes are proportionally larger than adult eyes, but they are still smaller."
# For now we assume that these negatives are in the same sample, so we train with 5 columns: anchor, positive, negative_1, negative_2, negative_3
# We now encode both the anchor, and the "candidate positives" out of which we want to find the real positive
anchor_embedding = model.encode(anchor)
candidate_embeddings = model.encode([positive, negative_1, negative_2, negative_3])
print(anchor_embedding.shape)
# (384,) a.k.a. 1 embedding of 384 dimensions
print(candidate_embeddings.shape)
# (4, 384) a.k.a. 4 embeddings of 384 dimensions
similarities = model.similarity(anchor_embedding, candidate_embeddings)
print(similarities)
# tensor([[0.7811, 0.0835, 0.0644, 0.0639]])
# a.k.a anchor is most similar to positive, not very similar to the 3 negatives
# Let's set up our loss
cross_entropy_loss = nn.CrossEntropyLoss()
# And we need a label, i.e. we need to know which of the 4 non-anchor embeddings is the positive one
# We can do this by setting label as the index of the true positive in the candidate_embeddings.
# In this case, the true positive is the first one, so the label is 0
label = 0
# And let's iterate over the scales to calculate the loss:
for scale in range(30):
# Now we can calculate the loss
loss = cross_entropy_loss(similarities * scale, torch.tensor([label]))
print(f"Loss with scale {scale}: {loss.item():.4f}") With these results:
Let's manually create some similarities and go through those: from torch import nn
import torch
similarities_list = [
torch.tensor([[0.7811, 0.0835, 0.0644, 0.0639]]),
torch.tensor([[0.5842, 0.5243, 0.5351, 0.5124]]),
torch.tensor([[0.4842, 0.5243, 0.5351, 0.5124]]),
torch.tensor([[0.2424, 0.4243, 0.5382, 0.4244]]),
]
descriptions = [
"Great similarity to positive",
"Slightly more similarity to positive",
"Slightly less similarity to positive",
"Low similarity to positive",
]
# Let's set up our loss
cross_entropy_loss = nn.CrossEntropyLoss()
# And we need a label, i.e. we need to know which of the 4 non-anchor embeddings is the positive one
# We can do this by setting label as the index of the true positive in the candidate_embeddings.
# In this case, the true positive is the first one, so the label is 0
label = 0
for similarities, description in zip(similarities_list, descriptions):
# And let's iterate over the scales to calculate the loss:
print(description)
for scale in range(30):
# Now we can calculate the loss
loss = cross_entropy_loss(similarities * scale, torch.tensor([label]))
print(f"Loss with scale {scale}: {loss.item():.4f}")
print()
So: a higher scale is harsher when the performance is bad (last case), while it's softer when the performance is good (first case). Beyond that, a higher scale is softer when the positive is slightly better than the negatives, and it's harder when the positive is slightly worse than the negatives. @daegonYu once asked what'd happen if we set the Great similarity to positive
Loss with scale 50: 0.0000
Slightly more similarity to positive
Loss with scale 50: 0.1514
Slightly less similarity to positive
Loss with scale 50: 3.2294
Low similarity to positive
Loss with scale 50: 14.7967 What scale results in the best overall performance is an unanswered question. I think it would be really fascinating actually if someone set up a training script that trains e.g. 40 small models, as I'm not sure if the best performance would be around 20, the default.
|
I started some training jobs for testing the different |
Ok, so I'm trying to intuitively understand this. If I am using a well curated dataset for training where each anchor has a positive example and the negative examples are generated from the other positive examples in the batch (in-batch negatives), then wouldn't it make sense to use a higher scale? Since there is good performance with the positive example's similarity score being better than that of the negative example (assuming the model can generate these similarity scores because the in-batch negative sentences should be significantly different from the positive sentence for an example). |
In my opinion, the intuition is a tad hard to understand, and it really seems like the best approach for now is to just run some tests. Speaking of which, these are the findings from my experiments yesterday: I suspect that the difference between the scale parameters shrinks a lot once you add more training data, but perhaps it's worthwhile to consider a higher scale under similar settings? I'd consider testing with your data to see what works best for you.
|
Thank you for sharing the good experimental results. I saw the experimental results with different Scale values above. What's interesting is that when Scale is 0, the Loss is the same at 1.3863. This means that if you train with Scale 0, you can get the same Loss value when training with data corresponding to "Great similarity to positive" and when training with data corresponding to "Low similarity to positive". I wonder if that means the model is not training properly. In addition, I personally posted a GitHub issue (microsoft/unilm#1588) about Microsoft's E5 model. Can you tell me why the following phenomenon occurs? "The logits are calculated with cosine_similarity / t. Therefore, the logits will fall in [-100, 100] with t = 0.01 and [-50, 50] with t=0.02, etc. However, this does not mean the learned cosine similarity will be in a wider range. On the contrary, the cosine similarity tends to concentrate as the temperature becomes lower." I understand that logits will fall in [-100, 100], which means that a lower temperature allows the logits to vary in a wider range, but I still do not understand why the cosine similarity tends to concentrate as the temperature becomes lower. Is this just an experimental result and the cause is unknown? |
Indeed, it's because we do As for the cosine similarity concentration: I'm actually not sure why this happens. My intuition however is that because
|
Thank you so much for your great advice. Your intuition has helped me a lot. Thanks! |
I am looking through the MultipleNegativesRankingLoss.py code and I have question about the 'scale' hyperparameter. Also known as the 'temperature', the scale is used to stretch or compress the range of output values from the similarity function. A larger scale creates greater distinction between positive and negative examples in terms of similarity score differences. The line below is how the scale is used in the forward function of the loss.
scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
Currently, the scale is set to 20 for when cosine similarity is used as the distance metric.
Why was 20 selected as the scale for when using cosine similarity on the embeddings? Is this the optimal scale value for cosine similarity? Would this hyperparameter need to be optimized during fine-tuning?
The text was updated successfully, but these errors were encountered: