diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 0ca152c96d..829096da5f 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -940,6 +940,9 @@ void LlamaBatch::finish() check_cuda_error(cudaStreamSynchronize(stream_)); + if (rank_ == 0 && llama_->ffi_lock_) { + llama_->ffi_lock_(1); + } for (int i = 0; i < batch_size_; ++i) { FT_CHECK(requests_[i] != nullptr); if (requests_[i]->stream_cb && rank_ == 0) { @@ -947,6 +950,9 @@ void LlamaBatch::finish() requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); } } + if (rank_ == 0 && llama_->ffi_lock_) { + llama_->ffi_lock_(0); + } if (debug_ && rank_ == 0) { std::stringstream ss; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index ed13aa40f4..c52a02db0c 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -34,6 +34,8 @@ #include "src/turbomind/utils/nccl_utils.h" #include +using ffi_api_lock_ctrl_t = std::function; + namespace turbomind { template @@ -91,6 +93,11 @@ class LlamaV2 { return vocab_size_; } + void setFfiLock(ffi_api_lock_ctrl_t func) + { + ffi_lock_ = func; + } + private: friend class Batch; @@ -188,6 +195,8 @@ class LlamaV2 { std::shared_ptr shared_state_; std::thread internal_thread_; + + ffi_api_lock_ctrl_t ffi_lock_ = nullptr; }; } // namespace turbomind diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 982d4914e8..1018e016fe 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -330,18 +330,7 @@ PYBIND11_MODULE(_turbomind, m) .def( "register_callback", [](AbstractTransformerModelInstance* self, triton_stream_cb_t cb, py::object ctx) { - auto callback = [=](std::shared_ptr> outputs, - void* ctx) { - thread_local PyGILState_STATE gstate; - if (ft::is_first_in_batch()) { - gstate = PyGILState_Ensure(); - } - cb(outputs, ctx); - if (ft::is_last_in_batch()) { - PyGILState_Release(gstate); - } - }; - self->registerCallback(callback, ctx.ptr()); + self->registerCallback(cb, ctx.ptr()); }, "callback"_a, "context"_a = nullptr) @@ -356,13 +345,25 @@ PYBIND11_MODULE(_turbomind, m) size_t pipeline_para_size, int enable_custom_all_reduce, std::string data_type) -> std::shared_ptr { + auto gil_control = [state = PyGILState_STATE{}](int op) mutable { + if (op) { + state = PyGILState_Ensure(); + } + else { + PyGILState_Release(state); + } + }; if (data_type == "half" || data_type == "fp16" || data_type == "int4") { - return std::make_shared>( + auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + model->setFfiLock(gil_control); + return model; } else { - return std::make_shared>( + auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + model->setFfiLock(gil_control); + return model; } }, "model_dir"_a, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 169d6cbdba..49235833e4 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -275,6 +275,7 @@ LlamaTritonModel::createModelInstance(int instance = shared_instances_[device_id].lock(); if (!instance) { instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); + instance->llm->setFfiLock(ffi_lock_); shared_instances_[device_id] = instance; } } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index f3a3a327a9..332000ce62 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel { void handleMissingParams(); + void setFfiLock(ffi_api_lock_ctrl_t func) + { + ffi_lock_ = func; + } + std::string toString() override; int getTensorParaSize() override; int getPipelineParaSize() override; @@ -112,4 +117,6 @@ struct LlamaTritonModel: public AbstractTransformerModel { std::string model_name_; std::string model_dir_; + + ffi_api_lock_ctrl_t ffi_lock_ = nullptr; };