Skip to content

Commit

Permalink
change iteration order
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Dec 18, 2024
1 parent 77df8bb commit 3f40ab9
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions cute_kernels/cute_inductor/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

class CuteInductor:
def __init__(
self, use_inductor: bool = True, replace_functions=[replace_rmsnorm, replace_swiglu_unchunked, replace_swiglu]
self,
use_torch_inductor_after_cute_inductor: bool = True,
replace_functions=[replace_rmsnorm, replace_swiglu_unchunked, replace_swiglu],
) -> None:
self.use_inductor = use_inductor
self.use_torch_inductor_after_cute_inductor = use_torch_inductor_after_cute_inductor
self.replace_functions = replace_functions

def compiler(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]) -> Callable:
Expand All @@ -34,7 +36,7 @@ def compiler(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor])
print("graph after cute inductor")
gm.print_readable()

if self.use_inductor:
if self.use_torch_inductor_after_cute_inductor:
inductor = lookup_backend("inductor")
compiled = inductor(gm, example_inputs)
else:
Expand Down

0 comments on commit 3f40ab9

Please sign in to comment.