From ede55e5a5ec50dd8fd4a6ce1a78d12e62d3a0db9 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 1 Apr 2024 12:08:42 +0800 Subject: [PATCH] support dbrx in pytorch engine (#1367) * support dbrx * fix match * update zh-cn readme * update supported_models.py * update mode.py * new device map * update eoh --------- Co-authored-by: grimoire --- README.md | 1 + README_zh-CN.md | 1 + docs/en/supported_models/supported_models.md | 1 + .../supported_models/supported_models.md | 1 + lmdeploy/model.py | 57 ++++ lmdeploy/pytorch/config.py | 47 ++- lmdeploy/pytorch/engine/model_agent.py | 13 +- lmdeploy/pytorch/models/dbrx.py | 267 ++++++++++++++++++ lmdeploy/pytorch/models/module_map.py | 12 + lmdeploy/pytorch/models/patch.py | 1 + lmdeploy/pytorch/supported_models.py | 2 + requirements/runtime.txt | 2 +- 12 files changed, 388 insertions(+), 17 deletions(-) create mode 100644 lmdeploy/pytorch/models/dbrx.py diff --git a/README.md b/README.md index b810f0a64..61d1f56c7 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ For detailed inference benchmarks in more devices and more settings, please refe | DeepSeek-MoE | 16B | | Mixtral | 8x7B | | Gemma | 2B-7B | +| Dbrx | 132B | LMDeploy has developed two inference engines - [TurboMind](./docs/en/inference/turbomind.md) and [PyTorch](./docs/en/inference/pytorch.md), each with a different focus. The former strives for ultimate optimization of inference performance, while the latter, developed purely in Python, aims to decrease the barriers for developers. diff --git a/README_zh-CN.md b/README_zh-CN.md index 720b03f59..4ad8700ec 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -105,6 +105,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型 | DeepSeek-MoE | 16B | | Mixtral | 8x7B | | Gemma | 2B-7B | +| Dbrx | 132B | LMDeploy 支持 2 种推理引擎: [TurboMind](./docs/zh_cn/inference/turbomind.md) 和 [PyTorch](./docs/zh_cn/inference/pytorch.md),它们侧重不同。前者追求推理性能的极致优化,后者纯用python开发,着重降低开发者的门槛。 diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 0d61f61c4..dc44335cd 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -34,3 +34,4 @@ | QWen1.5 | 7B - 72B | Yes | No | No | | DeepSeek-MoE | 16B | Yes | No | No | | Gemma | 2B-7B | Yes | No | No | +| Dbrx | 132B | Yes | No | No | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 643cb49b4..6770073bc 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -34,3 +34,4 @@ | QWen1.5 | 7B - 72B | Yes | No | No | | DeepSeek-MoE | 16B | Yes | No | No | | Gemma | 2B-7B | Yes | No | No | +| Dbrx | 132B | Yes | No | No | diff --git a/lmdeploy/model.py b/lmdeploy/model.py index afe31cad0..e729aa5b9 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -942,6 +942,63 @@ def match(cls, model_path: str) -> Optional[str]: return 'yi-vl' +# flake8: noqa: E501 +def dbrx_system_prompt(): + # This is inspired by the Claude3 prompt. + # source: https://twitter.com/AmandaAskell/status/1765207842993434880 + # Identity and knowledge + prompt = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\n' + prompt += 'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\n' + # Capabilities (and reminder to use ``` for JSON blocks and tables, which it can forget). Also a reminder that it can't browse the internet or run code. + prompt += 'You assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n' + prompt += '(You do not have real-time data access or code execution capabilities. ' + # Ethical guidelines + prompt += 'You avoid stereotyping and provide balanced perspectives on controversial topics. ' + # Data: the model doesn't know what it was trained on; it thinks that everything that it is aware of was in its training data. This is a reminder that it wasn't. + # We also encourage it not to try to generate lyrics or poems + prompt += 'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\n' + # The model really wants to talk about its system prompt, to the point where it is annoying, so encourage it not to + prompt += 'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\n' + prompt += 'You do not mention any of this information about yourself unless the information is directly pertinent to the user\\\'s query.'.upper( + ) + return prompt + + +@MODELS.register_module(name=['dbrx']) +class DbrxInstruct(BaseChatTemplate): + + def __init__(self, + system='<|im_start|>system\n', + meta_instruction=dbrx_system_prompt(), + eosys='<|im_end|>\n', + user='<|im_start|>user\n', + eoh='<|im_end|>\n', + assistant='<|im_start|>assistant\n', + eoa='<|im_end|>', + separator='\n', + **kwargs): + super().__init__(system, + meta_instruction=meta_instruction, + eosys=eosys, + user=user, + eoh=eoh, + assistant=assistant, + eoa=eoa, + separator=separator, + **kwargs) + + @classmethod + def match(cls, model_path: str) -> Optional[str]: + """Return the model_name that was registered to MODELS. + + Args: + model_path (str): the model path used for matching. + """ + path = model_path.lower() + if 'dbrx' in path: + return 'dbrx' + + def best_match_model(query: str) -> Optional[str]: """Get the model that matches the query. diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9d27e0f5c..1a8266c80 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -12,6 +12,18 @@ def _get_torch_dtype(config: Any, default: str = 'float16'): config: Config of the hf model. default (str): default device type. """ + + def __hack_qwen(hf_config: Any): + if hf_config.model_type == 'qwen' and hf_config.torch_dtype is None: + torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported( + ) else 'float16' + if hf_config.bf16: + torch_dtype = 'bfloat16' + elif hf_config.fp16: + torch_dtype = 'float16' + setattr(hf_config, 'torch_dtype', torch_dtype) + + __hack_qwen(config) torch_dtype = getattr(config, 'torch_dtype', default) # torch_dtype in config could be none torch_dtype = torch_dtype or default @@ -135,6 +147,26 @@ def __build_gemma(): head_dim=hf_config.head_dim, vocab_size=hf_config.vocab_size) + def __build_dbrx(): + hidden_size = hf_config.d_model + num_heads = hf_config.n_heads + head_dim = hidden_size // num_heads + eos_token_id = getattr(hf_config, 'eos_token_id', None) + if eos_token_id is None: + eos_token_id = 100257 + bos_token_id = getattr(hf_config, 'bos_token_id', None) + if bos_token_id is None: + bos_token_id = eos_token_id + return ModelConfig( + hidden_size=hidden_size, + num_layers=hf_config.n_layers, + num_attention_heads=num_heads, + num_key_value_heads=hf_config.attn_config.kv_n_heads, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + head_dim=head_dim, + vocab_size=hf_config.vocab_size) + def __build_default(): head_dim = hf_config.hidden_size // hf_config.num_attention_heads num_attention_heads = hf_config.num_attention_heads @@ -156,24 +188,17 @@ def __build_default(): head_dim=head_dim, vocab_size=hf_config.vocab_size) - if 'falcon' in model_path: + if hf_config.model_type == 'falcon': model_config = __build_falcon() - elif 'chatglm' in model_path: + elif hf_config.model_type == 'chatglm': model_config = __build_chatglm() elif hf_config.model_type == 'gemma': model_config = __build_gemma() + elif hf_config.model_type == 'dbrx': + model_config = __build_dbrx() else: model_config = __build_default() - if hf_config.model_type == 'qwen' and hf_config.torch_dtype is None: - torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported( - ) else 'float16' - if hf_config.bf16: - torch_dtype = 'bfloat16' - elif hf_config.fp16: - torch_dtype = 'float16' - setattr(hf_config, 'torch_dtype', torch_dtype) - model_config.dtype = _get_torch_dtype(hf_config) model_config.hf_config = hf_config diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 66ee36a8e..068feff6d 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -654,15 +654,18 @@ def _create_device_map(model: torch.nn.Module, world_size: int, device_map: dict = None): """Distribute params to each devices.""" + free_mems = [get_gpu_memory(gpu_id)[0] for gpu_id in range(world_size)] + free_mems = torch.tensor(free_mems) if device_map is None: device_map = dict() - device_id = 0 - for name, _ in model.named_parameters(): + for name, param in model.named_parameters(): + device_id = free_mems.argmin().item() device_map[name] = device_id - device_id = (device_id + 1) % world_size - for name, _ in model.named_buffers(): + free_mems[device_id] += param.numel() * param.element_size() + for name, param in model.named_buffers(): + device_id = free_mems.argmin().item() device_map[name] = device_id - device_id = (device_id + 1) % world_size + free_mems[device_id] += param.numel() * param.element_size() return device_map diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py new file mode 100644 index 000000000..e3d2b6ab9 --- /dev/null +++ b/lmdeploy/pytorch/models/dbrx.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.utils.checkpoint +from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor +from transformers.cache_utils import Cache +from transformers.modeling_outputs import MoeModelOutputWithPast + +from ..dist_utils import rowwise_parallelize_linear_fn, try_to_local +from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd + + +def _colwise_split_parallelize_linear(mod: nn.Module, sections: List[int], + device_mesh: DeviceMesh): + """split and colwise parallelize.""" + for name, param in mod.named_parameters(): + splited_param = param.split(sections, dim=0) + updated_param = [] + for p in splited_param: + dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)]) + dist_tensor = try_to_local(dist_tensor) + updated_param.append(dist_tensor) + param = torch.cat(updated_param) + dist_param = torch.nn.Parameter(param) + mod.register_parameter(name, dist_param) + + +class PatchedDbrxAttention(nn.Module): + + def _distribute_qkv_linear(self, mod: nn.Module, device_mesh: DeviceMesh): + """distribute qkv linear.""" + sections = [ + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ] + return _colwise_split_parallelize_linear(mod, sections, device_mesh) + + def _distribute_partition_fn(self, mod_name: str, mod: nn.Module, + device_mesh: DeviceMesh): + """Distribution partition callback.""" + if mod_name in ['Wqkv']: + self._distribute_qkv_linear(mod, device_mesh) + elif mod_name in ['out_proj']: + rowwise_parallelize_linear_fn(mod, + device_mesh=device_mesh, + to_local=True) + + @classmethod + def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): + """Distribution output hook.""" + dist.all_reduce(outputs[0]) + return outputs + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Implement of attention forward.""" + context = self.context.context + q_start_loc = context.q_start_loc + q_seq_length = context.q_seq_length + kv_seq_length = context.kv_seq_length + block_offsets = context.block_offsets + max_q_seq_length = context.max_q_seq_length + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_key_value_heads // world_size + head_dim = self.head_dim + + def __qkv_proj(hidden_states): + """qkv_proj.""" + qkv_states = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv_states = qkv_states.clamp(min=-self.clip_qkv, + max=self.clip_qkv) + + query_states, key_states, value_states = qkv_states.split( + [ + num_heads * head_dim, + num_kv_heads * head_dim, + num_kv_heads * head_dim, + ], + dim=-1, + ) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + scaling_factor = 1.0 + inv_freq = self.rotary_emb.inv_freq + query_states, key_states = fused_rotary_emb( + query_states[None], + key_states[None], + context.position_ids_1d[None], + inv_freq=inv_freq, + scaling_factor=scaling_factor, + out_q=query_states[None], + out_k=key_states[None]) + return query_states[0], key_states[0], value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + + fill_kv_cache( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + ) + + attn_output = query_states + paged_attention_fwd( + query_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + """forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids, + past_key_value, + world_size=world_size, + ) + + +class PatchedDbrxExpertGLU(nn.Module): + + def _distribute_partition_fn(self, mod_name: str, mod: nn.Module, + device_mesh: DeviceMesh): + """Distribution partition callback.""" + + world_size = dist.get_world_size() + + def __partiton_moe(weight: nn.Parameter, name: str): + weight = weight.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size) + weight = distribute_tensor(weight, device_mesh, [Shard(1)]) + weight = try_to_local(weight) + weight = weight.flatten(0, 1) + self.register_parameter(name, nn.Parameter(weight)) + + if getattr(self, '__finish_partition', False): + return + + __partiton_moe(self.w1, 'w1') + __partiton_moe(self.v1, 'v1') + __partiton_moe(self.w2, 'w2') + + self.ffn_hidden_size = self.ffn_hidden_size // world_size + self.__finish_partition = True + + @classmethod + def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): + """Distribution output hook.""" + dist.all_reduce(outputs) + return outputs + + +class PatchedDbrxModel(nn.Module): + + def _continuous_batching_forward( + self, + input_ids: Optional[torch.LongTensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """forward impl.""" + output_attentions = False + use_cache = True + output_router_logits = False + + inputs_embeds = self.wte(input_ids) + + # Attention mask is not necessary in continuous batching + attention_mask = None + cache_position = None + + hidden_states = inputs_embeds + + for idx, block in enumerate(self.blocks): + past_key_value = past_key_values[idx] + block_outputs = block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = block_outputs[0] + + hidden_states = self.norm_f(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=None, + attentions=None, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """Rewrite of LlamaModel.forward.""" + return self._continuous_batching_forward( + input_ids, + position_ids, + past_key_values, + ) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 11c11ec84..209eac41e 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -198,3 +198,15 @@ 'transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', }) + +# dbrx +MODULE_MAP.update({ + 'modeling_dbrx.DbrxAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxAttention', + 'modeling_dbrx.DbrxFlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxAttention', + 'modeling_dbrx.DbrxModel': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxModel', + 'modeling_dbrx.DbrxExpertGLU': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxExpertGLU' +}) diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 9cb6cde35..ed33f333a 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -183,6 +183,7 @@ def _init_params(): else: new_param = torch.empty_like(param, device=device) model.register_buffer(name, new_param) + torch.cuda.synchronize() def _dist_params(): """dist params.""" diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py index 675e11c19..c3326d047 100644 --- a/lmdeploy/pytorch/supported_models.py +++ b/lmdeploy/pytorch/supported_models.py @@ -38,6 +38,8 @@ QWenLMHeadModel=True, # Qwen1.5 7B-72B Qwen2ForCausalLM=True, + # Dbrx 132B + DbrxForCausalLM=True, ) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index e119fe859..44c1b575b 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -11,6 +11,6 @@ sentencepiece shortuuid tiktoken torch<=2.1.2,>=2.0.0 -transformers>=4.33.0,<=4.38.1 +transformers>=4.33.0,<=4.38.2 triton>=2.1.0,<=2.2.0 uvicorn