From 279e900ea5ec095eff97d2996a966a1ea0aa663f Mon Sep 17 00:00:00 2001 From: Colin Kealty Date: Tue, 4 Jun 2024 13:35:48 -0400 Subject: [PATCH 1/3] Add on the fly model loading to requests --- endpoints/OAI/router.py | 49 +++++++++++++++++++++++++- endpoints/OAI/types/chat_completion.py | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0e4f27be..f4cc5169 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,6 @@ import asyncio import pathlib +from loguru import logger from fastapi import APIRouter, Depends, HTTPException, Header, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -118,7 +119,7 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -129,6 +130,52 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ + if data.model is not None and ( + model.container is None or model.container.get_model_path().name != data.model + ): + adminValid = False + if "x_admin_key" in request.headers.keys(): + try: + await check_admin_key( + x_admin_key=request.headers.get("x_admin_key"), authorization=None + ) + adminValid = True + except HTTPException: + pass + + if not adminValid and "authorization" in request.headers.keys(): + try: + await check_admin_key( + x_admin_key=None, authorization=request.headers.get("authorization") + ) + adminValid = True + except HTTPException: + pass + + if adminValid: + logger.info( + f"New request for {data.model} which is not loaded, proper admin key provided, loading new model" + ) + + model_path = pathlib.Path( + unwrap(config.model_config().get("model_dir"), "models") + ) + model_path = model_path / data.model + + if not model_path.exists(): + error_message = handle_request_error( + "Could not find the model path for load. Check model name.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + await model.load_model(model_path) + else: + logger.info(f"No valid admin key found to change loaded model, ignoring") + else: + await check_model_container() + if model.container.prompt_template is None: error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea9..b66277b8 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -47,6 +47,7 @@ class ChatCompletionRequest(CommonCompletionRequest): add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} response_prefix: Optional[str] = None + model: Optional[str] = None class ChatCompletionResponse(BaseModel): From 21f14d431883009263bac40c17e4e634aff591e7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 3 Sep 2024 23:37:28 -0400 Subject: [PATCH 2/3] API: Update inline load - Add a config flag - Migrate support to /v1/completions - Unify the load function Signed-off-by: kingbri --- config_sample.yml | 3 ++ endpoints/OAI/router.py | 55 +++++-------------------------- endpoints/OAI/utils/completion.py | 50 ++++++++++++++++++++++++++-- 3 files changed, 60 insertions(+), 48 deletions(-) diff --git a/config_sample.yml b/config_sample.yml index 85bb1df4..3b4f2479 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -83,6 +83,9 @@ model: # Enable this if the program is looking for a specific OAI model #use_dummy_models: False + # Allow direct loading of models from a completion or chat completion request + inline_model_loading: False + # An initial model to load. Make sure the model is located in the model directory! # A model can be loaded later via the API. # REQUIRED: This must be filled out to load a model on startup! diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index ca91b0a8..1b98f41d 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,6 +1,4 @@ import asyncio -import pathlib -from loguru import logger from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -23,6 +21,7 @@ ) from endpoints.OAI.utils.completion import ( generate_completion, + load_inline_model, stream_generate_completion, ) from endpoints.OAI.utils.embeddings import get_embeddings @@ -43,7 +42,7 @@ def setup(): # Completions endpoint @router.post( "/v1/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def completion_request( request: Request, data: CompletionRequest @@ -54,6 +53,11 @@ async def completion_request( If stream = true, this returns an SSE stream. """ + if data.model: + await load_inline_model(data.model, request) + else: + await check_model_container() + model_path = model.container.model_dir if isinstance(data.prompt, list): @@ -99,49 +103,8 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ - if data.model is not None and ( - model.container is None or model.container.get_model_path().name != data.model - ): - adminValid = False - if "x_admin_key" in request.headers.keys(): - try: - await check_admin_key( - x_admin_key=request.headers.get("x_admin_key"), authorization=None - ) - adminValid = True - except HTTPException: - pass - - if not adminValid and "authorization" in request.headers.keys(): - try: - await check_admin_key( - x_admin_key=None, authorization=request.headers.get("authorization") - ) - adminValid = True - except HTTPException: - pass - - if adminValid: - logger.info( - f"New request for {data.model} which is not loaded, proper admin key provided, loading new model" - ) - - model_path = pathlib.Path( - unwrap(config.model_config().get("model_dir"), "models") - ) - model_path = model_path / data.model - - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - await model.load_model(model_path) - else: - logger.info(f"No valid admin key found to change loaded model, ignoring") + if data.model: + await load_inline_model(data.model, request) else: await check_model_container() diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 52c2bb46..5fdf81f4 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -1,4 +1,8 @@ -"""Completion utilities for OAI server.""" +""" +Completion utilities for OAI server. + +Also serves as a common module for completions and chat completions. +""" import asyncio import pathlib @@ -9,7 +13,8 @@ from loguru import logger -from common import model +from common import config, model +from common.auth import get_key_permission from common.networking import ( get_generator_error, handle_request_disconnect, @@ -173,6 +178,47 @@ async def stream_generate_completion( ) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists + if model.container and model.container.model_dir.name == model_name: + return + + model_config = config.model_config() + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + logger.warning( + f"Unable to switch model to {model_name} " + "because an admin key isn't provided." + ) + + return + + if not unwrap(model_config.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_load" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): From 9c10789ca1095e0571c534e5e5535f3b337ec604 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 4 Sep 2024 21:44:14 -0400 Subject: [PATCH 3/3] API: Error on invalid key permissions and cleanup format If a user requesting a model change isn't admin, error. Better to place the load function before the generate functions. Signed-off-by: kingbri --- endpoints/OAI/utils/completion.py | 83 ++++++++++++++++--------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 5fdf81f4..cc752c52 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -108,6 +108,48 @@ async def _stream_collector( await gen_queue.put(e) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists + if model.container and model.container.model_dir.name == model_name: + return + + model_config = config.model_config() + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + + if not unwrap(model_config.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_load" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def stream_generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): @@ -178,47 +220,6 @@ async def stream_generate_completion( ) -async def load_inline_model(model_name: str, request: Request): - """Load a model from the data.model parameter""" - - # Return if the model container already exists - if model.container and model.container.model_dir.name == model_name: - return - - model_config = config.model_config() - - # Inline model loading isn't enabled or the user isn't an admin - if not get_key_permission(request) == "admin": - logger.warning( - f"Unable to switch model to {model_name} " - "because an admin key isn't provided." - ) - - return - - if not unwrap(model_config.get("inline_model_loading"), False): - logger.warning( - f"Unable to switch model to {model_name} because " - '"inline_model_load" is not True in config.yml.' - ) - - return - - model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) - model_path = model_path / model_name - - # Model path doesn't exist - if not model_path.exists(): - logger.warning( - f"Could not find model path {str(model_path)}. Skipping inline model load." - ) - - return - - # Load the model - await model.load_model(model_path) - - async def generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ):