Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch dp support (WIP) #3207

Draft
wants to merge 89 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
0d006e2
better dist context
grimoire Jan 15, 2025
e5baae6
can not exit
grimoire Jan 16, 2025
e561e6d
Merge branch 'main' into torch-multinode
grimoire Jan 16, 2025
8a61faf
multinode support
grimoire Jan 17, 2025
ae7a742
better exception
grimoire Jan 17, 2025
2f75ee9
Merge branch 'main' into torch-multinode
grimoire Jan 21, 2025
709d293
merge main
grimoire Jan 21, 2025
cefbf98
refactor
grimoire Jan 22, 2025
ada2cc3
fix local rank
grimoire Jan 23, 2025
713aad5
replace group
grimoire Jan 24, 2025
2ec93c8
merge main
grimoire Feb 6, 2025
374b52a
fix dist
grimoire Feb 6, 2025
727ea86
remove useless code
grimoire Feb 6, 2025
b454c39
remove finish flag
grimoire Feb 7, 2025
3baab2d
Merge branch 'main' into torch-multinode
grimoire Feb 10, 2025
742af3b
refactor engine and model agent
grimoire Feb 10, 2025
420ab0f
uni executor
grimoire Feb 10, 2025
e2b82a7
wip
grimoire Feb 11, 2025
40251d5
tp
grimoire Feb 11, 2025
300263e
fix
grimoire Feb 11, 2025
ec10731
less async
grimoire Feb 12, 2025
2ee3ca8
Merge branch 'main' into torch-multinode-v2
grimoire Feb 12, 2025
ece1313
circle buf
grimoire Feb 12, 2025
911d9ab
event per block
grimoire Feb 12, 2025
3b6fa54
fast mp
grimoire Feb 12, 2025
a560cd9
fix error handler
grimoire Feb 12, 2025
6071e55
remove safe wait
grimoire Feb 12, 2025
13d1187
context in model agent
grimoire Feb 12, 2025
894edb8
fix on stop
grimoire Feb 13, 2025
1b8d35e
check before init
grimoire Feb 13, 2025
6db53ab
support close
grimoire Feb 13, 2025
c3bf202
fix tp close
grimoire Feb 14, 2025
7ab7ff3
ray ver0
grimoire Feb 14, 2025
8bfd5cb
fix close
grimoire Feb 15, 2025
5596b53
fix remote code
grimoire Feb 16, 2025
343eb78
optimize ray
grimoire Feb 16, 2025
f3444ce
better checker and logger
grimoire Feb 16, 2025
a35dc75
pack tensor
grimoire Feb 16, 2025
f47996a
auto check dist
grimoire Feb 17, 2025
da27d28
fix mp gloo
grimoire Feb 17, 2025
fc02b67
Merge branch 'main' into torch-multinode-v2
grimoire Feb 17, 2025
5366f5f
add timer tools
grimoire Feb 17, 2025
230228d
better scheduler
grimoire Feb 18, 2025
14bb0ae
fix mp hang
grimoire Feb 18, 2025
f66510a
fix mp
grimoire Feb 18, 2025
8ee9e86
fix chat
grimoire Feb 18, 2025
f293e64
less output
grimoire Feb 19, 2025
f41bc0b
Merge branch 'main' into torch-multinode-v2
grimoire Feb 19, 2025
b74e649
merge main
grimoire Feb 19, 2025
e8d1606
optimize ray get output
grimoire Feb 19, 2025
4ce9343
remove nsight runtime env
grimoire Feb 19, 2025
fbdffd2
dag
grimoire Feb 19, 2025
5871c27
optimize mp & lint
grimoire Feb 20, 2025
5c46f97
merge main
grimoire Feb 20, 2025
adfa3f3
optimize mp
grimoire Feb 20, 2025
0f63336
add base workerwrapper
grimoire Feb 21, 2025
ff64f11
fix gather, update flags
grimoire Feb 21, 2025
4ba2561
better return mask
grimoire Feb 21, 2025
026ebe4
add choice
grimoire Feb 21, 2025
a4639e9
enable mp,ray with worldsize=1
grimoire Feb 21, 2025
c76dcfc
fix mp exit
grimoire Feb 21, 2025
c5661ed
fix mp vlm
grimoire Feb 21, 2025
af83ff2
chat exit
grimoire Feb 21, 2025
f1a8a08
add docs
grimoire Feb 21, 2025
730c9e9
lint
grimoire Feb 24, 2025
ef2811b
doc
grimoire Feb 24, 2025
82f0f21
dp check
grimoire Feb 24, 2025
d28d690
fix blocked fp8 moe
grimoire Feb 24, 2025
95b2249
remove mask
grimoire Feb 24, 2025
efcb5df
support dp, async
grimoire Feb 25, 2025
d34145d
remove debug line
grimoire Feb 25, 2025
8960d29
fix model tp
grimoire Feb 25, 2025
411d12e
Merge branch 'main' into torch-multinode-v2
grimoire Feb 27, 2025
3481f7a
support sync execute
grimoire Feb 27, 2025
8f8e708
Merge branch 'torch-multinode-v2' into torch-dp-support
grimoire Feb 27, 2025
3647cae
fix chat stopwords
grimoire Feb 28, 2025
e371647
refactor chat
grimoire Feb 28, 2025
fae62ad
merge main
grimoire Mar 3, 2025
ade7bce
add warmup
grimoire Mar 3, 2025
8d65226
Merge branch 'torch-multinode-v2' into torch-dp-support
grimoire Mar 3, 2025
49e7f4c
merge main
grimoire Mar 4, 2025
e4339e7
disable warmup
grimoire Mar 4, 2025
f7515e4
dp support
grimoire Mar 6, 2025
1f10133
Merge branch 'main' into torch-dp-support
grimoire Mar 6, 2025
32cd198
fix ut, merge main, force eager
grimoire Mar 6, 2025
d4a42ee
support qwen2/internlm2/internlm3
grimoire Mar 6, 2025
db8b965
merge main
grimoire Mar 9, 2025
4ab0ce2
support blocked fp8 all gather
grimoire Mar 9, 2025
c0cd790
add more model support
grimoire Mar 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ class PytorchEngineConfig:
The `auto` option will use FP16 precision for FP32 and FP16
models, and BF16 precision for BF16 models.
tp (int): Tensor Parallelism. default 1.
dp (int): Data Parallelism. default 1.
dp_rank (int): rank of dp.
session_len (int): Max session length. Default None.
max_batch_size (int): Max batch size. If it is not specified,
the engine will automatically set it according to the device
Expand Down Expand Up @@ -280,9 +282,13 @@ class PytorchEngineConfig:
bit, set it to 4 or 8, respectively
distributed_executor_backend (str): backend of distributed backend,
options: ['uni', 'mp', 'ray']
should_execute_dummy_batch (str): execute dummy batch when if dp rank
has no request.
"""
dtype: str = 'auto'
tp: int = 1
dp: int = 1
dp_rank: int = 0
session_len: int = None
max_batch_size: int = None
cache_max_entry_count: float = 0.8
Expand All @@ -301,11 +307,13 @@ class PytorchEngineConfig:
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
distributed_executor_backend: str = None
should_execute_dummy_batch: bool = False

def __post_init__(self):
"""Check input validation."""
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'invalid tp'
assert self.dp >= 1, 'invalid dp'
assert 0 < self.cache_max_entry_count < 1, \
'invalid cache_max_entry_count'
assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks'
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/backends/blockedf8_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional

import torch

Expand All @@ -18,7 +18,9 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.distributed import get_tp_world_rank

from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
self.flash_attention_fwd = flash_attention_fwd

# for alibi attention
world_size, rank = get_world_rank()
world_size, rank = get_tp_world_rank()
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size

Expand Down
20 changes: 17 additions & 3 deletions lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from typing import List, Optional

import torch

Expand All @@ -9,6 +9,15 @@
from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""reduce scatter."""
outs = out.split(tp_sizes, -2)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
return out


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
"""triton linear blocked f8 implementation."""

Expand All @@ -23,7 +32,9 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
x_shape = x.shape
x = x.flatten(0, -2)
Expand All @@ -34,7 +45,10 @@ def forward(self,
out += bias

if all_reduce:
dist.all_reduce(out)
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
else:
dist.all_reduce(out)

out = out.unflatten(0, x_shape[:-1])
return out
Expand Down
24 changes: 21 additions & 3 deletions lmdeploy/pytorch/backends/default/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from typing import List, Optional

import torch
import torch.nn.functional as F
Expand All @@ -9,14 +9,32 @@
from ..linear import LinearBuilder, LinearImpl


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""reduce scatter."""
outs = out.split(tp_sizes, -2)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
return out


class DefaultLinearImpl(LinearImpl):
"""Linear implementation api."""

def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
out = F.linear(x, weight, bias)
if all_reduce:
dist.all_reduce(out)
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
else:
dist.all_reduce(out)
return out


Expand Down
10 changes: 8 additions & 2 deletions lmdeploy/pytorch/backends/dlinfer/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional
from typing import List, Optional

import torch

Expand All @@ -18,7 +18,13 @@ def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = No
weight = weight.data.t().contiguous()
return weight, bias

def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
return linear(x, weight, bias, all_reduce)

Expand Down
10 changes: 8 additions & 2 deletions lmdeploy/pytorch/backends/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional

import torch

Expand All @@ -13,7 +13,13 @@ def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = No
return weight, bias

@abstractmethod
def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
raise NotImplementedError

Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/check_env/dist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from lmdeploy.pytorch.config import DistConfig

from .base import BaseChecker


Expand All @@ -10,7 +12,8 @@ def __init__(self, tp: int, dp: int, distributed_executor_backend: str, device_t
super().__init__(logger)
self.tp = tp
self.dp = dp
self.world_size = tp * dp
self.dist_config = DistConfig(dp=dp, tp=tp)
self.world_size = self.dist_config.world_size
self.distributed_executor_backend = distributed_executor_backend
self.device_type = device_type

Expand Down
45 changes: 42 additions & 3 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,35 @@ def __post_init__(self):
self.enable_prefix_caching = False


@dataclass
class DistConfig:
dp: int = 1
tp: int = 1
ep: int = 1
dp_rank: int = 0
world_size: int = None
attn_config: 'DistConfig' = None

def __post_init__(self):
"""post init."""
assert self.dp_rank < self.dp
assert self.dp >= 1
if self.dp == 1:
world_size = max(self.tp, self.ep)
attn_config = self
else:
world_size = self.dp
attn_config = DistConfig(dp=1, tp=1, ep=1, dp_rank=0)
self.world_size = world_size
self.attn_config = attn_config

def need_dummy_batch(self):
"""need dummy batch."""
if self.dp == 1:
return False
return self.tp > 1 or self.ep > 1


@dataclass
class ModelConfig:
"""Config of model."""
Expand Down Expand Up @@ -118,7 +147,7 @@ def from_pretrained(cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto',
tp: int = 1):
dist_config: DistConfig = None):
"""Instantiate one of the configuration classes of the library from a
pretrained model configuration.

Expand All @@ -134,12 +163,22 @@ def from_pretrained(cls,
if getattr(hf_config, 'model_type', None) in ['phi3']:
# phi3 + trust_remote_code leads to error when tp.
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
return cls.from_hf_config(hf_config, pretrained_model_name_or_path, dtype=dtype, tp=tp)
return cls.from_hf_config(hf_config, pretrained_model_name_or_path, dtype=dtype, dist_config=dist_config)

@classmethod
def from_hf_config(cls, hf_config: Any, model_path: str = None, dtype: str = 'auto', tp: int = 1):
def from_hf_config(cls,
hf_config: Any,
model_path: str = None,
dtype: str = 'auto',
dist_config: DistConfig = None):
"""from huggingface config."""
from lmdeploy.pytorch.configurations import AutoModelConfigBuilder
if dist_config is None:
dist_config = DistConfig()
if dist_config.dp == 1:
tp = dist_config.tp
else:
tp = 1

model_config = AutoModelConfigBuilder.build(hf_config, model_path, tp=tp)

Expand Down
Loading
Loading