From fcd70ab37d062293716de4f2b3c5d44b1c6a57a8 Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 12 Sep 2023 13:51:10 +0800 Subject: [PATCH 1/5] reduce gil switching --- src/turbomind/models/llama/CMakeLists.txt | 1 + src/turbomind/models/llama/LlamaBatch.cc | 7 +++-- src/turbomind/python/bind.cpp | 14 +++++++++- src/turbomind/utils/CMakeLists.txt | 3 +++ src/turbomind/utils/pycb_utils.cc | 31 +++++++++++++++++++++++ src/turbomind/utils/pycb_utils.h | 15 +++++++++++ 6 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 src/turbomind/utils/pycb_utils.cc create mode 100644 src/turbomind/utils/pycb_utils.h diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 10b93fb9ec..3d05389a31 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart nccl_utils cuda_utils logger + pycb_utils llama_fmha) if (NOT MSVC) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 83db7ad65d..0ca152c96d 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -9,6 +9,7 @@ #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/logger.h" +#include "src/turbomind/utils/pycb_utils.h" #include #include #include @@ -899,8 +900,9 @@ void LlamaBatch::outputContextLogits(T* context_decoder_ if (context_logits_buf_ == nullptr) { NcclGuard guard(llama_->tensor_para_, stream_, true); - context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); - const auto tp = llama_->tensor_para_.world_size_; + context_logits_buf_ = + (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); + const auto tp = llama_->tensor_para_.world_size_; if (tp > 1) { FT_CHECK(llama_->vocab_size_padded_ % tp == 0); const auto local_vocab_size = llama_->vocab_size_padded_ / tp; @@ -941,6 +943,7 @@ void LlamaBatch::finish() for (int i = 0; i < batch_size_; ++i) { FT_CHECK(requests_[i] != nullptr); if (requests_[i]->stream_cb && rank_ == 0) { + set_batch_info(i, batch_size_); requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); } } diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 592b2b30e6..982d4914e8 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -4,6 +4,7 @@ #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/nccl_utils.h" +#include "src/turbomind/utils/pycb_utils.h" #include #include #include @@ -329,7 +330,18 @@ PYBIND11_MODULE(_turbomind, m) .def( "register_callback", [](AbstractTransformerModelInstance* self, triton_stream_cb_t cb, py::object ctx) { - self->registerCallback(cb, ctx.ptr()); + auto callback = [=](std::shared_ptr> outputs, + void* ctx) { + thread_local PyGILState_STATE gstate; + if (ft::is_first_in_batch()) { + gstate = PyGILState_Ensure(); + } + cb(outputs, ctx); + if (ft::is_last_in_batch()) { + PyGILState_Release(gstate); + } + }; + self->registerCallback(callback, ctx.ptr()); }, "callback"_a, "context"_a = nullptr) diff --git a/src/turbomind/utils/CMakeLists.txt b/src/turbomind/utils/CMakeLists.txt index 113ef4f25a..00f83769e6 100644 --- a/src/turbomind/utils/CMakeLists.txt +++ b/src/turbomind/utils/CMakeLists.txt @@ -109,3 +109,6 @@ add_library(tensor STATIC Tensor.cc) set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(tensor PUBLIC cuda_utils logger) + +add_library(pycb_utils STATIC pycb_utils.cc) +set_property(TARGET pycb_utils PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/turbomind/utils/pycb_utils.cc b/src/turbomind/utils/pycb_utils.cc new file mode 100644 index 0000000000..5fb72d4553 --- /dev/null +++ b/src/turbomind/utils/pycb_utils.cc @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "pycb_utils.h" +#include + +namespace turbomind { + +thread_local std::shared_ptr _current; +thread_local std::shared_ptr _total; + +void set_batch_info(int current, int total) +{ + if (!_current) { + _current = std::make_shared(); + _total = std::make_shared(); + } + *_current = current; + *_total = total; +} + +int is_first_in_batch() +{ + return *_current == 0; +} + +int is_last_in_batch() +{ + return *_current == (*_total - 1); +} + +} // namespace turbomind diff --git a/src/turbomind/utils/pycb_utils.h b/src/turbomind/utils/pycb_utils.h new file mode 100644 index 0000000000..45757379b2 --- /dev/null +++ b/src/turbomind/utils/pycb_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include + +namespace turbomind { + +void set_batch_info(int current, int total); + +int is_first_in_batch(); + +int is_last_in_batch(); + +} // namespace turbomind From f579ae7e56995174ea1c098ae01ddd22b7221a66 Mon Sep 17 00:00:00 2001 From: chenxin Date: Wed, 13 Sep 2023 07:55:58 +0000 Subject: [PATCH 2/5] ffi lock func --- src/turbomind/models/llama/LlamaBatch.cc | 6 ++++ src/turbomind/models/llama/LlamaV2.h | 9 ++++++ src/turbomind/python/bind.cpp | 29 ++++++++++--------- .../triton_backend/llama/LlamaTritonModel.cc | 1 + .../triton_backend/llama/LlamaTritonModel.h | 7 +++++ 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 0ca152c96d..829096da5f 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -940,6 +940,9 @@ void LlamaBatch::finish() check_cuda_error(cudaStreamSynchronize(stream_)); + if (rank_ == 0 && llama_->ffi_lock_) { + llama_->ffi_lock_(1); + } for (int i = 0; i < batch_size_; ++i) { FT_CHECK(requests_[i] != nullptr); if (requests_[i]->stream_cb && rank_ == 0) { @@ -947,6 +950,9 @@ void LlamaBatch::finish() requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); } } + if (rank_ == 0 && llama_->ffi_lock_) { + llama_->ffi_lock_(0); + } if (debug_ && rank_ == 0) { std::stringstream ss; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index ed13aa40f4..c52a02db0c 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -34,6 +34,8 @@ #include "src/turbomind/utils/nccl_utils.h" #include +using ffi_api_lock_ctrl_t = std::function; + namespace turbomind { template @@ -91,6 +93,11 @@ class LlamaV2 { return vocab_size_; } + void setFfiLock(ffi_api_lock_ctrl_t func) + { + ffi_lock_ = func; + } + private: friend class Batch; @@ -188,6 +195,8 @@ class LlamaV2 { std::shared_ptr shared_state_; std::thread internal_thread_; + + ffi_api_lock_ctrl_t ffi_lock_ = nullptr; }; } // namespace turbomind diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 982d4914e8..1018e016fe 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -330,18 +330,7 @@ PYBIND11_MODULE(_turbomind, m) .def( "register_callback", [](AbstractTransformerModelInstance* self, triton_stream_cb_t cb, py::object ctx) { - auto callback = [=](std::shared_ptr> outputs, - void* ctx) { - thread_local PyGILState_STATE gstate; - if (ft::is_first_in_batch()) { - gstate = PyGILState_Ensure(); - } - cb(outputs, ctx); - if (ft::is_last_in_batch()) { - PyGILState_Release(gstate); - } - }; - self->registerCallback(callback, ctx.ptr()); + self->registerCallback(cb, ctx.ptr()); }, "callback"_a, "context"_a = nullptr) @@ -356,13 +345,25 @@ PYBIND11_MODULE(_turbomind, m) size_t pipeline_para_size, int enable_custom_all_reduce, std::string data_type) -> std::shared_ptr { + auto gil_control = [state = PyGILState_STATE{}](int op) mutable { + if (op) { + state = PyGILState_Ensure(); + } + else { + PyGILState_Release(state); + } + }; if (data_type == "half" || data_type == "fp16" || data_type == "int4") { - return std::make_shared>( + auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + model->setFfiLock(gil_control); + return model; } else { - return std::make_shared>( + auto model = std::make_shared>( tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); + model->setFfiLock(gil_control); + return model; } }, "model_dir"_a, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 169d6cbdba..49235833e4 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -275,6 +275,7 @@ LlamaTritonModel::createModelInstance(int instance = shared_instances_[device_id].lock(); if (!instance) { instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); + instance->llm->setFfiLock(ffi_lock_); shared_instances_[device_id] = instance; } } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index f3a3a327a9..332000ce62 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel { void handleMissingParams(); + void setFfiLock(ffi_api_lock_ctrl_t func) + { + ffi_lock_ = func; + } + std::string toString() override; int getTensorParaSize() override; int getPipelineParaSize() override; @@ -112,4 +117,6 @@ struct LlamaTritonModel: public AbstractTransformerModel { std::string model_name_; std::string model_dir_; + + ffi_api_lock_ctrl_t ffi_lock_ = nullptr; }; From ba1b25578afdaef98bad0d1eed059e1562c6b985 Mon Sep 17 00:00:00 2001 From: chenxin Date: Wed, 13 Sep 2023 09:08:03 +0000 Subject: [PATCH 3/5] remove unused --- src/turbomind/models/llama/LlamaBatch.cc | 1 - src/turbomind/python/bind.cpp | 1 - src/turbomind/utils/CMakeLists.txt | 3 --- src/turbomind/utils/pycb_utils.cc | 31 ------------------------ src/turbomind/utils/pycb_utils.h | 15 ------------ 5 files changed, 51 deletions(-) delete mode 100644 src/turbomind/utils/pycb_utils.cc delete mode 100644 src/turbomind/utils/pycb_utils.h diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 829096da5f..f689b90ddb 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -9,7 +9,6 @@ #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/logger.h" -#include "src/turbomind/utils/pycb_utils.h" #include #include #include diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 1018e016fe..b55ed040af 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -4,7 +4,6 @@ #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/nccl_utils.h" -#include "src/turbomind/utils/pycb_utils.h" #include #include #include diff --git a/src/turbomind/utils/CMakeLists.txt b/src/turbomind/utils/CMakeLists.txt index 00f83769e6..113ef4f25a 100644 --- a/src/turbomind/utils/CMakeLists.txt +++ b/src/turbomind/utils/CMakeLists.txt @@ -109,6 +109,3 @@ add_library(tensor STATIC Tensor.cc) set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(tensor PUBLIC cuda_utils logger) - -add_library(pycb_utils STATIC pycb_utils.cc) -set_property(TARGET pycb_utils PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/turbomind/utils/pycb_utils.cc b/src/turbomind/utils/pycb_utils.cc deleted file mode 100644 index 5fb72d4553..0000000000 --- a/src/turbomind/utils/pycb_utils.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "pycb_utils.h" -#include - -namespace turbomind { - -thread_local std::shared_ptr _current; -thread_local std::shared_ptr _total; - -void set_batch_info(int current, int total) -{ - if (!_current) { - _current = std::make_shared(); - _total = std::make_shared(); - } - *_current = current; - *_total = total; -} - -int is_first_in_batch() -{ - return *_current == 0; -} - -int is_last_in_batch() -{ - return *_current == (*_total - 1); -} - -} // namespace turbomind diff --git a/src/turbomind/utils/pycb_utils.h b/src/turbomind/utils/pycb_utils.h deleted file mode 100644 index 45757379b2..0000000000 --- a/src/turbomind/utils/pycb_utils.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#pragma once - -#include - -namespace turbomind { - -void set_batch_info(int current, int total); - -int is_first_in_batch(); - -int is_last_in_batch(); - -} // namespace turbomind From 4080a3a3dee97754ddfdadf6eecbcfbfa8bfea26 Mon Sep 17 00:00:00 2001 From: chenxin Date: Wed, 13 Sep 2023 09:09:21 +0000 Subject: [PATCH 4/5] remove unused --- src/turbomind/models/llama/LlamaBatch.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index f689b90ddb..995f15b710 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -945,7 +945,6 @@ void LlamaBatch::finish() for (int i = 0; i < batch_size_; ++i) { FT_CHECK(requests_[i] != nullptr); if (requests_[i]->stream_cb && rank_ == 0) { - set_batch_info(i, batch_size_); requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); } } From 7b7a1a8ed837c01ebbecd950a84ca704b423a566 Mon Sep 17 00:00:00 2001 From: chenxin Date: Wed, 13 Sep 2023 09:11:31 +0000 Subject: [PATCH 5/5] remove unused --- src/turbomind/models/llama/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 3d05389a31..10b93fb9ec 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -39,7 +39,6 @@ target_link_libraries(Llama PUBLIC CUDA::cudart nccl_utils cuda_utils logger - pycb_utils llama_fmha) if (NOT MSVC)