diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 9c48e4f818..8a1de364b0 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -126,6 +126,7 @@ LlamaV2::LlamaV2(size_t head_num, template LlamaV2::~LlamaV2() { + shared_state_->request_queue.close(); internal_thread_.join(); delete decoder_; @@ -448,12 +449,24 @@ void LlamaV2::internalThreadEntry(int device_id) request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty); + // request queue was closed + // and there are no unprocessed requests in the queue + if (is_empty && infer_requests.empty() && stop_requests.empty()) { + // rank 0 sets flag + shared_state_->should_stop = true; + } + batch_.verifyRequests(stop_requests, infer_requests); } // wait while rank-0 is dequeueing shared_state_->barrier->wait(); + // exit if job is done + if (shared_state_->should_stop) { + return; + } + bool modified = false; if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) { @@ -486,8 +499,6 @@ void LlamaV2::internalThreadEntry(int device_id) batch_.finish(); } } - - FT_CHECK(0); } template diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index c52a02db0c..40633b0a22 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -46,6 +46,9 @@ class LlamaV2 { std::vector> stop_requests; RequestQueue request_queue; std::shared_ptr barrier; + + // rank 0 sets flag to true if there are no more tasks in the request_queue + bool should_stop = false; }; ~LlamaV2(); diff --git a/src/turbomind/models/llama/Request.h b/src/turbomind/models/llama/Request.h index cb2d1858a3..0bccf84a57 100644 --- a/src/turbomind/models/llama/Request.h +++ b/src/turbomind/models/llama/Request.h @@ -44,6 +44,11 @@ class RequestQueue { futures.reserve(requests.size()); { std::lock_guard lock(mutex_); + + if (closed_) { + throw std::runtime_error("Queue is closed"); + } + for (auto& r : requests) { futures.push_back(r->signal.get_future()); if (r->stop_flag) { @@ -65,7 +70,7 @@ class RequestQueue { { std::unique_lock lock(mutex_); if (blocking) { - cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()); }); + cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty() && closed_ == false); }); } stop_requests.clear(); @@ -81,11 +86,19 @@ class RequestQueue { } } + void close() + { + std::lock_guard lock(mutex_); + closed_ = true; + cv_.notify_all(); + } + private: std::queue> stop_queue_; std::queue> infer_queue_; std::mutex mutex_; std::condition_variable cv_; + bool closed_ = false; }; } // namespace turbomind diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index a87efcd73b..1ba191d211 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -125,9 +125,15 @@ class Allocator; template<> class Allocator: public IAllocator { private: - const int device_id_; - cudaStream_t stream_ = 0; // initialize as default stream - std::unordered_map* pointer_mapping_; + enum class MemoryType + { + HOST, + DEVICE + }; + + const int device_id_; + cudaStream_t stream_ = 0; // initialize as default stream + std::unordered_map>* pointer_mapping_; bool isExist(void* address) const { @@ -136,10 +142,10 @@ class Allocator: public IAllocator { ReallocType isReMalloc(void* address, size_t size) const { FT_CHECK(isExist(address)); - if (pointer_mapping_->at(address) < size) { + if (pointer_mapping_->at(address).first < size) { return ReallocType::INCREASE; } - else if (pointer_mapping_->at(address) == size) { + else if (pointer_mapping_->at(address).first == size) { return ReallocType::REUSE; } else { @@ -151,7 +157,7 @@ class Allocator: public IAllocator { Allocator(int device_id): device_id_(device_id) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - pointer_mapping_ = new std::unordered_map(); + pointer_mapping_ = new std::unordered_map>(); #if defined(CUDA_MEMORY_POOL_DISABLED) TM_LOG_WARNING( "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free." @@ -188,7 +194,9 @@ class Allocator: public IAllocator { { TM_LOG_DEBUG(__PRETTY_FUNCTION__); while (!pointer_mapping_->empty()) { - free((void**)(&pointer_mapping_->begin()->first)); + auto ptr = pointer_mapping_->begin()->first; + auto size_and_type = pointer_mapping_->begin()->second; + free(&ptr, size_and_type.second == MemoryType::HOST); } delete pointer_mapping_; } @@ -229,18 +237,19 @@ class Allocator: public IAllocator { check_cuda_error(getSetDevice(o_device)); TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); - pointer_mapping_->insert({getAddress(ptr), size}); + pointer_mapping_->insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}}); return ptr; } - void free(void** ptr, bool is_host = false) const + void free(void** ptr, bool _ = false) const { TM_LOG_DEBUG(__PRETTY_FUNCTION__); void* address = getAddress(*ptr); if (*ptr != nullptr) { int o_device = 0; if (pointer_mapping_->count(address)) { + const auto is_host = pointer_mapping_->at(address).second == MemoryType::HOST; TM_LOG_DEBUG("Free buffer %p", address); check_cuda_error(getSetDevice(device_id_, &o_device)); if (is_host) { @@ -361,7 +370,7 @@ class Allocator: public IAllocator { { while (!pointer_mapping_->empty()) { void* ptr = pointer_mapping_->begin()->second.flat().data(); - free((void**)(&ptr)); + free(&ptr); } pointer_mapping_->clear(); delete pointer_mapping_; @@ -454,7 +463,7 @@ class Allocator: public IAllocator { TM_LOG_DEBUG(__PRETTY_FUNCTION__); while (!pointer_mapping_->empty()) { void* ptr = pointer_mapping_->begin()->second.data_ptr(); - free((void**)(&ptr)); + free(&ptr); } pointer_mapping_->clear(); delete pointer_mapping_;