From 9f64748f97fa543a2b6b227cd26f570622cd26f1 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 29 Apr 2024 10:09:00 +0800 Subject: [PATCH] [FxImporter] Synchronize the collection of symbolic torch ops (#3236) --- python/torch_mlir/extras/fx_importer.py | 16 ++++------------ python/torch_mlir/fx.py | 4 ++-- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c1eec37aab00..9acf4ad03a77 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -236,12 +236,6 @@ # set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP if _IS_TORCH_2_1_OR_EARLIER: - SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size, - torch.ops.aten.sym_stride, - torch.ops.aten.sym_numel, - } - SYMBOLIC_OP_TO_TORCH_OP = { (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, @@ -249,13 +243,9 @@ (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, } -else: - SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_stride.int, - torch.ops.aten.sym_numel.default, - } + SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP} +else: SYMBOLIC_OP_TO_TORCH_OP = { torch.ops.aten.sym_size.default: torch.ops.aten.size.default, torch.ops.aten.sym_size.int: torch.ops.aten.size.int, @@ -264,6 +254,8 @@ torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, } + SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} + @dataclass(frozen=True) class SparsityMeta: diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 0879dbe31218..651ccae673a6 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Optional, Union, Dict, Tuple, Any +from typing import Optional, Union, Dict, Tuple, Any, Callable import warnings @@ -25,7 +25,7 @@ def export_and_import( dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, - decomposition_table: Optional[list] = None, + decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", enable_graph_printing: bool = False, **kwargs,