diff --git a/cute_kernels/cute_inductor/swiglu_unchunked.py b/cute_kernels/cute_inductor/swiglu_unchunked.py index de4a324d..c0f93b45 100644 --- a/cute_kernels/cute_inductor/swiglu_unchunked.py +++ b/cute_kernels/cute_inductor/swiglu_unchunked.py @@ -8,8 +8,13 @@ def replace_swiglu_unchunked(gm: GraphModule, node: Node) -> None: - if node.op == CALL_METHOD and node.target == torch.chunk.__name__: - print(node.args, node.kwargs) + if not (node.op == CALL_METHOD and node.target == torch.chunk.__name__): + return + + chunks = node.kwargs.get("chunks", node.args[1]) + if chunks != 2: + return + # if len(node.args) == 2 and node.args[1] == 2: # with gm.graph.inserting_after(node): # # Create a new node for the custom chunk_silu function