diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 2078c55e7..e72e9ba31 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -8,7 +8,7 @@ from pilot.model.parameter import ModelParameters from pilot.model.cluster.worker_base import ModelWorker from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter -from pilot.utils.model_utils import _clear_torch_cache +from pilot.utils.model_utils import _clear_model_cache from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger(__name__) @@ -87,7 +87,7 @@ def stop(self) -> None: del self.tokenizer self.model = None self.tokenizer = None - _clear_torch_cache(self._model_params.device) + _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: torch_imported = False diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py index a9f934a1c..62b799864 100644 --- a/pilot/model/cluster/worker/embedding_worker.py +++ b/pilot/model/cluster/worker/embedding_worker.py @@ -11,7 +11,7 @@ ) from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.embedding.loader import EmbeddingLoader -from pilot.utils.model_utils import _clear_torch_cache +from pilot.utils.model_utils import _clear_model_cache from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def stop(self) -> None: return del self._embeddings_impl self._embeddings_impl = None - _clear_torch_cache(self._model_params.device) + _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict): """Generate stream result, chat scene""" diff --git a/pilot/model/loader.py b/pilot/model/loader.py index a6019c129..63a484151 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -18,6 +18,7 @@ def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters): # TODO: vicuna-v1.5 8-bit quantization info is slow # TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5 + # TODO: support internlm quantization model_name = model_params.model_name.lower() supported_models = ["llama", "baichuan", "vicuna"] return any(m in model_name for m in supported_models) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 7b198c49a..24bee6cdb 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -26,6 +26,9 @@ ConversationVo, MessageVo, ChatSceneVo, + ChatCompletionResponseStreamChoice, + DeltaMessage, + ChatCompletionStreamResponse, ) from pilot.connections.db_conn_info import DBConfig, DbTypeInfo from pilot.configs.config import Config @@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()): ) else: return StreamingResponse( - stream_generator(chat), + stream_generator(chat, dialogue.incremental, dialogue.model_name), headers=headers, media_type="text/plain", ) @@ -421,19 +424,48 @@ async def no_stream_generator(chat): yield f"data: {msg}\n\n" -async def stream_generator(chat): +async def stream_generator(chat, incremental: bool, model_name: str): + """Generate streaming responses + + Our goal is to generate an openai-compatible streaming responses. + Currently, the incremental response is compatible, and the full response will be transformed in the future. + + Args: + chat (BaseChat): Chat instance. + incremental (bool): Used to control whether the content is returned incrementally or in full each time. + model_name (str): The model name + + Yields: + _type_: streaming responses + """ msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong." + stream_id = f"chatcmpl-{str(uuid.uuid1())}" + previous_response = "" async for chunk in chat.stream_call(): if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( chunk, chat.skip_echo_len ) - - msg = msg.replace("\n", "\\n") - yield f"data:{msg}\n\n" + msg = msg.replace("\ufffd", "") + if incremental: + incremental_output = msg[len(previous_response) :] + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant", content=incremental_output), + ) + chunk = ChatCompletionStreamResponse( + id=stream_id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + else: + # TODO generate an openai-compatible streaming responses + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + previous_response = msg await asyncio.sleep(0.02) - + if incremental: + yield "data: [DONE]\n\n" chat.current_message.add_ai_message(msg) chat.current_message.add_view_message(msg) chat.memory.append(chat.current_message) diff --git a/pilot/openapi/api_view_model.py b/pilot/openapi/api_view_model.py index d03beec8d..60065f2f2 100644 --- a/pilot/openapi/api_view_model.py +++ b/pilot/openapi/api_view_model.py @@ -1,5 +1,7 @@ from pydantic import BaseModel, Field -from typing import TypeVar, Generic, Any +from typing import TypeVar, Generic, Any, Optional, Literal, List +import uuid +import time T = TypeVar("T") @@ -59,6 +61,11 @@ class ConversationVo(BaseModel): """ model_name: str = None + """Used to control whether the content is returned incrementally or in full each time. + If this parameter is not provided, the default is full return. + """ + incremental: bool = False + class MessageVo(BaseModel): """ @@ -83,3 +90,21 @@ class MessageVo(BaseModel): model_name """ model_name: str + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}") + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] diff --git a/pilot/utils/model_utils.py b/pilot/utils/model_utils.py index a7a51ad32..d9527118e 100644 --- a/pilot/utils/model_utils.py +++ b/pilot/utils/model_utils.py @@ -1,10 +1,22 @@ import logging +logger = logging.getLogger(__name__) -def _clear_torch_cache(device="cuda"): - import gc +def _clear_model_cache(device="cuda"): + try: + # clear torch cache + import torch + + _clear_torch_cache(device) + except ImportError: + logger.warn("Torch not installed, skip clear torch cache") + # TODO clear other cache + + +def _clear_torch_cache(device="cuda"): import torch + import gc gc.collect() if device != "cpu": @@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"): empty_cache() except Exception as e: - logging.warn(f"Clear mps torch cache error, {str(e)}") + logger.warn(f"Clear mps torch cache error, {str(e)}") elif torch.has_cuda: device_count = torch.cuda.device_count() for device_id in range(device_count): cuda_device = f"cuda:{device_id}" - logging.info(f"Clear torch cache of device: {cuda_device}") + logger.info(f"Clear torch cache of device: {cuda_device}") with torch.cuda.device(cuda_device): torch.cuda.empty_cache() torch.cuda.ipc_collect() else: - logging.info("No cuda or mps, not support clear torch cache yet") + logger.info("No cuda or mps, not support clear torch cache yet")