diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 73fb541f1..fcf882c5c 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -804,6 +804,20 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Baichuan2-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773 + # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan2/issues/62 + Conversation( + name="baichuan2-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + # llama2 template # reference: https://huggingface.co/blog/codellama#conversational-instructions # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 423308455..296b53c8f 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1172,6 +1172,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): def get_default_conv_template(self, model_path: str) -> Conversation: # for Baichuan-13B-Chat if "chat" in model_path.lower(): + if "baichuan2" in model_path.lower(): + return get_conv_template("baichuan2-chat") return get_conv_template("baichuan-chat") return get_conv_template("zero_shot")