Skip to content

Commit

Permalink
Extract getitem replacement and meta value fixing into insert_fused_node
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Dec 5, 2024
1 parent 14b1902 commit e191d58
Showing 1 changed file with 101 additions and 82 deletions.
183 changes: 101 additions & 82 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import abc
import operator
from abc import abstractmethod
from typing import Callable, Iterable, List, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import torch
import torch._inductor.pattern_matcher as pm
# TODO(luka) use vllm.utils once #10836 landed
from compressed_tensors.quantization import FP8_DTYPE
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

Expand All @@ -30,33 +31,31 @@ def empty_fp32(*args, **kwargs):


# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
op) -> Optional[torch.fx.Node]:
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None


# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
def find_auto_fn(nodes: Iterable[fx.Node], op) -> fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node


# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: torch.fx.Node,
idx: int) -> Optional[torch.fx.Node]:
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None


# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
Expand Down Expand Up @@ -104,14 +103,14 @@ def process(self):
raise NotImplementedError

@property
def nodes(self) -> List[torch.fx.Node]:
def nodes(self) -> List[fx.Node]:
return self.match.nodes

@property
def graph(self) -> torch.fx.Graph:
def graph(self) -> fx.Graph:
return self.match.graph

def find_auto_fn(self, op) -> torch.fx.Node:
def find_auto_fn(self, op) -> fx.Node:
"""
Find the first auto_functionalized node with the given op in the match.
"""
Expand All @@ -134,8 +133,8 @@ def inserting_after_match(self):

return self.graph.inserting_after(last_node_in_match)

def insert_getitems(self, tuple_node: torch.fx.Node,
indices: Tuple[int, ...]) -> Tuple[torch.fx.Node, ...]:
def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> Tuple[fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.
Expand All @@ -160,7 +159,6 @@ def insert_auto_fn(self, op, kwargs):
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

# Key: (fp8/int8, static/dynamic, per-tensor/per-token, symmetric/asymmetric)
# Value: the torch op
QUANT_OPS = {
(FP8_DTYPE, True, True, True):
torch.ops._C.static_scaled_fp8_quant.default,
Expand All @@ -183,6 +181,66 @@ def insert_auto_fn(self, op, kwargs):
}


class QuantMultiOutputMatch(MultiOutputMatch):

def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
self.QUANT_OP = quant_op
self.FUSED_OP = fused_op

def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
int]],
**kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.
:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node
Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)
with
_, x1, y2, x2 = auto_fn(FUSED_OP)
we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
Note that the 0th element is None for auto-functionalized in-place ops.
Hence others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)

# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)

# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]

# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]

# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]

# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)


class RMSNormQuantPattern:

def __init__(self,
Expand Down Expand Up @@ -212,13 +270,6 @@ def __init__(self,
f" for quant scheme {keystr()})")
self.FUSED_OP = FUSED_OPS[key2]

class Match(MultiOutputMatch):

def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
self.QUANT_OP = quant_op
self.FUSED_OP = fused_op


class RMSNormStaticQuantPattern(RMSNormQuantPattern):

Expand Down Expand Up @@ -339,7 +390,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))

class Match(RMSNormQuantPattern.Match):
class Match(QuantMultiOutputMatch):

def process(self):
# Find the nodes in the match that we need to rebind
Expand All @@ -358,26 +409,14 @@ def process(self):
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]

# TODO simplify
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
getitem_nodes = self.insert_getitems(fused_node, (1, 2))
result_node_new, residual_node_new = getitem_nodes

# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)

# Finally, fix meta["val"] for de-functionalization.
# See MultiOutputMatch.process for more details.
rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"]
# Result of fused node is (None, result, residual)
fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2])
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
**kwargs)


class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
Expand Down Expand Up @@ -446,7 +485,7 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))

class Match(RMSNormQuantPattern.Match):
class Match(QuantMultiOutputMatch):

def process(self):
# Find the nodes in the match that we need to rebind
Expand All @@ -465,28 +504,17 @@ def process(self):
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]
kwargs["scale_ub"] = None # not used but required
kwargs["residual"] = None # not used but required
del kwargs["result_rms"] # not used in the fused op

# TODO simplify
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs)
getitem_nodes = self.insert_getitems(fused_node, (1, 2))
result_node_new, scale_node_new = getitem_nodes

# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new)

# Finally, fix meta["val"] for de-functionalization.
# See MultiOutputMatch.process for more details.
# Result of fused node is (None, result, scale)
fused_node.meta["val"] = quant_node.meta["val"]
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
residual=None, # not used but required
**kwargs)


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
Expand Down Expand Up @@ -555,7 +583,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))

class Match(RMSNormQuantPattern.Match):
class Match(QuantMultiOutputMatch):

def process(self):
# Find the nodes in the match that we need to rebind
Expand All @@ -575,28 +603,19 @@ def process(self):
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]
kwargs["scale_ub"] = None # not used but required

fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs)
getitem_ns = self.insert_getitems(fused_node, (1, 2, 3))
result_node_new, scale_node_new, residual_node_new = getitem_ns

# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new)

# Finally, fix meta["val"] for de-functionalization.
# See MultiOutputMatch.process for more details.
rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"]
# Result of fused node is (None, result, scale, residual)
fused_node.meta["val"] = (None, quant_tup[1], quant_tup[2],
rms_tup[2])
fused_return_mapping = {
1: (quant_node, 1), # result
2: (quant_node, 2), # scale
3: (rms_node, 2), # residual
}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
**kwargs)


class FusionPass(VllmInductorPass):
Expand Down Expand Up @@ -671,7 +690,7 @@ def record_match(self, match: MultiOutputMatch) -> bool:
# Return False to prevent automatic replacement.
return False

def process_matches(self, graph: torch.fx.Graph):
def process_matches(self, graph: fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
See MultiOutputMatch for more details.
Expand All @@ -684,7 +703,7 @@ def process_matches(self, graph: torch.fx.Graph):
assert all(node not in graph.nodes for match in self.matches
for node in match.match.nodes)

def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")

Expand Down

0 comments on commit e191d58

Please sign in to comment.