Skip to content

Commit

Permalink
Merge branch 'main' into remove-threadsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Dec 18, 2024
2 parents a1caab8 + 8afb84c commit 2de5db2
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 17 deletions.
3 changes: 1 addition & 2 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,7 @@ def __update_inputs(next_token_ids):
logger.debug('<ForwardTask>: '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)}')
if self.gpu_count == 1:
inputs = inputs.to_device('cuda')
inputs = inputs.to_device('cuda')
is_decoding = inputs.is_decoding
if all_ids is not None:
all_ids = all_ids.cuda()
Expand Down
36 changes: 27 additions & 9 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def model_forward(
stream = stream or torch.cuda.current_stream()
with torch.cuda.stream(stream):
# forward
inputs = inputs.to_device('cuda')
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
Expand Down Expand Up @@ -372,14 +371,26 @@ def _broadcast_config(cache_config):
return patched_model, cache_engine, cache_config


def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream):
def _broadcast_inputs(rank: int, inputs: Any, group: dist.group,
stream: torch.cuda.Stream):
"""get input tensor parallel."""
# broadcast meta info
if rank != 0:
inputs = [None, None, None]
else:
device_inputs = inputs[0]
meta_inputs = device_inputs.to_device('meta')
inputs[0] = meta_inputs

with torch.cuda.stream(stream):
dist.broadcast_object_list(inputs)
dist.broadcast_object_list(inputs, group=group)
if rank == 0:
device_inputs.broadcast()
else:
device_inputs = inputs[0].broadcast()

inputs[0] = device_inputs

return inputs


Expand All @@ -392,6 +403,7 @@ def _tp_model_loop(
adapters: Dict[str, str],
world_size: int,
barrier: mp.Barrier,
cpu_group: dist.group,
):
"""Start model loops for tensor parallel model inference.
Expand All @@ -417,11 +429,12 @@ def _tp_model_loop(
while True:
barrier.wait()
inputs, swap_in_map, swap_out_map = _broadcast_inputs(
rank, None, stream)
rank, None, cpu_group, stream)

cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
inputs = inputs.to_device('cuda')

model_forward(
patched_model,
Expand Down Expand Up @@ -453,10 +466,13 @@ def _start_tp_process(proc_id: int,
try:
from lmdeploy.pytorch.check_env import check_env_deeplink
check_env_deeplink(device_context.device_type)
timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
timeout=timeout)
cpu_group = dist.new_group(timeout=timeout, backend='gloo')
kwargs['cpu_group'] = cpu_group
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_dist_manager().context(dist_ctx), get_device_manager(
Expand Down Expand Up @@ -626,12 +642,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,

rank = 0
try:
timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
timeout=timeout)
cpu_group = dist.new_group(timeout=timeout, backend='gloo')
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
self._cpu_group = cpu_group
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
Expand Down Expand Up @@ -673,7 +692,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
self._cpu_group, self.stream)

cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
Expand All @@ -699,8 +719,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
while not self.stream.query():
await asyncio.sleep(0)
return output

def get_logits(self, hidden_states: torch.Tensor):
Expand Down
64 changes: 63 additions & 1 deletion lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
from typing import Any, Dict, List, Literal

import torch
from torch import distributed as dist

from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.config import ModelConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor


def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
"""broadcast tensor."""
if value.device.type == 'meta':
value = torch.empty_like(value, device=device)
dist.broadcast(value, src)
return value


@dataclass
class VisionModelInputs:
"""Vision model inputs."""
Expand All @@ -36,10 +45,45 @@ def to_device(self, device: str):
elif k == 'input_embeddings':
v = [[e.to(device) for e in li] for li in v]
elif k == 'input_multimodals':
new_v = []
for mm_datas in v:
new_mm_datas = dict()
for modal_type, data in mm_datas.items():
data = [d.to_device(device) for d in data]
mm_datas[modal_type] = data
new_mm_datas[modal_type] = data
new_v.append(new_mm_datas)
v = new_v
out_dict[k] = v

return VisionModelInputs(**out_dict)

def broadcast(self):
"""broadcast inputs.
Do `dist.broadcast_object_list(inputs.to_device('meta'))`
before broadcast tensors.
"""
out_dict = dict()
for f in fields(self):
k = f.name
v = getattr(self, k)
if v is None:
continue
if isinstance(v, torch.Tensor):
v = _broadcast_tensor(v)
elif k == 'input_embedding_ranges':
v = [_broadcast_tensor(e) for e in v]
elif k == 'input_embeddings':
v = [[_broadcast_tensor(e) for e in li] for li in v]
elif k == 'input_multimodals':
new_v = []
for mm_datas in v:
new_mm_datas = dict()
for modal_type, data in mm_datas.items():
data = [d.broadcast() for d in data]
new_mm_datas[modal_type] = data
new_v.append(new_mm_datas)
v = new_v
out_dict[k] = v

return VisionModelInputs(**out_dict)
Expand Down Expand Up @@ -202,6 +246,24 @@ def to_device(self, device: str):

return ModelInputs(**out_dict)

def broadcast(self):
"""broadcast inputs.
Do `dist.broadcast_object_list(inputs.to_device('meta'))`
before broadcast tensors.
"""
out_dict = dict()
for f in fields(self):
k = f.name
v = getattr(self, k)
if isinstance(v, torch.Tensor):
v = _broadcast_tensor(v)
elif isinstance(v, VisionModelInputs):
v = v.broadcast()
out_dict[k] = v

return ModelInputs(**out_dict)


@dataclass
class StepContext:
Expand Down
63 changes: 58 additions & 5 deletions lmdeploy/pytorch/multimodal/data_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Union

import torch
from torch import Tensor
from torch import distributed as dist


class MultiModalData:
Expand All @@ -14,6 +16,14 @@ class MultiModalData:
NestedTensor = Union[Tensor, List[Tensor]]


def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
"""broadcast tensor."""
if value.device.type == 'meta':
value = torch.empty_like(value, device=device)
dist.broadcast(value, src)
return value


@dataclass
class MultiModalTensor:
data: NestedTensor
Expand All @@ -28,24 +38,67 @@ def __post_init__(self):

def to_device(self, device: str, non_blocking: bool = False):
"""to device."""
out_dict = dict()
for f in fields(self):
k = f.name
if k in ('data', 'meta'):
continue
v = getattr(self, k)
out_dict[k] = v

if isinstance(self.data, Tensor):
self.data = self.data.to(device=device, non_blocking=non_blocking)
data = self.data.to(device=device, non_blocking=non_blocking)
else:
data = [
d.to(device=device, non_blocking=non_blocking)
for d in self.data
]
self.data = data
out_dict['data'] = data

new_meta = None
if self.meta is not None:
new_meta = dict()
for k, v in self.meta.items():
if isinstance(v, Tensor):
v = v.to(device=device, non_blocking=non_blocking)
self.meta[k] = v
elif hasattr(v, 'to_device'):
v = v.to_device(device=device, non_blocking=non_blocking)
new_meta[k] = v

out_dict['meta'] = new_meta
return MultiModalTensor(**out_dict)

def broadcast(self):
"""broadcast inputs tensors."""
out_dict = dict()
for f in fields(self):
k = f.name
if k in ('data', 'meta'):
continue
v = getattr(self, k)
out_dict[k] = v

if isinstance(self.data, Tensor):
data = _broadcast_tensor(self.data)
else:
data = [_broadcast_tensor(d) for d in self.data]
out_dict['data'] = data

new_meta = None
if self.meta is not None:
new_meta = dict()
for k, v in self.meta.items():
if isinstance(v, Tensor):
v = _broadcast_tensor(v)
self.meta[k] = v
elif hasattr(v, 'to_device'):
assert hasattr(v, 'broadcast')
v = v.broadcast()
self.meta[k] = v
return self
new_meta[k] = v

out_dict['meta'] = new_meta
return MultiModalTensor(**out_dict)


MultiModalInputs = Dict[str, List[MultiModalTensor]]

0 comments on commit 2de5db2

Please sign in to comment.