Skip to content

Commit

Permalink
fix batch initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Nov 7, 2023
1 parent d3a1356 commit 55dcb8b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
55 changes: 27 additions & 28 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@

namespace turbomind {

void ClearState(BatchState& s)
{
std::fill_n(s.requests.begin(), s.size, nullptr);
std::fill_n(s.sequences.begin(), s.size, nullptr);
s.size = s.active_size = 0;
}

template<typename T>
void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
{
Expand Down Expand Up @@ -184,6 +191,8 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
// sanity check, incoming request in previous iter should have been moved to `state_`
FT_CHECK(!state.requests[i]);

TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);

state.requests[i] = r;

// get sequence for the request
Expand Down Expand Up @@ -328,11 +337,6 @@ bool LlamaBatch<T>::Initialize()
process(state_);
process(incoming_);

// dbg(sequences);
// dbg(context_lengths);
// dbg(priorities);
// dbg(step_length_);

auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);

if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
Expand All @@ -344,7 +348,7 @@ bool LlamaBatch<T>::Initialize()
std::vector<int> idxs(sequences.size());
std::iota(idxs.begin(), idxs.end(), 0);

if (exchange) {
if (exchange || holes || incoming_->size) {
// put active ones first
auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
return sequences[idx]->status == Sequence::kActive; // present status
Expand All @@ -366,11 +370,9 @@ bool LlamaBatch<T>::Initialize()
}
std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; });
}
}

if (exchange || holes) {
// Copy sequence states to the back state buffer
back_->size = back_->active_size = 0;
// Copy sequence states to back buffer
FT_CHECK(back_->size == 0 && back_->active_size == 0);
for (const auto& i : idxs) {
auto& s = *sequences[i];
if (exchange) {
Expand All @@ -379,6 +381,7 @@ bool LlamaBatch<T>::Initialize()
if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
SaveRandomState(*state, idx);
}
// mark swap-ins
if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
state->is_swap_in[idx] = 1;
}
Expand All @@ -390,17 +393,24 @@ bool LlamaBatch<T>::Initialize()
}
// Swap the buffers
std::swap(state_, back_);
}

const int batch_size = state_->active_size;
ClearState(*back_);
ClearState(*incoming_);
}

/// Update block ptrs when there were
// 1. swap-in or swap-out
// 2. holes in the active buffer
// 3. new allocations (for exsiting active sequences)
if (exchange || active_holes || outcome.allocation) {
// Prepare intermediate buffers
h_cu_block_counts_[0] = 0;

auto k_ptrs = h_k_block_ptrs_;
auto v_ptrs = h_v_block_ptrs_;

const int batch_size = state_->active_size;

for (int i = 0; i < batch_size; ++i) {
const auto& seq = *state_->sequences[i];

Expand All @@ -415,28 +425,17 @@ bool LlamaBatch<T>::Initialize()
});
}

// if (1) {
// std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1);
// dbg(cu_block_cnts);
// }
// dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size]));
// dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size]));
// dbg(h_cu_block_counts_[batch_size]);
static_assert(sizeof(uintptr_t) == sizeof(void*));

Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);

static_assert(sizeof(uintptr_t) == sizeof(void*));
}

// clear incoming buffer
std::fill_n(incoming_->requests.begin(), incoming_->size, nullptr);
std::fill_n(incoming_->sequences.begin(), incoming_->size, nullptr);
incoming_->size = 0;

// in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
// generation & sampling need to be re-initialized for correctness
/// Layout of the buffers is changed, generation & sampling need to be re-initialized for correctness when there
/// were
// 1. swap-in or swap-out
// 2. holes in the active buffer
return exchange || active_holes;
}

Expand Down
15 changes: 12 additions & 3 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include <functional>
#include <memory>
#include <sstream>
#include <stdexcept>

namespace turbomind {

Expand Down Expand Up @@ -545,15 +544,25 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,
bool has_error = 0;
if (rank == 0) {
TM_LOG_INFO("[forward] Enqueue requests");

std::vector<uint64_t> ids;
for (const auto& r : requests) {
ids.push_back(r->id);
}

auto futures = shared_state_->request_queue.enqueue(std::move(requests));

FT_CHECK_WITH_INFO(ids.size() == futures.size(), "check failed");

TM_LOG_INFO("[forward] Wait for requests to complete ...");
for (auto& f : futures) {
auto ec = f.get();

for (int i = 0; i < futures.size(); ++i) {
auto ec = futures[i].get();
error_codes.push_back(ec);
if (ec) {
has_error = true;
}
TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec);
}
}

Expand Down

0 comments on commit 55dcb8b

Please sign in to comment.