diff --git a/lmdeploy/pytorch_poc/engine/engine.py b/lmdeploy/pytorch_poc/engine/engine.py index 0ae4b51ec1..045cba7a24 100644 --- a/lmdeploy/pytorch_poc/engine/engine.py +++ b/lmdeploy/pytorch_poc/engine/engine.py @@ -25,8 +25,8 @@ SchedulerConfig) from lmdeploy.pytorch_poc.messages import (MessageStatus, SamplingParam, SchedulerMessage, SchedulerSession) +from lmdeploy.pytorch_poc.models import patch from lmdeploy.pytorch_poc.paging import Scheduler -from lmdeploy.pytorch_poc.patch import patch from lmdeploy.pytorch_poc.utils import get_gpu_memory from lmdeploy.utils import get_logger @@ -1007,6 +1007,11 @@ def __init__(self, engine: Engine): self.req_count = 0 self.owned_sessions: List[int] = list() + def __del__(self): + """Destructor.""" + for session_id in self.owned_sessions: + self.end(session_id) + def _send_req(self, req_type: RequestType, data: Any): """Send request to engine. diff --git a/lmdeploy/pytorch_poc/kernels/__init__.py b/lmdeploy/pytorch_poc/kernels/__init__.py index 5b24fd7e0e..ffbf3da8a9 100644 --- a/lmdeploy/pytorch_poc/kernels/__init__.py +++ b/lmdeploy/pytorch_poc/kernels/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .context_alibi_pagedattention import alibi_paged_attention_fwd -from .context_biased_pagedattention import biased_paged_attention_fwd -from .context_flashattention_nopad import context_attention_fwd -from .context_pagedattention import paged_attention_fwd +from .alibi_pagedattention import alibi_paged_attention_fwd +from .biased_pagedattention import biased_paged_attention_fwd +from .flashattention_nopad import context_attention_fwd +from .pagedattention import paged_attention_fwd __all__ = [ 'context_attention_fwd', diff --git a/lmdeploy/pytorch_poc/kernels/context_alibi_pagedattention.py b/lmdeploy/pytorch_poc/kernels/alibi_pagedattention.py similarity index 100% rename from lmdeploy/pytorch_poc/kernels/context_alibi_pagedattention.py rename to lmdeploy/pytorch_poc/kernels/alibi_pagedattention.py diff --git a/lmdeploy/pytorch_poc/kernels/context_biased_pagedattention.py b/lmdeploy/pytorch_poc/kernels/biased_pagedattention.py similarity index 100% rename from lmdeploy/pytorch_poc/kernels/context_biased_pagedattention.py rename to lmdeploy/pytorch_poc/kernels/biased_pagedattention.py diff --git a/lmdeploy/pytorch_poc/kernels/context_flashattention_nopad.py b/lmdeploy/pytorch_poc/kernels/flashattention_nopad.py similarity index 100% rename from lmdeploy/pytorch_poc/kernels/context_flashattention_nopad.py rename to lmdeploy/pytorch_poc/kernels/flashattention_nopad.py diff --git a/lmdeploy/pytorch_poc/kernels/context_pagedattention.py b/lmdeploy/pytorch_poc/kernels/pagedattention.py similarity index 100% rename from lmdeploy/pytorch_poc/kernels/context_pagedattention.py rename to lmdeploy/pytorch_poc/kernels/pagedattention.py diff --git a/lmdeploy/pytorch_poc/patch/__init__.py b/lmdeploy/pytorch_poc/models/__init__.py similarity index 100% rename from lmdeploy/pytorch_poc/patch/__init__.py rename to lmdeploy/pytorch_poc/models/__init__.py diff --git a/lmdeploy/pytorch_poc/patch/baichuan.py b/lmdeploy/pytorch_poc/models/baichuan.py similarity index 99% rename from lmdeploy/pytorch_poc/patch/baichuan.py rename to lmdeploy/pytorch_poc/models/baichuan.py index 8471e5f3c7..575edd7789 100644 --- a/lmdeploy/pytorch_poc/patch/baichuan.py +++ b/lmdeploy/pytorch_poc/models/baichuan.py @@ -9,9 +9,8 @@ from lmdeploy.pytorch_poc.dist_utils import (rowwise_parallelize_linear_fn, try_to_local) -from lmdeploy.pytorch_poc.patch.functional import \ - attention_forward_with_paged_attention +from .functional import attention_forward_with_paged_attention from .llama import apply_rotary_pos_emb diff --git a/lmdeploy/pytorch_poc/patch/chatglm2.py b/lmdeploy/pytorch_poc/models/chatglm2.py similarity index 75% rename from lmdeploy/pytorch_poc/patch/chatglm2.py rename to lmdeploy/pytorch_poc/models/chatglm2.py index fcb72b9d6a..7aecdd8fae 100644 --- a/lmdeploy/pytorch_poc/patch/chatglm2.py +++ b/lmdeploy/pytorch_poc/models/chatglm2.py @@ -4,12 +4,18 @@ from typing import List, Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn import torch.utils.checkpoint +from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor from transformers.modeling_outputs import BaseModelOutputWithPast +from lmdeploy.pytorch_poc.dist_utils import (rowwise_parallelize_linear_fn, + try_to_local) from lmdeploy.pytorch_poc.kernels import paged_attention_fwd +from .functional import fill_kv_cache + def split_tensor_along_last_dim( tensor: torch.Tensor, @@ -70,6 +76,39 @@ class PatchedSelfAttention(nn.Module): the same size. """ + def _distribute_partition_fn(self, mod_name: str, mod: nn.Module, + device_mesh: DeviceMesh): + """Distribution partition callback.""" + if mod_name in ['query_key_value']: + sections = [ + self.num_attention_heads_per_partition * + self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * + self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * + self.hidden_size_per_attention_head, + ] + for name, param in mod.named_parameters(): + splited_param = param.split(sections, dim=0) + updated_param = [] + for p in splited_param: + dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)]) + dist_tensor = try_to_local(dist_tensor) + updated_param.append(dist_tensor) + param = torch.cat(updated_param) + dist_param = torch.nn.Parameter(param) + mod.register_parameter(name, dist_param) + elif mod_name in ['dense']: + rowwise_parallelize_linear_fn(mod, + device_mesh=device_mesh, + to_local=True) + + @classmethod + def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): + """Distribution output hook.""" + dist.all_reduce(outputs[0]) + return outputs + def _contiguous_batching_forward( self, hidden_states: torch.Tensor, @@ -90,40 +129,42 @@ def _contiguous_batching_forward( # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() origin_self = self context = self.context.context history_lengths = context.history_lengths - mixed_x_layer = origin_self.query_key_value(hidden_states) if origin_self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ origin_self.num_attention_heads_per_partition * - origin_self.hidden_size_per_attention_head, + origin_self.hidden_size_per_attention_head // world_size, origin_self.num_multi_query_groups_per_partition * - origin_self.hidden_size_per_attention_head, + origin_self.hidden_size_per_attention_head // world_size, origin_self.num_multi_query_groups_per_partition * - origin_self.hidden_size_per_attention_head, + origin_self.hidden_size_per_attention_head // world_size, ], dim=-1, ) query_layer = query_layer.view(query_layer.size()[:-1] + ( - origin_self.num_attention_heads_per_partition, + origin_self.num_attention_heads_per_partition // world_size, origin_self.hidden_size_per_attention_head, )) key_layer = key_layer.view(key_layer.size()[:-1] + ( - origin_self.num_multi_query_groups_per_partition, + origin_self.num_multi_query_groups_per_partition // world_size, origin_self.hidden_size_per_attention_head, )) value_layer = value_layer.view(value_layer.size()[:-1] + ( - origin_self.num_multi_query_groups_per_partition, + origin_self.num_multi_query_groups_per_partition // world_size, origin_self.hidden_size_per_attention_head, )) else: new_tensor_shape = mixed_x_layer.size()[:-1] + ( - origin_self.num_attention_heads_per_partition, + origin_self.num_attention_heads_per_partition // world_size, 3 * origin_self.hidden_size_per_attention_head, ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) @@ -131,7 +172,6 @@ def _contiguous_batching_forward( # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) @@ -151,15 +191,14 @@ def _contiguous_batching_forward( history_lengths = q_seq_length.new_tensor(history_lengths) kv_seq_length = q_seq_length + history_lengths max_seq_len = q_seq_length.max().item() - - context.fill_cache( - key_layer[0], - value_layer[0], - q_start_loc, - q_seq_length, - cache_k, - cache_v, - ) + fill_kv_cache(key_layer[0], + value_layer[0], + cache_k, + cache_v, + q_start_loc, + q_seq_length, + block_offsets=context.block_offsets, + history_lengths=history_lengths) if use_cache: kv_cache = (key_layer, value_layer) @@ -225,6 +264,31 @@ def forward( ) +class MLP(nn.Module): + + @classmethod + def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module, + device_mesh: DeviceMesh): + """Distribution partition callback.""" + if mod_name in ['dense_h_to_4h']: + for name, param in mod.named_parameters(): + dist_tensor = distribute_tensor(param.unflatten(0, (2, -1)), + device_mesh, [Shard(1)]) + dist_tensor = try_to_local(dist_tensor) + dist_param = torch.nn.Parameter(dist_tensor.flatten(0, 1)) + mod.register_parameter(name, dist_param) + elif mod_name in ['dense_4h_to_h']: + rowwise_parallelize_linear_fn(mod, + device_mesh=device_mesh, + to_local=True) + + @classmethod + def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): + """Distribution output hook.""" + dist.all_reduce(outputs) + return outputs + + class PatchedChatGLMModel(nn.Module): def _contiguous_batching_forward( diff --git a/lmdeploy/pytorch_poc/patch/functional.py b/lmdeploy/pytorch_poc/models/functional.py similarity index 98% rename from lmdeploy/pytorch_poc/patch/functional.py rename to lmdeploy/pytorch_poc/models/functional.py index 5d6bba8323..d1598590a3 100644 --- a/lmdeploy/pytorch_poc/patch/functional.py +++ b/lmdeploy/pytorch_poc/models/functional.py @@ -78,7 +78,8 @@ def fill_kv_cache( """ block_size = k_caches.size(1) - history_lengths = torch.tensor(history_lengths) + if not isinstance(history_lengths, torch.Tensor): + history_lengths = torch.tensor(history_lengths) first_free_block_offsets = history_lengths // block_size first_token_offsets = history_lengths % block_size diff --git a/lmdeploy/pytorch_poc/patch/internlm.py b/lmdeploy/pytorch_poc/models/internlm.py similarity index 96% rename from lmdeploy/pytorch_poc/patch/internlm.py rename to lmdeploy/pytorch_poc/models/internlm.py index f84fb7d212..db34e7ef1d 100644 --- a/lmdeploy/pytorch_poc/patch/internlm.py +++ b/lmdeploy/pytorch_poc/models/internlm.py @@ -8,8 +8,9 @@ from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn, rowwise_parallelize_linear_fn) -from lmdeploy.pytorch_poc.patch.functional import ( - apply_rotary_pos_emb, attention_forward_with_paged_attention) + +from .functional import (apply_rotary_pos_emb, + attention_forward_with_paged_attention) class PatchedInternLMAttention(nn.Module): diff --git a/lmdeploy/pytorch_poc/patch/llama.py b/lmdeploy/pytorch_poc/models/llama.py similarity index 98% rename from lmdeploy/pytorch_poc/patch/llama.py rename to lmdeploy/pytorch_poc/models/llama.py index b4cd93e0ff..92d13e3372 100644 --- a/lmdeploy/pytorch_poc/patch/llama.py +++ b/lmdeploy/pytorch_poc/models/llama.py @@ -9,8 +9,9 @@ from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn, rowwise_parallelize_linear_fn) -from lmdeploy.pytorch_poc.patch.functional import ( - apply_rotary_pos_emb, attention_forward_with_paged_attention) + +from .functional import (apply_rotary_pos_emb, + attention_forward_with_paged_attention) class LlamaAttention(nn.Module): diff --git a/lmdeploy/pytorch_poc/patch/patch.py b/lmdeploy/pytorch_poc/models/patch.py similarity index 83% rename from lmdeploy/pytorch_poc/patch/patch.py rename to lmdeploy/pytorch_poc/models/patch.py index a7862136b2..ecefc49ed7 100644 --- a/lmdeploy/pytorch_poc/patch/patch.py +++ b/lmdeploy/pytorch_poc/models/patch.py @@ -13,46 +13,50 @@ from lmdeploy.pytorch_poc.dist_utils import partition_module, replicate_module from lmdeploy.utils import get_logger +LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch_poc.models' + # llama MODULE_MAP = { 'transformers.models.llama.modeling_llama.LlamaAttention': - 'lmdeploy.pytorch_poc.patch.llama.LlamaAttention', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', 'transformers.models.llama.modeling_llama.LlamaModel': - 'lmdeploy.pytorch_poc.patch.llama.LlamaModel', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', 'transformers.models.llama.modeling_llama.LlamaMLP': - 'lmdeploy.pytorch_poc.patch.llama.LlamaMLP', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', } # baichuan MODULE_MAP.update({ 'modeling_baichuan.Model': - 'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa 'modeling_baichuan.BaichuanModel': - 'lmdeploy.pytorch_poc.patch.baichuan.BaichuanModel', # noqa + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanModel', # noqa 'modeling_baichuan.Attention': - 'lmdeploy.pytorch_poc.patch.baichuan.Attention', # noqa + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.Attention', # noqa 'modeling_baichuan.BaichuanAttention': - 'lmdeploy.pytorch_poc.patch.baichuan.BaichuanAttention', # noqa + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanAttention', # noqa 'modeling_baichuan.MLP': - 'lmdeploy.pytorch_poc.patch.llama.LlamaMLP', # noqa + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', # noqa }) # chatglm2 MODULE_MAP.update({ 'modeling_chatglm.SelfAttention': - 'lmdeploy.pytorch_poc.patch.chatglm2.PatchedSelfAttention', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedSelfAttention', 'modeling_chatglm.ChatGLMModel': - 'lmdeploy.pytorch_poc.patch.chatglm2.PatchedChatGLMModel', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedChatGLMModel', + 'modeling_chatglm.MLP': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.MLP', }) # internlm MODULE_MAP.update({ 'modeling_internlm.InternLMAttention': - 'lmdeploy.pytorch_poc.patch.internlm.PatchedInternLMAttention', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.PatchedInternLMAttention', 'modeling_internlm.InternLMModel': - 'lmdeploy.pytorch_poc.patch.llama.LlamaModel', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', 'modeling_internlm.InternLMMLP': - 'lmdeploy.pytorch_poc.patch.llama.LlamaMLP', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', }) @@ -169,6 +173,17 @@ def _update_model(model: torch.nn.Module): model._update_model_fn() +def _params_to_meta(model: torch.nn.Module): + """move parameters to meta device.""" + # recursive over children + for _, child in model.named_children(): + _params_to_meta(child) + + for k, v in model.named_parameters(recurse=False): + model.register_parameter( + k, torch.nn.Parameter(v.to('meta'), requires_grad=False)) + + def _load_state_dict( model: torch.nn.Module, state_dict: Dict[str, torch.Tensor] = None, @@ -232,13 +247,18 @@ def _load_state_dict( if not in_state_dict: continue - if rank == 0: - new_param = torch.nn.Parameter(state_dict[full_k], - requires_grad=False).to(device) - else: - new_param = torch.nn.Parameter(torch.empty_like(v, device=device), - requires_grad=False) - model.register_parameter(k, new_param) + param_names = [ + name for name, _ in model.named_parameters(recurse=False) + ] + if k in param_names: + if rank == 0: + new_param = torch.nn.Parameter(state_dict[full_k].to(v.dtype), + requires_grad=False).to(device) + else: + new_param = torch.nn.Parameter(torch.empty_like(v, + device=device), + requires_grad=False) + model.register_parameter(k, new_param) # distribute module if world_size > 1: @@ -312,6 +332,7 @@ def patch( # load checkpoint if checkpoints is not None: + _params_to_meta(model) device_mesh = DeviceMesh('cuda', list(range(world_size))) for ckpt in checkpoints: if rank == 0: