Skip to content

Commit

Permalink
[PyTorch] Proxy class for low-precision tensor (#1127)
Browse files Browse the repository at this point in the history
* Add base class for tensor proxies

Signed-off-by: Tim Moon <[email protected]>

* Move tensor detaching logic to tensor proxy base class

Signed-off-by: Tim Moon <[email protected]>

* Use Python wrappers to PyTorch extensions

Signed-off-by: Tim Moon <[email protected]>

* Include transpose caching logic in proxy encode function

Signed-off-by: Tim Moon <[email protected]>

* Debug dimension mismatch with amax history

Signed-off-by: Tim Moon <[email protected]>

* Move dequantize logic to proxy_decode func

Signed-off-by: Tim Moon <[email protected]>

* Rename to "QuantizedTensor"

Signed-off-by: Tim Moon <[email protected]>

* Rename "proxy_detach" to "detach"

Signed-off-by: Tim Moon <[email protected]>

* Include transpose cache in detach and clone funcs

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update FP8 workspaces with QuantizedTensor functions

Signed-off-by: Tim Moon <[email protected]>

* Move logic for FP8 transpose cache in FP8 workspaces to base class

Signed-off-by: Tim Moon <[email protected]>

* Remove cast-transpose logic from linear op

Signed-off-by: Tim Moon <[email protected]>

* Remove unnecessary args for Float8Tensor when using FP8 attr dict

Signed-off-by: Tim Moon <[email protected]>

* Remove __torch_function__ to QuantizedTensor

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Update tests/pytorch/test_float8tensor.py

Signed-off-by: Tim Moon <[email protected]>

* Debug FP8 transpose test

Signed-off-by: Tim Moon <[email protected]>

* Debug cast functions

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
3 people committed Sep 11, 2024
1 parent 40dda92 commit 2d57db8
Show file tree
Hide file tree
Showing 19 changed files with 1,352 additions and 1,242 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install torch
pip install torch numpy
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
Expand Down
9 changes: 4 additions & 5 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_transpose(
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_t, x, **tols)

# Caching test.
# Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
Expand All @@ -302,14 +302,13 @@ def test_transpose(
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."

# Inplace update test.
# Inplace update test
x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."

def test_serialization(
self,
Expand Down
5 changes: 1 addition & 4 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
test = Float8Tensor.to_float8(test)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous()
test._transpose_invalid = False
test = Float8Tensor.to_float8(test, with_transpose_cache=True)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/cpp_extensions/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def canonicalize_fp8_scales(
# Force offsets to be the same if needed
if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset:
if scale_offset != 0:
scale = scale[scale_offset]
scale = scale[scale_offset:]
scale_offset = 0
if amax_offset != 0:
amax = amax[0][amax_offset]
amax = amax[:, amax_offset:]
amax_offset = 0
if scale_inv_offset != 0:
scale_inv = scale_inv[scale_inv_offset]
scale_inv = scale_inv[scale_inv_offset:]
scale_inv_offset = 0

# Pack tensors and offsets into dicts
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales, empty_tensor
from ._common import canonicalize_fp8_scales

__all__ = ["cast_to_fp8", "cast_from_fp8"]

Expand Down Expand Up @@ -81,8 +81,7 @@ def cast_from_fp8(

# Construct empty tensors if needed
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`")

# Launch FP8 cast kernel
return torch.ops.tex_ts.cast_from_fp8_ts(
Expand Down
Loading

0 comments on commit 2d57db8

Please sign in to comment.