From d44a8bfea49cda9b74960e7cfc61f16ae0e59808 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Mon, 18 Sep 2023 20:03:44 +0800 Subject: [PATCH] Reduce gil switching (#407) * reduce gil switching * ffi lock func * remove unused * remove unused * remove unused --- src/turbomind/models/llama/LlamaBatch.cc | 11 +++++++++-- src/turbomind/models/llama/LlamaV2.h | 9 +++++++++ src/turbomind/python/bind.cpp | 16 ++++++++++++++-- .../triton_backend/llama/LlamaTritonModel.cc | 1 + .../triton_backend/llama/LlamaTritonModel.h | 7 +++++++ 5 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 83db7ad65d..995f15b710 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -899,8 +899,9 @@ void LlamaBatch::outputContextLogits(T* context_decoder_ if (context_logits_buf_ == nullptr) { NcclGuard guard(llama_->tensor_para_, stream_, true); - context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); - const auto tp = llama_->tensor_para_.world_size_; + context_logits_buf_ = + (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); + const auto tp = llama_->tensor_para_.world_size_; if (tp > 1) { FT_CHECK(llama_->vocab_size_padded_ % tp == 0); const auto local_vocab_size = llama_->vocab_size_padded_ / tp; @@ -938,12 +939,18 @@ void LlamaBatch::finish() check_cuda_error(cudaStreamSynchronize(stream_)); + if (rank_ == 0 && llama_->ffi_lock_) { + llama_->ffi_lock_(1); + } for (int i = 0; i < batch_size_; ++i) { FT_CHECK(requests_[i] != nullptr); if (requests_[i]->stream_cb && rank_ == 0) { requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); } } + if (rank_ == 0 && llama_->ffi_lock_) { + llama_->ffi_lock_(0); + } if (debug_ && rank_ == 0) { std::stringstream ss; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index ed13aa40f4..c52a02db0c 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -34,6 +34,8 @@ #include "src/turbomind/utils/nccl_utils.h" #include +using ffi_api_lock_ctrl_t = std::function; + namespace turbomind { template @@ -91,6 +93,11 @@ class LlamaV2 { return vocab_size_; } + void setFfiLock(ffi_api_lock_ctrl_t func) + { + ffi_lock_ = func; + } + private: friend class Batch; @@ -188,6 +195,8 @@ class LlamaV2 { std::shared_ptr shared_state_; std::thread internal_thread_; + + ffi_api_lock_ctrl_t ffi_lock_ = nullptr; }; } // namespace turbomind diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 592b2b30e6..b55ed040af 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -344,13 +344,25 @@ PYBIND11_MODULE(_turbomind, m) size_t pipeline_para_size, int enable_custom_all_reduce, std::string data_type) -> std::shared_ptr { + auto gil_control = [state = PyGILState_STATE{}](int op) mutable { + if (op) { + state = PyGILState_Ensure(); + } + else { + PyGILState_Release(state); + } + }; if (data_type == "half" || data_type == "fp16" || data_type == "int4") { - return std::make_shared>( + auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + model->setFfiLock(gil_control); + return model; } else { - return std::make_shared>( + auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + model->setFfiLock(gil_control); + return model; } }, "model_dir"_a, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 456f5f41c4..57d5c9be5b 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -276,6 +276,7 @@ LlamaTritonModel::createModelInstance(int instance = shared_instances_[device_id].lock(); if (!instance) { instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); + instance->llm->setFfiLock(ffi_lock_); shared_instances_[device_id] = instance; } } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index f3a3a327a9..332000ce62 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel { void handleMissingParams(); + void setFfiLock(ffi_api_lock_ctrl_t func) + { + ffi_lock_ = func; + } + std::string toString() override; int getTensorParaSize() override; int getPipelineParaSize() override; @@ -112,4 +117,6 @@ struct LlamaTritonModel: public AbstractTransformerModel { std::string model_name_; std::string model_dir_; + + ffi_api_lock_ctrl_t ffi_lock_ = nullptr; };