Skip to content

Commit

Permalink
[examples] Add bert emotional classification example (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
xTayEx authored Dec 23, 2023
1 parent 148c8f3 commit 9355ecb
Show file tree
Hide file tree
Showing 9 changed files with 30,747 additions and 1 deletion.
3 changes: 3 additions & 0 deletions examples/BuddyBert/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
arg0.data
arg1.data
bert.mlir
29 changes: 29 additions & 0 deletions examples/BuddyBert/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
add_custom_command(
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyBert/bert.mlir ${BUDDY_EXAMPLES_DIR}/BuddyBert/arg0.data ${BUDDY_EXAMPLES_DIR}/BuddyBert/arg1.data
COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyBert/import-bert.py
COMMENT "Generating bert.mlir and parameter files"
)


add_custom_command(
OUTPUT bert.o
COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyBert/bert.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
${LLVM_MLIR_BINARY_DIR}/mlir-opt
-pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyBert/bert.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyBert/bert.mlir
COMMENT "Building bert.o"
VERBATIM)

add_library(BERT STATIC bert.o)

SET_TARGET_PROPERTIES(BERT PROPERTIES LINKER_LANGUAGE C)

add_executable(buddy-bert-run bert-main.cpp)
target_link_directories(buddy-bert-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR})

set(BUDDY_BERT_LIBS BERT mlir_c_runner_utils)
target_link_libraries(buddy-bert-run ${BUDDY_BERT_LIBS})
23 changes: 23 additions & 0 deletions examples/BuddyBert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Buddy Compiler BERT Emotion Classification Example

## Introduction
This example shows how to use Buddy Compiler to compile a BERT model to MLIR code then run it. The [model](bhadresh-savani/bert-base-uncased-emotion) is trained to classify the emotion of a sentence into one of the following classes: sadness, joy, love, anger, fear, and surprise.


## How to run
1. Ensure that LLVM, Buddy Compiler and the Buddy Compiler python packages are installed properly. You can refer to [here](https://github.com/buddy-compiler/buddy-mlir) for more information and do a double check.

2. Set the `PYTHONPATH` environment variable.
```bash
$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH}
```

3. Build and run the BERT example
```bash
$ cmake -G Ninja .. -DBUDDY_BERT_EXAMPLES=ON
$ ninja buddy-bert-run
$ cd bin
$ ./buddy-bert-run
```

4. Enjoy it!
100 changes: 100 additions & 0 deletions examples/BuddyBert/bert-main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//===- bert-main.cpp -----------------------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include <buddy/Core/Container.h>
#include <buddy/LLM/TextContainer.h>
#include <filesystem>
#include <limits>
#include <string>
#include <utility>
#include <vector>

using namespace buddy;

// Declare BERT forward function.
extern "C" void
_mlir_ciface_forward(MemRef<float, 2> *result, MemRef<float, 1> *arg0,
MemRef<long long, 1> *arg1, MemRef<long long, 2> *arg2,
MemRef<long long, 2> *arg3, MemRef<long long, 2> *arg4);

void loadParameters(const std::string &floatParamPath,
const std::string &int64ParamPath,
MemRef<float, 1> &floatParam,
MemRef<long long, 1> &int64Param) {
std::ifstream floatParamFile(floatParamPath, std::ios::in | std::ios::binary);
if (!floatParamFile.is_open()) {
std::string errMsg = "Failed to open float param file: " +
std::filesystem::canonical(floatParamPath).string();
throw std::runtime_error(errMsg);
}
floatParamFile.read(reinterpret_cast<char *>(floatParam.getData()),
floatParam.getSize() * sizeof(float));
if (floatParamFile.fail()) {
throw std::runtime_error("Failed to read float param file");
}
floatParamFile.close();


std::ifstream int64ParamFile(int64ParamPath, std::ios::in | std::ios::binary);
if (!int64ParamFile.is_open()) {
std::string errMsg = "Failed to open int64 param file: " +
std::filesystem::canonical(int64ParamPath).string();
throw std::runtime_error(errMsg);
}
int64ParamFile.read(reinterpret_cast<char *>(int64Param.getData()),
int64Param.getSize() * sizeof(long long));
if (int64ParamFile.fail()) {
throw std::runtime_error("Failed to read int64 param file");
}
int64ParamFile.close();
}

int main() {
MemRef<float, 1> arg0({109486854});
MemRef<long long, 1> arg1({512});
loadParameters("../../examples/BuddyBert/arg0.data",
"../../examples/BuddyBert/arg1.data", arg0, arg1);

std::cout << "this BERT model will guess the emotion of your sentence"
<< std::endl;
std::cout << "What sentence do you want to say to BERT?" << std::endl;

std::string vocabDir = "../../examples/BuddyBert/vocab.txt";
std::string pureStr;
std::getline(std::cin, pureStr);
Text<long long, 2> pureStrContainer(pureStr);
pureStrContainer.tokenizeBert(vocabDir, 5);

MemRef<float, 2> result({1, 6});
MemRef<long long, 2> attention_mask({1, 5}, 1LL);
MemRef<long long, 2> token_type_ids({1, 5}, 0LL);
_mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer,
&attention_mask, &token_type_ids);
int predict_label = -1;
float max_logits = std::numeric_limits<float>::min();
for (int i = 0; i < 6; i++) {
if (max_logits < result.getData()[i]) {
max_logits = result.getData()[i];
predict_label = i;
}
}

std::vector<std::string> emotion = {"sadness", "joy", "love",
"anger", "fear", "surprise"};
std::cout << "The emotion of this sentence is \"" << emotion[predict_label]
<< "\"" << std::endl;
return 0;
}
63 changes: 63 additions & 0 deletions examples/BuddyBert/import-bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# ===- import-bert.py --------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the test of llama2 model.
#
# ===---------------------------------------------------------------------------

import os
from pathlib import Path

import numpy as np
import torch
from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa
from torch._inductor.decomposition import decompositions as inductor_decomp
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained(
"bhadresh-savani/bert-base-uncased-emotion"
)
model.eval()
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

tokenizer = BertTokenizer.from_pretrained(
"bhadresh-savani/bert-base-uncased-emotion"
)
inputs = {
"input_ids": torch.tensor([[1 for _ in range(5)]], dtype=torch.int64),
"token_type_ids": torch.tensor([[0 for _ in range(5)]], dtype=torch.int64),
"attention_mask": torch.tensor([[1 for _ in range(5)]], dtype=torch.int64),
}
with torch.no_grad():
module, params = dynamo_compiler.importer(model, **inputs)

current_path = os.path.dirname(os.path.abspath(__file__))

with open(Path(current_path) / "bert.mlir", "w") as module_file:
module_file.write(str(module))

float32_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params[:-1]]
)

float32_param.tofile(Path(current_path) / "arg0.data")

int64_param = params[-1].detach().numpy().reshape([-1])
int64_param.tofile(Path(current_path) / "arg1.data")
Loading

0 comments on commit 9355ecb

Please sign in to comment.