diff --git a/README.md b/README.md index 3ed91d925..42bcbee8b 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ ______________________________________________________________________ ## News 🎉 +- \[2023/09\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/supported_models/codellama.md) for deployment guide - \[2023/09\] TurboMind supports Baichuan2-7B - \[2023/08\] TurboMind supports flash-attention2. - \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling @@ -56,12 +57,15 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by > **Note**
> W4A16 inference requires Nvidia GPU with Ampere architecture or above. -| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 | -| :-------: | :-------------: | :--: | :-----: | :---: | :--: | -| Llama | Yes | Yes | Yes | Yes | No | -| Llama2 | Yes | Yes | Yes | Yes | No | -| InternLM | Yes | Yes | Yes | Yes | No | -| Baichuan2 | Yes | Yes | No | No | No | +| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 | +| :----------: | :-------------: | :--: | :-----: | :---: | :--: | +| Llama | Yes | Yes | Yes | Yes | No | +| Llama2 | Yes | Yes | Yes | Yes | No | +| InternLM | Yes | Yes | Yes | Yes | No | +| QWen-7B | Yes | Yes | Yes | No | No | +| Baichuan-7B | Yes | Yes | Yes | Yes | No | +| Baichuan2-7B | Yes | Yes | No | No | No | +| Code Llama | Yes | Yes | No | No | No | ### Pytorch diff --git a/README_zh-CN.md b/README_zh-CN.md index 1e3b101d3..35cae96eb 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -20,6 +20,7 @@ ______________________________________________________________________ ## 更新 🎉 +- \[2023/09\] TurboMind 支持 Code Llama 所有功能:代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/supported_models/codellama.md)阅读部署方法 - \[2023/09\] TurboMind 支持 Baichuan2-7B - \[2023/08\] TurboMind 支持 flash-attention2 - \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放 @@ -57,12 +58,15 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht > **Note**
> W4A16 推理需要 Ampere 及以上架构的 Nvidia GPU -| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 | -| :-------: | :------: | :--: | :-----: | :---: | :--: | -| Llama | Yes | Yes | Yes | Yes | No | -| Llama2 | Yes | Yes | Yes | Yes | No | -| InternLM | Yes | Yes | Yes | Yes | No | -| Baichuan2 | Yes | Yes | No | No | No | +| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 | +| :----------: | :------: | :--: | :-----: | :---: | :--: | +| Llama | Yes | Yes | Yes | Yes | No | +| Llama2 | Yes | Yes | Yes | Yes | No | +| InternLM | Yes | Yes | Yes | Yes | No | +| QWen-7B | Yes | Yes | Yes | No | No | +| Baichuan-7B | Yes | Yes | Yes | Yes | No | +| Baichuan2-7B | Yes | Yes | No | No | No | +| Code Llama | Yes | Yes | No | No | No | ### Pytorch diff --git a/docs/en/supported_models/codellama.md b/docs/en/supported_models/codellama.md new file mode 100644 index 000000000..1b5140205 --- /dev/null +++ b/docs/en/supported_models/codellama.md @@ -0,0 +1,112 @@ +# codellama + +## Introduction + +[codellama](https://github.com/facebookresearch/codellama) features enhanced coding capabilities. It can generate code and natural language about code, from both code and natural language prompts (e.g., “Write me a function that outputs the fibonacci sequence”). It can also be used for code completion and debugging. It supports many of the most popular programming languages used today, including Python, C++, Java, PHP, Typescript (Javascript), C#, Bash and more. + +There are three sizes (7b, 13b, 34b) as well as three flavours (base model, Python fine-tuned, and instruction tuned) released on [HuggingFace](https://huggingface.co/codellama). + +| Base Model | Python | Instruct | +| ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | +| [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) | [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf) | [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) | +| [codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf) | [codellama/CodeLlama-13b-Python-hf](https://huggingface.co/codellama/CodeLlama-13b-Python-hf) | [codellama/CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) | +| [codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | [codellama/CodeLlama-34b-Python-hf](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) | [codellama/CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) | + +The correspondence between the model and capabilities is: + +| models | code completion | infilling | instructions / chat | python specialist | +| ---------- | --------------- | ----------------- | ------------------- | ----------------- | +| Base Model | Y | Y(7B,13B), N(34B) | N | N | +| Python | Y | N | N | Y | +| Instruct | Y | Y(7B,13B), N(34B) | Y | N | + +## Inference + +Based on the above table, download the model that meets your requirements. Execute the following command to interact with the model in the console: + +```shell +# install lmdeploy +python3 -m pip install lmdeploy + +# convert weight layout +python3 -m lmdeploy.serve.turbomind.deploy codellama /the/path/of/codellama/model +``` + +Then, you can communicate with codellama in consolo by following instructions in next sections + +**Note**: + +- minimum requirement of `transformers` is **v4.33.0** +- lmdeploy supports copying code blocks to the console. But you have to press enter, input "!!" and press enter again to end the prompt. The way to input prompt for other supported models keeps unchanged, i.e., double pressing enter. + +### Completion + +```shell +python3 -m lmdeploy.turbomind.chat ./workspace --cap completion +``` + +### Infilling + +```shell +python3 -m lmdeploy.turbomind.chat ./workspace --cap infilling +``` + +The input code is supposed to have a special placeholder ``. For example, + +``` +def remove_non_ascii(s: str) -> str: + """ + return result +``` + +And the generated code piece by `turbomind.chat` is the one to be filled in `` + +### Chat + +``` +python3 -m lmdeploy.turbomind.chat ./workspace --cap chat --sys-instruct "Provide answers in Python" +``` + +`--sys-instruct` instruction can be changed to other coding languages as long as codellama supports it + +### Python specialist + +``` +python3 -m lmdeploy.turbomind.chat ./workspace --cap python +``` + +Python fine-tuned model is highly recommended when 'python specialist' capability is required. + +## Quantization + +TBD + +## Serving + +**LMDeploy server only supports `chat` capabllity**. The res ones are going to be supported soon. + +Launch inference server by: + +```shell +# --instance_num: number of instances to performance inference, which can be viewed as max requests concurrency +# --tp: the number of GPUs used in tensor parallelism +python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +``` + +Then, you can communicate with it by command line, + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +python -m lmdeploy.serve.openai.api_client restful_api_url +``` + +or through webui after launching gradio, + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +# server_ip and server_port here are for gradio ui +# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True +python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +``` + +Regarding the detailed information of RESTful API, you can refer to [restful_api.md](../restful_api.md). diff --git a/docs/zh_cn/supported_models/codellama.md b/docs/zh_cn/supported_models/codellama.md new file mode 100644 index 000000000..ca9029a52 --- /dev/null +++ b/docs/zh_cn/supported_models/codellama.md @@ -0,0 +1,114 @@ +# Code Llama + +## 模型介绍 + +[codellama](https://github.com/facebookresearch/codellama) 支持很多种编程语言,包括 Python, C++, Java, PHP, Typescript (Javascript), C#, Bash 等等。具备代码续写、代码填空、对话、python专项等 4 种能力。 + +它在 [HuggingFace](https://huggingface.co/codellama) 上发布了基座模型,Python模型和指令微调模型: + +| 基座模型 | Python微调模型 | 指令模型 | +| ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | +| [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) | [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf) | [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) | +| [codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf) | [codellama/CodeLlama-13b-Python-hf](https://huggingface.co/codellama/CodeLlama-13b-Python-hf) | [codellama/CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) | +| [codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | [codellama/CodeLlama-34b-Python-hf](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) | [codellama/CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) | + +模型和能力的对应关系为: + +| 模型 | 代码续写 | 代码填空 | 对话 | Python专项 | +| -------------- | -------- | ----------------- | ---- | ---------- | +| 基座模型 | Y | Y(7B,13B), N(34B) | N | N | +| Python微调模型 | Y | N | N | Y | +| 指令微调模型 | Y | Y(7B,13B), N(34B) | Y | N | + +## 推理 + +根据上述的模型和能力关系表,下载感兴趣的模型。执行如下的命令,把模型权重转成 turbomind 要求的格式: + +```shell +# 安装 lmdeploy +python3 -m pip install lmdeploy + +# 转模型格式 +python3 -m lmdeploy.serve.turbomind.deploy codellama /path/of/codellama/model +``` + +接下来,可参考如下章节,在控制台与 codellama 进行交互式对话。 + +**注意**: + +- **transformers最低要求 v4.33.0** +- `lmdeploy.turbomind.chat` 支持把代码块拷贝到控制台,**结束输出的方式为回车,再输入"!!",再回车**。其他非 codellama 模型,仍然是两次回车结束输入。 + +### 代码续写 + +```shell +python3 -m lmdeploy.turbomind.chat ./workspace --cap completion +``` + +### 代码填空 + +```shell +python3 -m lmdeploy.turbomind.chat ./workspace --cap infilling +``` + +输入的代码块中要包含 ``,比如: + +``` +def remove_non_ascii(s: str) -> str: + """ + return result +``` + +`turbomind.chat` 输出的代码即是要填到 `` 中的内容 + +### 对话 + +``` +python3 -m lmdeploy.turbomind.chat ./workspace --cap chat --sys-instruct "Provide answers in Python" +``` + +可以把 `--sys-instruct` 的指令换成 codellama 支持的其他变成语言。 + +### Python 专项 + +``` +python3 -m lmdeploy.turbomind.chat ./workspace --cap python +``` + +建议这里部署 Python 微调模型 + +## 量化 + +TBD + +## 服务 + +**目前,server 支持的是对话功能**,其余功能后续再加上。 + +启动 sever 的方式是: + +```shell +# --instance_num: turbomind推理实例的个数。可理解为支持的最大并发数 +# --tp: 在 tensor parallel时,使用的GPU数量 +python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1 +``` + +打开 `http://{server_ip}:{server_port}`,即可访问 swagger,查阅 RESTful API 的详细信息。 + +你可以用命令行,在控制台与 server 通信: + +```shell +# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 +python -m lmdeploy.serve.openai.api_client restful_api_url +``` + +或者,启动 gradio,在 webui 的聊天对话框中,与 codellama 交流: + +```shell +# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333 +# server_ip 和 server_port 是用来提供 gradio ui 访问服务的 +# 例子: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True +python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True +``` + +关于 RESTful API的详细介绍,请参考[这份](../restful_api.md)文档。 diff --git a/lmdeploy/model.py b/lmdeploy/model.py index b3706a59d..bf89e3906 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import dataclasses from abc import abstractmethod from typing import List @@ -7,7 +8,17 @@ MODELS = Registry('model', locations=['lmdeploy.model']) +@dataclasses.dataclass +class SamplingParam: + top_p: float = 0.8 + top_k: float = None + temperature: float = 0.8 + repetition_penalty: float = 1.0 + + +@MODELS.register_module(name='internlm') @MODELS.register_module(name='llama') +@MODELS.register_module(name='base') class BaseModel: """Base model.""" @@ -17,15 +28,16 @@ def __init__(self, top_k=None, temperature=0.8, repetition_penalty=1.0, + capability='chat', **kwargs): self.session_len = session_len self.top_p = top_p self.top_k = top_k self.temperature = temperature self.repetition_penalty = repetition_penalty + self.capability = capability - @staticmethod - def get_prompt(prompt, sequence_start=True): + def get_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the chat template. @@ -36,7 +48,14 @@ def get_prompt(prompt, sequence_start=True): Returns: str: the concatenated prompt """ - return prompt + if self.capability == 'completion': + return prompt + else: + return self.decorate_prompt(prompt, sequence_start) + + @abstractmethod + def decorate_prompt(self, prompt, sequence_start): + pass @staticmethod def _translate_messages(messages: List): @@ -87,6 +106,13 @@ def stop_words(self): """Return the stop-words' token ids.""" return None + @property + def sampling_param(self): + return SamplingParam(top_p=self.top_p, + top_k=self.top_k, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty) + @MODELS.register_module(name='vicuna') class Vicuna(BaseModel): @@ -103,7 +129,7 @@ def __init__( self.user = user self.assistant = assistant - def get_prompt(self, prompt, sequence_start=True): + def decorate_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the chat template. @@ -114,6 +140,8 @@ def get_prompt(self, prompt, sequence_start=True): Returns: str: the concatenated prompt """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: return f'{self.system} {self.user}: {prompt} {self.assistant}: ' else: @@ -141,13 +169,6 @@ def messages2prompt(self, messages, sequence_start=True): return ret -@MODELS.register_module(name='internlm') -class InternLM(BaseModel): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - @MODELS.register_module(name='internlm-chat-7b') class InternLMChat7B(BaseModel): """Chat template of InternLM model.""" @@ -166,7 +187,7 @@ def __init__(self, self.eoa = eoa self.assistant = assistant - def get_prompt(self, prompt, sequence_start=True): + def decorate_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the chat template. @@ -177,6 +198,8 @@ def get_prompt(self, prompt, sequence_start=True): Returns: str: the concatenated prompt """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: return f'{self.user}:{prompt}{self.eoh}\n' \ f'{self.assistant}:' @@ -227,8 +250,8 @@ def __init__(self, repetition_penalty=1.1, **kwargs): self.repetition_penalty = repetition_penalty -@MODELS.register_module(name='baichuan2-7b-chat') -class Baichuan2_7BChat(BaseModel): +@MODELS.register_module(name='baichuan2-7b') +class Baichuan2_7B(BaseModel): def __init__(self, temperature=0.3, @@ -244,7 +267,7 @@ def __init__(self, self.user_token = '' # id = 195 self.assistant_token = '' # id = 196 - def get_prompt(self, prompt, sequence_start=True): + def decorate_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the chat template. @@ -255,6 +278,8 @@ def get_prompt(self, prompt, sequence_start=True): Returns: str: the concatenated prompt """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' return f'{self.user_token}{prompt}{self.assistant_token}' def messages2prompt(self, messages, sequence_start=True): @@ -283,22 +308,24 @@ class Puyu(BaseModel): AI Laboratory.""" def __init__(self, - meta_instruction='', + system='', user='<|Human|>: ', eoh='', eosys='', assistant='<|Assistant|>: ', - system='<|System|>: ', + system_role='<|System|>: ', **kwargs): super().__init__(**kwargs) - self.meta_instruction = meta_instruction + self.meta_instruction = system self.user = user self.eoh = eoh self.eosys = eosys self.assistant = assistant - self.system = system + self.system = system_role - def get_prompt(self, prompt, sequence_start=True): + def decorate_prompt(self, prompt, sequence_start=True): + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: return f'{self.system}{self.meta_instruction}{self.eosys}\n' \ f'{self.user}{prompt}{self.eoh}\n' \ @@ -345,7 +372,7 @@ def __init__( e_inst='[/INST]', b_sys='<>\n', e_sys='\n<>\n\n', - default_sys_prompt="""\ + system="""\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", # noqa: E501 @@ -356,10 +383,10 @@ def __init__( self.e_inst = e_inst self.b_sys = b_sys self.e_sys = e_sys - self.default_sys_prompt = default_sys_prompt + self.default_sys_prompt = system self.session_len = session_len - def get_prompt(self, prompt, sequence_start=True): + def decorate_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the chat template. @@ -370,6 +397,8 @@ def get_prompt(self, prompt, sequence_start=True): Returns: str: the concatenated prompt """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: return f'{self.b_inst} ' \ f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \ @@ -424,7 +453,9 @@ def __init__(self, self.im_end = im_end self.system = system - def get_prompt(self, prompt, sequence_start=True): + def decorate_prompt(self, prompt, sequence_start=True): + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' if sequence_start: return f'{self.im_start}system\n{self.system}{self.im_end}' \ f'\n{self.im_start}user\n{prompt}{self.im_end}' \ @@ -462,6 +493,76 @@ def stop_words(self): return [151645] # <|im_end|> +@MODELS.register_module(name='codellama') +class CodeLlama(Llama2): + + def __init__(self, + system='', + session_len=4096, + suffix_first=False, + **kwargs): + super().__init__(**kwargs) + caps = ['completion', 'infilling', 'chat', 'python'] + assert self.capability in caps, \ + f'{self.capability} is not supported. ' \ + f'The supported capabilities are: {caps}' + self.default_sys_prompt = system + self.session_len = session_len + self.suffix_first = suffix_first + + # The following sampling parameters refers to https://github.com/facebookresearch/codellama # noqa: E501 + if self.capability == 'completion' or self.capability == 'python': + self.top_p = kwargs.get('top_p', 0.9) + self.temperature = kwargs.get('temperature', 0.2) + if self.capability == 'chat': + self.top_p = kwargs.get('top_p', 0.95) + self.temperature = kwargs.get('temperature', 0.2) + elif self.capability == 'infilling': + self.top_p = kwargs.get('top_p', 0.9) + self.temperature = kwargs.get('temperature', 0.0) + + def decorate_prompt(self, prompt, sequence_start=True): + if self.capability == 'infilling': + return self._infill_prompt(prompt) + elif self.capability == 'chat': + return self._get_prompt(prompt, sequence_start) + else: # python speicalist + return prompt + + def _infill_prompt(self, prompt): + prefix, suffix = prompt.split('') + if self.suffix_first: + # format as "
 {suf}  {pre}"
+            prompt = f'
 {suffix}  {prefix}'
+        else:
+            # format as "
 {pre} {suf} "
+            prompt = f'
 {prefix} {suffix} '
+        return prompt
+
+    def _get_prompt(self, prompt, sequence_start):
+        prompt = prompt.strip()
+        if sequence_start:
+            return f'{self.b_inst} ' \
+                   f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \
+                   f'{prompt} {self.e_inst}'
+
+        return f'{self.b_inst} {prompt} {self.e_inst}'
+
+    @property
+    def stop_words(self):
+        if self.capability == 'infilling':
+            # EOT ID
+            return [32010]
+        else:
+            return None
+
+    def messages2prompt(self, messages, sequence_start=True):
+        assert self.capability == 'chat', \
+            f'codellama message2prompt only supports chat mode ' \
+            f'but got {self.cap} mode'
+        return super().messages2prompt(messages, sequence_start)
+
+
 def main(model_name: str = 'test'):
     assert model_name in MODELS.module_dict.keys(), \
         f"'{model_name}' is not supported. " \
diff --git a/lmdeploy/serve/client.py b/lmdeploy/serve/client.py
index 1d22d4ba3..283e96e29 100644
--- a/lmdeploy/serve/client.py
+++ b/lmdeploy/serve/client.py
@@ -6,16 +6,23 @@
 from lmdeploy.serve.turbomind.chatbot import Chatbot
 
 
-def input_prompt():
-    """Input a prompt in the console interface."""
-    print('\ndouble enter to end input >>> ', end='')
-    sentinel = ''  # ends when this string is seen
+def input_prompt(model_name):
+    """Input a prompt in the consolo interface."""
+    if model_name == 'codellama':
+        print('\nenter !! to end the input >>>\n', end='')
+        sentinel = '!!'
+    else:
+        print('\ndouble enter to end input >>> ', end='')
+        sentinel = ''  # ends when this string is seen
     return '\n'.join(iter(input, sentinel))
 
 
 def main(tritonserver_addr: str,
          session_id: int = 1,
-         stream_output: bool = True):
+         cap: str = 'chat',
+         sys_instruct: str = None,
+         stream_output: bool = True,
+         **kwargs):
     """An example to communicate with inference server through the command line
     interface.
 
@@ -23,15 +30,22 @@ def main(tritonserver_addr: str,
         tritonserver_addr (str): the address in format "ip:port" of
           triton inference server
         session_id (int): the identical id of a session
+        cap (str): the capability of a model. For example, codellama has
+            the ability among ['completion', 'infill', 'instruct', 'python']
+        sys_instruct (str): the content of 'system' role, which is used by
+            conversational model
         stream_output (bool): indicator for streaming output or not
+        **kwargs (dict): other arguments for initializing model's chat template
     """
     log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
+    kwargs.update(capability=cap, system=sys_instruct)
     chatbot = Chatbot(tritonserver_addr,
                       log_level=log_level,
-                      display=stream_output)
+                      display=stream_output,
+                      **kwargs)
     nth_round = 1
     while True:
-        prompt = input_prompt()
+        prompt = input_prompt(chatbot.model_name)
         if prompt == 'exit':
             exit(0)
         elif prompt == 'end':
diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py
index 1212d3459..eb532e260 100644
--- a/lmdeploy/serve/turbomind/chatbot.py
+++ b/lmdeploy/serve/turbomind/chatbot.py
@@ -149,6 +149,7 @@ def stream_infer(self,
         self._session.status = 1
         self._session.request_id = request_id
         self._session.response = ''
+        self.cfg.update(**kwargs)
 
         self._session.prompt = self._get_prompt(prompt, sequence_start)
         for status, res, tokens in self._stream_infer(self._session,
@@ -507,7 +508,7 @@ def _stream_producer(tritonserver_addr, session, que, cfg, input_ids,
                 server
             session (Session): an instance of a session
             que (multiprocessing.Queue): response queue
-            cfg:
+            cfg (dict): parameters for sampling
             input_ids (numpy.ndarray): token ids of input prompt
             input_lengths (numpy.ndarray): length of input_ids
             request_output_len (int): the max number of tokens to be generated
diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py
index 516afd793..1c2b1becc 100644
--- a/lmdeploy/serve/turbomind/deploy.py
+++ b/lmdeploy/serve/turbomind/deploy.py
@@ -122,6 +122,7 @@ def export(model_name: str,
            max_position_embeddings: int = 0,
            use_dynamic_ntk: int = 0,
            use_logn_attn: int = 0,
+           rope_theta: float = 10000.0,
            tokenizer_info=tokenizer_info_sp):
     """Export deploying information to a config file.
 
@@ -213,6 +214,7 @@ def save_bin(param: torch.Tensor, name):
         vocab_size=_vocab_size,
         num_layer=num_layer,
         rotary_embedding=size_per_head,
+        rope_theta=rope_theta,
         inter_size=inter_size,
         norm_eps=norm_eps,
         attn_bias=int(attn_bias),
@@ -233,7 +235,8 @@ def save_bin(param: torch.Tensor, name):
         # extra attention params
         max_position_embeddings=max_position_embeddings,
         use_dynamic_ntk=int(use_dynamic_ntk),
-        use_logn_attn=int(use_logn_attn)))
+        use_logn_attn=int(use_logn_attn),
+    ))
 
     config = configparser.ConfigParser()
     for section, key_values in cfg.items():
@@ -415,6 +418,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
             model_arg = json.load(f)
             num_layer = model_arg['num_hidden_layers']
             norm_eps = model_arg['rms_norm_eps']
+            rope_theta = float(model_arg.get('rope_theta', 10000.0))
+            max_position_embeddings = int(
+                model_arg.get('max_position_embeddings', 0))
+            repo_scaling = bool(model_arg.get('rope_scaling', False))
             if 'num_key_value_heads' in model_arg:
                 kv_head_num = model_arg['num_key_value_heads']
             else:
@@ -525,13 +532,23 @@ def get_tensor_transposed(name: str):
     for ft, hf in other:
         model_params[ft] = get_tensor(hf)
 
-    if model_name == 'baichuan2-7b-chat':
+    if model_name == 'baichuan2-7b':
+        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/modeling_baichuan.py#L507
         # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
         model_params['output.weight'] = torch.nn.functional.normalize(
             model_params['output.weight'])
 
-    return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
-                  tokenizer_path, triton_models_path, tp)
+    return export(model_name,
+                  num_layer,
+                  norm_eps,
+                  kv_head_num,
+                  model_params,
+                  tokenizer_path,
+                  triton_models_path,
+                  tp,
+                  max_position_embeddings=max_position_embeddings,
+                  use_dynamic_ntk=repo_scaling,
+                  rope_theta=rope_theta)
 
 
 def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
@@ -574,6 +591,7 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
             model_arg = json.load(f)
             num_layer = model_arg['num_hidden_layers']
             norm_eps = model_arg['rms_norm_eps']
+            rope_theta = float(model_arg.get('rope_theta', 10000.0))
             if 'num_key_value_heads' in model_arg:
                 kv_head_num = model_arg['num_key_value_heads']
             else:
@@ -761,7 +779,8 @@ def tp_m_s4(x: torch.Tensor, tp: int):
                   triton_models_path,
                   tp,
                   weight_type='int4',
-                  group_size=group_size)
+                  group_size=group_size,
+                  rope_theta=rope_theta)
 
 
 def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
@@ -802,6 +821,7 @@ def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
             config = json.load(f)
             num_layer = config['num_hidden_layers']
             norm_eps = config['layer_norm_epsilon']
+            rope_theta = float(config.get('rotary_emb_base', 10000.0))
             if 'num_key_value_heads' in config:
                 kv_head_num = config['num_key_value_heads']
             else:
@@ -889,6 +909,7 @@ def get_tensor(name, trans=True):
                   max_position_embeddings=seq_length,
                   use_dynamic_ntk=use_dynamic_ntk,
                   use_logn_attn=use_logn_attn,
+                  rope_theta=rope_theta,
                   tokenizer_info=tokenizer_info_qwen)
 
 
diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py
index 68692a840..d617b1983 100644
--- a/lmdeploy/turbomind/chat.py
+++ b/lmdeploy/turbomind/chat.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import dataclasses
 import os
 import os.path as osp
 import random
@@ -12,10 +13,26 @@
 os.environ['TM_LOG_LEVEL'] = 'ERROR'
 
 
-def input_prompt():
+@dataclasses.dataclass
+class GenParam:
+    top_p: float
+    top_k: float
+    temperature: float
+    repetition_penalty: float
+    sequence_start: bool = False
+    sequence_end: bool = False
+    step: int = 0
+    request_output_len: int = 512
+
+
+def input_prompt(model_name):
     """Input a prompt in the consolo interface."""
-    print('\ndouble enter to end input >>> ', end='')
-    sentinel = ''  # ends when this string is seen
+    if model_name == 'codellama':
+        print('\nenter !! to end the input >>>\n', end='')
+        sentinel = '!!'
+    else:
+        print('\ndouble enter to end input >>> ', end='')
+        sentinel = ''  # ends when this string is seen
     return '\n'.join(iter(input, sentinel))
 
 
@@ -29,20 +46,50 @@ def valid_str(string, coding='utf-8'):
     return ret
 
 
+def get_gen_param(cap,
+                  sampling_param,
+                  nth_round,
+                  step,
+                  request_output_len=512,
+                  **kwargs):
+    """return parameters used by token generation."""
+    gen_param = GenParam(**dataclasses.asdict(sampling_param),
+                         request_output_len=request_output_len)
+    # Fix me later. turbomind.py doesn't support None top_k
+    if gen_param.top_k is None:
+        gen_param.top_k = 40
+
+    if cap == 'chat':
+        gen_param.sequence_start = (nth_round == 1)
+        gen_param.sequence_end = False
+        gen_param.step = step
+    else:
+        gen_param.sequence_start = True
+        gen_param.sequence_end = True
+        gen_param.step = 0
+    return gen_param
+
+
 def main(model_path,
          session_id: int = 1,
-         repetition_penalty: float = 1.0,
+         cap: str = 'chat',
+         sys_instruct: str = None,
          tp=1,
-         stream_output=True):
+         stream_output=True,
+         **kwargs):
     """An example to perform model inference through the command line
     interface.
 
     Args:
         model_path (str): the path of the deployed model
         session_id (int): the identical id of a session
-        repetition_penalty (float): parameter to penalize repetition
+        cap (str): the capability of a model. For example, codellama has
+            the ability among ['completion', 'infilling', 'chat', 'python']
+        sys_instruct (str): the content of 'system' role, which is used by
+            conversational model
         tp (int): GPU number used in tensor parallelism
         stream_output (bool): indicator for streaming output or not
+        **kwarg (dict): other arguments for initializing model's chat template
     """
     tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
     tokenizer = Tokenizer(tokenizer_model_path)
@@ -53,10 +100,13 @@ def main(model_path,
     step = 0
     seed = random.getrandbits(64)
     model_name = tm_model.model_name
-    model = MODELS.get(model_name)()
+    model = MODELS.get(model_name)(capability=cap, **kwargs) \
+        if sys_instruct is None else MODELS.get(model_name)(
+            capability=cap, system=sys_instruct, **kwargs)
 
+    print(f'session {session_id}')
     while True:
-        prompt = input_prompt()
+        prompt = input_prompt(model_name)
         if prompt == 'exit':
             exit(0)
         elif prompt == 'end':
@@ -73,28 +123,23 @@ def main(model_path,
             step = 0
             seed = random.getrandbits(64)
         else:
-            print(f'session {session_id}')
-            prompt = model.get_prompt(prompt, nth_round == 1)
+            prompt = model.get_prompt(prompt, nth_round)
             input_ids = tokenizer.encode(prompt)
             if step + len(input_ids) >= tm_model.session_len:
                 print('WARNING: exceed session max length.'
                       ' Please end the session.')
                 continue
+
+            gen_param = get_gen_param(cap, model.sampling_param, nth_round,
+                                      step, **kwargs)
+
             print(f'{prompt} ', end='', flush=True)
             response_size = 0
             for outputs in generator.stream_infer(
                     session_id=session_id,
                     input_ids=[input_ids],
                     stream_output=stream_output,
-                    request_output_len=512,
-                    sequence_start=(nth_round == 1),
-                    sequence_end=False,
-                    step=step,
-                    stop=False,
-                    top_k=40,
-                    top_p=0.8,
-                    temperature=0.8,
-                    repetition_penalty=repetition_penalty,
+                    **dataclasses.asdict(gen_param),
                     ignore_eos=False,
                     random_seed=seed if nth_round == 1 else None):
                 res, tokens = outputs[0]
diff --git a/lmdeploy/turbomind/tokenizer.py b/lmdeploy/turbomind/tokenizer.py
index bb7f95e9e..98db9c2b6 100644
--- a/lmdeploy/turbomind/tokenizer.py
+++ b/lmdeploy/turbomind/tokenizer.py
@@ -111,7 +111,8 @@ class HuggingFaceTokenizer:
     """
 
     def __init__(self, model_dir: str):
-        from transformers import AutoTokenizer, LlamaTokenizerFast
+        from transformers import (AutoTokenizer, CodeLlamaTokenizerFast,
+                                  LlamaTokenizerFast)
         model_file = osp.join(model_dir, 'tokenizer.model')
         backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
         model_file_exists = osp.exists(model_file)
@@ -120,7 +121,8 @@ def __init__(self, model_dir: str):
                   'It may take long time to initialize the tokenizer.')
         self.model = AutoTokenizer.from_pretrained(model_dir,
                                                    trust_remote_code=True)
-        self.need_padding = isinstance(self.model, LlamaTokenizerFast)
+        self.need_padding = isinstance(self.model, LlamaTokenizerFast) \
+            or isinstance(self.model, CodeLlamaTokenizerFast)
         self._no_prefix_space_tokens = None
         # save tokenizer.json to reuse
         if not osp.exists(backend_tokenizer_file) and model_file_exists:
diff --git a/requirements.txt b/requirements.txt
index c0cd48396..861623c04 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,6 +12,6 @@ setuptools
 shortuuid
 tiktoken
 torch
-transformers
+transformers>=4.33.0
 tritonclient[all]
 uvicorn
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention.h b/src/turbomind/kernels/decoder_masked_multihead_attention.h
index dba396bf4..b44332090 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention.h
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention.h
@@ -121,6 +121,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base {
     int        max_position_embeddings    = 0;
     bool       use_dynamic_ntk            = false;
     bool       use_logn_attn              = false;
+    float      rotary_embedding_base      = 10000.0f;
 };
 
 template
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
index 6b9101abb..c2b6039d6 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
@@ -1378,19 +1378,20 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params
     q = add(q, q_bias);
     k = add(k, k_bias);
 
-    float rotary_emb_base = 10000.f;
+    float rotary_embedding_base = params.rotary_embedding_base;
     if (params.use_dynamic_ntk) {
         // +1 because of `length_per_sample == context_length - 1`
-        rotary_emb_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
-                                                    params.max_position_embeddings,
-                                                    params.rotary_embedding_dim,
-                                                    rotary_emb_base);
+        rotary_embedding_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
+                                                          params.max_position_embeddings,
+                                                          params.rotary_embedding_dim,
+                                                          rotary_embedding_base);
     }
 
     // Padded len
     const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
     if (params.rotary_embedding_dim > 0) {
-        apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_emb_base, params.timestep - padd_len);
+        apply_rotary_embedding(
+            q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base, params.timestep - padd_len);
     }
 
     if (params.use_logn_attn) {
diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu
index 536175ccf..b2450c867 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.cu
+++ b/src/turbomind/kernels/unfused_attention_kernels.cu
@@ -863,6 +863,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                    int        kv_head_num,
                                                    int        size_per_head,
                                                    int        rotary_embedding_dim,
+                                                   float      rotary_embedding_base,
                                                    int        max_position_embeddings,
                                                    bool       use_dynamic_ntk,
                                                    bool       use_logn_attn)
@@ -931,14 +932,13 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
     const int context_len = history_len + input_length[batch_idx];
     const int timestep    = history_len + seq_idx;
 
-    float rotary_emb_base = 10000.f;
     if (use_dynamic_ntk) {
-        rotary_emb_base = mmha::rotary_embedding_get_base(
-            context_len, max_position_embeddings, rotary_embedding_dim, rotary_emb_base);
+        rotary_embedding_base = mmha::rotary_embedding_get_base(
+            context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base);
     }
 
     // TODO: unused computation on k if GQA is used
-    mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_emb_base, timestep);
+    mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep);
 
     if (use_logn_attn) {
         // +1 to convert to context length at the timestep
@@ -990,6 +990,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                                                              kv_head_num,              \
                                                                                              size_per_head,            \
                                                                                              rotary_embedding_dim,     \
+                                                                                             rotary_embedding_base,    \
                                                                                              max_position_embeddings,  \
                                                                                              use_dynamic_ntk,          \
                                                                                              use_logn_attn);
@@ -1010,6 +1011,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     const int    kv_head_num,
                                     const int    size_per_head,
                                     const int    rotary_embedding_dim,
+                                    float        rotary_embedding_base,
                                     int          max_position_embeddings,
                                     bool         use_dynamic_ntk,
                                     bool         use_logn_attn,
@@ -1039,6 +1041,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                                  const int    kv_head_num,                                             \
                                                  const int    size_per_head,                                           \
                                                  const int    rotary_embedding_dim,                                    \
+                                                 float        rotary_embedding_base,                                   \
                                                  int          max_position_embeddings,                                 \
                                                  bool         use_dynamic_ntk,                                         \
                                                  bool         use_logn_attn,                                           \
diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h
index 50069fc33..b5c37b5d4 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.h
+++ b/src/turbomind/kernels/unfused_attention_kernels.h
@@ -79,6 +79,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     const int    kv_head_num,
                                     const int    size_per_head,
                                     const int    rotary_embedding_dim,
+                                    float        rotary_embedding_base,
                                     int          max_position_embeddings,
                                     bool         use_dynamic_ntk,
                                     bool         use_logn_attn,
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 66bcf7570..e8f77e1c7 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -175,6 +175,7 @@ inline void LlamaContextAttentionLayer::forward(TensorMap*
                                    local_kv_head_num_,
                                    size_per_head_,
                                    params_.rotray_embedding_dim,
+                                   params_.rotary_embedding_base,
                                    params_.max_position_embeddings,
                                    params_.use_dynamic_ntk,
                                    params_.use_logn_attn,
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index eec9a7fbd..3caaf5906 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -61,6 +61,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T*     qkv_buf,
                                                       const int    kv_head_num,
                                                       const int    size_per_head,
                                                       const int    rotary_embedding_dim,
+                                                      const float  rotary_embedding_base,
                                                       const int    max_position_embeddings,
                                                       const bool   use_dynamic_ntk,
                                                       const bool   use_logn_attn,
@@ -129,6 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T*     qkv_buf,
 
     params.hidden_size_per_head    = size_per_head;
     params.rotary_embedding_dim    = rotary_embedding_dim;
+    params.rotary_embedding_base         = rotary_embedding_base;
     params.max_position_embeddings = max_position_embeddings;
     params.use_dynamic_ntk         = use_dynamic_ntk;
     params.use_logn_attn           = use_logn_attn;
@@ -261,6 +263,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap*                     o
         local_kv_head_num_,
         size_per_head_,
         params_.rotray_embedding_dim,
+        params_.rotary_embedding_base,
         params_.max_position_embeddings,
         params_.use_dynamic_ntk,
         params_.use_logn_attn,
diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h
index a2387e44e..8f8c96837 100644
--- a/src/turbomind/models/llama/llama_params.h
+++ b/src/turbomind/models/llama/llama_params.h
@@ -5,10 +5,11 @@
 namespace turbomind {
 
 struct LlamaAttentionParams {
-    int  rotray_embedding_dim;
-    int  max_position_embeddings;
-    bool use_dynamic_ntk;
-    bool use_logn_attn;
+    int   rotray_embedding_dim;
+    float rotary_embedding_base;
+    int   max_position_embeddings;
+    bool  use_dynamic_ntk;
+    bool  use_logn_attn;
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index 169d6cbdb..456f5f41c 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -137,6 +137,7 @@ LlamaTritonModel::LlamaTritonModel(size_t      tensor_para_size,
     group_size_            = reader.GetInteger("llama", "group_size", 0);
 
     attn_params_.rotray_embedding_dim    = reader.GetInteger("llama", "rotary_embedding");
+    attn_params_.rotary_embedding_base   = reader.GetFloat("llama", "rope_theta", 10000.0f);
     attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0);
     attn_params_.use_dynamic_ntk         = reader.GetInteger("llama", "use_dynamic_ntk", 0);
     attn_params_.use_logn_attn           = reader.GetInteger("llama", "use_logn_attn", 0);
diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py
new file mode 100644
index 000000000..83487f1f0
--- /dev/null
+++ b/tests/test_lmdeploy/test_model.py
@@ -0,0 +1,205 @@
+import pytest
+
+from lmdeploy.model import MODELS, SamplingParam
+
+
+def test_base_model():
+    model = MODELS.get('llama')()
+    assert model is not None
+    assert model.capability == 'chat'
+    assert model.get_prompt('test') is None
+    assert model.stop_words is None
+
+    model = MODELS.get('internlm')(capability='completion')
+    assert model.capability == 'completion'
+    assert model.get_prompt('hi') == 'hi'
+    assert model.messages2prompt('test') == 'test'
+
+
+def test_vicuna():
+    prompt = 'hello, can u introduce yourself'
+    model = MODELS.get('vicuna')(capability='completion')
+    assert model.get_prompt(prompt, sequence_start=True) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is None
+    assert model.system is not None
+
+    model = MODELS.get('vicuna')(capability='chat',
+                                 system='Provide answers in Python')
+    assert model.get_prompt(prompt, sequence_start=True) != prompt
+    assert model.get_prompt(prompt, sequence_start=False) != prompt
+    assert model.system == 'Provide answers in Python'
+
+    model = MODELS.get('vicuna')(capability='voice')
+    _prompt = None
+    with pytest.raises(AssertionError):
+        _prompt = model.get_prompt(prompt, sequence_start=True)
+    assert _prompt is None
+
+
+def test_internlm_chat():
+    prompt = 'hello, can u introduce yourself'
+    model = MODELS.get('internlm-chat-7b')(capability='completion')
+    assert model.get_prompt(prompt, sequence_start=True) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is not None
+    assert model.system == ''
+    assert model.session_len == 2048
+
+    model = MODELS.get('internlm-chat-7b')(capability='chat',
+                                           system='Provide answers in Python')
+    assert model.get_prompt(prompt, sequence_start=True) != prompt
+    assert model.get_prompt(prompt, sequence_start=False) != prompt
+    assert model.system == 'Provide answers in Python'
+
+    model = MODELS.get('internlm-chat-7b')(capability='voice')
+    _prompt = None
+    with pytest.raises(AssertionError):
+        _prompt = model.get_prompt(prompt, sequence_start=True)
+    assert _prompt is None
+
+    model = MODELS.get('internlm-chat-7b-8k')()
+    assert model.session_len == 8192
+
+
+def test_baichuan():
+    prompt = 'hello, can u introduce yourself'
+    model = MODELS.get('baichuan-7b')(capability='completion')
+    assert model.get_prompt(prompt, sequence_start=True) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is None
+    assert model.repetition_penalty == 1.1
+
+    model = MODELS.get('baichuan-7b')(capability='chat')
+    _prompt = model.get_prompt(prompt, sequence_start=True)
+    assert _prompt is None
+
+
+def test_llama2():
+    prompt = 'hello, can u introduce yourself'
+    model = MODELS.get('llama2')(capability='completion')
+    assert model.get_prompt(prompt, sequence_start=True) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is None
+    assert model.default_sys_prompt is not None
+
+    model = MODELS.get('llama2')(capability='chat',
+                                 system='Provide answers in Python')
+    assert model.get_prompt(prompt, sequence_start=True) != prompt
+    assert model.get_prompt(prompt, sequence_start=False) != prompt
+    assert model.default_sys_prompt == 'Provide answers in Python'
+
+    model = MODELS.get('llama2')(capability='voice')
+    _prompt = None
+    with pytest.raises(AssertionError):
+        _prompt = model.get_prompt(prompt, sequence_start=True)
+    assert _prompt is None
+
+
+def test_qwen():
+    prompt = 'hello, can u introduce yourself'
+    model = MODELS.get('qwen-7b')(capability='completion')
+    assert model.get_prompt(prompt, sequence_start=True) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is not None
+
+    model = MODELS.get('qwen-7b')(capability='chat')
+    assert model.get_prompt(prompt, sequence_start=True) != prompt
+    assert model.get_prompt(prompt, sequence_start=False) != prompt
+
+    model = MODELS.get('qwen-7b')(capability='voice')
+    _prompt = None
+    with pytest.raises(AssertionError):
+        _prompt = model.get_prompt(prompt, sequence_start=True)
+    assert _prompt is None
+
+
+def test_codellama_completion():
+    model = MODELS.get('codellama')(capability='completion')
+    prompt = """\
+import socket
+
+def ping_exponential_backoff(host: str):"""
+    assert model.get_prompt(prompt) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is None
+
+
+def test_codellama_infilling():
+    model = MODELS.get('codellama')(capability='infilling')
+    prompt = '''def remove_non_ascii(s: str) -> str:
+    """ 
+    return result
+'''
+    _prompt = model.get_prompt(prompt)
+    assert _prompt.find('') == -1
+    assert model.stop_words == [32010]
+
+    model = MODELS.get('codellama')(capability='infilling', suffix_first=True)
+    _prompt = model.get_prompt(prompt)
+    assert _prompt.find('') == -1
+
+
+def test_codellama_chat():
+    model = MODELS.get('codellama')(capability='chat',
+                                    system='Provide answers in Python')
+    prompt = 'Write a function that computes the set of sums of all contiguous sublists of a given list.'  # noqa: E501
+    _prompt = model.get_prompt(prompt, sequence_start=True)
+    assert _prompt.find('Provide answers in Python') != -1
+
+    _prompt = model.get_prompt(prompt, sequence_start=False)
+    assert _prompt.find('Provide answers in Python') == -1
+    assert model.stop_words is None
+
+
+def test_codellama_python_specialist():
+    model = MODELS.get('codellama')(capability='python')
+    prompt = """
+    def remove_non_ascii(s: str) -> str:
+"""
+    assert model.get_prompt(prompt, sequence_start=True) == prompt
+    assert model.get_prompt(prompt, sequence_start=False) == prompt
+    assert model.stop_words is None
+
+
+def test_codellama_others():
+    model = None
+    with pytest.raises(AssertionError):
+        model = MODELS.get('codellama')(capability='java')
+    assert model is None
+
+
+def test_sampling_param():
+    model = MODELS.get('llama')()
+    default_sampling_param = SamplingParam()
+    assert model.sampling_param == default_sampling_param
+
+    model = MODELS.get('llama')(top_p=0.1, top_k=10)
+    assert model.sampling_param.top_p == 0.1 and \
+        model.sampling_param.top_k == 10
+    assert model.sampling_param.temperature == 0.8 and \
+        model.sampling_param.repetition_penalty == 1.0
+
+    model = MODELS.get('codellama')(capability='completion')
+    assert model.sampling_param.top_p == 0.9 and \
+        model.sampling_param.top_k is None and \
+        model.sampling_param.temperature == 0.2 and \
+        model.sampling_param.repetition_penalty == 1.0
+
+    model = MODELS.get('codellama')(capability='chat')
+    assert model.sampling_param.top_p == 0.95 and \
+        model.sampling_param.top_k is None and \
+        model.sampling_param.temperature == 0.2 and \
+        model.sampling_param.repetition_penalty == 1.0
+
+    model = MODELS.get('codellama')(capability='infilling')
+    assert model.sampling_param.top_p == 0.9 and \
+        model.sampling_param.top_k is None and \
+        model.sampling_param.temperature == 0.0 and \
+        model.sampling_param.repetition_penalty == 1.0
+
+    model = MODELS.get('codellama')(capability='python')
+    assert model.sampling_param.top_p == 0.9 and \
+        model.sampling_param.top_k is None and \
+        model.sampling_param.temperature == 0.2 and \
+        model.sampling_param.repetition_penalty == 1.0