-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchao float8tensor] #1415
base: crpa/subclass-tensor-ops
Are you sure you want to change the base?
[torchao float8tensor] #1415
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
3fa8e2d
to
d5fb9fe
Compare
abf0167
to
e7ca8b7
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
896b631
to
316327f
Compare
@crcrpar if you merge main, the pt nightly distributed ci tests should be fixed. |
d5fb9fe
to
15c8d12
Compare
c87a36c
to
0de44ee
Compare
thunder/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should be in #1394
thunder/core/jit_ext.py
Outdated
@@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace( | |||
trace = TraceCtx() | |||
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope()) | |||
func_result = unwrap(wrapped_func_result) | |||
if shallow_copy_output: | |||
if shallow_copy_output and not trace.bound_symbols: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1] | ||
import_ctx, call_ctx, object_ctx = {}, {}, {} | ||
for bsym in trace_of_fwd.bound_symbols: | ||
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs() | ||
import_ctx.update(cur_import_ctx) | ||
call_ctx.update(cur_call_ctx) | ||
object_ctx.update(cur_object_ctx) | ||
|
||
if import_ctx: | ||
added_bsym._import_ctx.update(import_ctx) | ||
if call_ctx: | ||
if added_bsym._call_ctx is not None: | ||
added_bsym._call_ctx.update(call_ctx) | ||
else: | ||
added_bsym._call_ctx = call_ctx | ||
if object_ctx: | ||
added_bsym._object_ctx.update(object_ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be in #1394
thunder/core/prims.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should also be in #1394
thunder/executors/torch_autograd.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be in #1394
15c8d12
to
70dc6ba
Compare
04d528a
to
804bc99
Compare
if executor == DynamoThunderExecutor: | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close(actual, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This failure doesn't feel easy to fix to me. So I made this into a script:
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.splitter import SubgraphInfo
from thunder.tests.make_tensor import make_tensor
def main():
batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
dtype = torch.float32
model = nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=dtype)
expected = fp8_model(x)
backend = ThunderCompiler()
jitted = torch.compile(fp8_model, backend=backend)
actual = jitted(x)
backend.save_reproducer_to_folder("./debug_torchao_with_thunderfx", use_pytest_benchmark=True)
print(f"{len(backend.subgraph_infos) = }")
subgraph: SubgraphInfo
for subgraph in backend.subgraph_infos:
print(f"# {len(subgraph.thunder_compiled_fns) = }")
torch.testing.assert_close(actual, expected)
if __name__ == "__main__":
main()
note that pytorch/ao#1339 is needed at the moment.
Below, I put the console output of the script above:
% python debug_thunderfx_torchao_fp8.py
/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/dynamo/compiler.py:21: UserWarning: The ThunderCompiler is in active development and may not work as expected. Please report any issues you encounter to the Lightning Thunder team.
warnings.warn(
len(backend.subgraph_infos) = 1
# len(subgraph.thunder_compiled_fns) = 0
Traceback (most recent call last):
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 34, in <module>
main()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/debug_thunderfx_torchao_fp8.py", line 30, in main
torch.testing.assert_close(actual, expected)
File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 388 / 1024 (37.9%)
Greatest absolute difference: 0.18639898300170898 at index (1, 61) (up to 1e-05 allowed)
Greatest relative difference: 1.9664803743362427 at index (10, 33) (up to 1.3e-06 allowed)
So it seems that thunder.jit
isn't used for this program but the numeric is diverging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check the result to see if they stay the same between different invocations. (Maybe due to low precision, the results could be different).
expected = fp8_model(x)
actual = fp8_model(x)
torch.testing.assert_close(actual, expected)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But please add a comment why expected and actual are both from calling the same model rather than one model and a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For thunderfx, I updated the test to check the parity between inductor and ThunderCompiler, not eager and ThunderFX.
7c1fea6
to
8475ff7
Compare
70dc6ba
to
fc6d8a9
Compare
ca3b5f7
to
2b30049
Compare
# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`. | ||
# Currently no subgraphs go to thunder.jit. | ||
if is_thunderfx: | ||
for subgraph in backend.subgraph_infos: | ||
if not bias and dtype == thunder.core.dtypes.bfloat16: | ||
assert not subgraph.thunder_compiled_fns | ||
else: | ||
assert subgraph.thunder_compiled_fns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel #1539 related in that both of them somehow mistakenly push things to the fallback, not thunder.jit
.
fc6d8a9
to
ce3edbc
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
next, function with tensor creation in it Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
2b30049
to
6b73636
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
between torch and thunder proxy Signed-off-by: Masaki Kozuki <[email protected]>
since the outputs of subclass flattening would be replaceable with the args of ctor/unflatten of that subclass tensors. Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
for more information, see https://pre-commit.ci
What does this PR do?
Improve the tensor subclass support of #1394 for TorchAo float8.
note: pytorch/ao#1339 is needed
my environment