From 7c10e0315461d2bdf1f1f8ff64dd0856a24a04f9 Mon Sep 17 00:00:00 2001 From: R-Tars Date: Mon, 11 Nov 2024 02:01:35 +0000 Subject: [PATCH 1/6] update to torch2.4 --- examples/BuddyWhisper/import-whisper.py | 28 ++- examples/BuddyWhisper/whisper-main.cpp | 12 +- frontend/Python/frontend.py | 15 ++ frontend/Python/graph/operation.py | 19 ++ frontend/Python/ops/linalg.py | 256 ++++++++++++++++++++- frontend/Python/ops/tosa.py | 285 +++++++++++++++++++++++- 6 files changed, 598 insertions(+), 17 deletions(-) diff --git a/examples/BuddyWhisper/import-whisper.py b/examples/BuddyWhisper/import-whisper.py index 449646a676..533764eb8b 100644 --- a/examples/BuddyWhisper/import-whisper.py +++ b/examples/BuddyWhisper/import-whisper.py @@ -22,6 +22,7 @@ import torch import torch._dynamo as dynamo from torch._inductor.decomposition import decompositions as inductor_decomp +import transformers from transformers import WhisperForConditionalGeneration import numpy @@ -29,7 +30,11 @@ from buddy.compiler.ops import tosa from buddy.compiler.graph import GraphDriver from buddy.compiler.graph.transform import simply_fuse +from torch._decomp import get_decompositions + +print(torch.__version__) +print(transformers.__version__) # Retrieve the Whisper model path from environment variables. model_path = os.environ.get("WHISPER_MODEL_PATH") if model_path is None: @@ -40,17 +45,36 @@ model.config.use_cache = False # Generate placeholder for inputs. -input_features = torch.zeros(size=(1, 80, 3000), dtype=torch.float32) -decoder_input_ids = torch.zeros(size=(1, 448), dtype=torch.long) +input_features = torch.ones(size=(1, 80, 3000), dtype=torch.float32) +decoder_input_ids = torch.ones(size=(1, 448), dtype=torch.long) * 50258 inputs = { "input_features": input_features, "decoder_input_ids": decoder_input_ids, } +out = model(**inputs) +print(out.logits.flatten()[0:10]) +print(out.logits.shape) +print(out.encoder_last_hidden_state.shape) + +# DEFAULT_DECOMPOSITIONS = [ +# torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, +# ] + +# decomp = get_decompositions(DEFAULT_DECOMPOSITIONS) + +# # Initialize Dynamo Compiler with specific configurations as an importer. +# dynamo_compiler = DynamoCompiler( +# primary_registry=tosa.ops_registry, +# aot_autograd_decomposition={**inductor_decomp, **decomp}, +# # verbose=True +# ) + # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, aot_autograd_decomposition=inductor_decomp, + # verbose=True ) # Import the model into MLIR module and parameters. diff --git a/examples/BuddyWhisper/whisper-main.cpp b/examples/BuddyWhisper/whisper-main.cpp index 7d69ea3074..42e75f2c38 100644 --- a/examples/BuddyWhisper/whisper-main.cpp +++ b/examples/BuddyWhisper/whisper-main.cpp @@ -125,8 +125,9 @@ int main() { Text outputContainer; Audio rawAudioContainer("../../examples/BuddyWhisper/audio.wav"); MemRef audioInput({1, 80, 3000}); - MemRef resultContainer[2] = { + MemRef resultContainer[3] = { MemRef({1, 1500, 512}, false, 0), + MemRef({1, 448, 512}, false, 0), MemRef({1, 448, MaxVocabSize}, false, 0), }; MemRef textContainer({1, MaxTokenLength}, 50258); @@ -155,7 +156,7 @@ int main() { inferenceEnd - inferenceStart; // Determine the generated token. - const float *startPtr = resultContainer[1].getData() + i * MaxVocabSize; + const float *startPtr = resultContainer[2].getData() + i * MaxVocabSize; const float *endPtr = startPtr + MaxVocabSize; int maxIndex = findMaxIndex(startPtr, endPtr); @@ -171,8 +172,9 @@ int main() { textContainer.getData()[i + 1] = maxIndex; outputContainer.appendTokenIdx(maxIndex); - free(resultContainer[0].release()); - free(resultContainer[1].release()); + // free(resultContainer[0].release()); + // free(resultContainer[1].release()); + // free(resultContainer[2].release()); } /// Print the final result @@ -180,4 +182,4 @@ int main() { << std::endl; return 0; -} +} \ No newline at end of file diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 9d8c80f014..cac4275630 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -124,6 +124,7 @@ def __init__( "mean.dim": MeanOp, "rsqrt.default": RsqrtOp, "mul.Tensor": MulOp, + "mul.Scalar": MulOp, "t.default": TOp, "mm.default": MatmulOp, "transpose.int": TransposeOp, @@ -167,6 +168,9 @@ def __init__( "split.Tensor":SplitOp, "max.default":MaxOp, "gt.Scalar":GtOp, + "_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp, + "ge.Scalar": GeOp, + "gt.Tensor": GreaterThanOp, } @property @@ -286,6 +290,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): self._func_name, self._verbose ) + # with open('/home/zhuxinye/buddy-mlir/examples/TestOp/op.txt', 'w') as f: + # for gm_node in _gm.graph.nodes: + # f.write(f"{gm_node.name}\n") for gm_node in _gm.graph.nodes: node_users = [] for user in gm_node.users.keys(): @@ -325,6 +332,14 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): ) else: + # if str(gm_node.target) == "aten._scaled_dot_product_flash_attention_for_cpu.default": + # print(gm_node) + # print(gm_node.target) + # print(gm_node.target._schema.returns) + # print(gm_node.meta.get("val")) + # print(gm_node.meta.get("tensor_meta")) + # print(gm_node.kwargs) + # print(gm_node.args) tensor_meta = gm_node.meta.get("tensor_meta") val = gm_node.meta.get("val") # num_returns = len(gm_node.target._schema.returns) diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 0eb31fd961..511adf6e35 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -534,3 +534,22 @@ class GtOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + + +class ScaledDotProductFlashAttentionForCpuOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class GeOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class GreaterThanOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.BroadcastType + diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index b561b3433a..f73d219862 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1073,6 +1073,26 @@ def mul_op( element = mlir_element_attr_get(dtype, node.args[1]) attr = ir.DenseElementsAttr.get_splat(tensor_type, element) input2 = arith.ConstantOp(tensor_type, attr).result + + input1_dtype = ir.RankedTensorType(input1.type).element_type + input2_dtype = ir.RankedTensorType(input2.type).element_type + if input1_dtype != mlir_dtype: + input1 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input1.type).shape, + mlir_dtype, + ), + input1, + ) + if input2_dtype != mlir_dtype: + input2 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input2.type).shape, + mlir_dtype, + ), + input2, + ) + if input1 is None or input2 is None: return mul_result_tensor_type = ir.RankedTensorType.get(shape, mlir_dtype) @@ -1781,18 +1801,28 @@ def where_op( input3 = symbol_table.get((str(node.args[2]), 0)) if input1 is None or input2 is None or input3 is None: return - output_shape = list(node.tensor_meta["shape"]) dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) + + # print("input:") + # print(input1) + # print(input2) + # print(input3) + # print(input3.type) + if not isinstance(input2.type, ir.RankedTensorType): + input2 = tensor.SplatOp(tensor_type, input2).result + if not isinstance(input3.type, ir.RankedTensorType): + input3 = tensor.SplatOp(tensor_type, input3).result + generic_map = ir.AffineMap.get_permutation( [i for i in range(len(output_shape))] ) op = linalg.GenericOp( [tensor_type], - [input1, input3], + [input1, input2, input3], [output], ir.ArrayAttr.get( [ @@ -1811,6 +1841,11 @@ def where_op( [i for i in range(len(output_shape))] ) ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), ] ), ir.ArrayAttr.get( @@ -1822,11 +1857,12 @@ def where_op( op.region, [ ir.RankedTensorType(input1.type).element_type, + ir.RankedTensorType(input2.type).element_type, ir.RankedTensorType(input3.type).element_type, ir.RankedTensorType(output.result.type).element_type, ], ) - select_op = arith.SelectOp(block.arguments[0], input2, block.arguments[1]) + select_op = arith.SelectOp(block.arguments[0], block.arguments[1], block.arguments[2]) block.append(select_op) block.append(linalg.YieldOp([select_op.result])) @@ -1965,6 +2001,218 @@ def gt_op(node: GtOp, symbol_table): return cmp_op +def ge_op( + node: GeOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater equal operation. + From buddy GreaterEqualOp to MLIR arith `constant` operation. + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input_dtype = ir.RankedTensorType(input_tensor.type).element_type + input_shape = ir.RankedTensorType(input_tensor.type).shape + tensor_type = ir.RankedTensorType.get(input_shape, input_dtype) + + scalar = arith.ConstantOp(input_dtype, node.args[1]) + rhs = tensor.SplatOp(tensor_type, scalar) + + if str(input_dtype).find("i") != -1: + cmp_op = arith.CmpIOp(5, input_tensor, rhs) + else: + cmp_op = arith.CmpFOp(3, input_tensor, rhs) + + return cmp_op + +def greater_than_op( + node: GreaterThanOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater than operation. + From buddy GreaterThanOp to MLIR arith `constant` operation. + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + # value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4) + shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) + shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) + dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, dtype) + output = tensor.EmptyOp(output_shape, dtype) + if len(shp1) < len(shp2): + if int(shp1[-1]) > 1 and shp2[-1] == 1: + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(shp2) + 1)] + ) + op = linalg.GenericOp( + [tensor_type], + [input1, input2], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + i + for i in range( + len(shp2) - len(shp1), len(shp2) + ) + ] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2) - 1)] + + [len(shp2)] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(shp2) + + [ir.Attribute.parse("#linalg.iterator_type")] + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + ir.RankedTensorType(input2.type).element_type, + ir.RankedTensorType(input2.type).element_type, + dtype, + ], + ) + if ( + str(ir.RankedTensorType(input2.type).element_type).find("i") + != -1 + ): + cmpop = arith.CmpIOp( + 4, block.arguments[0], block.arguments[1] + ) + else: + cmpop = arith.CmpFOp( + 2, block.arguments[0], block.arguments[1] + ) + block.append(cmpop) + block.append(linalg.YieldOp([cmpop.result])) + + return op + + # print(node.args) + # input1 = symbol_table.get((str(node.args[0]), 0)) + # input2 = symbol_table.get((str(node.args[1]), 0)) + # shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) + # shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) + # output_shape = list(node.tensor_meta["shape"]) + # dtype = node.tensor_meta["dtype"] + # mlir_dtype = mlir_element_type_get(dtype) + # tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + # print(tensor_type) + # print(mlir_dtype) + # print(input1.type) + # print(ir.RankedTensorType(input1.type).element_type) + # output = tensor.EmptyOp(output_shape, mlir_dtype) + + # generic_map = ir.AffineMap.get_permutation( + # [i for i in range(len(output_shape))] + # ) + # # print(generic_map2) + + # op = linalg.GenericOp( + # [tensor_type], + # [input1, input2], + # [output], + # ir.ArrayAttr.get( + # [ + # ir.AffineMapAttr.get( + # generic_map.get_submap( + # [i for i in range(1, len(output_shape))] + # ) + # ), + # ir.AffineMapAttr.get( + # generic_map.get_submap( + # [i for i in range(len(shp2))] + # ) + # ), + # ir.AffineMapAttr.get( + # generic_map.get_submap( + # [i for i in range(len(output_shape))] + # ) + # ), + # ] + # ), + # ir.ArrayAttr.get( + # [ir.Attribute.parse("#linalg.iterator_type")] + # * len(output_shape) + # ), + # ) + # block = ir.Block.create_at_start( + # op.region, + # [ + # ir.RankedTensorType(input1.type).element_type, + # ir.RankedTensorType(input2.type).element_type, + # ir.RankedTensorType(output.result.type).element_type, + # ], + # ) + + # print("here1") + # if str(input1.type).find("i") != -1: + # lhs_index = [] + # rhs_index = [] + # for i in range(0, len(shp1)): + # lhs_indexcast_op = arith.IndexCastOp(ir.IndexType.get(), block.arguments[0]) + # block.append(lhs_indexcast_op) + # lhs_index.append(lhs_indexcast_op) + # print("here2") + # # index_op1 = linalg.IndexOp(ir._i64Attr(i, None)) + + # for i in range(0, len(shp2) - 1): + # rhs_indexcast_op = arith.IndexCastOp(ir.IndexType.get(), block.arguments[1]) + # rhs_index.append(rhs_indexcast_op) + # block.append(rhs_indexcast_op) + # print("here3") + # rhs_index_op = linalg.IndexOp(ir._i64Attr(len(shp2) - 1, None)) + # rhs_index.append(rhs_index_op) + # block.append(rhs_index_op) + + # print("here4") + # lhs = tensor.ExtractOp(input1, lhs_index) + # rhs = tensor.ExtractOp(input2, rhs_index) + # block.append(lhs) + # block.append(rhs) + # cmp_op = arith.CmpIOp(4, lhs, rhs) + # print("here5") + # block.append(cmp_op) + # block.append(linalg.YieldOp([cmp_op.result])) + # print("here6") + # else: + # cmp_op = arith.CmpFOp(2, block.arguments[0], block.arguments[1]) + + # return op ops_registry = { "MatmulOp": matmul_op, @@ -2001,4 +2249,6 @@ def gt_op(node: GtOp, symbol_table): "SplitOp": split_op, "MaxOp": max_op, "GtOp": gt_op, + "GeOp": ge_op, + "GreaterThanOp": greater_than_op, } diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 797fdfd6d2..9eb0929055 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -18,13 +18,13 @@ # # ===--------------------------------------------------------------------------- -import array +import array, copy from typing import Dict, List, Tuple, Union import numpy import sys import mlir.ir as ir -from mlir.dialects import tensor, tosa, arith, linalg +from mlir.dialects import tensor, tosa, arith, linalg, math from ..graph import TensorDType from ..graph import ( @@ -62,6 +62,7 @@ ClampMaxOp, RandIntLowOp, ArgMaxOp, + ScaledDotProductFlashAttentionForCpuOp, ) from .utils import * @@ -522,9 +523,56 @@ def convert_element_type_op(node: ConvertElementTypeOp, symbol_table): } input_tensor = symbol_table.get((str(node.args[0]), 0)) to_cast_type = types_mapping[node.args[1]] - sizes = ir.RankedTensorType(input_tensor.type).shape - output_type = ir.RankedTensorType.get(sizes, to_cast_type) - return tosa.CastOp(output_type, input_tensor) + input_type = ir.RankedTensorType(input_tensor.type).element_type + # When converting float to int, tosa.cast lowers to math.roundeven, but we don't need rounding. + if str(to_cast_type).find("i") != -1 and str(input_type).find("f") != -1: + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(output_shape))] + ) + op = linalg.GenericOp( + [tensor_type], + [input_tensor], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(output_shape) + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + input_type, + to_cast_type, + ], + ) + fptosi_op = arith.FPToSIOp(to_cast_type, block.arguments[0]) + block.append(fptosi_op) + block.append(linalg.YieldOp([fptosi_op.result])) + else: + sizes = ir.RankedTensorType(input_tensor.type).shape + output_type = ir.RankedTensorType.get(sizes, to_cast_type) + op = tosa.CastOp(output_type, input_tensor) + + return op def clone_op(node: CloneOp, symbol_table): @@ -788,6 +836,41 @@ def embedding_op(node: EmbeddingOp, symbol_table): return op +# def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: +# """ +# Import the expand operation. +# From buddy graph ir's `ExpandOp` operator to MLIR TOSA `add` operation. + +# Note: This conversion is implemented using the broadcast machanism of TOSA +# `add` operation. We allocate a tensor with the shape to expand and +# elements in this tensor is all zero. Then we add the original tensor +# to this all-zero tensor. After the applying the broadcasting, we get +# the result. +# """ +# to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) +# new_size = node.args[1] +# result_element_type = ir.RankedTensorType( +# to_expand_tensor.type +# ).element_type +# if result_element_type in ( +# ir.IntegerType.get_signless(1), +# ir.IntegerType.get_signless(64), +# ): +# element = ir.IntegerAttr.get(result_element_type, 0) +# elif result_element_type == ir.F32Type.get(): +# element = ir.FloatAttr.get(result_element_type, 0.0) +# else: +# raise NotImplementedError("Unsupported element type!") +# new_size_tensor_type = ir.RankedTensorType.get( +# new_size, result_element_type +# ) +# new_size_attr = ir.DenseElementsAttr.get_splat( +# new_size_tensor_type, element +# ) +# new_size_tensor = tosa.ConstOp(new_size_attr).results[0] +# op = _gen_arith_binary_op(to_expand_tensor, new_size_tensor, tosa.AddOp) +# return op + def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: """ Import the expand operation. @@ -800,6 +883,7 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: the result. """ to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) + original_size = to_expand_tensor.type.shape new_size = node.args[1] result_element_type = ir.RankedTensorType( to_expand_tensor.type @@ -813,8 +897,14 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: element = ir.FloatAttr.get(result_element_type, 0.0) else: raise NotImplementedError("Unsupported element type!") + expanded_size = [] + for dim, size in zip(original_size, new_size): + if size == -1: + expanded_size.append(dim) + else: + expanded_size.append(size) new_size_tensor_type = ir.RankedTensorType.get( - new_size, result_element_type + expanded_size, result_element_type ) new_size_attr = ir.DenseElementsAttr.get_splat( new_size_tensor_type, element @@ -1479,9 +1569,189 @@ def argmax_op(node: ArgMaxOp, symbol_table): return op +def scaled_dot_product_flash_attention_for_cpu_op( + node: ScaledDotProductFlashAttentionForCpuOp, symbol_table +): + """ + Perform scaled dot-product attention computation. + Args: + node (ScaledDotProductFlashAttentionForCpuOp): The scaled dot-product attention operation node with metadata. + symbol_table: Mapping of variable names to tensor references. + Returns: + result_reshape_op: Reshaped result tensor of the attention operation. + log_sumexp_op: Log-sum-exp constant operation. + """ + print(node.args) + print(node.kwargs) + query = symbol_table.get((str(node.args[0]), 0), node.args[0]) + key = symbol_table.get((str(node.args[1]), 0), node.args[1]) + value = symbol_table.get((str(node.args[2]), 0), node.args[2]) + if len(node.args) == 4: + dropout_p = node.args[3] + assert dropout_p != 0.0 + if len(node.args) == 5: + dropout_p = node.args[3] + is_causal = node.args[4] + assert dropout_p != 0.0 + assert is_causal == True + + attn_mask = node.kwargs.get("attn_mask", None) + scale = node.kwargs.get("scale", None) + + + # print("attn_mask") + # print(attn_mask) + + + query_shape = query.type.shape + key_shape = key.type.shape + value_shape = value.type.shape + output_shape = list(node.tensor_meta["shape"]) + L, S = query_shape[-2], key_shape[-2] + scale_factor = 1 / numpy.sqrt(query.type.shape[-1]) if scale is None else scale + + # Initialize attention bias + dtype = node.tensor_meta["dtype"][0] + attn_bias_shape = [L, S] + mlir_dtype = mlir_element_type_get(dtype) + attn_bias_type = ir.RankedTensorType.get(attn_bias_shape, mlir_dtype) + zero_constant = arith.ConstantOp(mlir_dtype, 0.0) + attn_bias = tensor.SplatOp(attn_bias_type, zero_constant) + if attn_mask is not None: + attn_mask = symbol_table.get((str(attn_mask), 0), attn_mask) + if attn_mask.type.element_type == ir.IntegerType.get_signless(1): + assert attn_mask.type.element_type == ir.IntegerType.get_signless(1) + tensor_type = ir.RankedTensorType.get(attn_mask.type.shape, ir.IntegerType.get_signless(1)) + true_tensor = arith.ConstantOp(tensor_type, ir.DenseElementsAttr.get_splat(tensor_type, ir.BoolAttr.get(True))) + attn_mask = arith.XOrIOp(attn_mask, true_tensor) + minus_inf_tensor = arith.ConstantOp(attn_mask.type, ir.DenseElementsAttr.get_splat(attn_mask.type, ir.FloatAttr.get(f32_type, float('-inf')))) + attn_bias = tensor.SelectOp(attn_mask, minus_inf_tensor, attn_bias) + else: + if attn_mask.type.shape != attn_bias.result.type.shape: + attn_mask = tosa.ReshapeOp(attn_mask, memoryview(array.array("i",attn_bias.result.type.shape))) + attn_bias = tosa.AddOp(attn_bias.result.type, attn_bias, attn_mask) + + # Transpose key tensor + key_shape = list(key.type.shape) + perm_list = list(range(len(key_shape))) + perm_list[-1], perm_list[-2] = perm_list[-2], perm_list[-1] + perm_const_op = tosa.ConstOp( + ir.DenseElementsAttr.get(memoryview(array.array("i", perm_list))) + ) + perm_shape = [] + perm_shape.append(key_shape[0]) + perm_shape.append(key_shape[1]) + perm_shape.append(key_shape[3]) + perm_shape.append(key_shape[2]) + permute_result_type = ir.RankedTensorType.get(perm_shape, mlir_dtype) + key = tosa.TransposeOp( + permute_result_type, key, perm_const_op.results[0] + ).result + + # Matrix multiplication of query and key + query_reshape_op = tosa.ReshapeOp( + query, + memoryview( + array.array( + "i", + [ + query_shape[0] * query_shape[1], + query_shape[2], + query_shape[3], + ], + ) + ), + ) + key_reshape_op = tosa.ReshapeOp( + key, + memoryview( + array.array( + "i", [key_shape[0] * key_shape[1], key_shape[3], key_shape[2]] + ) + ), + ) + matmul_result_shp = [ + key_shape[0] * key_shape[1], + query_shape[2], + key_shape[2], + ] + matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, mlir_dtype) + matmul_op = tosa.MatMulOp( + matmul_result_type, query_reshape_op.result, key_reshape_op.result + ) + # Multiply result by scale factor + scale_factor_constant = arith.ConstantOp(mlir_dtype, scale_factor) + scale_factor = tensor.SplatOp(matmul_result_type, scale_factor_constant) + mul_op = tosa.MulOp( + matmul_result_type, + matmul_op, + scale_factor, + ir.IntegerAttr.get(ir.IntegerType.get_signless(8), 0), + ) + + # Add attention bias to the result + add_op = tosa.AddOp(matmul_result_type, mul_op.result, attn_bias) + # Apply softmax to the result + softmax_output_shape = list(add_op.result.type.shape) + softmax_dim = len(softmax_output_shape) - 1 + + # Subtract the maximum value along the dimension where softmax is applied to prevent overflow during the exp operation. + max_vals = tosa.ReduceMaxOp(add_op.result, softmax_dim) + sub_op = tosa.SubOp(add_op.result.type, add_op, max_vals) + exp_op = math.ExpOp(sub_op) + reduce_sum_op = tosa.ReduceSumOp(exp_op, softmax_dim) + log_op = tosa.LogOp(reduce_sum_op.result.type, reduce_sum_op) + log_sumexp = tosa.AddOp(max_vals.result.type, max_vals, log_op) + log_weights = tosa.SubOp(add_op.result.type, add_op, log_sumexp) + softmax_result = math.ExpOp(log_weights) + log_sumexp = tosa.ReshapeOp( + log_sumexp, + memoryview( + array.array( + "i", + output_shape[1], + ) + ), + ) + + # This step includes dropout during training. + # Multiply the result by the value tensor. + value_reshape_op = tosa.ReshapeOp( + value, + memoryview( + array.array( + "i", + [key_shape[0] * key_shape[1], value_shape[2], value_shape[3]], + ) + ), + ) + matmul_result_shp = matmul_result_shp = [ + key_shape[0] * key_shape[1], + query_shape[2], + value_shape[3], + ] + matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, mlir_dtype) + matmul_op = tosa.MatMulOp( + matmul_result_type, softmax_result.result, value_reshape_op.result + ) + + result_reshape_op = tosa.ReshapeOp( + matmul_op.result, + memoryview( + array.array( + "i", + [key_shape[0], key_shape[1], query_shape[2], value_shape[3]], + ) + ), + ) + + return result_reshape_op, log_sumexp + + + ops_registry = { "AddOp": add_op, - "MulOp": mul_op, + # "MulOp": mul_op, "SubOp": sub_op, "SumDimOp": sum_op, "TanhOp": tanh_op, @@ -1515,4 +1785,5 @@ def argmax_op(node: ArgMaxOp, symbol_table): "ClampMaxOp": clamp_max_op, "RandIntLowOp": randint_low_op, "ArgMaxOp": argmax_op, + "ScaledDotProductFlashAttentionForCpuOp": scaled_dot_product_flash_attention_for_cpu_op, } From 9a45d70a4624f9f36440327afd71494a88b7b1db Mon Sep 17 00:00:00 2001 From: R-Tars Date: Tue, 12 Nov 2024 06:46:20 +0000 Subject: [PATCH 2/6] Suppport for torch2.4 --- examples/BuddyBert/bert-main.cpp | 21 +++- examples/BuddyLeNet/buddy-lenet-import.py | 4 +- examples/BuddyLlama/CMakeLists.txt | 1 + examples/BuddyLlama/llama-main.cpp | 2 +- examples/BuddyWhisper/import-whisper.py | 28 +----- frontend/Python/frontend.py | 15 +-- frontend/Python/ops/linalg.py | 97 ------------------ frontend/Python/ops/tosa.py | 114 ++++++++++++---------- tests/Python/test_max_pool2d.py | 4 +- tests/Python/test_mean.py | 2 +- 10 files changed, 87 insertions(+), 201 deletions(-) diff --git a/examples/BuddyBert/bert-main.cpp b/examples/BuddyBert/bert-main.cpp index c75ea9d8a9..902c702c15 100644 --- a/examples/BuddyBert/bert-main.cpp +++ b/examples/BuddyBert/bert-main.cpp @@ -24,9 +24,18 @@ using namespace buddy; +// Define ResultContainer +struct ResultContainer { + MemRef memRef3D; + MemRef memRef2D; + + ResultContainer(MemRef m1, MemRef m2) + : memRef3D(m1), memRef2D(m2) {} +}; + // Declare BERT forward function. extern "C" void -_mlir_ciface_forward(MemRef *result, MemRef *arg0, +_mlir_ciface_forward(ResultContainer *result, MemRef *arg0, MemRef *arg1, MemRef *arg2, MemRef *arg3, MemRef *arg4); @@ -85,7 +94,9 @@ int main() { pureStrContainer.tokenizeBert(vocabDir, 5); /// Initialize data containers. - MemRef result({1, 6}); + MemRef result1({1, 5, 768}); + MemRef result2({1, 6}); + ResultContainer result(result1, result2); MemRef attention_mask({1, 5}, 1LL); MemRef token_type_ids({1, 5}, 0LL); @@ -93,7 +104,7 @@ int main() { /// Execute forward inference of the model. _mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer, - &attention_mask, &token_type_ids); + &token_type_ids, &attention_mask); const auto inferenceEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration inferenceTime = @@ -102,8 +113,8 @@ int main() { int predict_label = -1; float max_logits = std::numeric_limits::min(); for (int i = 0; i < 6; i++) { - if (max_logits < result.getData()[i]) { - max_logits = result.getData()[i]; + if (max_logits < result.memRef2D.getData()[i]) { + max_logits = result.memRef2D.getData()[i]; predict_label = i; } } diff --git a/examples/BuddyLeNet/buddy-lenet-import.py b/examples/BuddyLeNet/buddy-lenet-import.py index 95e76de253..c787061a55 100644 --- a/examples/BuddyLeNet/buddy-lenet-import.py +++ b/examples/BuddyLeNet/buddy-lenet-import.py @@ -23,7 +23,6 @@ import numpy as np import torch -from torch._inductor.decomposition import decompositions as inductor_decomp from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver @@ -39,13 +38,12 @@ ) model = LeNet() -model = torch.load(model_path + "/lenet-model.pth") +model = torch.load(model_path + "/lenet-model.pth", weights_only=False) model = model.eval() # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, ) data = torch.randn([1, 1, 28, 28]) diff --git a/examples/BuddyLlama/CMakeLists.txt b/examples/BuddyLlama/CMakeLists.txt index a6bfc2f742..6953b7de7d 100644 --- a/examples/BuddyLlama/CMakeLists.txt +++ b/examples/BuddyLlama/CMakeLists.txt @@ -53,6 +53,7 @@ add_custom_command( COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | ${BUDDY_BINARY_DIR}/buddy-opt + -convert-elementwise-to-linalg -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor diff --git a/examples/BuddyLlama/llama-main.cpp b/examples/BuddyLlama/llama-main.cpp index 0bfc1e5d2f..61c42f0db2 100644 --- a/examples/BuddyLlama/llama-main.cpp +++ b/examples/BuddyLlama/llama-main.cpp @@ -24,7 +24,7 @@ using namespace buddy; -constexpr size_t ParamsSize = 6755192832; +constexpr size_t ParamsSize = 6738415680; constexpr size_t MaxVocabSize = 32000; constexpr size_t MaxTokenLength = 40; constexpr size_t HiddenSize = 4096; diff --git a/examples/BuddyWhisper/import-whisper.py b/examples/BuddyWhisper/import-whisper.py index 533764eb8b..449646a676 100644 --- a/examples/BuddyWhisper/import-whisper.py +++ b/examples/BuddyWhisper/import-whisper.py @@ -22,7 +22,6 @@ import torch import torch._dynamo as dynamo from torch._inductor.decomposition import decompositions as inductor_decomp -import transformers from transformers import WhisperForConditionalGeneration import numpy @@ -30,11 +29,7 @@ from buddy.compiler.ops import tosa from buddy.compiler.graph import GraphDriver from buddy.compiler.graph.transform import simply_fuse -from torch._decomp import get_decompositions - -print(torch.__version__) -print(transformers.__version__) # Retrieve the Whisper model path from environment variables. model_path = os.environ.get("WHISPER_MODEL_PATH") if model_path is None: @@ -45,36 +40,17 @@ model.config.use_cache = False # Generate placeholder for inputs. -input_features = torch.ones(size=(1, 80, 3000), dtype=torch.float32) -decoder_input_ids = torch.ones(size=(1, 448), dtype=torch.long) * 50258 +input_features = torch.zeros(size=(1, 80, 3000), dtype=torch.float32) +decoder_input_ids = torch.zeros(size=(1, 448), dtype=torch.long) inputs = { "input_features": input_features, "decoder_input_ids": decoder_input_ids, } -out = model(**inputs) -print(out.logits.flatten()[0:10]) -print(out.logits.shape) -print(out.encoder_last_hidden_state.shape) - -# DEFAULT_DECOMPOSITIONS = [ -# torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, -# ] - -# decomp = get_decompositions(DEFAULT_DECOMPOSITIONS) - -# # Initialize Dynamo Compiler with specific configurations as an importer. -# dynamo_compiler = DynamoCompiler( -# primary_registry=tosa.ops_registry, -# aot_autograd_decomposition={**inductor_decomp, **decomp}, -# # verbose=True -# ) - # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, aot_autograd_decomposition=inductor_decomp, - # verbose=True ) # Import the model into MLIR module and parameters. diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index cac4275630..9cf44948c9 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -290,9 +290,6 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): self._func_name, self._verbose ) - # with open('/home/zhuxinye/buddy-mlir/examples/TestOp/op.txt', 'w') as f: - # for gm_node in _gm.graph.nodes: - # f.write(f"{gm_node.name}\n") for gm_node in _gm.graph.nodes: node_users = [] for user in gm_node.users.keys(): @@ -332,18 +329,10 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): ) else: - # if str(gm_node.target) == "aten._scaled_dot_product_flash_attention_for_cpu.default": - # print(gm_node) - # print(gm_node.target) - # print(gm_node.target._schema.returns) - # print(gm_node.meta.get("val")) - # print(gm_node.meta.get("tensor_meta")) - # print(gm_node.kwargs) - # print(gm_node.args) tensor_meta = gm_node.meta.get("tensor_meta") val = gm_node.meta.get("val") - # num_returns = len(gm_node.target._schema.returns) - num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) + num_returns = len(gm_node.target._schema.returns) + # num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) if num_returns == 1: node_dtype = self._torch_dtype_translate( str(tensor_meta.dtype) diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index f73d219862..0081a117a2 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1807,11 +1807,6 @@ def where_op( tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) - # print("input:") - # print(input1) - # print(input2) - # print(input3) - # print(input3.type) if not isinstance(input2.type, ir.RankedTensorType): input2 = tensor.SplatOp(tensor_type, input2).result if not isinstance(input3.type, ir.RankedTensorType): @@ -2122,98 +2117,6 @@ def greater_than_op( return op - # print(node.args) - # input1 = symbol_table.get((str(node.args[0]), 0)) - # input2 = symbol_table.get((str(node.args[1]), 0)) - # shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) - # shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) - # output_shape = list(node.tensor_meta["shape"]) - # dtype = node.tensor_meta["dtype"] - # mlir_dtype = mlir_element_type_get(dtype) - # tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) - # print(tensor_type) - # print(mlir_dtype) - # print(input1.type) - # print(ir.RankedTensorType(input1.type).element_type) - # output = tensor.EmptyOp(output_shape, mlir_dtype) - - # generic_map = ir.AffineMap.get_permutation( - # [i for i in range(len(output_shape))] - # ) - # # print(generic_map2) - - # op = linalg.GenericOp( - # [tensor_type], - # [input1, input2], - # [output], - # ir.ArrayAttr.get( - # [ - # ir.AffineMapAttr.get( - # generic_map.get_submap( - # [i for i in range(1, len(output_shape))] - # ) - # ), - # ir.AffineMapAttr.get( - # generic_map.get_submap( - # [i for i in range(len(shp2))] - # ) - # ), - # ir.AffineMapAttr.get( - # generic_map.get_submap( - # [i for i in range(len(output_shape))] - # ) - # ), - # ] - # ), - # ir.ArrayAttr.get( - # [ir.Attribute.parse("#linalg.iterator_type")] - # * len(output_shape) - # ), - # ) - # block = ir.Block.create_at_start( - # op.region, - # [ - # ir.RankedTensorType(input1.type).element_type, - # ir.RankedTensorType(input2.type).element_type, - # ir.RankedTensorType(output.result.type).element_type, - # ], - # ) - - # print("here1") - # if str(input1.type).find("i") != -1: - # lhs_index = [] - # rhs_index = [] - # for i in range(0, len(shp1)): - # lhs_indexcast_op = arith.IndexCastOp(ir.IndexType.get(), block.arguments[0]) - # block.append(lhs_indexcast_op) - # lhs_index.append(lhs_indexcast_op) - # print("here2") - # # index_op1 = linalg.IndexOp(ir._i64Attr(i, None)) - - # for i in range(0, len(shp2) - 1): - # rhs_indexcast_op = arith.IndexCastOp(ir.IndexType.get(), block.arguments[1]) - # rhs_index.append(rhs_indexcast_op) - # block.append(rhs_indexcast_op) - # print("here3") - # rhs_index_op = linalg.IndexOp(ir._i64Attr(len(shp2) - 1, None)) - # rhs_index.append(rhs_index_op) - # block.append(rhs_index_op) - - # print("here4") - # lhs = tensor.ExtractOp(input1, lhs_index) - # rhs = tensor.ExtractOp(input2, rhs_index) - # block.append(lhs) - # block.append(rhs) - # cmp_op = arith.CmpIOp(4, lhs, rhs) - # print("here5") - # block.append(cmp_op) - # block.append(linalg.YieldOp([cmp_op.result])) - # print("here6") - # else: - # cmp_op = arith.CmpFOp(2, block.arguments[0], block.arguments[1]) - - # return op - ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 9eb0929055..d3f6ba1631 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -273,9 +273,49 @@ def _inner_op(result_type, input1, input2): input2, ir.IntegerAttr.get(ir.IntegerType.get_signless(8), 0), ) + + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) - input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) - input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + if isinstance(node.args[0], str): + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + else: + data = [node.args[0]] + input1_shape = numpy.array(data).shape + tensor_type = ir.RankedTensorType.get(input1_shape, mlir_dtype) + element = mlir_element_attr_get(dtype, node.args[0]) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + input2 = arith.ConstantOp(tensor_type, attr).result + + if isinstance(node.args[1], str): + input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + else: + data = [node.args[1]] + input2_shape = numpy.array(data).shape + tensor_type = ir.RankedTensorType.get(input2_shape, mlir_dtype) + element = mlir_element_attr_get(dtype, node.args[1]) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + input2 = arith.ConstantOp(tensor_type, attr).result + + input1_dtype = ir.RankedTensorType(input1.type).element_type + input2_dtype = ir.RankedTensorType(input2.type).element_type + if input1_dtype != mlir_dtype: + input1 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input1.type).shape, + mlir_dtype, + ), + input1, + ).result + if input2_dtype != mlir_dtype: + input2 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input2.type).shape, + mlir_dtype, + ), + input2, + ).result return _gen_arith_binary_op(input1, input2, _inner_op) @@ -527,10 +567,14 @@ def convert_element_type_op(node: ConvertElementTypeOp, symbol_table): # When converting float to int, tosa.cast lowers to math.roundeven, but we don't need rounding. if str(to_cast_type).find("i") != -1 and str(input_type).find("f") != -1: output_shape = list(node.tensor_meta["shape"]) - dtype = node.tensor_meta["dtype"] - mlir_dtype = mlir_element_type_get(dtype) - tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) - output = tensor.EmptyOp(output_shape, mlir_dtype) + tensor_type = ir.RankedTensorType.get(output_shape, to_cast_type) + output = tensor.EmptyOp(output_shape, to_cast_type) + + if str(to_cast_type) == "i1": + false_val = arith.ConstantOp(to_cast_type, 0) + true_val = arith.ConstantOp(to_cast_type, 1) + zero_val = arith.ConstantOp(input_type, 0.0) + generic_map = ir.AffineMap.get_permutation( [i for i in range(len(output_shape))] ) @@ -564,9 +608,16 @@ def convert_element_type_op(node: ConvertElementTypeOp, symbol_table): to_cast_type, ], ) - fptosi_op = arith.FPToSIOp(to_cast_type, block.arguments[0]) - block.append(fptosi_op) - block.append(linalg.YieldOp([fptosi_op.result])) + if str(to_cast_type) == "i1": + is_zero = arith.CmpFOp(1, block.arguments[0], zero_val) + result = arith.SelectOp(is_zero, false_val, true_val) + block.append(is_zero) + block.append(result) + block.append(linalg.YieldOp([result.result])) + else: + fptosi_op = arith.FPToSIOp(to_cast_type, block.arguments[0]) + block.append(fptosi_op) + block.append(linalg.YieldOp([fptosi_op.result])) else: sizes = ir.RankedTensorType(input_tensor.type).shape output_type = ir.RankedTensorType.get(sizes, to_cast_type) @@ -836,41 +887,6 @@ def embedding_op(node: EmbeddingOp, symbol_table): return op -# def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: -# """ -# Import the expand operation. -# From buddy graph ir's `ExpandOp` operator to MLIR TOSA `add` operation. - -# Note: This conversion is implemented using the broadcast machanism of TOSA -# `add` operation. We allocate a tensor with the shape to expand and -# elements in this tensor is all zero. Then we add the original tensor -# to this all-zero tensor. After the applying the broadcasting, we get -# the result. -# """ -# to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) -# new_size = node.args[1] -# result_element_type = ir.RankedTensorType( -# to_expand_tensor.type -# ).element_type -# if result_element_type in ( -# ir.IntegerType.get_signless(1), -# ir.IntegerType.get_signless(64), -# ): -# element = ir.IntegerAttr.get(result_element_type, 0) -# elif result_element_type == ir.F32Type.get(): -# element = ir.FloatAttr.get(result_element_type, 0.0) -# else: -# raise NotImplementedError("Unsupported element type!") -# new_size_tensor_type = ir.RankedTensorType.get( -# new_size, result_element_type -# ) -# new_size_attr = ir.DenseElementsAttr.get_splat( -# new_size_tensor_type, element -# ) -# new_size_tensor = tosa.ConstOp(new_size_attr).results[0] -# op = _gen_arith_binary_op(to_expand_tensor, new_size_tensor, tosa.AddOp) -# return op - def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: """ Import the expand operation. @@ -1581,11 +1597,10 @@ def scaled_dot_product_flash_attention_for_cpu_op( result_reshape_op: Reshaped result tensor of the attention operation. log_sumexp_op: Log-sum-exp constant operation. """ - print(node.args) - print(node.kwargs) query = symbol_table.get((str(node.args[0]), 0), node.args[0]) key = symbol_table.get((str(node.args[1]), 0), node.args[1]) value = symbol_table.get((str(node.args[2]), 0), node.args[2]) + if len(node.args) == 4: dropout_p = node.args[3] assert dropout_p != 0.0 @@ -1597,11 +1612,6 @@ def scaled_dot_product_flash_attention_for_cpu_op( attn_mask = node.kwargs.get("attn_mask", None) scale = node.kwargs.get("scale", None) - - - # print("attn_mask") - # print(attn_mask) - query_shape = query.type.shape key_shape = key.type.shape @@ -1751,7 +1761,7 @@ def scaled_dot_product_flash_attention_for_cpu_op( ops_registry = { "AddOp": add_op, - # "MulOp": mul_op, + "MulOp": mul_op, "SubOp": sub_op, "SumDimOp": sum_op, "TanhOp": tanh_op, diff --git a/tests/Python/test_max_pool2d.py b/tests/Python/test_max_pool2d.py index eecfc73d93..cac892761d 100644 --- a/tests/Python/test_max_pool2d.py +++ b/tests/Python/test_max_pool2d.py @@ -1,7 +1,6 @@ # RUN: %PYTHON %s 2>&1 | FileCheck %s import torch -from torch._inductor.decomposition import decompositions as inductor_decomp from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.ops import tosa @@ -19,7 +18,6 @@ def forward(self, a): model = TestModule() dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, ) in1 = torch.randn((1, 3, 640, 480)) @@ -27,7 +25,7 @@ def forward(self, a): model_opt = torch.compile(model, backend=dynamo_compiler) assert torch.allclose(model_opt(in1), model(in1), equal_nan=True) -graphs = dynamo_compiler.importer(model, in1) +graphs = dynamo_compiler._imported_graphs assert len(graphs) == 1 graph = graphs[0] graph.lower_to_top_level_ir() diff --git a/tests/Python/test_mean.py b/tests/Python/test_mean.py index 0595619d18..54cc092b48 100644 --- a/tests/Python/test_mean.py +++ b/tests/Python/test_mean.py @@ -24,7 +24,7 @@ def foo(x, y, keepdim): assert torch.allclose( foo_mlir(in1, in2, keepdim=in3), foo(in1, in2, keepdim=in3), equal_nan=True ) -graphs = dynamo_compiler.importer(foo, in1, in2, in3) +graphs = dynamo_compiler._imported_graphs assert len(graphs) == 1 graph = graphs[0] graph.lower_to_top_level_ir() From f45d990f43cc9632ceab0014d701cf2a090928b8 Mon Sep 17 00:00:00 2001 From: R-Tars Date: Tue, 12 Nov 2024 07:50:19 +0000 Subject: [PATCH 3/6] Support for troch2.4. --- examples/BuddyLlama/import-llama2.py | 2 +- frontend/Python/ops/linalg.py | 19 ++++++++------- frontend/Python/ops/tosa.py | 36 ++++++++++++++++++++-------- requirements.txt | 11 +++++---- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/examples/BuddyLlama/import-llama2.py b/examples/BuddyLlama/import-llama2.py index 2903d6bd81..d893ee87f6 100644 --- a/examples/BuddyLlama/import-llama2.py +++ b/examples/BuddyLlama/import-llama2.py @@ -38,7 +38,7 @@ ) # Initialize the tokenizer and model from the specified model path. -tokenizer = LlamaTokenizer.from_pretrained(model_path) +tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) model = LlamaForCausalLM.from_pretrained(model_path, torchscript=True) model.config.use_cache = False diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 0081a117a2..00a2ccda57 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1073,7 +1073,7 @@ def mul_op( element = mlir_element_attr_get(dtype, node.args[1]) attr = ir.DenseElementsAttr.get_splat(tensor_type, element) input2 = arith.ConstantOp(tensor_type, attr).result - + input1_dtype = ir.RankedTensorType(input1.type).element_type input2_dtype = ir.RankedTensorType(input2.type).element_type if input1_dtype != mlir_dtype: @@ -1806,7 +1806,7 @@ def where_op( mlir_dtype = mlir_element_type_get(dtype) tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) - + if not isinstance(input2.type, ir.RankedTensorType): input2 = tensor.SplatOp(tensor_type, input2).result if not isinstance(input3.type, ir.RankedTensorType): @@ -1857,7 +1857,9 @@ def where_op( ir.RankedTensorType(output.result.type).element_type, ], ) - select_op = arith.SelectOp(block.arguments[0], block.arguments[1], block.arguments[2]) + select_op = arith.SelectOp( + block.arguments[0], block.arguments[1], block.arguments[2] + ) block.append(select_op) block.append(linalg.YieldOp([select_op.result])) @@ -1996,6 +1998,7 @@ def gt_op(node: GtOp, symbol_table): return cmp_op + def ge_op( node: GeOp, symbol_table: Dict[Tuple[str, int], ir.Operation], @@ -2027,6 +2030,7 @@ def ge_op( return cmp_op + def greater_than_op( node: GreaterThanOp, symbol_table: Dict[Tuple[str, int], ir.Operation], @@ -2105,18 +2109,15 @@ def greater_than_op( str(ir.RankedTensorType(input2.type).element_type).find("i") != -1 ): - cmpop = arith.CmpIOp( - 4, block.arguments[0], block.arguments[1] - ) + cmpop = arith.CmpIOp(4, block.arguments[0], block.arguments[1]) else: - cmpop = arith.CmpFOp( - 2, block.arguments[0], block.arguments[1] - ) + cmpop = arith.CmpFOp(2, block.arguments[0], block.arguments[1]) block.append(cmpop) block.append(linalg.YieldOp([cmpop.result])) return op + ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index d3f6ba1631..8ba1a834ec 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -273,7 +273,7 @@ def _inner_op(result_type, input1, input2): input2, ir.IntegerAttr.get(ir.IntegerType.get_signless(8), 0), ) - + output_shape = list(node.tensor_meta["shape"]) dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) @@ -287,7 +287,7 @@ def _inner_op(result_type, input1, input2): element = mlir_element_attr_get(dtype, node.args[0]) attr = ir.DenseElementsAttr.get_splat(tensor_type, element) input2 = arith.ConstantOp(tensor_type, attr).result - + if isinstance(node.args[1], str): input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) else: @@ -297,7 +297,7 @@ def _inner_op(result_type, input1, input2): element = mlir_element_attr_get(dtype, node.args[1]) attr = ir.DenseElementsAttr.get_splat(tensor_type, element) input2 = arith.ConstantOp(tensor_type, attr).result - + input1_dtype = ir.RankedTensorType(input1.type).element_type input2_dtype = ir.RankedTensorType(input2.type).element_type if input1_dtype != mlir_dtype: @@ -1618,7 +1618,9 @@ def scaled_dot_product_flash_attention_for_cpu_op( value_shape = value.type.shape output_shape = list(node.tensor_meta["shape"]) L, S = query_shape[-2], key_shape[-2] - scale_factor = 1 / numpy.sqrt(query.type.shape[-1]) if scale is None else scale + scale_factor = ( + 1 / numpy.sqrt(query.type.shape[-1]) if scale is None else scale + ) # Initialize attention bias dtype = node.tensor_meta["dtype"][0] @@ -1631,16 +1633,31 @@ def scaled_dot_product_flash_attention_for_cpu_op( attn_mask = symbol_table.get((str(attn_mask), 0), attn_mask) if attn_mask.type.element_type == ir.IntegerType.get_signless(1): assert attn_mask.type.element_type == ir.IntegerType.get_signless(1) - tensor_type = ir.RankedTensorType.get(attn_mask.type.shape, ir.IntegerType.get_signless(1)) - true_tensor = arith.ConstantOp(tensor_type, ir.DenseElementsAttr.get_splat(tensor_type, ir.BoolAttr.get(True))) + tensor_type = ir.RankedTensorType.get( + attn_mask.type.shape, ir.IntegerType.get_signless(1) + ) + true_tensor = arith.ConstantOp( + tensor_type, + ir.DenseElementsAttr.get_splat( + tensor_type, ir.BoolAttr.get(True) + ), + ) attn_mask = arith.XOrIOp(attn_mask, true_tensor) - minus_inf_tensor = arith.ConstantOp(attn_mask.type, ir.DenseElementsAttr.get_splat(attn_mask.type, ir.FloatAttr.get(f32_type, float('-inf')))) + minus_inf_tensor = arith.ConstantOp( + attn_mask.type, + ir.DenseElementsAttr.get_splat( + attn_mask.type, ir.FloatAttr.get(f32_type, float("-inf")) + ), + ) attn_bias = tensor.SelectOp(attn_mask, minus_inf_tensor, attn_bias) else: if attn_mask.type.shape != attn_bias.result.type.shape: - attn_mask = tosa.ReshapeOp(attn_mask, memoryview(array.array("i",attn_bias.result.type.shape))) + attn_mask = tosa.ReshapeOp( + attn_mask, + memoryview(array.array("i", attn_bias.result.type.shape)), + ) attn_bias = tosa.AddOp(attn_bias.result.type, attn_bias, attn_mask) - + # Transpose key tensor key_shape = list(key.type.shape) perm_list = list(range(len(key_shape))) @@ -1758,7 +1775,6 @@ def scaled_dot_product_flash_attention_for_cpu_op( return result_reshape_op, log_sumexp - ops_registry = { "AddOp": add_op, "MulOp": mul_op, diff --git a/requirements.txt b/requirements.txt index 9818b8ec74..5efad98526 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --pre --extra-index-url https://download.pytorch.org/whl/cpu -torch == 2.1.2 +torch == 2.4.0 numpy < 2 -transformers == 4.33.1 -tokenizers == 0.13.3 -sentencepiece == 0.1.99 +transformers == 4.46.2 +tokenizers >= 0.20 +sentencepiece == 0.2.0 accelerate protobuf pybind11 == 2.11.1 @@ -12,3 +12,6 @@ tabulate datasets soundfile librosa +PyYAML +certifi +idna \ No newline at end of file From b3809738c812235564c957138b73dc273516c762 Mon Sep 17 00:00:00 2001 From: Xinye_Zhu <62270271+R-Tars@users.noreply.github.com> Date: Mon, 18 Nov 2024 09:45:46 +0800 Subject: [PATCH 4/6] Update frontend.py --- frontend/Python/frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 9cf44948c9..5d4f256f86 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -331,8 +331,8 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): else: tensor_meta = gm_node.meta.get("tensor_meta") val = gm_node.meta.get("val") - num_returns = len(gm_node.target._schema.returns) - # num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) + # num_returns = len(gm_node.target._schema.returns) + num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) if num_returns == 1: node_dtype = self._torch_dtype_translate( str(tensor_meta.dtype) From d1dc009be37daa608fe92c7cc0084c0a127c0f42 Mon Sep 17 00:00:00 2001 From: R-Tars Date: Mon, 2 Dec 2024 01:31:17 +0000 Subject: [PATCH 5/6] support for torch2.5 --- examples/BuddyBert/bert-main.cpp | 21 +- examples/BuddyWhisper/whisper-main.cpp | 12 +- frontend/Python/frontend.py | 45 ++++- frontend/Python/graph/operation.py | 5 + frontend/Python/ops/linalg.py | 253 +++++++++++++++++++++++-- requirements.txt | 2 +- 6 files changed, 293 insertions(+), 45 deletions(-) diff --git a/examples/BuddyBert/bert-main.cpp b/examples/BuddyBert/bert-main.cpp index 902c702c15..d3f0075491 100644 --- a/examples/BuddyBert/bert-main.cpp +++ b/examples/BuddyBert/bert-main.cpp @@ -24,18 +24,9 @@ using namespace buddy; -// Define ResultContainer -struct ResultContainer { - MemRef memRef3D; - MemRef memRef2D; - - ResultContainer(MemRef m1, MemRef m2) - : memRef3D(m1), memRef2D(m2) {} -}; - // Declare BERT forward function. extern "C" void -_mlir_ciface_forward(ResultContainer *result, MemRef *arg0, +_mlir_ciface_forward(MemRef *result, MemRef *arg0, MemRef *arg1, MemRef *arg2, MemRef *arg3, MemRef *arg4); @@ -94,9 +85,7 @@ int main() { pureStrContainer.tokenizeBert(vocabDir, 5); /// Initialize data containers. - MemRef result1({1, 5, 768}); - MemRef result2({1, 6}); - ResultContainer result(result1, result2); + MemRef result({1, 6}); MemRef attention_mask({1, 5}, 1LL); MemRef token_type_ids({1, 5}, 0LL); @@ -104,7 +93,7 @@ int main() { /// Execute forward inference of the model. _mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer, - &token_type_ids, &attention_mask); + &token_type_ids, &attention_mask); const auto inferenceEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration inferenceTime = @@ -113,8 +102,8 @@ int main() { int predict_label = -1; float max_logits = std::numeric_limits::min(); for (int i = 0; i < 6; i++) { - if (max_logits < result.memRef2D.getData()[i]) { - max_logits = result.memRef2D.getData()[i]; + if (max_logits < result.getData()[i]) { + max_logits = result.getData()[i]; predict_label = i; } } diff --git a/examples/BuddyWhisper/whisper-main.cpp b/examples/BuddyWhisper/whisper-main.cpp index 42e75f2c38..011b5c847e 100644 --- a/examples/BuddyWhisper/whisper-main.cpp +++ b/examples/BuddyWhisper/whisper-main.cpp @@ -33,7 +33,7 @@ using namespace std; using namespace buddy; using namespace dap; -constexpr size_t ParamsSize = 99148800; +constexpr size_t ParamsSize = 72593920; constexpr size_t MaxVocabSize = 51865; constexpr size_t MaxTokenLength = 448; @@ -125,9 +125,8 @@ int main() { Text outputContainer; Audio rawAudioContainer("../../examples/BuddyWhisper/audio.wav"); MemRef audioInput({1, 80, 3000}); - MemRef resultContainer[3] = { + MemRef resultContainer[2] = { MemRef({1, 1500, 512}, false, 0), - MemRef({1, 448, 512}, false, 0), MemRef({1, 448, MaxVocabSize}, false, 0), }; MemRef textContainer({1, MaxTokenLength}, 50258); @@ -156,7 +155,7 @@ int main() { inferenceEnd - inferenceStart; // Determine the generated token. - const float *startPtr = resultContainer[2].getData() + i * MaxVocabSize; + const float *startPtr = resultContainer[1].getData() + i * MaxVocabSize; const float *endPtr = startPtr + MaxVocabSize; int maxIndex = findMaxIndex(startPtr, endPtr); @@ -172,9 +171,8 @@ int main() { textContainer.getData()[i + 1] = maxIndex; outputContainer.appendTokenIdx(maxIndex); - // free(resultContainer[0].release()); - // free(resultContainer[1].release()); - // free(resultContainer[2].release()); + free(resultContainer[0].release()); + free(resultContainer[1].release()); } /// Print the final result diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 5d4f256f86..69a46e7842 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -171,6 +171,7 @@ def __init__( "_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp, "ge.Scalar": GeOp, "gt.Tensor": GreaterThanOp, + "_unsafe_index.Tensor": UnsafeIndexOp, } @property @@ -261,11 +262,26 @@ def _compile_fx( return for torchdynamo's call. """ - params = { - **dict(gm.named_parameters(remove_duplicate=False)), - **dict(gm.named_buffers(remove_duplicate=False)), - } - params_flat, _ = pytree.tree_flatten(params) + # params = { + # # **dict(gm.named_parameters(remove_duplicate=False)), + # **dict(gm.named_buffers(remove_duplicate=False)), + # } + # print(len(params)) + # params_flat, _ = pytree.tree_flatten(params) + inputs_pos = [] + params_pos = [] + buffers_pos = [] + for i, node in enumerate(gm.graph.nodes): + if i >= len(inputs): + break + if not str(node).startswith("l_self"): + inputs_pos.append(i) + elif "buffer" in str(node): + buffers_pos.append(i) + else: + params_pos.append(i) + + params_flat = [inputs[i] for i in params_pos + buffers_pos] if self._verbose: print("Graph in tabular form:") @@ -275,7 +291,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): """Compile a FX graph in Aten/Prims IR to MLIR.""" nonlocal params_flat func_inputs = [] - for inp in _inputs[len(params_flat) :]: + for i in inputs_pos: + # for inp in _inputs[len(params_flat) :]: + inp = _inputs[i] inp_shape = inp.shape inp_dtype = self._torch_dtype_translate(str(inp.dtype)) func_inputs.append(TensorMeta(inp_shape, inp_dtype)) @@ -290,7 +308,20 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): self._func_name, self._verbose ) - for gm_node in _gm.graph.nodes: + param_nodes = [] + buffers_nodes = [] + input_nodes = [] + for i, node in enumerate(_gm.graph.nodes): + if i in params_pos: + param_nodes.append(node) + elif i in buffers_pos: + buffers_nodes.append(node) + elif i in inputs_pos: + input_nodes.append(node) + + gm_nodes = param_nodes + buffers_nodes + input_nodes + + for gm_node in gm_nodes: node_users = [] for user in gm_node.users.keys(): node_users.append(str(user)) diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 511adf6e35..c1a7b09746 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -553,3 +553,8 @@ def __init__(self) -> None: super().__init__() self._op_type = OpType.BroadcastType + +class UnsafeIndexOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReshapeType diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 00a2ccda57..ec6c827e6c 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1231,28 +1231,51 @@ def index_op( return input1_shape = ir.RankedTensorType(input1.type).shape input2 = node.args[1] + input2_dim_sum = 0 + for i in range(len(input2)): + input2_dim_sum += len(symbol_table.get((str(input2[i]), 0)).type.shape) output_shape = list(node.tensor_meta["shape"]) + input_shape = input1.type.shape dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) if len(input2) < len(input1_shape): tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) - loops = ir.RankedTensorType( - symbol_table.get((str(input2[0]), 0)).type - ).shape generic_map = ir.AffineMap.get_permutation( - [i for i in range(len(output_shape))] + [i for i in range(max(len(output_shape), len(input_shape)))] ) - input_map = [ - ir.AffineMapAttr.get( - generic_map.get_submap([j for j in range(len(loops))]) + input_map = [] + for i in range(len(input2)): + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(i, i + len(input2_shape))] + ) + ) ) - for i in range(len(input2)) - ] + [ - ir.AffineMapAttr.get( - generic_map.get_submap([j for j in range(len(output_shape))]) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) ) - ] operands = [symbol_table.get((str(i), 0)) for i in input2] op = linalg.GenericOp( [tensor_type], @@ -1261,7 +1284,7 @@ def index_op( ir.ArrayAttr.get(input_map), ir.ArrayAttr.get( [ir.Attribute.parse("#linalg.iterator_type")] - * len(output_shape) + * max(len(output_shape), len(input_shape)) ), ) arguments = [ @@ -1273,7 +1296,9 @@ def index_op( indexcast_op = arith.IndexCastOp(ir.IndexType.get(), i) block.append(indexcast_op) index.append(indexcast_op.result) - for i in range(len(loops), len(output_shape) - len(input2) + 1): + for i in range( + input2_dim_sum, max(len(input_shape), len(output_shape)) + ): index_op = linalg.IndexOp(ir._i64Attr(i, None)) block.append(index_op) index.append(index_op.result) @@ -1573,6 +1598,9 @@ def softmax_op( if dim < 0: dim += len(output_shape) mlir_dtype = mlir_element_type_get(dtype) + max_vals = tosa.ReduceMaxOp(input1, dim) + sub_op_output = ir.RankedTensorType.get(input1.type.shape, mlir_dtype) + input1 = tosa.SubOp(sub_op_output, input1, max_vals) # tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) # output = tensor.EmptyOp(output_shape, mlir_dtype) # op = linalg.softmax( @@ -2118,6 +2146,202 @@ def greater_than_op( return op +def unsafe_index_op( + node: UnsafeIndexOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor _unsafe_index operation. + From buddy UnsafeIndexOp to MLIR linalg `generic` + operation. + Note: This op, get input node slice result by input index. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + assert len(node.args) == 2 + input1 = symbol_table.get((str(node.args[0]), 0)) + if input1 is None: + return + input1_shape = ir.RankedTensorType(input1.type).shape + input2 = node.args[1] + have_none = False + for i in input2: + if i == None: + have_none = True + break + input2_dim_sum = 0 + for i in range(len(input2)): + input2_dim_sum += ( + len(symbol_table.get((str(input2[i]), 0)).type.shape) + if input2[i] != None + else 0 + ) + output_shape = list(node.tensor_meta["shape"]) + input_shape = input1.type.shape + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + if len(input2) < len(input1_shape): + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(max(len(output_shape), len(input_shape)))] + ) + input_map = [] + for i in range(len(input2)): + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(i, i + len(input2_shape))] + ) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) + ) + operands = [symbol_table.get((str(i), 0)) for i in input2] + op = linalg.GenericOp( + [tensor_type], + operands, + [output], + ir.ArrayAttr.get(input_map), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * max(len(output_shape), len(input_shape)) + ), + ) + arguments = [ + ir.RankedTensorType(i.type).element_type for i in operands + ] + [ir.RankedTensorType(output.result.type).element_type] + block = ir.Block.create_at_start(op.region, arguments) + index = [] + for i in block.arguments[:-1]: + indexcast_op = arith.IndexCastOp(ir.IndexType.get(), i) + block.append(indexcast_op) + index.append(indexcast_op.result) + for i in range( + input2_dim_sum, max(len(input_shape), len(output_shape)) + ): + index_op = linalg.IndexOp(ir._i64Attr(i, None)) + block.append(index_op) + index.append(index_op.result) + value = tensor.ExtractOp(input1, index) + block.append(value) + block.append(linalg.YieldOp([value.result])) + else: + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(max(len(output_shape), len(input_shape)))] + ) + input_map = [] + for i in range(len(input2)): + if input2[i] == None: + continue + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + if have_none: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap([j for j in range(i, i + 1)]) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) + ) + if have_none: + operands = [] + for i in input2: + if i == None: + continue + input2_ = symbol_table.get((str(i), 0)) + input2_shape = input2_.type.shape + if i != None and len(input2_shape) > 1: + total_size = 1 + for x in input2_shape: + total_size *= x + reshape_op = tosa.ReshapeOp( + input2_, memoryview(array.array("i", [total_size])) + ) + operands.append(reshape_op.result) + + else: + operands = [symbol_table.get((str(i), 0)) for i in input2] + op = linalg.GenericOp( + [tensor_type], + operands, + [output], + ir.ArrayAttr.get(input_map), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * max(len(output_shape), len(input_shape)) + ), + ) + arguments = [ + ir.RankedTensorType(i.type).element_type for i in operands + ] + [ir.RankedTensorType(output.result.type).element_type] + block = ir.Block.create_at_start(op.region, arguments) + index = [] + None_count = 0 + for i in range(len(input2)): + if input2[i] == None: + None_count += 1 + index_op = linalg.IndexOp(ir._i64Attr(i, None)) + block.append(index_op) + index.append(index_op.result) + else: + indexcast_op = arith.IndexCastOp( + ir.IndexType.get(), block.arguments[i - None_count] + ) + block.append(indexcast_op) + index.append(indexcast_op.result) + value = tensor.ExtractOp(input1, index) + block.append(value) + block.append(linalg.YieldOp([value.result])) + return op + + ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, @@ -2155,4 +2379,5 @@ def greater_than_op( "GtOp": gt_op, "GeOp": ge_op, "GreaterThanOp": greater_than_op, + "UnsafeIndexOp": unsafe_index_op, } diff --git a/requirements.txt b/requirements.txt index 5efad98526..6b2fd250c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --pre --extra-index-url https://download.pytorch.org/whl/cpu -torch == 2.4.0 +torch == 2.5.1 numpy < 2 transformers == 4.46.2 tokenizers >= 0.20 From 5ac08eaeae72677012ac575102226809fb7a6423 Mon Sep 17 00:00:00 2001 From: R-Tars Date: Mon, 2 Dec 2024 03:41:35 +0000 Subject: [PATCH 6/6] update --- frontend/Python/frontend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 69a46e7842..6441b23de3 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -311,6 +311,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): param_nodes = [] buffers_nodes = [] input_nodes = [] + other_nodes = [] for i, node in enumerate(_gm.graph.nodes): if i in params_pos: param_nodes.append(node) @@ -318,8 +319,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): buffers_nodes.append(node) elif i in inputs_pos: input_nodes.append(node) - - gm_nodes = param_nodes + buffers_nodes + input_nodes + else: + other_nodes.append(node) + gm_nodes = param_nodes + buffers_nodes + input_nodes + other_nodes for gm_node in gm_nodes: node_users = []