diff --git a/fastchat/conversation.py b/fastchat/conversation.py index e3601be43..6067710e8 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -28,6 +28,7 @@ class SeparatorStyle(IntEnum): PHOENIX = auto() ROBIN = auto() FALCON_CHAT = auto() + CHATGLM3 = auto() @dataclasses.dataclass @@ -163,6 +164,16 @@ def get_prompt(self) -> str: else: ret += role + "\n" return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = "" + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + "\n" + " " + message + else: + ret += role + return ret elif self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] @@ -448,6 +459,21 @@ def get_conv_template(name: str) -> Conversation: ) ) +# ChatGLM3 default template +register_conv_template( + Conversation( + name="chatglm3", + system_template="<|system|>\n {system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATGLM3, + stop_token_ids=[ + 64795, + 64797, + 2, + ], # "<|user|>", "<|observation|>", "" + ) +) + # CodeGeex(2) Template register_conv_template( Conversation( diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index ca03e69e4..9a2424d57 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -754,9 +754,17 @@ def match(self, model_path: str): def load_model(self, model_path: str, from_pretrained_kwargs: dict): revision = from_pretrained_kwargs.get("revision", "main") - tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, revision=revision - ) + if "chatglm3" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + encode_special_tokens=True, + trust_remote_code=True, + revision=revision, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) model = AutoModel.from_pretrained( model_path, trust_remote_code=True, **from_pretrained_kwargs ) @@ -766,6 +774,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation: model_path = model_path.lower() if "chatglm2" in model_path.lower(): return get_conv_template("chatglm2") + if "chatglm3" in model_path.lower(): + return get_conv_template("chatglm3") return get_conv_template("chatglm")