-
Notifications
You must be signed in to change notification settings - Fork 173
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[frontend] Initial frontend Python packages.
- Loading branch information
Showing
10 changed files
with
381 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
add_subdirectory(FrontendGen) | ||
add_subdirectory(Interfaces) | ||
if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) | ||
add_subdirectory(Python) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Get all `.py` files in the current path | ||
file(GLOB PY_FILES "*.py") | ||
|
||
# Copy all `.py` files to the destination | ||
file(COPY ${PY_FILES} DESTINATION ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# ===- frontend.py ------------------------------------------------------------- | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# ===--------------------------------------------------------------------------- | ||
# | ||
# This is the entry of the Buddy Compiler frontend. | ||
# | ||
# ===--------------------------------------------------------------------------- | ||
|
||
import operator | ||
from typing import Callable, List, Union | ||
|
||
import mlir.dialects.func as func | ||
import mlir.ir as ir | ||
from mlir.passmanager import PassManager | ||
import torch | ||
from torch._functorch.aot_autograd import aot_module_simplified | ||
|
||
from .ops_gen import operation_func | ||
|
||
|
||
def dynamo_importer( | ||
gm: torch.fx.GraphModule, inputs: List[torch.Tensor] | ||
) -> Callable: | ||
"""The main entry point of buddy compiler for torch dynamo. It takes a FX | ||
graph module and a list of inputs as parameters. The compiler will first use | ||
PyTorch's AOT autograd to lower FX graph in Torch IR to Aten/Prims IR. Then | ||
it will map the operators in Aten/Prims IR to MLIR operations and generate | ||
an MLIR module. Finally, It will lower the MLIR module to LLVM dialect. | ||
Args: | ||
gm (torch.fx.GraphModule): The FX graph module to be compiled. | ||
inputs (List[torch.Tensor]): The inputs of the FX graph module. | ||
Returns: | ||
Callable: A compiled function that equivalent to the FX graph. | ||
""" | ||
|
||
def _compiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): | ||
"""Compile a FX graph in Aten/Prims IR to MLIR.""" | ||
# Custom Compiler from FX Graph to MLIR | ||
# Initialize the MLIR context. | ||
ctx = ir.Context() | ||
with ir.Location.unknown(ctx): | ||
fx_importer = FXGraphImporter(gm, inputs) | ||
module = fx_importer.import_graph() | ||
# TODO: design an interface that return the `module` directly. | ||
module.dump() | ||
return gm.forward | ||
|
||
return aot_module_simplified(gm, inputs, fw_compiler=_compiler) | ||
|
||
|
||
class FXGraphImporter: | ||
"""The FX graph importer class.""" | ||
|
||
def __init__( | ||
self, | ||
gm: torch.fx.GraphModule, | ||
inputs: List[torch.Tensor], | ||
func_name: str = "forward", | ||
): | ||
""" | ||
Args: | ||
gm (torch.fx.GraphModule): The FX graph module that will be imported. | ||
inputs (List[torch.Tensor]): Input tensor(s) of the FX graph. | ||
func_name (str): Name of the generated MLIR func. | ||
""" | ||
self._symbol_table = {} | ||
self._gm = gm | ||
self._func_name = func_name | ||
self._inputs = inputs | ||
self._num_input_visited = 0 | ||
self._module = ir.Module.create() | ||
|
||
def import_graph(self) -> ir.Module: | ||
"""Import the FX graph, generate an MLIR module in high-level dialects. | ||
Returns: | ||
mlir.ir.Module: An MLIR moduel in high-level dialects. | ||
""" | ||
with ir.InsertionPoint(self._module.body): | ||
arguments = [] | ||
for arg in self._inputs: | ||
shape_list = list(arg.shape) | ||
f32 = ir.F32Type.get() | ||
tensor_arg = ir.RankedTensorType.get(shape_list, f32) | ||
arguments.append(tensor_arg) | ||
|
||
@func.FuncOp.from_py_func(*arguments, name=self._func_name) | ||
def generated_func(*args): | ||
args_list = list(args) | ||
for node in self._gm.graph.nodes: | ||
if node.op == "output": | ||
output_node_args = node.args[0] | ||
returns = [] | ||
for output_arg in output_node_args: | ||
op = self._symbol_table.get((str(output_arg), 0)) | ||
returns.append(op) | ||
self._symbol_table[("output", 0)] = returns | ||
elif node.op == "placeholder": | ||
self._import_placeholder(node, args_list) | ||
else: | ||
if node.target is operator.getitem: | ||
self._symbol_table[ | ||
(str(node.name), 0) | ||
] = self._symbol_table[(node.args[0], node.args[1])] | ||
else: | ||
self._import_op(node) | ||
return self._symbol_table.get(("output", 0)) | ||
|
||
return self._module | ||
|
||
def _import_placeholder(self, node: torch.fx.Node, args_list): | ||
placeholder_name = args_list[self._num_input_visited] | ||
self._symbol_table[(str(node.name), 0)] = placeholder_name | ||
self._num_input_visited += 1 | ||
|
||
def _import_op(self, node: torch.fx.Node): | ||
op_name = node.target.__name__ | ||
op_ret: Union[ir.Operation, tuple] = operation_func[op_name]( | ||
node, self._symbol_table | ||
) | ||
if isinstance(op_ret, tuple): | ||
for i, operation in op_ret: | ||
self._symbol_table[(str(node.name), i)] = operation.result | ||
else: | ||
self._symbol_table[(str(node.name), 0)] = op_ret.result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# ===- ops_gen.py -------------------------------------------------------------- | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# ===--------------------------------------------------------------------------- | ||
# | ||
# This is the collection of operation generators. | ||
# | ||
# ===--------------------------------------------------------------------------- | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
import mlir.ir as ir | ||
from mlir.dialects import arith, linalg, tosa | ||
import torch | ||
|
||
|
||
def _broadcast_shape( | ||
tensor_input1: ir.Value, tensor_input2: ir.Value | ||
) -> List[int]: | ||
"""Calculate the broadcast shape of two tensors with broadcastable shapes | ||
according to PyTorch's broadcast semantics: | ||
https://pytorch.org/docs/stable/notes/broadcasting.html | ||
""" | ||
shp1 = ir.RankedTensorType(tensor_input1.type).shape | ||
shp2 = ir.RankedTensorType(tensor_input2.type).shape | ||
if len(shp1) < len(shp2): | ||
shp1, shp2 = shp2, shp1 | ||
while len(shp2) < len(shp1): | ||
shp2.insert(0, 1) | ||
for idx, (dim1, dim2) in enumerate(zip(shp1, shp2)): | ||
shp1[idx] = shp2[idx] = max(dim1, dim2) | ||
|
||
return shp1 | ||
|
||
|
||
def add_op( | ||
node: torch.fx.Node, symbol_table: Dict[Tuple[str, int], ir.Operation] | ||
) -> ir.Operation: | ||
"""Map aten.add.Tensor to tosa.add. | ||
Args: | ||
node: A FX graph containing the aten.add.Tensor operator and its parameter. | ||
symbol_table: The symbol table that records the mapping between symbols | ||
and operations. | ||
Returns: | ||
ir.Operation: The generated tosa.add operation. | ||
""" | ||
input1 = symbol_table.get((str(node.args[0]), 0)) | ||
input2 = symbol_table.get((str(node.args[1]), 0)) | ||
broadcasted_shp = _broadcast_shape(input1, input2) | ||
sizes = broadcasted_shp | ||
f32 = ir.F32Type.get() | ||
addResultTensorType = ir.RankedTensorType.get(sizes, f32) | ||
op = tosa.AddOp(addResultTensorType, input1, input2) | ||
return op | ||
|
||
|
||
def addmm_op( | ||
node: torch.fx.Node, symbol_table: Dict[Tuple[str, int], ir.Operation] | ||
) -> ir.Operation: | ||
"""Map aten.addmm.default to MLIR operation. | ||
Args: | ||
node (torch.fx.Node): A FX graph containing the aten.addmm.default | ||
operator and its parameter. | ||
symbol_table (Dict[Tuple[str, int], ir.Operation]): The symbol table that | ||
records the mapping between symbols and operations. | ||
Returns: | ||
ir.Operation: The generated MLIR operation representing aten.addmm.default | ||
""" | ||
input_ = symbol_table.get((str(node.args[0]), 0)) | ||
mat1 = symbol_table.get((str(node.args[1]), 0)) | ||
mat2 = symbol_table.get((str(node.args[2]), 0)) | ||
mat1_shp = ir.RankedTensorType(mat1.type).shape | ||
mat2_shp = ir.RankedTensorType(mat2.type).shape | ||
result_shp = [mat1_shp[0], mat2_shp[1]] | ||
f32 = ir.F32Type.get() | ||
element = ir.FloatAttr.get(f32, 0.0) | ||
tensor_type = ir.RankedTensorType.get(result_shp, f32) | ||
attr = ir.DenseElementsAttr.get_splat(tensor_type, element) | ||
matmul_result_buffer = arith.ConstantOp(tensor_type, attr).result | ||
# Generate matmul operation. | ||
matmul_op_result = linalg.matmul(mat1, mat2, outs=[matmul_result_buffer]) | ||
|
||
add_result_tensor_type = ir.RankedTensorType.get(result_shp, f32) | ||
op = tosa.AddOp(add_result_tensor_type, input_, matmul_op_result) | ||
return op | ||
|
||
|
||
operation_func = {"add.Tensor": add_op, "addmm.default": addmm_op} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[tool.black] | ||
line-length = 80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# RUN: %PYTHON %s 2>&1 | FileCheck %s | ||
|
||
import torch | ||
import torch._dynamo as dynamo | ||
|
||
from buddy.compiler import frontend | ||
|
||
|
||
def foo(x, y): | ||
return x + y | ||
|
||
|
||
foo_mlir = dynamo.optimize(frontend.dynamo_importer)(foo) | ||
in1 = torch.randn(10) | ||
in2 = torch.randn(10) | ||
|
||
# CHECK: module { | ||
# CHECK-LABEL: func.func @forward | ||
# CHECK: %{{.*}} = "tosa.add" | ||
# CHECK: return %{{.*}} | ||
# CHECK: } | ||
# CHECK: } | ||
foo_mlir(in1, in2) |
Oops, something went wrong.