From 42c00967c6bbb90d31a3529fd0c96502ada25132 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 18 Dec 2024 01:44:09 -0500 Subject: [PATCH] rename Signed-off-by: Mayank Mishra --- cute_kernels/cute_inductor/rmsnorm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cute_kernels/cute_inductor/rmsnorm.py b/cute_kernels/cute_inductor/rmsnorm.py index 91336d90..6a185b70 100644 --- a/cute_kernels/cute_inductor/rmsnorm.py +++ b/cute_kernels/cute_inductor/rmsnorm.py @@ -9,12 +9,10 @@ def replace_rmsnorm(gm: GraphModule, node: Node) -> None: if node.op == CALL_FUNCTION and node.target == torch.rms_norm: with gm.graph.inserting_after(node): - args = list(node.args) - kwargs = dict(node.kwargs) + # delete normalized_shape from the args (position 1) + args = node.args[:1] + node.args[2:] - # delete normalized_shape from the args - if len(args) > 1: - del args[1] + kwargs = {key: value for key, value in node.kwargs.items()} # delete normalized_shape from the kwargs kwargs.pop("normalized_shape", None) @@ -24,7 +22,7 @@ def replace_rmsnorm(gm: GraphModule, node: Node) -> None: if input is not None: kwargs["x"] = input - new_node = gm.graph.call_function(rmsnorm_cute, args=args, kwargs=kwargs) + new_node = gm.graph.call_function(rmsnorm_cute, args=tuple(args), kwargs=kwargs) node.replace_all_uses_with(new_node) gm.graph.erase_node(node)