Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SigLIP-style loss for better DDP #3119

Open
kddubey opened this issue Dec 5, 2024 · 5 comments
Open

SigLIP-style loss for better DDP #3119

kddubey opened this issue Dec 5, 2024 · 5 comments

Comments

@kddubey
Copy link
Contributor

kddubey commented Dec 5, 2024

Hello,

SigLIP demonstrates that formulating contrastive learning as a collection of independent classifications works well.

One benefit is that the loss easily allows training with more negatives—easily b/c every summand in the loss is independent of the rest. There's no softmax/normalization like there is in MNRL. CachedMultipleNegativesRankingLoss already enables large batch sizes for single-device training. But for DDP, IIUC, any implementation of MNRL requires more communication overhead than SigLIP b/c of the softmax.

The SigLIP paper studies image-text data. To see if their sigmoid-style loss works for text data, I ran a tiny experiment here. It demonstrates that the performance is on par with MNRL on STS. (The notebook doesn't implement the actual, distributed SigLIP training scheme; it just checks that sigmoid instead of softmax works well for STS.)

I'm wondering if you or others are interested in incorporating a SigLIP-style training scheme into SentenceTransformers.

Thanks!

@tomaarsen
Copy link
Collaborator

Hello!

CachedMultipleNegativesRankingLoss is indeed our "solution of choice" for training with large batch sizes currently. Although discussed in #2831, it doesn't do any cross-device negatives sharing. As a result, presumably there's not a significant amount of communication overhead, right? I may be off here, though.

For example, you want to train with a batch size of 16384 and you have 8 GPUs. Each GPU can handle 64 samples at a time before OOM-ing. Then you can use DDP with 8 processes, a per_device_train_batch_size of 16384, and CachedMultipleNegativesRankingLoss with a mini_batch_size of 64.

The total batch size will be 131072, and that many samples are collected at the start of the step. These are then split up into 8 subbatches of 16384 samples and divided across the 8 GPUs. Each GPU will process these with CMNRL as if it's the only GPU, and the only communication is once all 8 are finished: when the gradients are averaged across all GPUs.

Please let me know if I'm overlooking something!


Also, looking at your script - the inputs (i.e. scores) must be Sigmoid-ed before they're fed through the BCELoss, right? The scale of 20 means that the scores range between 0 and 20, and with the bias it's between -10 and 10.

Another note: I like the idea of a dynamic scale by making it a Parameter. I experimented with this briefly, but to turn it into something "meaningful", I'd have to override the optimizer kwargs with a much higher learning rate here:

optimizer_kwargs["optimizer_dict"] = [
{
"params": [
p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]

  • Tom Aarsen

@kddubey
Copy link
Contributor Author

kddubey commented Dec 5, 2024

I think your explanation is right; CMNRL + DDP will basically do plain DDP (plus accumulating gradients w/in each device under the hood) and won't share negatives. The motivation of the issue is to enable this sharing of negatives in a DDP setup.

Here's what I mean. Say we have G GPUs. Our global batch size is B, and each anchor in our dataset has N negatives. Each GPU is responsible for computing the loss for its B/G anchors. With (C)MNRL + DDP, each anchor in each GPU only sees its local (B/G) * (1 + N) - 1 negatives.

SigLIP is not exactly plain DDP. It allows each anchor to efficiently see all other negatives in the global batch (see Figure 1 and Section 3.3 of SigLIP), i.e., each of the B/G anchors in each GPU sees B * (1 + N) - 1 negatives—~G times as many as before. The result is hopefully that you get a better update after each batch, which gets you a better model in less time. (Also, we should still be able to accumulate mini-batches w/ the sigmoid loss, as already done in CMNRL.)

Ofc, we can achieve seeing as many negatives as we want via accumulating mini-batches in CMNRL. But that accumulation will happen in sequence instead of in parallel.

Does this motivation sound right?

Now that I'm writing this out though, I realized that SigLIP will do G times as many similarity calculations: (C)MNRL + DDP does (B/G) * (B/G) * (1 + N) * G while SigLIP does B * B * (1 + N). This fact plus the discussion in #2831 makes it definitely possible that CMNRL is not slower than SigLIP. So we'd probably need to have some concrete SigLIP DDP experiments before anchoring (ha) too hard on it.

I'm now less convinced about it / will probably happily stick w/ CMNRL for now :-)

Also, looking at your script - the inputs (i.e. scores) must be Sigmoid-ed before they're fed through the BCELoss, right?

The script uses BCEWithLogitsLoss, which assumes the inputs are logits. (I should've called the object bce_with_logits_loss instead of bce_loss.)

Another note: I like the idea of a dynamic scale by making it a Parameter. I experimented with this briefly, but to turn it into something "meaningful", I'd have to override the optimizer kwargs with a much higher learning rate here:

Good to know! That's something I missed in my experiment; the scale and bias only shifted by 0.002 lol

@tomaarsen
Copy link
Collaborator

The motivation is indeed sound: sharing negatives allows for bigger batches, but I think it does not necessarily make sense to use cross-device negatives if we can also arbitrarily increase the batch size per device. All it really does is introduce cross-device communication, right?

The script uses BCEWithLogitsLoss, which assumes the inputs are logits. (I should've called the object bce_with_logits_loss instead of bce_loss.)

Oops, I totally missed that. I was indeed stuck on bce_loss. I was already a bit surprised you got such good results while seemingly using the loss incorrectly.


I think there is merit in providing more elaboration in the documentation somewhere (either the Distributed Training section, or the CMNRL API reference) about really large batch size cases. I see a lot of papers use cross-device batches to try and increase their batch size, not realising that GradCache can solve all of their problems.

So researchers seem to look for cross-device solutions in Sentence Transformers too, not realising that we purposefully don't support it (as I think CMNRL is an equivalent or superior option).

  • Tom Aarsen

@kddubey
Copy link
Contributor Author

kddubey commented Dec 6, 2024

All it really does is introduce cross-device communication, right?

It does, but SigLIP's cross-device communication is relatively efficient. SigLIP shifts embeddings from each device to its neighbor ("collective permute" in the paper) instead of doing all-gathers. I believe this shift has to happen G - 1 times per batch though. So it's definitely possible that the increase in negatives doesn't justify the cost in communication time. There would need to be hard experiments to say either way.

Overall, I think my concerns are mostly washed away. Thank you for the discussion!

@kddubey
Copy link
Contributor Author

kddubey commented Dec 8, 2024

Another application of the sigmoid/multi-label loss is that it seamlessly supports multi-positive and multi-negative data. (Pretty sure both positives and negatives can come in variable numbers.) I'm not yet certain about the benefits of this over unfolding the positives and feeding them to MNRL. But it seems like a potential use case. Maybe gets higher training throughput. What do you think?

Update: I'm investigating this in https://github.com/kddubey/mpnrl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants