Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce gil switching #407

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,8 +899,9 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_

if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
Expand Down Expand Up @@ -938,12 +939,18 @@ void LlamaBatch<T>::finish()

check_cuda_error(cudaStreamSynchronize(stream_));

if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(1);
}
for (int i = 0; i < batch_size_; ++i) {
FT_CHECK(requests_[i] != nullptr);
if (requests_[i]->stream_cb && rank_ == 0) {
requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(0);
}

if (debug_ && rank_ == 0) {
std::stringstream ss;
Expand Down
9 changes: 9 additions & 0 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include "src/turbomind/utils/nccl_utils.h"
#include <unordered_map>

using ffi_api_lock_ctrl_t = std::function<void(int)>;

namespace turbomind {

template<typename T>
Expand Down Expand Up @@ -91,6 +93,11 @@ class LlamaV2 {
return vocab_size_;
}

void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}

private:
friend class Batch;

Expand Down Expand Up @@ -188,6 +195,8 @@ class LlamaV2 {
std::shared_ptr<SharedState> shared_state_;

std::thread internal_thread_;

ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};

} // namespace turbomind
16 changes: 14 additions & 2 deletions src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,25 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size,
int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
auto gil_control = [state = PyGILState_STATE{}](int op) mutable {
if (op) {
state = PyGILState_Ensure();
}
else {
PyGILState_Release(state);
}
};
if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>(
auto model = std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
else {
return std::make_shared<LlamaTritonModel<float>>(
auto model = std::make_shared<LlamaTritonModel<float>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
},
"model_dir"_a,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ LlamaTritonModel<T>::createModelInstance(int
instance = shared_instances_[device_id].lock();
if (!instance) {
instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm);
instance->llm->setFfiLock(ffi_lock_);
shared_instances_[device_id] = instance;
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel {

void handleMissingParams();

void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}

std::string toString() override;
int getTensorParaSize() override;
int getPipelineParaSize() override;
Expand Down Expand Up @@ -112,4 +117,6 @@ struct LlamaTritonModel: public AbstractTransformerModel {

std::string model_name_;
std::string model_dir_;

ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};