Skip to content

Commit

Permalink
[frontend][LLM] Update TextContainer (#223)
Browse files Browse the repository at this point in the history
* [frontend][LLM] Completed all todos in TextContainer

* [frontend][LLM] Change vocab path

* [frontend][LLM] Change variable name

* [examples][BuddyLlama] Change function name in llama-main.cpp

* [frontend][LLM] Delete unnecessary function
  • Loading branch information
Lester-1 authored Oct 31, 2023
1 parent 3f1ccf5 commit 022b26e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 50 deletions.
4 changes: 2 additions & 2 deletions examples/BuddyLlama/llama-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ int main() {
auto buddyTokenizeTime =
duration_cast<milliseconds>(buddyTokenizeEnd - buddyTokenizeStart);
// Print the tokenized result
cout << "Get User input:" << pureStrContainer.revert(pureStrContainer)
cout << "Get User input:" << pureStrContainer.revertLlama(pureStrContainer)
<< endl;
cout << "[Buddy] Tokenize input time: " << buddyTokenizeTime.count() << "ms"
<< endl;
Expand Down Expand Up @@ -108,7 +108,7 @@ int main() {
auto buddyEnd = system_clock::now();
buddyReadTime = duration_cast<milliseconds>(buddyEnd - buddyStart);
// Print the result
cout << "[Buddy] Result: " << pureStrContainer.revert(pureStrContainer)
cout << "[Buddy] Result: " << pureStrContainer.revertLlama(pureStrContainer)
<< endl;
cout << "[Buddy] Llama exection time: "
<< (double)(buddyReadTime.count()) / 1000 << "s" << endl;
Expand Down
74 changes: 29 additions & 45 deletions frontend/Interfaces/buddy/LLM/TextContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "buddy/Core/Container.h"
#include <fstream>
#include <iostream>
#include <cctype>
#include <unordered_map>

namespace buddy {
Expand Down Expand Up @@ -66,7 +67,7 @@ template <typename T, size_t N> class Text : public MemRef<T, N> {
// This function initializes the conversion from Text memref to a string.
// Tokens are identified by ids and thick underlines are replaced with
// whitespaces.
std::string revert(Text<T, 2> input);
std::string revertLlama(Text<T, 2> input);

// Get sequence length
size_t getTokenCnt() { return this->tokenCnt; }
Expand All @@ -79,34 +80,18 @@ template <typename T, size_t N> class Text : public MemRef<T, N> {
}

private:
// Check if a character is a whitespace character.
bool isWhitespace(char s) const {
// TODO-HIGH: Consider using standard library functions like `isspace`.
// return isspace(static_cast<unsigned char>(s));
return s == ' ' || s == '\t' || s == '\n' || s == '\r';
}
// Check if a character is a punctuation character.
bool isPunctuation(char s) const {
// TODO-HIGH: Consider using standard library functions like `ispunct`.
// return ispunct(static_cast<unsigned char>(s));
return (s >= 33 && s <= 47) || (s >= 58 && s <= 64) ||
(s >= 91 && s <= 96) || (s >= 123 && s <= 126);
}
// Change character from uppercase to lowercase
char toLower(char s) const {
// TODO-HIGH: Consider using standard library functions like `tolower`.
// return static_cast<char>(tolower(static_cast<unsigned char>(s)));
if (s >= 65 && s <= 90)
return s + 32;
else
return s;
}
// Check if a char is a chinese character
// TODO-MID: find more accurate strategy and write more comments.
bool isChineseChar(char s) {
// Check if a char is component of multi-bytes string.
// Using lookup table to determine the number of bytes of a character.
// If the number of bytes is 1, return false(0), otherwise return the
// number of bytes.
int isMutiBytesChar(char s) {
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
int8_t highbits = static_cast<uint8_t>(s) >> 4;
return lookup[highbits] == 3;
if(lookup[highbits] == 1) {
return 0;
}
else
return lookup[highbits];
}
// Replace all " " with "▁"
std::string replaceAllSpace(const std::string &str) {
Expand Down Expand Up @@ -224,7 +209,6 @@ void Text<T, N>::tokenizeLlama(const std::string &vocab, size_t length) {

this->aligned[0] = cls;
tokenCnt = 1;

// Directly fill this->aligned in reverse order.
for (auto it = res.rbegin(); it != res.rend(); ++it) {
this->aligned[tokenCnt++] = *it;
Expand All @@ -247,7 +231,7 @@ void Text<T, N>::tokenizeBert(const std::string &vocab, size_t length,
this->size = this->product(this->sizes);
this->allocated = new T[this->size];
this->aligned = this->allocated;
this->pad = 0;
this->pad = 102;
this->unk = 100;
this->cls = 101;
this->sep = 102;
Expand All @@ -260,26 +244,26 @@ void Text<T, N>::tokenizeBert(const std::string &vocab, size_t length,
for (size_t i = 0; i < str.size(); i++) {
char s = str[i];
if (lower) {
s = toLower(s);
s = tolower(s);
}
if (isWhitespace(s) || isPunctuation(s) || isChineseChar(s)) {
if (isspace(s) || ispunct(s) || isMutiBytesChar(s)) {
if (!token.empty()) {
processToken(token, tokenCnt, affix);
token.clear();
}
if (isPunctuation(s)) {
if (ispunct(s)) {
token = s;
processToken(token, tokenCnt, false);
token.clear();
}
if (isChineseChar(s)) {
token.append(str, i, 3);
if (int bytes = isMutiBytesChar(s)) {
token.append(str, i, bytes);
// If it doesn't divide by affix, divide the Chinese words one by one.
if (!affix) {
processToken(token, tokenCnt, false);
token.clear();
}
i += 2;
i += bytes - 1;
}
} else {
token += s;
Expand All @@ -295,14 +279,14 @@ void Text<T, N>::tokenizeBert(const std::string &vocab, size_t length,
this->aligned[tokenCnt++] = sep;
// Padding the rest text container.
for (size_t i = tokenCnt; i < length; i++) {
// TODO-HIGH: considering use `pad` here.
this->aligned[i] = sep;
this->aligned[i] = pad;
}
}

// TODO-HIGH: consider using `revertLlama` here.
// The revert function is used to convert the tokenized sequence back to a
// full string.
template <typename T, size_t N>
std::string Text<T, N>::revert(Text<T, 2> input) {
std::string Text<T, N>::revertLlama(Text<T, 2> input) {
std::string dst;

const int PAD_ID = 0;
Expand All @@ -315,14 +299,14 @@ std::string Text<T, N>::revert(Text<T, 2> input) {
continue;
if (id == SEP_ID)
break;
// Replace each "▁" with a space.
std::string token = this->idToTokenVec[id];
if (token.find("") != std::string::npos) {
dst.append(" ");
// TODO-HIGH: consider whether the `3` is reasonable here.
dst.append(token, 3);
} else {
dst.append(token);
size_t pos = token.find("");
while (pos != std::string::npos) {
token.replace(pos, 3, " ");
pos = token.find("", pos + 1);
}
dst.append(token);
}
if (dst[0] == ' ') {
dst.erase(0, 1);
Expand Down
8 changes: 5 additions & 3 deletions tests/Interface/core/TextContainerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ int main() {
fprintf(stderr, "%ld\n", cornerStrContainer.getData()[11]);

//===--------------------------------------------------------------------===//
// Test text constructor for chinese cases.
// Test text constructor for mutibyteschar cases.
// Specially, the Chinese characters are included.
// Select Chinese characters for testing.
//===--------------------------------------------------------------------===//
std::string chineseStr = "我,中国北京人!";
Text<size_t, 2> chineseStrBertContainer(chineseStr);
Expand Down Expand Up @@ -232,7 +234,7 @@ int main() {
Text<size_t, 2> pureStrLlamaContainer(pureStrLlama);
pureStrLlamaContainer.tokenizeLlama(vocabDir, 12);
std::string pureStrLlamaResult =
pureStrLlamaContainer.revert(pureStrLlamaContainer);
pureStrLlamaContainer.revertLlama(pureStrLlamaContainer);
// CHECK: 1
fprintf(stderr, "%ld\n", pureStrLlamaContainer.getData()[0]);
// CHECK: 8619
Expand Down Expand Up @@ -266,7 +268,7 @@ int main() {
Text<size_t, 2> puncStrLlamaContainer(puncStrLlama);
puncStrLlamaContainer.tokenizeLlama(vocabDir, 12);
std::string puncStrLlamaResult =
puncStrLlamaContainer.revert(puncStrLlamaContainer);
puncStrLlamaContainer.revertLlama(puncStrLlamaContainer);
// CHECK: 1
fprintf(stderr, "%ld\n", puncStrLlamaContainer.getData()[0]);
// CHECK: 8619
Expand Down

0 comments on commit 022b26e

Please sign in to comment.