Skip to content

Commit

Permalink
Merge pull request #11 from vtuber-plan/optimum
Browse files Browse the repository at this point in the history
optimum inference support
  • Loading branch information
jstzwj authored Jun 30, 2023
2 parents 7ff7fbe + bed90c0 commit 1715fa2
Show file tree
Hide file tree
Showing 19 changed files with 458 additions and 31 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,31 @@ The core features include:
- HuggingFace-compatible RESTful APIs.
- Tabby-compatible RESTful APIs.

## Support Model Architectures
* LLaMa
* GLM
* Bloom
* OPT
* GPT2
* GPT Neo
* GPT Big Code

## Tested Models
* LLaMa
* Vicuna
* ChatGLM
* ChatGLM2
* Falcon
* Starcoder
* WizardLM
* OpenBuddy

## Benchmark
We use single RTX3090 to run a finetuned 7B LLaMA model (OpenBuddy V0.9) under the bf16 setting.
We create 32 threads to submit chat tasks to the server, and the following figure shows the Queries Per Second (QPS) and Tokens Per Second (TPS) of FastChat and LangPort with different max model concurrency settings.

![benchmark_chat](assets/benchmark_chat.jpg)


## News
- [2023/06/18] Add ggml (llama.cpp gpt.cpp starcoder.cpp etc.) worker support.
- [2023/06/09] Add LLama.cpp worker support.
Expand Down
43 changes: 35 additions & 8 deletions benchmark/bench_chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import argparse
import random
import time
import traceback
import openai
import threading
import tqdm
import datasets
from concurrent.futures import ThreadPoolExecutor

def start_session(i: int, url: str, model: str, stream: bool=False, max_tokens: int=2048, random_len: int=0) -> str:
def start_session(i: int, url: str, model: str, dataset, stream: bool=False, max_tokens: int=2048, random_len: int=0) -> str:
try:
openai.api_key = "EMPTY" # Not support yet
openai.api_base = url

if random_len != 0 :
messages = [{"role": "user", "content": "Hello! What is your name?" + "a" * random.randint(1, random_len)}]
else:
messages = [{"role": "user", "content": "Hello! What is your name?"}]
messages = dataset[i]

# create a chat completion
response = openai.ChatCompletion.create(
model=model,
Expand All @@ -32,17 +32,44 @@ def start_session(i: int, url: str, model: str, stream: bool=False, max_tokens:
total_tokens = response.usage.total_tokens
completion_tokens = response.usage.completion_tokens
except Exception as e:
print(e)
traceback.print_exc()
return "", 0, 0

return out, total_tokens, completion_tokens


def get_prompt(raw_dataset):
dataset = []
for conversations in raw_dataset["conversations"]:
messages = []
for data in conversations:
out_data = {"role": "system", "content": ""}
if data["user"] == "human":
out_data["role"] = "user"
if data["user"] == "gpt":
out_data["role"] = "assitant"

out_data["content"] = data["text"]
messages.append(out_data)

if messages[-1]["role"] == "gpt":
messages = messages[:-1]

prompt = "\n###".join([msg["role"] + ": " + msg["content"] for msg in messages]) + "\n### assistant: "
if len(prompt) > 2048:
continue
dataset.append(messages)
return dataset

def main(args):
dataset = datasets.load_dataset("theblackcat102/sharegpt-english", split="train")
dataset = get_prompt(dataset)

tik = time.time()
tasks = []
with ThreadPoolExecutor(max_workers=args.n_thread) as t:
for i in range(args.total_task):
task = t.submit(start_session, i=i, url=args.url, model=args.model_name, stream=False, max_tokens=args.max_tokens, random_len=args.random_len)
task = t.submit(start_session, i=i, url=args.url, model=args.model_name, dataset=dataset, stream=False, max_tokens=args.max_tokens, random_len=args.random_len)
tasks.append(task)

results = []
Expand All @@ -63,7 +90,7 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:8000/v1")
parser.add_argument("--model-name", type=str, default="vicuna")
parser.add_argument("--max-tokens", type=int, default=1024)
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--total-task", type=int, default=200)
parser.add_argument("--n-thread", type=int, default=32)
parser.add_argument("--random-len", type=int, default=0)
Expand Down
97 changes: 97 additions & 0 deletions benchmark/bench_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
import random
import time
import traceback
import openai
import threading
import tqdm
import datasets
from concurrent.futures import ThreadPoolExecutor

def start_session(i: int, url: str, model: str, dataset, stream: bool=False, max_tokens: int=2048) -> str:
try:
openai.api_key = "EMPTY" # Not support yet
openai.api_base = url

prompt = dataset[i]
# create a chat completion
response = openai.Completion.create(
model=model,
prompt=prompt,
stream=stream,
max_tokens=max_tokens,
temperature=0.9,
)
# print the completion
if stream:
out = ""
for chunk in response:
out += str(chunk)
else:
out = response.choices[0].text
total_tokens = response.usage.total_tokens
completion_tokens = response.usage.completion_tokens
except Exception as e:
traceback.print_exc()
return "", 0, 0

return out, total_tokens, completion_tokens

def get_prompt(raw_dataset):
dataset = []
for conversations in raw_dataset["conversations"]:
messages = []
for data in conversations:
out_data = {"role": "system", "content": ""}
if data["user"] == "human":
out_data["role"] = "user"
if data["user"] == "gpt":
out_data["role"] = "assitant"

out_data["content"] = data["text"]
messages.append(out_data)

if messages[-1]["role"] == "gpt":
messages = messages[:-1]

prompt = "\n###".join([msg["role"] + ": " + msg["content"] for msg in messages]) + "\n### assistant: "
if len(prompt) > 2048:
continue
dataset.append(prompt)
return dataset

def main(args):
dataset = datasets.load_dataset("theblackcat102/sharegpt-english", split="train")
dataset = get_prompt(dataset)

tik = time.time()
tasks = []
with ThreadPoolExecutor(max_workers=args.n_thread) as t:
for i in range(args.total_task):
task = t.submit(start_session, i=i, url=args.url, model=args.model_name, dataset=dataset, stream=False, max_tokens=args.max_tokens)
tasks.append(task)

results = []
for task in tqdm.tqdm(tasks):
results.append(task.result())

n_tokens = sum([ret[2] for ret in results])
n_queries = sum([1 for ret in results if ret[2] != 0])
time_seconds = time.time() - tik
print(
f"Successful number: {n_queries} / {args.total_task}. "
f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
f"throughput: {n_tokens / time_seconds} tokens/s."
f"QPS: {n_queries / time_seconds} queries/s."
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:8000/v1")
parser.add_argument("--model-name", type=str, default="vicuna")
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--total-task", type=int, default=64)
parser.add_argument("--n-thread", type=int, default=4)
args = parser.parse_args()

main(args)
11 changes: 11 additions & 0 deletions langport/data/conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SeparatorStyle(Enum):
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
CHATGLM = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -128,6 +129,16 @@ def get_prompt(self) -> str:
else:
ret += role + ": " + "<s>"
return ret
elif self.settings.sep_style == SeparatorStyle.CHATGLM:
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
if i % 2 == 0:
ret += f"[Round {i+1}]\n\n"
ret += role + ":" + message + self.settings.sep
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.settings.sep_style}")

Expand Down
13 changes: 13 additions & 0 deletions langport/data/conversation/settings/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from langport.data.conversation import (
ConversationSettings,
SeparatorStyle,
)

# Chatglm default template
chatglm = ConversationSettings(
name="chatglm",
roles=("问", "答"),
sep_style=SeparatorStyle.CHATGLM,
sep="\n\n",
stop_str="\n\n",
)
13 changes: 13 additions & 0 deletions langport/data/conversation/settings/wizardlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from langport.data.conversation import (
ConversationSettings,
SeparatorStyle,
)


# Vicuna v1.1 template
wizardlm = ConversationSettings(
name="wizardlm",
roles=("USER", "ASSISTANT"),
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep=" ",
)
19 changes: 19 additions & 0 deletions langport/model/adapters/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Optional
from langport.data.conversation import ConversationHistory, SeparatorStyle
from langport.data.conversation.conversation_settings import get_conv_settings
from langport.model.model_adapter import BaseAdapter

class BaichuanAdapter(BaseAdapter):
"""The model adapter for baichuan-inc/baichuan-7B"""

def match(self, model_path: str):
return "baichuan" in model_path

def get_default_conv_template(self, model_path: str) -> ConversationHistory:
settings = get_conv_settings("one_shot")
return ConversationHistory(
system="",
messages=(),
offset=0,
settings=settings,
)
16 changes: 11 additions & 5 deletions langport/model/adapters/chatglm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from transformers import (
AutoModel,
AutoTokenizer,
)

from langport.data.conversation import ConversationHistory
from langport.data.conversation.conversation_settings import get_conv_settings
from langport.model.model_adapter import BaseAdapter

class ChatGLMAdapter(BaseAdapter):
"""The model adapter for THUDM/chatglm-6b"""

def match(self, model_path: str):
return "chatglm" in model_path

def get_default_conv_template(self, model_path: str) -> ConversationHistory:
settings = get_conv_settings("chatglm")
return ConversationHistory(
system="",
messages=[],
offset=0,
settings=settings,
)
19 changes: 19 additions & 0 deletions langport/model/adapters/wizardlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Optional
from langport.data.conversation import ConversationHistory, SeparatorStyle
from langport.data.conversation.conversation_settings import get_conv_settings
from langport.model.model_adapter import BaseAdapter

class WizardLMAdapter(BaseAdapter):
"""The model adapter for WizardLM/WizardLM-13B-V1.0"""

def match(self, model_path: str):
return "wizardlm" in model_path

def get_default_conv_template(self, model_path: str) -> ConversationHistory:
settings = get_conv_settings("wizardlm")
return ConversationHistory(
system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. ",
messages=(),
offset=0,
settings=settings,
)
11 changes: 11 additions & 0 deletions langport/model/executor/generation/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,19 @@ def generate(self, inputs: BatchingTask,
past_key_values=past_key_values,
)
else:
if step > 0:
dynamic_attention_mask = torch.cat(
(attention_mask,
torch.ones(
inputs.batch_size, step,
dtype=torch.long, device=decoder_input_ids.device
)), dim=1
)
else:
dynamic_attention_mask = attention_mask
out = self.model(
input_ids=decoder_input_ids,
attention_mask=dynamic_attention_mask,
use_cache=self.model.generation_config.use_cache,
past_key_values=past_key_values,
)
Expand Down
Loading

0 comments on commit 1715fa2

Please sign in to comment.