From 51f01d6df2aca9fd538a3c482b3f87166e5153d8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 4 Mar 2024 22:59:10 -0500 Subject: [PATCH] API: Back to async According to FastAPI docs, if you're using a generic function, running it in async will make it more performant (which makes sense since running def functions for routes will automatically run the caller through a threadpool). Tested and everything works fine. Signed-off-by: kingbri --- common/auth.py | 8 ++++++-- main.py | 36 ++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/common/auth.py b/common/auth.py index 7c3975fa..ded42b97 100644 --- a/common/auth.py +++ b/common/auth.py @@ -76,7 +76,9 @@ def load_auth_keys(disable_from_config: bool): ) -def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): +async def check_api_key( + x_api_key: str = Header(None), authorization: str = Header(None) +): """Check if the API key is valid.""" # Allow request if auth is disabled @@ -102,7 +104,9 @@ def check_api_key(x_api_key: str = Header(None), authorization: str = Header(Non raise HTTPException(401, "Please provide an API key") -def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): +async def check_admin_key( + x_admin_key: str = Header(None), authorization: str = Header(None) +): """Check if the admin key is valid.""" # Allow request if auth is disabled diff --git a/main.py b/main.py index 737b668a..b71a4ed5 100644 --- a/main.py +++ b/main.py @@ -92,7 +92,7 @@ MODEL_CONTAINER: Optional[ExllamaV2Container] = None -def _check_model_container(): +async def _check_model_container(): if MODEL_CONTAINER is None or not ( MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded ): @@ -116,7 +116,7 @@ def _check_model_container(): # Model list endpoint @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -def list_models(): +async def list_models(): """Lists all models in the model directory.""" model_config = get_model_config() model_dir = unwrap(model_config.get("model_dir"), "models") @@ -140,7 +140,7 @@ def list_models(): "/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)], ) -def get_current_model(): +async def get_current_model(): """Returns the currently loaded model.""" model_name = MODEL_CONTAINER.get_model_path().name prompt_template = MODEL_CONTAINER.prompt_template @@ -173,7 +173,7 @@ def get_current_model(): @app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) -def list_draft_models(): +async def list_draft_models(): """Lists all draft models in the model directory.""" draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models") draft_model_path = pathlib.Path(draft_model_dir) @@ -225,7 +225,7 @@ async def generator(): # Unload the existing model if MODEL_CONTAINER and MODEL_CONTAINER.model: - unload_model() + await unload_model() MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) @@ -235,7 +235,7 @@ async def generator(): try: for module, modules in load_status: if await request.is_disconnected(): - unload_model() + await unload_model() break if module == 0: @@ -293,7 +293,7 @@ async def generator(): "/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)], ) -def unload_model(): +async def unload_model(): """Unloads the currently loaded model.""" global MODEL_CONTAINER @@ -303,7 +303,7 @@ def unload_model(): @app.get("/v1/templates", dependencies=[Depends(check_api_key)]) @app.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -def get_templates(): +async def get_templates(): templates = get_all_templates() template_strings = list(map(lambda template: template.stem, templates)) return TemplateList(data=template_strings) @@ -313,7 +313,7 @@ def get_templates(): "/v1/template/switch", dependencies=[Depends(check_admin_key), Depends(_check_model_container)], ) -def switch_template(data: TemplateSwitchRequest): +async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template""" if not data.name: raise HTTPException(400, "New template name not found.") @@ -329,7 +329,7 @@ def switch_template(data: TemplateSwitchRequest): "/v1/template/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)], ) -def unload_template(): +async def unload_template(): """Unloads the currently selected template""" MODEL_CONTAINER.prompt_template = None @@ -338,7 +338,7 @@ def unload_template(): # Sampler override endpoints @app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) @app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -def list_sampler_overrides(): +async def list_sampler_overrides(): """API wrapper to list all currently applied sampler overrides""" return get_sampler_overrides() @@ -348,7 +348,7 @@ def list_sampler_overrides(): "/v1/sampling/override/switch", dependencies=[Depends(check_admin_key)], ) -def switch_sampler_override(data: SamplerOverrideSwitchRequest): +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): """Switch the currently loaded override preset""" if data.preset: @@ -370,7 +370,7 @@ def switch_sampler_override(data: SamplerOverrideSwitchRequest): "/v1/sampling/override/unload", dependencies=[Depends(check_admin_key)], ) -def unload_sampler_override(): +async def unload_sampler_override(): """Unloads the currently selected override preset""" set_overrides_from_dict({}) @@ -379,7 +379,7 @@ def unload_sampler_override(): # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -def get_all_loras(): +async def get_all_loras(): """Lists all LoRAs in the lora directory.""" lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) @@ -392,7 +392,7 @@ def get_all_loras(): "/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)], ) -def get_active_loras(): +async def get_active_loras(): """Returns the currently loaded loras.""" active_loras = LoraList( data=list( @@ -455,7 +455,7 @@ def load_loras_internal(): "/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)], ) -def unload_loras(): +async def unload_loras(): """Unloads the currently loaded loras.""" MODEL_CONTAINER.unload(loras_only=True) @@ -465,7 +465,7 @@ def unload_loras(): "/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)], ) -def encode_tokens(data: TokenEncodeRequest): +async def encode_tokens(data: TokenEncodeRequest): """Encodes a string into tokens.""" raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params()) tokens = unwrap(raw_tokens, []) @@ -479,7 +479,7 @@ def encode_tokens(data: TokenEncodeRequest): "/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)], ) -def decode_tokens(data: TokenDecodeRequest): +async def decode_tokens(data: TokenDecodeRequest): """Decodes tokens into a string.""" message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params()) response = TokenDecodeResponse(text=unwrap(message, ""))