From c9186691b5b14edbe70405a20738ecff4a230e9f Mon Sep 17 00:00:00 2001 From: q yao Date: Mon, 21 Oct 2024 10:59:32 +0800 Subject: [PATCH] Add barrier to prevent TP nccl kernel waiting. (#2607) * add mp.barrier * remove old exit mechanism of exit_flag (#4) * fix exit problem on ascend platform * remove exit_flag in tp exit * set log level --------- Co-authored-by: CyCle1024 --- lmdeploy/pytorch/engine/model_agent.py | 106 ++++++++++++------------- 1 file changed, 52 insertions(+), 54 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 4c902dbe2..1daf614c8 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -2,6 +2,7 @@ import asyncio import atexit import os +import threading from datetime import timedelta from typing import Any, Callable, Dict, List @@ -400,7 +401,7 @@ def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): """get input tensor parallel.""" # broadcast meta info if rank != 0: - inputs = [None, None, None, None] + inputs = [None, None, None] with torch.cuda.stream(stream): dist.broadcast_object_list(inputs) @@ -415,6 +416,7 @@ def _tp_model_loop( backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, + barrier: mp.Barrier, ): """Start model loops for tensor parallel model inference. @@ -438,12 +440,10 @@ def _tp_model_loop( world_size=world_size) while True: - inputs, swap_in_map, swap_out_map, exit_flag = _broadcast_inputs( + barrier.wait() + inputs, swap_in_map, swap_out_map = _broadcast_inputs( rank, None, stream) - if exit_flag: - break - cache_swapping(cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -460,6 +460,7 @@ def _tp_model_loop( def _start_tp_process(proc_id: int, world_size: int, func: Callable, + log_level: int, device_context: DeviceContext, args: List = None, kwargs: Dict = None): @@ -473,6 +474,7 @@ def _start_tp_process(proc_id: int, kwargs (Dict): The keyword arguments of the func. """ rank = proc_id + 1 + logger.setLevel(log_level) try: from lmdeploy.pytorch.check_env import check_env_deeplink check_env_deeplink(device_context.device_type) @@ -499,14 +501,16 @@ def _check_context_alive(mp_context: mp.ProcessContext): """check context alive.""" procs: List[mp.Process] = mp_context.processes failed_ranks = list(idx for idx, p in enumerate(procs) if not p.is_alive()) - if len(failed_ranks) > 0: - for p in procs: - if p.is_alive(): - p.terminate() - else: - p.close() - logger.error(f'TP process Rank{failed_ranks} failed.') - exit(1) + if len(failed_ranks) == 0: + return + for p in procs: + if p.is_alive(): + p.terminate() + else: + p.close() + logger.error(f'TP process {failed_ranks} failed.') + # TODO: not safe exit. + os._exit(1) def _find_available_port() -> bool: @@ -561,13 +565,14 @@ def __signal_term_handler(sig, frame): self.world_size = world_size self.backend_config = backend_config + self.mp_bar = self.mp_ctx.Barrier(world_size) self._start_sub_process(model_path, model_config=model_config, cache_config=cache_config, backend_config=backend_config, adapters=adapters, world_size=world_size, - trust_remote_code=trust_remote_code) + barrier=self.mp_bar) model, cache_engine, cache_config = self._build_model( model_path=model_path, @@ -575,18 +580,29 @@ def __signal_term_handler(sig, frame): cache_config=cache_config, backend_config=backend_config, adapters=adapters, - world_size=world_size, - ) + world_size=world_size) self.patched_model = model self.cache_config = cache_config self.cache_engine = cache_engine self.stream = torch.cuda.Stream() + def _mp_watchdog(self, mp_context: mp.ProcessContext, timeout: int = 1): + """watch dog of mp context. + + Args: + mp_context: context of multiprocess. + timeout: timeout + """ + import time + while True: + _check_context_alive(mp_context) + time.sleep(timeout) + def _start_sub_process(self, model_path: str, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, adapters: Dict[str, str], - world_size: int, trust_remote_code: bool): + world_size: int, barrier: mp.Barrier): """Start tensor parallel sub process.""" port = _find_available_port() os.environ.setdefault('MASTER_ADDR', '127.0.0.1') @@ -601,19 +617,27 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, args=( world_size, _tp_model_loop, + logger.level, device_context, (model_path, ), - dict(model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - adapters=adapters, - world_size=world_size), + dict( + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + adapters=adapters, + world_size=world_size, + barrier=barrier, + ), ), nprocs=world_size - 1, join=False, daemon=True, ) - _check_context_alive(self.mp_context) + + t_watchdog = threading.Thread(target=self._mp_watchdog, + args=[self.mp_context, 1.0], + daemon=True) + t_watchdog.start() rank = 0 try: @@ -628,8 +652,7 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, if dist.is_initialized(): dist.destroy_process_group() raise e - # Please see Note [Exit By Sending Exit Flag] - atexit.register(_exit_by_sending_exit_flag, rank, self) + atexit.register(_exit_handler, self) @torch.inference_mode() def _build_model( @@ -642,7 +665,6 @@ def _build_model( world_size: int, ): """build model.""" - _check_context_alive(self.mp_context) rank = 0 model, cache_engine, cache_config = _tp_build_model( rank, @@ -664,10 +686,9 @@ def get_block_numel(self): def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" - _check_context_alive(self.mp_context) + self.mp_bar.wait() rank = 0 - exit_flag = False - _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map, exit_flag], + _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], self.stream) cache_swapping(self.cache_engine, swap_in_map=swap_in_map, @@ -717,32 +738,9 @@ def get_logits(self, hidden_states: torch.Tensor): return self.patched_model.get_logits(hidden_states) -def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent): - """[Note] Exit By Sending Exit Flag: the registration to `atexit` of this - function should be called after importing torch.multiprocessing and the - initialization of distributed process group.""" - if not hasattr(agent, 'stream'): - # agent is not initialized, just exits normally - if hasattr(agent, 'patched_model'): - del agent.patched_model - return - - import sys - if 'torch_npu' in sys.modules and 'uvicorn.server' in sys.modules: - # Workaround for CLI serve mode with device_type ascend: - # using uvicorn server causes ascend low-level backend of subprocesses - # corrupted, and using _broadcast_inputs in this case leads to - # main process hanging, just exits normally +def _exit_handler(agent: TPModelAgent): + if hasattr(agent, 'patched_model'): del agent.patched_model - return - - # send exit_flag to all subprocess relying on all subprocess are alive - # and wait at _broadcast_inputs - exit_flag = True - _broadcast_inputs(rank, [None, None, None, exit_flag], agent.stream) - agent.stream.synchronize() - - del agent.patched_model def build_model_agent(model_path: str,