Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Proxy class for low-precision tensor #1127

Merged
merged 25 commits into from
Sep 11, 2024

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Aug 21, 2024

Description

The Float8Tensor class effectively implements a proxy design pattern: it internally encodes data in FP8 with FP32 scaling factors but externally presents the interface of a plain PyTorch tensor in FP32/FP16/BF16. This PR generalizes this logic by moving the proxy logic to an abstract ProxyTensor QuantizedTensor class. I envision implementing other quantization schemes (e.g. block scaling) by subclassing ProxyTensor QuantizedTensor

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Add base class for tensor proxies

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 marked this pull request as ready for review August 30, 2024 02:00
@timmoon10 timmoon10 changed the title [WIP] [PyTorch] Proxy class for low-precision tensor [PyTorch] Proxy class for low-precision tensor Aug 30, 2024
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@ksivaman
Copy link
Member

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pipeline is clean on rerun: 18307147

@ksivaman ksivaman merged commit 2d57db8 into NVIDIA:main Sep 11, 2024
25 of 26 checks passed
yaox12 pushed a commit to yaox12/TransformerEngine that referenced this pull request Sep 12, 2024
* Add base class for tensor proxies

Signed-off-by: Tim Moon <[email protected]>

* Move tensor detaching logic to tensor proxy base class

Signed-off-by: Tim Moon <[email protected]>

* Use Python wrappers to PyTorch extensions

Signed-off-by: Tim Moon <[email protected]>

* Include transpose caching logic in proxy encode function

Signed-off-by: Tim Moon <[email protected]>

* Debug dimension mismatch with amax history

Signed-off-by: Tim Moon <[email protected]>

* Move dequantize logic to proxy_decode func

Signed-off-by: Tim Moon <[email protected]>

* Rename to "QuantizedTensor"

Signed-off-by: Tim Moon <[email protected]>

* Rename "proxy_detach" to "detach"

Signed-off-by: Tim Moon <[email protected]>

* Include transpose cache in detach and clone funcs

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update FP8 workspaces with QuantizedTensor functions

Signed-off-by: Tim Moon <[email protected]>

* Move logic for FP8 transpose cache in FP8 workspaces to base class

Signed-off-by: Tim Moon <[email protected]>

* Remove cast-transpose logic from linear op

Signed-off-by: Tim Moon <[email protected]>

* Remove unnecessary args for Float8Tensor when using FP8 attr dict

Signed-off-by: Tim Moon <[email protected]>

* Remove __torch_function__ to QuantizedTensor

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Update tests/pytorch/test_float8tensor.py

Signed-off-by: Tim Moon <[email protected]>

* Debug FP8 transpose test

Signed-off-by: Tim Moon <[email protected]>

* Debug cast functions

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants