-
Notifications
You must be signed in to change notification settings - Fork 275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing a generic ModelConverter
interface.
#823
Introducing a generic ModelConverter
interface.
#823
Conversation
159f03b
to
64a5338
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchtitan/config_manager.py
Outdated
# self.parser.add_argument( | ||
# "--float8.enable_float8_linear", | ||
# action="store_true", | ||
# help=""" | ||
# If true, swaps `torch.nn.Linear` with `Float8Linear`. | ||
# This feature requires you to install 'torchao' which can be found | ||
# here: https://github.com/pytorch/ao | ||
# """, | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somehow think we can keep this as is in this PR, and by default include "float8" in "model.handlers" configs.
My idea is to decouple handler registration and handler turning on/off.
So that it enables your use case, but doesn't change anyone else's need to adapt code / mental model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to keep this one as it is not to break current float8 workflow.
But model.handlers
is not a registration list, it is a list of handlers to apply to the model. Registration of a new model handler is done in the code with register_model_handler
.
If we don't have model.handlers
, we are adding complexity for the end user + the codebase: for every new handler, one needs to add an my_handler.enable
flag in config + having an if/else
logic in every handler to check if it is activated or not.
It feels to me as the wrong design pattern: it is much simpler for the user to specify a list of handlers to apply (and for the case float8, I would advocate to raise an error, not just a warning if the hardware does not support it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah for sure, pls don't get me wrong.
I needed to run a lot of experiments with & without Float8, as a user of torchtitan. In the past, to enable / disable Float8, I just had to turn the three flags to True/False in the toml file. So I prefer not to go to two different places (model.converters
and float8.enable...
) to enable/disable float8 configs.
But I agree it's not good pattern keeping it as is. So let's follow your original change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two general comments:
-
ModelHandler is too vague. Sorry that I am not good at naming either, so I asked GPT and
ModelConverter
is the recommended way. And it is consistent with theconvert()
method. -
We should not create a new optimizer hooks since
torch.optim.Optimizer
already support hooks. The major benefit is that some other components, like TorchFT, may modify the behavior oftorch.optim.Optimizer.step()
. Adding additional hooks is likely to be incompatible (or make the integration hard) to those libraries. I'll submit a new PR to demonstrate how to register optimizer hooks based on Add Dynamic Model Import and ModelSpec Definition #814.
45e10b7
to
182a7ae
Compare
This model handler interface should cover most cases in quantization, fused layer optimization, ...
182a7ae
to
a3a20e4
Compare
* Using `string_list` for `model.converters` argument; * Renaming to `ModelConverter`; * Basic unit test coverage for `ModelConvertersContainer`
a3a20e4
to
ed24d73
Compare
@tianyu-l I have pushed a couple of improvements following your review, reverting the @tianyu-l @fegin On the optimizer hook: it feels to me there is no obvious simple solution. The optimizer is getting created after the parallelism is applied to the model, whereas the converter/handler has to be built before (so can't receive the optimizer hook). In the current form, this PR is not adding anything new on this side, just doing a generic renaming of the float8 logic. I would be in favor of getting this feature merged as it is, and following the integration #814 on |
In case you want to integrate with optimizer hooks, I believe the simpler way for now is to pass the class OptimizersContainer(Stateful):
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages
and saving/loading optimizer state_dict at checkpoint.
"""
def __init__(
self, model_parts: List[nn.Module], model_converter: ModelConverter, optimizer_kwargs: Dict[str, Any], name: str
) -> None:
...
for model in self.model_parts:
...
optimizer.register_step_post_hook(...) # lambda function calling model_converter.post_optimizer_hook Happy to add that as a current solution, which may be still improved in the future (following #814 merge). I am not 100% myself what is the cleanest way, so that's why I feel a dedicated PR may be beneficial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be in favor of getting this feature merged as it is, and following the integration #814 on main, then open a PR to see which solution is the most elegant to integrate with PyTorch optimizer hook with ModelConverter. It is otherwise complicated to understand how these 2 PRs interacts with each other, may just delay things without necessarily finding a consensus.
Sounds good to me. We can do the register_step_post_hook
migration in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding tests!
|
job config, and apply them to the model sequentially. | ||
""" | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding #814 (comment), I think we can call apply_to_train_specs
to register hooks to optimizers here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be indeed an option. Happy to discuss on a new PR, the small downside is my hunch that registers should be immutable, I have a bad feeling about modifying an existing entry! But maybe it would be an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the intertwined logic is bad -- can we just specify it in the TrainSpec
construction and let the spec handle the registration and usage?
ModelHandler
interface.ModelConverter
interface.
…el-handler-interface
db04aa4
to
211d8f9
Compare
211d8f9
to
a210898
Compare
@tianyu-l I merge |
Please tell me if there is any additional improvement in mind. Once merged, I'll look at how integrating #814 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! please address final comments :)
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
self.enabled = False | ||
|
||
float8_config = job_config.float8 | ||
if not float8_config.enable_float8_linear: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove this in config_manager.py
too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mistake, forgot to add it to the commit! Now fixed.
job config, and apply them to the model sequentially. | ||
""" | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the intertwined logic is bad -- can we just specify it in the TrainSpec
construction and let the spec handle the registration and usage?
@@ -54,4 +54,3 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] | |||
mode = 'full' | |||
|
|||
[float8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When float8 is disabled in a config (debugmodel, 8b, 70b), let's add the two float8 options (as False
), and have a commented out line converters = "float8"
. The point is to make it easier to continue using float8, especially for people who are not aware of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added to debug, 8B and 70B models
@tianyu-l Should be good hopefully :) I merged the latest Thanks for your feedback on the PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome, thank you for contributing!
Let's work together on the next PR to integrate ModelConverter
to TrainSpec
This model handler interface should cover most cases in quantization, fused layer optimization, ...
This PR adds:
ModelConverter
class, transforming a model;model.converters
where the user can add a list of converters to apply to the model (e.g.float8
)Float8Handler
toModelConverter
interface.Related issue: #790