Skip to content

Commit

Permalink
Add barrier to prevent TP nccl kernel waiting. (#2607)
Browse files Browse the repository at this point in the history
* add mp.barrier

* remove old exit mechanism of exit_flag (#4)

* fix exit problem on ascend platform

* remove exit_flag in tp exit

* set log level

---------

Co-authored-by: CyCle1024 <[email protected]>
  • Loading branch information
grimoire and CyCle1024 authored Oct 21, 2024
1 parent a465e60 commit c918669
Showing 1 changed file with 52 additions and 54 deletions.
106 changes: 52 additions & 54 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import atexit
import os
import threading
from datetime import timedelta
from typing import Any, Callable, Dict, List

Expand Down Expand Up @@ -400,7 +401,7 @@ def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream):
"""get input tensor parallel."""
# broadcast meta info
if rank != 0:
inputs = [None, None, None, None]
inputs = [None, None, None]

with torch.cuda.stream(stream):
dist.broadcast_object_list(inputs)
Expand All @@ -415,6 +416,7 @@ def _tp_model_loop(
backend_config: BackendConfig,
adapters: Dict[str, str],
world_size: int,
barrier: mp.Barrier,
):
"""Start model loops for tensor parallel model inference.
Expand All @@ -438,12 +440,10 @@ def _tp_model_loop(
world_size=world_size)

while True:
inputs, swap_in_map, swap_out_map, exit_flag = _broadcast_inputs(
barrier.wait()
inputs, swap_in_map, swap_out_map = _broadcast_inputs(
rank, None, stream)

if exit_flag:
break

cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
Expand All @@ -460,6 +460,7 @@ def _tp_model_loop(
def _start_tp_process(proc_id: int,
world_size: int,
func: Callable,
log_level: int,
device_context: DeviceContext,
args: List = None,
kwargs: Dict = None):
Expand All @@ -473,6 +474,7 @@ def _start_tp_process(proc_id: int,
kwargs (Dict): The keyword arguments of the func.
"""
rank = proc_id + 1
logger.setLevel(log_level)
try:
from lmdeploy.pytorch.check_env import check_env_deeplink
check_env_deeplink(device_context.device_type)
Expand All @@ -499,14 +501,16 @@ def _check_context_alive(mp_context: mp.ProcessContext):
"""check context alive."""
procs: List[mp.Process] = mp_context.processes
failed_ranks = list(idx for idx, p in enumerate(procs) if not p.is_alive())
if len(failed_ranks) > 0:
for p in procs:
if p.is_alive():
p.terminate()
else:
p.close()
logger.error(f'TP process Rank{failed_ranks} failed.')
exit(1)
if len(failed_ranks) == 0:
return
for p in procs:
if p.is_alive():
p.terminate()
else:
p.close()
logger.error(f'TP process {failed_ranks} failed.')
# TODO: not safe exit.
os._exit(1)


def _find_available_port() -> bool:
Expand Down Expand Up @@ -561,32 +565,44 @@ def __signal_term_handler(sig, frame):
self.world_size = world_size
self.backend_config = backend_config

self.mp_bar = self.mp_ctx.Barrier(world_size)
self._start_sub_process(model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
trust_remote_code=trust_remote_code)
barrier=self.mp_bar)

model, cache_engine, cache_config = self._build_model(
model_path=model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
)
world_size=world_size)
self.patched_model = model
self.cache_config = cache_config
self.cache_engine = cache_engine
self.stream = torch.cuda.Stream()

def _mp_watchdog(self, mp_context: mp.ProcessContext, timeout: int = 1):
"""watch dog of mp context.
Args:
mp_context: context of multiprocess.
timeout: timeout
"""
import time
while True:
_check_context_alive(mp_context)
time.sleep(timeout)

def _start_sub_process(self, model_path: str, model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig, adapters: Dict[str,
str],
world_size: int, trust_remote_code: bool):
world_size: int, barrier: mp.Barrier):
"""Start tensor parallel sub process."""
port = _find_available_port()
os.environ.setdefault('MASTER_ADDR', '127.0.0.1')
Expand All @@ -601,19 +617,27 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
args=(
world_size,
_tp_model_loop,
logger.level,
device_context,
(model_path, ),
dict(model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size),
dict(
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
barrier=barrier,
),
),
nprocs=world_size - 1,
join=False,
daemon=True,
)
_check_context_alive(self.mp_context)

t_watchdog = threading.Thread(target=self._mp_watchdog,
args=[self.mp_context, 1.0],
daemon=True)
t_watchdog.start()

rank = 0
try:
Expand All @@ -628,8 +652,7 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
if dist.is_initialized():
dist.destroy_process_group()
raise e
# Please see Note [Exit By Sending Exit Flag]
atexit.register(_exit_by_sending_exit_flag, rank, self)
atexit.register(_exit_handler, self)

@torch.inference_mode()
def _build_model(
Expand All @@ -642,7 +665,6 @@ def _build_model(
world_size: int,
):
"""build model."""
_check_context_alive(self.mp_context)
rank = 0
model, cache_engine, cache_config = _tp_build_model(
rank,
Expand All @@ -664,10 +686,9 @@ def get_block_numel(self):
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""forward impl."""
_check_context_alive(self.mp_context)
self.mp_bar.wait()
rank = 0
exit_flag = False
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map, exit_flag],
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
Expand Down Expand Up @@ -717,32 +738,9 @@ def get_logits(self, hidden_states: torch.Tensor):
return self.patched_model.get_logits(hidden_states)


def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent):
"""[Note] Exit By Sending Exit Flag: the registration to `atexit` of this
function should be called after importing torch.multiprocessing and the
initialization of distributed process group."""
if not hasattr(agent, 'stream'):
# agent is not initialized, just exits normally
if hasattr(agent, 'patched_model'):
del agent.patched_model
return

import sys
if 'torch_npu' in sys.modules and 'uvicorn.server' in sys.modules:
# Workaround for CLI serve mode with device_type ascend:
# using uvicorn server causes ascend low-level backend of subprocesses
# corrupted, and using _broadcast_inputs in this case leads to
# main process hanging, just exits normally
def _exit_handler(agent: TPModelAgent):
if hasattr(agent, 'patched_model'):
del agent.patched_model
return

# send exit_flag to all subprocess relying on all subprocess are alive
# and wait at _broadcast_inputs
exit_flag = True
_broadcast_inputs(rank, [None, None, None, exit_flag], agent.stream)
agent.stream.synchronize()

del agent.patched_model


def build_model_agent(model_path: str,
Expand Down

0 comments on commit c918669

Please sign in to comment.