Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchao float8tensor] #1415

Draft
wants to merge 67 commits into
base: crpa/subclass-tensor-ops
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 8, 2024

What does this PR do?

Improve the tensor subclass support of #1394 for TorchAo float8.

note: pytorch/ao#1339 is needed

my environment

  • torch: 2.6.0a0+git62eea62
  • nvfuser: 0.2.23+gitbb05859
  • torchao: 0.7.0+gitb2e42ff6
  • CUDA device: RTX 6000 Ada Generation
  • Driver Version: 560.35.03
  • CUDA Version: 12.6

@crcrpar

This comment was marked as outdated.

@crcrpar

This comment was marked as outdated.

@crcrpar

This comment was marked as outdated.

@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch 2 times, most recently from 896b631 to 316327f Compare November 24, 2024 16:13
@t-vi
Copy link
Collaborator

t-vi commented Nov 25, 2024

@crcrpar if you merge main, the pt nightly distributed ci tests should be fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should be in #1394

@@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace(
trace = TraceCtx()
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
if shallow_copy_output and not trace.bound_symbols:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +774 to +794

added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
import_ctx.update(cur_import_ctx)
call_ctx.update(cur_call_ctx)
object_ctx.update(cur_object_ctx)

if import_ctx:
added_bsym._import_ctx.update(import_ctx)
if call_ctx:
if added_bsym._call_ctx is not None:
added_bsym._call_ctx.update(call_ctx)
else:
added_bsym._call_ctx = call_ctx
if object_ctx:
added_bsym._object_ctx.update(object_ctx)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in #1394

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should also be in #1394

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in #1394

@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 15c8d12 to 70dc6ba Compare November 28, 2024 12:31
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 04d528a to 804bc99 Compare November 28, 2024 12:32
Comment on lines 275 to 294
if executor == DynamoThunderExecutor:
with pytest.raises(AssertionError):
torch.testing.assert_close(actual, expected)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This failure doesn't feel easy to fix to me. So I made this into a script:

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.splitter import SubgraphInfo
from thunder.tests.make_tensor import make_tensor


def main():
    batch_size, in_features, out_features = 16, 32, 64

    device = torch.device("cuda")
    dtype = torch.float32

    model = nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype)
    fp8_model = convert_to_float8_training(model)
    x = make_tensor((batch_size, in_features), device=device, dtype=dtype)
    expected = fp8_model(x)

    backend = ThunderCompiler()
    jitted = torch.compile(fp8_model, backend=backend)
    actual = jitted(x)

    backend.save_reproducer_to_folder("./debug_torchao_with_thunderfx", use_pytest_benchmark=True)
    print(f"{len(backend.subgraph_infos) = }")
    subgraph: SubgraphInfo
    for subgraph in backend.subgraph_infos:
        print(f"# {len(subgraph.thunder_compiled_fns) = }")

    torch.testing.assert_close(actual, expected)


if __name__ == "__main__":
    main()

note that pytorch/ao#1339 is needed at the moment.

Below, I put the console output of the script above:

% python debug_thunderfx_torchao_fp8.py
/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/dynamo/compiler.py:21: UserWarning: The ThunderCompiler is in active development and may not work as expected. Please report any issues you encounter to the Lightning Thunder team.
  warnings.warn(
len(backend.subgraph_infos) = 1
# len(subgraph.thunder_compiled_fns) = 0
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 34, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 30, in main
    torch.testing.assert_close(actual, expected)
  File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 388 / 1024 (37.9%)
Greatest absolute difference: 0.18639898300170898 at index (1, 61) (up to 1e-05 allowed)
Greatest relative difference: 1.9664803743362427 at index (10, 33) (up to 1.3e-06 allowed)

So it seems that thunder.jit isn't used for this program but the numeric is diverging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check the result to see if they stay the same between different invocations. (Maybe due to low precision, the results could be different).

expected = fp8_model(x)
actual = fp8_model(x)
torch.testing.assert_close(actual, expected)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But please add a comment why expected and actual are both from calling the same model rather than one model and a reference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For thunderfx, I updated the test to check the parity between inductor and ThunderCompiler, not eager and ThunderFX.

@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 7c1fea6 to 8475ff7 Compare November 30, 2024 07:02
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 70dc6ba to fc6d8a9 Compare December 7, 2024 07:22
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch 2 times, most recently from ca3b5f7 to 2b30049 Compare December 9, 2024 09:28
Comment on lines +303 to +310
# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
# Currently no subgraphs go to thunder.jit.
if is_thunderfx:
for subgraph in backend.subgraph_infos:
if not bias and dtype == thunder.core.dtypes.bfloat16:
assert not subgraph.thunder_compiled_fns
else:
assert subgraph.thunder_compiled_fns
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel #1539 related in that both of them somehow mistakenly push things to the fallback, not thunder.jit.

@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from fc6d8a9 to ce3edbc Compare December 12, 2024 23:23
Signed-off-by: Masaki Kozuki <[email protected]>
next, function with tensor creation in it

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
crcrpar and others added 23 commits December 13, 2024 08:24
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 2b30049 to 6b73636 Compare December 12, 2024 23:25
between torch and thunder proxy

Signed-off-by: Masaki Kozuki <[email protected]>
since the outputs of subclass flattening would be replaceable with the
args of ctor/unflatten of that subclass tensors.

Signed-off-by: Masaki Kozuki <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants