Skip to content

Commit

Permalink
[frontend][LLM] Assign SEP to all the tail elements.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Aug 29, 2023
1 parent b572cd6 commit 19fb275
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
5 changes: 4 additions & 1 deletion frontend/Interfaces/buddy/LLM/TextContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ void Text<T, N>::tokenize(const std::string &vocab, long long length) {
processToken(token, tokenCnt);
token.clear();
}
this->aligned[tokenCnt] = 102; // [SEP] NLP Separator Marker
// [SEP] NLP Separator Marker
for (long long i = tokenCnt; i < length; i++) {
this->aligned[i] = 102;
}
}

template <typename T, size_t N>
Expand Down
12 changes: 12 additions & 0 deletions tests/Interface/core/TextContainerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ int main() {
fprintf(stderr, "%lld\n", pureStrContainer.getData()[7]);
// CHECK: 102
fprintf(stderr, "%lld\n", pureStrContainer.getData()[8]);
// CHECK: 102
fprintf(stderr, "%lld\n", pureStrContainer.getData()[9]);
// CHECK: 102
fprintf(stderr, "%lld\n", pureStrContainer.getData()[10]);
// CHECK: 102
fprintf(stderr, "%lld\n", pureStrContainer.getData()[11]);

//===--------------------------------------------------------------------===//
// Test text constructor for punctuation.
Expand Down Expand Up @@ -84,6 +90,10 @@ int main() {
fprintf(stderr, "%lld\n", puncStrContainer.getData()[8]);
// CHECK: 102
fprintf(stderr, "%lld\n", puncStrContainer.getData()[9]);
// CHECK: 102
fprintf(stderr, "%lld\n", puncStrContainer.getData()[10]);
// CHECK: 102
fprintf(stderr, "%lld\n", puncStrContainer.getData()[11]);

//===--------------------------------------------------------------------===//
// Test text constructor for corner cases.
Expand Down Expand Up @@ -113,4 +123,6 @@ int main() {
fprintf(stderr, "%lld\n", cornerStrContainer.getData()[9]);
// CHECK: 102
fprintf(stderr, "%lld\n", cornerStrContainer.getData()[10]);
// CHECK: 102
fprintf(stderr, "%lld\n", cornerStrContainer.getData()[11]);
}

0 comments on commit 19fb275

Please sign in to comment.