Skip to content

Commit

Permalink
ffi lock func
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Sep 13, 2023
1 parent fcd70ab commit f579ae7
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -940,13 +940,19 @@ void LlamaBatch<T>::finish()

check_cuda_error(cudaStreamSynchronize(stream_));

if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(1);
}
for (int i = 0; i < batch_size_; ++i) {
FT_CHECK(requests_[i] != nullptr);
if (requests_[i]->stream_cb && rank_ == 0) {
set_batch_info(i, batch_size_);
requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(0);
}

if (debug_ && rank_ == 0) {
std::stringstream ss;
Expand Down
9 changes: 9 additions & 0 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include "src/turbomind/utils/nccl_utils.h"
#include <unordered_map>

using ffi_api_lock_ctrl_t = std::function<void(int)>;

namespace turbomind {

template<typename T>
Expand Down Expand Up @@ -91,6 +93,11 @@ class LlamaV2 {
return vocab_size_;
}

void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}

private:
friend class Batch;

Expand Down Expand Up @@ -188,6 +195,8 @@ class LlamaV2 {
std::shared_ptr<SharedState> shared_state_;

std::thread internal_thread_;

ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};

} // namespace turbomind
29 changes: 15 additions & 14 deletions src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,18 +330,7 @@ PYBIND11_MODULE(_turbomind, m)
.def(
"register_callback",
[](AbstractTransformerModelInstance* self, triton_stream_cb_t cb, py::object ctx) {
auto callback = [=](std::shared_ptr<std::unordered_map<std::string, triton::Tensor>> outputs,
void* ctx) {
thread_local PyGILState_STATE gstate;
if (ft::is_first_in_batch()) {
gstate = PyGILState_Ensure();
}
cb(outputs, ctx);
if (ft::is_last_in_batch()) {
PyGILState_Release(gstate);
}
};
self->registerCallback(callback, ctx.ptr());
self->registerCallback(cb, ctx.ptr());
},
"callback"_a,
"context"_a = nullptr)
Expand All @@ -356,13 +345,25 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size,
int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
auto gil_control = [state = PyGILState_STATE{}](int op) mutable {
if (op) {
state = PyGILState_Ensure();
}
else {
PyGILState_Release(state);
}
};
if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>(
auto model = std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
else {
return std::make_shared<LlamaTritonModel<float>>(
auto model = std::make_shared<LlamaTritonModel<float>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
},
"model_dir"_a,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ LlamaTritonModel<T>::createModelInstance(int
instance = shared_instances_[device_id].lock();
if (!instance) {
instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm);
instance->llm->setFfiLock(ffi_lock_);
shared_instances_[device_id] = instance;
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel {

void handleMissingParams();

void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}

std::string toString() override;
int getTensorParaSize() override;
int getPipelineParaSize() override;
Expand Down Expand Up @@ -112,4 +117,6 @@ struct LlamaTritonModel: public AbstractTransformerModel {

std::string model_name_;
std::string model_dir_;

ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};

0 comments on commit f579ae7

Please sign in to comment.