Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose stop words and filter eoa #352

Merged
merged 10 commits into from
Sep 26, 2023
Merged
44 changes: 14 additions & 30 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ def __init__(self,
temperature=0.8,
repetition_penalty=1.0,
capability='chat',
stop_words=None,
**kwargs):
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.repetition_penalty = repetition_penalty
self.stop_words = stop_words
self.capability = capability

def get_prompt(self, prompt, sequence_start=True):
Expand Down Expand Up @@ -101,11 +103,6 @@ def messages2prompt(self, messages, sequence_start=True):
return self.get_prompt(messages)
# chat history processing in derived classes

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return None

@property
def sampling_param(self):
return SamplingParam(top_p=self.top_p,
Expand Down Expand Up @@ -180,13 +177,15 @@ def __init__(self,
eoh='',
eoa='<eoa>',
assistant='<|Bot|>',
stop_words=['<eoa>'],
**kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.eoh = eoh
self.eoa = eoa
self.assistant = assistant
self.stop_words = stop_words

def decorate_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
Expand All @@ -202,7 +201,7 @@ def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.user}:{prompt}{self.eoh}\n' \
return f'<BOS>{self.system}{self.user}:{prompt}{self.eoh}\n' \
f'{self.assistant}:'
else:
return f'\n{self.user}:{prompt}{self.eoh}\n' \
Expand All @@ -220,7 +219,7 @@ def messages2prompt(self, messages, sequence_start=True):
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
ret = '<BOS>'
ret = '<BOS>' + self.system if system is None else '<BOS>' + system
AllentDan marked this conversation as resolved.
Show resolved Hide resolved
for user, assistant in zip(users, assistants):
if assistant:
ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' \
Expand All @@ -229,11 +228,6 @@ def messages2prompt(self, messages, sequence_start=True):
ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:'
return ret

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [103028]


@MODELS.register_module(name='internlm-chat-20b')
@MODELS.register_module(name='internlm-chat-7b-8k')
Expand Down Expand Up @@ -332,12 +326,14 @@ def __init__(self,
eoh='',
assistant='',
eoa='',
stop_words=None,
**kwargs):
super().__init__(**kwargs)
self.meta_instruction = meta_instruction
self.system = system
self.user = user
self.assistant = assistant
self.stop_words = stop_words
self.eosys = eosys
self.eoh = eoh
self.eoa = eoa
Expand Down Expand Up @@ -375,11 +371,6 @@ def messages2prompt(self, messages, sequence_start=True):
ret += f'{self.user}{user}{self.eoh}{self.assistant}'
return ret

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [45623]


@MODELS.register_module(name='llama2')
class Llama2(BaseModel):
Expand Down Expand Up @@ -461,6 +452,7 @@ def __init__(self,
im_start='<|im_start|>',
im_end='<|im_end|>',
system='You are a helpful assistant.',
stop_words=['<|im_end|>'],
**kwargs):
super().__init__(**kwargs)
self.session_len = session_len
Expand All @@ -471,6 +463,7 @@ def __init__(self,
self.im_start = im_start
self.im_end = im_end
self.system = system
self.stop_words = stop_words

def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \
Expand Down Expand Up @@ -506,11 +499,6 @@ def messages2prompt(self, messages, sequence_start=True):
f'\n{self.im_start}assistant\n'
return ret

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [151645] # <|im_end|>


@MODELS.register_module(name='codellama')
class CodeLlama(Llama2):
Expand All @@ -519,6 +507,7 @@ def __init__(self,
system='',
session_len=4096,
suffix_first=False,
stop_words=None,
**kwargs):
super().__init__(**kwargs)
caps = ['completion', 'infilling', 'chat', 'python']
Expand All @@ -528,6 +517,7 @@ def __init__(self,
self.default_sys_prompt = system
self.session_len = session_len
self.suffix_first = suffix_first
self.stop_words = stop_words

# The following sampling parameters refers to https://github.com/facebookresearch/codellama # noqa: E501
if self.capability == 'completion' or self.capability == 'python':
Expand All @@ -539,6 +529,8 @@ def __init__(self,
elif self.capability == 'infilling':
self.top_p = kwargs.get('top_p', 0.9)
self.temperature = kwargs.get('temperature', 0.0)
if self.stop_words is None:
self.stop_words = ['<EOT>']

def decorate_prompt(self, prompt, sequence_start=True):
if self.capability == 'infilling':
Expand Down Expand Up @@ -567,14 +559,6 @@ def _get_prompt(self, prompt, sequence_start):

return f'{self.b_inst} {prompt} {self.e_inst}'

@property
def stop_words(self):
if self.capability == 'infilling':
# EOT ID
return [32010]
else:
return None

def messages2prompt(self, messages, sequence_start=True):
assert self.capability == 'chat', \
f'codellama message2prompt only supports chat mode ' \
Expand Down
16 changes: 14 additions & 2 deletions lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lmdeploy.model import MODELS
from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor,
prepare_tensor)
from lmdeploy.utils import filter_suffix


@dataclass
Expand Down Expand Up @@ -157,6 +158,8 @@ def stream_infer(self,
request_output_len,
sequence_start,
sequence_end):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value < 0:
break
else:
Expand Down Expand Up @@ -346,6 +349,8 @@ def infer(self,
sequence_end):
if status.value < 0:
break
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value == 0:
self._session.histories = \
self._session.histories + self._session.prompt + \
Expand Down Expand Up @@ -386,16 +391,23 @@ def _get_eos(self):
token_ids, _ = self.preprocess('<EOS>')
return token_ids[0][0]

def _stop_words(self, stop_words: List[int]):
def _stop_words(self, stop_words: List[str]):
"""return stop-words' token ids."""
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about turbomind's stop_words
stop_words = [
int(self.preprocess(stop_word)[0][0][-1])
for stop_word in stop_words
]
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
'invalid stop_words'
stop_word_offsets = range(1, len(stop_words) + 1)
stop_words = np.array([[stop_words,
stop_word_offsets]]).astype(np.int32)
Expand Down
18 changes: 14 additions & 4 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import lmdeploy
from lmdeploy.model import MODELS
from lmdeploy.turbomind import Tokenizer
from lmdeploy.utils import get_logger

# TODO: find another way import _turbomind
Expand All @@ -22,14 +23,16 @@
import _turbomind as _tm # noqa: E402


def _stop_words(stop_words: List[int]):
def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
"""return list of stop-words to numpy.ndarray."""
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'

stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
assert isinstance(stop_words, List) and all(
isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_words
Expand Down Expand Up @@ -106,7 +109,10 @@ def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1):
self.model_name = parser.get(section_name, 'model_name')
data_type = parser.get(section_name, 'weight_type')
model = MODELS.get(self.model_name)()
self.stop_words = _stop_words(model.stop_words)
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
self.stop_words = _stop_words(model.stop_words, tokenizer)

# params
self.node_id = node_id
Expand Down Expand Up @@ -162,6 +168,8 @@ def __init__(self, tm_model, cuda_stream_id=0):
self.gpu_count = tm_model.gpu_count

self.stop_words = tm_model.stop_words
self.stop_tokens = [] if self.stop_words is None else \
self.stop_words.flatten().tolist()
self.eos_id = tm_model.eos_id
self.session_len = tm_model.session_len

Expand Down Expand Up @@ -346,6 +354,8 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
output, len_ = output, len_.item()
if len(output) > 0 and output[-1].item() == self.eos_id:
outputs.append((output[:-1], len_ - 1))
elif len(output) > 0 and output[-1].item() in self.stop_tokens:
outputs.append((output[:-1], len_))
else:
outputs.append((output, len_))

Expand Down
20 changes: 19 additions & 1 deletion lmdeploy/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional
from typing import List, Optional

logger_initialized = {}

Expand Down Expand Up @@ -77,3 +77,21 @@ def get_logger(name: str,
logger_initialized[name] = True

return logger


def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str:
"""Filter response with suffixes.

Args:
response (str): generated response by LLMs.
suffixes (str): a list of suffixes to be deleted.

Return:
str: a clean response.
"""
if suffixes is None:
return response
for item in suffixes:
if response.endswith(item):
response = response[:len(response) - len(item)]
return response
2 changes: 1 addition & 1 deletion tests/test_lmdeploy/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_codellama_infilling():
'''
_prompt = model.get_prompt(prompt)
assert _prompt.find('<FILL>') == -1
assert model.stop_words == [32010]
assert model.stop_words == ['<EOT>']

model = MODELS.get('codellama')(capability='infilling', suffix_first=True)
_prompt = model.get_prompt(prompt)
Expand Down
Loading