Skip to content

Commit

Permalink
support dbrx in pytorch engine (#1367)
Browse files Browse the repository at this point in the history
* support dbrx

* fix match

* update zh-cn readme

* update supported_models.py

* update mode.py

* new device map

* update eoh

---------

Co-authored-by: grimoire <[email protected]>
  • Loading branch information
grimoire and grimoire authored Apr 1, 2024
1 parent 08231ea commit ede55e5
Show file tree
Hide file tree
Showing 12 changed files with 388 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
| DeepSeek-MoE | 16B |
| Mixtral | 8x7B |
| Gemma | 2B-7B |
| Dbrx | 132B |

LMDeploy has developed two inference engines - [TurboMind](./docs/en/inference/turbomind.md) and [PyTorch](./docs/en/inference/pytorch.md), each with a different focus. The former strives for ultimate optimization of inference performance, while the latter, developed purely in Python, aims to decrease the barriers for developers.

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
| DeepSeek-MoE | 16B |
| Mixtral | 8x7B |
| Gemma | 2B-7B |
| Dbrx | 132B |

LMDeploy 支持 2 种推理引擎: [TurboMind](./docs/zh_cn/inference/turbomind.md)[PyTorch](./docs/zh_cn/inference/pytorch.md),它们侧重不同。前者追求推理性能的极致优化,后者纯用python开发,着重降低开发者的门槛。

Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
| QWen1.5 | 7B - 72B | Yes | No | No |
| DeepSeek-MoE | 16B | Yes | No | No |
| Gemma | 2B-7B | Yes | No | No |
| Dbrx | 132B | Yes | No | No |
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
| QWen1.5 | 7B - 72B | Yes | No | No |
| DeepSeek-MoE | 16B | Yes | No | No |
| Gemma | 2B-7B | Yes | No | No |
| Dbrx | 132B | Yes | No | No |
57 changes: 57 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,63 @@ def match(cls, model_path: str) -> Optional[str]:
return 'yi-vl'


# flake8: noqa: E501
def dbrx_system_prompt():
# This is inspired by the Claude3 prompt.
# source: https://twitter.com/AmandaAskell/status/1765207842993434880
# Identity and knowledge
prompt = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\n'
prompt += 'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\n'
# Capabilities (and reminder to use ``` for JSON blocks and tables, which it can forget). Also a reminder that it can't browse the internet or run code.
prompt += 'You assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n'
prompt += '(You do not have real-time data access or code execution capabilities. '
# Ethical guidelines
prompt += 'You avoid stereotyping and provide balanced perspectives on controversial topics. '
# Data: the model doesn't know what it was trained on; it thinks that everything that it is aware of was in its training data. This is a reminder that it wasn't.
# We also encourage it not to try to generate lyrics or poems
prompt += 'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\n'
# The model really wants to talk about its system prompt, to the point where it is annoying, so encourage it not to
prompt += 'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\n'
prompt += 'You do not mention any of this information about yourself unless the information is directly pertinent to the user\\\'s query.'.upper(
)
return prompt


@MODELS.register_module(name=['dbrx'])
class DbrxInstruct(BaseChatTemplate):

def __init__(self,
system='<|im_start|>system\n',
meta_instruction=dbrx_system_prompt(),
eosys='<|im_end|>\n',
user='<|im_start|>user\n',
eoh='<|im_end|>\n',
assistant='<|im_start|>assistant\n',
eoa='<|im_end|>',
separator='\n',
**kwargs):
super().__init__(system,
meta_instruction=meta_instruction,
eosys=eosys,
user=user,
eoh=eoh,
assistant=assistant,
eoa=eoa,
separator=separator,
**kwargs)

@classmethod
def match(cls, model_path: str) -> Optional[str]:
"""Return the model_name that was registered to MODELS.
Args:
model_path (str): the model path used for matching.
"""
path = model_path.lower()
if 'dbrx' in path:
return 'dbrx'


def best_match_model(query: str) -> Optional[str]:
"""Get the model that matches the query.
Expand Down
47 changes: 36 additions & 11 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ def _get_torch_dtype(config: Any, default: str = 'float16'):
config: Config of the hf model.
default (str): default device type.
"""

def __hack_qwen(hf_config: Any):
if hf_config.model_type == 'qwen' and hf_config.torch_dtype is None:
torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported(
) else 'float16'
if hf_config.bf16:
torch_dtype = 'bfloat16'
elif hf_config.fp16:
torch_dtype = 'float16'
setattr(hf_config, 'torch_dtype', torch_dtype)

__hack_qwen(config)
torch_dtype = getattr(config, 'torch_dtype', default)
# torch_dtype in config could be none
torch_dtype = torch_dtype or default
Expand Down Expand Up @@ -135,6 +147,26 @@ def __build_gemma():
head_dim=hf_config.head_dim,
vocab_size=hf_config.vocab_size)

def __build_dbrx():
hidden_size = hf_config.d_model
num_heads = hf_config.n_heads
head_dim = hidden_size // num_heads
eos_token_id = getattr(hf_config, 'eos_token_id', None)
if eos_token_id is None:
eos_token_id = 100257
bos_token_id = getattr(hf_config, 'bos_token_id', None)
if bos_token_id is None:
bos_token_id = eos_token_id
return ModelConfig(
hidden_size=hidden_size,
num_layers=hf_config.n_layers,
num_attention_heads=num_heads,
num_key_value_heads=hf_config.attn_config.kv_n_heads,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
head_dim=head_dim,
vocab_size=hf_config.vocab_size)

def __build_default():
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
num_attention_heads = hf_config.num_attention_heads
Expand All @@ -156,24 +188,17 @@ def __build_default():
head_dim=head_dim,
vocab_size=hf_config.vocab_size)

if 'falcon' in model_path:
if hf_config.model_type == 'falcon':
model_config = __build_falcon()
elif 'chatglm' in model_path:
elif hf_config.model_type == 'chatglm':
model_config = __build_chatglm()
elif hf_config.model_type == 'gemma':
model_config = __build_gemma()
elif hf_config.model_type == 'dbrx':
model_config = __build_dbrx()
else:
model_config = __build_default()

if hf_config.model_type == 'qwen' and hf_config.torch_dtype is None:
torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported(
) else 'float16'
if hf_config.bf16:
torch_dtype = 'bfloat16'
elif hf_config.fp16:
torch_dtype = 'float16'
setattr(hf_config, 'torch_dtype', torch_dtype)

model_config.dtype = _get_torch_dtype(hf_config)

model_config.hf_config = hf_config
Expand Down
13 changes: 8 additions & 5 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,15 +654,18 @@ def _create_device_map(model: torch.nn.Module,
world_size: int,
device_map: dict = None):
"""Distribute params to each devices."""
free_mems = [get_gpu_memory(gpu_id)[0] for gpu_id in range(world_size)]
free_mems = torch.tensor(free_mems)
if device_map is None:
device_map = dict()
device_id = 0
for name, _ in model.named_parameters():
for name, param in model.named_parameters():
device_id = free_mems.argmin().item()
device_map[name] = device_id
device_id = (device_id + 1) % world_size
for name, _ in model.named_buffers():
free_mems[device_id] += param.numel() * param.element_size()
for name, param in model.named_buffers():
device_id = free_mems.argmin().item()
device_map[name] = device_id
device_id = (device_id + 1) % world_size
free_mems[device_id] += param.numel() * param.element_size()
return device_map


Expand Down
Loading

0 comments on commit ede55e5

Please sign in to comment.