diff --git a/docs/zh_cn/serving/restful_api.md b/docs/zh_cn/serving/restful_api.md index 0a40700153..931c4a21a9 100644 --- a/docs/zh_cn/serving/restful_api.md +++ b/docs/zh_cn/serving/restful_api.md @@ -156,6 +156,13 @@ openaoe -f /path/to/your/config-template.yaml 具体信息请参考 [部署说明](https://github.com/InternLM/OpenAOE/blob/main/docs/tech-report/model_serving_by_lmdeploy/model_serving_by_lmdeploy.md). +### 自定义对话模板 + +LMDeploy 支持两种添加对话模板的形式: + +- 一种是以 LMDeploy 现有对话模板,自定义一个python对话模板类,注册成功后直接用即可。优点是自定义程度高,可控性强。 +- 一种是传入 Huggingface 的对话模板,即 Jinja 模板。 + ### FAQ 1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。如需调整会话支持的最大长度,可以通过启动`api_server`时,设置`--session_len`参数大小。 diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index ed9cc5a17d..9392f61cee 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -126,6 +126,7 @@ def add_parser_api_server(): # chat template args ArgumentHelper.meta_instruction(parser) + ArgumentHelper.jinja_template(parser) ArgumentHelper.cap(parser) # pytorch engine args @@ -210,6 +211,7 @@ def gradio(args): chat_template_config = ChatTemplateConfig( model_name=args.model_name, meta_instruction=args.meta_instruction, + jinja_template=args.jinja_template, capability=args.cap) run(args.model_path_or_server, server_name=args.server_name, @@ -244,7 +246,8 @@ def api_server(args): chat_template_config = ChatTemplateConfig( model_name=args.model_name, meta_instruction=args.meta_instruction, - capability=args.cap) + capability=args.cap, + jinja_template=args.jinja_template) run_api_server(args.model_path, backend=args.backend, backend_config=backend_config, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 1eb7bbbc60..db22298145 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -326,6 +326,18 @@ def meta_instruction(parser): default=None, help='System prompt for ChatTemplateConfig') + @staticmethod + def jinja_template(parser): + """Add argument jinjia template to parser.""" + + return parser.add_argument( + '--jinja-template', + type=str, + default=None, + help=\ + 'The file path to the chat template, or the template in single-line form. Could refer to https://huggingface.co/docs/transformers/main/en/chat_templating' # noqa + ) + @staticmethod def cache_max_entry_count(parser): """Add argument cache_max_entry_count to parser.""" diff --git a/lmdeploy/model.py b/lmdeploy/model.py index ed7b08e8b5..3868572b01 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import codecs import dataclasses import os from abc import abstractmethod +from copy import deepcopy from typing import List, Literal, Optional from fuzzywuzzy import fuzz, process @@ -34,7 +36,7 @@ class ChatTemplateConfig: capability: ('completion' | 'infilling' | 'chat' | 'python') = None """ # noqa: E501 - model_name: str + model_name: str = None system: Optional[str] = None meta_instruction: Optional[str] = None eosys: Optional[str] = None @@ -44,17 +46,35 @@ class ChatTemplateConfig: eoa: Optional[str] = None capability: Optional[Literal['completion', 'infilling', 'chat', 'python']] = None + jinja_template: Optional[str] = None @property def chat_template(self): attrs = { key: value for key, value in dataclasses.asdict(self).items() - if value is not None + if value is not None and key != 'jinja_template' } model: BaseModel = MODELS.get(self.model_name).from_config(**attrs) return model + def get_jinja_template(self): + """Get the jinja template.""" + if self.jinja_template is not None: + try: + with open(self.jinja_template, 'r') as f: + template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + template = codecs.decode(self.jinja_template, 'unicode_escape') + return template + return None + + def copy(self): + """Get a copy of the class.""" + return deepcopy(self) + @MODELS.register_module(name='internlm') @MODELS.register_module(name='llama') diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index bd81372c9e..2cc8ba3635 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -3,6 +3,7 @@ import dataclasses import random from argparse import ArgumentError +from asyncio.log import logger from contextlib import contextmanager from queue import Empty, Queue from threading import Thread @@ -65,23 +66,26 @@ def __init__(self, tp: int = 1, **kwargs) -> None: if backend == 'turbomind': - self._build_turbomind(model_path=model_path, - model_name=model_name, - backend_config=backend_config, - chat_template_config=chat_template_config, - tp=tp, - **kwargs) + self._build_turbomind( + model_path=model_path, + model_name=model_name, + backend_config=backend_config, + chat_template_config=chat_template_config.copy(), + tp=tp, + **kwargs) elif backend == 'pytorch': - self._build_pytorch(model_path=model_path, - model_name=model_name, - backend_config=backend_config, - chat_template_config=chat_template_config, - **kwargs) + self._build_pytorch( + model_path=model_path, + model_name=model_name, + backend_config=backend_config, + chat_template_config=chat_template_config.copy(), + **kwargs) else: raise ValueError(f'unsupported backend {backend}') self.backend = backend self.instance_num = self.backend_config.max_batch_size self.tokenizer = self.engine.tokenizer + self._load_chat_template(chat_template_config) self.id2step = {} self.id2generator = {} self.loop = asyncio.get_event_loop() @@ -89,6 +93,31 @@ def __init__(self, for i in range(self.instance_num): self.gens_set.add(self.engine.create_instance()) + def _load_chat_template(self, chat_template_config: ChatTemplateConfig): + """Load a chat template from chat_template_config. + + Priority: + 1. chat_template_config.model_name + 2. chat_template_config.jinja_template + 3. jinja template in tokenizer_config.json + 4. deduced chat_template in lmdeploy + """ + # if model_name is given, lmdeploy template will be applied + # no matter what Jinja template + if chat_template_config and chat_template_config.model_name: + return + # if no model_name passed in, will choose tokenizer's template + # it could be a Jinja if it exists in tokenizer_config.json + # if there is no Jinja template in tokenizer_config.json, a deduced + # lmdeploy template will be applied + if type(self.tokenizer.model.chat_template) == str: + self.chat_template = self.tokenizer.model.chat_template + + # user defined Jinja template will be applied once a user pass + # a Jinja template in instead of a model name + if chat_template_config and chat_template_config.jinja_template: + self.chat_template = chat_template_config.get_jinja_template() + def _build_turbomind( self, model_path: str, @@ -440,7 +469,21 @@ async def generate( gen_config.random_seed = random.getrandbits(64) prompt = messages if do_preprocess: - prompt = self.chat_template.messages2prompt(prompt, sequence_start) + if type(prompt) == str: + if hasattr(self.chat_template, 'messages2prompt'): + prompt = self.chat_template.messages2prompt( + prompt, sequence_start) + else: + # TODO better logger + logger.warning(f'{self.chat_template} Jinja chat template' + f' Can not be used for interactive chat. ' + 'Please use lmdeploy defined chat template ' + 'by passing in a model name.') + else: + # support + prompt = self.tokenizer.apply_chat_template(prompt, + self.chat_template, + tokenize=False) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) finish_reason = None if self.id2step[str(session_id)] + len( diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 2a645f2d75..8d1bf20927 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -957,7 +957,6 @@ def serve(model_path: str, backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend config instance. Default to none. chat_template_config (ChatTemplateConfig): chat template configuration. - Default to None. server_name (str): host ip for serving server_port (int): server port tp (int): tensor parallel diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py index b37a5a2483..1c8eb8482f 100644 --- a/lmdeploy/tokenizer.py +++ b/lmdeploy/tokenizer.py @@ -3,10 +3,11 @@ import os.path as osp from collections import deque from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch +from lmdeploy.model import MODELS, BaseModel, best_match_model from lmdeploy.utils import get_logger # this file will be copied to triton server, make sure all @@ -195,6 +196,16 @@ def __init__(self, model_dir: str): self.logger = get_logger('lmdeploy') self.model = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + + # get chat template from hf + self.chat_template = self.model.chat_template + deduced_name = best_match_model(model_dir) + self.lmdeploy_chat_template = None # for interactive chat + if deduced_name is not None: + # will apply if hf chat template is None + self.lmdeploy_chat_template = MODELS.get(deduced_name)() + self.chat_template = self.lmdeploy_chat_template + self._prefix_space_tokens = None if self.model.eos_token_id is None: @@ -465,6 +476,49 @@ def __call__(self, s: Union[str, Sequence[str]]): add_special_tokens = False return self.model(s, add_special_tokens=add_special_tokens) + def apply_chat_template( + self, + conversation: List[Dict[str, str]], + chat_template: Optional[Union[str, BaseModel]] = None, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + **tokenizer_kwargs, + ) -> Union[str, List[int]]: + """This is a function compatible with huggingface + AutoTokenizer.apply_chat_template. + + Args: + conversation (List | str): type string for interactive chat. + List refers to OpenAI format messages. + """ + if chat_template is None: + chat_template = self.chat_template + if hasattr(chat_template, 'messages2prompt'): + prompt = chat_template.messages2prompt(conversation) + if tokenize: + return self.encode(prompt) + else: + return prompt + elif isinstance(chat_template, str) or chat_template is None: + # apply hf chat template + return self.model.apply_chat_template( + conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **tokenizer_kwargs) + else: + raise TypeError(f'Unsupported chat_template type: {chat_template}' + f' for {conversation}') + class Tokenizer: """Tokenize prompts or de-tokenize tokens into texts. @@ -574,3 +628,28 @@ def indexes_containing_token(self, token): 'than 1. Currently, it can not be used as stop words') return [] return self.model.indexes_containing_token(token) + + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]], str], + chat_template: Optional[str] = None, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + **tokenizer_kwargs, + ) -> Union[str, List[int]]: + """This is a function compatible with huggingface + AutoTokenizer.apply_chat_template.""" + return self.model.apply_chat_template( + conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **tokenizer_kwargs)