Skip to content

Mutual exclusion for vision encoder & LLM #3126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __init__(self,
try_import_deeplink(backend_config.device_type)
self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config)
super().__init__(model_path, backend=backend, backend_config=backend_config, **kwargs)
if backend == 'turbomind':
# for mutual exclusion with LLM inference
# reading `model_comm.mutex` will trigger its creation in tm
self.vl_encoder.tm_mutex = self.engine.model_comm.mutex
if self.model_name == 'base':
raise RuntimeError(
'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501
Expand Down
16 changes: 14 additions & 2 deletions lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
self.executor = ThreadPoolExecutor(max_workers=1)
self.tm_mutex = None # set by `VLAsyncEngine` later if needed
torch.cuda.empty_cache()

async def preprocess(self, messages: List[Dict]) -> List[Dict]:
Expand All @@ -55,8 +56,19 @@ async def async_infer(self, messages: List[Dict]) -> List[Dict]:
messages (List[Dict]): a list of message, which is the output
of `preprocess()`
"""
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,
self.max_batch_size)

def forward():
if self.tm_mutex:
self.tm_mutex.lock()
try:
return self.model.forward(messages, self.max_batch_size)
finally:
# TODO: make sure work on GPU is actually done at this point
if self.tm_mutex:
self.tm_mutex.unlock()

future = asyncio.get_event_loop().run_in_executor(self.executor, forward)

future.add_done_callback(_raise_exception_on_finish)
outputs = await future
return outputs
Expand Down
18 changes: 14 additions & 4 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
}

template<typename T>
void LlamaBatch<T>::InternalThreadEntry()
void LlamaBatch<T>::InternalThreadEntry() noexcept
{
// TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
check_cuda_error(cudaSetDevice(device_id_));
Expand Down Expand Up @@ -1587,6 +1587,9 @@ void LlamaBatch<T>::InternalThreadEntry()

NvtxScope scope("mainloop");

if (rank_ == 0 && shared_state_->mutex) {
shared_state_->mutex->lock();
}
// 1. Wait while rank-0 is dequeueing
// 2. Broadcast `ec` from rank-0
shared_state_->barrier->wait();
Expand Down Expand Up @@ -1637,6 +1640,15 @@ void LlamaBatch<T>::InternalThreadEntry()
gateway_->notify(std::move(signals));
}
}

if (shared_state_->mutex) {
check_cuda_error(cudaStreamSynchronize(stream_));
shared_state_->barrier->wait();
if (rank_ == 0) {
// release the lock to external modules such as VL encoder
shared_state_->mutex->unlock();
}
}
}

// Unreachable
Expand Down Expand Up @@ -1871,15 +1883,13 @@ struct TuningContext {
{
linear_.set_measure(false);
isTuning() = false;
// This will catch async errors during tuning
check_cuda_error(cudaStreamSynchronize(stream_));
}
};

} // namespace

template<class T>
void LlamaBatch<T>::tune()
void LlamaBatch<T>::tune() noexcept
{
auto& linear = *context_->linear;
if (auto str = std::getenv("TM_GEMM_IMPORT")) {
Expand Down
7 changes: 3 additions & 4 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct SharedState {
std::vector<std::shared_ptr<Request>> infer_reqs;
std::vector<std::shared_ptr<Request>> kill_reqs;
std::shared_ptr<Barrier> barrier;
std::shared_ptr<std::mutex> mutex;
bool abort;
std::atomic<size_t> free_size{std::numeric_limits<size_t>::max()};
};
Expand Down Expand Up @@ -127,16 +128,14 @@ class LlamaBatch {
return session_len_;
}

void tune();
void tune() noexcept;

private:
void BroadcastCancelFlags();

void ProcessCancelRequests(std::vector<Signal>& signals);

void InternalThreadEntry();

void OutputThreadEntry();
void InternalThreadEntry() noexcept;

void CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc);

Expand Down
7 changes: 6 additions & 1 deletion src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ PYBIND11_MODULE(_turbomind, m)
return oss.str();
});

py::class_<std::mutex, std::shared_ptr<std::mutex>>(m, "Mutex")
.def("lock", [](std::mutex& mutex) { mutex.lock(); })
.def("unlock", [](std::mutex& mutex) { mutex.unlock(); });

py::class_<ft::RequestState, std::unique_ptr<ft::RequestState>>(m, "RequestState")
.def_readonly("status", &ft::RequestState::status)
.def_readonly("seq_len", &ft::RequestState::seq_len);
Expand Down Expand Up @@ -618,5 +622,6 @@ PYBIND11_MODULE(_turbomind, m)
.def("__str__", &AbstractTransformerModel::toString)
.def("__repr__", &AbstractTransformerModel::toString)
.def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize)
.def("get_pipeline_para_size", &AbstractTransformerModel::getPipelineParaSize);
.def("get_pipeline_para_size", &AbstractTransformerModel::getPipelineParaSize)
.def_property_readonly("mutex", [](AbstractTransformerModel* model) { return model->mutex(); });
}
8 changes: 8 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ struct LlamaTritonModel: public AbstractTransformerModel {
int getTensorParaSize() override;
int getPipelineParaSize() override;

std::shared_ptr<std::mutex> mutex() const noexcept override
{
if (!shared_state_->mutex) {
shared_state_->mutex = std::make_shared<std::mutex>();
}
return shared_state_->mutex;
}

private:
std::unique_ptr<Engine<T>>
createSharedModelInstance(int deviceId,
Expand Down
5 changes: 5 additions & 0 deletions src/turbomind/triton_backend/transformer_triton_backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ struct AbstractTransformerModel {
std::pair<std::vector<NcclParam>, std::vector<NcclParam>> nccl_params,
std::shared_ptr<AbstractCustomComm>) = 0;

virtual std::shared_ptr<std::mutex> mutex() const noexcept
{
return {};
}

virtual std::string toString() = 0;
virtual int getTensorParaSize() = 0;
virtual int getPipelineParaSize() = 0;
Expand Down
Loading