-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CI] Check test if the GenerationTesterMixin
inheritance is correct 🐛 🔫
#36180
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
GenerationTesterMixin
inheritance is correctGenerationTesterMixin
inheritance is correct 🐛 🔫
# Model class overwrites `generate` (e.g. time series models) -> can generate | ||
if str(cls.__name__) in str(cls.generate): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can_generate()
is only used in GenerationMixin
-related code. Let's remove time series model from this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it completely different or uses part of generate()
, like some audio models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's completely different ☠️
@@ -2753,6 +2753,41 @@ def test_speculative_sampling_target_distribution(self): | |||
self.assertTrue(last_token_counts[8] > last_token_counts[3]) | |||
|
|||
|
|||
global_rng = random.Random() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copied from test_modeling_common
, otherwise we would have circular dependencies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a comment with "copied from" can be added i think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these used in a lot of places in this file, or just inside one method?
If so, we can probably avoid circular dependencies by importing them within that (single) method ..?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea, moving to an internal import to prevent code bloat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhmmm local imports would be needed in many places, will go with # Copied from
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one another possible approach is not to use
issubclass(self.class, GenerationTesterMixin),
but check the __name__
of all parent classes (but recursively, but maybe there is no bulit-in function to do this?) and see if GenerationTesterMixin
is in that set.
Up to you.
p.s. # copied
is less maintained and we would like to get away from it (since now we rely more on modular)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also okay to have a pure copy of the short functions :P It's just a handful of lines, I don't think it's worth the extra work for now -- I will have to refactor these lines when we remove TF (i.e. very soon) 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for aligning generation tests into one way. I left a few questions, for parts I didn't understand
@@ -2753,6 +2753,41 @@ def test_speculative_sampling_target_distribution(self): | |||
self.assertTrue(last_token_counts[8] > last_token_counts[3]) | |||
|
|||
|
|||
global_rng = random.Random() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a comment with "copied from" can be added i think
# Model class overwrites `generate` (e.g. time series models) -> can generate | ||
if str(cls.__name__) in str(cls.generate): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it completely different or uses part of generate()
, like some audio models?
# Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante | ||
all_generative_model_classes = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my understanding: do we need to have empty all_generative_model_classes
if no GenerationTesterMixin
is added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a model inherits GenerationMixin
, one of these two must happen:
GenerationTesterMixin
must be in the tester- [when the tests are broken, this should NEVER happen on new models] we specify
all_generative_model_classes = ()
to make it very clear we're NOT running generation tests
option 2 is intentionally annoying (we are forced to overwrite a property), so we are very explicit about skipping tests. We don't want skips to happen unless we're very intentional about it.
@unittest.skip(reason="TODO: Fix me @joao") | ||
def test_generate_with_head_masking(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
head-mask test can be deleted after #35786 is merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
@zucchini-nlp PR comments addressed/replied 🤗 |
What does this PR do?
If a model inherits from
GenerationMixin
, then it can callgenerate
. In those models,GenerationTesterMixin
must be added to the tester in order to testgenerate
. Sometimes we forget to do it. Not anymore 😈This PR follows up on #33212 and removes the last piece of common human error. It adds a meta-test to ensure we have the correct inheritance, which enables us to confirm that the right tests are being run automatically. In other words, if the model inherits from
GenerationMixin
, this test confirms that the tester inherits fromGenerationTesterMixin
(and the other way around too). This test also disincentivizes complex tester inheritance hierarchies, as inidefics
,bark
, ormusicgen
, which hurt long-term maintenance.In the process, it uncovered *many* cases of incorrect/missing inheritance, which are dealt with in this PR. Look at all these models that should have been tested with
generate
and haven't been since they were added 😭👉 reviewer: please start by reviewing
tests/test_modeling_common.py
andmodeling_utils.py
, then check the other files.