Skip to content

Commit

Permalink
feat(model) Support the DB-GPT-Hub trained mode (#760)
Browse files Browse the repository at this point in the history
Close #745
Support the DB-GPT-Hub trained model, codellama series .
  • Loading branch information
fangyinc authored Oct 31, 2023
2 parents f19b600 + 3233e26 commit 852cf67
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pilot/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def get_device() -> str:
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
Expand Down
14 changes: 14 additions & 0 deletions pilot/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,19 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class CodeLlamaAdapter(BaseLLMAdaper):
"""The model adapter for codellama"""

def match(self, model_path: str):
return "codellama" in model_path.lower()

def loader(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer


class BaichuanAdapter(BaseLLMAdaper):
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""

Expand Down Expand Up @@ -420,6 +433,7 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(CodeLlamaAdapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)
Expand Down
21 changes: 21 additions & 0 deletions pilot/model/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,27 @@ def get_conv_template(name: str) -> Conversation:
)
)


# codellama template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
register_conv_template(
Conversation(
name="codellama",
system="<s>[INST] <<SYS>>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
"If you don't know the answer to the request, please don't share false information.\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
)
)


# Alpaca default template
register_conv_template(
Conversation(
Expand Down
10 changes: 9 additions & 1 deletion pilot/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
"llama-cpp",
"proxyllm",
"gptj-6b",
"codellama-13b-sql-sft",
"codellama-7b",
"codellama-7b-sql-sft",
"codellama-13b",
]


Expand Down Expand Up @@ -148,8 +152,12 @@ def model_adaptation(
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")

if system_messages:
conv.set_system_message("".join(system_messages))
if isinstance(conv, Conversation):
conv.set_system_message("".join(system_messages))
else:
conv.update_system_message("".join(system_messages))

# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
Expand Down
11 changes: 11 additions & 0 deletions pilot/server/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")


class CodeLlamaChatAdapter(BaseChatAdpter):
"""The model ChatAdapter for codellama ."""

def match(self, model_path: str):
return "codellama" in model_path.lower()

def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("codellama")


class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "baichuan" in model_path.lower()
Expand Down Expand Up @@ -268,6 +278,7 @@ def get_conv_template(self, model_path: str) -> Conversation:
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(CodeLlamaChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)
Expand Down

0 comments on commit 852cf67

Please sign in to comment.