From a92f34081c6144294a2920b8450b733cd4611745 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Fri, 24 Nov 2023 20:05:09 +0800 Subject: [PATCH] feat(model): Support Yi-34B-Chat (#837) --- README.md | 1 + README.zh.md | 1 + pilot/configs/model_config.py | 3 + pilot/model/cluster/worker/default_worker.py | 5 +- pilot/model/llm_out/hf_chat_llm.py | 54 ++++++ pilot/model/model_adapter.py | 166 +++++++++++++++++-- pilot/scene/base_message.py | 28 ++++ setup.py | 17 +- 8 files changed, 253 insertions(+), 22 deletions(-) create mode 100644 pilot/model/llm_out/hf_chat_llm.py diff --git a/README.md b/README.md index fb8180ed0..524541209 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,7 @@ At present, we have introduced several key features to showcase our current capa - [openchat_3.5](https://huggingface.co/openchat/openchat_3.5) - [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) - [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) + - [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat) - Support API Proxy LLMs - [x] [ChatGPT](https://api.openai.com/) diff --git a/README.zh.md b/README.zh.md index 6f85ac605..00ead0cc8 100644 --- a/README.zh.md +++ b/README.zh.md @@ -133,6 +133,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模 - [openchat_3.5](https://huggingface.co/openchat/openchat_3.5) - [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) - [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) + - [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat) - 支持在线代理模型 - [x] [ChatGPT](https://api.openai.com/) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 51c1bf2d9..890d380f9 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -126,6 +126,9 @@ def get_device() -> str: "xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"), # https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1 "xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"), + # https://huggingface.co/01-ai/Yi-34B-Chat + "yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"), + "yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"), } EMBEDDING_MODEL_CONFIG = { diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 8cc447fce..89d064bcf 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -253,6 +253,7 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str): params, self.model_name, self.model_path, + self.tokenizer, prompt_template=self.ml.prompt_template, ) stream_type = "" @@ -269,7 +270,9 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str): self.model, self.model_path ) str_prompt = params.get("prompt") - print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n") + print( + f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n" + ) generate_stream_func_str_name = "{}.{}".format( generate_stream_func.__module__, generate_stream_func.__name__ diff --git a/pilot/model/llm_out/hf_chat_llm.py b/pilot/model/llm_out/hf_chat_llm.py new file mode 100644 index 000000000..570b20b09 --- /dev/null +++ b/pilot/model/llm_out/hf_chat_llm.py @@ -0,0 +1,54 @@ +import logging +import torch +from threading import Thread +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType + +logger = logging.getLogger(__name__) + + +@torch.inference_mode() +def huggingface_chat_generate_stream( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + params, + device, + context_len=4096, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 0.7)) + top_p = float(params.get("top_p", 1.0)) + echo = params.get("echo", False) + max_new_tokens = int(params.get("max_new_tokens", 2048)) + + input_ids = tokenizer(prompt).input_ids + # input_ids = input_ids.to(device) + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + input_ids = input_ids[-max_src_len:] + input_echo_len = len(input_ids) + input_ids = torch.as_tensor([input_ids], device=device) + + # messages = params["messages"] + # messages = ModelMessage.to_openai_messages(messages) + # input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt') + # input_ids = input_ids.to(device) + + streamer = TextIteratorStreamer( + tokenizer, skip_prompt=not echo, skip_special_tokens=True + ) + generate_kwargs = { + "input_ids": input_ids, + "max_length": context_len, + "temperature": temperature, + "streamer": streamer, + } + + thread = Thread(target=model.generate, kwargs=generate_kwargs) + thread.start() + out = "" + for new_text in streamer: + out += new_text + yield out diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 3b809e669..adf6321f0 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING +from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional import dataclasses import logging import threading @@ -41,6 +41,7 @@ thread_local = threading.local() _IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" + _OLD_MODELS = [ "llama-cpp", "proxyllm", @@ -51,6 +52,14 @@ "codellama-13b", ] +_NEW_HF_CHAT_MODELS = [ + "yi-34b", + "yi-6b", +] + +# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist. +_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"] + class LLMModelAdaper: """New Adapter for DB-GPT LLM models""" @@ -99,26 +108,25 @@ def get_default_conv_template( """Get the default conv template""" raise NotImplementedError - def model_adaptation( + def get_str_prompt( + self, + params: Dict, + messages: List[ModelMessage], + tokenizer: Any, + prompt_template: str = None, + ) -> Optional[str]: + return None + + def get_prompt_with_template( self, params: Dict, + messages: List[ModelMessage], model_name: str, model_path: str, + model_context: Dict, prompt_template: str = None, - ) -> Tuple[Dict, Dict]: - """Params adaptation""" + ): conv = self.get_default_conv_template(model_name, model_path) - messages = params.get("messages") - # Some model scontext to dbgpt server - model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} - - if messages: - # Dict message to ModelMessage - messages = [ - m if isinstance(m, ModelMessage) else ModelMessage(**m) - for m in messages - ] - params["messages"] = messages if prompt_template: logger.info(f"Use prompt template {prompt_template} from config") @@ -128,7 +136,7 @@ def model_adaptation( logger.info( f"No conv from model_path {model_path} or no messages in params, {self}" ) - return params, model_context + return None, None, None conv = conv.copy() system_messages = [] @@ -180,6 +188,41 @@ def model_adaptation( # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) new_prompt = conv.get_prompt() + return new_prompt, conv.stop_str, conv.stop_token_ids + + def model_adaptation( + self, + params: Dict, + model_name: str, + model_path: str, + tokenizer: Any, + prompt_template: str = None, + ) -> Tuple[Dict, Dict]: + """Params adaptation""" + messages = params.get("messages") + # Some model scontext to dbgpt server + model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} + if messages: + # Dict message to ModelMessage + messages = [ + m if isinstance(m, ModelMessage) else ModelMessage(**m) + for m in messages + ] + params["messages"] = messages + + new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template) + conv_stop_str, conv_stop_token_ids = None, None + if not new_prompt: + ( + new_prompt, + conv_stop_str, + conv_stop_token_ids, + ) = self.get_prompt_with_template( + params, messages, model_name, model_path, model_context, prompt_template + ) + if not new_prompt: + return params, model_context + # Overwrite the original prompt # TODO remote bos token and eos token from tokenizer_config.json of model prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) @@ -192,8 +235,8 @@ def model_adaptation( custom_stop_token_ids = params.get("stop_token_ids") # Prefer the value passed in from the input parameter - params["stop"] = custom_stop or conv.stop_str - params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids + params["stop"] = custom_stop or conv_stop_str + params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids return params, model_context @@ -270,6 +313,69 @@ def __str__(self) -> str: ) +class NewHFChatModelAdapter(LLMModelAdaper): + def load(self, model_path: str, from_pretrained_kwargs: dict): + try: + import transformers + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel + except ImportError as exc: + raise ValueError( + "Could not import depend python package " + "Please install it with `pip install transformers`." + ) from exc + if not transformers.__version__ >= "4.34.0": + raise ValueError( + "Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0" + ) + revision = from_pretrained_kwargs.get("revision", "main") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + revision=revision, + trust_remote_code=True, + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision, trust_remote_code=True + ) + try: + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + except NameError: + model = AutoModel.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + # tokenizer.use_default_system_prompt = False + return model, tokenizer + + def get_generate_stream_function(self, model, model_path: str): + """Get the generate stream function of the model""" + from pilot.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream + + return huggingface_chat_generate_stream + + def get_str_prompt( + self, + params: Dict, + messages: List[ModelMessage], + tokenizer: Any, + prompt_template: str = None, + ) -> Optional[str]: + from transformers import AutoTokenizer + + if not tokenizer: + raise ValueError("tokenizer is is None") + tokenizer: AutoTokenizer = tokenizer + + messages = ModelMessage.to_openai_messages(messages) + str_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return str_prompt + + def get_conv_template(name: str) -> "Conversation": """Get a conversation template.""" from fastchat.conversation import get_conv_template @@ -298,6 +404,11 @@ def get_llm_model_adapter( logger.info("Current model type is vllm, return VLLMModelAdaperWrapper") return VLLMModelAdaperWrapper() + use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS) + if use_new_hf_chat_models: + logger.info(f"Current model {model_name} use NewHFChatModelAdapter") + return NewHFChatModelAdapter() + must_use_old = any(m in model_name for m in _OLD_MODELS) if use_fastchat and not must_use_old: logger.info("Use fastcat adapter") @@ -334,6 +445,7 @@ def _get_fastchat_model_adapter( if use_fastchat_monkey_patch: model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch thread_local.model_name = model_name + _remove_black_list_model_of_fastchat() if caller: return caller(model_path) finally: @@ -377,6 +489,24 @@ def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None): ) +@cache +def _remove_black_list_model_of_fastchat(): + from fastchat.model.model_adapter import model_adapters + + black_list_models = [] + for adapter in model_adapters: + try: + if ( + adapter.get_default_conv_template("/data/not_exist_model_path").name + in _BLACK_LIST_MODLE_PROMPT + ): + black_list_models.append(adapter) + except Exception: + pass + for adapter in black_list_models: + model_adapters.remove(adapter) + + def _dynamic_model_parser() -> Callable[[None], List[Type]]: from pilot.utils.parameter_utils import _SimpleArgParser from pilot.model.parameter import ( diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index bca03acf1..c4c10459c 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -113,6 +113,34 @@ def from_openai_messages( raise ValueError(f"Unknown role: {msg_role}") return result + @staticmethod + def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]: + """Convert to OpenAI message format and + hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) + """ + history = [] + # Add history conversation + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + history.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + history.append({"role": "system", "content": message.content}) + elif message.role == ModelMessageRoleType.AI: + history.append({"role": "assistant", "content": message.content}) + else: + pass + # Move the last user's information to the end + temp_his = history[::-1] + last_user_input = None + for m in temp_his: + if m["role"] == "user": + last_user_input = m + break + if last_user_input: + history.remove(last_user_input) + history.append(last_user_input) + return history + @staticmethod def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: return list(map(lambda m: m.dict(), messages)) diff --git a/setup.py b/setup.py index 6de96aa27..ca0bd818f 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,10 @@ LLAMA_CPP_GPU_ACCELERATION = ( os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true" ) +BUILD_FROM_SOURCE = os.getenv("BUILD_FROM_SOURCE", "false").lower() == "true" +BUILD_FROM_SOURCE_URL_FAST_CHAT = os.getenv( + "BUILD_FROM_SOURCE_URL_FAST_CHAT", "git+https://github.com/lm-sys/FastChat.git" +) def parse_requirements(file_name: str) -> List[str]: @@ -298,7 +302,6 @@ def core_requires(): ] setup_spec.extras["framework"] = [ - "fschat", "coloredlogs", "httpx", "sqlparse==0.4.4", @@ -315,7 +318,8 @@ def core_requires(): "duckdb-engine", "jsonschema", # TODO move transformers to default - "transformers>=4.31.0", + # "transformers>=4.31.0", + "transformers>=4.34.0", "alembic==1.12.0", # for excel "openpyxl==3.1.2", @@ -324,6 +328,12 @@ def core_requires(): # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit. "pympler", ] + if BUILD_FROM_SOURCE: + setup_spec.extras["framework"].append( + f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}" + ) + else: + setup_spec.extras["framework"].append("fschat") def knowledge_requires(): @@ -426,7 +436,8 @@ def default_requires(): pip install "db-gpt[default]" """ setup_spec.extras["default"] = [ - "tokenizers==0.13.3", + # "tokenizers==0.13.3", + "tokenizers>=0.14", "accelerate>=0.20.3", "sentence-transformers", "protobuf==3.20.3",