From a7c5007c238830238f68aa88bc37cc5e424fa82b Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 23 Nov 2023 21:00:42 +0800 Subject: [PATCH] [Fix] Skip empty batch (#747) --- src/turbomind/models/llama/LlamaBatch.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index a932a12d3..76394e862 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -475,6 +475,10 @@ bool LlamaBatch::Initialize() template void LlamaBatch::CopyState(const std::vector>& desc) { + if (desc.empty()) { + return; + } + std::vector idxs(desc.size()); std::iota(idxs.begin(), idxs.end(), 0); @@ -1430,18 +1434,21 @@ void LlamaBatch::InternalThreadEntry(int device_id) // finished sequences is handled by `Initialize()` finished_count = 0; - ContextDecode(); - if (state_->active_size) { + + ContextDecode(); + if (modified) { g = InitializeGeneration(); InitializeSampling(); } + for (int i = 0; i < step_length_; ++i) { if (!Generate(g)) { break; } } + if (auto signals = Finish(g, finished_count); !signals.empty()) { if (finished_count) { // Finished requests and corresponding output tensors will be released when notified