From 5109c89daa0de420306c18bd1bd97e4b1e6e5be4 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Fri, 17 Nov 2023 11:06:36 +0800 Subject: [PATCH] feat(cache): Not cache the failed model output --- pilot/model/cluster/base.py | 2 ++ pilot/model/cluster/worker/default_worker.py | 37 ++++++++++++++++---- pilot/model/cluster/worker/manager.py | 12 +++++-- pilot/model/cluster/worker/remote_manager.py | 5 ++- pilot/model/model_adapter.py | 9 +++-- pilot/model/operator/model_operator.py | 20 +++++++++-- pilot/scene/base_chat.py | 2 +- 7 files changed, 69 insertions(+), 18 deletions(-) diff --git a/pilot/model/cluster/base.py b/pilot/model/cluster/base.py index 9d22161b1..36c4779b8 100644 --- a/pilot/model/cluster/base.py +++ b/pilot/model/cluster/base.py @@ -17,6 +17,8 @@ class PromptRequest(BaseModel): temperature: float = None max_new_tokens: int = None stop: str = None + stop_token_ids: List[int] = [] + context_len: int = None echo: bool = True span_id: str = None diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 44a476f20..d6663fc9f 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -1,6 +1,6 @@ import os import logging -from typing import Dict, Iterator, List +from typing import Dict, Iterator, List, Optional from pilot.configs.model_config import get_device from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper @@ -60,7 +60,7 @@ def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: self.ml: ModelLoader = ModelLoader( model_path=self.model_path, model_name=self.model_name ) - # TODO read context len from model config + # Default model context len self.context_len = 2048 def model_param_class(self) -> ModelParameters: @@ -111,6 +111,12 @@ def start( self.model, self.tokenizer = self.ml.loader_with_params( model_params, self.llm_adapter ) + model_max_length = _parse_model_max_length(self.model, self.tokenizer) + if model_max_length: + logger.info( + f"Parse model max length {model_max_length} from model {self.model_name}." + ) + self.context_len = model_max_length def stop(self) -> None: if not self.model: @@ -138,9 +144,9 @@ def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: ) previous_response = "" - + context_len = params.get("context_len") or self.context_len for output in generate_stream_func( - self.model, self.tokenizer, params, get_device(), self.context_len + self.model, self.tokenizer, params, get_device(), context_len ): model_output, incremental_output, output_str = self._handle_output( output, previous_response, model_context @@ -183,9 +189,10 @@ async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: ) previous_response = "" + context_len = params.get("context_len") or self.context_len async for output in generate_stream_func( - self.model, self.tokenizer, params, get_device(), self.context_len + self.model, self.tokenizer, params, get_device(), context_len ): model_output, incremental_output, output_str = self._handle_output( output, previous_response, model_context @@ -279,11 +286,27 @@ def _handle_exception(self, e): # Check if the exception is a torch.cuda.CudaError and if torch was imported. if _torch_imported and isinstance(e, torch.cuda.CudaError): model_output = ModelOutput( - text="**GPU OutOfMemory, Please Refresh.**", error_code=0 + text="**GPU OutOfMemory, Please Refresh.**", error_code=1 ) else: model_output = ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, + error_code=1, ) return model_output + + +def _parse_model_max_length(model, tokenizer) -> Optional[int]: + if not (tokenizer or model): + return None + try: + if tokenizer and hasattr(tokenizer, "model_max_length"): + return tokenizer.model_max_length + if model and hasattr(model, "config"): + model_config = model.config + if hasattr(model_config, "max_sequence_length"): + return model_config.max_sequence_length + if hasattr(model_config, "max_position_embeddings"): + return model_config.max_position_embeddings + except Exception: + return None diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index d67519f59..2dd402920 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -119,7 +119,10 @@ async def start(self): _async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func) ) for listener in self.start_listeners: - listener(self) + if asyncio.iscoroutinefunction(listener): + await listener(self) + else: + listener(self) async def stop(self, ignore_exception: bool = False): if not self.run_data.stop_event.is_set(): @@ -325,7 +328,7 @@ async def generate_stream( except Exception as e: yield ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, + error_code=1, ) return async with worker_run_data.semaphore: @@ -355,7 +358,7 @@ async def generate(self, params: Dict) -> ModelOutput: except Exception as e: return ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, + error_code=1, ) async with worker_run_data.semaphore: if worker_run_data.worker.support_async(): @@ -996,6 +999,7 @@ def run_worker_manager( port: int = None, embedding_model_name: str = None, embedding_model_path: str = None, + start_listener: Callable[["WorkerManager"], None] = None, ): global worker_manager @@ -1029,6 +1033,8 @@ def run_worker_manager( worker_manager, embedding_model_name, embedding_model_path ) + worker_manager.after_start(start_listener) + if include_router: app.include_router(router, prefix="/api") diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py index 61b608cc7..4047f428e 100644 --- a/pilot/model/cluster/worker/remote_manager.py +++ b/pilot/model/cluster/worker/remote_manager.py @@ -15,7 +15,10 @@ def __init__(self, model_registry: ModelRegistry = None) -> None: async def start(self): for listener in self.start_listeners: - listener(self) + if asyncio.iscoroutinefunction(listener): + await listener(self) + else: + listener(self) async def stop(self, ignore_exception: bool = False): pass diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index e2deeaa02..8fd242882 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -170,9 +170,12 @@ def model_adaptation( model_context["has_format_prompt"] = True params["prompt"] = new_prompt - # Overwrite model params: - params["stop"] = conv.stop_str - params["stop_token_ids"] = conv.stop_token_ids + custom_stop = params.get("stop") + custom_stop_token_ids = params.get("stop_token_ids") + + # Prefer the value passed in from the input parameter + params["stop"] = custom_stop or conv.stop_str + params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids return params, model_context diff --git a/pilot/model/operator/model_operator.py b/pilot/model/operator/model_operator.py index 7cf6395ec..6486e8373 100644 --- a/pilot/model/operator/model_operator.py +++ b/pilot/model/operator/model_operator.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Dict, Union +from typing import AsyncIterator, Dict, List, Union import logging from pilot.awel import ( BranchFunc, @@ -227,7 +227,7 @@ async def transform_stream( ) outputs.append(out) yield out - if llm_cache_key: + if llm_cache_key and _is_success_model_output(outputs): llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs) await self._client.set(llm_cache_key, llm_cache_value) @@ -258,7 +258,7 @@ async def map(self, input_value: ModelOutput) -> ModelOutput: _LLM_MODEL_INPUT_VALUE_KEY ) llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value) - if llm_cache_key: + if llm_cache_key and _is_success_model_output(input_value): await self._client.set(llm_cache_key, llm_cache_value) return input_value @@ -284,3 +284,17 @@ def _parse_cache_key_dict(input_value: Dict) -> Dict: # TODO pass model_type "model_type": input_value.get("model_type", "huggingface"), } + + +def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool: + if not out: + return False + if isinstance(out, list): + # check last model output + out = out[-1] + error_code = 0 + if isinstance(out, ModelOutput): + error_code = out.error_code + else: + error_code = int(out.get("error_code", 0)) + return error_code == 0 diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 4f5d8c5d0..9a19f3255 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -173,7 +173,7 @@ async def __call_base(self): "messages": llm_messages, "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), - "stop": self.prompt_template.sep, + # "stop": self.prompt_template.sep, "echo": self.llm_echo, } return payload