Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize tp broadcast #2889

Merged
merged 62 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
464d451
refactor VL modules for internvl and qwen2-vl (#2764)
lvhan028 Nov 18, 2024
3ba8309
Refactor VL modules for glm4v, deepseek-vl, llava-hf, cogvlm (#2772)
lvhan028 Nov 19, 2024
accfc00
Refactor VL modules for qwen-vl, llava and llava_next (#2773)
lvhan028 Nov 20, 2024
463b508
Refactor VL modules for qwen2-vl (#2777)
lvhan028 Nov 20, 2024
9e18529
Fix side-effect to internvl (#2778)
lvhan028 Nov 20, 2024
b2c29be
Refactor VL modules for phi3-vision (#2779)
lvhan028 Nov 20, 2024
aca0f2d
Refactor VL modules for mllama and yi-vl (#2781)
lvhan028 Nov 20, 2024
fff7b91
Refactor VLM module for minicpm and molmo (#2794)
lvhan028 Nov 22, 2024
45cf22d
Refactor VLM modules for xcomposer series (#2796)
lvhan028 Nov 22, 2024
38eec0d
Refactor VLM modules for internvl-llava (#2797)
lvhan028 Nov 22, 2024
c4638d5
Refactor VLM modules v2 (#2806)
lvhan028 Nov 25, 2024
b07211b
Remove vl template (#2809)
lvhan028 Nov 25, 2024
e7b5a7e
Resolve conflicts (#2811)
lvhan028 Nov 25, 2024
70875eb
resolve conflicts
lvhan028 Nov 25, 2024
787f765
Merge pull request #2812 from lvhan028/resolve-conflicts
lvhan028 Nov 25, 2024
099721a
PytorchEngine refactor multimodal (#2742)
grimoire Nov 25, 2024
36a15e3
minor-fix
lvhan028 Nov 25, 2024
4e6760e
minor-fix (#2813)
lvhan028 Nov 25, 2024
f7c167e
fix
grimoire Dec 3, 2024
e977361
fix
lvhan028 Dec 4, 2024
5dc967d
fix mono
grimoire Dec 4, 2024
ae7015a
fix docs
grimoire Dec 4, 2024
c746fd3
read norm_type
grimoire Dec 4, 2024
dc9757c
Merge branch 'main' into refactor-vl
lvhan028 Dec 4, 2024
d577acb
super().collect_images->self.collect_images
lvhan028 Dec 4, 2024
30ea075
Merge branch 'refactor-vl' of github.com:InternLM/lmdeploy into refac…
lvhan028 Dec 4, 2024
bbcf9a5
add note in supported models
lvhan028 Dec 5, 2024
b3a2887
define the parameters clearly
lvhan028 Dec 5, 2024
8f7a56f
better streaming
grimoire Dec 5, 2024
c2a5b44
merge main
grimoire Dec 6, 2024
29c3558
fix molmo
lvhan028 Dec 6, 2024
240f352
Merge branch 'refactor-vl' of github.com:InternLM/lmdeploy into refac…
grimoire Dec 6, 2024
74e7bf8
Fix vision model batch inference (#2868)
lvhan028 Dec 9, 2024
fb3f8cc
warn glm4v does not support multi images
lvhan028 Dec 9, 2024
a65007b
unconst
grimoire Dec 9, 2024
88f99d4
Merge branch 'refactor-vl' of github.com:InternLM/lmdeploy into refac…
grimoire Dec 9, 2024
ed2efb3
fix deepseek-vl
lvhan028 Dec 9, 2024
d8828cf
merge main
lvhan028 Dec 9, 2024
18b38e9
fix internvl
lvhan028 Dec 9, 2024
3cb7e4d
Merge branch 'refactor-vl' into fix-refactor-vl
lvhan028 Dec 9, 2024
92b09d0
fix llava
grimoire Dec 10, 2024
db367f4
fix minicpm 2.6
lvhan028 Dec 10, 2024
4e2f1f8
fix callback
grimoire Dec 10, 2024
715fbb3
fix minicpm v2.5
lvhan028 Dec 10, 2024
1a6d88f
fix minicpm v2.6
lvhan028 Dec 10, 2024
8ee759f
Merge branch 'refactor-vl' into fix-refactor-vl
lvhan028 Dec 10, 2024
f0a4422
update llava_next.py
lvhan028 Dec 10, 2024
ee022ad
remove hardcode from xcomposer2.py
lvhan028 Dec 10, 2024
02a25eb
Merge pull request #2879 from lvhan028/fix-refactor-vl
lvhan028 Dec 10, 2024
a21abe3
rollback supported_models
lvhan028 Dec 10, 2024
6a9342e
change to staticmethod
lvhan028 Dec 10, 2024
798298b
optimize tp
grimoire Dec 12, 2024
1b6ea24
solve conflict
grimoire Dec 12, 2024
a9aacda
Merge branch 'refactor-vl' into optimize-tp-broadcast
grimoire Dec 12, 2024
d005bc8
Merge branch 'refactor-vl' of github.com:InternLM/lmdeploy into refac…
grimoire Dec 12, 2024
e9517e1
Merge branch 'refactor-vl' into optimize-tp-broadcast
grimoire Dec 12, 2024
4107d4f
fix vlm quantization
lvhan028 Dec 12, 2024
fdaa601
Merge branch 'refactor-vl' of github.com:InternLM/lmdeploy into refac…
lvhan028 Dec 12, 2024
e5a5085
update doc
lvhan028 Dec 13, 2024
bc93e73
update
lvhan028 Dec 13, 2024
67188eb
Merge branch 'refactor-vl' into optimize-tp-broadcast
grimoire Dec 13, 2024
6cdc76a
solve conflict
grimoire Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,8 +782,7 @@ def __update_inputs(next_token_ids):
logger.debug('<ForwardTask>: '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)}')
if self.gpu_count == 1:
inputs = inputs.to_device('cuda')
inputs = inputs.to_device('cuda')
is_decoding = inputs.is_decoding
if all_ids is not None:
all_ids = all_ids.cuda()
Expand Down
36 changes: 27 additions & 9 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def model_forward(
stream = stream or torch.cuda.current_stream()
with torch.cuda.stream(stream):
# forward
inputs = inputs.to_device('cuda')
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
Expand Down Expand Up @@ -372,14 +371,26 @@ def _broadcast_config(cache_config):
return patched_model, cache_engine, cache_config


def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream):
def _broadcast_inputs(rank: int, inputs: Any, group: dist.group,
stream: torch.cuda.Stream):
"""get input tensor parallel."""
# broadcast meta info
if rank != 0:
inputs = [None, None, None]
else:
device_inputs = inputs[0]
meta_inputs = device_inputs.to_device('meta')
inputs[0] = meta_inputs

with torch.cuda.stream(stream):
dist.broadcast_object_list(inputs)
dist.broadcast_object_list(inputs, group=group)
if rank == 0:
device_inputs.broadcast()
else:
device_inputs = inputs[0].broadcast()

inputs[0] = device_inputs

return inputs


Expand All @@ -392,6 +403,7 @@ def _tp_model_loop(
adapters: Dict[str, str],
world_size: int,
barrier: mp.Barrier,
cpu_group: dist.group,
):
"""Start model loops for tensor parallel model inference.

Expand All @@ -417,11 +429,12 @@ def _tp_model_loop(
while True:
barrier.wait()
inputs, swap_in_map, swap_out_map = _broadcast_inputs(
rank, None, stream)
rank, None, cpu_group, stream)

cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
inputs = inputs.to_device('cuda')

model_forward(
patched_model,
Expand Down Expand Up @@ -453,10 +466,13 @@ def _start_tp_process(proc_id: int,
try:
from lmdeploy.pytorch.check_env import check_env_deeplink
check_env_deeplink(device_context.device_type)
timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
timeout=timeout)
cpu_group = dist.new_group(timeout=timeout, backend='gloo')
kwargs['cpu_group'] = cpu_group
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_dist_manager().context(dist_ctx), get_device_manager(
Expand Down Expand Up @@ -626,12 +642,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,

rank = 0
try:
timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
timeout=timeout)
cpu_group = dist.new_group(timeout=timeout, backend='gloo')
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
self._cpu_group = cpu_group
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
Expand Down Expand Up @@ -673,7 +692,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
self._cpu_group, self.stream)

cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
Expand All @@ -699,8 +719,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
while not self.stream.query():
await asyncio.sleep(0)
return output

def get_logits(self, hidden_states: torch.Tensor):
Expand Down
64 changes: 63 additions & 1 deletion lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
from typing import Any, Dict, List, Literal

import torch
from torch import distributed as dist

from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.config import ModelConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor


def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
"""broadcast tensor."""
if value.device.type == 'meta':
value = torch.empty_like(value, device=device)
dist.broadcast(value, src)
return value


@dataclass
class VisionModelInputs:
"""Vision model inputs."""
Expand All @@ -36,10 +45,45 @@ def to_device(self, device: str):
elif k == 'input_embeddings':
v = [[e.to(device) for e in li] for li in v]
elif k == 'input_multimodals':
new_v = []
for mm_datas in v:
new_mm_datas = dict()
for modal_type, data in mm_datas.items():
data = [d.to_device(device) for d in data]
mm_datas[modal_type] = data
new_mm_datas[modal_type] = data
new_v.append(new_mm_datas)
v = new_v
out_dict[k] = v

return VisionModelInputs(**out_dict)

def broadcast(self):
"""broadcast inputs.

Do `dist.broadcast_object_list(inputs.to_device('meta'))`
before broadcast tensors.
"""
out_dict = dict()
for f in fields(self):
k = f.name
v = getattr(self, k)
if v is None:
continue
if isinstance(v, torch.Tensor):
v = _broadcast_tensor(v)
elif k == 'input_embedding_ranges':
v = [_broadcast_tensor(e) for e in v]
elif k == 'input_embeddings':
v = [[_broadcast_tensor(e) for e in li] for li in v]
elif k == 'input_multimodals':
new_v = []
for mm_datas in v:
new_mm_datas = dict()
for modal_type, data in mm_datas.items():
data = [d.broadcast() for d in data]
new_mm_datas[modal_type] = data
new_v.append(new_mm_datas)
v = new_v
out_dict[k] = v

return VisionModelInputs(**out_dict)
Expand Down Expand Up @@ -202,6 +246,24 @@ def to_device(self, device: str):

return ModelInputs(**out_dict)

def broadcast(self):
"""broadcast inputs.

Do `dist.broadcast_object_list(inputs.to_device('meta'))`
before broadcast tensors.
"""
out_dict = dict()
for f in fields(self):
k = f.name
v = getattr(self, k)
if isinstance(v, torch.Tensor):
v = _broadcast_tensor(v)
elif isinstance(v, VisionModelInputs):
v = v.broadcast()
out_dict[k] = v

return ModelInputs(**out_dict)


@dataclass
class StepContext:
Expand Down
63 changes: 58 additions & 5 deletions lmdeploy/pytorch/multimodal/data_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Union

import torch
from torch import Tensor
from torch import distributed as dist


class MultiModalData:
Expand All @@ -14,6 +16,14 @@ class MultiModalData:
NestedTensor = Union[Tensor, List[Tensor]]


def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
"""broadcast tensor."""
if value.device.type == 'meta':
value = torch.empty_like(value, device=device)
dist.broadcast(value, src)
return value


@dataclass
class MultiModalTensor:
data: NestedTensor
Expand All @@ -28,24 +38,67 @@ def __post_init__(self):

def to_device(self, device: str, non_blocking: bool = False):
"""to device."""
out_dict = dict()
for f in fields(self):
k = f.name
if k in ('data', 'meta'):
continue
v = getattr(self, k)
out_dict[k] = v

if isinstance(self.data, Tensor):
self.data = self.data.to(device=device, non_blocking=non_blocking)
data = self.data.to(device=device, non_blocking=non_blocking)
else:
data = [
d.to(device=device, non_blocking=non_blocking)
for d in self.data
]
self.data = data
out_dict['data'] = data

new_meta = None
if self.meta is not None:
new_meta = dict()
for k, v in self.meta.items():
if isinstance(v, Tensor):
v = v.to(device=device, non_blocking=non_blocking)
self.meta[k] = v
elif hasattr(v, 'to_device'):
v = v.to_device(device=device, non_blocking=non_blocking)
new_meta[k] = v

out_dict['meta'] = new_meta
return MultiModalTensor(**out_dict)

def broadcast(self):
"""broadcast inputs tensors."""
out_dict = dict()
for f in fields(self):
k = f.name
if k in ('data', 'meta'):
continue
v = getattr(self, k)
out_dict[k] = v

if isinstance(self.data, Tensor):
data = _broadcast_tensor(self.data)
else:
data = [_broadcast_tensor(d) for d in self.data]
out_dict['data'] = data

new_meta = None
if self.meta is not None:
new_meta = dict()
for k, v in self.meta.items():
if isinstance(v, Tensor):
v = _broadcast_tensor(v)
self.meta[k] = v
elif hasattr(v, 'to_device'):
assert hasattr(v, 'broadcast')
v = v.broadcast()
self.meta[k] = v
return self
new_meta[k] = v

out_dict['meta'] = new_meta
return MultiModalTensor(**out_dict)


MultiModalInputs = Dict[str, List[MultiModalTensor]]