Skip to content

Commit

Permalink
[frontend] Initial frontend Python packages.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Oct 13, 2023
1 parent d43b463 commit 17d6e35
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 21 deletions.
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR BUDDY_MLIR_OUT_OF_TREE_

set(LLVM_MLIR_BINARY_DIR ${MLIR_DIR}/../../../bin)
set(LLVM_MLIR_SOURCE_DIR ${MLIR_DIR}/../../../../mlir)
set(LLVM_PROJECT_BUILD_DIR ${MLIR_DIR}/../../../)
set(LLVM_PROJECT_SOURCE_DIR ${MLIR_DIR}/../../../../)

list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
Expand Down Expand Up @@ -86,6 +87,7 @@ set(BUDDY_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/bin)
set(BUDDY_EXAMPLES_DIR ${BUDDY_SOURCE_DIR}/examples)
set(BUDDY_MIDEND_INCLUDE_DIR ${BUDDY_SOURCE_DIR}/midend/include/)
set(BUDDY_THIRDPARTY_INCLUDE_DIR ${BUDDY_SOURCE_DIR}/thirdparty/include/)
set(BUDDY_MLIR_PYTHON_PACKAGES_DIR ${BUDDY_BUILD_DIR}/python_packages)

set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${BUDDY_BINARY_DIR})

Expand Down Expand Up @@ -154,6 +156,17 @@ if(BUDDY_DSL_EXAMPLES)
# add macros to generate ANTLR Cpp code from grammar
find_package(ANTLR REQUIRED)
endif()

#-------------------------------------------------------------------------------
# Initialize Python packages
#-------------------------------------------------------------------------------
if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES)
file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy)
file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler)
file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/__init__.py "")
file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/__init__.py "")
endif()

#-------------------------------------------------------------------------------
# Directory setup
#-------------------------------------------------------------------------------
Expand Down
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ $ cmake -G Ninja ../llvm \
-DCMAKE_BUILD_TYPE=RELEASE
```

To enable MLIR Python bindings, please use the following configuration:

```
$ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_TARGETS_TO_BUILD="host;RISCV" \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DCMAKE_BUILD_TYPE=RELEASE \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DPython3_EXECUTABLE=[path_to_python_executable]
```

If your target machine has lld installed, you can use the following configuration:

```
Expand Down Expand Up @@ -74,6 +86,18 @@ $ ninja
$ ninja check-buddy
```

To utilize the Buddy Compiler Python package, please ensure that the MLIR Python bindings are enabled and use the following configuration:

```
$ cmake -G Ninja .. \
-DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \
-DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DCMAKE_BUILD_TYPE=RELEASE \
-DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \
-DPython3_EXECUTABLE=[path_to_python_executable]
```

If you want to add domain-specific framework support, please add the following cmake options:

| Framework | Enable Option | Other Options |
Expand Down
3 changes: 3 additions & 0 deletions frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
add_subdirectory(FrontendGen)
add_subdirectory(Interfaces)
if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES)
add_subdirectory(Python)
endif()
5 changes: 5 additions & 0 deletions frontend/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Get all `.py` files in the current path
file(GLOB PY_FILES "*.py")

# Copy all `.py` files to the destination
file(COPY ${PY_FILES} DESTINATION ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/)
142 changes: 142 additions & 0 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# ===- frontend.py -------------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the entry of the Buddy Compiler frontend.
#
# ===---------------------------------------------------------------------------

import operator
from typing import Callable, List, Union

import mlir.dialects.func as func
import mlir.ir as ir
from mlir.passmanager import PassManager
import torch
from torch._functorch.aot_autograd import aot_module_simplified

from .ops_gen import operation_func


def dynamo_importer(
gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
) -> Callable:
"""The main entry point of buddy compiler for torch dynamo. It takes a FX
graph module and a list of inputs as parameters. The compiler will first use
PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then
it will map the operators in Aten/Prims IR to MLIR operations and generate
an MLIR module. Finally, It will lower the MLIR module to LLVM dialect.
Args:
gm (torch.fx.GraphModule): The FX graph module to be compiled.
inputs (List[torch.Tensor]): The inputs of the FX graph module.
Returns:
Callable: A compiled function that equivalent to the FX graph.
"""

def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
"""Compile a FX graph in Aten/Prims IR to MLIR."""
# Custom Compiler from FX Graph to MLIR
# Initialize the MLIR context.
ctx = ir.Context()
with ir.Location.unknown(ctx):
fx_importer = FXGraphImporter(gm, inputs)
module = fx_importer.import_graph()
# TODO: design an interface that return the `module` directly.
module.dump()
return gm.forward

return aot_module_simplified(gm, inputs, fw_compiler=_compiler)


class FXGraphImporter:
"""The FX graph importer class."""

def __init__(
self,
gm: torch.fx.GraphModule,
inputs: List[torch.Tensor],
func_name: str = "forward",
):
"""
Args:
gm (torch.fx.GraphModule): The FX graph module that will be imported.
inputs (List[torch.Tensor]): Input tensor(s) of the FX graph.
func_name (str): Name of the generated MLIR func.
"""
self._symbol_table = {}
self._gm = gm
self._func_name = func_name
self._inputs = inputs
self._num_input_visited = 0
self._module = ir.Module.create()

def import_graph(self) -> ir.Module:
"""Import the FX graph, generate an MLIR module in high-level dialects.
Returns:
mlir.ir.Module: An MLIR moduel in high-level dialects.
"""
with ir.InsertionPoint(self._module.body):
arguments = []
for arg in self._inputs:
shape_list = list(arg.shape)
f32 = ir.F32Type.get()
tensor_arg = ir.RankedTensorType.get(shape_list, f32)
arguments.append(tensor_arg)

@func.FuncOp.from_py_func(*arguments, name=self._func_name)
def generated_func(*args):
args_list = list(args)
for node in self._gm.graph.nodes:
if node.op == "output":
output_node_args = node.args[0]
returns = []
for output_arg in output_node_args:
op = self._symbol_table.get((str(output_arg), 0))
returns.append(op)
self._symbol_table[("output", 0)] = returns
elif node.op == "placeholder":
self._import_placeholder(node, args_list)
else:
if node.target is operator.getitem:
self._symbol_table[
(str(node.name), 0)
] = self._symbol_table[(node.args[0], node.args[1])]
else:
self._import_op(node)
return self._symbol_table.get(("output", 0))

return self._module

def _import_placeholder(self, node: torch.fx.Node, args_list):
placeholder_name = args_list[self._num_input_visited]
self._symbol_table[(str(node.name), 0)] = placeholder_name
self._num_input_visited += 1

def _import_op(self, node: torch.fx.Node):
op_name = node.target.__name__
op_ret: Union[ir.Operation, tuple] = operation_func[op_name](
node, self._symbol_table
)
if isinstance(op_ret, tuple):
for i, operation in op_ret:
self._symbol_table[(str(node.name), i)] = operation.result
else:
self._symbol_table[(str(node.name), 0)] = op_ret.result
106 changes: 106 additions & 0 deletions frontend/Python/ops_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# ===- ops_gen.py --------------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the collection of operation generators.
#
# ===---------------------------------------------------------------------------

from typing import Dict, List, Tuple

import mlir.ir as ir
from mlir.dialects import arith, linalg, tosa
import torch


def _broadcast_shape(
tensor_input1: ir.Value, tensor_input2: ir.Value
) -> List[int]:
"""Calculate the broadcast shape of two tensors with broadcastable shapes
according to PyTorch's broadcast semantics:
https://pytorch.org/docs/stable/notes/broadcasting.html
"""
shp1 = ir.RankedTensorType(tensor_input1.type).shape
shp2 = ir.RankedTensorType(tensor_input2.type).shape
if len(shp1) < len(shp2):
shp1, shp2 = shp2, shp1
while len(shp2) < len(shp1):
shp2.insert(0, 1)
for idx, (dim1, dim2) in enumerate(zip(shp1, shp2)):
shp1[idx] = shp2[idx] = max(dim1, dim2)

return shp1


def add_op(
node: torch.fx.Node, symbol_table: Dict[Tuple[str, int], ir.Operation]
) -> ir.Operation:
"""Map aten.add.Tensor to tosa.add.
Args:
node: A FX graph containing the aten.add.Tensor operator and its parameter.
symbol_table: The symbol table that records the mapping between symbols
and operations.
Returns:
ir.Operation: The generated tosa.add operation.
"""
input1 = symbol_table.get((str(node.args[0]), 0))
input2 = symbol_table.get((str(node.args[1]), 0))
broadcasted_shp = _broadcast_shape(input1, input2)
sizes = broadcasted_shp
f32 = ir.F32Type.get()
addResultTensorType = ir.RankedTensorType.get(sizes, f32)
op = tosa.AddOp(addResultTensorType, input1, input2)
return op


def addmm_op(
node: torch.fx.Node, symbol_table: Dict[Tuple[str, int], ir.Operation]
) -> ir.Operation:
"""Map aten.addmm.default to MLIR operation.
Args:
node (torch.fx.Node): A FX graph containing the aten.addmm.default
operator and its parameter.
symbol_table (Dict[Tuple[str, int], ir.Operation]): The symbol table that
records the mapping between symbols and operations.
Returns:
ir.Operation: The generated MLIR operation representing aten.addmm.default
"""
input_ = symbol_table.get((str(node.args[0]), 0))
mat1 = symbol_table.get((str(node.args[1]), 0))
mat2 = symbol_table.get((str(node.args[2]), 0))
mat1_shp = ir.RankedTensorType(mat1.type).shape
mat2_shp = ir.RankedTensorType(mat2.type).shape
result_shp = [mat1_shp[0], mat2_shp[1]]
f32 = ir.F32Type.get()
element = ir.FloatAttr.get(f32, 0.0)
tensor_type = ir.RankedTensorType.get(result_shp, f32)
attr = ir.DenseElementsAttr.get_splat(tensor_type, element)
matmul_result_buffer = arith.ConstantOp(tensor_type, attr).result
# Generate matmul operation.
matmul_op_result = linalg.matmul(mat1, mat2, outs=[matmul_result_buffer])

add_result_tensor_type = ir.RankedTensorType.get(result_shp, f32)
op = tosa.AddOp(add_result_tensor_type, input_, matmul_op_result)
return op


operation_func = {"add.Tensor": add_op, "addmm.default": addmm_op}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 80
23 changes: 23 additions & 0 deletions tests/Python/test_arith_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo

from buddy.compiler import frontend


def foo(x, y):
return x + y


foo_mlir = dynamo.optimize(frontend.dynamo_importer)(foo)
in1 = torch.randn(10)
in2 = torch.randn(10)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = "tosa.add"
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
foo_mlir(in1, in2)
Loading

0 comments on commit 17d6e35

Please sign in to comment.