Skip to content

Commit

Permalink
API: Back to async
Browse files Browse the repository at this point in the history
According to FastAPI docs, if you're using a generic function, running
it in async will make it more performant (which makes sense since
running def functions for routes will automatically run the caller
through a threadpool).

Tested and everything works fine.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Mar 5, 2024
1 parent d6602c5 commit 51f01d6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
8 changes: 6 additions & 2 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def load_auth_keys(disable_from_config: bool):
)


def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the API key is valid."""

# Allow request if auth is disabled
Expand All @@ -102,7 +104,9 @@ def check_api_key(x_api_key: str = Header(None), authorization: str = Header(Non
raise HTTPException(401, "Please provide an API key")


def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
async def check_admin_key(
x_admin_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the admin key is valid."""

# Allow request if auth is disabled
Expand Down
36 changes: 18 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
MODEL_CONTAINER: Optional[ExllamaV2Container] = None


def _check_model_container():
async def _check_model_container():
if MODEL_CONTAINER is None or not (
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
):
Expand All @@ -116,7 +116,7 @@ def _check_model_container():
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
def list_models():
async def list_models():
"""Lists all models in the model directory."""
model_config = get_model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
Expand All @@ -140,7 +140,7 @@ def list_models():
"/v1/internal/model/info",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def get_current_model():
async def get_current_model():
"""Returns the currently loaded model."""
model_name = MODEL_CONTAINER.get_model_path().name
prompt_template = MODEL_CONTAINER.prompt_template
Expand Down Expand Up @@ -173,7 +173,7 @@ def get_current_model():


@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
def list_draft_models():
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
Expand Down Expand Up @@ -225,7 +225,7 @@ async def generator():

# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
unload_model()
await unload_model()

MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)

Expand All @@ -235,7 +235,7 @@ async def generator():
try:
for module, modules in load_status:
if await request.is_disconnected():
unload_model()
await unload_model()
break

if module == 0:
Expand Down Expand Up @@ -293,7 +293,7 @@ async def generator():
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def unload_model():
async def unload_model():
"""Unloads the currently loaded model."""
global MODEL_CONTAINER

Expand All @@ -303,7 +303,7 @@ def unload_model():

@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
def get_templates():
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)
Expand All @@ -313,7 +313,7 @@ def get_templates():
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def switch_template(data: TemplateSwitchRequest):
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
raise HTTPException(400, "New template name not found.")
Expand All @@ -329,7 +329,7 @@ def switch_template(data: TemplateSwitchRequest):
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def unload_template():
async def unload_template():
"""Unloads the currently selected template"""

MODEL_CONTAINER.prompt_template = None
Expand All @@ -338,7 +338,7 @@ def unload_template():
# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
def list_sampler_overrides():
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""

return get_sampler_overrides()
Expand All @@ -348,7 +348,7 @@ def list_sampler_overrides():
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
def switch_sampler_override(data: SamplerOverrideSwitchRequest):
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""

if data.preset:
Expand All @@ -370,7 +370,7 @@ def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
def unload_sampler_override():
async def unload_sampler_override():
"""Unloads the currently selected override preset"""

set_overrides_from_dict({})
Expand All @@ -379,7 +379,7 @@ def unload_sampler_override():
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
def get_all_loras():
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
Expand All @@ -392,7 +392,7 @@ def get_all_loras():
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def get_active_loras():
async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=list(
Expand Down Expand Up @@ -455,7 +455,7 @@ def load_loras_internal():
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def unload_loras():
async def unload_loras():
"""Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(loras_only=True)

Expand All @@ -465,7 +465,7 @@ def unload_loras():
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def encode_tokens(data: TokenEncodeRequest):
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string into tokens."""
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
tokens = unwrap(raw_tokens, [])
Expand All @@ -479,7 +479,7 @@ def encode_tokens(data: TokenEncodeRequest):
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def decode_tokens(data: TokenDecodeRequest):
async def decode_tokens(data: TokenDecodeRequest):
"""Decodes tokens into a string."""
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
Expand Down

0 comments on commit 51f01d6

Please sign in to comment.