Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] use_gather_object is not respected after the first eval in trainer #36213

Open
ducha-aiki opened this issue Feb 15, 2025 · 1 comment · May be fixed by #36214
Open

[bug] use_gather_object is not respected after the first eval in trainer #36213

ducha-aiki opened this issue Feb 15, 2025 · 1 comment · May be fixed by #36214

Comments

@ducha-aiki
Copy link
Contributor

Hi,

Consider this as a mix of bug report and PR. Unfortunately, it would take too long for me to cook the toy example to reproduce, and the issue is quite simple and obvious enough.

So, there is a use_gather_object argument in TrainingArguments, allowing to use non-tensors in eval, or tensors with different shapes.
It is handled here:

https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L5103

        # create accelerator object
        self.accelerator = Accelerator(**args)
        # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
        self.gather_function = self.accelerator.gather_for_metrics

        if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
            self.gather_function = functools.partial(
                self.gather_function, use_gather_object=self.args.eval_use_gather_object
            )

However, I have noticed, that after the first eval, the 2nd eval is crashing, while trying to concat batches with different shapes, as if the flag eval_use_gather_object stopped working. Indeed it does, because here in evaluation_loop, the self.gather_function is reset, and the .eval_use_gather_object is not used
https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L4359

        # After all calls to `.gather_function`, reset to `gather_for_metrics`:
        self.gather_function = self.accelerator.gather_for_metrics

Suggested fix (I am using it myself): add the same line to use the flag.
https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L4359

        if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
            self.gather_function = functools.partial(
                self.gather_function, use_gather_object=self.args.eval_use_gather_object
            )
@ducha-aiki ducha-aiki linked a pull request Feb 15, 2025 that will close this issue
@SunMarc
Copy link
Member

SunMarc commented Feb 17, 2025

Hi @ducha-aiki, thanks for the report ! I just reviewed your PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants