Skip to content

Commit

Permalink
feat(cache): Not cache the failed model output
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Nov 17, 2023
1 parent 9957720 commit 5109c89
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 18 deletions.
2 changes: 2 additions & 0 deletions pilot/model/cluster/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class PromptRequest(BaseModel):
temperature: float = None
max_new_tokens: int = None
stop: str = None
stop_token_ids: List[int] = []
context_len: int = None
echo: bool = True
span_id: str = None

Expand Down
37 changes: 30 additions & 7 deletions pilot/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import logging
from typing import Dict, Iterator, List
from typing import Dict, Iterator, List, Optional

from pilot.configs.model_config import get_device
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
Expand Down Expand Up @@ -60,7 +60,7 @@ def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
# TODO read context len from model config
# Default model context len
self.context_len = 2048

def model_param_class(self) -> ModelParameters:
Expand Down Expand Up @@ -111,6 +111,12 @@ def start(
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
model_max_length = _parse_model_max_length(self.model, self.tokenizer)
if model_max_length:
logger.info(
f"Parse model max length {model_max_length} from model {self.model_name}."
)
self.context_len = model_max_length

def stop(self) -> None:
if not self.model:
Expand Down Expand Up @@ -138,9 +144,9 @@ def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
)

previous_response = ""

context_len = params.get("context_len") or self.context_len
for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
self.model, self.tokenizer, params, get_device(), context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
Expand Down Expand Up @@ -183,9 +189,10 @@ async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
)

previous_response = ""
context_len = params.get("context_len") or self.context_len

async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
self.model, self.tokenizer, params, get_device(), context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
Expand Down Expand Up @@ -279,11 +286,27 @@ def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
error_code=1,
)
return model_output


def _parse_model_max_length(model, tokenizer) -> Optional[int]:
if not (tokenizer or model):
return None
try:
if tokenizer and hasattr(tokenizer, "model_max_length"):
return tokenizer.model_max_length
if model and hasattr(model, "config"):
model_config = model.config
if hasattr(model_config, "max_sequence_length"):
return model_config.max_sequence_length
if hasattr(model_config, "max_position_embeddings"):
return model_config.max_position_embeddings
except Exception:
return None
12 changes: 9 additions & 3 deletions pilot/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ async def start(self):
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
for listener in self.start_listeners:
listener(self)
if asyncio.iscoroutinefunction(listener):
await listener(self)
else:
listener(self)

async def stop(self, ignore_exception: bool = False):
if not self.run_data.stop_event.is_set():
Expand Down Expand Up @@ -325,7 +328,7 @@ async def generate_stream(
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
error_code=1,
)
return
async with worker_run_data.semaphore:
Expand Down Expand Up @@ -355,7 +358,7 @@ async def generate(self, params: Dict) -> ModelOutput:
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
error_code=1,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
Expand Down Expand Up @@ -996,6 +999,7 @@ def run_worker_manager(
port: int = None,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
):
global worker_manager

Expand Down Expand Up @@ -1029,6 +1033,8 @@ def run_worker_manager(
worker_manager, embedding_model_name, embedding_model_path
)

worker_manager.after_start(start_listener)

if include_router:
app.include_router(router, prefix="/api")

Expand Down
5 changes: 4 additions & 1 deletion pilot/model/cluster/worker/remote_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def __init__(self, model_registry: ModelRegistry = None) -> None:

async def start(self):
for listener in self.start_listeners:
listener(self)
if asyncio.iscoroutinefunction(listener):
await listener(self)
else:
listener(self)

async def stop(self, ignore_exception: bool = False):
pass
Expand Down
9 changes: 6 additions & 3 deletions pilot/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ def model_adaptation(
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt

# Overwrite model params:
params["stop"] = conv.stop_str
params["stop_token_ids"] = conv.stop_token_ids
custom_stop = params.get("stop")
custom_stop_token_ids = params.get("stop_token_ids")

# Prefer the value passed in from the input parameter
params["stop"] = custom_stop or conv.stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids

return params, model_context

Expand Down
20 changes: 17 additions & 3 deletions pilot/model/operator/model_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator, Dict, Union
from typing import AsyncIterator, Dict, List, Union
import logging
from pilot.awel import (
BranchFunc,
Expand Down Expand Up @@ -227,7 +227,7 @@ async def transform_stream(
)
outputs.append(out)
yield out
if llm_cache_key:
if llm_cache_key and _is_success_model_output(outputs):
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
await self._client.set(llm_cache_key, llm_cache_value)

Expand Down Expand Up @@ -258,7 +258,7 @@ async def map(self, input_value: ModelOutput) -> ModelOutput:
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
if llm_cache_key:
if llm_cache_key and _is_success_model_output(input_value):
await self._client.set(llm_cache_key, llm_cache_value)
return input_value

Expand All @@ -284,3 +284,17 @@ def _parse_cache_key_dict(input_value: Dict) -> Dict:
# TODO pass model_type
"model_type": input_value.get("model_type", "huggingface"),
}


def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
if not out:
return False
if isinstance(out, list):
# check last model output
out = out[-1]
error_code = 0
if isinstance(out, ModelOutput):
error_code = out.error_code
else:
error_code = int(out.get("error_code", 0))
return error_code == 0
2 changes: 1 addition & 1 deletion pilot/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def __call_base(self):
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep,
# "stop": self.prompt_template.sep,
"echo": self.llm_echo,
}
return payload
Expand Down

0 comments on commit 5109c89

Please sign in to comment.