Skip to content

Commit

Permalink
fix eos
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Nov 17, 2023
1 parent 68bd386 commit 9a972b6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
15 changes: 12 additions & 3 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)

token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true);

end_ids_buf_ = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false);
finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);

Expand Down Expand Up @@ -613,6 +612,9 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
d_curand_state_ =
(curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false);

d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false);
h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true);

sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
Expand Down Expand Up @@ -714,7 +716,9 @@ void LlamaBatch<T>::FreeBuffer()

allocator_->free((void**)&token_ids_buf_);

allocator_->free((void**)&end_ids_buf_);
allocator_->free((void**)&d_end_ids_buf_);
allocator_->free((void**)&h_end_ids_buf_, true);

allocator_->free((void**)&finished_buf_);
allocator_->free((void**)&seq_limit_len_);

Expand Down Expand Up @@ -842,6 +846,11 @@ void LlamaBatch<T>::InitializeSampling()
}
}

// init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)batch_size}, d_end_ids_buf_}});

inputs_ = std::move(inputs);

model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
Expand Down Expand Up @@ -993,7 +1002,7 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
logits_buf_,
seq_limit_len_,
context_length_buf_,
end_ids_buf_,
d_end_ids_buf_,
g.step,
0,
g.max_init_ctx_len,
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,10 @@ class LlamaBatch {

// used by dynamic decoder
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
int* end_ids_buf_{};
bool* finished_buf_{};
uint32_t* seq_limit_len_{};
int* h_end_ids_buf_{};
int* d_end_ids_buf_{};

int** request_output_ids_ptrs_{};
int* request_output_ids_lens_{};
Expand Down

0 comments on commit 9a972b6

Please sign in to comment.