From d85c3a572c66dfb0e05528e1679e9af25e548157 Mon Sep 17 00:00:00 2001 From: Scott Lee <40186387+xTayEx@users.noreply.github.com> Date: Tue, 24 Oct 2023 20:07:30 +0800 Subject: [PATCH] [OSPP] Implementation of TorchDynamo compiler (#208) * [frontend] Add tosa operators for python frontend * [frontend] Support keyword arguments in importer * [frontend] Add docstring for tosa operators * [frontend] Add README.md for `BuddyPython` * [frontend] Add tests for tosa operator conversion functions --- examples/BuddyPython/README.md | 105 ++++ examples/BuddyPython/bert.py | 18 + examples/BuddyPython/module_gen.py | 2 +- examples/MLIRPython/.style.yapf | 4 - examples/MLIRPython/addmm.py | 20 - examples/MLIRPython/arith_add.py | 17 - examples/MLIRPython/buddy/compiler.py | 177 ------ examples/MLIRPython/buddy/operators_gen.py | 81 --- examples/MLIRPython/matmul.py | 11 - frontend/Python/frontend.py | 7 +- frontend/Python/ops/tosa.py | 672 ++++++++++++++++++++- tests/Python/test_addmm.py | 35 ++ tests/Python/test_amax.py | 34 ++ tests/Python/test_arith_div.py | 34 ++ tests/Python/test_arith_mul.py | 33 + tests/Python/test_arith_sub.py | 33 + tests/Python/test_bmm.py | 33 + tests/Python/test_clone.py | 32 + tests/Python/test_convert_element_type.py | 33 + tests/Python/test_embedding.py | 58 ++ tests/Python/test_exp.py | 32 + tests/Python/test_expand.py | 32 + tests/Python/test_permute.py | 33 + tests/Python/test_reshape.py | 33 + tests/Python/test_rsqrt.py | 32 + tests/Python/test_select.py | 35 ++ tests/Python/test_slice.py | 35 ++ tests/Python/test_sum.py | 34 ++ tests/Python/test_tanh.py | 32 + tests/Python/test_unsqueeze.py | 33 + tests/Python/test_var_mean.py | 66 ++ tests/Python/test_view.py | 33 + 32 files changed, 1552 insertions(+), 317 deletions(-) create mode 100644 examples/BuddyPython/README.md create mode 100644 examples/BuddyPython/bert.py delete mode 100644 examples/MLIRPython/.style.yapf delete mode 100644 examples/MLIRPython/addmm.py delete mode 100644 examples/MLIRPython/arith_add.py delete mode 100644 examples/MLIRPython/buddy/compiler.py delete mode 100644 examples/MLIRPython/buddy/operators_gen.py delete mode 100644 examples/MLIRPython/matmul.py create mode 100644 tests/Python/test_addmm.py create mode 100644 tests/Python/test_amax.py create mode 100644 tests/Python/test_arith_div.py create mode 100644 tests/Python/test_arith_mul.py create mode 100644 tests/Python/test_arith_sub.py create mode 100644 tests/Python/test_bmm.py create mode 100644 tests/Python/test_clone.py create mode 100644 tests/Python/test_convert_element_type.py create mode 100644 tests/Python/test_embedding.py create mode 100644 tests/Python/test_exp.py create mode 100644 tests/Python/test_expand.py create mode 100644 tests/Python/test_permute.py create mode 100644 tests/Python/test_reshape.py create mode 100644 tests/Python/test_rsqrt.py create mode 100644 tests/Python/test_select.py create mode 100644 tests/Python/test_slice.py create mode 100644 tests/Python/test_sum.py create mode 100644 tests/Python/test_tanh.py create mode 100644 tests/Python/test_unsqueeze.py create mode 100644 tests/Python/test_var_mean.py create mode 100644 tests/Python/test_view.py diff --git a/examples/BuddyPython/README.md b/examples/BuddyPython/README.md new file mode 100644 index 0000000000..a84a4e9fd5 --- /dev/null +++ b/examples/BuddyPython/README.md @@ -0,0 +1,105 @@ +# Buddy Compiler Python Importer +## Introduction +This package serves as the PyTorch importer of Buddy Compiler. It is built on top of TorchDynamo, a Python-level JIT compiler introduced in PyTorch 2.0. Using this importer, one can convert a PyTorch function/model to corresponding MLIR code. + +## Quick Start + +### Prerequisites +MLIR Python Bindings is required for this importer. Run below commands to build it. + +```bash +## Build MLIR Python Bindings + +Build MLIR Python Binding in Buddy-MLIR. + +// [Option] Enter your Python virtual environment. +$ cd llvm +$ python3 -m pip install -r mlir/python/requirements.txt +$ cd build +$ cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_TARGETS_TO_BUILD="host;RISCV" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=[path_to_python_executable] +$ ninja check-mlir +``` + +Add MLIR Python bindings to your Python path. +```bash +// In the LLVM build dirctory. +$ export PYTHONPATH=$(pwd)/tools/mlir/python_packages/mlir_core +``` + +Test the MLIR python bindings environment. + +```python +$ python3 +>>> from mlir.ir import Context, Module +>>> ... +``` + +### Demo +Run the following code to generate MLIR code for the `foo` function. +```python +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + +# Define the target function or model. +def foo(x, y): + return x * y + x + + +# Define the input tensors +in1 = torch.randn(10) +in2 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +module, _ = dynamo_compiler.importer(foo, *(in1, in2)) + +print(module) +``` +If everything works well, the output should be as below. +```mlir +module { + func.func @forward(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { + %0 = "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %1 = "tosa.add"(%0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + return %1 : tensor<10xf32> + } +} +``` + +For more demos, please refer to [examples/BuddyPython](https://github.com/buddy-compiler/buddy-mlir/tree/main/examples/BuddyPython). We currently offer two demos below. + +* `module_gen.py`: A more detailed version of the quick start demo. +* `bert.py`: Import a [bert-base-uncased](https://huggingface.co/bert-base-uncased) model, convert it to MLIR code. + +## Methodology +[TorchDynamo](https://pytorch.org/docs/stable/dynamo/index.html) is a cutting-edge Python-level JIT compiler introduced in PyTorch 2.0, designed to make unmodified PyTorch programs faster. It achieves this by hooking into the frame evaluation API of CPython to rewrite the bytecode before it's executed. This process extract the sequences of PyTorch operations into a FX graph which is then just-in-time compiled with a compiler backend. While TorchInductor serves as the default backend, PyTorch 2.0 also offers an interface for custom compiler backends. This is the main entry point that help us implement this importer. + +### Operator + +* **Operator Mappings**: What this importer do is to convert a piece of PyTorch code to the corresponding MLIR code. To achieve it, we write some conversion functions that map PyTorch's operators to MLIR code snippets. Currently, we've mapped about 20 operators. For what operators are supported, please refer to the [frontend/Python/ops](https://github.com/buddy-compiler/buddy-mlir/tree/main/frontend/Python/ops) directory. + +* **Operator Registries**: We organize the operator mapping functions using operator registries. Each operator registry is a Python dict that maps the PyTorch operator's name to its corresponding mapping function. Currently, we've offer three operator registries, i.e. `tosa`, `math` and `linalg`. The registry name stands for the main MLIR dialect that used to implement a operator. + + +### Symbol Table +In PyTorch FX graph, there exist dependencies between operators. These dependencies represent the inputs and outpus of each operator. To handle the dependencies between operators and generate MLIR code for the whole FX graph, during the importing process, the importer will build a symbol table. This symbol table is a Python dict that maps the operator's name to the their corresponding MLIR operation. When a new PyTorch operator is going to be imported, the importer will search the symbol table for its inputs, i.e. the operator's argument(s), and the inputs' MLIR code snippet. After that, the importer will generate the MLIR code snippet for the operator and add it to the symbol table. This process will be repeated until the whole FX graph are imported. + +### Import Strategy +In order to make the importing procedure more robust, we've implement a fallback importing strategy. This machenism is consisted of two parts, i.e. primary registry and fallback registry. When importer is going to import a PyTorch operator, it will first search the primary registry for the operator's mapping function. If the operator is not found in the primary registry, the importer will try to search the fallback registry. By default, the importer will use `tosa` registry as the primary registry, and all the other registries as the fallback registry. + +## Limitations +Currently, we only support AOT execution of the generated MLIR code. To execute the generated MLIR code, one need to use the llvm tooltrain to compile it to an executable binary. We are working on the JIT execution of the generated MLIR code. diff --git a/examples/BuddyPython/bert.py b/examples/BuddyPython/bert.py new file mode 100644 index 0000000000..d7c254cebe --- /dev/null +++ b/examples/BuddyPython/bert.py @@ -0,0 +1,18 @@ +from torch._inductor.decomposition import decompositions as inductor_decomp +from transformers import BertModel, BertTokenizer +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + +model = BertModel.from_pretrained("bert-base-uncased") +model.eval() +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp +) + +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +text = "Replace me by any text you'd like." +encoded_text = tokenizer(text, return_tensors="pt") +module, params = dynamo_compiler.importer(model, **encoded_text) +print(module) +print(params) diff --git a/examples/BuddyPython/module_gen.py b/examples/BuddyPython/module_gen.py index 81642dd93d..10a1e2ee1c 100644 --- a/examples/BuddyPython/module_gen.py +++ b/examples/BuddyPython/module_gen.py @@ -46,7 +46,7 @@ def foo(x, y): # The first way to generate an MLIR Module: # Pass the function and input data to the dynamo compiler's importer, # and accepts the generated module and weight parameters. -module, params = dynamo_compiler.importer(foo, (float32_in1, float32_in2)) +module, params = dynamo_compiler.importer(foo, *(float32_in1, float32_in2)) print(module) print(params) diff --git a/examples/MLIRPython/.style.yapf b/examples/MLIRPython/.style.yapf deleted file mode 100644 index 9ef1dc15ba..0000000000 --- a/examples/MLIRPython/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 diff --git a/examples/MLIRPython/addmm.py b/examples/MLIRPython/addmm.py deleted file mode 100644 index 29cf695635..0000000000 --- a/examples/MLIRPython/addmm.py +++ /dev/null @@ -1,20 +0,0 @@ -from buddy.compiler import DynamoCompiler -import torch -import torch._dynamo as dynamo - - -def foo(c, a, b): - return torch.addmm(c, a, b) - - -foo_mlir = dynamo.optimize(DynamoCompiler)(foo) - -a_float32 = torch.randn(3, 2) -b_float32 = torch.randn(2, 3) -c_float32 = torch.randn(3, 3) -foo_mlir(c_float32, a_float32, b_float32) - -a_int32 = torch.randint(10, (3, 2)).to(torch.int32) -b_int32 = torch.randint(10, (2, 3)).to(torch.int32) -c_int32 = torch.randint(10, (3, 3)).to(torch.int32) -foo_mlir(c_int32, a_int32, b_int32) diff --git a/examples/MLIRPython/arith_add.py b/examples/MLIRPython/arith_add.py deleted file mode 100644 index 0c2b7a11cf..0000000000 --- a/examples/MLIRPython/arith_add.py +++ /dev/null @@ -1,17 +0,0 @@ -from buddy import compiler -import torch -import torch._dynamo as dynamo - - -def foo(x, y): - return x + y - - -foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) -float32_in1 = torch.randn(10).to(torch.float32) -float32_in2 = torch.randn(10).to(torch.float32) -foo_mlir(float32_in1, float32_in2) - -int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32) -int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32) -foo_mlir(int32_in1, int32_in2) diff --git a/examples/MLIRPython/buddy/compiler.py b/examples/MLIRPython/buddy/compiler.py deleted file mode 100644 index 6a24fcfef3..0000000000 --- a/examples/MLIRPython/buddy/compiler.py +++ /dev/null @@ -1,177 +0,0 @@ -"""The buddy compiler backend for torch dynamo. -""" -import operator -from typing import List, Union, Callable - -import torch -from torch._functorch.aot_autograd import aot_module_simplified -import mlir.ir as ir -import mlir.dialects.func as func -from mlir.passmanager import PassManager - -from .operators_gen import operation_func - - -def DynamoCompiler(gm: torch.fx.GraphModule, - inputs: List[torch.Tensor]) -> Callable: - """The main entry point of buddy compiler for torch dynamo. It takes a FX - graph module and a list of inputs as parameters. The compiler will first use - PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then - it will map the operators in Aten/Prims IR to MLIR operations and generate an - MLIR module. Finally, It will lower the MLIR module to LLVM dialect. - - Args: - gm (torch.fx.GraphModule): The FX graph module to be compiled. - inputs (List[torch.Tensor]): The inputs of the FX graph module. - - Returns: - Callable: A compiled function that equivalent to the FX graph. - - """ - - def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): - """Compile a FX graph in Aten/Prims IR to MLIR.""" - print("Custom Compiler from FX Graph to MLIR:") - print("-------------------------------------------------------------------") - gm.graph.print_tabular() - # Initialize the MLIR context. - ctx = ir.Context() - with ir.Location.unknown(ctx): - fx_importer = FXGraphImporter(gm, inputs) - module = fx_importer.import_graph() - module = Lowering(module) - return gm.forward - - return aot_module_simplified(gm, inputs, fw_compiler=_compiler) - - -class FXGraphImporter: - """The FX graph importer class.""" - - def __init__( - self, - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], - func_name: str = "main", - ): - """ - Args: - gm (torch.fx.GraphModule): The FX graph module that will be imported. - inputs (List[torch.Tensor]): Input tensor(s) of the FX graph. - func_name (str): Name of the generated MLIR func. - - """ - self._symbol_table = {} - self._gm = gm - self._func_name = func_name - self._inputs = inputs - self._num_input_visited = 0 - self._module = ir.Module.create() - - def import_graph(self) -> ir.Module: - """Import the FX graph, generate an MLIR module in high-level dialects. - - Returns: - mlir.ir.Module: An MLIR module in high-level dialects. - - """ - with ir.InsertionPoint(self._module.body): - arguments = [] - for arg in self._inputs: - shape_list = list(arg.shape) - dtype = arg.dtype - match dtype: - case torch.int32: - mlir_dtype = ir.IntegerType.get_signless(32) - case torch.float32: - mlir_dtype = ir.F32Type.get() - case _: - raise NotImplementedError( - f"Unsupported dtype {dtype} for argument {arg}") - tensor_arg = ir.RankedTensorType.get(shape_list, mlir_dtype) - arguments.append(tensor_arg) - - @func.FuncOp.from_py_func(*arguments, name=self._func_name) - def generated_func(*args): - args_list = list(args) - for node in self._gm.graph.nodes: - if node.op == "output": - output_node_args = node.args[0] - returns = [] - for output_arg in output_node_args: - op = self._symbol_table.get((str(output_arg), 0)) - returns.append(op) - - self._symbol_table[("output", 0)] = returns - elif node.op == "placeholder": - self._import_placeholder(node, args_list) - else: - if node.target is operator.getitem: - self._symbol_table[(str(node.name), - 0)] = self._symbol_table[(node.args[0], - node.args[1])] - else: - self._import_op(node) - - return self._symbol_table.get(("output", 0)) - - print("Printing the generated MLIR...") - print(self._module) - return self._module - - def _import_placeholder(self, node: torch.fx.Node, args_list): - placeholder_name = args_list[self._num_input_visited] - self._symbol_table[(str(node.name), 0)] = placeholder_name - self._num_input_visited += 1 - - def _import_op(self, node: torch.fx.Node): - op_name = node.target.__name__ - - op_ret: Union[ir.Operation, - tuple] = operation_func[op_name](node, self._symbol_table) - if isinstance(op_ret, tuple): - for i, operation in op_ret: - self._symbol_table[(str(node.name), i)] = operation.result - else: - self._symbol_table[(str(node.name), 0)] = op_ret.result - - -def Lowering(module: ir.Module): - """Lower an MLIR module to LLVM dialect. - - Args: - module (mlir.ir.Module): An MLIR module that need to be lowered. - - Returns: - mlir.ir.Module: An MLIR module in LLVM dialect. - - """ - print("-------------------------------------------------------------------") - print("Bufferizing the module ...") - pm = PassManager("builtin.module") - pm.add("func.func(tosa-to-linalg-named)") - pm.add("func.func(tosa-to-linalg)") - pm.add("func.func(tosa-to-tensor)") - pm.add("func.func(tosa-to-arith)") - pm.add("empty-tensor-to-alloc-tensor") - pm.add("convert-elementwise-to-linalg") - pm.add("arith-bufferize") - pm.add("func.func(linalg-bufferize)") - pm.add("func.func(tensor-bufferize)") - pm.add("func-bufferize") - pm.run(module.operation) - print(module) - print("-------------------------------------------------------------------") - print("Lowering the module to LLVM dialect ...") - pm.add("func.func(buffer-deallocation)") - pm.add("func.func(convert-linalg-to-loops)") - pm.add("convert-scf-to-cf") - pm.add("convert-linalg-to-llvm") - pm.add("convert-arith-to-llvm") - pm.add("expand-strided-metadata") - pm.add("finalize-memref-to-llvm") - pm.add("convert-func-to-llvm") - pm.add("reconcile-unrealized-casts") - pm.run(module.operation) - print(module) - return module diff --git a/examples/MLIRPython/buddy/operators_gen.py b/examples/MLIRPython/buddy/operators_gen.py deleted file mode 100644 index 378d40ee92..0000000000 --- a/examples/MLIRPython/buddy/operators_gen.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Generate the MLIR operations for the operators in the FX graph. -""" -from typing import Dict, Tuple, List - -import torch - -import mlir.ir as ir -from mlir.dialects import tosa, linalg, arith - - -def _broadcast_shape(tensor_input1: ir.Value, - tensor_input2: ir.Value) -> List[int]: - """Calculate the broadcast shape of two tensors with broadcastable shapes - according to PyTorch's broadcast semantics: https://pytorch.org/docs/stable/notes/broadcasting.html""" - shp1 = ir.RankedTensorType(tensor_input1.type).shape - shp2 = ir.RankedTensorType(tensor_input2.type).shape - if len(shp1) < len(shp2): - shp1, shp2 = shp2, shp1 - while len(shp2) < len(shp1): - shp2.insert(0, 1) - for idx, (dim1, dim2) in enumerate(zip(shp1, shp2)): - shp1[idx] = shp2[idx] = max(dim1, dim2) - - return shp1 - - -def AddOp(node: torch.fx.Node, - symbol_table: Dict[Tuple[str, int], ir.Operation]) -> ir.Operation: - """Map aten.add.Tensor to tosa.add. - - Args: - node: A FX graph containing the aten.add.Tensor operator and its parameter. - symbol_table: The symbol table that records the mapping between symbols and operations. - - Returns: - ir.Operation: The generated tosa.add operation. - - """ - input1 = symbol_table.get((str(node.args[0]), 0)) - input2 = symbol_table.get((str(node.args[1]), 0)) - broadcasted_shp = _broadcast_shape(input1, input2) - sizes = broadcasted_shp - result_element_type = ir.RankedTensorType(input1.type).element_type - add_result_tensor_type = ir.RankedTensorType.get(sizes, result_element_type) - op = tosa.AddOp(add_result_tensor_type, input1, input2) - return op - - -def AddMMOp(node: torch.fx.Node, - symbol_table: Dict[Tuple[str, int], ir.Operation]) -> ir.Operation: - """Map aten.addmm.default to MLIR operation. - - Args: - node (torch.fx.Node): A FX graph containing the aten.addmm.default operator and its parameter. - symbol_table (Dict[Tuple[str, int], ir.Operation]): The symbol table that records the mapping between symbols and operations. - - Returns: - ir.Operation: The generated MLIR operation representing aten.addmm.default - - """ - input_ = symbol_table.get((str(node.args[0]), 0)) - mat1 = symbol_table.get((str(node.args[1]), 0)) - mat2 = symbol_table.get((str(node.args[2]), 0)) - mat1_shp = ir.RankedTensorType(mat1.type).shape - mat2_shp = ir.RankedTensorType(mat2.type).shape - mat1 = tosa.ReshapeOp(mat1, [1, *mat1_shp]).output - mat2 = tosa.ReshapeOp(mat2, [1, *mat2_shp]).output - - matmul_result_shp = [1, mat1_shp[0], mat2_shp[1]] - result_element_type = ir.RankedTensorType(input_.type).element_type - matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, result_element_type) - matmul_op = tosa.MatMulOp(matmul_result_type, mat1, mat2) - matmul_result = tosa.ReshapeOp(matmul_op.c, matmul_result_shp[1:]) - - add_result_shp = [mat1_shp[0], mat2_shp[1]] - add_result_tensor_type = ir.RankedTensorType.get(add_result_shp, result_element_type) - op = tosa.AddOp(add_result_tensor_type, input_, matmul_result) - return op - - -operation_func = {"add.Tensor": AddOp, "addmm.default": AddMMOp} diff --git a/examples/MLIRPython/matmul.py b/examples/MLIRPython/matmul.py deleted file mode 100644 index d15ae1df68..0000000000 --- a/examples/MLIRPython/matmul.py +++ /dev/null @@ -1,11 +0,0 @@ -from buddy import compiler -import torch -import torch._dynamo as dynamo - -def foo(x, y): - return torch.matmul(x, y) - -foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) -in1 = torch.randn(2, 3) -in2 = torch.randn(3, 5) -foo_mlir(in1, in2) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 5f60a5c28e..684bc7ec15 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -132,20 +132,21 @@ def __call__( """ return self._compile_fx(gm, inputs) - def importer(self, model, data): + def importer(self, model, *args, **kwargs): """ Imports the provided model as MLIR module and flat parameters. Args: model: The model to be imported. - data: The data for the model. + args: Arguments for the model. + kwargs: Keyword arguments for the model. Returns: module: The imported MLIR module. params: The imported flat parameters. """ model_opt = dynamo.optimize(self._compile_fx)(model) - model_opt(*data) + model_opt(*args, **kwargs) module = self._imported_module params = self._imported_params return module, params diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 68c278d1c4..3accb44473 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -18,15 +18,16 @@ # # ===--------------------------------------------------------------------------- +import torch import array -from typing import Dict, List, Union, Tuple +from typing import Dict, List, Tuple, Union import mlir.ir as ir -from mlir.dialects import tosa, tensor -import torch +from mlir.dialects import tensor, tosa def _normalize_binary_operator_shape(shp1, shp2): + """Normalize the shape of two input tensors according to the broadcasting rule""" shp1 = list(shp1) shp2 = list(shp2) while len(shp1) < len(shp2): @@ -38,6 +39,8 @@ def _normalize_binary_operator_shape(shp1, shp2): def _gen_arith_binary_op(input1, input2, op_func): + """Generate arithmetic binary operation. Most binary operations follow the same pattern. + So we can use one function to generate them, avoiding code duplication.""" input1, input2 = _normalize_binary_operator_args(input1, input2) input1_shape = ir.RankedTensorType(input1.type).shape @@ -70,6 +73,8 @@ def _gen_arith_binary_op(input1, input2, op_func): def _scalar_to_tensor( scalar: Union[float, int], element_type: ir.Type, shape: List[int] ): + """PyTorch allow the binary operation between tensor and scalar. But MLIR does not. + So we need to convert scalars to the corresponding tensors.""" element = ( ir.FloatAttr.get(element_type, float(scalar)) if str(element_type) == "f32" @@ -82,6 +87,7 @@ def _scalar_to_tensor( def _normalize_binary_operator_args(arg1, arg2): + """Normalize the types of binary operator arguments.""" if isinstance(arg1, ir.Value) and ( isinstance(arg2, float) or isinstance(arg2, int) ): @@ -118,13 +124,102 @@ def _normalize_binary_operator_args(arg1, arg2): ) +def addmm_op( + node, symbol_table: Dict[Tuple[str, int], ir.Operation] +) -> ir.Operation: + """ + Import matrix multiplication operation. + From PyTorch `aten.addmm.default` operator to MLIR TOSA `matmul` operation. + + Note: this function first reshapes the input matrices to 3D tensors + (since tosa.MatMulOp requires it). Then it multiplies these reshaped matrices. + Finally, it adds the input tensor to the matrix multiplication result. + + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding operations. + + Returns: + op: The operation representing the result of adding the matrix multiplication + to the input tensor. + """ + # get input + input_ = symbol_table.get((str(node.args[0]), 0)) + mat1 = symbol_table.get((str(node.args[1]), 0)) + mat2 = symbol_table.get((str(node.args[2]), 0)) + # get input shape + mat1_shp = ir.RankedTensorType(mat1.type).shape + mat2_shp = ir.RankedTensorType(mat2.type).shape + # append index because tosa.MatMulOp doesn't accept 2D tensor + mat1_reshape_op = tosa.ReshapeOp( + mat1, memoryview(array.array("i", [1, *mat1_shp])) + ) + mat2_reshape_op = tosa.ReshapeOp( + mat2, memoryview(array.array("i", [1, *mat2_shp])) + ) + # do matmul + result_element_type = ir.RankedTensorType(mat1.type).element_type + matmul_result_shp = [1, mat1_shp[0], mat2_shp[1]] + matmul_result_type = ir.RankedTensorType.get( + matmul_result_shp, result_element_type + ) + matmul_op = tosa.MatMulOp( + matmul_result_type, mat1_reshape_op.result, mat2_reshape_op.result + ) + # restore the shape + final_result_shape = [mat1_shp[0], mat2_shp[1]] + matmul_result_reshape_op = tosa.ReshapeOp( + matmul_op.c, memoryview(array.array("i", final_result_shape)) + ) + + op = _gen_arith_binary_op( + input_, matmul_result_reshape_op.result, tosa.AddOp + ) + return op + + +def bmm_op(node, symbol_table) -> ir.Operation: + """ + Import batch matrix multiplication operation. + From PyTorch `aten.bmm.default` operator to MLIR TOSA `matmul` operation. + """ + input_ = symbol_table.get((str(node.args[0]), 0)) + mat2 = symbol_table.get((str(node.args[1]), 0)) + input_shp = ir.RankedTensorType(input_.type).shape + mat2_shp = ir.RankedTensorType(mat2.type).shape + sizes = [input_shp[0], input_shp[1], mat2_shp[2]] + result_element_type = ir.RankedTensorType(input_.type).element_type + result_type = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.MatMulOp(result_type, input_, mat2) + return op + + def add_op(node, symbol_table): + """ + Import tensor addition operation. + From PyTorch `aten.add.Tensor` operator to MLIR TOSA `add` operation. + """ input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) return _gen_arith_binary_op(input1, input2, tosa.AddOp) +def sub_op(node, symbol_table): + """ + Import tensor subtraction operation. + From PyTorch `aten.sub.Tensor` operator to MLIR TOSA `sub` operation. + """ + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + return _gen_arith_binary_op(input1, input2, tosa.SubOp) + + def mul_op(node, symbol_table): + """ + Import tensor multiplication operation. + From PyTorch `aten.mul.Tensor` operator to MLIR TOSA `mul` operation. + """ + def _inner_op(result_type, input1, input2): return tosa.MulOp( result_type, @@ -139,7 +234,578 @@ def _inner_op(result_type, input1, input2): return _gen_arith_binary_op(input1, input2, _inner_op) +def div_op(node, symbol_table): + """ + Import tensor division operation. + From PyTorch `aten.div.Tensor` operator to MLIR TOSA `div` operation. + """ + + def _inner_op(result_type, input1, input2): + return tosa.MulOp( + result_type, + input1, + tosa.ReciprocalOp(input2.type, input2).result, + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + + return _gen_arith_binary_op(input1, input2, _inner_op) + + +def tanh_op(node, symbol_table): + """ + Import elementwise tanh operation. + From PyTorch `aten.tanh.default` operator to MLIR TOSA `tanh` operation. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input1.type).shape + result_element_type = ir.RankedTensorType(input1.type).element_type + tanhResultTensorType = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.TanhOp(tanhResultTensorType, input1) + return op + + +def exp_op(node, symbol_table): + """ + Import elementwise exponential operation. + From PyTorch `aten.exp.default` operator to MLIR TOSA `exp` operation. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input1.type).shape + result_element_type = ir.RankedTensorType(input1.type).element_type + expResultTensorType = ir.RankedTensorType.get(sizes, result_element_type) + op = tosa.ExpOp(expResultTensorType, input1) + return op + + +def rsqrt_op(node, symbol_table): + """ + Import elementwise reciprocal square root operation. + From PyTorch `aten.rsqrt.default` operator to MLIR TOSA `rsqrt` operation. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input1.type).shape + result_element_type = ir.RankedTensorType(input1.type).element_type + rsqrt_result_tensor_type = ir.RankedTensorType.get( + sizes, result_element_type + ) + op = tosa.RsqrtOp(rsqrt_result_tensor_type, input1) + return op + + +def amax_op(node, symbol_table): + """ + Import the amax operation. + From PyTorch `aten.amax.default` operator to MLIR TOSA `reduce_max` operation. + + Note: This conversion function returns the maximum value of each slice + of the input tensor in the given dimension(s). This is consistent + with PyTorch's `torch.amax` operator. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + dim_val = node.args[1][0] + if dim_val < 0: + dim_val += len(ir.RankedTensorType(input1.type).shape) + signless_type = ir.IntegerType.get_signless(64) + dim_attr = ir.IntegerAttr.get(signless_type, dim_val) + op = tosa.ReduceMaxOp(input1, dim_attr) + return op + + +def reshape_op(node, symbol_table): + """ + Import the reshape operation. + From PyTorch `aten.reshape.default` operator to MLIR TOSA `reshape` operation. + + Note: If the new shape contains one and only one `-1`, the size of the new shape will be inferred automatically. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + new_shape = node.args[1] + total_size = 1 + now_shape = ir.RankedTensorType(input1.type).shape + for dim_siz in now_shape: + total_size *= dim_siz + + neg_one_cnt = 0 + rest_size = 1 + for dim_siz in new_shape: + if dim_siz == -1: + neg_one_cnt += 1 + continue + rest_size *= dim_siz + + if neg_one_cnt != 0: + if neg_one_cnt > 1 or total_size % rest_size != 0: + raise ValueError("Can not infer the new shape!") + infer_dim_size = total_size // rest_size + for i, _ in enumerate(new_shape): + if new_shape[i] == -1: + new_shape[i] = infer_dim_size + + new_shape_content = array.array("i", new_shape) + new_shape_content = memoryview(new_shape_content) + op = tosa.ReshapeOp(input1, new_shape_content) + + return op + + +def unsqueeze_op(node, symbol_table): + """ + Import the unsqueeze operation. + From PyTorch `aten.unsqueeze.default` operator to MLIR TOSA `reshape` operation. + + Note: "unsqueeze" means inserting a new dimension of size 1 at the specified + position. For more information, please refer to + https://pytorch.org/docs/stable/generated/torch.unsqueeze.html + """ + input_tensor = symbol_table.get((str(node.args[0]), 0)) + dim = node.args[1] + sizes = ir.RankedTensorType(input_tensor.type).shape + sizes.insert(dim, 1) + new_shape_content = array.array("i", sizes) + new_shape_content = memoryview(new_shape_content) + op = tosa.ReshapeOp(input_tensor, new_shape_content) + return op + + +def select_op(node, symbol_table): + """ + Import the select operation. + From PyTorch `aten.select.int` operator to MLIR TOSA `reshape` operation. + + Note: "select" means slicing the input tensor along the selected dimension at + the given index. For more information, please refer to + https://pytorch.org/docs/stable/generated/torch.select.html + """ + input_tensor = symbol_table.get((str(node.args[0]), 0)) + dim = node.args[1] + index = node.args[2] + + sizes = ir.RankedTensorType(input_tensor.type).shape + + new_sizes = sizes[:dim] + [1] + sizes[dim + 1 :] + new_sizes_attr = ir._denseI64ArrayAttr(new_sizes, None) + + start = [0] * len(sizes) + start[dim] = index + start_attr = ir._denseI64ArrayAttr(start, None) + + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + output_type = ir.RankedTensorType.get(new_sizes, result_element_type) + op = tosa.SliceOp(output_type, input_tensor, start_attr, new_sizes_attr) + + reshape_sizes = sizes[:dim] + sizes[dim + 1 :] + reshape_sizes_content = array.array("i", reshape_sizes) + reshape_sizes_content = memoryview(reshape_sizes_content) + op = tosa.ReshapeOp(op.results[0], reshape_sizes_content) + + return op + + +def slice_op(node, symbol_table): + """ + Import the slice operation. + From PyTorch `aten.slice.Tensor` operator to MLIR tensor `extract_slice` operation. + + Note: "slice" means slicing the input tensor along the selected dimension from a + given start index to an end index. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0)) + dim = node.args[1] + start_idx = node.args[2] + end_idx = node.args[3] + + sizes = ir.RankedTensorType(input_tensor.type).shape + + if start_idx < 0: + start_idx += sizes[dim] + + if end_idx < 0: + end_idx += sizes[dim] + + if start_idx < 0: + start_idx = 0 + elif start_idx >= sizes[dim]: + start_idx = sizes[dim] + + if end_idx < start_idx: + end_idx = start_idx + elif end_idx >= sizes[dim]: + end_idx = sizes[dim] + + new_sizes = [x for x in sizes] + new_sizes[dim] = end_idx - start_idx + new_sizes_attr = ir._denseI64ArrayAttr(new_sizes, None) + + offsets = [0] * len(sizes) + offsets[dim] = start_idx + offsets_attr = ir._denseI64ArrayAttr(offsets, None) + + strides = [1] * len(sizes) + strides_attr = ir._denseI64ArrayAttr(strides, None) + + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + extract_slice_result_type = ir.RankedTensorType.get( + new_sizes, result_element_type + ) + op = tensor.ExtractSliceOp( + extract_slice_result_type, + input_tensor, + [], + [], + [], + offsets_attr, + new_sizes_attr, + strides_attr, + ) + + return op + + +def convert_element_type_op(node, symbol_table): + """ + Import the element type conversion operation. + From PyTorch `prims.convert_element_type.default` operator to + MLIR TOSA `cast` operation. + """ + # maintain a mapping of torch types and mlir types + types_mapping = { + torch.float64: ir.F64Type.get(), + torch.float32: ir.F32Type.get(), + torch.float16: ir.F16Type.get(), + } + input_tensor = symbol_table.get((str(node.args[0]), 0)) + to_cast_type = types_mapping[node.args[1]] + sizes = ir.RankedTensorType(input_tensor.type).shape + output_type = ir.RankedTensorType.get(sizes, to_cast_type) + return tosa.CastOp(output_type, input_tensor) + + +def clone_op(node, symbol_table): + """ + Import the clone operation. + From PyTorch `aten.clone.default` operator to MLIR TOSA `identity` operation. + + Note: Since MLIR follow the SSA form, when using the `identity` operation, we + actually deep-copies the original tensor. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0)) + sizes = ir.RankedTensorType(input_tensor.type).shape + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + output_type = ir.RankedTensorType.get(sizes, result_element_type) + + return tosa.IdentityOp(output_type, input_tensor) + + +def var_mean_op(node, symbol_table): + """ + Import the variance & mean operation. + From PyTorch `aten.var_mean.default` operator to two MLIR TOSA `mul` operation. + + Note: The conversion procedure can be splited into two steps: + 1. In the first part, we calculate the mean value along the given dimension(s) + in `mean_dim_op` function. We first reduce the input tensor along the given + dimension(s) using tosa's `reduce_sum` operation. Then we calculate the mean + value by multiplying the reciprocal of the total size of the reduced dimension(s). + 2. In the second part, we calculate the variance value. We follow the formula in + this link: https://pytorch.org/docs/stable/generated/torch.var_mean.html. We first + calculate (\bar{x} - x_i), where \bar{x} is the mean value we calculated in the first + step. By applying tosa's `mul` operation, we get (\bar{x} - x_i) ^ 2. Then we reduce + the multiplication result to get \sum_{i=0}^{N}(\bar{x} - x_i) ^ 2. Finally, we divide + the reduction sum result by the total size of the reduced dimension(s) minus the + correction. + + `keepdim` argument is supported. It's handled by the applying a `reshape` operation. + + """ + + def mean_dim_op(_input_tensor: ir.Value, _dim) -> ir.Operation: + if isinstance(_dim, int): + _dim = [_dim] + + # `_input_tensor` is the first tensor we need to reduce + reduce_sum_result = _input_tensor + + # reduce along each dimension in `_dim` + for _dim_item in _dim: + reduce_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), _dim_item + ) + reduce_sum_op: ir.Operation = tosa.ReduceSumOp( + reduce_sum_result, reduce_dim_attr + ) + # Next reduction is executed based on this time's reduction result + reduce_sum_result = reduce_sum_op.results[0] + + tensor_shp = ir.RankedTensorType(_input_tensor.type).shape + dim_size = 1 + # calculate the total size on all reduction dimensions to get the denominator + for _dim_item in _dim: + dim_size *= tensor_shp[_dim_item] + + denominator_const_op: ir.Operation = tosa.ConstOp( + ir.DenseElementsAttr.get(memoryview(array.array("f", [dim_size]))) + ) + + reciprocal_op: ir.Operation = tosa.ReciprocalOp( + denominator_const_op.results[0].type, + denominator_const_op.results[0], + ) + + return tosa.MulOp( + reduce_sum_op.results[0].type, + reciprocal_op.results[0], + reduce_sum_op.results[0], + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + def var_dim_op( + _input_tensor: ir.Value, _mean_tensor: ir.Value, _dim, _correction + ) -> ir.Operation: + if isinstance(_dim, int): + _dim = [_dim] + # get (\bar{x} - x_i) + sub_op: ir.Operation = tosa.SubOp( + _input_tensor.type, _input_tensor, _mean_tensor + ) + + # get (\bar{x} - x_i) ^ 2 + mul_op: ir.Operation = tosa.MulOp( + _input_tensor.type, + sub_op.results[0], + sub_op.results[0], + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + # the result of `mul_op` is the first tensor we need to reduce + reduce_sum_op = mul_op + for _dim_item in _dim: + reduce_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), _dim_item + ) + reduce_sum_op: ir.Operation = tosa.ReduceSumOp( + reduce_sum_op.results[0], reduce_dim_attr + ) + + tensor_shp = ir.RankedTensorType(_input_tensor.type).shape + dim_size = 1 + # calculate the denominator + for _dim_item in _dim: + dim_size *= tensor_shp[_dim_item] + biased_denominator_const_op: ir.Operation = tosa.ConstOp( + ir.DenseElementsAttr.get( + memoryview(array.array("f", [dim_size - _correction])) + ) + ) + reciprocal_op: ir.Operation = tosa.ReciprocalOp( + biased_denominator_const_op.results[0].type, + biased_denominator_const_op.results[0], + ) + + return tosa.MulOp( + reduce_sum_op.results[0].type, + reciprocal_op.results[0], + reduce_sum_op.results[0], + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0), + ) + + mean_input_tensor = symbol_table.get((str(node.args[0]), 0)) + var_input_tensor = symbol_table.get((str(node.args[0]), 0)) + + kwargs = node.kwargs + keepdim = kwargs.get("keepdim", False) + correction = kwargs.get("correction", 1.0) + + mean_op = None + var_op = None + if len(node.args) == 1: + calc_dims = range( + len(ir.RankedTensorType(mean_input_tensor.type).shape) + ) + else: + calc_dims = node.args[1] + + mean_op = mean_dim_op(mean_input_tensor, calc_dims) + var_op = var_dim_op( + var_input_tensor, mean_op.results[0], calc_dims, correction + ) + mean_input_tensor = mean_op.results[0] + var_input_tensor = var_op.results[0] + + if not keepdim: + result_shp = ir.RankedTensorType(var_op.results[0].type).shape + result_shp = [siz for siz in result_shp if siz != 1] + var_op = tosa.ReshapeOp( + var_op.results[0], memoryview(array.array("i", result_shp)) + ) + mean_op = tosa.ReshapeOp( + mean_op.results[0], memoryview(array.array("i", result_shp)) + ) + + return var_op, mean_op + + +def permute_op(node, symbol_table): + """ + Import the permute operation. + From PyTorch `aten.permute.default` operator to MLIR TOSA `transpose` operation. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0)) + perm = node.args[1] + perm_const_op = tosa.ConstOp( + ir.DenseElementsAttr.get(memoryview(array.array("i", perm))) + ) + result_element_type = ir.RankedTensorType(input_tensor.type).element_type + init_shape = ir.RankedTensorType(input_tensor.type).shape + new_shape = [] + for perm_item in perm: + new_shape.append(init_shape[perm_item]) + + permute_result_type = ir.RankedTensorType.get( + new_shape, result_element_type + ) + permute_op = tosa.TransposeOp( + permute_result_type, input_tensor, perm_const_op.results[0] + ) + return permute_op + + +def embedding_op(node, symbol_table): + """ + Import the embedding operation. + From PyTorch `aten.embedding.default` operator to MLIR TOSA `reshape` operation. + + Note: Althought this conversion function will finally return a `reshape` operation, + the core is the `gather` operation. It can generate a tensor for which each + element in the output is a slice of the values tensor based on the value of + indices. In this case, we use `gather` to extract elements from the weight + tensor based on the `indices` argument. + """ + indices = symbol_table.get((str(node.args[1]), 0)) + weight = symbol_table.get((str(node.args[0]), 0)) + + indices_size = ir.RankedTensorType(indices.type).shape + weight_size = ir.RankedTensorType(weight.type).shape + result_element_type = ir.RankedTensorType(weight.type).element_type + assert len(indices_size) == 2 + + if indices_size[0] != 1: + total_size = 1 + for x in indices_size: + total_size *= x + indices_reshape_op = tosa.ReshapeOp( + indices, memoryview(array.array("i", [1, total_size])) + ) + indices = indices_reshape_op.result + gather_result_type = ir.RankedTensorType.get( + [1, total_size, weight_size[1]], result_element_type + ) + else: + gather_result_type = ir.RankedTensorType.get( + [*indices_size, weight_size[1]], result_element_type + ) + + # tosa.gather doesn't support i64, so we need to cast it to i32 + if str(ir.RankedTensorType(indices.type).element_type) != "i32": + indices = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(indices.type).shape, + ir.IntegerType.get_signless(32), + ), + indices, + ) + + weight_reshape_op = tosa.ReshapeOp( + weight, memoryview(array.array("i", [1, *weight_size])) + ) + + gather_op = tosa.GatherOp( + gather_result_type, weight_reshape_op.result, indices + ) + op = tosa.ReshapeOp( + gather_op.output, + memoryview(array.array("i", [*indices_size, weight_size[1]])), + ) + + return op + + +def expand_op(node, symbol_table) -> ir.Operation: + """ + Import the expand operation. + From PyTorch `aten.expand.default` operator to MLIR TOSA `add` operation. + + Note: This conversion is implemented using the broadcast machanism of TOSA + `add` operation. We allocate a tensor with the shape to expand and + elements in this tensor is all zero. Then we add the original tensor + to this all-zero tensor. After the applying the broadcasting, we get + the result. + """ + to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) + new_size = node.args[1] + result_element_type = ir.RankedTensorType( + to_expand_tensor.type + ).element_type + element = ir.FloatAttr.get(result_element_type, 0.0) + new_size_tensor_type = ir.RankedTensorType.get( + new_size, result_element_type + ) + new_size_attr = ir.DenseElementsAttr.get_splat( + new_size_tensor_type, element + ) + new_size_tensor = tosa.ConstOp(new_size_attr).results[0] + op = _gen_arith_binary_op(to_expand_tensor, new_size_tensor, tosa.AddOp) + return op + + +def sum_op(node, symbol_table): + """ + Import the sum operation. + From PyTorch `aten.sum.dim_IntList` operator to MLIR TOSA `reduce_sum` operation. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0)) + reduce_sum_dims = node.args[1] + dim_cnt = len(ir.RankedTensorType(input_tensor.type).shape) + reduce_sum_dims = [ + dim if dim >= 0 else dim_cnt + dim for dim in reduce_sum_dims + ] + _reduce_sum_input_tensor = input_tensor + reduce_sum_op = None + for dim in reduce_sum_dims: + reduce_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), dim + ) + reduce_sum_op = tosa.ReduceSumOp( + _reduce_sum_input_tensor, reduce_dim_attr + ) + _reduce_sum_input_tensor = reduce_sum_op.results[0] + + return reduce_sum_op + + ops_registry = { "add.Tensor": add_op, "mul.Tensor": mul_op, + "sub.Tensor": sub_op, + "sum.dim_IntList": sum_op, + "tanh.default": tanh_op, + "amax.default": amax_op, + "rsqrt.default": rsqrt_op, + "bmm.default": bmm_op, + "clone.default": clone_op, + "div.Tensor": div_op, + "exp.default": exp_op, + "expand.default": expand_op, + "var_mean.correction": var_mean_op, + "addmm.default": addmm_op, + "reshape.default": reshape_op, + "view.default": reshape_op, + "select.int": select_op, + "slice.Tensor": slice_op, + "embedding.default": embedding_op, + "convert_element_type.default": convert_element_type_op, + "permute.default": permute_op, + "unsqueeze.default": unsqueeze_op, } diff --git a/tests/Python/test_addmm.py b/tests/Python/test_addmm.py new file mode 100644 index 0000000000..75b33a78d5 --- /dev/null +++ b/tests/Python/test_addmm.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y, z): + return torch.ops.aten.addmm(z, x, y) + + +in1 = torch.randn(4, 2) +in2 = torch.randn(2, 4) +in3 = torch.randn(4, 4) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, in2, in3) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.matmul" +# CHECK: %{{.*}} = "tosa.add" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_amax.py b/tests/Python/test_amax.py new file mode 100644 index 0000000000..ebcb440463 --- /dev/null +++ b/tests/Python/test_amax.py @@ -0,0 +1,34 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import random +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, dim): + return torch.ops.aten.amax(x, dim, True) + + +in1 = torch.randn(4, 5, 2, 9) +dim = [random.randint(0, 3)] + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, dim) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reduce_max" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_arith_div.py b/tests/Python/test_arith_div.py new file mode 100644 index 0000000000..6e2d4f5c09 --- /dev/null +++ b/tests/Python/test_arith_div.py @@ -0,0 +1,34 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return x / y + + +in1 = torch.randn(10) +in2 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, in2) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reciprocal" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_arith_mul.py b/tests/Python/test_arith_mul.py new file mode 100644 index 0000000000..b07160117c --- /dev/null +++ b/tests/Python/test_arith_mul.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return x * y + + +in1 = torch.randn(10) +in2 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, in2) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_arith_sub.py b/tests/Python/test_arith_sub.py new file mode 100644 index 0000000000..92dcfdf954 --- /dev/null +++ b/tests/Python/test_arith_sub.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return x - y + + +in1 = torch.randn(10) +in2 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, in2) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.sub" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_bmm.py b/tests/Python/test_bmm.py new file mode 100644 index 0000000000..e1f9010147 --- /dev/null +++ b/tests/Python/test_bmm.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return torch.ops.aten.bmm(x, y) + + +in1 = torch.randn(10, 3, 2) +in2 = torch.randn(10, 2, 3) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, in2) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.matmul" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_clone.py b/tests/Python/test_clone.py new file mode 100644 index 0000000000..732b30cd33 --- /dev/null +++ b/tests/Python/test_clone.py @@ -0,0 +1,32 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x): + return torch.ops.aten.clone(x) + + +in1 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.identity" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_convert_element_type.py b/tests/Python/test_convert_element_type.py new file mode 100644 index 0000000000..d55d7b7553 --- /dev/null +++ b/tests/Python/test_convert_element_type.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, to_cast_type): + return torch.ops.prims.convert_element_type(x, to_cast_type) + + +in1 = torch.randn(10).to(torch.float32) +to_cast_type = torch.float16 + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, to_cast_type) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.cast" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_embedding.py b/tests/Python/test_embedding.py new file mode 100644 index 0000000000..b00c218e14 --- /dev/null +++ b/tests/Python/test_embedding.py @@ -0,0 +1,58 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(weight, indices): + return torch.ops.aten.embedding(weight, indices) + + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +# test trivial case +weight = torch.randn(10, 5) +indices = torch.randint(10, (3, 3)).to(torch.int32) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(weight, indices) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: %{{.*}} = "tosa.gather" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) + + +# test cast case +weight = torch.randn(10, 5) +indices = torch.randint(10, (3, 3)).to(torch.int64) + + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(weight, indices) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: %{{.*}} = "tosa.cast" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: %{{.*}} = "tosa.gather" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_exp.py b/tests/Python/test_exp.py new file mode 100644 index 0000000000..ceb286f849 --- /dev/null +++ b/tests/Python/test_exp.py @@ -0,0 +1,32 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x): + return torch.ops.aten.exp(x) + + +in1 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.exp" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_expand.py b/tests/Python/test_expand.py new file mode 100644 index 0000000000..3a6b9a5913 --- /dev/null +++ b/tests/Python/test_expand.py @@ -0,0 +1,32 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, new_size): + return torch.ops.aten.expand(x, new_size) + +x = torch.randn(1, 3) +new_size = (6, 3) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, new_size) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.add" +# CHECK: return %{{.*}} : tensor<6x3xf32> +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_permute.py b/tests/Python/test_permute.py new file mode 100644 index 0000000000..8855fb4995 --- /dev/null +++ b/tests/Python/test_permute.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return torch.ops.aten.permute(x, y) + + +x = torch.randn(3, 2, 4) +perm = (2, 0, 1) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, perm) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.transpose" +# CHECK: return %{{.*}} : tensor<4x3x2xf32> +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_reshape.py b/tests/Python/test_reshape.py new file mode 100644 index 0000000000..213da8a9a3 --- /dev/null +++ b/tests/Python/test_reshape.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, new_shape): + return torch.ops.aten.reshape(x, new_shape) + + +x = torch.randn(2, 3) +new_shape = (3, 2) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, new_shape) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_rsqrt.py b/tests/Python/test_rsqrt.py new file mode 100644 index 0000000000..cfcdbdee20 --- /dev/null +++ b/tests/Python/test_rsqrt.py @@ -0,0 +1,32 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x): + return torch.ops.aten.rsqrt(x) + + +x = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.rsqrt" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_select.py b/tests/Python/test_select.py new file mode 100644 index 0000000000..41a51b5752 --- /dev/null +++ b/tests/Python/test_select.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, dim, index): + return torch.ops.aten.select(x, dim, index) + + +x = torch.randn(3, 5, 2) +dim = 1 +index = 2 + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, dim, index) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.slice" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: return %{{.*}} : tensor<3x2xf32> +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_slice.py b/tests/Python/test_slice.py new file mode 100644 index 0000000000..61a8658e1b --- /dev/null +++ b/tests/Python/test_slice.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, dim, start_idx, end_idx): + return torch.ops.aten.slice(x, dim, start_idx, end_idx) + + +x = torch.randn(3, 5, 2) +dim = 1 +start_idx = 1 +end_idx = 3 + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, dim, start_idx, end_idx) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = tensor.extract_slice +# CHECK: return %{{.*}} : tensor<3x2x2xf32> +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_sum.py b/tests/Python/test_sum.py new file mode 100644 index 0000000000..1a744a6b9e --- /dev/null +++ b/tests/Python/test_sum.py @@ -0,0 +1,34 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import random +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, dim): + return torch.ops.aten.sum(x, dim, True) + + +x = torch.randn(4, 5, 2, 9) +dim = [random.randint(0, 3)] + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, dim) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reduce_sum" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_tanh.py b/tests/Python/test_tanh.py new file mode 100644 index 0000000000..7866e83542 --- /dev/null +++ b/tests/Python/test_tanh.py @@ -0,0 +1,32 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x): + return torch.ops.aten.tanh(x) + + +x = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.tanh" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_unsqueeze.py b/tests/Python/test_unsqueeze.py new file mode 100644 index 0000000000..d1d9c2fca2 --- /dev/null +++ b/tests/Python/test_unsqueeze.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, dim): + return torch.ops.aten.unsqueeze(x, dim) + + +x = torch.randn(10) +dim = 0 + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x, dim) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: return %{{.*}} : tensor<1x10xf32> +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_var_mean.py b/tests/Python/test_var_mean.py new file mode 100644 index 0000000000..8319817ebf --- /dev/null +++ b/tests/Python/test_var_mean.py @@ -0,0 +1,66 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x): + return torch.ops.aten.var_mean(x) + + +def foo_keepdim(x): + return torch.ops.aten.var_mean(x, keepdim=True) + + +x = torch.randn(10, 2, 4) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(x) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reduce_sum" +# CHECK: %{{.*}} = "tosa.const" +# CHECK: %{{.*}} = "tosa.reciprocal" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: %{{.*}} = "tosa.sub" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: %{{.*}} = "tosa.reduce_sum" +# CHECK: %{{.*}} = "tosa.const" +# CHECK: %{{.*}} = "tosa.reciprocal" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: %{{.*}} = "tosa.reshape" +# CHECK: return %{{.*}} : tensor, tensor +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) + +foo_keepdim_mlir = dynamo.optimize(dynamo_compiler)(foo_keepdim) +foo_keepdim_mlir(x) +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.reduce_sum" +# CHECK: %{{.*}} = "tosa.const" +# CHECK: %{{.*}} = "tosa.reciprocal" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: %{{.*}} = "tosa.sub" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: %{{.*}} = "tosa.reduce_sum" +# CHECK: %{{.*}} = "tosa.const" +# CHECK: %{{.*}} = "tosa.reciprocal" +# CHECK: %{{.*}} = "tosa.mul" +# CHECK: return %{{.*}} : tensor<1x1x1xf32>, tensor<1x1x1xf32> +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module) diff --git a/tests/Python/test_view.py b/tests/Python/test_view.py new file mode 100644 index 0000000000..abd905d483 --- /dev/null +++ b/tests/Python/test_view.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return x + y + + +in1 = torch.randn(10) +in2 = torch.randn(10) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) +foo_mlir(in1, in2) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.add" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +print(dynamo_compiler.imported_module)