From 3f40ab90a4fd744c7e7c10693f65f40498577dca Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 18 Dec 2024 11:44:55 -0500 Subject: [PATCH] change iteration order Signed-off-by: Mayank Mishra --- cute_kernels/cute_inductor/compiler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cute_kernels/cute_inductor/compiler.py b/cute_kernels/cute_inductor/compiler.py index 0468fc3c..2de78709 100644 --- a/cute_kernels/cute_inductor/compiler.py +++ b/cute_kernels/cute_inductor/compiler.py @@ -14,9 +14,11 @@ class CuteInductor: def __init__( - self, use_inductor: bool = True, replace_functions=[replace_rmsnorm, replace_swiglu_unchunked, replace_swiglu] + self, + use_torch_inductor_after_cute_inductor: bool = True, + replace_functions=[replace_rmsnorm, replace_swiglu_unchunked, replace_swiglu], ) -> None: - self.use_inductor = use_inductor + self.use_torch_inductor_after_cute_inductor = use_torch_inductor_after_cute_inductor self.replace_functions = replace_functions def compiler(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]) -> Callable: @@ -34,7 +36,7 @@ def compiler(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]) print("graph after cute inductor") gm.print_readable() - if self.use_inductor: + if self.use_torch_inductor_after_cute_inductor: inductor = lookup_backend("inductor") compiled = inductor(gm, example_inputs) else: