Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Check test if the GenerationTesterMixin inheritance is correct 🐛 🔫 #36180

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

gante
Copy link
Member

@gante gante commented Feb 13, 2025

What does this PR do?

If a model inherits from GenerationMixin, then it can call generate. In those models, GenerationTesterMixin must be added to the tester in order to test generate. Sometimes we forget to do it. Not anymore 😈

This PR follows up on #33212 and removes the last piece of common human error. It adds a meta-test to ensure we have the correct inheritance, which enables us to confirm that the right tests are being run automatically. In other words, if the model inherits from GenerationMixin, this test confirms that the tester inherits from GenerationTesterMixin (and the other way around too). This test also disincentivizes complex tester inheritance hierarchies, as in idefics, bark, or musicgen, which hurt long-term maintenance.

In the process, it uncovered *many* cases of incorrect/missing inheritance, which are dealt with in this PR. Look at all these models that should have been tested with generate and haven't been since they were added 😭

👉 reviewer: please start by reviewing tests/test_modeling_common.py and modeling_utils.py, then check the other files.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante changed the title [CI] Check test if the GenerationTesterMixin inheritance is correct [CI] Check test if the GenerationTesterMixin inheritance is correct 🐛 🔫 Feb 14, 2025
Comment on lines -1656 to -1658
# Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate):
return True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can_generate() is only used in GenerationMixin-related code. Let's remove time series model from this function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it completely different or uses part of generate(), like some audio models?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's completely different ☠️

@@ -2753,6 +2753,41 @@ def test_speculative_sampling_target_distribution(self):
self.assertTrue(last_token_counts[8] > last_token_counts[3])


global_rng = random.Random()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from test_modeling_common, otherwise we would have circular dependencies

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a comment with "copied from" can be added i think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these used in a lot of places in this file, or just inside one method?

If so, we can probably avoid circular dependencies by importing them within that (single) method ..?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, moving to an internal import to prevent code bloat

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmmm local imports would be needed in many places, will go with # Copied from instead

Copy link
Collaborator

@ydshieh ydshieh Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one another possible approach is not to use

issubclass(self.class, GenerationTesterMixin),

but check the __name__ of all parent classes (but recursively, but maybe there is no bulit-in function to do this?) and see if GenerationTesterMixin is in that set.

Up to you.

p.s. # copied is less maintained and we would like to get away from it (since now we rely more on modular)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also okay to have a pure copy of the short functions :P It's just a handful of lines, I don't think it's worth the extra work for now -- I will have to refactor these lines when we remove TF (i.e. very soon) 👀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK 👍

@gante gante marked this pull request as ready for review February 14, 2025 12:33
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for aligning generation tests into one way. I left a few questions, for parts I didn't understand

@@ -2753,6 +2753,41 @@ def test_speculative_sampling_target_distribution(self):
self.assertTrue(last_token_counts[8] > last_token_counts[3])


global_rng = random.Random()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a comment with "copied from" can be added i think

Comment on lines -1656 to -1658
# Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate):
return True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it completely different or uses part of generate(), like some audio models?

Comment on lines +454 to +455
# Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante
all_generative_model_classes = ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding: do we need to have empty all_generative_model_classes if no GenerationTesterMixin is added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a model inherits GenerationMixin, one of these two must happen:

  1. GenerationTesterMixin must be in the tester
  2. [when the tests are broken, this should NEVER happen on new models] we specify all_generative_model_classes = () to make it very clear we're NOT running generation tests

option 2 is intentionally annoying (we are forced to overwrite a property), so we are very explicit about skipping tests. We don't want skips to happen unless we're very intentional about it.

Comment on lines +423 to +425
@unittest.skip(reason="TODO: Fix me @joao")
def test_generate_with_head_masking(self):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head-mask test can be deleted after #35786 is merged

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@gante
Copy link
Member Author

gante commented Feb 14, 2025

@zucchini-nlp PR comments addressed/replied 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants