diff --git a/examples/MLIRPython/.style.yapf b/examples/MLIRPython/.style.yapf deleted file mode 100644 index 9ef1dc15ba..0000000000 --- a/examples/MLIRPython/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 diff --git a/examples/MLIRPython/addmm.py b/examples/MLIRPython/addmm.py index 29cf695635..e31edb7de0 100644 --- a/examples/MLIRPython/addmm.py +++ b/examples/MLIRPython/addmm.py @@ -1,13 +1,13 @@ -from buddy.compiler import DynamoCompiler +from buddy.compiler import dynamo_compiler import torch import torch._dynamo as dynamo def foo(c, a, b): - return torch.addmm(c, a, b) + return torch.addmm(c, a, b) -foo_mlir = dynamo.optimize(DynamoCompiler)(foo) +foo_mlir = dynamo.optimize(dynamo_compiler)(foo) a_float32 = torch.randn(3, 2) b_float32 = torch.randn(2, 3) diff --git a/examples/MLIRPython/arith_add.py b/examples/MLIRPython/arith_add.py index 0c2b7a11cf..0a3679bab2 100644 --- a/examples/MLIRPython/arith_add.py +++ b/examples/MLIRPython/arith_add.py @@ -4,10 +4,10 @@ def foo(x, y): - return x + y + return x + y -foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) +foo_mlir = dynamo.optimize(compiler.dynamo_compiler)(foo) float32_in1 = torch.randn(10).to(torch.float32) float32_in2 = torch.randn(10).to(torch.float32) foo_mlir(float32_in1, float32_in2) diff --git a/examples/MLIRPython/buddy/compiler.py b/examples/MLIRPython/buddy/compiler.py index 3f66d04f99..f7effe777b 100644 --- a/examples/MLIRPython/buddy/compiler.py +++ b/examples/MLIRPython/buddy/compiler.py @@ -1,211 +1,211 @@ -"""The buddy compiler backend for torch dynamo. -""" -import json +# ===- frontend.py ------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the entry of the Buddy Compiler frontend. +# +# ===--------------------------------------------------------------------------- + import operator from typing import Callable, List, Union import mlir.dialects.func as func import mlir.ir as ir import torch -from iree import compiler as ireec -from iree import runtime as ireert from mlir.passmanager import PassManager from torch._functorch.aot_autograd import aot_module_simplified from torch._inductor.decomposition import decompositions as inductor_decomp -from .operators_gen import operation_func - - -def DynamoCompiler(gm: torch.fx.GraphModule, - inputs: List[torch.Tensor]) -> Callable: - """The main entry point of buddy compiler for torch dynamo. It takes a FX - graph module and a list of inputs as parameters. The compiler will first use - PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then - it will map the operators in Aten/Prims IR to MLIR operations and generate an - MLIR module. Finally, It will lower the MLIR module to LLVM dialect. - - Args: - gm (torch.fx.GraphModule): The FX graph module to be compiled. - inputs (List[torch.Tensor]): The inputs of the FX graph module. - - Returns: - Callable: A compiled function that equivalent to the FX graph. - - """ - - def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): - """Compile a FX graph in Aten/Prims IR to MLIR.""" - # print("Custom Compiler from FX Graph to MLIR:") - # print("-------------------------------------------------------------------") - # gm.graph.print_tabular() - # Initialize the MLIR context. - ctx = ir.Context() - with ir.Location.unknown(ctx): - fx_importer = FXGraphImporter(gm, inputs) - module = fx_importer.import_graph() - module = Lowering(module) - - return gm.forward - - # compiled_flatbuffer = ireec.compile_str(str(module), target_backends=["vmvx"]) - # runtime_config = ireert.Config("local-task") - # ctx = ireert.SystemContext(config=runtime_config) - # vm_module = ireert.VmModule.copy_buffer(ctx.instance, compiled_flatbuffer) - # ctx.add_vm_module(vm_module) - - # return lambda *args: ctx.modules.module["main"](*args).to_host() - - args_dict = { - **gm.state_dict(), "input_ids": inputs[0], - "token_type_ids": inputs[1], - "attention_mask": inputs[2] - } - with open("bert_parameters.bin", "wb") as args_f, \ - open("bert_parameters_shape.txt", "w+") as shape_f, \ - open("bert_parameters_dtype.txt", "w+") as dtype_f: - for value in args_dict.values(): - dtype_f.write(f"{value.dtype}\n") - shape_f.write(" ".join([str(dim) for dim in value.shape]) + "\n") - args_f.write(value.numpy().tobytes()) - - return aot_module_simplified(gm, - inputs, - fw_compiler=_compiler, - decompositions=inductor_decomp.copy()) +from buddy.operators_gen import operation_func -class FXGraphImporter: - """The FX graph importer class.""" - - def __init__( - self, - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], - func_name: str = "main", - ): - """ +def dynamo_compiler( + gm: torch.fx.GraphModule, inputs: List[torch.Tensor] +) -> Callable: + """The main entry point of buddy compiler for torch dynamo. It takes a FX + graph module and a list of inputs as parameters. The compiler will first use + PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then + it will map the operators in Aten/Prims IR to MLIR operations and generate an + MLIR module. Finally, It will lower the MLIR module to LLVM dialect. + Args: - gm (torch.fx.GraphModule): The FX graph module that will be imported. - inputs (List[torch.Tensor]): Input tensor(s) of the FX graph. - func_name (str): Name of the generated MLIR func. + gm (torch.fx.GraphModule): The FX graph module to be compiled. + inputs (List[torch.Tensor]): The inputs of the FX graph module. + + Returns: + Callable: A compiled function that equivalent to the FX graph. """ - self._symbol_table = {} - self._gm = gm - self._func_name = func_name - self._inputs = inputs - self._num_input_visited = 0 - self._module = ir.Module.create() - def import_graph(self) -> ir.Module: - """Import the FX graph, generate an MLIR module in high-level dialects. + def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): + """Compile a FX graph in Aten/Prims IR to MLIR.""" + print("Custom Compiler from FX Graph to MLIR:") + print( + "-------------------------------------------------------------------" + ) + gm.graph.print_tabular() + # Initialize the MLIR context. + ctx = ir.Context() + with ir.Location.unknown(ctx): + fx_importer = FXGraphImporter(gm, inputs) + module = fx_importer.import_graph() + module = Lowering(module) - Returns: - mlir.ir.Module: An MLIR module in high-level dialects. + return gm.forward - """ - with ir.InsertionPoint(self._module.body): - arguments = [] - for arg in self._inputs: - shape_list = list(arg.shape) - dtype = arg.dtype - match dtype: - case torch.int32: - mlir_dtype = ir.IntegerType.get_signless(32) - case torch.int64: - mlir_dtype = ir.IntegerType.get_signless(64) - case torch.float32: - mlir_dtype = ir.F32Type.get() - case _: - raise NotImplementedError( - f"Unsupported dtype {dtype} for argument {arg}") - tensor_arg = ir.RankedTensorType.get(shape_list, mlir_dtype) - arguments.append(tensor_arg) - - @func.FuncOp.from_py_func(*arguments, name=self._func_name) - def generated_func(*args): - args_list = list(args) - for node in self._gm.graph.nodes: - if node.op == "output": - output_node_args = node.args[0] - returns = [] - for output_arg in output_node_args: - op = self._symbol_table.get((str(output_arg), 0)) - returns.append(op) - - self._symbol_table[("output", 0)] = returns - elif node.op == "placeholder": - self._import_placeholder(node, args_list) - else: - if node.target is operator.getitem: - self._symbol_table[(str(node.name), - 0)] = self._symbol_table[(str(node.args[0]), - node.args[1])] - else: - self._import_op(node) - - return self._symbol_table.get(("output", 0)) - - # print("Printing the generated MLIR...") - # print(self._module) - return self._module - - def _import_placeholder(self, node: torch.fx.Node, args_list): - placeholder_name = args_list[self._num_input_visited] - self._symbol_table[(str(node.name), 0)] = placeholder_name - self._num_input_visited += 1 - - def _import_op(self, node: torch.fx.Node): - op_name = node.target.__name__ - - op_ret: Union[ir.Operation, - tuple] = operation_func[op_name](node, self._symbol_table) - if isinstance(op_ret, tuple): - for i, operation in enumerate(op_ret): - self._symbol_table[(str(node.name), i)] = operation.result - else: - self._symbol_table[(str(node.name), 0)] = op_ret.result + return aot_module_simplified( + gm, inputs, fw_compiler=_compiler, decompositions=inductor_decomp.copy() + ) + + +class FXGraphImporter: + """The FX graph importer class.""" + + def __init__( + self, + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + func_name: str = "main", + ): + """ + Args: + gm (torch.fx.GraphModule): The FX graph module that will be imported. + inputs (List[torch.Tensor]): Input tensor(s) of the FX graph. + func_name (str): Name of the generated MLIR func. + + """ + self._symbol_table = {} + self._gm = gm + self._func_name = func_name + self._inputs = inputs + self._num_input_visited = 0 + self._module = ir.Module.create() + + def import_graph(self) -> ir.Module: + """Import the FX graph, generate an MLIR module in high-level dialects. + + Returns: + mlir.ir.Module: An MLIR module in high-level dialects. + + """ + with ir.InsertionPoint(self._module.body): + arguments = [] + for arg in self._inputs: + shape_list = list(arg.shape) + dtype = arg.dtype + match dtype: + case torch.int32: + mlir_dtype = ir.IntegerType.get_signless(32) + case torch.int64: + mlir_dtype = ir.IntegerType.get_signless(64) + case torch.float32: + mlir_dtype = ir.F32Type.get() + case _: + raise NotImplementedError( + f"Unsupported dtype {dtype} for argument {arg}" + ) + tensor_arg = ir.RankedTensorType.get(shape_list, mlir_dtype) + arguments.append(tensor_arg) + + @func.FuncOp.from_py_func(*arguments, name=self._func_name) + def generated_func(*args): + args_list = list(args) + for node in self._gm.graph.nodes: + if node.op == "output": + output_node_args = node.args[0] + returns = [] + for output_arg in output_node_args: + op = self._symbol_table.get((str(output_arg), 0)) + returns.append(op) + + self._symbol_table[("output", 0)] = returns + elif node.op == "placeholder": + self._import_placeholder(node, args_list) + else: + if node.target is operator.getitem: + self._symbol_table[ + (str(node.name), 0) + ] = self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + else: + self._import_op(node) + + return self._symbol_table.get(("output", 0)) + + print("Printing the generated MLIR...") + print(self._module) + return self._module + + def _import_placeholder(self, node: torch.fx.Node, args_list): + placeholder_name = args_list[self._num_input_visited] + self._symbol_table[(str(node.name), 0)] = placeholder_name + self._num_input_visited += 1 + + def _import_op(self, node: torch.fx.Node): + op_name = node.target.__name__ + + op_ret: Union[ir.Operation, tuple] = operation_func[op_name]( + node, self._symbol_table + ) + if isinstance(op_ret, tuple): + for i, operation in enumerate(op_ret): + self._symbol_table[(str(node.name), i)] = operation.result + else: + self._symbol_table[(str(node.name), 0)] = op_ret.result def Lowering(module: ir.Module): - """Lower an MLIR module to LLVM dialect. - - Args: - module (mlir.ir.Module): An MLIR module that need to be lowered. - - Returns: - mlir.ir.Module: An MLIR module in LLVM dialect. - - """ - # print("-------------------------------------------------------------------") - # print("Bufferizing the module ...") - pm = PassManager("builtin.module") - pm.add("func.func(tosa-to-linalg-named)") - pm.add("func.func(tosa-to-linalg)") - pm.add("func.func(tosa-to-tensor)") - pm.add("func.func(tosa-to-arith)") - pm.add("empty-tensor-to-alloc-tensor") - pm.add("convert-elementwise-to-linalg") - pm.add("arith-bufferize") - pm.add("func.func(linalg-bufferize)") - pm.add("func.func(tensor-bufferize)") - pm.add("func-bufferize") - pm.run(module.operation) - # print(module) - # print("-------------------------------------------------------------------") - # print("Lowering the module to LLVM dialect ...") - pm.add("func.func(buffer-deallocation)") - pm.add("func.func(convert-linalg-to-loops)") - pm.add("convert-math-to-llvm") - pm.add("convert-math-to-libm") - pm.add("convert-scf-to-cf") - pm.add("convert-linalg-to-llvm") - pm.add("convert-arith-to-llvm") - pm.add("expand-strided-metadata") - pm.add("finalize-memref-to-llvm") - pm.add("func.func(llvm-request-c-wrappers)") - pm.add("convert-func-to-llvm") - pm.add("reconcile-unrealized-casts") - pm.run(module.operation) - # print(module) - return module + """Lower an MLIR module to LLVM dialect. + + Args: + module (mlir.ir.Module): An MLIR module that need to be lowered. + + Returns: + mlir.ir.Module: An MLIR module in LLVM dialect. + + """ + print("-------------------------------------------------------------------") + print("Bufferizing the module ...") + pm = PassManager("builtin.module") + pm.add("func.func(tosa-to-linalg-named)") + pm.add("func.func(tosa-to-linalg)") + pm.add("func.func(tosa-to-tensor)") + pm.add("func.func(tosa-to-arith)") + pm.add("empty-tensor-to-alloc-tensor") + pm.add("convert-elementwise-to-linalg") + pm.add("arith-bufferize") + pm.add("func.func(linalg-bufferize)") + pm.add("func.func(tensor-bufferize)") + pm.add("func-bufferize") + pm.run(module.operation) + print(module) + print("-------------------------------------------------------------------") + print("Lowering the module to LLVM dialect ...") + pm.add("func.func(buffer-deallocation)") + pm.add("func.func(convert-linalg-to-loops)") + pm.add("convert-math-to-llvm") + pm.add("convert-math-to-libm") + pm.add("convert-scf-to-cf") + pm.add("convert-linalg-to-llvm") + pm.add("convert-arith-to-llvm") + pm.add("expand-strided-metadata") + pm.add("finalize-memref-to-llvm") + pm.add("func.func(llvm-request-c-wrappers)") + pm.add("convert-func-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.run(module.operation) + print(module) + return module diff --git a/examples/MLIRPython/matmul.py b/examples/MLIRPython/matmul.py index d15ae1df68..8779eb1a01 100644 --- a/examples/MLIRPython/matmul.py +++ b/examples/MLIRPython/matmul.py @@ -2,10 +2,12 @@ import torch import torch._dynamo as dynamo + def foo(x, y): - return torch.matmul(x, y) + return torch.matmul(x, y) + -foo_mlir = dynamo.optimize(compiler.DynamoCompiler)(foo) +foo_mlir = dynamo.optimize(compiler.dynamo_compiler)(foo) in1 = torch.randn(2, 3) in2 = torch.randn(3, 5) foo_mlir(in1, in2)