Skip to content

Commit

Permalink
[fx] Implement auto_functionalized higher order op. (llvm#3063)
Browse files Browse the repository at this point in the history
* Also adds the basic scaffolding for handling more of these, which will
be needed for cond, while, etc.
* Refactors some of the support in the generic OpOverload emitter so it
can be shared with these other special forms.

This has been on my list for a while, but it just so happens that as
part of upgrading to PyTorch 2.3 and a pure upstream flow in Turbine, we
were using a feature that required integration with auto_functionalized.
This is perhaps the "weirdest" of the higher-order ops and a poor place
to start, but needs must. We have testing for this in Turbine.

Full support in Turbine has an entire custom ops facility. I've reduced
this down to a unit test in torch-mlir.
  • Loading branch information
stellaraccident authored Mar 27, 2024
1 parent 11eaba3 commit e2343cf
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 49 deletions.
193 changes: 147 additions & 46 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from torch._ops import (
OpOverload as TorchOpOverload,
HigherOrderOperator,
)

from torch._subclasses import (
Expand Down Expand Up @@ -1201,6 +1202,8 @@ def import_nodes(
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node, target)
elif isinstance(target, HigherOrderOperator):
self._import_hop(loc, node, target)
else:
raise NotImplementedError(
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
Expand Down Expand Up @@ -1293,6 +1296,78 @@ def _import_symbolic_torch_op(
), f"Unable to parse symbolic operation: {target} with args {node.args}"
self._import_torch_op_overload(loc, node, concrete_target)

def _import_hop(self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator):
# Imports a higher-order operator.
# See: https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565
assert hop.namespace == "higher_order"
hop_name = hop.name()
handler_name = f"_import_hop_{hop_name}"
handler = getattr(self, handler_name, None)
if handler is None:
raise NotImplementedError(
f"Higher-order operation '{hop_name}' not "
f"implemented in the FxImporter "
f"(tried '{handler_name}')"
)
handler(loc, node, hop)

def _import_hop_auto_functionalized(
self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator
):
# Imports the torch._higher_order_ops.auto_functionalize.auto_functionalized HOP.
# This op wraps a target OpOverload with args/kwargs dispatched to it.
# Even thought the OpOverload will return None, this returns the
# arguments mutated. Note that the general op overload importing can't
# be used here as they use a special encoding for everything.
# See: torch/_higher_order_ops/auto_functionalize.py
(op_overload,) = node.args
schema = op_overload._schema
assert isinstance(schema, FunctionSchema)
mlir_op_name = _get_mlir_op_name_for_schema(schema)

# Functionalization transforms the results to (*actual, *aliased).
# If the schema is actually zero return, then the first "val"
# type will be None and we need to bind that as a result of the node.
# However, that doesn't make it into the IR. This special casing is
# annoying.
node_result_types = [
(None if v is None else self._cc.tensor_metadata_to_type(v))
for v in node.meta["val"]
]

if len(schema.returns) == 0:
assert node_result_types[0] is None
ir_result_types = node_result_types[1:]
bind_none = 1
else:
ir_result_types = node_result_types
bind_none = 0

# The auto_functionalized ops maps all arguments by name (as opposed
# to mixed for generic OpOverload). Linearize them.
operands = []
for parameter in schema.arguments:
operand = self._import_argument(
loc, node.kwargs[parameter.name], parameter.type
)
operands.append(operand)

operation = _emit_operation(
mlir_op_name,
result_types=ir_result_types,
operands=operands,
loc=loc,
)

# Special case: if declared_result_types was empty, then we bind a
# None for future node access.
self._multi_result_nodes.add(node)
if bind_none:
self.bind_node_value(node, None, 0)
# Record value mappings for remainder.
for i, value in enumerate(operation.results):
self.bind_node_value(node, value, i + bind_none)

def _import_torch_op_overload(
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
):
Expand Down Expand Up @@ -1322,13 +1397,7 @@ def _import_torch_op_overload(

schema = target._schema
assert isinstance(schema, FunctionSchema)

# Map to a `torch` dialect name.
namespace, sep, unqualified_name = schema.name.partition("::")
assert sep, f"Malformed Torch op name {schema.name}"
mlir_op_name = f"torch.{namespace}.{unqualified_name}"
if schema.overload_name != "":
mlir_op_name += f".{schema.overload_name}"
mlir_op_name = _get_mlir_op_name_for_schema(schema)

# Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments.
if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and (
Expand All @@ -1347,28 +1416,11 @@ def _import_torch_op_overload(
op_overload = getattr(op_overload, op_attrs[i])
schema = op_overload._schema

return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
# if "tensor_meta" is None, this will throw unsupported placeholder node error
result_types = [self._cc.node_val_to_type(node)]
elif return_count == 0:
# Some torch ops do have 0 returns, and these are supported with ZeroResults
# op trait. Python bindings for IR creation allow us to pass empty result_types
# for such ops. Therefore, we pass an empty result types for these cases.
result_types = []
else:
# Multi-return will unpack the meta["val"] and trigger our getitem subscripting
# short-circuit above. Note that if we ever choose to also fully reify Python
# level result tuples, we will need to create a tuple-boxed version of this and
# redirect to it for generic object access.

result_types = []
for v in node.meta["val"]:
result_types.append(self._cc.tensor_metadata_to_type(v))
result_types = tuple(result_types)

# Convert result types.
result_types = self._unpack_node_result_types(node, schema)
if len(result_types) > 1:
self._multi_result_nodes.add(node)

# Unroll operands from formal parameters, args and kwargs.
operands = []
for i, parameter in enumerate(schema.arguments):
Expand All @@ -1389,24 +1441,9 @@ def _import_torch_op_overload(
)
)

# Support unregistered torch ops using torch.operator.
# torch.operator is used to represent ops from registry
# which haven't been generated by torch_ods_gen.py.
if not self._c.is_registered_operation(mlir_op_name):
operation = Operation.create(
"torch.operator",
attributes={"name": StringAttr.get(mlir_op_name)},
results=result_types,
operands=operands,
loc=loc,
)
else:
operation = Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
loc=loc,
)
operation = _emit_operation(
mlir_op_name, result_types=result_types, operands=operands, loc=loc
)

# Record value mapping.
for i, value in enumerate(operation.results):
Expand Down Expand Up @@ -1571,6 +1608,29 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
with loc:
return cvt(arg, self, self._cc)

def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema):
return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
# if "tensor_meta" is None, this will throw unsupported placeholder node error
result_types = [self._cc.node_val_to_type(node)]
elif return_count == 0:
# Some torch ops do have 0 returns, and these are supported with ZeroResults
# op trait. Python bindings for IR creation allow us to pass empty result_types
# for such ops. Therefore, we pass an empty result types for these cases.
result_types = []
else:
# Multi-return will unpack the meta["val"] and trigger our getitem subscripting
# short-circuit above. Note that if we ever choose to also fully reify Python
# level result tuples, we will need to create a tuple-boxed version of this and
# redirect to it for generic object access.

result_types = []
for v in node.meta["val"]:
result_types.append(self._cc.tensor_metadata_to_type(v))
result_types = tuple(result_types)
return result_types


def _make_constant_op(
op_name: str, value_attr: Attribute, result_type: Optional[IrType] = None
Expand Down Expand Up @@ -1686,6 +1746,47 @@ def lookup(self, t: type) -> Any:
return None


###############################################################################
# Utilities
###############################################################################


def _get_mlir_op_name_for_schema(schema: FunctionSchema) -> str:
# Returns a fully-qualified MLIR operation name (i.e. 'torch.foobar')
# for a function schema.
namespace, sep, unqualified_name = schema.name.partition("::")
assert sep, f"Malformed Torch op name {schema.name}"
mlir_op_name = f"torch.{namespace}.{unqualified_name}"
if schema.overload_name != "":
mlir_op_name += f".{schema.overload_name}"
return mlir_op_name


def _emit_operation(
mlir_op_name: str, result_types: List[IrType], operands: List[Value], loc: Location
) -> Operation:
# Support unregistered torch ops using torch.operator.
# torch.operator is used to represent ops from registry
# which haven't been generated by torch_ods_gen.py.
context = loc.context
if not context.is_registered_operation(mlir_op_name):
operation = Operation.create(
"torch.operator",
attributes={"name": StringAttr.get(mlir_op_name)},
results=result_types,
operands=operands,
loc=loc,
)
else:
operation = Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
loc=loc,
)
return operation


###############################################################################
# Reference mapping
###############################################################################
Expand Down
8 changes: 6 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table


def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[list] = None,
func_name: str = "main",
**kwargs,
):
Expand All @@ -32,8 +34,10 @@ def export_and_import(
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
decomp_table = get_decomposition_table()
prog = prog.run_decompositions(decomp_table)
if decomposition_table is None:
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
if experimental_support_mutation:
if torch.__version__ < "2.3.0.dev20240207":
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
Expand Down
92 changes: 92 additions & 0 deletions test/python/fx_importer/v2.3/auto_functionalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

# RUN: %PYTHON %s | FileCheck %s

from typing import Optional

import torch
import torch.export
import torch.nn as nn

from torch_mlir import fx

from torch_mlir.ir import (
Operation,
)


LIBRARY = torch.library.Library("torch_mlir_test", "DEF")

LIBRARY.define("inplace_modify(Tensor(a!) x) -> ()")
LIBRARY.define("inplace_modify_calc(Tensor(a!) x) -> (Tensor)")


def inplace_modify_calc_meta(x):
return torch.empty_like(x)


LIBRARY.impl("inplace_modify_calc", inplace_modify_calc_meta, "Meta")


def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()


# CHECK-LABEL: test_auto_functionalized_hop
@run
def test_auto_functionalized_hop():
class Basic(nn.Module):
def forward(self, x):
torch.ops.torch_mlir_test.inplace_modify(x)
return x * x

m = fx.export_and_import(
Basic(),
torch.randn(3, 4),
experimental_support_mutation=True,
# TODO: ExportedProgram.run_decompositions() seems to have trouble
# with mode selection and Python higher order op implementations.
# Isolate and report upstream.
# Raises:
# File "/home/stella/v/Dev/lib/python3.11/site-packages/torch/_ops.py", line 323, in dispatch
# assert (
# AssertionError: Current active mode <torch._subclasses.functional_tensor.FunctionalTensorMode object at 0x7a1106504fd0> not registered
decomposition_table=[],
)
# CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]]
print(m)
m.operation.verify()


# CHECK-LABEL: test_auto_functionalized_one_ret
@run
def test_auto_functionalized_one_ret():
class Basic(nn.Module):
def forward(self, x):
y = torch.ops.torch_mlir_test.inplace_modify_calc(x)
return x * y

m = fx.export_and_import(
Basic(),
torch.randn(3, 4),
experimental_support_mutation=True,
# TODO: ExportedProgram.run_decompositions() seems to have trouble
# with mode selection and Python higher order op implementations.
# Isolate and report upstream.
# Raises:
# File "/home/stella/v/Dev/lib/python3.11/site-packages/torch/_ops.py", line 323, in dispatch
# assert (
# AssertionError: Current active mode <torch._subclasses.functional_tensor.FunctionalTensorMode object at 0x7a1106504fd0> not registered
decomposition_table=[],
)
# CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>)
# CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0
print(m)
m.operation.verify()
2 changes: 1 addition & 1 deletion test/python/fx_importer/v2.3/lit.local.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ config.unsupported = True

try:
import torch
if torch.__version__ >= "2.3.0.dev20240207":
if torch.__version__ >= "2.3.0":
print("Enabling Torch v2.3+ tests")
config.unsupported = False
except ModuleNotFoundError:
Expand Down

0 comments on commit e2343cf

Please sign in to comment.