diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index e1575ea03..0e1fb3d40 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -78,6 +78,10 @@ def get_device() -> str: "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"), "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), "internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"), + "codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"), + "codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"), + "codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"), + "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"), # For test now "opt-125m": os.path.join(MODEL_PATH, "opt-125m"), } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 69b159a13..5ce5b2173 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -320,6 +320,19 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer +class CodeLlamaAdapter(BaseLLMAdaper): + """The model adapter for codellama""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().loader(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + class BaichuanAdapter(BaseLLMAdaper): """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)""" @@ -420,6 +433,7 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GPT4AllAdapter) register_llm_model_adapters(Llama2Adapter) +register_llm_model_adapters(CodeLlamaAdapter) register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(WizardLMAdapter) register_llm_model_adapters(LlamaCppAdapater) diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index b3674e946..5d4309d9f 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -339,6 +339,27 @@ def get_conv_template(name: str) -> Conversation: ) ) + +# codellama template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md +register_conv_template( + Conversation( + name="codellama", + system="[INST] <>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request." + "If you don't know the answer to the request, please don't share false information.\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n", + ) +) + + # Alpaca default template register_conv_template( Conversation( diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 1580e8863..e09b868e7 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -45,6 +45,10 @@ "llama-cpp", "proxyllm", "gptj-6b", + "codellama-13b-sql-sft", + "codellama-7b", + "codellama-7b-sql-sft", + "codellama-13b", ] @@ -148,8 +152,12 @@ def model_adaptation( conv.append_message(conv.roles[1], content) else: raise ValueError(f"Unknown role: {role}") + if system_messages: - conv.set_system_message("".join(system_messages)) + if isinstance(conv, Conversation): + conv.set_system_message("".join(system_messages)) + else: + conv.update_system_message("".join(system_messages)) # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index cb486021b..64b72739b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -215,6 +215,16 @@ def get_conv_template(self, model_path: str) -> Conversation: return get_conv_template("llama-2") +class CodeLlamaChatAdapter(BaseChatAdpter): + """The model ChatAdapter for codellama .""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def get_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codellama") + + class BaichuanChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "baichuan" in model_path.lower() @@ -268,6 +278,7 @@ def get_conv_template(self, model_path: str) -> Conversation: register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GPT4AllChatAdapter) register_llm_model_chat_adapter(Llama2ChatAdapter) +register_llm_model_chat_adapter(CodeLlamaChatAdapter) register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter) register_llm_model_chat_adapter(LlamaCppChatAdapter)