Skip to content

Commit

Permalink
update autocast.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Dec 11, 2024
1 parent 13b7973 commit bce7372
Showing 1 changed file with 81 additions and 46 deletions.
127 changes: 81 additions & 46 deletions thunder/transforms/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from thunder.core.proxies import TensorProxy
from thunder.core.symbol import BoundSymbolInterface, Symbol
from thunder.core.proxies import TensorProxy
from thunder.core.trace import TraceCtx
from thunder.core.trace_interpreter import TraceSubstitutionProcessor
from thunder.core.trace import TraceCtx,tracectx
from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list
from thunder.core.transforms import construct_trace, eval_trace, Transform
from thunder.clang import (
maybe_convert_to_dtype,
Expand Down Expand Up @@ -317,53 +317,88 @@ def is_cpu_tensor(p):

return None


class AutocastTraceSubstitutionProcessor(TraceSubstitutionProcessor):
def __init__(self, trace, dtype):
super().__init__(trace)
self.dtype = dtype

def process_bsym(self, bsym):
"""Process a bound symbol for autocast transformation.
This method is called by TraceSubstitutionProcessor.__call__ for each bound symbol.
"""
# Get the autocast implementation for this symbol
autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym)

if autocast_impl is None:
# If no autocast rule exists, use the original symbol
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)
result = bsym.sym(*args, **kwargs)
self.set_result(result)
return

# Apply the autocast implementation
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)
result = autocast_impl(*args, dtype=self.dtype, **kwargs)
self.set_result(result)


class AutocastTransform(Transform):
"""Transform that enables autocasting operations to a specified dtype.
Args:
dtype: The data type to which arguments could get cast if they are float32.
"""

"""Transform that applies automatic mixed precision (autocast) to eligible operations."""

def __init__(self, dtype: dtypes.dtype):
super().__init__()
"""Initialize the autocast transform.
Args:
dtype: The target dtype to cast eligible operations to (float16 or bfloat16)
"""
if not isinstance(dtype, dtypes.dtype):
raise ValueError(f"`dtype` expected to be `thunder.dtype.dtype` but got {type(dtype)}")
_check_valid_autocast_dtype(dtype)
raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}")

if dtype not in _allowed_downcast_types:
raise ValueError(
f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}"
)
self.dtype = dtype

def transform_traces_pre_prologue(
self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx, **kwargs
) -> tuple[TraceCtx, TraceCtx, TraceCtx]:
processor = AutocastTraceSubstitutionProcessor(computation_trace, self.dtype)
new_computation_trace, outputs = processor()
new_computation_trace.set_provenance("Autocast Transform")
return prologue_trace, new_computation_trace, epilogue_trace
self,
prologue_trace: TraceCtx,
computation_trace: TraceCtx,
epilogue_trace: TraceCtx | None,
**kwargs
) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]:
"""Transform the computation trace to apply autocast rules."""

class AutocastProcessor(TraceSubstitutionProcessor):
def __init__(self, trace, dtype, *args, **kwargs):
super().__init__(trace, *args, **kwargs)
self.dtype = dtype

def process_bsym(self, bsym):
# Skip special symbols that shouldn't be processed
if bsym.sym.id in trace_interpreter_skip_list:
self.new_trace.bound_symbols.append(bsym.from_bsym())
return

# Check if symbol has an autocast implementation
autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym)

if autocast_impl is not None:
# Read the arguments with potential autocast conversion
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)

# Apply the autocast implementation
with disable_autocast():
result = autocast_impl(*args, **kwargs, dtype=self.dtype)

self.set_result(result)
else:
# No autocast rule, process normally
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)
result = bsym.sym(*args, **kwargs)
self.set_result(result)

# Add the bound symbol to new trace
new_bsym = bsym.from_bsym()
new_bsym.args = args
new_bsym.kwargs = kwargs
self.add_processed_bsyms([new_bsym])

# Process the computation trace
if computation_trace is not None:
processor = AutocastProcessor(computation_trace, self.dtype)

# Get the actual args and kwargs from the kwargs dict
args = kwargs.get('args', ())
kw = kwargs.get('kwargs', {})

with tracectx(processor.new_trace):
# Initialize the processor's environment with input arguments
for trace_arg, arg in zip(computation_trace.args, args):
processor.env[trace_arg.name] = arg

# Initialize kwargs if any
for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()):
processor.env[trace_kwarg.name] = kwarg

new_trace, _ = processor()
computation_trace = new_trace

return prologue_trace, computation_trace, epilogue_trace

0 comments on commit bce7372

Please sign in to comment.