Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(web): Add incremental response to streaming response for /v1/chat/completion request #611

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pilot/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
# TODO: support internlm quantization
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)
Expand Down
44 changes: 38 additions & 6 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
ConversationVo,
MessageVo,
ChatSceneVo,
ChatCompletionResponseStreamChoice,
DeltaMessage,
ChatCompletionStreamResponse,
)
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config
Expand Down Expand Up @@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
else:
return StreamingResponse(
stream_generator(chat),
stream_generator(chat, dialogue.incremental, dialogue.model_name),
headers=headers,
media_type="text/plain",
)
Expand Down Expand Up @@ -421,19 +424,48 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"


async def stream_generator(chat):
async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses

Our goal is to generate an openai-compatible streaming responses.
Currently, the incremental response is compatible, and the full response will be transformed in the future.

Args:
chat (BaseChat): Chat instance.
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
model_name (str): The model name

Yields:
_type_: streaming responses
"""
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."

stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in chat.stream_call():
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)

msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
msg = msg.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)

if incremental:
yield "data: [DONE]\n\n"
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)
Expand Down
27 changes: 26 additions & 1 deletion pilot/openapi/api_view_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
from typing import TypeVar, Generic, Any, Optional, Literal, List
import uuid
import time

T = TypeVar("T")

Expand Down Expand Up @@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
"""
model_name: str = None

"""Used to control whether the content is returned incrementally or in full each time.
If this parameter is not provided, the default is full return.
"""
incremental: bool = False


class MessageVo(BaseModel):
"""
Expand All @@ -83,3 +90,21 @@ class MessageVo(BaseModel):
model_name
"""
model_name: str


class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
6 changes: 4 additions & 2 deletions pilot/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@


def _clear_torch_cache(device="cuda"):
try:
import torch
except ImportError:
return
fangyinc marked this conversation as resolved.
Show resolved Hide resolved
import gc

import torch

gc.collect()
if device != "cpu":
if torch.has_mps:
Expand Down