Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for torch2.5 #431

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/BuddyBert/bert-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ int main() {

/// Execute forward inference of the model.
_mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer,
&attention_mask, &token_type_ids);
&token_type_ids, &attention_mask);

const auto inferenceEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> inferenceTime =
Expand Down
4 changes: 1 addition & 3 deletions examples/BuddyLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
Expand All @@ -39,13 +38,12 @@
)

model = LeNet()
model = torch.load(model_path + "/lenet-model.pth")
model = torch.load(model_path + "/lenet-model.pth", weights_only=False)
model = model.eval()

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

data = torch.randn([1, 1, 28, 28])
Expand Down
1 change: 1 addition & 0 deletions examples/BuddyLlama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_custom_command(
COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-convert-elementwise-to-linalg
-arith-expand
-eliminate-empty-tensors
-empty-tensor-to-alloc-tensor
Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/import-llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)

# Initialize the tokenizer and model from the specified model path.
tokenizer = LlamaTokenizer.from_pretrained(model_path)
tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True)
model = LlamaForCausalLM.from_pretrained(model_path, torchscript=True)
model.config.use_cache = False

Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/llama-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

using namespace buddy;

constexpr size_t ParamsSize = 6755192832;
constexpr size_t ParamsSize = 6738415680;
constexpr size_t MaxVocabSize = 32000;
constexpr size_t MaxTokenLength = 40;
constexpr size_t HiddenSize = 4096;
Expand Down
4 changes: 2 additions & 2 deletions examples/BuddyWhisper/whisper-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using namespace std;
using namespace buddy;
using namespace dap;

constexpr size_t ParamsSize = 99148800;
constexpr size_t ParamsSize = 72593920;
constexpr size_t MaxVocabSize = 51865;
constexpr size_t MaxTokenLength = 448;

Expand Down Expand Up @@ -180,4 +180,4 @@ int main() {
<< std::endl;

return 0;
}
}
51 changes: 44 additions & 7 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
"mean.dim": MeanOp,
"rsqrt.default": RsqrtOp,
"mul.Tensor": MulOp,
"mul.Scalar": MulOp,
"t.default": TOp,
"mm.default": MatmulOp,
"transpose.int": TransposeOp,
Expand Down Expand Up @@ -167,6 +168,10 @@ def __init__(
"split.Tensor":SplitOp,
"max.default":MaxOp,
"gt.Scalar":GtOp,
"_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp,
"ge.Scalar": GeOp,
"gt.Tensor": GreaterThanOp,
"_unsafe_index.Tensor": UnsafeIndexOp,
}

@property
Expand Down Expand Up @@ -257,11 +262,26 @@ def _compile_fx(
return for torchdynamo's call.
"""

params = {
**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False)),
}
params_flat, _ = pytree.tree_flatten(params)
# params = {
# # **dict(gm.named_parameters(remove_duplicate=False)),
# **dict(gm.named_buffers(remove_duplicate=False)),
# }
# print(len(params))
# params_flat, _ = pytree.tree_flatten(params)
inputs_pos = []
params_pos = []
buffers_pos = []
for i, node in enumerate(gm.graph.nodes):
if i >= len(inputs):
break
if not str(node).startswith("l_self"):
inputs_pos.append(i)
elif "buffer" in str(node):
buffers_pos.append(i)
else:
params_pos.append(i)

params_flat = [inputs[i] for i in params_pos + buffers_pos]

if self._verbose:
print("Graph in tabular form:")
Expand All @@ -271,7 +291,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
"""Compile a FX graph in Aten/Prims IR to MLIR."""
nonlocal params_flat
func_inputs = []
for inp in _inputs[len(params_flat) :]:
for i in inputs_pos:
# for inp in _inputs[len(params_flat) :]:
inp = _inputs[i]
inp_shape = inp.shape
inp_dtype = self._torch_dtype_translate(str(inp.dtype))
func_inputs.append(TensorMeta(inp_shape, inp_dtype))
Expand All @@ -286,7 +308,22 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
self._func_name,
self._verbose
)
for gm_node in _gm.graph.nodes:
param_nodes = []
buffers_nodes = []
input_nodes = []
other_nodes = []
for i, node in enumerate(_gm.graph.nodes):
if i in params_pos:
param_nodes.append(node)
elif i in buffers_pos:
buffers_nodes.append(node)
elif i in inputs_pos:
input_nodes.append(node)
else:
other_nodes.append(node)
gm_nodes = param_nodes + buffers_nodes + input_nodes + other_nodes

for gm_node in gm_nodes:
node_users = []
for user in gm_node.users.keys():
node_users.append(str(user))
Expand Down
24 changes: 24 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,27 @@ class GtOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class ScaledDotProductFlashAttentionForCpuOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class GeOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class GreaterThanOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.BroadcastType


class UnsafeIndexOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType
Loading