diff --git a/config_sample.yml b/config_sample.yml index 85bb1df4..3b4f2479 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -83,6 +83,9 @@ model: # Enable this if the program is looking for a specific OAI model #use_dummy_models: False + # Allow direct loading of models from a completion or chat completion request + inline_model_loading: False + # An initial model to load. Make sure the model is located in the model directory! # A model can be loaded later via the API. # REQUIRED: This must be filled out to load a model on startup! diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 66bc759f..1b98f41d 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -21,6 +21,7 @@ ) from endpoints.OAI.utils.completion import ( generate_completion, + load_inline_model, stream_generate_completion, ) from endpoints.OAI.utils.embeddings import get_embeddings @@ -41,7 +42,7 @@ def setup(): # Completions endpoint @router.post( "/v1/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def completion_request( request: Request, data: CompletionRequest @@ -52,6 +53,11 @@ async def completion_request( If stream = true, this returns an SSE stream. """ + if data.model: + await load_inline_model(data.model, request) + else: + await check_model_container() + model_path = model.container.model_dir if isinstance(data.prompt, list): @@ -86,7 +92,7 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -97,6 +103,11 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ + if data.model: + await load_inline_model(data.model, request) + else: + await check_model_container() + if model.container.prompt_template is None: error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 89777923..30ec7699 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -56,6 +56,7 @@ class ChatCompletionRequest(CommonCompletionRequest): add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} response_prefix: Optional[str] = None + model: Optional[str] = None # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 52c2bb46..cc752c52 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -1,4 +1,8 @@ -"""Completion utilities for OAI server.""" +""" +Completion utilities for OAI server. + +Also serves as a common module for completions and chat completions. +""" import asyncio import pathlib @@ -9,7 +13,8 @@ from loguru import logger -from common import model +from common import config, model +from common.auth import get_key_permission from common.networking import ( get_generator_error, handle_request_disconnect, @@ -103,6 +108,48 @@ async def _stream_collector( await gen_queue.put(e) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists + if model.container and model.container.model_dir.name == model_name: + return + + model_config = config.model_config() + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + + if not unwrap(model_config.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_load" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def stream_generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ):