Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Minor optimizations to reduce CPU overheads in modules #1191

Merged
merged 20 commits into from
Oct 4, 2024

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Sep 18, 2024

Description

We have observed that TE modules experience non-trivial CPU overhead, which often becomes a performance bottleneck in the forward pass of small models. For example, measuring the CPU runtime for Megatron-core modules with BF16 compute and TP=1:

Unfortunately this overhead is distributed throughout the forward pass. Many basic PyTorch operations, e.g. getting attributes from torch.Tensor, involve O(1 us) overhead, so even basic checks to handle all of our advanced features will eventually add up to something non-trivial.

This PR makes a few minor optimizations:

  • Avoid importing from te.pytorch.cpu_offload in every forward pass
  • Memoize NCCL process group properties
  • Avoid custom logic in torch.nn.Module.__setattr__ when possible
  • Avoid custom logic for accessing params in torch.nn.Module when possible
  • Avoid accessing tensor attrs more than necessary

I see a 1.22x speedup, with 115 us per forward pass.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Avoid importing from te.pytorch.cpu_offload in every forward pass
  • Memoize NCCL process group properties
  • Avoid custom logic in torch.nn.Module.__setattr__ when possible
  • Avoid custom logic for accessing params in torch.nn.Module when possible
  • Avoid accessing tensor attrs more than necessary

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Avoid enable_grad context when possible in cast function. Cache distributed group properties.

Signed-off-by: Tim Moon <[email protected]>
Avoid torch.nn.Module impl of __setattr__.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 added the enhancement New feature or request label Sep 18, 2024
Copy link
Collaborator

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you propagate the CPU offloading importing fix to GroupedLinear as well?

from ..cpu_offload import CPUOffloadEnabled

self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self._fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we couldn't instead just do something like

te_params = self.get_te_params()  # calls _fast_getattr internally, te_params is a normal object
te_params.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise it will be hard to enforce everybody using only the _fast_get/setattr I think.

Copy link
Collaborator Author

@timmoon10 timmoon10 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better we could store these attrs in fp8_meta or some other dict. I feel like the behavior of torch.nn.Module is a hint we shouldn't change its attrs frequently.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried wrapping these attrs in a property that internally calls _fast_setattr, but that added ~10 us overhead (probably from the extra indirection when getting the attrs). I think it's a good idea to refactor these frequently changed attrs so they are not held directly by the module, but I think that would be beyond the scope of this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this logic into the __setattr__ function makes things a little cleaner. It adds ~2 us overhead, but it's still a win of ~6 us compared to the baseline.

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Comment on lines +221 to +224
if tensor is None:
return None
if tensor.dtype == dtype:
return tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if tensor is None:
return None
if tensor.dtype == dtype:
return tensor
if tensor is None or tensor.dtype == dtype
return tensor

@timmoon10 timmoon10 merged commit 9d976bc into NVIDIA:main Oct 4, 2024
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants