Skip to content

Commit

Permalink
fix different stop/bad words length in batch
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Mar 5, 2024
1 parent a6e8188 commit efb2396
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,14 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
template<typename T>
void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
{
d_stop_words_ = (int*)allocator_->reMalloc(d_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
d_bad_words_ = (int*)allocator_->reMalloc(d_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
d_stop_words_ =
(int*)allocator_->reMalloc(d_stop_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true);
d_bad_words_ =
(int*)allocator_->reMalloc(d_bad_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true);
h_stop_words_ =
(int*)allocator_->reMalloc(h_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
(int*)allocator_->reMalloc(h_stop_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true, true);
h_bad_words_ =
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * 2 * kMaxStopBadWordsLen, true, true);

h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
Expand Down Expand Up @@ -1050,17 +1052,48 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
FT_CHECK(shape[0] == 1);
shape[0] = batch_size;
const int size_in_bytes = ref.sizeBytes();
memset(h_ptr, 0, size_in_bytes * batch_size);

int max_list_length = 0;
if (name == "bad_words_list" || name == "stop_words_list") {
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(name)) {
Tensor& src = state_->requests[i]->inputs[rank_].at(name);
FT_CHECK(src.shape.size() == 3 && src.shape[1] == 2 && src.shape[2] <= kMaxStopBadWordsLen);
max_list_length = std::max(max_list_length, (int)src.shape[2]);
}
}
std::fill_n((int*)h_ptr, batch_size * 2 * max_list_length, -1);
shape[2] = max_list_length;
}
else {
memset(h_ptr, 0, size_in_bytes * batch_size);
}
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->inputs[rank_].isExist(name)) {
Tensor& src = state_->requests[i]->inputs[rank_].at(name);
FT_CHECK(ref.shape == src.shape);
std::copy_n(src.getPtr<std::byte>(), size_in_bytes, h_ptr + size_in_bytes * i);
if (name == "bad_words_list" || name == "stop_words_list") {
int list_length = src.shape[2];
std::copy_n(src.getPtr<std::byte>(),
sizeof(int) * list_length,
h_ptr + i * sizeof(int) * 2 * max_list_length);
std::copy_n(src.getPtr<std::byte>() + sizeof(int) * list_length,
sizeof(int) * list_length,
h_ptr + i * sizeof(int) * 2 * max_list_length + sizeof(int) * max_list_length);
}
else {
FT_CHECK(ref.shape == src.shape);
std::copy_n(src.getPtr<std::byte>(), size_in_bytes, h_ptr + size_in_bytes * i);
}
}
}
if (d_ptr) {
Copy(h_ptr, batch_size * size_in_bytes, d_ptr);
if (name == "bad_words_list" || name == "stop_words_list") {
Copy(h_ptr, batch_size * sizeof(int) * 2 * max_list_length, d_ptr);
}
else {
Copy(h_ptr, batch_size * size_in_bytes, d_ptr);
}
}
inputs.insert({name, {d_ptr ? MEMORY_GPU : MEMORY_CPU, ref.type, shape, d_ptr ? d_ptr : h_ptr}});
if (debug_ && rank_ == 0) {
Expand Down

0 comments on commit efb2396

Please sign in to comment.