diff --git a/CMakeLists.txt b/CMakeLists.txt index febd797e91..0678c0f92a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR BUDDY_MLIR_OUT_OF_TREE_ set(LLVM_MLIR_BINARY_DIR ${MLIR_DIR}/../../../bin) set(LLVM_MLIR_SOURCE_DIR ${MLIR_DIR}/../../../../mlir) + set(LLVM_PROJECT_BUILD_DIR ${MLIR_DIR}/../../../) set(LLVM_PROJECT_SOURCE_DIR ${MLIR_DIR}/../../../../) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") @@ -86,6 +87,7 @@ set(BUDDY_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/bin) set(BUDDY_EXAMPLES_DIR ${BUDDY_SOURCE_DIR}/examples) set(BUDDY_MIDEND_INCLUDE_DIR ${BUDDY_SOURCE_DIR}/midend/include/) set(BUDDY_THIRDPARTY_INCLUDE_DIR ${BUDDY_SOURCE_DIR}/thirdparty/include/) +set(BUDDY_MLIR_PYTHON_PACKAGES_DIR ${BUDDY_BUILD_DIR}/python_packages) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${BUDDY_BINARY_DIR}) @@ -154,6 +156,17 @@ if(BUDDY_DSL_EXAMPLES) # add macros to generate ANTLR Cpp code from grammar find_package(ANTLR REQUIRED) endif() + +#------------------------------------------------------------------------------- +# Initialize Python packages +#------------------------------------------------------------------------------- +if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) + file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy) + file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler) + file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/__init__.py "") + file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/__init__.py "") +endif() + #------------------------------------------------------------------------------- # Directory setup #------------------------------------------------------------------------------- diff --git a/README.md b/README.md index d228884cde..a1f466f8c3 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,18 @@ $ cmake -G Ninja ../llvm \ -DCMAKE_BUILD_TYPE=RELEASE ``` +To enable MLIR Python bindings, please use the following configuration: + +``` +$ cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_TARGETS_TO_BUILD="host;RISCV" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=[path_to_python_executable] +``` + If your target machine has lld installed, you can use the following configuration: ``` @@ -74,6 +86,18 @@ $ ninja $ ninja check-buddy ``` +To utilize the Buddy Compiler Python package, please ensure that the MLIR Python bindings are enabled and use the following configuration: + +``` +$ cmake -G Ninja .. \ + -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ + -DPython3_EXECUTABLE=[path_to_python_executable] +``` + If you want to add domain-specific framework support, please add the following cmake options: | Framework | Enable Option | Other Options | diff --git a/frontend/CMakeLists.txt b/frontend/CMakeLists.txt index 714a8fb2e2..39e683c04b 100644 --- a/frontend/CMakeLists.txt +++ b/frontend/CMakeLists.txt @@ -1,2 +1,5 @@ add_subdirectory(FrontendGen) add_subdirectory(Interfaces) +if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) + add_subdirectory(Python) +endif() diff --git a/frontend/Python/CMakeLists.txt b/frontend/Python/CMakeLists.txt new file mode 100644 index 0000000000..ba2ec3eb06 --- /dev/null +++ b/frontend/Python/CMakeLists.txt @@ -0,0 +1,5 @@ +# Get all `.py` files in the current path +file(GLOB PY_FILES "*.py") + +# Copy all `.py` files to the destination +file(COPY ${PY_FILES} DESTINATION ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py new file mode 100644 index 0000000000..ff833fb966 --- /dev/null +++ b/frontend/Python/frontend.py @@ -0,0 +1,142 @@ +# ===- frontend.py ------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the entry of the Buddy Compiler frontend. +# +# ===--------------------------------------------------------------------------- + +import operator +from typing import Callable, List, Union + +import mlir.dialects.func as func +import mlir.ir as ir +from mlir.passmanager import PassManager +import torch +from torch._functorch.aot_autograd import aot_module_simplified + +from .ops_gen import operation_func + + +def dynamo_importer( + gm: torch.fx.GraphModule, inputs: List[torch.Tensor] +) -> Callable: + """The main entry point of buddy compiler for torch dynamo. It takes a FX + graph module and a list of inputs as parameters. The compiler will first use + PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then + it will map the operators in Aten/Prims IR to MLIR operations and generate + an MLIR module. Finally, It will lower the MLIR module to LLVM dialect. + + Args: + gm (torch.fx.GraphModule): The FX graph module to be compiled. + inputs (List[torch.Tensor]): The inputs of the FX graph module. + + Returns: + Callable: A compiled function that equivalent to the FX graph. + + """ + + def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): + """Compile a FX graph in Aten/Prims IR to MLIR.""" + # Custom Compiler from FX Graph to MLIR + # Initialize the MLIR context. + ctx = ir.Context() + with ir.Location.unknown(ctx): + fx_importer = FXGraphImporter(gm, inputs) + module = fx_importer.import_graph() + # TODO: design an interface that return the `module` directly. + module.dump() + return gm.forward + + return aot_module_simplified(gm, inputs, fw_compiler=_compiler) + + +class FXGraphImporter: + """The FX graph importer class.""" + + def __init__( + self, + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + func_name: str = "forward", + ): + """ + Args: + gm (torch.fx.GraphModule): The FX graph module that will be imported. + inputs (List[torch.Tensor]): Input tensor(s) of the FX graph. + func_name (str): Name of the generated MLIR func. + + """ + self._symbol_table = {} + self._gm = gm + self._func_name = func_name + self._inputs = inputs + self._num_input_visited = 0 + self._module = ir.Module.create() + + def import_graph(self) -> ir.Module: + """Import the FX graph, generate an MLIR module in high-level dialects. + + Returns: + mlir.ir.Module: An MLIR moduel in high-level dialects. + + """ + with ir.InsertionPoint(self._module.body): + arguments = [] + for arg in self._inputs: + shape_list = list(arg.shape) + f32 = ir.F32Type.get() + tensor_arg = ir.RankedTensorType.get(shape_list, f32) + arguments.append(tensor_arg) + + @func.FuncOp.from_py_func(*arguments, name=self._func_name) + def generated_func(*args): + args_list = list(args) + for node in self._gm.graph.nodes: + if node.op == "output": + output_node_args = node.args[0] + returns = [] + for output_arg in output_node_args: + op = self._symbol_table.get((str(output_arg), 0)) + returns.append(op) + self._symbol_table[("output", 0)] = returns + elif node.op == "placeholder": + self._import_placeholder(node, args_list) + else: + if node.target is operator.getitem: + self._symbol_table[ + (str(node.name), 0) + ] = self._symbol_table[(node.args[0], node.args[1])] + else: + self._import_op(node) + return self._symbol_table.get(("output", 0)) + + return self._module + + def _import_placeholder(self, node: torch.fx.Node, args_list): + placeholder_name = args_list[self._num_input_visited] + self._symbol_table[(str(node.name), 0)] = placeholder_name + self._num_input_visited += 1 + + def _import_op(self, node: torch.fx.Node): + op_name = node.target.__name__ + op_ret: Union[ir.Operation, tuple] = operation_func[op_name]( + node, self._symbol_table + ) + if isinstance(op_ret, tuple): + for i, operation in op_ret: + self._symbol_table[(str(node.name), i)] = operation.result + else: + self._symbol_table[(str(node.name), 0)] = op_ret.result diff --git a/frontend/Python/ops_gen.py b/frontend/Python/ops_gen.py new file mode 100644 index 0000000000..7abb28aa15 --- /dev/null +++ b/frontend/Python/ops_gen.py @@ -0,0 +1,106 @@ +# ===- ops_gen.py -------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the collection of operation generators. +# +# ===--------------------------------------------------------------------------- + +from typing import Dict, List, Tuple + +import mlir.ir as ir +from mlir.dialects import arith, linalg, tosa +import torch + + +def _broadcast_shape( + tensor_input1: ir.Value, tensor_input2: ir.Value +) -> List[int]: + """Calculate the broadcast shape of two tensors with broadcastable shapes + according to PyTorch's broadcast semantics: + https://pytorch.org/docs/stable/notes/broadcasting.html + + """ + shp1 = ir.RankedTensorType(tensor_input1.type).shape + shp2 = ir.RankedTensorType(tensor_input2.type).shape + if len(shp1) < len(shp2): + shp1, shp2 = shp2, shp1 + while len(shp2) < len(shp1): + shp2.insert(0, 1) + for idx, (dim1, dim2) in enumerate(zip(shp1, shp2)): + shp1[idx] = shp2[idx] = max(dim1, dim2) + + return shp1 + + +def add_op( + node: torch.fx.Node, symbol_table: Dict[Tuple[str, int], ir.Operation] +) -> ir.Operation: + """Map aten.add.Tensor to tosa.add. + + Args: + node: A FX graph containing the aten.add.Tensor operator and its parameter. + symbol_table: The symbol table that records the mapping between symbols + and operations. + + Returns: + ir.Operation: The generated tosa.add operation. + + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + broadcasted_shp = _broadcast_shape(input1, input2) + sizes = broadcasted_shp + f32 = ir.F32Type.get() + addResultTensorType = ir.RankedTensorType.get(sizes, f32) + op = tosa.AddOp(addResultTensorType, input1, input2) + return op + + +def addmm_op( + node: torch.fx.Node, symbol_table: Dict[Tuple[str, int], ir.Operation] +) -> ir.Operation: + """Map aten.addmm.default to MLIR operation. + + Args: + node (torch.fx.Node): A FX graph containing the aten.addmm.default + operator and its parameter. + symbol_table (Dict[Tuple[str, int], ir.Operation]): The symbol table that + records the mapping between symbols and operations. + + Returns: + ir.Operation: The generated MLIR operation representing aten.addmm.default + + """ + input_ = symbol_table.get((str(node.args[0]), 0)) + mat1 = symbol_table.get((str(node.args[1]), 0)) + mat2 = symbol_table.get((str(node.args[2]), 0)) + mat1_shp = ir.RankedTensorType(mat1.type).shape + mat2_shp = ir.RankedTensorType(mat2.type).shape + result_shp = [mat1_shp[0], mat2_shp[1]] + f32 = ir.F32Type.get() + element = ir.FloatAttr.get(f32, 0.0) + tensor_type = ir.RankedTensorType.get(result_shp, f32) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + matmul_result_buffer = arith.ConstantOp(tensor_type, attr).result + # Generate matmul operation. + matmul_op_result = linalg.matmul(mat1, mat2, outs=[matmul_result_buffer]) + + add_result_tensor_type = ir.RankedTensorType.get(result_shp, f32) + op = tosa.AddOp(add_result_tensor_type, input_, matmul_op_result) + return op + + +operation_func = {"add.Tensor": add_op, "addmm.default": addmm_op} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..83c116eb5f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 80 diff --git a/tests/Python/test_arith_add.py b/tests/Python/test_arith_add.py new file mode 100644 index 0000000000..fbf351a9f2 --- /dev/null +++ b/tests/Python/test_arith_add.py @@ -0,0 +1,23 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo + +from buddy.compiler import frontend + + +def foo(x, y): + return x + y + + +foo_mlir = dynamo.optimize(frontend.dynamo_importer)(foo) +in1 = torch.randn(10) +in2 = torch.randn(10) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.add" +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } +foo_mlir(in1, in2) diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index cae228c83e..4cf5e245f7 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -16,56 +16,95 @@ # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'BUDDY' +config.name = "BUDDY" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.c', '.cpp'] +config.suffixes = [".mlir", ".c", ".cpp"] +if config.buddy_mlir_enable_python_packages: + config.suffixes.append(".py") # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.buddy_obj_root, 'tests') +config.test_exec_root = os.path.join(config.buddy_obj_root, "tests") -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) -llvm_config.with_system_environment( - ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. -config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] +config.excludes = [ + "Inputs", + "Examples", + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", + "lit.cfg.py", + "lit.site.cfg.py", +] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.buddy_obj_root, 'tests') +config.test_exec_root = os.path.join(config.buddy_obj_root, "tests") # config.buddy_tools_dir = os.path.join(config.buddy_obj_root, 'bin') # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) tool_dirs = [config.buddy_tools_dir, config.llvm_tools_dir] tools = [ - 'buddy-opt', - 'buddy-translate', - 'buddy-container-test', - 'buddy-audio-container-test', - 'buddy-text-container-test', - 'mlir-cpu-runner', + "buddy-opt", + "buddy-translate", + "buddy-container-test", + "buddy-audio-container-test", + "buddy-text-container-test", + "mlir-cpu-runner", ] -tools.extend([ - ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'), -]) +tools.extend( + [ + ToolSubst( + "%mlir_runner_utils_dir", + config.mlir_runner_utils_dir, + unresolved="ignore", + ), + ] +) + +python_executable = config.python_executable +tools.extend( + [ + ToolSubst("%PYTHON", python_executable, unresolved="ignore"), + ] +) + +# Add the python path for both upstream MLIR and Buddy Compiler python packages. +if config.buddy_mlir_enable_python_packages: + llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join( + config.llvm_build_dir, + "tools", + "mlir", + "python_packages", + "mlir_core", + ), + config.buddy_python_packages_dir, + ], + append_path=True, + ) if config.buddy_enable_opencv == "ON": - tools.append('buddy-image-container-test') + tools.append("buddy-image-container-test") llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tests/lit.site.cfg.py.in b/tests/lit.site.cfg.py.in index ab97e755b5..6a5e5f37e3 100644 --- a/tests/lit.site.cfg.py.in +++ b/tests/lit.site.cfg.py.in @@ -12,7 +12,7 @@ config.llvm_shlib_dir = "@SHLIBDIR@" config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" -config.python_executable = "@PYTHON_EXECUTABLE@" +config.python_executable = "@Python3_EXECUTABLE@" config.gold_executable = "@GOLD_EXECUTABLE@" config.ld64_executable = "@LD64_EXECUTABLE@" config.enable_shared = @ENABLE_SHARED@ @@ -28,11 +28,14 @@ config.enable_libcxx = "@LLVM_ENABLE_LIBCXX@" config.host_ldflags = '@HOST_LDFLAGS@' config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' +config.llvm_build_dir = "@LLVM_PROJECT_BUILD_DIR@" config.host_arch = "@HOST_ARCH@" config.buddy_src_root = "@CMAKE_SOURCE_DIR@" config.buddy_obj_root = "@CMAKE_BINARY_DIR@" config.buddy_tools_dir = "@BUDDY_BINARY_DIR@" config.buddy_enable_opencv = "@BUDDY_ENABLE_OPENCV@" +config.buddy_mlir_enable_python_packages = "@BUDDY_MLIR_ENABLE_PYTHON_PACKAGES@" +config.buddy_python_packages_dir = "@BUDDY_MLIR_PYTHON_PACKAGES_DIR@" config.mlir_runner_utils_dir = "@LLVM_LIBS_DIR@" # Support substitution of the tools_dir with user parameters. This is