From 022b26ec62994c0a899c709cd35d781073cc3587 Mon Sep 17 00:00:00 2001 From: Lester <44289481+Lester-1@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:47:53 +0800 Subject: [PATCH] [frontend][LLM] Update TextContainer (#223) * [frontend][LLM] Completed all todos in TextContainer * [frontend][LLM] Change vocab path * [frontend][LLM] Change variable name * [examples][BuddyLlama] Change function name in llama-main.cpp * [frontend][LLM] Delete unnecessary function --- examples/BuddyLlama/llama-main.cpp | 4 +- frontend/Interfaces/buddy/LLM/TextContainer.h | 74 ++++++++----------- tests/Interface/core/TextContainerTest.cpp | 8 +- 3 files changed, 36 insertions(+), 50 deletions(-) diff --git a/examples/BuddyLlama/llama-main.cpp b/examples/BuddyLlama/llama-main.cpp index 85d4fef944..537b36fb9d 100644 --- a/examples/BuddyLlama/llama-main.cpp +++ b/examples/BuddyLlama/llama-main.cpp @@ -45,7 +45,7 @@ int main() { auto buddyTokenizeTime = duration_cast(buddyTokenizeEnd - buddyTokenizeStart); // Print the tokenized result - cout << "Get User input:" << pureStrContainer.revert(pureStrContainer) + cout << "Get User input:" << pureStrContainer.revertLlama(pureStrContainer) << endl; cout << "[Buddy] Tokenize input time: " << buddyTokenizeTime.count() << "ms" << endl; @@ -108,7 +108,7 @@ int main() { auto buddyEnd = system_clock::now(); buddyReadTime = duration_cast(buddyEnd - buddyStart); // Print the result - cout << "[Buddy] Result: " << pureStrContainer.revert(pureStrContainer) + cout << "[Buddy] Result: " << pureStrContainer.revertLlama(pureStrContainer) << endl; cout << "[Buddy] Llama exection time: " << (double)(buddyReadTime.count()) / 1000 << "s" << endl; diff --git a/frontend/Interfaces/buddy/LLM/TextContainer.h b/frontend/Interfaces/buddy/LLM/TextContainer.h index b5a76bff37..7c7860001a 100644 --- a/frontend/Interfaces/buddy/LLM/TextContainer.h +++ b/frontend/Interfaces/buddy/LLM/TextContainer.h @@ -26,6 +26,7 @@ #include "buddy/Core/Container.h" #include #include +#include #include namespace buddy { @@ -66,7 +67,7 @@ template class Text : public MemRef { // This function initializes the conversion from Text memref to a string. // Tokens are identified by ids and thick underlines are replaced with // whitespaces. - std::string revert(Text input); + std::string revertLlama(Text input); // Get sequence length size_t getTokenCnt() { return this->tokenCnt; } @@ -79,34 +80,18 @@ template class Text : public MemRef { } private: - // Check if a character is a whitespace character. - bool isWhitespace(char s) const { - // TODO-HIGH: Consider using standard library functions like `isspace`. - // return isspace(static_cast(s)); - return s == ' ' || s == '\t' || s == '\n' || s == '\r'; - } - // Check if a character is a punctuation character. - bool isPunctuation(char s) const { - // TODO-HIGH: Consider using standard library functions like `ispunct`. - // return ispunct(static_cast(s)); - return (s >= 33 && s <= 47) || (s >= 58 && s <= 64) || - (s >= 91 && s <= 96) || (s >= 123 && s <= 126); - } - // Change character from uppercase to lowercase - char toLower(char s) const { - // TODO-HIGH: Consider using standard library functions like `tolower`. - // return static_cast(tolower(static_cast(s))); - if (s >= 65 && s <= 90) - return s + 32; - else - return s; - } - // Check if a char is a chinese character - // TODO-MID: find more accurate strategy and write more comments. - bool isChineseChar(char s) { + // Check if a char is component of multi-bytes string. + // Using lookup table to determine the number of bytes of a character. + // If the number of bytes is 1, return false(0), otherwise return the + // number of bytes. + int isMutiBytesChar(char s) { const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; int8_t highbits = static_cast(s) >> 4; - return lookup[highbits] == 3; + if(lookup[highbits] == 1) { + return 0; + } + else + return lookup[highbits]; } // Replace all " " with "▁" std::string replaceAllSpace(const std::string &str) { @@ -224,7 +209,6 @@ void Text::tokenizeLlama(const std::string &vocab, size_t length) { this->aligned[0] = cls; tokenCnt = 1; - // Directly fill this->aligned in reverse order. for (auto it = res.rbegin(); it != res.rend(); ++it) { this->aligned[tokenCnt++] = *it; @@ -247,7 +231,7 @@ void Text::tokenizeBert(const std::string &vocab, size_t length, this->size = this->product(this->sizes); this->allocated = new T[this->size]; this->aligned = this->allocated; - this->pad = 0; + this->pad = 102; this->unk = 100; this->cls = 101; this->sep = 102; @@ -260,26 +244,26 @@ void Text::tokenizeBert(const std::string &vocab, size_t length, for (size_t i = 0; i < str.size(); i++) { char s = str[i]; if (lower) { - s = toLower(s); + s = tolower(s); } - if (isWhitespace(s) || isPunctuation(s) || isChineseChar(s)) { + if (isspace(s) || ispunct(s) || isMutiBytesChar(s)) { if (!token.empty()) { processToken(token, tokenCnt, affix); token.clear(); } - if (isPunctuation(s)) { + if (ispunct(s)) { token = s; processToken(token, tokenCnt, false); token.clear(); } - if (isChineseChar(s)) { - token.append(str, i, 3); + if (int bytes = isMutiBytesChar(s)) { + token.append(str, i, bytes); // If it doesn't divide by affix, divide the Chinese words one by one. if (!affix) { processToken(token, tokenCnt, false); token.clear(); } - i += 2; + i += bytes - 1; } } else { token += s; @@ -295,14 +279,14 @@ void Text::tokenizeBert(const std::string &vocab, size_t length, this->aligned[tokenCnt++] = sep; // Padding the rest text container. for (size_t i = tokenCnt; i < length; i++) { - // TODO-HIGH: considering use `pad` here. - this->aligned[i] = sep; + this->aligned[i] = pad; } } -// TODO-HIGH: consider using `revertLlama` here. +// The revert function is used to convert the tokenized sequence back to a +// full string. template -std::string Text::revert(Text input) { +std::string Text::revertLlama(Text input) { std::string dst; const int PAD_ID = 0; @@ -315,14 +299,14 @@ std::string Text::revert(Text input) { continue; if (id == SEP_ID) break; + // Replace each "▁" with a space. std::string token = this->idToTokenVec[id]; - if (token.find("▁") != std::string::npos) { - dst.append(" "); - // TODO-HIGH: consider whether the `3` is reasonable here. - dst.append(token, 3); - } else { - dst.append(token); + size_t pos = token.find("▁"); + while (pos != std::string::npos) { + token.replace(pos, 3, " "); + pos = token.find("▁", pos + 1); } + dst.append(token); } if (dst[0] == ' ') { dst.erase(0, 1); diff --git a/tests/Interface/core/TextContainerTest.cpp b/tests/Interface/core/TextContainerTest.cpp index 9d5841f914..8a2bf10440 100644 --- a/tests/Interface/core/TextContainerTest.cpp +++ b/tests/Interface/core/TextContainerTest.cpp @@ -129,7 +129,9 @@ int main() { fprintf(stderr, "%ld\n", cornerStrContainer.getData()[11]); //===--------------------------------------------------------------------===// - // Test text constructor for chinese cases. + // Test text constructor for mutibyteschar cases. + // Specially, the Chinese characters are included. + // Select Chinese characters for testing. //===--------------------------------------------------------------------===// std::string chineseStr = "我,中国北京人!"; Text chineseStrBertContainer(chineseStr); @@ -232,7 +234,7 @@ int main() { Text pureStrLlamaContainer(pureStrLlama); pureStrLlamaContainer.tokenizeLlama(vocabDir, 12); std::string pureStrLlamaResult = - pureStrLlamaContainer.revert(pureStrLlamaContainer); + pureStrLlamaContainer.revertLlama(pureStrLlamaContainer); // CHECK: 1 fprintf(stderr, "%ld\n", pureStrLlamaContainer.getData()[0]); // CHECK: 8619 @@ -266,7 +268,7 @@ int main() { Text puncStrLlamaContainer(puncStrLlama); puncStrLlamaContainer.tokenizeLlama(vocabDir, 12); std::string puncStrLlamaResult = - puncStrLlamaContainer.revert(puncStrLlamaContainer); + puncStrLlamaContainer.revertLlama(puncStrLlamaContainer); // CHECK: 1 fprintf(stderr, "%ld\n", puncStrLlamaContainer.getData()[0]); // CHECK: 8619