From f9d9f8faea86f545d547f92248ade18ae44d5fcb Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 23 Feb 2024 12:37:13 +0800 Subject: [PATCH] feat(model): Support gemma model (#1187) --- README.md | 2 ++ README.zh.md | 2 ++ dbgpt/configs/model_config.py | 4 +++ dbgpt/model/adapter/hf_adapter.py | 51 +++++++++++++++++++++++++++++-- 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 39ef6e25e..8caa5556d 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,8 @@ At present, we have introduced several key features to showcase our current capa We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more. - News + - 🔥🔥🔥 [gemma-7b-it](https://huggingface.co/google/gemma-7b-it) + - 🔥🔥🔥 [gemma-2b-it](https://huggingface.co/google/gemma-2b-it) - 🔥🔥🔥 [SOLAR-10.7B](https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0) - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) diff --git a/README.zh.md b/README.zh.md index 18b8985e5..1a662e4e4 100644 --- a/README.zh.md +++ b/README.zh.md @@ -139,6 +139,8 @@ DB-GPT是一个开源的AI原生数据应用开发框架(AI Native Data App Deve 海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型: - 新增支持模型 + - 🔥🔥🔥 [gemma-7b-it](https://huggingface.co/google/gemma-7b-it) + - 🔥🔥🔥 [gemma-2b-it](https://huggingface.co/google/gemma-2b-it) - 🔥🔥🔥 [SOLAR-10.7B](https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0) - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 187481553..c6d475cb8 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -148,6 +148,10 @@ def get_device() -> str: # https://huggingface.co/01-ai/Yi-34B-Chat-4bits "yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"), "yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"), + # https://huggingface.co/google/gemma-7b-it + "gemma-7b-it": os.path.join(MODEL_PATH, "gemma-7b-it"), + # https://huggingface.co/google/gemma-2b-it + "gemma-2b-it": os.path.join(MODEL_PATH, "gemma-2b-it"), } EMBEDDING_MODEL_CONFIG = { diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py index 7e3faabcb..05429af3b 100644 --- a/dbgpt/model/adapter/hf_adapter.py +++ b/dbgpt/model/adapter/hf_adapter.py @@ -39,19 +39,38 @@ def match( def do_match(self, lower_model_name_or_path: Optional[str] = None): raise NotImplementedError() - def load(self, model_path: str, from_pretrained_kwargs: dict): + def check_dependencies(self) -> None: + """Check if the dependencies are installed + + Raises: + ValueError: If the dependencies are not installed + """ try: import transformers - from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer except ImportError as exc: raise ValueError( "Could not import depend python package " "Please install it with `pip install transformers`." ) from exc - if not transformers.__version__ >= "4.34.0": + self.check_transformer_version(transformers.__version__) + + def check_transformer_version(self, current_version: str) -> None: + if not current_version >= "4.34.0": raise ValueError( "Current model (Load by NewHFChatModelAdapter) require transformers.__version__>=4.34.0" ) + + def load(self, model_path: str, from_pretrained_kwargs: dict): + try: + import transformers + from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + except ImportError as exc: + raise ValueError( + "Could not import depend python package " + "Please install it with `pip install transformers`." + ) from exc + self.check_dependencies() + revision = from_pretrained_kwargs.get("revision", "main") try: tokenizer = AutoTokenizer.from_pretrained( @@ -149,6 +168,32 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None): ) +class GemmaAdapter(NewHFChatModelAdapter): + """ + https://huggingface.co/google/gemma-7b-it + + TODO: There are problems with quantization. + """ + + support_4bit: bool = False + support_8bit: bool = False + support_system_message: bool = False + + def check_transformer_version(self, current_version: str) -> None: + if not current_version >= "4.38.0": + raise ValueError( + "Gemma require transformers.__version__>=4.38.0, please upgrade your transformers package." + ) + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return ( + lower_model_name_or_path + and "gemma-" in lower_model_name_or_path + and "it" in lower_model_name_or_path + ) + + register_model_adapter(YiAdapter) register_model_adapter(Mixtral8x7BAdapter) register_model_adapter(SOLARAdapter) +register_model_adapter(GemmaAdapter)