Skip to content

Commit

Permalink
rename modules (#504)
Browse files Browse the repository at this point in the history
Co-authored-by: grimoire <[email protected]>
  • Loading branch information
q.yao and grimoire authored Oct 9, 2023
1 parent 8123c8e commit 8085fbc
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 50 deletions.
7 changes: 6 additions & 1 deletion lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
SchedulerConfig)
from lmdeploy.pytorch_poc.messages import (MessageStatus, SamplingParam,
SchedulerMessage, SchedulerSession)
from lmdeploy.pytorch_poc.models import patch
from lmdeploy.pytorch_poc.paging import Scheduler
from lmdeploy.pytorch_poc.patch import patch
from lmdeploy.pytorch_poc.utils import get_gpu_memory
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -1007,6 +1007,11 @@ def __init__(self, engine: Engine):
self.req_count = 0
self.owned_sessions: List[int] = list()

def __del__(self):
"""Destructor."""
for session_id in self.owned_sessions:
self.end(session_id)

def _send_req(self, req_type: RequestType, data: Any):
"""Send request to engine.
Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch_poc/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .context_alibi_pagedattention import alibi_paged_attention_fwd
from .context_biased_pagedattention import biased_paged_attention_fwd
from .context_flashattention_nopad import context_attention_fwd
from .context_pagedattention import paged_attention_fwd
from .alibi_pagedattention import alibi_paged_attention_fwd
from .biased_pagedattention import biased_paged_attention_fwd
from .flashattention_nopad import context_attention_fwd
from .pagedattention import paged_attention_fwd

__all__ = [
'context_attention_fwd',
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

from lmdeploy.pytorch_poc.dist_utils import (rowwise_parallelize_linear_fn,
try_to_local)
from lmdeploy.pytorch_poc.patch.functional import \
attention_forward_with_paged_attention

from .functional import attention_forward_with_paged_attention
from .llama import apply_rotary_pos_emb


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from typing import List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from transformers.modeling_outputs import BaseModelOutputWithPast

from lmdeploy.pytorch_poc.dist_utils import (rowwise_parallelize_linear_fn,
try_to_local)
from lmdeploy.pytorch_poc.kernels import paged_attention_fwd

from .functional import fill_kv_cache


def split_tensor_along_last_dim(
tensor: torch.Tensor,
Expand Down Expand Up @@ -70,6 +76,39 @@ class PatchedSelfAttention(nn.Module):
the same size.
"""

def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['query_key_value']:
sections = [
self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
]
for name, param in mod.named_parameters():
splited_param = param.split(sections, dim=0)
updated_param = []
for p in splited_param:
dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)])
dist_tensor = try_to_local(dist_tensor)
updated_param.append(dist_tensor)
param = torch.cat(updated_param)
dist_param = torch.nn.Parameter(param)
mod.register_parameter(name, dist_param)
elif mod_name in ['dense']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs

def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -90,48 +129,49 @@ def _contiguous_batching_forward(
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]

world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
origin_self = self

context = self.context.context
history_lengths = context.history_lengths

mixed_x_layer = origin_self.query_key_value(hidden_states)

if origin_self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
origin_self.num_attention_heads_per_partition *
origin_self.hidden_size_per_attention_head,
origin_self.hidden_size_per_attention_head // world_size,
origin_self.num_multi_query_groups_per_partition *
origin_self.hidden_size_per_attention_head,
origin_self.hidden_size_per_attention_head // world_size,
origin_self.num_multi_query_groups_per_partition *
origin_self.hidden_size_per_attention_head,
origin_self.hidden_size_per_attention_head // world_size,
],
dim=-1,
)
query_layer = query_layer.view(query_layer.size()[:-1] + (
origin_self.num_attention_heads_per_partition,
origin_self.num_attention_heads_per_partition // world_size,
origin_self.hidden_size_per_attention_head,
))
key_layer = key_layer.view(key_layer.size()[:-1] + (
origin_self.num_multi_query_groups_per_partition,
origin_self.num_multi_query_groups_per_partition // world_size,
origin_self.hidden_size_per_attention_head,
))
value_layer = value_layer.view(value_layer.size()[:-1] + (
origin_self.num_multi_query_groups_per_partition,
origin_self.num_multi_query_groups_per_partition // world_size,
origin_self.hidden_size_per_attention_head,
))
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
origin_self.num_attention_heads_per_partition,
origin_self.num_attention_heads_per_partition // world_size,
3 * origin_self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer,
value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
Expand All @@ -151,15 +191,14 @@ def _contiguous_batching_forward(
history_lengths = q_seq_length.new_tensor(history_lengths)
kv_seq_length = q_seq_length + history_lengths
max_seq_len = q_seq_length.max().item()

context.fill_cache(
key_layer[0],
value_layer[0],
q_start_loc,
q_seq_length,
cache_k,
cache_v,
)
fill_kv_cache(key_layer[0],
value_layer[0],
cache_k,
cache_v,
q_start_loc,
q_seq_length,
block_offsets=context.block_offsets,
history_lengths=history_lengths)

if use_cache:
kv_cache = (key_layer, value_layer)
Expand Down Expand Up @@ -225,6 +264,31 @@ def forward(
)


class MLP(nn.Module):

@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['dense_h_to_4h']:
for name, param in mod.named_parameters():
dist_tensor = distribute_tensor(param.unflatten(0, (2, -1)),
device_mesh, [Shard(1)])
dist_tensor = try_to_local(dist_tensor)
dist_param = torch.nn.Parameter(dist_tensor.flatten(0, 1))
mod.register_parameter(name, dist_param)
elif mod_name in ['dense_4h_to_h']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs


class PatchedChatGLMModel(nn.Module):

def _contiguous_batching_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def fill_kv_cache(
"""
block_size = k_caches.size(1)

history_lengths = torch.tensor(history_lengths)
if not isinstance(history_lengths, torch.Tensor):
history_lengths = torch.tensor(history_lengths)
first_free_block_offsets = history_lengths // block_size
first_token_offsets = history_lengths % block_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from lmdeploy.pytorch_poc.patch.functional import (
apply_rotary_pos_emb, attention_forward_with_paged_attention)

from .functional import (apply_rotary_pos_emb,
attention_forward_with_paged_attention)


class PatchedInternLMAttention(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from lmdeploy.pytorch_poc.patch.functional import (
apply_rotary_pos_emb, attention_forward_with_paged_attention)

from .functional import (apply_rotary_pos_emb,
attention_forward_with_paged_attention)


class LlamaAttention(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,50 @@
from lmdeploy.pytorch_poc.dist_utils import partition_module, replicate_module
from lmdeploy.utils import get_logger

LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch_poc.models'

# llama
MODULE_MAP = {
'transformers.models.llama.modeling_llama.LlamaAttention':
'lmdeploy.pytorch_poc.patch.llama.LlamaAttention',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaModel':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'transformers.models.llama.modeling_llama.LlamaMLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
}

# baichuan
MODULE_MAP.update({
'modeling_baichuan.Model':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa
'modeling_baichuan.BaichuanModel':
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanModel', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanModel', # noqa
'modeling_baichuan.Attention':
'lmdeploy.pytorch_poc.patch.baichuan.Attention', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.Attention', # noqa
'modeling_baichuan.BaichuanAttention':
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanAttention', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanAttention', # noqa
'modeling_baichuan.MLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', # noqa
})

# chatglm2
MODULE_MAP.update({
'modeling_chatglm.SelfAttention':
'lmdeploy.pytorch_poc.patch.chatglm2.PatchedSelfAttention',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedSelfAttention',
'modeling_chatglm.ChatGLMModel':
'lmdeploy.pytorch_poc.patch.chatglm2.PatchedChatGLMModel',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedChatGLMModel',
'modeling_chatglm.MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.MLP',
})

# internlm
MODULE_MAP.update({
'modeling_internlm.InternLMAttention':
'lmdeploy.pytorch_poc.patch.internlm.PatchedInternLMAttention',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.PatchedInternLMAttention',
'modeling_internlm.InternLMModel':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'modeling_internlm.InternLMMLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
})


Expand Down Expand Up @@ -169,6 +173,17 @@ def _update_model(model: torch.nn.Module):
model._update_model_fn()


def _params_to_meta(model: torch.nn.Module):
"""move parameters to meta device."""
# recursive over children
for _, child in model.named_children():
_params_to_meta(child)

for k, v in model.named_parameters(recurse=False):
model.register_parameter(
k, torch.nn.Parameter(v.to('meta'), requires_grad=False))


def _load_state_dict(
model: torch.nn.Module,
state_dict: Dict[str, torch.Tensor] = None,
Expand Down Expand Up @@ -232,13 +247,18 @@ def _load_state_dict(
if not in_state_dict:
continue

if rank == 0:
new_param = torch.nn.Parameter(state_dict[full_k],
requires_grad=False).to(device)
else:
new_param = torch.nn.Parameter(torch.empty_like(v, device=device),
requires_grad=False)
model.register_parameter(k, new_param)
param_names = [
name for name, _ in model.named_parameters(recurse=False)
]
if k in param_names:
if rank == 0:
new_param = torch.nn.Parameter(state_dict[full_k].to(v.dtype),
requires_grad=False).to(device)
else:
new_param = torch.nn.Parameter(torch.empty_like(v,
device=device),
requires_grad=False)
model.register_parameter(k, new_param)

# distribute module
if world_size > 1:
Expand Down Expand Up @@ -312,6 +332,7 @@ def patch(

# load checkpoint
if checkpoints is not None:
_params_to_meta(model)
device_mesh = DeviceMesh('cuda', list(range(world_size)))
for ckpt in checkpoints:
if rank == 0:
Expand Down

0 comments on commit 8085fbc

Please sign in to comment.