-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about Group DRO implementation #33
Comments
NicholasCorrado
changed the title
Question about DRO implementation
Question about Group DRO implementation
Oct 16, 2024
Hi Nicholas, you're correct that the code is slightly different from the
Group DRO loss. The weighting scheme in the code follows the pytorch
CrossEntropyLoss way of doing weighted loss over different classes, and is
equivalent to the group DRO loss as batch size grows, but is a biased
estimate with finite samples. I'd expect similar behavior with either one
though. I would also take a look at https://github.com/kohpangwei/group_DRO
for another reference for implementing Group DRO.
…On Wed, Oct 16, 2024 at 1:55 PM Nicholas E. Corrado < ***@***.***> wrote:
First, thank you for the solid work and for making this code public -- the
paper makes some great insights and the code is very clean! I'm using this
codebase a reference for implementing a variant of Group DRO, and I had a
clarification question on the loss computation.
The Group DRO loss stated in the DoReMi paper
<https://arxiv.org/abs/2305.10429> is (Eq. 1):
$$\min_{\theta} \max_{\alpha \in \Delta^k} \mathcal L(\theta, \alpha) :=
\sum_{i=1}^k \alpha_i \left[\frac{1}{\sum_{x\in D_i}|x|} \sum_{x \in
D_i}\ell_\theta(x) - \ell_\text{ref}(x)\right]$$
However, it looks like the code is actually optimizing
$$\min_{\theta} \max_{\alpha \in \Delta^k} \mathcal L(\theta, \alpha) :=
\frac{1}{\sum_{x\in D}|x|}\sum_{i=1}^k \alpha_i \left[\sum_{x \in
D_i}\ell_\theta(x) - \ell_\text{ref}(x)\right]$$
In particular, it looks like the loss is a reweighted average *over all
samples across all domains* rather than a reweighted *sum of averages
over domain-specific losses.*
The domain weight update computes the average domain-specific losses here:
https://github.com/sangmichaelxie/doremi/blob/7cde52d1848737aa967ecbdb9e643cf334de160d/doremi/trainer.py#L252C22-L252C110
I would expect to see a similar computation for the model parameter
updates, but it looks like the code computes the total loss across all
domains, reweights it by the domain weights, and then normalizes by a
constant normalizer (a reweighted average loss over all samples in all
domains).
https://github.com/sangmichaelxie/doremi/blob/7cde52d1848737aa967ecbdb9e643cf334de160d/doremi/trainer.py#L363C17-L363C89
Could you please clarify if the Group DRO loss stated in Eq. 1 is indeed
implemented in the code, or if it is summing domain-specific average? Thank
you!
—
Reply to this email directly, view it on GitHub
<#33>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AB4J2SWZX3NZEDFH7OJEHFDZ32ZAPAVCNFSM6AAAAABQCDXU3KVHI2DSMVQWIX3LMV43ASLTON2WKOZSGU4TENZXHAYDKOA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
First, thank you for the solid work and for making this code public -- the paper makes some great insights and the code is very clean! I'm using this codebase a reference for implementing a variant of Group DRO, and I had a clarification question on the loss computation.
The Group DRO loss stated in the DoReMi paper is (Eq. 1):
However, it looks like the code is actually optimizing
In particular, it looks like the loss is a reweighted average over all samples across all domains rather than a reweighted sum of averages over domain-specific losses.
The domain weight update computes the average domain-specific losses here: https://github.com/sangmichaelxie/doremi/blob/7cde52d1848737aa967ecbdb9e643cf334de160d/doremi/trainer.py#L252C22-L252C110
I would expect to see a similar computation for the model parameter updates, but it looks like the code computes the total loss across all domains, reweights it by the domain weights, and then normalizes by a constant
normalizer
(a reweighted average loss over all samples in all domains).https://github.com/sangmichaelxie/doremi/blob/7cde52d1848737aa967ecbdb9e643cf334de160d/doremi/trainer.py#L363C17-L363C89
Could you please clarify if the Group DRO loss stated in Eq. 1 is indeed implemented in the code, or if it is summing domain-specific average? Thank you!
The text was updated successfully, but these errors were encountered: