diff --git a/OAI/types/lora.py b/OAI/types/lora.py index 018bf061..10c8327f 100644 --- a/OAI/types/lora.py +++ b/OAI/types/lora.py @@ -32,6 +32,7 @@ class LoraLoadRequest(BaseModel): """Represents a Lora load request.""" loras: List[LoraLoadInfo] + skip_queue: bool = False class LoraLoadResponse(BaseModel): diff --git a/OAI/types/model.py b/OAI/types/model.py index 71b8ab3d..5301ca83 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -93,6 +93,7 @@ class ModelLoadRequest(BaseModel): use_cfg: Optional[bool] = None fasttensors: Optional[bool] = False draft: Optional[DraftModelLoadRequest] = None + skip_queue: Optional[bool] = False class ModelLoadResponse(BaseModel): diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index eb13356f..43b86334 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -55,6 +55,7 @@ class ExllamaV2Container: autosplit_reserve: List[float] = [96 * 1024**2] # Load state + model_is_loading: bool = False model_loaded: bool = False def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): @@ -350,6 +351,9 @@ def load_gen(self, progress_callback=None): def progress(loaded_modules: int, total_modules: int) """ + # Notify that the model is being loaded + self.model_is_loading = True + # Load tokenizer self.tokenizer = ExLlamaV2Tokenizer(self.config) @@ -439,6 +443,7 @@ def progress(loaded_modules: int, total_modules: int) torch.cuda.empty_cache() # Update model load state + self.model_is_loading = False self.model_loaded = True logger.info("Model successfully loaded.") @@ -472,7 +477,7 @@ def unload(self, loras_only: bool = False): # Update model load state self.model_loaded = False - logger.info("Model unloaded.") + logger.info("Loras unloaded." if loras_only else "Model unloaded.") def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" diff --git a/common/utils.py b/common/utils.py index 6fec00dc..b01be0f8 100644 --- a/common/utils.py +++ b/common/utils.py @@ -45,8 +45,10 @@ def handle_request_error(message: str): request_error = TabbyRequestError(error=error_message) # Log the error and provided message to the console - logger.error(error_message.trace) - logger.error(message) + if error_message.trace: + logger.error(error_message.trace) + + logger.error(f"Sent to request: {message}") return request_error diff --git a/main.py b/main.py index c60916f3..737b668a 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,9 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" -import asyncio import os import pathlib import signal import sys +import time import uvicorn import threading from asyncio import CancelledError @@ -93,7 +93,9 @@ def _check_model_container(): - if MODEL_CONTAINER is None or not MODEL_CONTAINER.model_loaded: + if MODEL_CONTAINER is None or not ( + MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded + ): error_message = handle_request_error( "No models are currently loaded." ).error.message @@ -183,24 +185,13 @@ def list_draft_models(): # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -def load_model(request: Request, data: ModelLoadRequest): +async def load_model(request: Request, data: ModelLoadRequest): """Loads a model into the model container.""" - global MODEL_CONTAINER + # Verify request parameters if not data.name: raise HTTPException(400, "A model name was not provided.") - # Unload the existing model - if MODEL_CONTAINER and MODEL_CONTAINER.model: - loaded_model_name = MODEL_CONTAINER.get_model_path().name - - if loaded_model_name == data.name: - raise HTTPException( - 400, f'Model "{loaded_model_name}"is already loaded! Aborting.' - ) - else: - MODEL_CONTAINER.unload() - model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models")) model_path = model_path / data.name @@ -219,10 +210,24 @@ def load_model(request: Request, data: ModelLoadRequest): if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") - MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) + # Check if the model is already loaded + if MODEL_CONTAINER and MODEL_CONTAINER.model: + loaded_model_name = MODEL_CONTAINER.get_model_path().name + + if loaded_model_name == data.name: + raise HTTPException( + 400, f'Model "{loaded_model_name}"is already loaded! Aborting.' + ) async def generator(): """Generator for the loading process.""" + global MODEL_CONTAINER + + # Unload the existing model + if MODEL_CONTAINER and MODEL_CONTAINER.model: + unload_model() + + MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) model_type = "draft" if MODEL_CONTAINER.draft_config else "model" load_status = MODEL_CONTAINER.load_gen(load_progress) @@ -230,6 +235,7 @@ async def generator(): try: for module, modules in load_status: if await request.is_disconnected(): + unload_model() break if module == 0: @@ -269,7 +275,17 @@ async def generator(): except Exception as exc: yield get_generator_error(str(exc)) - return StreamingResponse(generator(), media_type="text/event-stream") + # Determine whether to use or skip the queue + if data.skip_queue: + logger.warning( + "Model load request is skipping the completions queue. " + "Unexpected results may occur." + ) + generator_callback = generator + else: + generator_callback = partial(generate_with_semaphore, generator) + + return StreamingResponse(generator_callback(), media_type="text/event-stream") # Unload model endpoint @@ -398,7 +414,7 @@ def get_active_loras(): "/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)], ) -def load_lora(data: LoraLoadRequest): +async def load_lora(data: LoraLoadRequest): """Loads a LoRA into the model container.""" if not data.loras: raise HTTPException(400, "List of loras to load is not found.") @@ -411,14 +427,27 @@ def load_lora(data: LoraLoadRequest): ) # Clean-up existing loras if present - if len(MODEL_CONTAINER.active_loras) > 0: - MODEL_CONTAINER.unload(True) + def load_loras_internal(): + if len(MODEL_CONTAINER.active_loras) > 0: + unload_loras() + + result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump()) + return LoraLoadResponse( + success=unwrap(result.get("success"), []), + failure=unwrap(result.get("failure"), []), + ) - result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump()) - return LoraLoadResponse( - success=unwrap(result.get("success"), []), - failure=unwrap(result.get("failure"), []), - ) + internal_callback = partial(run_in_threadpool, load_loras_internal) + + # Determine whether to skip the queue + if data.skip_queue: + logger.warning( + "Lora load request is skipping the completions queue. " + "Unexpected results may occur." + ) + return await internal_callback() + else: + return await call_with_semaphore(internal_callback) # Unload lora endpoint @@ -428,7 +457,7 @@ def load_lora(data: LoraLoadRequest): ) def unload_loras(): """Unloads the currently loaded loras.""" - MODEL_CONTAINER.unload(True) + MODEL_CONTAINER.unload(loras_only=True) # Encode tokens endpoint @@ -498,7 +527,8 @@ async def generator(): ) return StreamingResponse( - generate_with_semaphore(generator), media_type="text/event-stream" + generate_with_semaphore(generator), + media_type="text/event-stream", ) try: @@ -515,7 +545,8 @@ async def generator(): return response except Exception as exc: error_message = handle_request_error( - "Completion aborted. Please check the server console." + "Completion aborted. Maybe the model was unloaded? " + "Please check the server console." ).error.message # Server error if there's a generation exception @@ -617,7 +648,8 @@ async def generator(): return response except Exception as exc: error_message = handle_request_error( - "Chat completion aborted. Please check the server console." + "Chat completion aborted. Maybe the model was unloaded? " + "Please check the server console." ).error.message # Server error if there's a generation exception @@ -636,7 +668,6 @@ def start_api(host: str, port: int): app, host=host, port=port, - log_level="debug", ) @@ -733,15 +764,13 @@ def entrypoint(args: Optional[dict] = None): host = unwrap(network_config.get("host"), "127.0.0.1") port = unwrap(network_config.get("port"), 5000) - # Start the API in a daemon thread - # This allows for command signals to be passed and properly shut down the program - # Otherwise the program will hang # TODO: Replace this with abortables, async via producer consumer, or something else - threading.Thread(target=partial(start_api, host, port), daemon=True).start() + api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True) + api_thread.start() # Keep the program alive - loop = asyncio.get_event_loop() - loop.run_forever() + while api_thread.is_alive(): + time.sleep(0.5) if __name__ == "__main__":