-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Debug checkpointing with operation-based API #1063
Conversation
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes in the Linear
op are orthogonal to the rest of this PR. PyTorch modules save checkpoints recursively, so checkpointing a fused op (e.g. Linear
) will also checkpoint its constituent basic ops (e.g. BasicLinear
, Bias
). By registering the weight and bias params with the Linear
op, the checkpoints were saving two copies of the params. Converting weight
and bias
into Python properties avoids this behavior while retaining the existing API.
/te-ci pytorch |
* Debug checkpointing with operation-based API Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Store checkpoint FP8 state on CPU Signed-off-by: Tim Moon <[email protected]> * Fix bug where linear op was saving params multiple times Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
This PR debugs checkpointing with the operation-based API (see #707), in particular by adding logic to include FP8 scaling factors in the checkpoint. The checkpointing logic is very similar to the module-based API:
TransformerEngine/transformer_engine/pytorch/module/base.py
Line 555 in 5b6546c
TransformerEngine/transformer_engine/pytorch/module/base.py
Line 587 in 5b6546c
It is admittedly rather unintuitive, but I've added comments to justify the weird behavior.
I've also fixed an orthogonal bug where the linear op was including two copies of its params in its checkpoint.
Type of change
Changes
Checklist: