')
+ if self.suffix_first:
+ # format as " {suf} {pre}"
+ prompt = f' {suffix} {prefix}'
+ else:
+ # format as " {pre} {suf} "
+ prompt = f' {prefix} {suffix} '
+ return prompt
+
+ def _get_prompt(self, prompt, sequence_start):
+ prompt = prompt.strip()
+ if sequence_start:
+ return f'{self.b_inst} ' \
+ f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \
+ f'{prompt} {self.e_inst}'
+
+ return f'{self.b_inst} {prompt} {self.e_inst}'
+
+ @property
+ def stop_words(self):
+ if self.capability == 'infilling':
+ # EOT ID
+ return [32010]
+ else:
+ return None
+
+ def messages2prompt(self, messages, sequence_start=True):
+ assert self.capability == 'chat', \
+ f'codellama message2prompt only supports chat mode ' \
+ f'but got {self.cap} mode'
+ return super().messages2prompt(messages, sequence_start)
+
+
def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
diff --git a/lmdeploy/serve/client.py b/lmdeploy/serve/client.py
index 1d22d4ba3..283e96e29 100644
--- a/lmdeploy/serve/client.py
+++ b/lmdeploy/serve/client.py
@@ -6,16 +6,23 @@
from lmdeploy.serve.turbomind.chatbot import Chatbot
-def input_prompt():
- """Input a prompt in the console interface."""
- print('\ndouble enter to end input >>> ', end='')
- sentinel = '' # ends when this string is seen
+def input_prompt(model_name):
+ """Input a prompt in the consolo interface."""
+ if model_name == 'codellama':
+ print('\nenter !! to end the input >>>\n', end='')
+ sentinel = '!!'
+ else:
+ print('\ndouble enter to end input >>> ', end='')
+ sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def main(tritonserver_addr: str,
session_id: int = 1,
- stream_output: bool = True):
+ cap: str = 'chat',
+ sys_instruct: str = None,
+ stream_output: bool = True,
+ **kwargs):
"""An example to communicate with inference server through the command line
interface.
@@ -23,15 +30,22 @@ def main(tritonserver_addr: str,
tritonserver_addr (str): the address in format "ip:port" of
triton inference server
session_id (int): the identical id of a session
+ cap (str): the capability of a model. For example, codellama has
+ the ability among ['completion', 'infill', 'instruct', 'python']
+ sys_instruct (str): the content of 'system' role, which is used by
+ conversational model
stream_output (bool): indicator for streaming output or not
+ **kwargs (dict): other arguments for initializing model's chat template
"""
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
+ kwargs.update(capability=cap, system=sys_instruct)
chatbot = Chatbot(tritonserver_addr,
log_level=log_level,
- display=stream_output)
+ display=stream_output,
+ **kwargs)
nth_round = 1
while True:
- prompt = input_prompt()
+ prompt = input_prompt(chatbot.model_name)
if prompt == 'exit':
exit(0)
elif prompt == 'end':
diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py
index 1212d3459..eb532e260 100644
--- a/lmdeploy/serve/turbomind/chatbot.py
+++ b/lmdeploy/serve/turbomind/chatbot.py
@@ -149,6 +149,7 @@ def stream_infer(self,
self._session.status = 1
self._session.request_id = request_id
self._session.response = ''
+ self.cfg.update(**kwargs)
self._session.prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session,
@@ -507,7 +508,7 @@ def _stream_producer(tritonserver_addr, session, que, cfg, input_ids,
server
session (Session): an instance of a session
que (multiprocessing.Queue): response queue
- cfg:
+ cfg (dict): parameters for sampling
input_ids (numpy.ndarray): token ids of input prompt
input_lengths (numpy.ndarray): length of input_ids
request_output_len (int): the max number of tokens to be generated
diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py
index 516afd793..1c2b1becc 100644
--- a/lmdeploy/serve/turbomind/deploy.py
+++ b/lmdeploy/serve/turbomind/deploy.py
@@ -122,6 +122,7 @@ def export(model_name: str,
max_position_embeddings: int = 0,
use_dynamic_ntk: int = 0,
use_logn_attn: int = 0,
+ rope_theta: float = 10000.0,
tokenizer_info=tokenizer_info_sp):
"""Export deploying information to a config file.
@@ -213,6 +214,7 @@ def save_bin(param: torch.Tensor, name):
vocab_size=_vocab_size,
num_layer=num_layer,
rotary_embedding=size_per_head,
+ rope_theta=rope_theta,
inter_size=inter_size,
norm_eps=norm_eps,
attn_bias=int(attn_bias),
@@ -233,7 +235,8 @@ def save_bin(param: torch.Tensor, name):
# extra attention params
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk=int(use_dynamic_ntk),
- use_logn_attn=int(use_logn_attn)))
+ use_logn_attn=int(use_logn_attn),
+ ))
config = configparser.ConfigParser()
for section, key_values in cfg.items():
@@ -415,6 +418,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
+ rope_theta = float(model_arg.get('rope_theta', 10000.0))
+ max_position_embeddings = int(
+ model_arg.get('max_position_embeddings', 0))
+ repo_scaling = bool(model_arg.get('rope_scaling', False))
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
@@ -525,13 +532,23 @@ def get_tensor_transposed(name: str):
for ft, hf in other:
model_params[ft] = get_tensor(hf)
- if model_name == 'baichuan2-7b-chat':
+ if model_name == 'baichuan2-7b':
+ # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/modeling_baichuan.py#L507
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
model_params['output.weight'] = torch.nn.functional.normalize(
model_params['output.weight'])
- return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
- tokenizer_path, triton_models_path, tp)
+ return export(model_name,
+ num_layer,
+ norm_eps,
+ kv_head_num,
+ model_params,
+ tokenizer_path,
+ triton_models_path,
+ tp,
+ max_position_embeddings=max_position_embeddings,
+ use_dynamic_ntk=repo_scaling,
+ rope_theta=rope_theta)
def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
@@ -574,6 +591,7 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
+ rope_theta = float(model_arg.get('rope_theta', 10000.0))
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
@@ -761,7 +779,8 @@ def tp_m_s4(x: torch.Tensor, tp: int):
triton_models_path,
tp,
weight_type='int4',
- group_size=group_size)
+ group_size=group_size,
+ rope_theta=rope_theta)
def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
@@ -802,6 +821,7 @@ def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
config = json.load(f)
num_layer = config['num_hidden_layers']
norm_eps = config['layer_norm_epsilon']
+ rope_theta = float(config.get('rotary_emb_base', 10000.0))
if 'num_key_value_heads' in config:
kv_head_num = config['num_key_value_heads']
else:
@@ -889,6 +909,7 @@ def get_tensor(name, trans=True):
max_position_embeddings=seq_length,
use_dynamic_ntk=use_dynamic_ntk,
use_logn_attn=use_logn_attn,
+ rope_theta=rope_theta,
tokenizer_info=tokenizer_info_qwen)
diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py
index 68692a840..d617b1983 100644
--- a/lmdeploy/turbomind/chat.py
+++ b/lmdeploy/turbomind/chat.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import dataclasses
import os
import os.path as osp
import random
@@ -12,10 +13,26 @@
os.environ['TM_LOG_LEVEL'] = 'ERROR'
-def input_prompt():
+@dataclasses.dataclass
+class GenParam:
+ top_p: float
+ top_k: float
+ temperature: float
+ repetition_penalty: float
+ sequence_start: bool = False
+ sequence_end: bool = False
+ step: int = 0
+ request_output_len: int = 512
+
+
+def input_prompt(model_name):
"""Input a prompt in the consolo interface."""
- print('\ndouble enter to end input >>> ', end='')
- sentinel = '' # ends when this string is seen
+ if model_name == 'codellama':
+ print('\nenter !! to end the input >>>\n', end='')
+ sentinel = '!!'
+ else:
+ print('\ndouble enter to end input >>> ', end='')
+ sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
@@ -29,20 +46,50 @@ def valid_str(string, coding='utf-8'):
return ret
+def get_gen_param(cap,
+ sampling_param,
+ nth_round,
+ step,
+ request_output_len=512,
+ **kwargs):
+ """return parameters used by token generation."""
+ gen_param = GenParam(**dataclasses.asdict(sampling_param),
+ request_output_len=request_output_len)
+ # Fix me later. turbomind.py doesn't support None top_k
+ if gen_param.top_k is None:
+ gen_param.top_k = 40
+
+ if cap == 'chat':
+ gen_param.sequence_start = (nth_round == 1)
+ gen_param.sequence_end = False
+ gen_param.step = step
+ else:
+ gen_param.sequence_start = True
+ gen_param.sequence_end = True
+ gen_param.step = 0
+ return gen_param
+
+
def main(model_path,
session_id: int = 1,
- repetition_penalty: float = 1.0,
+ cap: str = 'chat',
+ sys_instruct: str = None,
tp=1,
- stream_output=True):
+ stream_output=True,
+ **kwargs):
"""An example to perform model inference through the command line
interface.
Args:
model_path (str): the path of the deployed model
session_id (int): the identical id of a session
- repetition_penalty (float): parameter to penalize repetition
+ cap (str): the capability of a model. For example, codellama has
+ the ability among ['completion', 'infilling', 'chat', 'python']
+ sys_instruct (str): the content of 'system' role, which is used by
+ conversational model
tp (int): GPU number used in tensor parallelism
stream_output (bool): indicator for streaming output or not
+ **kwarg (dict): other arguments for initializing model's chat template
"""
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
@@ -53,10 +100,13 @@ def main(model_path,
step = 0
seed = random.getrandbits(64)
model_name = tm_model.model_name
- model = MODELS.get(model_name)()
+ model = MODELS.get(model_name)(capability=cap, **kwargs) \
+ if sys_instruct is None else MODELS.get(model_name)(
+ capability=cap, system=sys_instruct, **kwargs)
+ print(f'session {session_id}')
while True:
- prompt = input_prompt()
+ prompt = input_prompt(model_name)
if prompt == 'exit':
exit(0)
elif prompt == 'end':
@@ -73,28 +123,23 @@ def main(model_path,
step = 0
seed = random.getrandbits(64)
else:
- print(f'session {session_id}')
- prompt = model.get_prompt(prompt, nth_round == 1)
+ prompt = model.get_prompt(prompt, nth_round)
input_ids = tokenizer.encode(prompt)
if step + len(input_ids) >= tm_model.session_len:
print('WARNING: exceed session max length.'
' Please end the session.')
continue
+
+ gen_param = get_gen_param(cap, model.sampling_param, nth_round,
+ step, **kwargs)
+
print(f'{prompt} ', end='', flush=True)
response_size = 0
for outputs in generator.stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=stream_output,
- request_output_len=512,
- sequence_start=(nth_round == 1),
- sequence_end=False,
- step=step,
- stop=False,
- top_k=40,
- top_p=0.8,
- temperature=0.8,
- repetition_penalty=repetition_penalty,
+ **dataclasses.asdict(gen_param),
ignore_eos=False,
random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0]
diff --git a/lmdeploy/turbomind/tokenizer.py b/lmdeploy/turbomind/tokenizer.py
index bb7f95e9e..98db9c2b6 100644
--- a/lmdeploy/turbomind/tokenizer.py
+++ b/lmdeploy/turbomind/tokenizer.py
@@ -111,7 +111,8 @@ class HuggingFaceTokenizer:
"""
def __init__(self, model_dir: str):
- from transformers import AutoTokenizer, LlamaTokenizerFast
+ from transformers import (AutoTokenizer, CodeLlamaTokenizerFast,
+ LlamaTokenizerFast)
model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
@@ -120,7 +121,8 @@ def __init__(self, model_dir: str):
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)
- self.need_padding = isinstance(self.model, LlamaTokenizerFast)
+ self.need_padding = isinstance(self.model, LlamaTokenizerFast) \
+ or isinstance(self.model, CodeLlamaTokenizerFast)
self._no_prefix_space_tokens = None
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
diff --git a/requirements.txt b/requirements.txt
index c0cd48396..861623c04 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,6 +12,6 @@ setuptools
shortuuid
tiktoken
torch
-transformers
+transformers>=4.33.0
tritonclient[all]
uvicorn
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention.h b/src/turbomind/kernels/decoder_masked_multihead_attention.h
index dba396bf4..b44332090 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention.h
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention.h
@@ -121,6 +121,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base {
int max_position_embeddings = 0;
bool use_dynamic_ntk = false;
bool use_logn_attn = false;
+ float rotary_embedding_base = 10000.0f;
};
template
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
index 6b9101abb..c2b6039d6 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
@@ -1378,19 +1378,20 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params
q = add(q, q_bias);
k = add(k, k_bias);
- float rotary_emb_base = 10000.f;
+ float rotary_embedding_base = params.rotary_embedding_base;
if (params.use_dynamic_ntk) {
// +1 because of `length_per_sample == context_length - 1`
- rotary_emb_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
- params.max_position_embeddings,
- params.rotary_embedding_dim,
- rotary_emb_base);
+ rotary_embedding_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
+ params.max_position_embeddings,
+ params.rotary_embedding_dim,
+ rotary_embedding_base);
}
// Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0) {
- apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_emb_base, params.timestep - padd_len);
+ apply_rotary_embedding(
+ q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base, params.timestep - padd_len);
}
if (params.use_logn_attn) {
diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu
index 536175ccf..b2450c867 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.cu
+++ b/src/turbomind/kernels/unfused_attention_kernels.cu
@@ -863,6 +863,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
int kv_head_num,
int size_per_head,
int rotary_embedding_dim,
+ float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn)
@@ -931,14 +932,13 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
const int context_len = history_len + input_length[batch_idx];
const int timestep = history_len + seq_idx;
- float rotary_emb_base = 10000.f;
if (use_dynamic_ntk) {
- rotary_emb_base = mmha::rotary_embedding_get_base(
- context_len, max_position_embeddings, rotary_embedding_dim, rotary_emb_base);
+ rotary_embedding_base = mmha::rotary_embedding_get_base(
+ context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base);
}
// TODO: unused computation on k if GQA is used
- mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_emb_base, timestep);
+ mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep);
if (use_logn_attn) {
// +1 to convert to context length at the timestep
@@ -990,6 +990,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
kv_head_num, \
size_per_head, \
rotary_embedding_dim, \
+ rotary_embedding_base, \
max_position_embeddings, \
use_dynamic_ntk, \
use_logn_attn);
@@ -1010,6 +1011,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
+ float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn,
@@ -1039,6 +1041,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int kv_head_num, \
const int size_per_head, \
const int rotary_embedding_dim, \
+ float rotary_embedding_base, \
int max_position_embeddings, \
bool use_dynamic_ntk, \
bool use_logn_attn, \
diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h
index 50069fc33..b5c37b5d4 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.h
+++ b/src/turbomind/kernels/unfused_attention_kernels.h
@@ -79,6 +79,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
+ float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn,
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 66bcf7570..e8f77e1c7 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -175,6 +175,7 @@ inline void LlamaContextAttentionLayer::forward(TensorMap*
local_kv_head_num_,
size_per_head_,
params_.rotray_embedding_dim,
+ params_.rotary_embedding_base,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index eec9a7fbd..3caaf5906 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -61,6 +61,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
+ const float rotary_embedding_base,
const int max_position_embeddings,
const bool use_dynamic_ntk,
const bool use_logn_attn,
@@ -129,6 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
+ params.rotary_embedding_base = rotary_embedding_base;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
@@ -261,6 +263,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o
local_kv_head_num_,
size_per_head_,
params_.rotray_embedding_dim,
+ params_.rotary_embedding_base,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h
index a2387e44e..8f8c96837 100644
--- a/src/turbomind/models/llama/llama_params.h
+++ b/src/turbomind/models/llama/llama_params.h
@@ -5,10 +5,11 @@
namespace turbomind {
struct LlamaAttentionParams {
- int rotray_embedding_dim;
- int max_position_embeddings;
- bool use_dynamic_ntk;
- bool use_logn_attn;
+ int rotray_embedding_dim;
+ float rotary_embedding_base;
+ int max_position_embeddings;
+ bool use_dynamic_ntk;
+ bool use_logn_attn;
};
} // namespace turbomind
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index 169d6cbdb..456f5f41c 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -137,6 +137,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size,
group_size_ = reader.GetInteger("llama", "group_size", 0);
attn_params_.rotray_embedding_dim = reader.GetInteger("llama", "rotary_embedding");
+ attn_params_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f);
attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0);
attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0);
attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0);
diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py
new file mode 100644
index 000000000..83487f1f0
--- /dev/null
+++ b/tests/test_lmdeploy/test_model.py
@@ -0,0 +1,205 @@
+import pytest
+
+from lmdeploy.model import MODELS, SamplingParam
+
+
+def test_base_model():
+ model = MODELS.get('llama')()
+ assert model is not None
+ assert model.capability == 'chat'
+ assert model.get_prompt('test') is None
+ assert model.stop_words is None
+
+ model = MODELS.get('internlm')(capability='completion')
+ assert model.capability == 'completion'
+ assert model.get_prompt('hi') == 'hi'
+ assert model.messages2prompt('test') == 'test'
+
+
+def test_vicuna():
+ prompt = 'hello, can u introduce yourself'
+ model = MODELS.get('vicuna')(capability='completion')
+ assert model.get_prompt(prompt, sequence_start=True) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is None
+ assert model.system is not None
+
+ model = MODELS.get('vicuna')(capability='chat',
+ system='Provide answers in Python')
+ assert model.get_prompt(prompt, sequence_start=True) != prompt
+ assert model.get_prompt(prompt, sequence_start=False) != prompt
+ assert model.system == 'Provide answers in Python'
+
+ model = MODELS.get('vicuna')(capability='voice')
+ _prompt = None
+ with pytest.raises(AssertionError):
+ _prompt = model.get_prompt(prompt, sequence_start=True)
+ assert _prompt is None
+
+
+def test_internlm_chat():
+ prompt = 'hello, can u introduce yourself'
+ model = MODELS.get('internlm-chat-7b')(capability='completion')
+ assert model.get_prompt(prompt, sequence_start=True) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is not None
+ assert model.system == ''
+ assert model.session_len == 2048
+
+ model = MODELS.get('internlm-chat-7b')(capability='chat',
+ system='Provide answers in Python')
+ assert model.get_prompt(prompt, sequence_start=True) != prompt
+ assert model.get_prompt(prompt, sequence_start=False) != prompt
+ assert model.system == 'Provide answers in Python'
+
+ model = MODELS.get('internlm-chat-7b')(capability='voice')
+ _prompt = None
+ with pytest.raises(AssertionError):
+ _prompt = model.get_prompt(prompt, sequence_start=True)
+ assert _prompt is None
+
+ model = MODELS.get('internlm-chat-7b-8k')()
+ assert model.session_len == 8192
+
+
+def test_baichuan():
+ prompt = 'hello, can u introduce yourself'
+ model = MODELS.get('baichuan-7b')(capability='completion')
+ assert model.get_prompt(prompt, sequence_start=True) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is None
+ assert model.repetition_penalty == 1.1
+
+ model = MODELS.get('baichuan-7b')(capability='chat')
+ _prompt = model.get_prompt(prompt, sequence_start=True)
+ assert _prompt is None
+
+
+def test_llama2():
+ prompt = 'hello, can u introduce yourself'
+ model = MODELS.get('llama2')(capability='completion')
+ assert model.get_prompt(prompt, sequence_start=True) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is None
+ assert model.default_sys_prompt is not None
+
+ model = MODELS.get('llama2')(capability='chat',
+ system='Provide answers in Python')
+ assert model.get_prompt(prompt, sequence_start=True) != prompt
+ assert model.get_prompt(prompt, sequence_start=False) != prompt
+ assert model.default_sys_prompt == 'Provide answers in Python'
+
+ model = MODELS.get('llama2')(capability='voice')
+ _prompt = None
+ with pytest.raises(AssertionError):
+ _prompt = model.get_prompt(prompt, sequence_start=True)
+ assert _prompt is None
+
+
+def test_qwen():
+ prompt = 'hello, can u introduce yourself'
+ model = MODELS.get('qwen-7b')(capability='completion')
+ assert model.get_prompt(prompt, sequence_start=True) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is not None
+
+ model = MODELS.get('qwen-7b')(capability='chat')
+ assert model.get_prompt(prompt, sequence_start=True) != prompt
+ assert model.get_prompt(prompt, sequence_start=False) != prompt
+
+ model = MODELS.get('qwen-7b')(capability='voice')
+ _prompt = None
+ with pytest.raises(AssertionError):
+ _prompt = model.get_prompt(prompt, sequence_start=True)
+ assert _prompt is None
+
+
+def test_codellama_completion():
+ model = MODELS.get('codellama')(capability='completion')
+ prompt = """\
+import socket
+
+def ping_exponential_backoff(host: str):"""
+ assert model.get_prompt(prompt) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is None
+
+
+def test_codellama_infilling():
+ model = MODELS.get('codellama')(capability='infilling')
+ prompt = '''def remove_non_ascii(s: str) -> str:
+ """
+ return result
+'''
+ _prompt = model.get_prompt(prompt)
+ assert _prompt.find('') == -1
+ assert model.stop_words == [32010]
+
+ model = MODELS.get('codellama')(capability='infilling', suffix_first=True)
+ _prompt = model.get_prompt(prompt)
+ assert _prompt.find('') == -1
+
+
+def test_codellama_chat():
+ model = MODELS.get('codellama')(capability='chat',
+ system='Provide answers in Python')
+ prompt = 'Write a function that computes the set of sums of all contiguous sublists of a given list.' # noqa: E501
+ _prompt = model.get_prompt(prompt, sequence_start=True)
+ assert _prompt.find('Provide answers in Python') != -1
+
+ _prompt = model.get_prompt(prompt, sequence_start=False)
+ assert _prompt.find('Provide answers in Python') == -1
+ assert model.stop_words is None
+
+
+def test_codellama_python_specialist():
+ model = MODELS.get('codellama')(capability='python')
+ prompt = """
+ def remove_non_ascii(s: str) -> str:
+"""
+ assert model.get_prompt(prompt, sequence_start=True) == prompt
+ assert model.get_prompt(prompt, sequence_start=False) == prompt
+ assert model.stop_words is None
+
+
+def test_codellama_others():
+ model = None
+ with pytest.raises(AssertionError):
+ model = MODELS.get('codellama')(capability='java')
+ assert model is None
+
+
+def test_sampling_param():
+ model = MODELS.get('llama')()
+ default_sampling_param = SamplingParam()
+ assert model.sampling_param == default_sampling_param
+
+ model = MODELS.get('llama')(top_p=0.1, top_k=10)
+ assert model.sampling_param.top_p == 0.1 and \
+ model.sampling_param.top_k == 10
+ assert model.sampling_param.temperature == 0.8 and \
+ model.sampling_param.repetition_penalty == 1.0
+
+ model = MODELS.get('codellama')(capability='completion')
+ assert model.sampling_param.top_p == 0.9 and \
+ model.sampling_param.top_k is None and \
+ model.sampling_param.temperature == 0.2 and \
+ model.sampling_param.repetition_penalty == 1.0
+
+ model = MODELS.get('codellama')(capability='chat')
+ assert model.sampling_param.top_p == 0.95 and \
+ model.sampling_param.top_k is None and \
+ model.sampling_param.temperature == 0.2 and \
+ model.sampling_param.repetition_penalty == 1.0
+
+ model = MODELS.get('codellama')(capability='infilling')
+ assert model.sampling_param.top_p == 0.9 and \
+ model.sampling_param.top_k is None and \
+ model.sampling_param.temperature == 0.0 and \
+ model.sampling_param.repetition_penalty == 1.0
+
+ model = MODELS.get('codellama')(capability='python')
+ assert model.sampling_param.top_p == 0.9 and \
+ model.sampling_param.top_k is None and \
+ model.sampling_param.temperature == 0.2 and \
+ model.sampling_param.repetition_penalty == 1.0