Skip to content

Commit

Permalink
Fix the openai benchmarking requests to work with latest OpenAI apis (v…
Browse files Browse the repository at this point in the history
…llm-project#2992)

Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
wangchen615 and ywang96 authored Mar 4, 2024
1 parent ff578ca commit 9a4548b
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,80 @@ async def async_request_openai_completions(
return output


async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"v1/chat/completions"
), "OpenAI Chat API URL must end with 'v1/chat/completions'."

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"messages": [
{
"role": "user",
"content": request_func_input.prompt,
},
],
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}

output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0
st = time.perf_counter()
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft

chunk = chunk.strip()
if not chunk:
continue

chunk = chunk.decode("utf-8").lstrip("data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
body = json.loads(chunk)
if "content" in body["choices"][0]["delta"]:
generated_text += body["choices"][0]["delta"][
"content"]

output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False

if pbar:
pbar.update(1)
return output


ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi,
"vllm": async_request_vllm,
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"tensorrt-llm": async_request_trt_llm,
}

0 comments on commit 9a4548b

Please sign in to comment.