Skip to content

Commit

Permalink
Add support for baichuan2 models (#2408)
Browse files Browse the repository at this point in the history
  • Loading branch information
obitoquilt authored Sep 13, 2023
1 parent aa153d5 commit 3149253
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
14 changes: 14 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,20 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Baichuan2-13B-Chat template
register_conv_template(
# source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json
# https://github.com/baichuan-inc/Baichuan2/issues/62
Conversation(
name="baichuan2-chat",
roles=("<reserved_106>", "<reserved_107>"),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="",
stop_token_ids=[],
)
)

# llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
Expand Down
2 changes: 2 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
def get_default_conv_template(self, model_path: str) -> Conversation:
# for Baichuan-13B-Chat
if "chat" in model_path.lower():
if "baichuan2" in model_path.lower():
return get_conv_template("baichuan2-chat")
return get_conv_template("baichuan-chat")
return get_conv_template("zero_shot")

Expand Down

0 comments on commit 3149253

Please sign in to comment.