-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Minor optimizations to reduce CPU overheads in modules #1191
[PyTorch] Minor optimizations to reduce CPU overheads in modules #1191
Conversation
Avoid enable_grad context when possible in cast function. Cache distributed group properties. Signed-off-by: Tim Moon <[email protected]>
Avoid torch.nn.Module impl of __setattr__. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you propagate the CPU offloading importing fix to GroupedLinear as well?
from ..cpu_offload import CPUOffloadEnabled |
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() | ||
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() | ||
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() | ||
self._fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we couldn't instead just do something like
te_params = self.get_te_params() # calls _fast_getattr internally, te_params is a normal object
te_params.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise it will be hard to enforce everybody using only the _fast_get/setattr
I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better we could store these attrs in fp8_meta
or some other dict. I feel like the behavior of torch.nn.Module
is a hint we shouldn't change its attrs frequently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried wrapping these attrs in a property that internally calls _fast_setattr
, but that added ~10 us overhead (probably from the extra indirection when getting the attrs). I think it's a good idea to refactor these frequently changed attrs so they are not held directly by the module, but I think that would be beyond the scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving this logic into the __setattr__
function makes things a little cleaner. It adds ~2 us overhead, but it's still a win of ~6 us compared to the baseline.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
4d9379d
to
92c443e
Compare
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
if tensor is None: | ||
return None | ||
if tensor.dtype == dtype: | ||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if tensor is None: | |
return None | |
if tensor.dtype == dtype: | |
return tensor | |
if tensor is None or tensor.dtype == dtype | |
return tensor |
Description
We have observed that TE modules experience non-trivial CPU overhead, which often becomes a performance bottleneck in the forward pass of small models. For example, measuring the CPU runtime for Megatron-core modules with BF16 compute and TP=1:
ColumnParallelLinear
: 74 us per forward passTEColumnParallelLinear
: 140 us per forward passUnfortunately this overhead is distributed throughout the forward pass. Many basic PyTorch operations, e.g. getting attributes from
torch.Tensor
, involve O(1 us) overhead, so even basic checks to handle all of our advanced features will eventually add up to something non-trivial.This PR makes a few minor optimizations:
te.pytorch.cpu_offload
in every forward passtorch.nn.Module.__setattr__
when possibletorch.nn.Module
when possibleI see a 1.22x speedup, with 115 us per forward pass.
Type of change
Changes
te.pytorch.cpu_offload
in every forward passtorch.nn.Module.__setattr__
when possibletorch.nn.Module
when possibleChecklist: