diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index ba5ecc2895d7..d49524b77a13 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -44,6 +44,7 @@ from torch._ops import ( OpOverload as TorchOpOverload, + HigherOrderOperator, ) from torch._subclasses import ( @@ -1201,6 +1202,8 @@ def import_nodes( elif isinstance(target, TorchOpOverload): # Dispatch to an ATen op. self._import_torch_op_overload(loc, node, target) + elif isinstance(target, HigherOrderOperator): + self._import_hop(loc, node, target) else: raise NotImplementedError( f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}" @@ -1293,6 +1296,78 @@ def _import_symbolic_torch_op( ), f"Unable to parse symbolic operation: {target} with args {node.args}" self._import_torch_op_overload(loc, node, concrete_target) + def _import_hop(self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator): + # Imports a higher-order operator. + # See: https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565 + assert hop.namespace == "higher_order" + hop_name = hop.name() + handler_name = f"_import_hop_{hop_name}" + handler = getattr(self, handler_name, None) + if handler is None: + raise NotImplementedError( + f"Higher-order operation '{hop_name}' not " + f"implemented in the FxImporter " + f"(tried '{handler_name}')" + ) + handler(loc, node, hop) + + def _import_hop_auto_functionalized( + self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator + ): + # Imports the torch._higher_order_ops.auto_functionalize.auto_functionalized HOP. + # This op wraps a target OpOverload with args/kwargs dispatched to it. + # Even thought the OpOverload will return None, this returns the + # arguments mutated. Note that the general op overload importing can't + # be used here as they use a special encoding for everything. + # See: torch/_higher_order_ops/auto_functionalize.py + (op_overload,) = node.args + schema = op_overload._schema + assert isinstance(schema, FunctionSchema) + mlir_op_name = _get_mlir_op_name_for_schema(schema) + + # Functionalization transforms the results to (*actual, *aliased). + # If the schema is actually zero return, then the first "val" + # type will be None and we need to bind that as a result of the node. + # However, that doesn't make it into the IR. This special casing is + # annoying. + node_result_types = [ + (None if v is None else self._cc.tensor_metadata_to_type(v)) + for v in node.meta["val"] + ] + + if len(schema.returns) == 0: + assert node_result_types[0] is None + ir_result_types = node_result_types[1:] + bind_none = 1 + else: + ir_result_types = node_result_types + bind_none = 0 + + # The auto_functionalized ops maps all arguments by name (as opposed + # to mixed for generic OpOverload). Linearize them. + operands = [] + for parameter in schema.arguments: + operand = self._import_argument( + loc, node.kwargs[parameter.name], parameter.type + ) + operands.append(operand) + + operation = _emit_operation( + mlir_op_name, + result_types=ir_result_types, + operands=operands, + loc=loc, + ) + + # Special case: if declared_result_types was empty, then we bind a + # None for future node access. + self._multi_result_nodes.add(node) + if bind_none: + self.bind_node_value(node, None, 0) + # Record value mappings for remainder. + for i, value in enumerate(operation.results): + self.bind_node_value(node, value, i + bind_none) + def _import_torch_op_overload( self, loc: Location, node: torch_fx.Node, target: TorchOpOverload ): @@ -1322,13 +1397,7 @@ def _import_torch_op_overload( schema = target._schema assert isinstance(schema, FunctionSchema) - - # Map to a `torch` dialect name. - namespace, sep, unqualified_name = schema.name.partition("::") - assert sep, f"Malformed Torch op name {schema.name}" - mlir_op_name = f"torch.{namespace}.{unqualified_name}" - if schema.overload_name != "": - mlir_op_name += f".{schema.overload_name}" + mlir_op_name = _get_mlir_op_name_for_schema(schema) # Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments. if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and ( @@ -1347,28 +1416,11 @@ def _import_torch_op_overload( op_overload = getattr(op_overload, op_attrs[i]) schema = op_overload._schema - return_count = len(schema.returns) - if return_count == 1: - # Unary return directly maps a single meta["val"] and cannot be subscripted. - # if "tensor_meta" is None, this will throw unsupported placeholder node error - result_types = [self._cc.node_val_to_type(node)] - elif return_count == 0: - # Some torch ops do have 0 returns, and these are supported with ZeroResults - # op trait. Python bindings for IR creation allow us to pass empty result_types - # for such ops. Therefore, we pass an empty result types for these cases. - result_types = [] - else: - # Multi-return will unpack the meta["val"] and trigger our getitem subscripting - # short-circuit above. Note that if we ever choose to also fully reify Python - # level result tuples, we will need to create a tuple-boxed version of this and - # redirect to it for generic object access. - - result_types = [] - for v in node.meta["val"]: - result_types.append(self._cc.tensor_metadata_to_type(v)) - result_types = tuple(result_types) - + # Convert result types. + result_types = self._unpack_node_result_types(node, schema) + if len(result_types) > 1: self._multi_result_nodes.add(node) + # Unroll operands from formal parameters, args and kwargs. operands = [] for i, parameter in enumerate(schema.arguments): @@ -1389,24 +1441,9 @@ def _import_torch_op_overload( ) ) - # Support unregistered torch ops using torch.operator. - # torch.operator is used to represent ops from registry - # which haven't been generated by torch_ods_gen.py. - if not self._c.is_registered_operation(mlir_op_name): - operation = Operation.create( - "torch.operator", - attributes={"name": StringAttr.get(mlir_op_name)}, - results=result_types, - operands=operands, - loc=loc, - ) - else: - operation = Operation.create( - mlir_op_name, - results=result_types, - operands=operands, - loc=loc, - ) + operation = _emit_operation( + mlir_op_name, result_types=result_types, operands=operands, loc=loc + ) # Record value mapping. for i, value in enumerate(operation.results): @@ -1571,6 +1608,29 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: with loc: return cvt(arg, self, self._cc) + def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema): + return_count = len(schema.returns) + if return_count == 1: + # Unary return directly maps a single meta["val"] and cannot be subscripted. + # if "tensor_meta" is None, this will throw unsupported placeholder node error + result_types = [self._cc.node_val_to_type(node)] + elif return_count == 0: + # Some torch ops do have 0 returns, and these are supported with ZeroResults + # op trait. Python bindings for IR creation allow us to pass empty result_types + # for such ops. Therefore, we pass an empty result types for these cases. + result_types = [] + else: + # Multi-return will unpack the meta["val"] and trigger our getitem subscripting + # short-circuit above. Note that if we ever choose to also fully reify Python + # level result tuples, we will need to create a tuple-boxed version of this and + # redirect to it for generic object access. + + result_types = [] + for v in node.meta["val"]: + result_types.append(self._cc.tensor_metadata_to_type(v)) + result_types = tuple(result_types) + return result_types + def _make_constant_op( op_name: str, value_attr: Attribute, result_type: Optional[IrType] = None @@ -1686,6 +1746,47 @@ def lookup(self, t: type) -> Any: return None +############################################################################### +# Utilities +############################################################################### + + +def _get_mlir_op_name_for_schema(schema: FunctionSchema) -> str: + # Returns a fully-qualified MLIR operation name (i.e. 'torch.foobar') + # for a function schema. + namespace, sep, unqualified_name = schema.name.partition("::") + assert sep, f"Malformed Torch op name {schema.name}" + mlir_op_name = f"torch.{namespace}.{unqualified_name}" + if schema.overload_name != "": + mlir_op_name += f".{schema.overload_name}" + return mlir_op_name + + +def _emit_operation( + mlir_op_name: str, result_types: List[IrType], operands: List[Value], loc: Location +) -> Operation: + # Support unregistered torch ops using torch.operator. + # torch.operator is used to represent ops from registry + # which haven't been generated by torch_ods_gen.py. + context = loc.context + if not context.is_registered_operation(mlir_op_name): + operation = Operation.create( + "torch.operator", + attributes={"name": StringAttr.get(mlir_op_name)}, + results=result_types, + operands=operands, + loc=loc, + ) + else: + operation = Operation.create( + mlir_op_name, + results=result_types, + operands=operands, + loc=loc, + ) + return operation + + ############################################################################### # Reference mapping ############################################################################### diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 298c66e2681b..9a3ea3850477 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -16,6 +16,7 @@ from torch_mlir.dialects import torch as torch_d from torch_mlir.extras.fx_decomp_util import get_decomposition_table + def export_and_import( f, *args, @@ -23,6 +24,7 @@ def export_and_import( dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, + decomposition_table: Optional[list] = None, func_name: str = "main", **kwargs, ): @@ -32,8 +34,10 @@ def export_and_import( if fx_importer is None: fx_importer = FxImporter(context=context, hooks=hooks) prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) - decomp_table = get_decomposition_table() - prog = prog.run_decompositions(decomp_table) + if decomposition_table is None: + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") diff --git a/test/python/fx_importer/v2.3/auto_functionalized.py b/test/python/fx_importer/v2.3/auto_functionalized.py new file mode 100644 index 000000000000..ab7401dcc2fb --- /dev/null +++ b/test/python/fx_importer/v2.3/auto_functionalized.py @@ -0,0 +1,92 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir import fx + +from torch_mlir.ir import ( + Operation, +) + + +LIBRARY = torch.library.Library("torch_mlir_test", "DEF") + +LIBRARY.define("inplace_modify(Tensor(a!) x) -> ()") +LIBRARY.define("inplace_modify_calc(Tensor(a!) x) -> (Tensor)") + + +def inplace_modify_calc_meta(x): + return torch.empty_like(x) + + +LIBRARY.impl("inplace_modify_calc", inplace_modify_calc_meta, "Meta") + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +# CHECK-LABEL: test_auto_functionalized_hop +@run +def test_auto_functionalized_hop(): + class Basic(nn.Module): + def forward(self, x): + torch.ops.torch_mlir_test.inplace_modify(x) + return x * x + + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + experimental_support_mutation=True, + # TODO: ExportedProgram.run_decompositions() seems to have trouble + # with mode selection and Python higher order op implementations. + # Isolate and report upstream. + # Raises: + # File "/home/stella/v/Dev/lib/python3.11/site-packages/torch/_ops.py", line 323, in dispatch + # assert ( + # AssertionError: Current active mode not registered + decomposition_table=[], + ) + # CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + # CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]] + print(m) + m.operation.verify() + + +# CHECK-LABEL: test_auto_functionalized_one_ret +@run +def test_auto_functionalized_one_ret(): + class Basic(nn.Module): + def forward(self, x): + y = torch.ops.torch_mlir_test.inplace_modify_calc(x) + return x * y + + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + experimental_support_mutation=True, + # TODO: ExportedProgram.run_decompositions() seems to have trouble + # with mode selection and Python higher order op implementations. + # Isolate and report upstream. + # Raises: + # File "/home/stella/v/Dev/lib/python3.11/site-packages/torch/_ops.py", line 323, in dispatch + # assert ( + # AssertionError: Current active mode not registered + decomposition_table=[], + ) + # CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) + # CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0 + print(m) + m.operation.verify() diff --git a/test/python/fx_importer/v2.3/lit.local.cfg b/test/python/fx_importer/v2.3/lit.local.cfg index b10b239f8b3a..00c613754f64 100644 --- a/test/python/fx_importer/v2.3/lit.local.cfg +++ b/test/python/fx_importer/v2.3/lit.local.cfg @@ -2,7 +2,7 @@ config.unsupported = True try: import torch - if torch.__version__ >= "2.3.0.dev20240207": + if torch.__version__ >= "2.3.0": print("Enabling Torch v2.3+ tests") config.unsupported = False except ModuleNotFoundError: