Skip to content

Commit

Permalink
add tensor subclass transform output to traces
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Dec 12, 2024
1 parent 832bf79 commit 6b73636
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ def get_computation_and_inputs(*args, **kwargs):
_tensor_subclass_transform_applied = True
if not _tensor_subclass_transform_applied:
computation_trc = flatten_tensor_subclasses(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
Expand Down

0 comments on commit 6b73636

Please sign in to comment.