Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add extra concurrency to API #76

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions OAI/types/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class LoraLoadRequest(BaseModel):
"""Represents a Lora load request."""

loras: List[LoraLoadInfo]
skip_queue: bool = False


class LoraLoadResponse(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ModelLoadRequest(BaseModel):
use_cfg: Optional[bool] = None
fasttensors: Optional[bool] = False
draft: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False


class ModelLoadResponse(BaseModel):
Expand Down
14 changes: 13 additions & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class ExllamaV2Container:
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]

# Load state
model_is_loading: bool = False
model_loaded: bool = False

def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Create model container
Expand Down Expand Up @@ -347,6 +351,9 @@ def load_gen(self, progress_callback=None):
def progress(loaded_modules: int, total_modules: int)
"""

# Notify that the model is being loaded
self.model_is_loading = True

# Load tokenizer
self.tokenizer = ExLlamaV2Tokenizer(self.config)

Expand Down Expand Up @@ -435,6 +442,9 @@ def progress(loaded_modules: int, total_modules: int)
gc.collect()
torch.cuda.empty_cache()

# Update model load state
self.model_is_loading = False
self.model_loaded = True
logger.info("Model successfully loaded.")

def unload(self, loras_only: bool = False):
Expand Down Expand Up @@ -465,7 +475,9 @@ def unload(self, loras_only: bool = False):
gc.collect()
torch.cuda.empty_cache()

logger.info("Model unloaded.")
# Update model load state
self.model_loaded = False
logger.info("Loras unloaded." if loras_only else "Model unloaded.")

def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
Expand Down
8 changes: 6 additions & 2 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def load_auth_keys(disable_from_config: bool):
)


def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the API key is valid."""

# Allow request if auth is disabled
Expand All @@ -102,7 +104,9 @@ def check_api_key(x_api_key: str = Header(None), authorization: str = Header(Non
raise HTTPException(401, "Please provide an API key")


def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
async def check_admin_key(
x_admin_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the admin key is valid."""

# Allow request if auth is disabled
Expand Down
23 changes: 13 additions & 10 deletions common/generators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Generator functions for the tabbyAPI."""
"""Generator handling"""

import asyncio
import inspect
from asyncio import Semaphore
from functools import partialmethod
from typing import AsyncGenerator
from typing import AsyncGenerator, Generator, Union

generate_semaphore = Semaphore(1)
generate_semaphore = asyncio.Semaphore(1)


# Async generation that blocks on a semaphore
async def generate_with_semaphore(generator: AsyncGenerator):
async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
"""Generate with a semaphore."""

async with generate_semaphore:
if inspect.isasyncgenfunction:
async for result in generator():
Expand All @@ -19,9 +20,11 @@ async def generate_with_semaphore(generator: AsyncGenerator):
yield result


# Block a function with semaphore
async def call_with_semaphore(callback: partialmethod):
if inspect.iscoroutinefunction(callback):
return await callback()
"""Call with a semaphore."""

async with generate_semaphore:
return callback()
if inspect.iscoroutinefunction(callback):
return await callback()
else:
return callback()
39 changes: 26 additions & 13 deletions common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Common utilities for the tabbyAPI"""
import traceback
from typing import Optional
"""Common utility functions"""

import traceback
from pydantic import BaseModel
from typing import Optional

from common.logger import init_logger

Expand All @@ -14,30 +14,43 @@ def load_progress(module, modules):
yield module, modules


class TabbyGeneratorErrorMessage(BaseModel):
"""Common error types."""
class TabbyRequestErrorMessage(BaseModel):
"""Common request error type."""

message: str
trace: Optional[str] = None


class TabbyGeneratorError(BaseModel):
"""Common error types."""
class TabbyRequestError(BaseModel):
"""Common request error type."""

error: TabbyGeneratorErrorMessage
error: TabbyRequestErrorMessage


def get_generator_error(message: str):
"""Get a generator error."""
error_message = TabbyGeneratorErrorMessage(

generator_error = handle_request_error(message)

return get_sse_packet(generator_error.model_dump_json())


def handle_request_error(message: str):
"""Log a request error to the console."""

error_message = TabbyRequestErrorMessage(
message=message, trace=traceback.format_exc()
)

generator_error = TabbyGeneratorError(error=error_message)
request_error = TabbyRequestError(error=error_message)

# Log and send the exception
logger.error(generator_error.error.trace)
return get_sse_packet(generator_error.model_dump_json())
# Log the error and provided message to the console
if error_message.trace:
logger.error(error_message.trace)

logger.error(f"Sent to request: {message}")

return request_error


def get_sse_packet(json_data: str):
Expand Down
Loading
Loading