diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 10b93fb9ec..3d05389a31 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart nccl_utils cuda_utils logger + pycb_utils llama_fmha) if (NOT MSVC) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 83db7ad65d..0ca152c96d 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -9,6 +9,7 @@ #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/logger.h" +#include "src/turbomind/utils/pycb_utils.h" #include #include #include @@ -899,8 +900,9 @@ void LlamaBatch::outputContextLogits(T* context_decoder_ if (context_logits_buf_ == nullptr) { NcclGuard guard(llama_->tensor_para_, stream_, true); - context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); - const auto tp = llama_->tensor_para_.world_size_; + context_logits_buf_ = + (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); + const auto tp = llama_->tensor_para_.world_size_; if (tp > 1) { FT_CHECK(llama_->vocab_size_padded_ % tp == 0); const auto local_vocab_size = llama_->vocab_size_padded_ / tp; @@ -941,6 +943,7 @@ void LlamaBatch::finish() for (int i = 0; i < batch_size_; ++i) { FT_CHECK(requests_[i] != nullptr); if (requests_[i]->stream_cb && rank_ == 0) { + set_batch_info(i, batch_size_); requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); } } diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 592b2b30e6..982d4914e8 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -4,6 +4,7 @@ #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/nccl_utils.h" +#include "src/turbomind/utils/pycb_utils.h" #include #include #include @@ -329,7 +330,18 @@ PYBIND11_MODULE(_turbomind, m) .def( "register_callback", [](AbstractTransformerModelInstance* self, triton_stream_cb_t cb, py::object ctx) { - self->registerCallback(cb, ctx.ptr()); + auto callback = [=](std::shared_ptr> outputs, + void* ctx) { + thread_local PyGILState_STATE gstate; + if (ft::is_first_in_batch()) { + gstate = PyGILState_Ensure(); + } + cb(outputs, ctx); + if (ft::is_last_in_batch()) { + PyGILState_Release(gstate); + } + }; + self->registerCallback(callback, ctx.ptr()); }, "callback"_a, "context"_a = nullptr) diff --git a/src/turbomind/utils/CMakeLists.txt b/src/turbomind/utils/CMakeLists.txt index 113ef4f25a..00f83769e6 100644 --- a/src/turbomind/utils/CMakeLists.txt +++ b/src/turbomind/utils/CMakeLists.txt @@ -109,3 +109,6 @@ add_library(tensor STATIC Tensor.cc) set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(tensor PUBLIC cuda_utils logger) + +add_library(pycb_utils STATIC pycb_utils.cc) +set_property(TARGET pycb_utils PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/turbomind/utils/pycb_utils.cc b/src/turbomind/utils/pycb_utils.cc new file mode 100644 index 0000000000..5fb72d4553 --- /dev/null +++ b/src/turbomind/utils/pycb_utils.cc @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "pycb_utils.h" +#include + +namespace turbomind { + +thread_local std::shared_ptr _current; +thread_local std::shared_ptr _total; + +void set_batch_info(int current, int total) +{ + if (!_current) { + _current = std::make_shared(); + _total = std::make_shared(); + } + *_current = current; + *_total = total; +} + +int is_first_in_batch() +{ + return *_current == 0; +} + +int is_last_in_batch() +{ + return *_current == (*_total - 1); +} + +} // namespace turbomind diff --git a/src/turbomind/utils/pycb_utils.h b/src/turbomind/utils/pycb_utils.h new file mode 100644 index 0000000000..45757379b2 --- /dev/null +++ b/src/turbomind/utils/pycb_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include + +namespace turbomind { + +void set_batch_info(int current, int total); + +int is_first_in_batch(); + +int is_last_in_batch(); + +} // namespace turbomind