Skip to content

Commit

Permalink
Reduce gil switching (#407)
Browse files Browse the repository at this point in the history
* reduce gil switching

* ffi lock func

* remove unused

* remove unused

* remove unused
  • Loading branch information
irexyc authored Sep 18, 2023
1 parent 2dec28a commit d44a8bf
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 4 deletions.
11 changes: 9 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,8 +899,9 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_

if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
Expand Down Expand Up @@ -938,12 +939,18 @@ void LlamaBatch<T>::finish()

check_cuda_error(cudaStreamSynchronize(stream_));

if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(1);
}
for (int i = 0; i < batch_size_; ++i) {
FT_CHECK(requests_[i] != nullptr);
if (requests_[i]->stream_cb && rank_ == 0) {
requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(0);
}

if (debug_ && rank_ == 0) {
std::stringstream ss;
Expand Down
9 changes: 9 additions & 0 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include "src/turbomind/utils/nccl_utils.h"
#include <unordered_map>

using ffi_api_lock_ctrl_t = std::function<void(int)>;

namespace turbomind {

template<typename T>
Expand Down Expand Up @@ -91,6 +93,11 @@ class LlamaV2 {
return vocab_size_;
}

void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}

private:
friend class Batch;

Expand Down Expand Up @@ -188,6 +195,8 @@ class LlamaV2 {
std::shared_ptr<SharedState> shared_state_;

std::thread internal_thread_;

ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};

} // namespace turbomind
16 changes: 14 additions & 2 deletions src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,25 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size,
int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
auto gil_control = [state = PyGILState_STATE{}](int op) mutable {
if (op) {
state = PyGILState_Ensure();
}
else {
PyGILState_Release(state);
}
};
if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>(
auto model = std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
else {
return std::make_shared<LlamaTritonModel<float>>(
auto model = std::make_shared<LlamaTritonModel<float>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
},
"model_dir"_a,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ LlamaTritonModel<T>::createModelInstance(int
instance = shared_instances_[device_id].lock();
if (!instance) {
instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm);
instance->llm->setFfiLock(ffi_lock_);
shared_instances_[device_id] = instance;
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel {

void handleMissingParams();

void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}

std::string toString() override;
int getTensorParaSize() override;
int getPipelineParaSize() override;
Expand Down Expand Up @@ -112,4 +117,6 @@ struct LlamaTritonModel: public AbstractTransformerModel {

std::string model_name_;
std::string model_dir_;

ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};

0 comments on commit d44a8bf

Please sign in to comment.