Skip to content

Commit

Permalink
support tp > n_kv_heads for pt engine (#2872)
Browse files Browse the repository at this point in the history
* support tp > n_kv_heads for pt models

* fix conflicts and remove replicate_kv

* remove llava config
  • Loading branch information
RunningLeon authored Dec 18, 2024
1 parent 1b219e3 commit 7deb69c
Show file tree
Hide file tree
Showing 28 changed files with 297 additions and 223 deletions.
25 changes: 18 additions & 7 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
# change to user specified data type if it is not 'auto'
if dtype == 'auto':
torch_dtype = torch_dtype if torch_dtype in [
torch.float16, torch.bfloat16
] else torch.float16
'float16', 'bfloat16'
] else 'float16'
else:
torch_dtype = dtype
config.dtype = eval(f'torch.{torch_dtype}')
Expand Down Expand Up @@ -104,7 +104,6 @@ class ModelConfig:
v_head_dim: int = None
sliding_window: int = -1
dtype: torch.dtype = torch.float16
multi_query_attention: bool = False
vocab_size: int = 40000
hf_config: Any = None
cogvlm_style: bool = False
Expand All @@ -118,7 +117,8 @@ def get_head_size(self):
def from_pretrained(cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto'):
dtype: str = 'auto',
tp: int = 1):
"""Instantiate one of the configuration classes of the library from a
pretrained model configuration.
Expand All @@ -138,17 +138,21 @@ def from_pretrained(cls,
pretrained_model_name_or_path)
return cls.from_hf_config(hf_config,
pretrained_model_name_or_path,
dtype=dtype)
dtype=dtype,
tp=tp)

@classmethod
def from_hf_config(cls,
hf_config: Any,
model_path: str = None,
dtype: str = 'auto'):
dtype: str = 'auto',
tp: int = 1):
"""from huggingface config."""
from lmdeploy.pytorch.configurations import AutoModelConfigBuilder

model_config = AutoModelConfigBuilder.build(hf_config, model_path)
model_config = AutoModelConfigBuilder.build(hf_config,
model_path,
tp=tp)

if model_config.k_head_dim is None:
assert model_config.head_dim is not None
Expand All @@ -157,6 +161,13 @@ def from_hf_config(cls,
assert model_config.head_dim is not None
model_config.v_head_dim = model_config.head_dim

# check for tp
assert model_config.num_attention_heads % tp == 0
if model_config.num_key_value_heads >= tp:
assert model_config.num_key_value_heads % tp == 0
else:
assert tp % model_config.num_key_value_heads == 0

# should after setting `hf_config` and `model_arch` attributes
model_config = _update_torch_dtype(model_config, dtype)

Expand Down
17 changes: 15 additions & 2 deletions lmdeploy/pytorch/configurations/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def condition(cls, hf_config):
f'`condition` of {cls.__name__} not implemented.')

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from .default import DefaultModelConfigBuilder

Expand All @@ -46,8 +46,21 @@ def build(cls, hf_config, model_path: str = None):

logger.debug(f'build model config with {valid_builder.__name__}')

cfg = valid_builder.build(hf_config, model_path)
cfg = valid_builder.build(hf_config, model_path, **kwargs)
if cfg.hf_config is None:
cfg.hf_config = hf_config

return cfg

@classmethod
def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads):
"""update num kv heads."""
# update num_kv_heads for tp mode
if tp > 1 and tp > num_key_value_heads:
assert tp % num_key_value_heads == 0
n_replicate = tp // num_key_value_heads
hf_config.num_replicate_key_value_heads = n_replicate
num_key_value_heads = tp

hf_config.num_key_value_heads = num_key_value_heads
return num_key_value_heads
15 changes: 13 additions & 2 deletions lmdeploy/pytorch/configurations/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,27 @@ def condition(cls, hf_config):
return hf_config.model_type == 'chatglm'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
bos_token_id = hf_config.bos_token_id
if bos_token_id is None:
bos_token_id = hf_config.pad_token_id

if hf_config.multi_query_attention:
num_key_value_heads = hf_config.multi_query_group_num
else:
num_key_value_heads = hf_config.num_attention_heads

tp = kwargs.get('tp', 1)
# update num_kv_heads for tp mode
num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
num_key_value_heads)

cfg = ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=hf_config.multi_query_group_num,
num_key_value_heads=num_key_value_heads,
bos_token_id=bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
Expand Down
9 changes: 6 additions & 3 deletions lmdeploy/pytorch/configurations/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ def condition(cls, hf_config):
return model_arch == 'CogVLMForCausalLM'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from lmdeploy.utils import is_bf16_supported
cfg = DefaultModelConfigBuilder.build(hf_config)
if getattr(hf_config, 'num_multi_query_heads', None):
cfg.num_key_value_heads = hf_config.num_multi_query_heads
hf_config.num_key_value_heads = hf_config.num_multi_query_heads
else:
hf_config.num_key_value_heads = hf_config.num_attention_heads

cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
cfg.cogvlm_style = True
torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
hf_config.torch_dtype = torch_dtype
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/configurations/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.model_type == 'dbrx'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
hidden_size = hf_config.d_model
num_heads = hf_config.n_heads
Expand Down
11 changes: 8 additions & 3 deletions lmdeploy/pytorch/configurations/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@ def condition(cls, hf_config):
return hf_config.model_type == 'deepseek_v2'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim)
k_head_dim = head_dim
v_head_dim = 0
num_attention_heads = hf_config.num_attention_heads
# multi query attn
num_key_value_heads = 1
tp = kwargs.get('tp', 1)
# update num_kv_heads for tp mode
num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
num_key_value_heads)

return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
num_attention_heads=num_attention_heads,
Expand All @@ -28,5 +34,4 @@ def build(cls, hf_config, model_path: str = None):
head_dim=head_dim,
k_head_dim=k_head_dim,
v_head_dim=v_head_dim,
vocab_size=hf_config.vocab_size,
multi_query_attention=True)
vocab_size=hf_config.vocab_size)
7 changes: 6 additions & 1 deletion lmdeploy/pytorch/configurations/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def condition(cls, hf_config):
return True

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
num_attention_heads = hf_config.num_attention_heads
Expand All @@ -23,6 +23,11 @@ def build(cls, hf_config, model_path: str = None):
if use_sliding_window:
sliding_window = getattr(hf_config, 'sliding_window',
sliding_window) or -1
tp = kwargs.get('tp', 1)
# update num_kv_heads for tp mode
num_key_value_heads = cls.update_num_kv_heads(hf_config, tp,
num_key_value_heads)

return ModelConfig(
hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
Expand Down
9 changes: 7 additions & 2 deletions lmdeploy/pytorch/configurations/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def condition(cls, hf_config):
return hf_config.model_type == 'falcon'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build falcon."""
num_attention_heads = hf_config.num_attention_heads
if hf_config.new_decoder_architecture:
Expand All @@ -24,6 +24,12 @@ def build(cls, hf_config, model_path: str = None):
else:
# rw-1b, MHA
kv_head = num_attention_heads

tp = kwargs.get('tp', 1)
# update num_kv_heads for tp mode
kv_head = cls.update_num_kv_heads(hf_config, tp, kv_head)
hf_config.num_kv_heads = kv_head

head_dim = hf_config.hidden_size // num_attention_heads
return ModelConfig(
hidden_size=hf_config.hidden_size,
Expand All @@ -33,6 +39,5 @@ def build(cls, hf_config, model_path: str = None):
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
multi_query_attention=hf_config.multi_query,
vocab_size=hf_config.vocab_size,
)
16 changes: 5 additions & 11 deletions lmdeploy/pytorch/configurations/gemma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class GemmaModelConfigBuilder(AutoModelConfigBuilder):
Expand All @@ -12,13 +11,8 @@ def condition(cls, hf_config):
return hf_config.model_type in ['gemma', 'gemma2']

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build gemma."""
return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=hf_config.num_key_value_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=hf_config.head_dim,
vocab_size=hf_config.vocab_size)
cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
cfg.head_dim = hf_config.head_dim
return cfg
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/configurations/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'InternVLChatModel'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
cfg = DefaultModelConfigBuilder.build(hf_config.llm_config)
cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path,
**kwargs)
cfg.hf_config = hf_config
return cfg
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/configurations/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def condition(cls, hf_config):
]

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
text_config = hf_config.text_config
hidden_size = getattr(text_config, 'hidden_size', 4096)
Expand Down
26 changes: 9 additions & 17 deletions lmdeploy/pytorch/configurations/minicpm3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder):
Expand All @@ -12,21 +12,13 @@ def condition(cls, hf_config):
return hf_config.architectures[0] in ['MiniCPM3ForCausalLM']

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim)
k_head_dim = head_dim
v_head_dim = head_dim
num_attention_heads = hf_config.num_attention_heads
num_key_value_heads = hf_config.num_key_value_heads
return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
k_head_dim=k_head_dim,
v_head_dim=v_head_dim,
vocab_size=hf_config.vocab_size,
multi_query_attention=False)

cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
cfg.head_dim = head_dim
cfg.k_head_dim = head_dim
cfg.v_head_dim = head_dim

return cfg
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/configurations/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ def condition(cls, hf_config):
return hf_config.architectures[0] == 'MllamaForConditionalGeneration'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build llava hf."""
cfg = DefaultModelConfigBuilder.build(hf_config.text_config)
cfg = DefaultModelConfigBuilder.build(hf_config.text_config,
model_path, **kwargs)
cfg.hf_config = hf_config
return cfg
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/configurations/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def condition(cls, hf_config):
return hf_config.model_type == 'qwen'

@classmethod
def build(cls, hf_config, model_path: str = None):
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from lmdeploy.utils import is_bf16_supported
cfg = DefaultModelConfigBuilder.build(hf_config)
cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
if cfg.bos_token_id is None:
cfg.bos_token_id = 151644
if cfg.eos_token_id is None:
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _get_key_block_shape_impl(cls,
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
if local and not model_config.multi_query_attention:
if local:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
Expand All @@ -121,7 +121,7 @@ def _get_value_block_shape_impl(cls,
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
if local and not model_config.multi_query_attention:
if local:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def build_model_agent(model_path: str,
custom_module_map (str): customized nn module map
"""
model_config = ModelConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, dtype=dtype)
model_path, trust_remote_code=trust_remote_code, dtype=dtype, tp=tp)
model_config.custom_module_map = custom_module_map
if tp == 1:
model_agent = BaseModelAgent(model_path,
Expand Down
9 changes: 4 additions & 5 deletions lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ def __init__(self,

self.projection_size = config.kv_channels * config.num_attention_heads
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = self.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_size = (self.projection_size // config.num_attention_heads)
self.multi_query_attention = config.multi_query_attention
if self.multi_query_attention:
self.num_kv_heads = config.multi_query_group_num
num_replicate_kv_heads = getattr(config,
'num_replicate_key_value_heads', 1)
self.query_key_value = build_qkv_proj(
config.hidden_size,
num_q_heads=self.num_attention_heads,
Expand All @@ -54,7 +53,7 @@ def __init__(self,
quant_config=quantization_config,
dtype=dtype,
device=device,
)
num_replicate_kv_heads=num_replicate_kv_heads)

# apply rotary
self.apply_rotary_pos_emb = ApplyRotaryEmb()
Expand Down
Loading

0 comments on commit 7deb69c

Please sign in to comment.