From 2f51b44c1bb277ab496adcfc6ac2abc3ad90d56a Mon Sep 17 00:00:00 2001 From: rittik9 Date: Mon, 9 Dec 2024 12:38:36 +0000 Subject: [PATCH] refactor: autocast.py --- thunder/tests/test_autocast.py | 21 ++------------------- thunder/transforms/autocast.py | 21 +++------------------ 2 files changed, 5 insertions(+), 37 deletions(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 1d80bf583..7c1b647a5 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -27,6 +27,7 @@ def test_thunder_autocast_transform(executor, device, dtype): def f(a, b, c): return a @ (b + c) + # The following functions needs to be updated as autocast_impls grows. def g(a, b, c): return a + b - c @@ -58,6 +59,7 @@ def h(a, b, c): out = compiled(x, y, z) devicetype = torch.device(device).type + # note(crcrpar): This test could be broken in the future as thunder autocast develops. with torch.autocast(device_type=devicetype, dtype=autocast_torch_dtype): torch_output = func(x, y, z) assert out.dtype == torch_output.dtype @@ -309,22 +311,3 @@ def foo(a, b, c, d): for eg, jg in zip(eager_grads, jit_grads): torch.testing.assert_close(eg, jg, rtol=5e-3, atol=5e-3) - - -# def simple_addition(x, y): -# return x + y - - -# def test_autocast_transform(): -# autocast_transform = AutocastTransform(dtype=torch.bfloat16) -# jitted_fn = jit(simple_addition, transforms=[autocast_transform]) - -# x = torch.randn(2, 2, dtype=torch.float32) -# y = torch.randn(2, 2, dtype=torch.float32) - -# result = jitted_fn(x, y) - -# assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}" - -# expected_result = simple_addition(x, y).to(torch.bfloat16) -# assert torch.allclose(result, expected_result), "The output values do not match the expected results." diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 0d860cabc..545f731e2 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -347,55 +347,40 @@ def __init__(self, trace, dtype, *args, **kwargs): self.dtype = dtype def process_bsym(self, bsym): - # Skip special symbols that shouldn't be processed if bsym.sym.id in trace_interpreter_skip_list: self.new_trace.bound_symbols.append(bsym.from_bsym()) return - # Check if symbol has an autocast implementation autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) if autocast_impl is not None: - # Read the arguments with potential autocast conversion args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) - # Apply the autocast implementation with disable_autocast(): result = autocast_impl(*args, **kwargs, dtype=self.dtype) self.set_result(result) else: - # No autocast rule, process normally args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) result = bsym.sym(*args, **kwargs) self.set_result(result) - # Add the bound symbol to new trace new_bsym = bsym.from_bsym() new_bsym.args = args new_bsym.kwargs = kwargs self.add_processed_bsyms([new_bsym]) - # Process the computation trace if computation_trace is not None: processor = AutocastProcessor(computation_trace, self.dtype) - # Get the actual args and kwargs from the kwargs dict args = kwargs.get("args", ()) kw = kwargs.get("kwargs", {}) - with tracectx(processor.new_trace): - # Initialize the processor's environment with input arguments - for trace_arg, arg in zip(computation_trace.args, args): - processor.env[trace_arg.name] = arg + processor.process_args(*args, **kw) - # Initialize kwargs if any - for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()): - processor.env[trace_kwarg.name] = kwarg - - new_trace, _ = processor() - computation_trace = new_trace + new_trace, outputs = processor() + computation_trace = new_trace return prologue_trace, computation_trace, epilogue_trace