diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 4cf9d53362..a932a12d3e 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -207,7 +207,6 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) auto& seq = *state.sequences[idx]; if (int step = r->inputs[rank_].getVal("step", -1); step >= 0) { - /// TODO: revise step setting if (step <= seq.tokens.size()) { seq.tokens.resize(step); seq.cache_len = std::min(seq.cache_len, step); @@ -1258,7 +1257,17 @@ auto LlamaBatch::Finish(GenerationState& g, int& finished_count) -> std::vect check_cuda_error(cudaStreamSynchronize(stream_)); - // invariant: context_length = sequence_length + 1 + // `SequenceManager` needs real-time value of cache length + // ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet + for (int i = 0; i < batch_size; ++i) { + if (state_->requests[i]) { + FT_CHECK(state_->sequences[i]); + state_->sequences[i]->cache_len = state_->h_context_length[i]; + } + } + + // invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just + // generated) tokens for (int i = 0; i < batch_size; ++i) { ++state_->h_context_length[i]; } @@ -1267,7 +1276,7 @@ auto LlamaBatch::Finish(GenerationState& g, int& finished_count) -> std::vect int* output_ptr = h_output_ids_; for (int i = 0; i < batch_size; ++i) { if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) { - const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len); + const int count = state_->h_context_length[i]; // TODO: sync history output tokens at when receiving the request and copy only the last token here std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]); *h_request_seqlen_ptrs_[i] = count; @@ -1284,14 +1293,6 @@ auto LlamaBatch::Finish(GenerationState& g, int& finished_count) -> std::vect TM_LOG_INFO("[finish] [%s]", ss.str().c_str()); } - // `SequenceManager` needs real-time value of cache length - for (int i = 0; i < batch_size; ++i) { - if (state_->requests[i]) { - FT_CHECK(state_->sequences[i]); - state_->sequences[i]->cache_len = state_->h_context_length[i]; - } - } - std::vector signals; { NvtxScope _("stream_and_completion_signal"); @@ -1343,8 +1344,7 @@ auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Sig FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id)); } else { - // Account for the last generated token if not a stop request (which doesn't generate) - const int output_len = state_->h_context_length[index] + 1 - static_cast(force_stop); + const int output_len = state_->h_context_length[index]; auto& seq = *state_->sequences[index]; // Update token IDs