Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename modules #504

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
SchedulerConfig)
from lmdeploy.pytorch_poc.messages import (MessageStatus, SamplingParam,
SchedulerMessage, SchedulerSession)
from lmdeploy.pytorch_poc.models import patch
from lmdeploy.pytorch_poc.paging import Scheduler
from lmdeploy.pytorch_poc.patch import patch
from lmdeploy.pytorch_poc.utils import get_gpu_memory
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -1007,6 +1007,11 @@ def __init__(self, engine: Engine):
self.req_count = 0
self.owned_sessions: List[int] = list()

def __del__(self):
"""Destructor."""
for session_id in self.owned_sessions:
self.end(session_id)

def _send_req(self, req_type: RequestType, data: Any):
"""Send request to engine.

Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch_poc/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .context_alibi_pagedattention import alibi_paged_attention_fwd
from .context_biased_pagedattention import biased_paged_attention_fwd
from .context_flashattention_nopad import context_attention_fwd
from .context_pagedattention import paged_attention_fwd
from .alibi_pagedattention import alibi_paged_attention_fwd
from .biased_pagedattention import biased_paged_attention_fwd
from .flashattention_nopad import context_attention_fwd
from .pagedattention import paged_attention_fwd

__all__ = [
'context_attention_fwd',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

from lmdeploy.pytorch_poc.dist_utils import (rowwise_parallelize_linear_fn,
try_to_local)
from lmdeploy.pytorch_poc.patch.functional import \
attention_forward_with_paged_attention

from .functional import attention_forward_with_paged_attention
from .llama import apply_rotary_pos_emb


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from typing import List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from transformers.modeling_outputs import BaseModelOutputWithPast

from lmdeploy.pytorch_poc.dist_utils import (rowwise_parallelize_linear_fn,
try_to_local)
from lmdeploy.pytorch_poc.kernels import paged_attention_fwd

from .functional import fill_kv_cache


def split_tensor_along_last_dim(
tensor: torch.Tensor,
Expand Down Expand Up @@ -70,6 +76,39 @@ class PatchedSelfAttention(nn.Module):
the same size.
"""

def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['query_key_value']:
sections = [
self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
]
for name, param in mod.named_parameters():
splited_param = param.split(sections, dim=0)
updated_param = []
for p in splited_param:
dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)])
dist_tensor = try_to_local(dist_tensor)
updated_param.append(dist_tensor)
param = torch.cat(updated_param)
dist_param = torch.nn.Parameter(param)
mod.register_parameter(name, dist_param)
elif mod_name in ['dense']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs

def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -90,48 +129,49 @@ def _contiguous_batching_forward(
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]

world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
origin_self = self

context = self.context.context
history_lengths = context.history_lengths

mixed_x_layer = origin_self.query_key_value(hidden_states)

if origin_self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
origin_self.num_attention_heads_per_partition *
origin_self.hidden_size_per_attention_head,
origin_self.hidden_size_per_attention_head // world_size,
origin_self.num_multi_query_groups_per_partition *
origin_self.hidden_size_per_attention_head,
origin_self.hidden_size_per_attention_head // world_size,
origin_self.num_multi_query_groups_per_partition *
origin_self.hidden_size_per_attention_head,
origin_self.hidden_size_per_attention_head // world_size,
],
dim=-1,
)
query_layer = query_layer.view(query_layer.size()[:-1] + (
origin_self.num_attention_heads_per_partition,
origin_self.num_attention_heads_per_partition // world_size,
origin_self.hidden_size_per_attention_head,
))
key_layer = key_layer.view(key_layer.size()[:-1] + (
origin_self.num_multi_query_groups_per_partition,
origin_self.num_multi_query_groups_per_partition // world_size,
origin_self.hidden_size_per_attention_head,
))
value_layer = value_layer.view(value_layer.size()[:-1] + (
origin_self.num_multi_query_groups_per_partition,
origin_self.num_multi_query_groups_per_partition // world_size,
origin_self.hidden_size_per_attention_head,
))
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
origin_self.num_attention_heads_per_partition,
origin_self.num_attention_heads_per_partition // world_size,
3 * origin_self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer,
value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
Expand All @@ -151,15 +191,14 @@ def _contiguous_batching_forward(
history_lengths = q_seq_length.new_tensor(history_lengths)
kv_seq_length = q_seq_length + history_lengths
max_seq_len = q_seq_length.max().item()

context.fill_cache(
key_layer[0],
value_layer[0],
q_start_loc,
q_seq_length,
cache_k,
cache_v,
)
fill_kv_cache(key_layer[0],
value_layer[0],
cache_k,
cache_v,
q_start_loc,
q_seq_length,
block_offsets=context.block_offsets,
history_lengths=history_lengths)

if use_cache:
kv_cache = (key_layer, value_layer)
Expand Down Expand Up @@ -225,6 +264,31 @@ def forward(
)


class MLP(nn.Module):

@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['dense_h_to_4h']:
for name, param in mod.named_parameters():
dist_tensor = distribute_tensor(param.unflatten(0, (2, -1)),
device_mesh, [Shard(1)])
dist_tensor = try_to_local(dist_tensor)
dist_param = torch.nn.Parameter(dist_tensor.flatten(0, 1))
mod.register_parameter(name, dist_param)
elif mod_name in ['dense_4h_to_h']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs


class PatchedChatGLMModel(nn.Module):

def _contiguous_batching_forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def fill_kv_cache(
"""
block_size = k_caches.size(1)

history_lengths = torch.tensor(history_lengths)
if not isinstance(history_lengths, torch.Tensor):
history_lengths = torch.tensor(history_lengths)
first_free_block_offsets = history_lengths // block_size
first_token_offsets = history_lengths % block_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from lmdeploy.pytorch_poc.patch.functional import (
apply_rotary_pos_emb, attention_forward_with_paged_attention)

from .functional import (apply_rotary_pos_emb,
attention_forward_with_paged_attention)


class PatchedInternLMAttention(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from lmdeploy.pytorch_poc.patch.functional import (
apply_rotary_pos_emb, attention_forward_with_paged_attention)

from .functional import (apply_rotary_pos_emb,
attention_forward_with_paged_attention)


class LlamaAttention(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,50 @@
from lmdeploy.pytorch_poc.dist_utils import partition_module, replicate_module
from lmdeploy.utils import get_logger

LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch_poc.models'

# llama
MODULE_MAP = {
'transformers.models.llama.modeling_llama.LlamaAttention':
'lmdeploy.pytorch_poc.patch.llama.LlamaAttention',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaModel':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'transformers.models.llama.modeling_llama.LlamaMLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
}

# baichuan
MODULE_MAP.update({
'modeling_baichuan.Model':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa
'modeling_baichuan.BaichuanModel':
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanModel', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanModel', # noqa
'modeling_baichuan.Attention':
'lmdeploy.pytorch_poc.patch.baichuan.Attention', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.Attention', # noqa
'modeling_baichuan.BaichuanAttention':
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanAttention', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanAttention', # noqa
'modeling_baichuan.MLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP', # noqa
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', # noqa
})

# chatglm2
MODULE_MAP.update({
'modeling_chatglm.SelfAttention':
'lmdeploy.pytorch_poc.patch.chatglm2.PatchedSelfAttention',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedSelfAttention',
'modeling_chatglm.ChatGLMModel':
'lmdeploy.pytorch_poc.patch.chatglm2.PatchedChatGLMModel',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedChatGLMModel',
'modeling_chatglm.MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.MLP',
})

# internlm
MODULE_MAP.update({
'modeling_internlm.InternLMAttention':
'lmdeploy.pytorch_poc.patch.internlm.PatchedInternLMAttention',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.PatchedInternLMAttention',
'modeling_internlm.InternLMModel':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel',
'modeling_internlm.InternLMMLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP',
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
})


Expand Down Expand Up @@ -169,6 +173,17 @@ def _update_model(model: torch.nn.Module):
model._update_model_fn()


def _params_to_meta(model: torch.nn.Module):
"""move parameters to meta device."""
# recursive over children
for _, child in model.named_children():
_params_to_meta(child)

for k, v in model.named_parameters(recurse=False):
model.register_parameter(
k, torch.nn.Parameter(v.to('meta'), requires_grad=False))


def _load_state_dict(
model: torch.nn.Module,
state_dict: Dict[str, torch.Tensor] = None,
Expand Down Expand Up @@ -232,13 +247,18 @@ def _load_state_dict(
if not in_state_dict:
continue

if rank == 0:
new_param = torch.nn.Parameter(state_dict[full_k],
requires_grad=False).to(device)
else:
new_param = torch.nn.Parameter(torch.empty_like(v, device=device),
requires_grad=False)
model.register_parameter(k, new_param)
param_names = [
name for name, _ in model.named_parameters(recurse=False)
]
if k in param_names:
if rank == 0:
new_param = torch.nn.Parameter(state_dict[full_k].to(v.dtype),
requires_grad=False).to(device)
else:
new_param = torch.nn.Parameter(torch.empty_like(v,
device=device),
requires_grad=False)
model.register_parameter(k, new_param)

# distribute module
if world_size > 1:
Expand Down Expand Up @@ -312,6 +332,7 @@ def patch(

# load checkpoint
if checkpoints is not None:
_params_to_meta(model)
device_mesh = DeviceMesh('cuda', list(range(world_size)))
for ckpt in checkpoints:
if rank == 0:
Expand Down
Loading