Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing a generic ModelConverter interface. #823

Conversation

balancap
Copy link
Contributor

@balancap balancap commented Feb 6, 2025

This model handler interface should cover most cases in quantization, fused layer optimization, ...

This PR adds:

  • A generic interface for a ModelConverter class, transforming a model;
  • An argument model.converters where the user can add a list of converters to apply to the model (e.g. float8)
  • Converting Float8Handler to ModelConverter interface.

Related issue: #790

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 6, 2025
@balancap balancap marked this pull request as draft February 6, 2025 15:14
@balancap balancap force-pushed the introduce-generic-model-handler-interface branch 4 times, most recently from 159f03b to 64a5338 Compare February 6, 2025 16:50
@balancap balancap marked this pull request as ready for review February 6, 2025 20:05
@balancap balancap marked this pull request as draft February 6, 2025 20:05
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general it looks quite reasonable to me. I left several suggestions. Please see if they make sense.
cc @fegin @wconstab for any additional feedbacks.

torchtitan/model_handler.py Outdated Show resolved Hide resolved
torchtitan/model_handler.py Outdated Show resolved Hide resolved
torchtitan/model_handler.py Outdated Show resolved Hide resolved
torchtitan/model_handler.py Outdated Show resolved Hide resolved
Comment on lines 544 to 552
# self.parser.add_argument(
# "--float8.enable_float8_linear",
# action="store_true",
# help="""
# If true, swaps `torch.nn.Linear` with `Float8Linear`.
# This feature requires you to install 'torchao' which can be found
# here: https://github.com/pytorch/ao
# """,
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow think we can keep this as is in this PR, and by default include "float8" in "model.handlers" configs.
My idea is to decouple handler registration and handler turning on/off.
So that it enables your use case, but doesn't change anyone else's need to adapt code / mental model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to keep this one as it is not to break current float8 workflow.

But model.handlers is not a registration list, it is a list of handlers to apply to the model. Registration of a new model handler is done in the code with register_model_handler.

If we don't have model.handlers, we are adding complexity for the end user + the codebase: for every new handler, one needs to add an my_handler.enable flag in config + having an if/else logic in every handler to check if it is activated or not.

It feels to me as the wrong design pattern: it is much simpler for the user to specify a list of handlers to apply (and for the case float8, I would advocate to raise an error, not just a warning if the hardware does not support it).

Copy link
Contributor

@tianyu-l tianyu-l Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah for sure, pls don't get me wrong.

I needed to run a lot of experiments with & without Float8, as a user of torchtitan. In the past, to enable / disable Float8, I just had to turn the three flags to True/False in the toml file. So I prefer not to go to two different places (model.converters and float8.enable...) to enable/disable float8 configs.

But I agree it's not good pattern keeping it as is. So let's follow your original change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 👍

torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
train.py Outdated Show resolved Hide resolved
train_configs/debug_model.toml Outdated Show resolved Hide resolved
torchtitan/model_handler.py Outdated Show resolved Hide resolved
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two general comments:

  1. ModelHandler is too vague. Sorry that I am not good at naming either, so I asked GPT and ModelConverter is the recommended way. And it is consistent with the convert() method.

  2. We should not create a new optimizer hooks since torch.optim.Optimizer already support hooks. The major benefit is that some other components, like TorchFT, may modify the behavior of torch.optim.Optimizer.step(). Adding additional hooks is likely to be incompatible (or make the integration hard) to those libraries. I'll submit a new PR to demonstrate how to register optimizer hooks based on Add Dynamic Model Import and ModelSpec Definition #814.

torchtitan/model_handler.py Outdated Show resolved Hide resolved
@balancap
Copy link
Contributor Author

balancap commented Feb 7, 2025

Following @fegin suggestion, I'll do the renaming to ModelConverter. I hope you're fine with it too @tianyu-l

@balancap balancap force-pushed the introduce-generic-model-handler-interface branch from 45e10b7 to 182a7ae Compare February 7, 2025 16:33
This model handler interface should cover most cases in quantization, fused layer optimization, ...
@balancap balancap force-pushed the introduce-generic-model-handler-interface branch from 182a7ae to a3a20e4 Compare February 7, 2025 17:09
* Using `string_list` for `model.converters` argument;
* Renaming to `ModelConverter`;
* Basic unit test coverage for `ModelConvertersContainer`
@balancap balancap force-pushed the introduce-generic-model-handler-interface branch from a3a20e4 to ed24d73 Compare February 7, 2025 17:42
@balancap balancap marked this pull request as ready for review February 7, 2025 17:43
@balancap
Copy link
Contributor Author

balancap commented Feb 7, 2025

@tianyu-l I have pushed a couple of improvements following your review, reverting the float8 flag back and adding basic unit test coverage.

@tianyu-l @fegin On the optimizer hook: it feels to me there is no obvious simple solution. The optimizer is getting created after the parallelism is applied to the model, whereas the converter/handler has to be built before (so can't receive the optimizer hook). In the current form, this PR is not adding anything new on this side, just doing a generic renaming of the float8 logic. I would be in favor of getting this feature merged as it is, and following the integration #814 on main, then open a PR to see which solution is the most elegant to integrate with PyTorch optimizer hook with ModelConverter. It is otherwise complicated to understand how these 2 PRs interacts with each other, may just delay things without necessarily finding a consensus.

@balancap
Copy link
Contributor Author

balancap commented Feb 7, 2025

In case you want to integrate with optimizer hooks, I believe the simpler way for now is to pass the model_converter to OptimizersContainer.__init__:

class OptimizersContainer(Stateful):
    """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages
    and saving/loading optimizer state_dict at checkpoint.
    """

    def __init__(
        self, model_parts: List[nn.Module], model_converter: ModelConverter, optimizer_kwargs: Dict[str, Any], name: str
    ) -> None:
          ...
          for model in self.model_parts:
                ...
                optimizer.register_step_post_hook(...) # lambda function calling model_converter.post_optimizer_hook

Happy to add that as a current solution, which may be still improved in the future (following #814 merge). I am not 100% myself what is the cleanest way, so that's why I feel a dedicated PR may be beneficial

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be in favor of getting this feature merged as it is, and following the integration #814 on main, then open a PR to see which solution is the most elegant to integrate with PyTorch optimizer hook with ModelConverter. It is otherwise complicated to understand how these 2 PRs interacts with each other, may just delay things without necessarily finding a consensus.

Sounds good to me. We can do the register_step_post_hook migration in another PR.

torchtitan/float8.py Show resolved Hide resolved
train.py Outdated Show resolved Hide resolved
torchtitan/float8.py Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

@fegin
Copy link
Contributor

fegin commented Feb 11, 2025

  1. Can you also change the title to be converter?

  2. https://github.com/pytorch/torchtitan/pull/814/files#r1951428540 demonstrates how to use ModeSpec to register the optimizer hooks.

job config, and apply them to the model sequentially.
"""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding #814 (comment), I think we can call apply_to_train_specs to register hooks to optimizers here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be indeed an option. Happy to discuss on a new PR, the small downside is my hunch that registers should be immutable, I have a bad feeling about modifying an existing entry! But maybe it would be an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the intertwined logic is bad -- can we just specify it in the TrainSpec construction and let the spec handle the registration and usage?

@balancap balancap changed the title Introducing a generic ModelHandler interface. Introducing a generic ModelConverter interface. Feb 13, 2025
@balancap balancap force-pushed the introduce-generic-model-handler-interface branch from db04aa4 to 211d8f9 Compare February 13, 2025 16:51
@balancap balancap force-pushed the introduce-generic-model-handler-interface branch from 211d8f9 to a210898 Compare February 13, 2025 16:52
@balancap
Copy link
Contributor Author

@tianyu-l I merge main to remove conflicts, reset to model.converters="float8" logic only as discussed, fixed the estimation.py to use ModelConverter as well.

@balancap
Copy link
Contributor Author

balancap commented Feb 13, 2025

Please tell me if there is any additional improvement in mind.

Once merged, I'll look at how integrating #814 ModelSpec with apply_to_train_specs (that are a couple of things around pipelining model parts, ... and so on where I prefer to make sure I have it right).

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! please address final comments :)

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this in config_manager.py too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, forgot to add it to the commit! Now fixed.

job config, and apply them to the model sequentially.
"""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the intertwined logic is bad -- can we just specify it in the TrainSpec construction and let the spec handle the registration and usage?

@@ -54,4 +54,3 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
mode = 'full'

[float8]
Copy link
Contributor

@tianyu-l tianyu-l Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When float8 is disabled in a config (debugmodel, 8b, 70b), let's add the two float8 options (as False), and have a commented out line converters = "float8". The point is to make it easier to continue using float8, especially for people who are not aware of this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added to debug, 8B and 70B models

@tianyu-l tianyu-l added this to the torchtitan v1.0.0 release milestone Feb 14, 2025
@balancap
Copy link
Contributor Author

@tianyu-l Should be good hopefully :) I merged the latest main and fixed the small import conflicts.

Thanks for your feedback on the PR!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome, thank you for contributing!
Let's work together on the next PR to integrate ModelConverter to TrainSpec

@tianyu-l tianyu-l merged commit 57387af into pytorch:main Feb 15, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants