Skip to content

Commit

Permalink
Support huggingface chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Feb 1, 2024
1 parent 453cc25 commit f5738a6
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 17 deletions.
7 changes: 7 additions & 0 deletions docs/zh_cn/serving/restful_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ openaoe -f /path/to/your/config-template.yaml

具体信息请参考 [部署说明](https://github.com/InternLM/OpenAOE/blob/main/docs/tech-report/model_serving_by_lmdeploy/model_serving_by_lmdeploy.md).

### 自定义对话模板

LMDeploy 支持两种添加对话模板的形式:

- 一种是以 LMDeploy 现有对话模板,自定义一个python对话模板类,注册成功后直接用即可。优点是自定义程度高,可控性强。
- 一种是传入 Huggingface 的对话模板,即 Jinja 模板。

### FAQ

1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。如需调整会话支持的最大长度,可以通过启动`api_server`时,设置`--session_len`参数大小。
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def add_parser_api_server():

# chat template args
ArgumentHelper.meta_instruction(parser)
ArgumentHelper.jinja_template(parser)
ArgumentHelper.cap(parser)

# pytorch engine args
Expand Down Expand Up @@ -210,6 +211,7 @@ def gradio(args):
chat_template_config = ChatTemplateConfig(
model_name=args.model_name,
meta_instruction=args.meta_instruction,
jinja_template=args.jinja_template,
capability=args.cap)
run(args.model_path_or_server,
server_name=args.server_name,
Expand Down Expand Up @@ -244,7 +246,8 @@ def api_server(args):
chat_template_config = ChatTemplateConfig(
model_name=args.model_name,
meta_instruction=args.meta_instruction,
capability=args.cap)
capability=args.cap,
jinja_template=args.jinja_template)
run_api_server(args.model_path,
backend=args.backend,
backend_config=backend_config,
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,18 @@ def meta_instruction(parser):
default=None,
help='System prompt for ChatTemplateConfig')

@staticmethod
def jinja_template(parser):
"""Add argument jinjia template to parser."""

return parser.add_argument(
'--jinja-template',
type=str,
default=None,
help=\
'The file path to the chat template, or the template in single-line form. Could refer to https://huggingface.co/docs/transformers/main/en/chat_templating' # noqa
)

@staticmethod
def cache_max_entry_count(parser):
"""Add argument cache_max_entry_count to parser."""
Expand Down
24 changes: 22 additions & 2 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import codecs
import dataclasses
import os
from abc import abstractmethod
from copy import deepcopy
from typing import List, Literal, Optional

from fuzzywuzzy import fuzz, process
Expand Down Expand Up @@ -34,7 +36,7 @@ class ChatTemplateConfig:
capability: ('completion' | 'infilling' | 'chat' | 'python') = None
""" # noqa: E501

model_name: str
model_name: str = None
system: Optional[str] = None
meta_instruction: Optional[str] = None
eosys: Optional[str] = None
Expand All @@ -44,17 +46,35 @@ class ChatTemplateConfig:
eoa: Optional[str] = None
capability: Optional[Literal['completion', 'infilling', 'chat',
'python']] = None
jinja_template: Optional[str] = None

@property
def chat_template(self):
attrs = {
key: value
for key, value in dataclasses.asdict(self).items()
if value is not None
if value is not None and key != 'jinja_template'
}
model: BaseModel = MODELS.get(self.model_name).from_config(**attrs)
return model

def get_jinja_template(self):
"""Get the jinja template."""
if self.jinja_template is not None:
try:
with open(self.jinja_template, 'r') as f:
template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
template = codecs.decode(self.jinja_template, 'unicode_escape')
return template
return None

def copy(self):
"""Get a copy of the class."""
return deepcopy(self)


@MODELS.register_module(name='internlm')
@MODELS.register_module(name='llama')
Expand Down
67 changes: 55 additions & 12 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import random
from argparse import ArgumentError
from asyncio.log import logger
from contextlib import contextmanager
from queue import Empty, Queue
from threading import Thread
Expand Down Expand Up @@ -65,30 +66,58 @@ def __init__(self,
tp: int = 1,
**kwargs) -> None:
if backend == 'turbomind':
self._build_turbomind(model_path=model_path,
model_name=model_name,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
self._build_turbomind(
model_path=model_path,
model_name=model_name,
backend_config=backend_config,
chat_template_config=chat_template_config.copy(),
tp=tp,
**kwargs)
elif backend == 'pytorch':
self._build_pytorch(model_path=model_path,
model_name=model_name,
backend_config=backend_config,
chat_template_config=chat_template_config,
**kwargs)
self._build_pytorch(
model_path=model_path,
model_name=model_name,
backend_config=backend_config,
chat_template_config=chat_template_config.copy(),
**kwargs)
else:
raise ValueError(f'unsupported backend {backend}')
self.backend = backend
self.instance_num = self.backend_config.max_batch_size
self.tokenizer = self.engine.tokenizer
self._load_chat_template(chat_template_config)
self.id2step = {}
self.id2generator = {}
self.loop = asyncio.get_event_loop()
self.gens_set = set()
for i in range(self.instance_num):
self.gens_set.add(self.engine.create_instance())

def _load_chat_template(self, chat_template_config: ChatTemplateConfig):
"""Load a chat template from chat_template_config.
Priority:
1. chat_template_config.model_name
2. chat_template_config.jinja_template
3. jinja template in tokenizer_config.json
4. deduced chat_template in lmdeploy
"""
# if model_name is given, lmdeploy template will be applied
# no matter what Jinja template
if chat_template_config and chat_template_config.model_name:
return
# if no model_name passed in, will choose tokenizer's template
# it could be a Jinja if it exists in tokenizer_config.json
# if there is no Jinja template in tokenizer_config.json, a deduced
# lmdeploy template will be applied
if type(self.tokenizer.model.chat_template) == str:
self.chat_template = self.tokenizer.model.chat_template

# user defined Jinja template will be applied once a user pass
# a Jinja template in instead of a model name
if chat_template_config and chat_template_config.jinja_template:
self.chat_template = chat_template_config.get_jinja_template()

def _build_turbomind(
self,
model_path: str,
Expand Down Expand Up @@ -440,7 +469,21 @@ async def generate(
gen_config.random_seed = random.getrandbits(64)
prompt = messages
if do_preprocess:
prompt = self.chat_template.messages2prompt(prompt, sequence_start)
if type(prompt) == str:
if hasattr(self.chat_template, 'messages2prompt'):
prompt = self.chat_template.messages2prompt(
prompt, sequence_start)
else:
# TODO better logger
logger.warning(f'{self.chat_template} Jinja chat template'
f' Can not be used for interactive chat. '
'Please use lmdeploy defined chat template '
'by passing in a model name.')
else:
# support
prompt = self.tokenizer.apply_chat_template(prompt,
self.chat_template,
tokenize=False)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = None
if self.id2step[str(session_id)] + len(
Expand Down
1 change: 0 additions & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,6 @@ def serve(model_path: str,
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): host ip for serving
server_port (int): server port
tp (int): tensor parallel
Expand Down
81 changes: 80 additions & 1 deletion lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os.path as osp
from collections import deque
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch

from lmdeploy.model import MODELS, BaseModel, best_match_model
from lmdeploy.utils import get_logger

# this file will be copied to triton server, make sure all
Expand Down Expand Up @@ -195,6 +196,16 @@ def __init__(self, model_dir: str):
self.logger = get_logger('lmdeploy')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)

# get chat template from hf
self.chat_template = self.model.chat_template
deduced_name = best_match_model(model_dir)
self.lmdeploy_chat_template = None # for interactive chat
if deduced_name is not None:
# will apply if hf chat template is None
self.lmdeploy_chat_template = MODELS.get(deduced_name)()
self.chat_template = self.lmdeploy_chat_template

self._prefix_space_tokens = None

if self.model.eos_token_id is None:
Expand Down Expand Up @@ -465,6 +476,49 @@ def __call__(self, s: Union[str, Sequence[str]]):
add_special_tokens = False
return self.model(s, add_special_tokens=add_special_tokens)

def apply_chat_template(
self,
conversation: List[Dict[str, str]],
chat_template: Optional[Union[str, BaseModel]] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[str] = None,
**tokenizer_kwargs,
) -> Union[str, List[int]]:
"""This is a function compatible with huggingface
AutoTokenizer.apply_chat_template.
Args:
conversation (List | str): type string for interactive chat.
List refers to OpenAI format messages.
"""
if chat_template is None:
chat_template = self.chat_template
if hasattr(chat_template, 'messages2prompt'):
prompt = chat_template.messages2prompt(conversation)
if tokenize:
return self.encode(prompt)
else:
return prompt
elif isinstance(chat_template, str) or chat_template is None:
# apply hf chat template
return self.model.apply_chat_template(
conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**tokenizer_kwargs)
else:
raise TypeError(f'Unsupported chat_template type: {chat_template}'
f' for {conversation}')


class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Expand Down Expand Up @@ -574,3 +628,28 @@ def indexes_containing_token(self, token):
'than 1. Currently, it can not be used as stop words')
return []
return self.model.indexes_containing_token(token)

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], str],
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[str] = None,
**tokenizer_kwargs,
) -> Union[str, List[int]]:
"""This is a function compatible with huggingface
AutoTokenizer.apply_chat_template."""
return self.model.apply_chat_template(
conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**tokenizer_kwargs)

0 comments on commit f5738a6

Please sign in to comment.