From 55dcb8bd7d3d37c4011cbef3663dd9a6a47cece8 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Tue, 7 Nov 2023 12:24:14 +0000 Subject: [PATCH] fix batch initialization --- src/turbomind/models/llama/LlamaBatch.cc | 55 ++++++++++++------------ src/turbomind/models/llama/LlamaV2.cc | 15 +++++-- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 9ca3fdc761..a0109b91e4 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -25,6 +25,13 @@ namespace turbomind { +void ClearState(BatchState& s) +{ + std::fill_n(s.requests.begin(), s.size, nullptr); + std::fill_n(s.sequences.begin(), s.size, nullptr); + s.size = s.active_size = 0; +} + template void LlamaBatch::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs) { @@ -184,6 +191,8 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) // sanity check, incoming request in previous iter should have been moved to `state_` FT_CHECK(!state.requests[i]); + TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id); + state.requests[i] = r; // get sequence for the request @@ -328,11 +337,6 @@ bool LlamaBatch::Initialize() process(state_); process(incoming_); - // dbg(sequences); - // dbg(context_lengths); - // dbg(priorities); - // dbg(step_length_); - auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_); if (outcome.allocation || outcome.swap_in || outcome.swap_out) { @@ -344,7 +348,7 @@ bool LlamaBatch::Initialize() std::vector idxs(sequences.size()); std::iota(idxs.begin(), idxs.end(), 0); - if (exchange) { + if (exchange || holes || incoming_->size) { // put active ones first auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) { return sequences[idx]->status == Sequence::kActive; // present status @@ -366,11 +370,9 @@ bool LlamaBatch::Initialize() } std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; }); } - } - if (exchange || holes) { - // Copy sequence states to the back state buffer - back_->size = back_->active_size = 0; + // Copy sequence states to back buffer + FT_CHECK(back_->size == 0 && back_->active_size == 0); for (const auto& i : idxs) { auto& s = *sequences[i]; if (exchange) { @@ -379,6 +381,7 @@ bool LlamaBatch::Initialize() if (status[i] == Sequence::kActive && s.status != Sequence::kActive) { SaveRandomState(*state, idx); } + // mark swap-ins if (status[i] != Sequence::kActive && s.status == Sequence::kActive) { state->is_swap_in[idx] = 1; } @@ -390,10 +393,15 @@ bool LlamaBatch::Initialize() } // Swap the buffers std::swap(state_, back_); - } - const int batch_size = state_->active_size; + ClearState(*back_); + ClearState(*incoming_); + } + /// Update block ptrs when there were + // 1. swap-in or swap-out + // 2. holes in the active buffer + // 3. new allocations (for exsiting active sequences) if (exchange || active_holes || outcome.allocation) { // Prepare intermediate buffers h_cu_block_counts_[0] = 0; @@ -401,6 +409,8 @@ bool LlamaBatch::Initialize() auto k_ptrs = h_k_block_ptrs_; auto v_ptrs = h_v_block_ptrs_; + const int batch_size = state_->active_size; + for (int i = 0; i < batch_size; ++i) { const auto& seq = *state_->sequences[i]; @@ -415,28 +425,17 @@ bool LlamaBatch::Initialize() }); } - // if (1) { - // std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1); - // dbg(cu_block_cnts); - // } - // dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size])); - // dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size])); - // dbg(h_cu_block_counts_[batch_size]); + static_assert(sizeof(uintptr_t) == sizeof(void*)); Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_); Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_); Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_); - - static_assert(sizeof(uintptr_t) == sizeof(void*)); } - // clear incoming buffer - std::fill_n(incoming_->requests.begin(), incoming_->size, nullptr); - std::fill_n(incoming_->sequences.begin(), incoming_->size, nullptr); - incoming_->size = 0; - - // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed - // generation & sampling need to be re-initialized for correctness + /// Layout of the buffers is changed, generation & sampling need to be re-initialized for correctness when there + /// were + // 1. swap-in or swap-out + // 2. holes in the active buffer return exchange || active_holes; } diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index fca323f0ad..f7a6e784f4 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -37,7 +37,6 @@ #include #include #include -#include namespace turbomind { @@ -545,15 +544,25 @@ void LlamaV2::forward(std::unordered_map* outputs, bool has_error = 0; if (rank == 0) { TM_LOG_INFO("[forward] Enqueue requests"); + + std::vector ids; + for (const auto& r : requests) { + ids.push_back(r->id); + } + auto futures = shared_state_->request_queue.enqueue(std::move(requests)); + FT_CHECK_WITH_INFO(ids.size() == futures.size(), "check failed"); + TM_LOG_INFO("[forward] Wait for requests to complete ..."); - for (auto& f : futures) { - auto ec = f.get(); + + for (int i = 0; i < futures.size(); ++i) { + auto ec = futures[i].get(); error_codes.push_back(ec); if (ec) { has_error = true; } + TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec); } }