Skip to content

Commit

Permalink
Fix cache/output length calculation (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Nov 23, 2023
1 parent 6b00f62 commit 434961c
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
auto& seq = *state.sequences[idx];

if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
Expand Down Expand Up @@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect

check_cuda_error(cudaStreamSynchronize(stream_));

// invariant: context_length = sequence_length + 1
// `SequenceManager` needs real-time value of cache length
// ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}

// invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
// generated) tokens
for (int i = 0; i < batch_size; ++i) {
++state_->h_context_length[i];
}
Expand All @@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len);
const int count = state_->h_context_length[i];
// TODO: sync history output tokens at when receiving the request and copy only the last token here
std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]);
*h_request_seqlen_ptrs_[i] = count;
Expand All @@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
}

// `SequenceManager` needs real-time value of cache length
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}

std::vector<Signal> signals;
{
NvtxScope _("stream_and_completion_signal");
Expand Down Expand Up @@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
}
else {
// Account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(force_stop);
const int output_len = state_->h_context_length[index];
auto& seq = *state_->sequences[index];

// Update token IDs
Expand Down

0 comments on commit 434961c

Please sign in to comment.