From 9a972b6c2eb2d67efd1eb38331ea718087011db6 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 17 Nov 2023 07:57:50 +0000 Subject: [PATCH] fix eos --- src/turbomind/models/llama/LlamaBatch.cc | 15 ++++++++++++--- src/turbomind/models/llama/LlamaBatch.h | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 0937e6b419..8e30ee9ad1 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -574,7 +574,6 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len) token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true); - end_ids_buf_ = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false); finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false); seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false); @@ -613,6 +612,9 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) d_curand_state_ = (curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false); + d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false); + h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true); + sampling_params_ = { {"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_}, {"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_}, @@ -714,7 +716,9 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&token_ids_buf_); - allocator_->free((void**)&end_ids_buf_); + allocator_->free((void**)&d_end_ids_buf_); + allocator_->free((void**)&h_end_ids_buf_, true); + allocator_->free((void**)&finished_buf_); allocator_->free((void**)&seq_limit_len_); @@ -842,6 +846,11 @@ void LlamaBatch::InitializeSampling() } } + // init for eos + std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_); + Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_); + inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)batch_size}, d_end_ids_buf_}}); + inputs_ = std::move(inputs); model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_); @@ -993,7 +1002,7 @@ bool LlamaBatch::Generate(GenerationState& g) logits_buf_, seq_limit_len_, context_length_buf_, - end_ids_buf_, + d_end_ids_buf_, g.step, 0, g.max_init_ctx_len, diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 83b054030e..eca1270cbc 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -225,9 +225,10 @@ class LlamaBatch { // used by dynamic decoder int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step` - int* end_ids_buf_{}; bool* finished_buf_{}; uint32_t* seq_limit_len_{}; + int* h_end_ids_buf_{}; + int* d_end_ids_buf_{}; int** request_output_ids_ptrs_{}; int* request_output_ids_lens_{};