Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BuddyLLaMA] Add GPU lowering example for LLaMA inference. #229

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
de8d569
Add python venv to gitignore.
SForeKeeper Nov 1, 2023
f49ae86
Prevent generation of llama weight.
SForeKeeper Nov 1, 2023
8e3d880
Add mmap to Memref Container.
SForeKeeper Nov 3, 2023
cbf09a1
Generate powf instead of fpowi in linalg lowering.
SForeKeeper Nov 3, 2023
9c038b5
Add instructions for lowering GPU code.
SForeKeeper Nov 3, 2023
36f095e
Fix pipeline filename.
SForeKeeper Dec 3, 2023
de8064a
Fix compilation instruction.
SForeKeeper Dec 3, 2023
4038377
Add llama-gpu pipeline in cmake.
SForeKeeper Dec 4, 2023
43faa06
Add GPUHostRegister pass.
SForeKeeper Dec 5, 2023
64184f0
Add temp pass test.
SForeKeeper Dec 25, 2023
afbbc18
Add function skeleton for alloc checking.
SForeKeeper Dec 25, 2023
8a0f925
Update host registration process.
SForeKeeper Dec 26, 2023
d6fcd04
Update host register logic.
SForeKeeper Dec 26, 2023
679a500
Update host registration so it works correctly again.
SForeKeeper Dec 26, 2023
4db76cc
[tests] Fix ContainerTest check error on MacOS + Apple Silicon. (#268)
Lester-1 Dec 25, 2023
f751446
[Git] Add shallow option to speed up pulling code (#259)
harrisonGPU Dec 26, 2023
1e88a5c
[DAP] Add IIR vectorization pass. (#241)
taiqzheng Dec 26, 2023
cc7eab4
[VectorExp] Add AOT pipeline.
zhanghb97 Dec 26, 2023
2298a50
[DAP] Fix LowerDAPPass (#245)
taiqzheng Dec 26, 2023
8ca5b72
[examples] Improve Tosa Dialect examples (#271)
meshtag Jan 3, 2024
d2a1d52
[Python] Add pybind11 dependency (#272)
BHbean Jan 4, 2024
f5aeb8f
[test] Exclude bert examples.
zhanghb97 Jan 4, 2024
4a76ae8
[NFC] Fix example paths in examples/README.md
meshtag Jan 13, 2024
5ba44b5
[VectorExp] Add GetVL and SetVL operations to VectorExp dialect.
zhanghb97 Jan 16, 2024
531460a
[examples] Fix examples of MLIR Tensor dialect.
zhanghb97 Jan 17, 2024
8777287
[Benchmark] Fix conv-opt lowering flops benchmark
meshtag Jan 20, 2024
482b463
[frontend] Add Initial Graph Infra.
weilinquan Dec 9, 2023
a679d1a
[DIP] Modify imgcodecs and ImageContainer to update buddy-benchmark (…
meshtag Jan 30, 2024
574fcaa
[NFC] Format the code and eliminate unused variables.
zhanghb97 Jan 30, 2024
9051ae5
Update gitignores.
SForeKeeper Feb 4, 2024
fb18ab1
Add files.
SForeKeeper Feb 6, 2024
b41981b
Fix incorrect order.
SForeKeeper Feb 6, 2024
7e52db1
Add working example.
SForeKeeper Feb 6, 2024
68b04d6
Add openmp to current implementation.
SForeKeeper Feb 6, 2024
8380ffd
Use non-packed version for GPU.
SForeKeeper Feb 6, 2024
c58bfcd
First version of gpu memcpy.
SForeKeeper Feb 7, 2024
a3f0845
Finish memref to gpu operation.
SForeKeeper Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@

# Clangd cache
.cache

# Python venv
venv*
2 changes: 2 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
path = llvm
url = https://github.com/llvm/llvm-project.git
branch = main
shallow = true
[submodule "thirdparty/mimalloc"]
path = thirdparty/mimalloc
url = https://github.com/microsoft/mimalloc.git
shallow = true
2 changes: 1 addition & 1 deletion benchmark/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ all:$(OUT)
$(shell rm -rf tempFile)

BUDDY_OPT_OPTIONS := -conv-vectorization="strip-mining=${STRIP}" -lower-affine -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts
MLIR_OPT_OPTIONS := -convert-linalg-to-loops -convert-linalg-to-llvm -lower-affine -convert-scf-to-cf -convert-scf-to-cf -convert-vector-to-llvm --finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts
MLIR_OPT_OPTIONS := -convert-linalg-to-loops -lower-affine -convert-scf-to-cf -convert-scf-to-cf -convert-vector-to-llvm --finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts

$(OUT):$(SOURCE)
@echo $*
Expand Down
53 changes: 53 additions & 0 deletions docs/IIRVectorizationAlgorithm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Algorithm Explanation

This document shows the details of the algorithms used in DAPVectorization pass.

## IIR Vectorization Implementation

IIR filter can represent in different forms, typically ZPK(Zero-Pole-Gain) form or SOS(Second-Order Sections) form. Filter can be defined in ZPK form and then transformed to SOS form.

### Scalar Computation for IIR Operation

Currently, our IIR operation supports filter with SOS form. When the filter has only one set of parameters, denoted as {$𝑏_0, 𝑏_1, b_2, a_1, a_2$}, distinguishing parameters by subscripts. The equation is shown in the following form:

**IIR with one set of params:**
$$ y_n = 𝑏_0 𝑥_𝑛 + 𝑏_1 𝑥_{𝑛−1} − 𝑎_1 𝑦_{𝑛−1} + 𝑏_2 𝑥_{𝑛−2} − 𝑎_2 𝑦_{𝑛−2} $$

When the filter have multiple sets of filters, the operation use a cascade method for calculation. Take two sets of params as an example, filter parameters denoted as {$𝑏_0^0, 𝑏_1^0, b_2^0, a_1^0, a_2^0$} and {$𝑏_0^1, 𝑏_1^1, b_2^1, a_1^1, a_2^1$}, superscript indicates parameters from different sets. The process is listed below:

**IIR with two sets of params:**
$$y_n^0 = 𝑏_0^0 𝑥_𝑛^0 + 𝑏_1^0 𝑥_{𝑛−1}^0 − 𝑎_1^0 𝑦_{𝑛−1}^0 + 𝑏_2^0 𝑥_{𝑛−2}^0 − 𝑎_2^0 𝑦_{𝑛−2}^0 $$
$$x_n^1 = y_n^0$$
$$y_n^1 = 𝑏_0^1 𝑥_𝑛^1 + 𝑏_1^1 𝑥_{𝑛−1}^1 − 𝑎_1^1 𝑦_{𝑛−1}^1 + 𝑏_2^1 𝑥_{𝑛−2}^1 − 𝑎_2^0 𝑦_{𝑛−2}^1$$

### Vectorization for IIR Operation

This section shows the implementation of IIR Vectorization algorithm. The example shown below contains 4 sets of parameters, with superscript {$0, 1, 2, 3$} representing each set of parameters.

1. **Segment IIR Equation & Generate Vector Params**
![Segment IIR Equation to three parts due to different time moment](./Images/IIRSegmentation.png)
IIR equation were segmented into 3 parts, each part were calculated in different time moment. When $S2$ was calculated at time $t_i$, it will be used to calculate $S1$ at time $t_{i+1}$, then produce the final result at time $t_{i+2}$.

![Generate SOS params in vector form](./Images/IIRVectorParams.png)
In the above image, vector $B0$ were the collection of all $b_0$ params, other vectors $B1, B2, A1, A2$ each collect there corresponding params.

2. **Computing One Set of Params**
![Computing step 1](./Images/IIRComputing1.png)
The first step in computation, calculate $y_0^0$ with the following equation:
$$𝑦_0^0=𝑏_0^0𝑥_0+s_1^0$$
At time moment $0$, the initial values of $S1, S2$ were set to $0$.
![Computing step 2](./Images/IIRComputing2.png)
The second step in computation, calculate $s_1^0$ with the following equation:
$$𝑠_1^0=𝑏_1^0𝑥_0−𝑎_1^0𝑦_0^0+s_2^0 $$
![Computing step 3](./Images/IIRComputing3.png)
The third step in computation, calculate $s_2^0$ with the following equation:
$$𝑠_2^0=𝑏_2^0𝑥_0−𝑎_2^0𝑦_0^0$$

The above three steps happen in the same time moment $t$, which is the same loop iteration in program. The order of these three steps cannot change, because the value from vector $S1, S2$ were actually produced before time moment $t$.
3. **Cascade Method**
![Cascade step 1](./Images/IIRCascade1.png)
Now the values $y_0^0$, $s_1^0$ and $s_2^0$ were produced, here the whole system will get a new input $x1$ and move on the computation.
![Cascade step 2](./Images/IIRCascade2.png)
The $y_0^0$ were moved right and the new input $x1$ were pushed in. The value in vector $S1$ and $S2$ are not changed and will jump back to the second step. The difference in the next iteration is that two sets of parameters are used and this is where the performance improves.

When the example above came to the fourth iteration, the computation will be using all the parameters. This situation occurs for the vast majority of the time during the computation. Also, considering a longer vector length(currently support 4, 8, 16, 32, 64), it can achieve a 10x performance improvement.
Binary file added docs/Images/IIRCascade1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/Images/IIRCascade2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/Images/IIRComputing1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/Images/IIRComputing2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/Images/IIRComputing3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/Images/IIRSegmentation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/Images/IIRVectorParams.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 6 additions & 2 deletions examples/BuddyBert/import-bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@
"attention_mask": torch.tensor([[1 for _ in range(5)]], dtype=torch.int64),
}
with torch.no_grad():
module, params = dynamo_compiler.importer(model, **inputs)
graphs = dynamo_compiler.importer(model, **inputs)

assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]
graph.lower_to_top_level_ir(do_params_pack=True)
current_path = os.path.dirname(os.path.abspath(__file__))

with open(Path(current_path) / "bert.mlir", "w") as module_file:
module_file.write(str(module))
module_file.write(str(graph._imported_module))

float32_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params[:-1]]
Expand Down
8 changes: 8 additions & 0 deletions examples/BuddyGPU/matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module {
func.func @forward(%arg0: tensor<5376x2048xf32>, %arg1: tensor<2048x5376xf32>) -> tensor<5376x5376xf32> {
%cst = arith.constant dense<0.000000e+00> : tensor<5376x5376xf32>
%0 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5376x2048xf32>, tensor<2048x5376xf32>) outs(%cst : tensor<5376x5376xf32>) -> tensor<5376x5376xf32>
return %0 : tensor<5376x5376xf32>
}
}

54 changes: 54 additions & 0 deletions examples/BuddyGPU/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ===- matmul.py --------------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===--------------------------------------------------------------------------
#
# This file demonstrates the usage of Buddy's frontend for PyTorch module.
#
# ===--------------------------------------------------------------------------

import os
import time

import numpy
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from torch._functorch.aot_autograd import aot_autograd_decompositions
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa

dtype = torch.float32

def foo(x, y):
return torch.matmul(x, y)

in1 = torch.ones([5376, 2048], dtype=torch.float32)
in2 = torch.ones([2048, 5376], dtype=torch.float32)
# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

graphs = dynamo_compiler.importer(foo, in1, in2)
assert len(graphs) == 1
graph = graphs[0]
graph.lower_to_top_level_ir()

path_prefix = os.path.dirname(os.path.abspath(__file__))
# Write the MLIR module to the file.
with open(os.path.join(path_prefix, "matmul.mlir"), "w") as module_file:
print(graph._imported_module, file=module_file)
195 changes: 195 additions & 0 deletions examples/BuddyGPU/run-module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import mlir.ir as ir
import mlir.dialects.func as func
import mlir.dialects.memref as memref
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir import runtime as rt
from mlir.ir import *
import numpy as np
import ctypes
import gc
import torch


def to_numpy(element_type: str) -> np.dtype:
match element_type:
case "f16":
return np.float16
case "f32":
return np.float32
case "f64":
return np.float64
case "i8":
return np.int8
case "i16":
return np.int16
case "i32":
return np.int32
case "i64":
return np.int64
case "bf16":
return ValueError("bf16 is not supported by numpy")
case _:
raise ValueError(f"Unsupported type: {element_type}")


def to_mlir(dtype: np.dtype) -> ir.Type:
match dtype:
case np.float16:
return ir.F16Type.get()
case np.float32:
return ir.F32Type.get()
case np.float64:
return ir.F64Type.get()
case np.int8:
return ir.IntegerType.get_signless(8)
case np.int16:
return ir.IntegerType.get_signless(16)
case np.int32:
return ir.IntegerType.get_signless(32)
case np.int64:
return ir.IntegerType.get_signless(64)
case _:
raise ValueError(f"Unsupported type: {dtype}")


def lower_to_llvm_cpu(module: Module) -> Module:
pm = PassManager("builtin.module")
pm.add("func.func(tosa-to-linalg-named)")
pm.add("func.func(tosa-to-linalg)")
pm.add("func.func(tosa-to-tensor)")
pm.add("func.func(tosa-to-arith)")
pm.add("arith-expand")
pm.add("eliminate-empty-tensors")
pm.add("empty-tensor-to-alloc-tensor")
pm.add("convert-elementwise-to-linalg")
pm.add("one-shot-bufferize")
pm.add("func.func(convert-linalg-to-affine-loops)")
pm.add("affine-loop-fusion")
pm.add("func.func(affine-parallelize)")
pm.add("lower-affine")
pm.add("convert-scf-to-openmp")
pm.add("func-bufferize")
pm.add("arith-bufferize")
pm.add("func.func(tensor-bufferize)")
pm.add("func.func(buffer-deallocation)")
pm.add("func.func(finalizing-bufferize)")
pm.add("expand-strided-metadata")
pm.add("convert-vector-to-llvm")
pm.add("memref-expand")
pm.add("arith-expand")
pm.add("convert-arith-to-llvm")
pm.add("finalize-memref-to-llvm")
pm.add("convert-scf-to-cf")
pm.add("func.func(llvm-request-c-wrappers)")
pm.add("convert-openmp-to-llvm")
pm.add("convert-math-to-llvm")
pm.add("convert-math-to-libm")
pm.add("convert-func-to-llvm")
pm.add("reconcile-unrealized-casts")
pm.run(module.operation)
return module


def new_ranked_memref_descriptor(nparray: np.ndarray):
ctp = rt.as_ctype(nparray.dtype)
if nparray.ndim == 0:
x = rt.make_zero_d_memref_descriptor(ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
return x

x = rt.make_nd_memref_descriptor(nparray.ndim, ctp)()
nbytes = nparray.nbytes
buffer = ctypes.create_string_buffer(nbytes)
ctypes.memmove(buffer, nparray.ctypes.data, nbytes)
x.allocated = ctypes.cast(buffer, ctypes.c_void_p).value
x.aligned = ctypes.cast(buffer, ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
x.shape = nparray.ctypes.shape

# Numpy uses byte quantities to express strides, MLIR OTOH uses the
# torch abstraction which specifies strides in terms of elements.
strides_ctype_t = ctypes.c_longlong * nparray.ndim
x.strides = strides_ctype_t(
*[x // nparray.itemsize for x in nparray.strides]
)
return x


def testMemrefAdd():
with Context():
module = Module.parse(
"""
module {
func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xf32>
%2 = memref.load %arg1[] : memref<f32>
%3 = arith.addf %1, %2 : f32
memref.store %3, %arg2[%0] : memref<1xf32>
return
}
} """
)
arg1 = np.array([32.5]).astype(np.float32)
arg2 = np.array(6).astype(np.float32)
res = np.array([0]).astype(np.float32)

arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(arg1))
)
arg2_memref_ptr = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(arg2))
)
res_memref_ptr = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(res))
)

execution_engine = ExecutionEngine(lower_to_llvm_cpu(module))
execution_engine.invoke(
"main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
)
npout = rt.ranked_memref_to_numpy(res_memref_ptr[0])
print(npout)

def get_memref_descriptors(args: list[Type]):
memref_ptrs = []
for arg in args:
elem_type = to_numpy(str(arg.element_type))
np_arg = np.random.rand(*arg.shape).astype(elem_type)
memref_ptrs.append(
ctypes.pointer(
ctypes.pointer(new_ranked_memref_descriptor(np_arg))
)
)
return memref_ptrs

def test():
with Context() as ctx:
file = open(
"/home/liam/PLCT/buddy-mlir/examples/BuddyGPU/matmul.mlir", "r"
)
module: Module = Module.parse(file.read())
funcOp: func.FuncOp = (
module.operation.regions[0].blocks[0].operations[0]
)
funcName = str(funcOp.name).replace('"', "")
assert isinstance(funcOp, func.FuncOp)
args_type: list[Type] = [arg.type for arg in funcOp.arguments]
res_type = funcOp.type.results

newModule = lower_to_llvm_cpu(module)
memref_ptrs = get_memref_descriptors(res_type+args_type)

engine = ExecutionEngine(newModule,shared_libs=['/usr/lib/libomp.so'])
engine.invoke(funcName, *memref_ptrs)
out = rt.ranked_memref_to_numpy(memref_ptrs[0][0])
print(out)
input1 = rt.ranked_memref_to_numpy(memref_ptrs[1][0])
input2 = rt.ranked_memref_to_numpy(memref_ptrs[2][0])
numpy_out = np.matmul(input1, input2)
print(f"MLIR equal to PyTorch? {np.allclose(out, numpy_out)}")

test()
23 changes: 23 additions & 0 deletions examples/BuddyGraph/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Buddy Graph Representation Examples

## Run the Examples

0. Enter your Python Env
```
(base)$ conda activate buddy
(buddy)$ ...
```
1. Build Python Packages
2. Configure Python Path
```
(buddy)$ cd buddy-mlir/build
(buddy)$ export BUDDY_MLIR_BUILD_DIR=$PWD
(buddy)$ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build
(buddy)$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH}

```
3. Run the Examples
```
(buddy)$ cd examples/BuddyGraph
(buddy)$ python import-dynamo-break.py
```
Loading