Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Debug checkpointing with operation-based API #1063

Merged
merged 13 commits into from
Nov 5, 2024

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jul 31, 2024

Description

This PR debugs checkpointing with the operation-based API (see #707), in particular by adding logic to include FP8 scaling factors in the checkpoint. The checkpointing logic is very similar to the module-based API:

def get_extra_state(self) -> torch.Tensor:

def set_extra_state(self, state: torch.Tensor) -> None:

It is admittedly rather unintuitive, but I've added comments to justify the weird behavior.

I've also fixed an orthogonal bug where the linear op was including two copies of its params in its checkpoint.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Add logic for FP8 scales when checkpointing with operation-based API
  • Fix bug where linear op checkpoint saves params twice

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the bug Something isn't working label Jul 31, 2024
@timmoon10 timmoon10 requested a review from ksivaman July 31, 2024 01:56
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Sep 27, 2024
Signed-off-by: Tim Moon <[email protected]>
timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Oct 2, 2024
Signed-off-by: Tim Moon <[email protected]>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes in the Linear op are orthogonal to the rest of this PR. PyTorch modules save checkpoints recursively, so checkpointing a fused op (e.g. Linear) will also checkpoint its constituent basic ops (e.g. BasicLinear, Bias). By registering the weight and bias params with the Linear op, the checkpoints were saving two copies of the params. Converting weight and bias into Python properties avoids this behavior while retaining the existing API.

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

Merging with approval from @ptrendx and @ksivaman.

@timmoon10 timmoon10 merged commit f20d3dd into NVIDIA:main Nov 5, 2024
13 of 14 checks passed
phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request Nov 5, 2024
* Debug checkpointing with operation-based API

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Store checkpoint FP8 state on CPU

Signed-off-by: Tim Moon <[email protected]>

* Fix bug where linear op was saving params multiple times

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant