Skip to content

Commit

Permalink
Fix RAG analyzer (#2048)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

1. Fix bug on converting full width to half width characters.
2. Fix corruption for multi-threaded analyzer 

Issue link:#1973

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
yingfeng authored Oct 15, 2024
1 parent 443f9f4 commit 4b619c2
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 87 deletions.
166 changes: 80 additions & 86 deletions src/common/analyzer/rag_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,27 @@ void Split(const String &input, const String &split_pattern, std::vector<String>
result.push_back(String(token.data(), token.size()));
}
if (keep_delim)
result.push_back(std::string(extracted_delim_token.data(), extracted_delim_token.size()));
result.push_back(String(extracted_delim_token.data(), extracted_delim_token.size()));
last_end = leftover;
}

if (!leftover.empty()) {
result.push_back(String(leftover.data(), leftover.size()));
}
}

void Split(const String &input, const RE2 &pattern, std::vector<String> &result, bool keep_delim = false) {
re2::StringPiece leftover(input.data());
re2::StringPiece last_end = leftover;
re2::StringPiece extracted_delim_token;

while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) {
std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data());
if (!token.empty()) {
result.push_back(String(token.data(), token.size()));
}
if (keep_delim)
result.push_back(String(extracted_delim_token.data(), extracted_delim_token.size()));
last_end = leftover;
}

Expand All @@ -104,8 +124,6 @@ String Replace(const RE2 &re, const String &replacement, const String &input) {
return output;
}

bool RegexMatch(const String &str, const String &pattern) { return RE2::PartialMatch(str, RE2(pattern)); }

String Join(const Vector<String> &tokens, int start, int end, const String &delim = " ") {
std::ostringstream oss;
for (int i = start; i < end; ++i) {
Expand All @@ -128,59 +146,6 @@ String Join(const TermList &tokens, int start, int end, const String &delim = "
return oss.str();
}

std::wstring UTF8ToWide(const String &utf8) {
std::wstring result;
int i = 0, length = utf8.length();

while (i < length) {
wchar_t wchar;
unsigned char byte1 = utf8[i];

if (byte1 <= 0x7F) {
wchar = byte1;
i += 1;
} else if ((byte1 & 0xE0) == 0xC0) {
if (i + 1 >= length)
throw std::runtime_error("Invalid UTF-8 string");
wchar = (byte1 & 0x1F) << 6 | (utf8[i + 1] & 0x3F);
i += 2;
} else if ((byte1 & 0xF0) == 0xE0) {
if (i + 2 >= length)
throw std::runtime_error("Invalid UTF-8 string");
wchar = (byte1 & 0x0F) << 12 | (utf8[i + 1] & 0x3F) << 6 | (utf8[i + 2] & 0x3F);
i += 3;
} else {
throw std::runtime_error("Invalid UTF-8 string");
}

result += wchar;
}

return result;
}

String WideCharToUTF8(wchar_t wchar) {
String result;

if (wchar <= 0x7F) {
result += static_cast<char>(wchar); // 1 byte
} else if (wchar <= 0x7FF) {
result += static_cast<char>(0xC0 | ((wchar >> 6) & 0x1F)); // 2 bytes
result += static_cast<char>(0x80 | (wchar & 0x3F));
} else if (wchar <= 0xFFFF) {
result += static_cast<char>(0xE0 | ((wchar >> 12) & 0x0F)); // 3 bytes
result += static_cast<char>(0x80 | ((wchar >> 6) & 0x3F));
result += static_cast<char>(0x80 | (wchar & 0x3F));
} else if (wchar <= 0x10FFFF) {
result += static_cast<char>(0xF0 | ((wchar >> 18) & 0x07)); // 4 bytes
result += static_cast<char>(0x80 | ((wchar >> 12) & 0x3F));
result += static_cast<char>(0x80 | ((wchar >> 6) & 0x3F));
result += static_cast<char>(0x80 | (wchar & 0x3F));
}

return result;
}

bool IsChinese(const String &str) {
for (std::size_t i = 0; i < str.length(); ++i) {
unsigned char c = str[i];
Expand Down Expand Up @@ -254,18 +219,23 @@ class RegexTokenizer {
};

RAGAnalyzer::RAGAnalyzer(const String &path)
: dict_path_(path), lowercase_string_buffer_(term_string_buffer_limit_), regex_tokenizer_(MakeUnique<RegexTokenizer>()) {}
: dict_path_(path), stemmer_(MakeUnique<Stemmer>()), lowercase_string_buffer_(term_string_buffer_limit_),
regex_tokenizer_(MakeUnique<RegexTokenizer>()) {
InitStemmer(STEM_LANG_ENGLISH);
}

RAGAnalyzer::RAGAnalyzer(const RAGAnalyzer &other)
: own_dict_(false), trie_(other.trie_), pos_table_(other.pos_table_), lemma_(other.lemma_), stemmer_(other.stemmer_), opencc_(other.opencc_),
lowercase_string_buffer_(term_string_buffer_limit_), fine_grained_(other.fine_grained_), regex_tokenizer_(MakeUnique<RegexTokenizer>()) {}
: own_dict_(false), trie_(other.trie_), pos_table_(other.pos_table_), lemma_(other.lemma_), stemmer_(MakeUnique<Stemmer>()),
opencc_(other.opencc_), lowercase_string_buffer_(term_string_buffer_limit_), fine_grained_(other.fine_grained_),
regex_tokenizer_(MakeUnique<RegexTokenizer>()) {
InitStemmer(STEM_LANG_ENGLISH);
}

RAGAnalyzer::~RAGAnalyzer() {
if (own_dict_) {
delete trie_;
delete pos_table_;
delete lemma_;
delete stemmer_;
delete opencc_;
}
}
Expand Down Expand Up @@ -330,9 +300,6 @@ Status RAGAnalyzer::Load() {
}
lemma_ = new Lemmatizer(lemma_path.string());

stemmer_ = new Stemmer();
InitStemmer(STEM_LANG_ENGLISH);

fs::path opencc_path(root / OPENCC_PATH);

if (!fs::exists(opencc_path)) {
Expand All @@ -348,22 +315,50 @@ Status RAGAnalyzer::Load() {
}

String RAGAnalyzer::StrQ2B(const String &input) {
std::wstring wide_str = UTF8ToWide(input);
String result;

for (wchar_t wchar : wide_str) {
int code = static_cast<int>(wchar);
if (code == 0x3000) {
result += ' ';
} else if (code >= 0xFF01 && code <= 0xFF5E) {
// Convert full-width characters to half-width
result += static_cast<char>(code - 0xfee0);
String output;
size_t i = 0;

while (i < input.size()) {
unsigned char c = input[i];

uint32_t codepoint = 0;
if (c < 0x80) {
codepoint = c;
i += 1;
} else if ((c & 0xE0) == 0xC0) {
codepoint = (c & 0x1F) << 6;
codepoint |= (input[i + 1] & 0x3F);
i += 2;
} else if ((c & 0xF0) == 0xE0) {
codepoint = (c & 0x0F) << 12;
codepoint |= (input[i + 1] & 0x3F) << 6;
codepoint |= (input[i + 2] & 0x3F);
i += 3;
} else {
output += c;
i += 1;
continue;
}

if (codepoint >= 0xFF01 && codepoint <= 0xFF5E) {
output += static_cast<char>(codepoint - 0xFEE0);
} else if (codepoint == 0x3000) {
output += ' ';
} else {
result += WideCharToUTF8(wchar);
if (codepoint < 0x80) {
output += static_cast<char>(codepoint);
} else if (codepoint < 0x800) {
output += static_cast<char>(0xC0 | (codepoint >> 6));
output += static_cast<char>(0x80 | (codepoint & 0x3F));
} else if (codepoint < 0x10000) {
output += static_cast<char>(0xE0 | (codepoint >> 12));
output += static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F));
output += static_cast<char>(0x80 | (codepoint & 0x3F));
}
}
}

return result;
return output;
}

i32 RAGAnalyzer::Freq(const String &key) {
Expand Down Expand Up @@ -544,11 +539,10 @@ int RAGAnalyzer::DFS(const String &chars, int s, Vector<Pair<String, int>> &pre_
String RAGAnalyzer::Merge(const String &tks_str) {
String tks = tks_str;

RE2 re_space(R"#(([ ]+))#");
tks = Replace(re_space, " ", tks);
tks = Replace(replace_space_pattern_, " ", tks);

Vector<String> tokens;
Split(tks, "( )", tokens);
Split(tks, blank_pattern_, tokens);
Vector<String> res;
std::size_t s = 0;
while (true) {
Expand All @@ -558,7 +552,7 @@ String RAGAnalyzer::Merge(const String &tks_str) {
std::size_t E = s + 1;
for (std::size_t e = s + 2; e < std::min(tokens.size() + 1, s + 6); ++e) {
String tk = Join(tokens, s, e, "");
if (RE2::PartialMatch(tk, REGEX_SPLIT_CHAR)) {
if (RE2::PartialMatch(tk, regex_split_pattern_)) {
if (Freq(tk) > 0) {
E = e;
}
Expand All @@ -573,7 +567,7 @@ String RAGAnalyzer::Merge(const String &tks_str) {

void RAGAnalyzer::EnglishNormalize(const Vector<String> &tokens, Vector<String> &res) {
for (auto &t : tokens) {
if (RegexMatch(t, "[a-zA-Z_-]+$")) {
if (RE2::PartialMatch(t, pattern1_)) { //"[a-zA-Z_-]+$"
String lemma_term = lemma_->Lemmatize(t);
char *lowercase_term = lowercase_string_buffer_.data();
ToLower(lemma_term.c_str(), lemma_term.size(), lowercase_term, term_string_buffer_limit_);
Expand Down Expand Up @@ -616,9 +610,9 @@ String RAGAnalyzer::Tokenize(const String &line) {
}

Vector<String> arr;
Split(strline, REGEX_SPLIT_CHAR, arr, true);
Split(strline, regex_split_pattern_, arr, true);
for (const auto &L : arr) {
if (UTF8Length(L) < 2 || RegexMatch(L, "[a-z\\.-]+$") || RegexMatch(L, "[0-9\\.-]+$")) {
if (UTF8Length(L) < 2 || RE2::PartialMatch(L, pattern2_) || RE2::PartialMatch(L, pattern3_)) { //[a-z\\.-]+$ [0-9\\.-]+$
res.push_back(L);
continue;
}
Expand Down Expand Up @@ -674,7 +668,7 @@ String RAGAnalyzer::Tokenize(const String &line) {

void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector<String> &result) {
Vector<String> tks;
Split(tokens, "( )", tks);
Split(tokens, blank_pattern_, tks);
Vector<String> res;
std::size_t zh_num = 0;
for (auto &token : tks) {
Expand All @@ -699,7 +693,7 @@ void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector<String> &resu
}

for (auto &token : tks) {
if (UTF8Length(token) < 3 || RegexMatch(token, "[0-9,\\.-]+$")) {
if (UTF8Length(token) < 3 || RE2::PartialMatch(token, pattern4_)) { //[0-9,\\.-]+$
res.push_back(token);
continue;
}
Expand All @@ -722,7 +716,7 @@ void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector<String> &resu
String s_token;
if (stk.size() == token.length()) {
s_token = token;
} else if (RegexMatch(token, "[a-z\\.-]+")) {
} else if (RE2::PartialMatch(token, pattern5_)) { // [a-z\\.-]+
for (auto &t : stk) {
if (UTF8Length(t) < 3) {
s_token = token;
Expand All @@ -746,7 +740,7 @@ int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
if (fine_grained_) {
FineGrainedTokenize(output, tokens);
} else
Split(output, "( )", tokens);
Split(output, blank_pattern_, tokens);
unsigned offset = 0;
for (auto &t : tokens) {
func(data, t.c_str(), t.size(), offset++, 0, Term::AND, level, false);
Expand Down
18 changes: 17 additions & 1 deletion src/common/analyzer/rag_analyzer.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public:

Lemmatizer *lemma_{nullptr};

Stemmer *stemmer_{nullptr};
UniquePtr<Stemmer> stemmer_;

OpenCC *opencc_{nullptr};

Expand All @@ -102,5 +102,21 @@ public:
bool fine_grained_{false};

UniquePtr<RegexTokenizer> regex_tokenizer_;

RE2 pattern1_{"[a-zA-Z_-]+$"};

RE2 pattern2_{"[a-z\\.-]+$"};

RE2 pattern3_{"[0-9\\.-]+$"};

RE2 pattern4_{"[0-9,\\.-]+$"};

RE2 pattern5_{"[a-z\\.-]+"};

RE2 regex_split_pattern_{R"#(([ ,\.<>/?;'\[\]\`!@#$%^&*$$\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z\.-]+|[0-9,\.-]+))#"};

RE2 blank_pattern_{"( )"};

RE2 replace_space_pattern_{R"#(([ ]+))#"};
};
} // namespace infinity

0 comments on commit 4b619c2

Please sign in to comment.