Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ class PytorchEngineConfig:
session_len (int): Max session length. Default None.
max_batch_size (int): Max batch size. If it is not specified,
the engine will automatically set it according to the device
attn_tp_size (int): tp size for attention, only works for dp>1
mlp_tp_size (int): tp size for mlp, only works for dp>1
moe_tp_size (int): tp size for moe, only works for dp>1
cache_max_entry_count (float): the percentage of gpu memory occupied
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
it defaults to 0.8, signifying the percentage of FREE GPU memory
Expand Down Expand Up @@ -350,6 +353,9 @@ class PytorchEngineConfig:
ep: int = 1
session_len: int = None
max_batch_size: int = None
attn_tp_size: int = None
mlp_tp_size: int = None
moe_tp_size: int = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/pytorch/backends/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def update_weights(self,
return qweight, scales, qzeros, bias

@abstractmethod
def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
raise NotImplementedError

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearBlockedF8Impl(ABC):
Expand All @@ -19,6 +20,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.flash_attention_fwd = flash_attention_fwd

# for alibi attention
world_size, rank = get_tp_world_rank()
world_size, rank = get_tp_world_rank('attn')
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
self.block_sparse_size = block_sparse_size
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out_features = scales.size(1)
out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
20 changes: 6 additions & 14 deletions lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@
logger = get_logger('lmdeploy')


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""Reduce scatter."""
outs = out.split(tp_sizes, -2)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
return out


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
"""Triton linear blocked f8 implementation."""

Expand All @@ -37,6 +28,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -52,7 +44,7 @@ def forward(self,

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)
return out
Expand Down Expand Up @@ -117,6 +109,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -128,12 +121,11 @@ def forward(self,
out = out[:x.size(0)]
if bias is not None:
out += bias
out = out.unflatten(0, x_shape[:-1])

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)

out = out.unflatten(0, x_shape[:-1])
dist.all_reduce(out, group=group)
return out
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def update_inputs(self, inputs):
meta = self.get_meta()
padding_batch_size = meta.padding_batch_size
tp_size = self._get_capture_tokens(padding_batch_size)
dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
dp_meta.sync_tp_size(tp_size)
return inputs

def get_capture_batch_sizes(self) -> List[int]:
Expand Down
44 changes: 44 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,45 @@ def ep_expert_list(self, world_size: int, rank: int):
else:
return super().ep_expert_list(world_size=world_size, rank=rank)

def _split_inputs_by_attn_tp(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
):
"""Split input by attn tp."""
dist_ctx = get_dist_manager().current_context()
attn_tp = dist_ctx.dist_config.attn_tp
attn_rank = dist_ctx.attn_tp_group.rank
num_states = hidden_states.size(0)

if attn_tp == 1 or attn_tp > num_states:
return hidden_states, topk_weights, topk_ids, None

# split size
base = num_states // attn_tp
remain = num_states % attn_tp
split_size = [base + 1] * remain + [base] * (attn_tp - remain)

# split inputs
hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank]
topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank]
topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank]

return hidden_states, topk_weights, topk_ids, split_size

def _gather_outputs_by_attn_tp(self, out_states: torch.Tensor, split_size: List[int]):
"""Gather output by attn tp."""
if split_size is None:
return out_states

dist_ctx = get_dist_manager().current_context()
gpu_group = dist_ctx.attn_tp_group.gpu_group
new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1]))
new_out_states_list = list(new_out_states.split(split_size, dim=0))
dist.all_gather(new_out_states_list, out_states, group=gpu_group)
return new_out_states

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
Expand All @@ -633,12 +672,17 @@ def forward(self,
act_func: Callable = None,
**kwargs):
"""forward."""
hidden_states, topk_weights, topk_ids, split_size = self._split_inputs_by_attn_tp(
hidden_states, topk_weights, topk_ids)

topk_weights = self.do_renormalize(topk_weights)
step_ctx = get_step_ctx_manager().current_context()
low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm
moe = self.fusedmoe_build(low_latency_mode)
out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,
down_scale, expert_list)

out_states = self._gather_outputs_by_attn_tp(out_states, split_size)
return out_states

def do_renormalize(self, topk_weights):
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
if isinstance(x, torch.Tensor):
input_quant, input_scale = per_token_quant_int8(x, 1e-7, quant_dtype=self.quant_dtype)
Expand All @@ -79,7 +80,7 @@ def forward(self,
bias=bias)

if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/default/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out_shape = x.shape[:-1] + (self.out_features, )
input_dtype = x.dtype
Expand All @@ -77,7 +78,7 @@ def forward(self,
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
22 changes: 5 additions & 17 deletions lmdeploy/pytorch/backends/default/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,12 @@
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F

import lmdeploy.pytorch.distributed as dist

from ..linear import LinearBuilder, LinearImpl


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""Reduce scatter."""
out = out.transpose(0, -2)
if not out.is_contiguous():
out = out.contiguous()
outs = out.split(tp_sizes, 0)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
out = out.transpose(0, -2)
return out


class DefaultLinearImpl(LinearImpl):
"""Linear implementation api."""

Expand All @@ -30,15 +16,17 @@ def forward(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: dist.ProcessGroup = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
out = F.linear(x, weight, bias)
if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
from lmdeploy.pytorch.distributed import reduce_scatter_by_tp_sizes
out = reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out = awq_linear(x, qweight, scales, qzeros, bias, all_reduce, self.group_size)
return out
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/dlinfer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import List, Optional

import torch
import torch.distributed as dist

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.kernels.dlinfer import linear

from ..linear import LinearBuilder, LinearImpl
Expand Down Expand Up @@ -32,12 +32,13 @@ def forward(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: dist.ProcessGroup = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
out = linear(x, weight, bias, False)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/dlinfer/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
if isinstance(x, torch.Tensor):
input_quant, input_scale = dynamic_quant(x, self.quant_dtype)
Expand All @@ -46,7 +47,7 @@ def forward(self,

out = linear_w8a8(input_quant, weight, input_scale, scale, self.out_dtype, self.quant_dtype, bias)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearImpl(ABC):
Expand All @@ -18,6 +19,7 @@ def forward(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: dist.ProcessGroup = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/check_env/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def check(self):
if self.device_type == 'cuda' and not is_dlblas_installed():
self.log_and_exit(mod_name='Dist',
message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).')
if self.dp % self.ep != 0:
if self.ep % self.dp != 0:
self.log_and_exit(mod_name='Dist',
message=f'ep>1 requires dp % ep == 0. Get dp={self.dp} and ep={self.ep}.')
message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.')
elif self.dist_config.enable_eplb:
self.log_and_exit(mod_name='Dist', message=f'Enable eplb requires ep > 1. Get ep={self.ep}.')

Expand Down
Loading