Skip to content

Commit

Permalink
Optimize tp broadcast (#2889)
Browse files Browse the repository at this point in the history
* refactor VL modules for internvl and qwen2-vl (#2764)

* qwen2-vl

* internvl

* qwen2

* Refactor VL modules for glm4v, deepseek-vl, llava-hf, cogvlm (#2772)

* qwen2-vl

* internvl

* qwen2

* get image_tokens_per_patch for internvl2

* deepseek-vl

* cogvlm

* glm4v

* update internvl

* internvl_llava

* llava

* glm4v

* upate internvl

* cogvlm

* deepseek

* llava_hf

* rollback llava, internvl-llava

* Refactor VL modules for qwen-vl, llava and llava_next (#2773)

* qwen2-vl

* internvl

* qwen2

* get image_tokens_per_patch for internvl2

* deepseek-vl

* cogvlm

* glm4v

* update internvl

* internvl_llava

* llava

* glm4v

* upate internvl

* cogvlm

* deepseek

* llava_hf

* rollback llava, internvl-llava

* refactor qwen

* update internvl

* update llava_hf

* update qwen2-vl

* llava_next

* update llava_next

* update llava

* update llava

* update llava

* Refactor VL modules for qwen2-vl (#2777)

* qwen2-vl

* internvl

* qwen2

* get image_tokens_per_patch for internvl2

* deepseek-vl

* cogvlm

* glm4v

* update internvl

* internvl_llava

* llava

* glm4v

* upate internvl

* cogvlm

* deepseek

* llava_hf

* rollback llava, internvl-llava

* refactor qwen

* update internvl

* update llava_hf

* update qwen2-vl

* llava_next

* update llava_next

* update llava

* update llava

* update llava

* qwen2

* Fix side-effect to internvl (#2778)

* qwen2-vl

* internvl

* qwen2

* get image_tokens_per_patch for internvl2

* deepseek-vl

* cogvlm

* glm4v

* update internvl

* internvl_llava

* llava

* glm4v

* upate internvl

* cogvlm

* deepseek

* llava_hf

* rollback llava, internvl-llava

* refactor qwen

* update internvl

* update llava_hf

* update qwen2-vl

* llava_next

* update llava_next

* update llava

* update llava

* update llava

* qwen2

* fix internvl

* Refactor VL modules for phi3-vision (#2779)

* qwen2-vl

* internvl

* qwen2

* get image_tokens_per_patch for internvl2

* deepseek-vl

* cogvlm

* glm4v

* update internvl

* internvl_llava

* llava

* glm4v

* upate internvl

* cogvlm

* deepseek

* llava_hf

* rollback llava, internvl-llava

* refactor qwen

* update internvl

* update llava_hf

* update qwen2-vl

* llava_next

* update llava_next

* update llava

* update llava

* update llava

* qwen2

* fix internvl

* phi3-vision

* Refactor VL modules for mllama and yi-vl (#2781)

* qwen2-vl

* internvl

* qwen2

* get image_tokens_per_patch for internvl2

* deepseek-vl

* cogvlm

* glm4v

* update internvl

* internvl_llava

* llava

* glm4v

* upate internvl

* cogvlm

* deepseek

* llava_hf

* rollback llava, internvl-llava

* refactor qwen

* update internvl

* update llava_hf

* update qwen2-vl

* llava_next

* update llava_next

* update llava

* update llava

* update llava

* qwen2

* fix internvl

* phi3-vision

* refactor yi-vl

* refactor mllama

* Refactor VLM module for minicpm and molmo (#2794)

* Refactor VLM modules for xcomposer series (#2796)

* Refactor VLM modules for internvl-llava (#2797)

* Refactor VLM modules v2 (#2806)

* internvl2 v2

* cogvlm

* deepseek-vl

* glm-4v

* llava-hf

* llava-next

* llava

* internvl-llava

* mllama

* phi3-vision

* qwen

* qwen2

* yi-vl

* xcomposer

* minicpm

* molmo

* update

* update

* Remove vl template (#2809)

* Resolve conflicts (#2811)

* feature: support qwen2.5 fuction_call (#2737)

* feat: support qwen2.5 tools_call

* fix: npe bug

* fix: 模版不一致

* fix: adopting review suggestions

* fix: adopting review suggestions

* fix: adopting review suggestions

* fix: adopting review suggestions

* feat: Support multi tools calling

* feat: Support multi tools calling

* fix: Add '\n' between each tool

* fix: Add ensure_ascii=False

* bugfix: rfind

* bugfix: tools_call -> tool_calls

* bugfix: add toolName in tool_response

* fix: some '\n' error

* fix: remove toolname

* fix: replace '\n' to self.separator

* feat: add doc with multiple tool calling

* fix:update doc

* feat: add qwen2.5 prompt template test

* feat: add qwen2.5 no tool call prompt test

---------

Co-authored-by: gaozixiang <[email protected]>

* Update supported models & Ascend doc (#2765)

* update ascend supported model list

* fix markdown

* fix markdown

* fix lint

* Update get_started.md

* Update get_started.md

* [CI] Split vl testcases into turbomind and pytorch backend (#2751)

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* [Feature] support minicpm-v_2_6 for pytorch engine. (#2767)

* support minicpmv_2_6.

* update supported_models.

* update supported_models.

* Support qwen2-vl AWQ quantization (#2787)

* Support qwen2-vl AWQ quantization

* Update config.yaml


* [dlinfer] Fix qwenvl rope error for dlinfer backend (#2795)

* Optimize update_step_ctx on Ascend (#2804)

* opt update_ctx for ascend

* fix lint


* PytorchEngine refactor multimodal (#2742)

* WIP

* support mrope

* support long context

* support causal=false

* fix mask

* flash attn bound

* optimize

* Moskau, Moskau, wirf die Gläser an die Wand

* YMCA

* optimize mllama

* update processor

* support cogvlm

* all work and no play make jack a dull boy

* upgrade triton

* support qwen2vl

* support internvl

* phi3-v WIP

* glm4v WIP

* support chatglm and cogvlm

* use image tokens

* support llava

* support internvl-mono

* phi3v, mllama

* add llavanext

* use img token ids

* support multiimage chatglm cogvlm

* fix ut

* minor-fix

* minor-fix (#2813)

* fix

* fix mono

* fix docs

* read norm_type

* super().collect_images->self.collect_images

* add note in supported models

* define the parameters clearly

* better streaming

* fix molmo

* Fix vision model batch inference (#2868)

* remove forward from vl models that are not supported by tm

* support max_batch_size

* fix

* warn glm4v does not support multi images

* unconst

* fix deepseek-vl

* fix internvl

* fix llava

* fix minicpm 2.6

* fix callback

* fix minicpm v2.5

* fix minicpm v2.6

* update llava_next.py

* remove hardcode from xcomposer2.py

* rollback supported_models

* change to staticmethod

* optimize tp

* fix vlm quantization

* update doc

* update
  • Loading branch information
grimoire authored Dec 17, 2024
1 parent 1efed79 commit 8afb84c
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 17 deletions.
3 changes: 1 addition & 2 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,8 +782,7 @@ def __update_inputs(next_token_ids):
logger.debug('<ForwardTask>: '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)}')
if self.gpu_count == 1:
inputs = inputs.to_device('cuda')
inputs = inputs.to_device('cuda')
is_decoding = inputs.is_decoding
if all_ids is not None:
all_ids = all_ids.cuda()
Expand Down
36 changes: 27 additions & 9 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def model_forward(
stream = stream or torch.cuda.current_stream()
with torch.cuda.stream(stream):
# forward
inputs = inputs.to_device('cuda')
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
Expand Down Expand Up @@ -372,14 +371,26 @@ def _broadcast_config(cache_config):
return patched_model, cache_engine, cache_config


def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream):
def _broadcast_inputs(rank: int, inputs: Any, group: dist.group,
stream: torch.cuda.Stream):
"""get input tensor parallel."""
# broadcast meta info
if rank != 0:
inputs = [None, None, None]
else:
device_inputs = inputs[0]
meta_inputs = device_inputs.to_device('meta')
inputs[0] = meta_inputs

with torch.cuda.stream(stream):
dist.broadcast_object_list(inputs)
dist.broadcast_object_list(inputs, group=group)
if rank == 0:
device_inputs.broadcast()
else:
device_inputs = inputs[0].broadcast()

inputs[0] = device_inputs

return inputs


Expand All @@ -392,6 +403,7 @@ def _tp_model_loop(
adapters: Dict[str, str],
world_size: int,
barrier: mp.Barrier,
cpu_group: dist.group,
):
"""Start model loops for tensor parallel model inference.
Expand All @@ -417,11 +429,12 @@ def _tp_model_loop(
while True:
barrier.wait()
inputs, swap_in_map, swap_out_map = _broadcast_inputs(
rank, None, stream)
rank, None, cpu_group, stream)

cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
inputs = inputs.to_device('cuda')

model_forward(
patched_model,
Expand Down Expand Up @@ -453,10 +466,13 @@ def _start_tp_process(proc_id: int,
try:
from lmdeploy.pytorch.check_env import check_env_deeplink
check_env_deeplink(device_context.device_type)
timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
timeout=timeout)
cpu_group = dist.new_group(timeout=timeout, backend='gloo')
kwargs['cpu_group'] = cpu_group
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_dist_manager().context(dist_ctx), get_device_manager(
Expand Down Expand Up @@ -626,12 +642,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,

rank = 0
try:
timeout = timedelta(days=35600)
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
timeout=timeout)
cpu_group = dist.new_group(timeout=timeout, backend='gloo')
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
self._cpu_group = cpu_group
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
Expand Down Expand Up @@ -673,7 +692,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
self._cpu_group, self.stream)

cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
Expand All @@ -699,8 +719,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
while not self.stream.query():
await asyncio.sleep(0)
return output

def get_logits(self, hidden_states: torch.Tensor):
Expand Down
64 changes: 63 additions & 1 deletion lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
from typing import Any, Dict, List, Literal

import torch
from torch import distributed as dist

from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.config import ModelConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor


def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
"""broadcast tensor."""
if value.device.type == 'meta':
value = torch.empty_like(value, device=device)
dist.broadcast(value, src)
return value


@dataclass
class VisionModelInputs:
"""Vision model inputs."""
Expand All @@ -36,10 +45,45 @@ def to_device(self, device: str):
elif k == 'input_embeddings':
v = [[e.to(device) for e in li] for li in v]
elif k == 'input_multimodals':
new_v = []
for mm_datas in v:
new_mm_datas = dict()
for modal_type, data in mm_datas.items():
data = [d.to_device(device) for d in data]
mm_datas[modal_type] = data
new_mm_datas[modal_type] = data
new_v.append(new_mm_datas)
v = new_v
out_dict[k] = v

return VisionModelInputs(**out_dict)

def broadcast(self):
"""broadcast inputs.
Do `dist.broadcast_object_list(inputs.to_device('meta'))`
before broadcast tensors.
"""
out_dict = dict()
for f in fields(self):
k = f.name
v = getattr(self, k)
if v is None:
continue
if isinstance(v, torch.Tensor):
v = _broadcast_tensor(v)
elif k == 'input_embedding_ranges':
v = [_broadcast_tensor(e) for e in v]
elif k == 'input_embeddings':
v = [[_broadcast_tensor(e) for e in li] for li in v]
elif k == 'input_multimodals':
new_v = []
for mm_datas in v:
new_mm_datas = dict()
for modal_type, data in mm_datas.items():
data = [d.broadcast() for d in data]
new_mm_datas[modal_type] = data
new_v.append(new_mm_datas)
v = new_v
out_dict[k] = v

return VisionModelInputs(**out_dict)
Expand Down Expand Up @@ -202,6 +246,24 @@ def to_device(self, device: str):

return ModelInputs(**out_dict)

def broadcast(self):
"""broadcast inputs.
Do `dist.broadcast_object_list(inputs.to_device('meta'))`
before broadcast tensors.
"""
out_dict = dict()
for f in fields(self):
k = f.name
v = getattr(self, k)
if isinstance(v, torch.Tensor):
v = _broadcast_tensor(v)
elif isinstance(v, VisionModelInputs):
v = v.broadcast()
out_dict[k] = v

return ModelInputs(**out_dict)


@dataclass
class StepContext:
Expand Down
63 changes: 58 additions & 5 deletions lmdeploy/pytorch/multimodal/data_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Union

import torch
from torch import Tensor
from torch import distributed as dist


class MultiModalData:
Expand All @@ -14,6 +16,14 @@ class MultiModalData:
NestedTensor = Union[Tensor, List[Tensor]]


def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'):
"""broadcast tensor."""
if value.device.type == 'meta':
value = torch.empty_like(value, device=device)
dist.broadcast(value, src)
return value


@dataclass
class MultiModalTensor:
data: NestedTensor
Expand All @@ -28,24 +38,67 @@ def __post_init__(self):

def to_device(self, device: str, non_blocking: bool = False):
"""to device."""
out_dict = dict()
for f in fields(self):
k = f.name
if k in ('data', 'meta'):
continue
v = getattr(self, k)
out_dict[k] = v

if isinstance(self.data, Tensor):
self.data = self.data.to(device=device, non_blocking=non_blocking)
data = self.data.to(device=device, non_blocking=non_blocking)
else:
data = [
d.to(device=device, non_blocking=non_blocking)
for d in self.data
]
self.data = data
out_dict['data'] = data

new_meta = None
if self.meta is not None:
new_meta = dict()
for k, v in self.meta.items():
if isinstance(v, Tensor):
v = v.to(device=device, non_blocking=non_blocking)
self.meta[k] = v
elif hasattr(v, 'to_device'):
v = v.to_device(device=device, non_blocking=non_blocking)
new_meta[k] = v

out_dict['meta'] = new_meta
return MultiModalTensor(**out_dict)

def broadcast(self):
"""broadcast inputs tensors."""
out_dict = dict()
for f in fields(self):
k = f.name
if k in ('data', 'meta'):
continue
v = getattr(self, k)
out_dict[k] = v

if isinstance(self.data, Tensor):
data = _broadcast_tensor(self.data)
else:
data = [_broadcast_tensor(d) for d in self.data]
out_dict['data'] = data

new_meta = None
if self.meta is not None:
new_meta = dict()
for k, v in self.meta.items():
if isinstance(v, Tensor):
v = _broadcast_tensor(v)
self.meta[k] = v
elif hasattr(v, 'to_device'):
assert hasattr(v, 'broadcast')
v = v.broadcast()
self.meta[k] = v
return self
new_meta[k] = v

out_dict['meta'] = new_meta
return MultiModalTensor(**out_dict)


MultiModalInputs = Dict[str, List[MultiModalTensor]]

0 comments on commit 8afb84c

Please sign in to comment.