diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index cf5ff7062..66c4d7681 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -30,6 +30,8 @@ class GenerationConfig: random_seed (int): Seed used when sampling a token stop_words (List[str]): Words that stop generating further tokens bad_words (List[str]): Words that the engine will never generate + min_new_tokens (int): The minimum numbers of tokens to generate, + ignoring the number of tokens in the prompt. """ n: int = 1 @@ -42,6 +44,7 @@ class GenerationConfig: random_seed: int = None stop_words: List[str] = None bad_words: List[str] = None + min_new_tokens: int = None @dataclass @@ -65,7 +68,7 @@ def From(gen_config: GenerationConfig, tokenizer: Tokenizer): >>> tokenizer = Tokenizer('internlm/internlm-chat-7b') >>> gen_config = GenerationConfig(stop_words=['']) >>> gen_config = EngineGenerationConfig.From(gen_config, tokenizer) - """ # noqa E501 + """ # noqa E501 def special_word_token_ids(words): if words is not None: diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 4a4dc9157..9febedefe 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -648,6 +648,10 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): inputs['input_embeddings'] = input_embeddings inputs['input_embedding_ranges'] = input_embedding_ranges + if gen_config.min_new_tokens is not None: + inputs['min_length'] = _broadcast_np(gen_config.min_new_tokens, + np.int32) + bad_words = [] if gen_config.bad_words is not None: bad_words.extend(gen_config.bad_words) diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index f7ebfeff0..28bf43aac 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -497,9 +497,8 @@ __global__ void batchApplyMinLengthPenalty(T* logits, const int vocab_size_padded) { int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index - // We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1, - // which is equal to the length of k/v caches. - if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) { + // In decoder, sequence_lengths means length of sequence that has kv cache already computed + if (sequence_lengths[bid] + 1 < min_lengths[bid]) { T mask_val = (std::is_same::value) ? -65504.0f : -FLT_MAX; logits[bid * vocab_size_padded + end_ids[bid]] = mask_val; } diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc index 1c9ae099d..91b6809f3 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.cc @@ -45,6 +45,7 @@ void BaseSamplingLayer::allocateBuffer(size_t batch_size, Tensor top_k, Tenso repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float)); min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int)); skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool)); + context_length_ = (int*)std::realloc((void*)context_length_, batch_size * sizeof(int)); is_allocate_buffer_ = true; } @@ -63,6 +64,7 @@ void BaseSamplingLayer::freeBuffer() std::free(repetition_penalty_); std::free(min_lengths_); std::free(skip_decode_); + std::free(context_length_); is_allocate_buffer_ = false; } } @@ -161,16 +163,23 @@ void BaseSamplingLayer::setup(const size_t batch_size, const size_t beam_widt repetition_penalty_type_ = RepetitionPenaltyType::None; } - const int default_min_length = 0; - Tensor min_lengths = runtime_args->at("min_length", Tensor(MEMORY_CPU, TYPE_INT32, {1}, &default_min_length)); - if (min_lengths.size() == 1) { - int minlen = min_lengths.getVal(); - deviceFill(min_lengths_buf_, batch_size, minlen, stream_); - std::fill_n(min_lengths_, batch_size, minlen); + // min_length + if (runtime_args->isExist("min_length")) { + Tensor min_lengths = runtime_args->at("min_length"); + Tensor context_lengths = runtime_args->at("context_length"); + Tensor prompt_lengths = runtime_args->at("prompt_length"); + auto p1 = min_lengths.getPtr(); + auto p2 = prompt_lengths.getPtr(); + for (int i = 0; i < batch_size; i++) { + min_lengths_[i] = p1[i] + p2[i]; + } + cudaAutoCpy(min_lengths_buf_, min_lengths_, batch_size, stream_); + std::copy_n(context_lengths.getPtr(), batch_size, context_length_); } else { - cudaAutoCpy(min_lengths_buf_, min_lengths.getPtr(), batch_size, stream_); - std::copy_n(min_lengths.getPtr(), batch_size, min_lengths_); + std::fill_n(min_lengths_, batch_size, 0); + deviceFill(min_lengths_buf_, batch_size, 0, stream_); + std::fill_n(context_length_, batch_size, 0); } } @@ -300,10 +309,12 @@ void BaseSamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_t } } - const int num_generated_tokens = step - max_input_length; - const int* min_lengths = min_lengths_ + ite * local_batch_size; + const int num_generated_tokens = step - max_input_length; + const int* min_lengths = min_lengths_ + ite * local_batch_size; + std::vector index(local_batch_size); + std::iota(index.begin(), index.end(), 0); const bool invoke_min_length_penalty = std::any_of( - min_lengths, min_lengths + local_batch_size, [&](int min_length) { return min_length > num_generated_tokens; }); + index.begin(), index.end(), [&](int i) { return min_lengths[i] > context_length_[i] + num_generated_tokens; }); if (invoke_min_length_penalty) { FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty"); invokeMinLengthPenalty(logits, diff --git a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h index 29462e16a..68cf79c87 100644 --- a/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/turbomind/layers/sampling_layers/BaseSamplingLayer.h @@ -47,6 +47,7 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { int* min_lengths_ = nullptr; bool* skip_decode_ = nullptr; bool skip_any_ = false; + int* context_length_ = nullptr; RepetitionPenaltyType repetition_penalty_type_ = RepetitionPenaltyType::None; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 89e0b45bf..c6bfc1502 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -328,6 +328,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) } // total context length (history + input) + state.h_prompt_length[idx] = output_ids - output_ids_base; state.h_context_length[idx] = output_ids - output_ids_base; state.h_finished[idx] = false; @@ -698,6 +699,7 @@ void LlamaBatch::CopyState(const std::vectorh_prompt_length[di] = s->h_prompt_length[si]; d->h_context_length[di] = s->h_context_length[si]; d->h_finished[di] = s->h_finished[si]; d->h_rope_theta[di] = s->h_rope_theta[si]; @@ -774,6 +776,7 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) h_bad_words_ = (int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true); + h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true); h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true); h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true); h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true); @@ -796,6 +799,7 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) sampling_params_ = { {"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_}, {"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_}, + {"min_length", (std::byte*)h_min_length_, nullptr}, {"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr}, {"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr}, {"temperature", (std::byte*)h_temperature_, nullptr}, @@ -828,6 +832,8 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) (uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true); for (auto& s : states_) { + s.h_prompt_length = + (int*)allocator_->reMalloc(s.h_prompt_length, sizeof(int) * max_batch_size, false, true); s.h_context_length = (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true); s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true); @@ -1060,6 +1066,12 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) } } + // MinLengthPenalty + if (inputs.isExist("min_length")) { + inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}}); + inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}}); + } + // init for eos std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_); Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 01caaefb3..f04044830 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -20,6 +20,7 @@ namespace turbomind { struct BatchState { + int* h_prompt_length; // history + input, ignore generated int* h_context_length; bool* h_finished; @@ -249,6 +250,7 @@ class LlamaBatch { uintptr_t* h_k_block_ptrs_{}; uintptr_t* h_v_block_ptrs_{}; + int* h_min_length_{}; int* h_runtime_top_k_{}; float* h_runtime_top_p_{}; float* h_temperature_{};