Skip to content

Commit

Permalink
feat(model): Support Yi-34B-Chat
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Nov 24, 2023
1 parent 5075668 commit bd3ab7f
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 22 deletions.
3 changes: 3 additions & 0 deletions pilot/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def get_device() -> str:
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
# https://huggingface.co/01-ai/Yi-34B-Chat
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
}

EMBEDDING_MODEL_CONFIG = {
Expand Down
5 changes: 4 additions & 1 deletion pilot/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
params,
self.model_name,
self.model_path,
self.tokenizer,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
Expand All @@ -269,7 +270,9 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
self.model, self.model_path
)
str_prompt = params.get("prompt")
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
print(
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
)

generate_stream_func_str_name = "{}.{}".format(
generate_stream_func.__module__, generate_stream_func.__name__
Expand Down
54 changes: 54 additions & 0 deletions pilot/model/llm_out/hf_chat_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType

logger = logging.getLogger(__name__)


@torch.inference_mode()
def huggingface_chat_generate_stream(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
params,
device,
context_len=4096,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 0.7))
top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False)
max_new_tokens = int(params.get("max_new_tokens", 2048))

input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 1
input_ids = input_ids[-max_src_len:]
input_echo_len = len(input_ids)
input_ids = torch.as_tensor([input_ids], device=device)

# messages = params["messages"]
# messages = ModelMessage.to_openai_messages(messages)
# input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
# input_ids = input_ids.to(device)

streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"max_length": context_len,
"temperature": temperature,
"streamer": streamer,
}

thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
out = ""
for new_text in streamer:
out += new_text
yield out
166 changes: 148 additions & 18 deletions pilot/model/model_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
import dataclasses
import logging
import threading
Expand Down Expand Up @@ -41,6 +41,7 @@
thread_local = threading.local()
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"


_OLD_MODELS = [
"llama-cpp",
"proxyllm",
Expand All @@ -51,6 +52,14 @@
"codellama-13b",
]

_NEW_HF_CHAT_MODELS = [
"yi-34b",
"yi-6b",
]

# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]


class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
Expand Down Expand Up @@ -99,26 +108,25 @@ def get_default_conv_template(
"""Get the default conv template"""
raise NotImplementedError

def model_adaptation(
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
return None

def get_prompt_with_template(
self,
params: Dict,
messages: List[ModelMessage],
model_name: str,
model_path: str,
model_context: Dict,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
):
conv = self.get_default_conv_template(model_name, model_path)
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}

if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages

if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
Expand All @@ -128,7 +136,7 @@ def model_adaptation(
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context
return None, None, None

conv = conv.copy()
system_messages = []
Expand Down Expand Up @@ -180,6 +188,41 @@ def model_adaptation(
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids

def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
tokenizer: Any,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages

new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
conv_stop_str, conv_stop_token_ids = None, None
if not new_prompt:
(
new_prompt,
conv_stop_str,
conv_stop_token_ids,
) = self.get_prompt_with_template(
params, messages, model_name, model_path, model_context, prompt_template
)
if not new_prompt:
return params, model_context

# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
Expand All @@ -192,8 +235,8 @@ def model_adaptation(
custom_stop_token_ids = params.get("stop_token_ids")

# Prefer the value passed in from the input parameter
params["stop"] = custom_stop or conv.stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
params["stop"] = custom_stop or conv_stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids

return params, model_context

Expand Down Expand Up @@ -270,6 +313,69 @@ def __str__(self) -> str:
)


class NewHFChatModelAdapter(LLMModelAdaper):
def load(self, model_path: str, from_pretrained_kwargs: dict):
try:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install transformers`."
) from exc
if not transformers.__version__ >= "4.34.0":
raise ValueError(
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
)
revision = from_pretrained_kwargs.get("revision", "main")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer,
revision=revision,
trust_remote_code=True,
)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision, trust_remote_code=True
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
except NameError:
model = AutoModel.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
# tokenizer.use_default_system_prompt = False
return model, tokenizer

def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
from pilot.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream

return huggingface_chat_generate_stream

def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
from transformers import AutoTokenizer

if not tokenizer:
raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer

messages = ModelMessage.to_openai_messages(messages)
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return str_prompt


def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
from fastchat.conversation import get_conv_template
Expand Down Expand Up @@ -298,6 +404,11 @@ def get_llm_model_adapter(
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
return VLLMModelAdaperWrapper()

use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
if use_new_hf_chat_models:
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
return NewHFChatModelAdapter()

must_use_old = any(m in model_name for m in _OLD_MODELS)
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
Expand Down Expand Up @@ -334,6 +445,7 @@ def _get_fastchat_model_adapter(
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
_remove_black_list_model_of_fastchat()
if caller:
return caller(model_path)
finally:
Expand Down Expand Up @@ -377,6 +489,24 @@ def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
)


@cache
def _remove_black_list_model_of_fastchat():
from fastchat.model.model_adapter import model_adapters

black_list_models = []
for adapter in model_adapters:
try:
if (
adapter.get_default_conv_template("/data/not_exist_model_path").name
in _BLACK_LIST_MODLE_PROMPT
):
black_list_models.append(adapter)
except Exception:
pass
for adapter in black_list_models:
model_adapters.remove(adapter)


def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (
Expand Down
28 changes: 28 additions & 0 deletions pilot/scene/base_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,34 @@ def from_openai_messages(
raise ValueError(f"Unknown role: {msg_role}")
return result

@staticmethod
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""
history = []
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# Move the last user's information to the end
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
return history

@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))
Expand Down
Loading

0 comments on commit bd3ab7f

Please sign in to comment.