Skip to content

Commit

Permalink
feat(model): Support Phi-3 models (#1554)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored May 23, 2024
1 parent 47430f2 commit 7f55aa4
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ At present, we have introduced several key features to showcase our current capa
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.

- News
- 🔥🔥🔥 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- 🔥🔥🔥 [Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)
- 🔥🔥🔥 [Yi-1.5-9B-Chat](https://huggingface.co/01-ai/Yi-1.5-9B-Chat)
- 🔥🔥🔥 [Yi-1.5-6B-Chat](https://huggingface.co/01-ai/Yi-1.5-6B-Chat)
Expand Down
1 change: 1 addition & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:

- 新增支持模型
- 🔥🔥🔥 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- 🔥🔥🔥 [Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)
- 🔥🔥🔥 [Yi-1.5-9B-Chat](https://huggingface.co/01-ai/Yi-1.5-9B-Chat)
- 🔥🔥🔥 [Yi-1.5-6B-Chat](https://huggingface.co/01-ai/Yi-1.5-6B-Chat)
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ def get_device() -> str:
"gemma-2b-it": os.path.join(MODEL_PATH, "gemma-2b-it"),
"starling-lm-7b-beta": os.path.join(MODEL_PATH, "Starling-LM-7B-beta"),
"deepseek-v2-lite-chat": os.path.join(MODEL_PATH, "DeepSeek-V2-Lite-Chat"),
"sailor-14b-chat": os.path.join(MODEL_PATH, "Sailor-14B-Chat"),
# https://huggingface.co/microsoft/Phi-3-medium-128k-instruct
"phi-3-medium-128k-instruct": os.path.join(
MODEL_PATH, "Phi-3-medium-128k-instruct"
),
"phi-3-medium-4k-instruct": os.path.join(MODEL_PATH, "Phi-3-medium-4k-instruct"),
"phi-3-small-128k-instruct": os.path.join(MODEL_PATH, "Phi-3-small-128k-instruct"),
"phi-3-small-8k-instruct": os.path.join(MODEL_PATH, "Phi-3-small-8k-instruct"),
"phi-3-mini-128k-instruct": os.path.join(MODEL_PATH, "Phi-3-mini-128k-instruct"),
"phi-3-mini-4k-instruct": os.path.join(MODEL_PATH, "Phi-3-mini-4k-instruct"),
}

EMBEDDING_MODEL_CONFIG = {
Expand Down
57 changes: 57 additions & 0 deletions dbgpt/model/adapter/hf_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,61 @@ def load(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class SailorAdapter(QwenAdapter):
"""
https://huggingface.co/sail/Sailor-14B-Chat
"""

def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "sailor" in lower_model_name_or_path
and "chat" in lower_model_name_or_path
)


class PhiAdapter(NewHFChatModelAdapter):
"""
https://huggingface.co/microsoft/Phi-3-medium-128k-instruct
"""

support_4bit: bool = True
support_8bit: bool = True
support_system_message: bool = False

def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "phi-3" in lower_model_name_or_path
and "instruct" in lower_model_name_or_path
)

def load(self, model_path: str, from_pretrained_kwargs: dict):
if not from_pretrained_kwargs:
from_pretrained_kwargs = {}
if "trust_remote_code" not in from_pretrained_kwargs:
from_pretrained_kwargs["trust_remote_code"] = True
return super().load(model_path, from_pretrained_kwargs)

def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
str_prompt = super().get_str_prompt(
params,
messages,
tokenizer,
prompt_template,
convert_to_compatible_format,
)
params["custom_stop_words"] = ["<|end|>"]
return str_prompt


# The following code is used to register the model adapter
# The last registered model adapter is matched first
register_model_adapter(YiAdapter)
Expand All @@ -408,3 +463,5 @@ def load(self, model_path: str, from_pretrained_kwargs: dict):
register_model_adapter(QwenMoeAdapter)
register_model_adapter(Llama3Adapter)
register_model_adapter(DeepseekV2Adapter)
register_model_adapter(SailorAdapter)
register_model_adapter(PhiAdapter)
5 changes: 5 additions & 0 deletions dbgpt/model/llm_out/hf_chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def huggingface_chat_generate_stream(
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", None)
custom_stop_words = params.get("custom_stop_words", [])

input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
Expand Down Expand Up @@ -62,4 +63,8 @@ def huggingface_chat_generate_stream(
out = ""
for new_text in streamer:
out += new_text
if custom_stop_words:
for stop_word in custom_stop_words:
if out.endswith(stop_word):
out = out[: -len(stop_word)]
yield out

0 comments on commit 7f55aa4

Please sign in to comment.