Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.compile with ZeRO (Experimental) #4878

Merged
merged 49 commits into from
Feb 6, 2024
Merged

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Dec 28, 2023

This PR enables torch.compile with ZeRO stages 1/2/3. You need to add compile section in your DeepSpeed config. The fields in the section are passed to torch.compile.

  "compile": {
    "disable": false,
    "backend": "inductor"
  }

To enable a custom backend, you can pass the fully qualified name of the backend function. For example, if you have a backend class my_backend in my_backend.py in the current directory, you can enable it by "backend": "my_backend.my_backend". You can find an example in a unit test.

Currently we validated the results with Megatron-DeepSpeed. See the example for the details.

NOTICE: This PR is a draft. We will need to validate the coverage and accuracy with many more examples.

@tjruwase
Copy link
Contributor

@stas00, FYI

@stas00
Copy link
Collaborator

stas00 commented Dec 31, 2023

Amazing work, @tohtana! I'm looking forward to trying it out

Here is a quick feedback:

Could we please flip disable to enabled so that the logic is consistent with other config values?

  • no double negation logic
  • consistent enabled (and not enable) - as all other config sections use that name.

@stas00
Copy link
Collaborator

stas00 commented Jan 3, 2024

tried it out and the compiled engine doesn't seem to forward some (all?) custom methods to the unwrapped model, e.g. it's failing:

[28:7]:  File "/data/env/lib/repos/retro-llama/tr043-dawn-llama-3/DeepSpeed/deepspeed/runtime/engine.py", line 468, in __getattr__
[28:7]:    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
[28:7]:AttributeError: 'DeepSpeedEngine' object has no attribute 'get_model_tflops_per_batch_per_gpu'

get_model_tflops_per_batch_per_gpu is a normal model's attribute and the same setup works if I set "disable": true for the compile section.

This method is just part of the normal model.

@stas00
Copy link
Collaborator

stas00 commented Jan 3, 2024

I hacked around it via model.module.method... and then I get many warnings and errors with the inductor backend and then it fails. I have attached the log.

This is just training Llama-2 on a single node using Accelerate with torch-nightly from last night.

The llama model is the same as HF Transformers with some additional methods. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

ds-compile.txt

@stas00
Copy link
Collaborator

stas00 commented Jan 3, 2024

If I disable the ds profiler than it runs despite the compilation errors/warnings - same log as in the previous comment, other than the last traceback where it crashes.

@stas00
Copy link
Collaborator

stas00 commented Jan 3, 2024

I'm also observing a very strange behavior of performance cycling:

the tflops go like this per iteration: 196, 196, 192, 196, 196, 192, 196, 196, 192, - 2 fast one slower - very exactly

w/o compile it was a consistent 194.

so this tells me something gets recompiled every 3 iterations.

@tohtana
Copy link
Contributor Author

tohtana commented Jan 5, 2024

@stas00 Thank you for your feedback! This PR is still experimental. Let me address the issues one by one.

The configuration disable is what I specifically sought feedback on. Currently, all configuration items under compile are passed to torch.compile, which accepts disable, not enable. This design was chosen for its simplicity, given the uncertainty of future changes in torch.compile. But we can define enable and flip it before passing it to torch.compile.

Do you have any further comments on this? If not, I will switch it to enable as you suggested. Actually, it is also my personal preference.

@stas00
Copy link
Collaborator

stas00 commented Jan 5, 2024

That's totally understandable, Masahiro. Tunji made that clear when he tagged me. If it's too early to provide feedback please ping me when you're ready for it.


disable vs enabled:

Ideally, Deepspeed users will never need to know anything about torch.compile specifics - many frameworks integrate this feature w/o having the user interact with it directly. So its API doesn't have to impact Deepspeed's API.

Since most (all?) Deepspeed config sections use enabled I'd say it'd be the most consistent to continue with that convention.

But this is an opinion of a single person, so please seek out opinions of others.

@tohtana
Copy link
Contributor Author

tohtana commented Jan 5, 2024

@stas00 Thank you for your quick reply. Probably it is difficult to have a clear conclusion for now. I will simply switch it to enable. Otherwise, many other users would have the same question as yours.
For a clearer answer, we need more experience to know what options DeepSpeed's users need in their applications. Even the options of torch.compile may change.

@stas00
Copy link
Collaborator

stas00 commented Jan 5, 2024

  • please note that it's enabled that DS uses everywhere else and not enable

  • wrt other options I'd say - use the minimal amount of options -

  1. let's perhaps start with only backend and then pick the most sensible defaults for that option.

  2. Then provide a user an API where they can preset their own **torch_compile_kwargs that will be passed to torch.compile - that way you're future proofing the Deepspeed API while allowing torch to do what they please - deepspeed will sync with the future changes to keep up with the sensible defaults and power-users should always be able to override the defaults.

deepspeed_engine.set_torch_compile_kwargs(**kwargs)

2a. I don't know if the current config file allows for a not predefined dict, so perhaps this could be possible:

  "compile": {
    "enabled": true,
    "backend": "inductor",
    "kwargs": {"key1"=value, "key2"=value}
  }

this should definitely work:

  "compile": {
    "enabled": true,
    "backend": "inductor",
    "kwargs": "key1=value;key2=value"
  }

but I don't know if all torch.compile kwargs could be stringified

but providing a programmatical API for power users would be the most fool-proof:

@tohtana tohtana enabled auto-merge February 6, 2024 04:34
@tohtana tohtana added this pull request to the merge queue Feb 6, 2024
Merged via the queue into master with commit c3cfe96 Feb 6, 2024
14 checks passed
mrwyattii added a commit that referenced this pull request Feb 12, 2024
Tests running older version of torch will fail the compile tests added
in #4878.
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
This PR enables `torch.compile` with ZeRO stages 1/2/3. You need to add
`compile` section in your DeepSpeed config. The fields in the section
are passed to `torch.compile`.

```json
  "compile": {
    "disable": false,
    "backend": "inductor"
  }
```

To enable a custom backend, you can pass the fully qualified name of the
backend function. For example, if you have a backend class `my_backend`
in `my_backend.py` in the current directory, you can enable it by
`"backend": "my_backend.my_backend"`. You can find an example in [a unit
test](https://github.com/microsoft/DeepSpeed/blob/eb9d4e06e9596f391aea305a6a5c6ec70cc28b58/tests/unit/runtime/compile/test_config.py#L116).

Currently we validated the results with Megatron-DeepSpeed. See the
[example](https://github.com/microsoft/Megatron-DeepSpeed/tree/tohtana/enable_compile/examples_deepspeed/compile)
for the details.

NOTICE: This PR is a draft. We will need to validate the coverage and
accuracy with many more examples.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
Tests running older version of torch will fail the compile tests added
in microsoft#4878.
return backend

elif isinstance(backend, str):
if backend in torch._dynamo.list_backends():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana The default list_backends call will exclude debug and experimental backends, e.g. eager. I think it's better to use list_backends(exclude_tags=()) here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment. I opened #5191.

github-merge-queue bot pushed a commit that referenced this pull request Feb 26, 2024
As mentioned at
#4878 (comment),
we are currently unable to enable debug or experimental backends for the
compiler. This PR enables users to utilize these backends.
ShellyNR pushed a commit to ShellyNR/DeepSpeed that referenced this pull request Mar 11, 2024
As mentioned at
microsoft#4878 (comment),
we are currently unable to enable debug or experimental backends for the
compiler. This PR enables users to utilize these backends.
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR enables `torch.compile` with ZeRO stages 1/2/3. You need to add
`compile` section in your DeepSpeed config. The fields in the section
are passed to `torch.compile`.

```json
  "compile": {
    "disable": false,
    "backend": "inductor"
  }
```

To enable a custom backend, you can pass the fully qualified name of the
backend function. For example, if you have a backend class `my_backend`
in `my_backend.py` in the current directory, you can enable it by
`"backend": "my_backend.my_backend"`. You can find an example in [a unit
test](https://github.com/microsoft/DeepSpeed/blob/eb9d4e06e9596f391aea305a6a5c6ec70cc28b58/tests/unit/runtime/compile/test_config.py#L116).

Currently we validated the results with Megatron-DeepSpeed. See the
[example](https://github.com/microsoft/Megatron-DeepSpeed/tree/tohtana/enable_compile/examples_deepspeed/compile)
for the details.

NOTICE: This PR is a draft. We will need to validate the coverage and
accuracy with many more examples.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
Tests running older version of torch will fail the compile tests added
in microsoft#4878.
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
As mentioned at
microsoft#4878 (comment),
we are currently unable to enable debug or experimental backends for the
compiler. This PR enables users to utilize these backends.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants