Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Dec 18, 2024
1 parent 77cbce8 commit 42c0096
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions cute_kernels/cute_inductor/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
def replace_rmsnorm(gm: GraphModule, node: Node) -> None:
if node.op == CALL_FUNCTION and node.target == torch.rms_norm:
with gm.graph.inserting_after(node):
args = list(node.args)
kwargs = dict(node.kwargs)
# delete normalized_shape from the args (position 1)
args = node.args[:1] + node.args[2:]

# delete normalized_shape from the args
if len(args) > 1:
del args[1]
kwargs = {key: value for key, value in node.kwargs.items()}

# delete normalized_shape from the kwargs
kwargs.pop("normalized_shape", None)
Expand All @@ -24,7 +22,7 @@ def replace_rmsnorm(gm: GraphModule, node: Node) -> None:
if input is not None:
kwargs["x"] = input

new_node = gm.graph.call_function(rmsnorm_cute, args=args, kwargs=kwargs)
new_node = gm.graph.call_function(rmsnorm_cute, args=tuple(args), kwargs=kwargs)

node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)

0 comments on commit 42c0096

Please sign in to comment.