diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index b72429aed..0d860cabc 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,7 +8,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.trace import TraceCtx,tracectx +from thunder.core.trace import TraceCtx, tracectx from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( @@ -317,18 +317,19 @@ def is_cpu_tensor(p): return None + class AutocastTransform(Transform): """Transform that applies automatic mixed precision (autocast) to eligible operations.""" - + def __init__(self, dtype: dtypes.dtype): """Initialize the autocast transform. - + Args: dtype: The target dtype to cast eligible operations to (float16 or bfloat16) """ if not isinstance(dtype, dtypes.dtype): raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") - + if dtype not in _allowed_downcast_types: raise ValueError( f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" @@ -336,19 +337,15 @@ def __init__(self, dtype: dtypes.dtype): self.dtype = dtype def transform_traces_pre_prologue( - self, - prologue_trace: TraceCtx, - computation_trace: TraceCtx, - epilogue_trace: TraceCtx | None, - **kwargs + self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: """Transform the computation trace to apply autocast rules.""" - + class AutocastProcessor(TraceSubstitutionProcessor): def __init__(self, trace, dtype, *args, **kwargs): super().__init__(trace, *args, **kwargs) self.dtype = dtype - + def process_bsym(self, bsym): # Skip special symbols that shouldn't be processed if bsym.sym.id in trace_interpreter_skip_list: @@ -357,24 +354,24 @@ def process_bsym(self, bsym): # Check if symbol has an autocast implementation autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) - + if autocast_impl is not None: # Read the arguments with potential autocast conversion args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) - + # Apply the autocast implementation with disable_autocast(): result = autocast_impl(*args, **kwargs, dtype=self.dtype) - + self.set_result(result) else: # No autocast rule, process normally - args = tree_map(self.read, bsym.args) + args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) result = bsym.sym(*args, **kwargs) self.set_result(result) - + # Add the bound symbol to new trace new_bsym = bsym.from_bsym() new_bsym.args = args @@ -384,21 +381,21 @@ def process_bsym(self, bsym): # Process the computation trace if computation_trace is not None: processor = AutocastProcessor(computation_trace, self.dtype) - + # Get the actual args and kwargs from the kwargs dict - args = kwargs.get('args', ()) - kw = kwargs.get('kwargs', {}) - + args = kwargs.get("args", ()) + kw = kwargs.get("kwargs", {}) + with tracectx(processor.new_trace): # Initialize the processor's environment with input arguments for trace_arg, arg in zip(computation_trace.args, args): processor.env[trace_arg.name] = arg - + # Initialize kwargs if any for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()): processor.env[trace_kwarg.name] = kwarg - + new_trace, _ = processor() computation_trace = new_trace - return prologue_trace, computation_trace, epilogue_trace \ No newline at end of file + return prologue_trace, computation_trace, epilogue_trace