Skip to content

Commit

Permalink
API: Split functions into their own files
Browse files Browse the repository at this point in the history
Previously, generation function were bundled with the request function
causing the overall code structure and API to look ugly and unreadable.

Split these up and cleanup a lot of the methods that were previously
overlooked in the API itself.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Mar 12, 2024
1 parent cdffa8f commit 46caa50
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 293 deletions.
13 changes: 13 additions & 0 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,16 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
async def load_model(model_path: pathlib.Path, **kwargs):
async for _, _, _ in load_model_gen(model_path, **kwargs):
pass


def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras."""
if len(container.active_loras) > 0:
unload_loras()

return container.load_loras(lora_dir, **kwargs)


def unload_loras():
"""Wrapper to unload loras"""
container.unload(loras_only=True)
242 changes: 49 additions & 193 deletions endpoints/OAI/app.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
import pathlib
from sse_starlette import EventSourceResponse
import uvicorn
from asyncio import CancelledError
from uuid import uuid4
from jinja2 import TemplateError
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse

from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key
from common.generators import (
call_with_semaphore,
generate_with_semaphore,
release_semaphore,
)
from common.logger import UVICORN_LOG_CONFIG
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import (
get_generator_error,
handle_request_error,
unwrap,
)
Expand All @@ -39,7 +33,6 @@
from endpoints.OAI.types.model import (
ModelCard,
ModelLoadRequest,
ModelLoadResponse,
ModelCardParameters,
)
from endpoints.OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
Expand All @@ -50,12 +43,16 @@
TokenDecodeRequest,
TokenDecodeResponse,
)
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
create_completion_response,
create_chat_completion_response,
create_chat_completion_stream_chunk,
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.model import get_model_list
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list

app = FastAPI(
Expand Down Expand Up @@ -169,73 +166,34 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name

load_data = data.model_dump()

draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(
400, "draft_model_name was not found inside the draft object."
)

load_data["draft"]["draft_model_dir"] = unwrap(
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)

if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")

async def generator():
"""Request generation wrapper for the loading process."""

load_status = model.load_model_gen(model_path, **load_data)
try:
async for module, modules, model_type in load_status:
if await request.is_disconnected():
release_semaphore()
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
return

if module != 0:
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)

yield response.model_dump_json()

if module == modules:
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="finished",
)

yield response.model_dump_json()
except CancelledError:
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
yield get_generator_error(str(exc))
load_callback = partial(
stream_model_load, request, data, model_path, draft_model_path
)

# Determine whether to use or skip the queue
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
logger.warning(
"Model load request is skipping the completions queue. "
"Unexpected results may occur."
)
generator_callback = generator
else:
generator_callback = partial(generate_with_semaphore, generator)
load_callback = partial(generate_with_semaphore, load_callback)

return EventSourceResponse(generator_callback())
return EventSourceResponse(load_callback())


# Unload model endpoint
Expand Down Expand Up @@ -363,6 +321,7 @@ async def get_active_loras():
)
async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""

if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")

Expand All @@ -373,28 +332,25 @@ async def load_lora(data: LoraLoadRequest):
"A parent lora directory does not exist. Check your config.yml?",
)

# Clean-up existing loras if present
def load_loras_internal():
if len(model.container.active_loras) > 0:
unload_loras()

result = model.container.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []),
)

internal_callback = partial(run_in_threadpool, load_loras_internal)
load_callback = partial(
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
)

# Determine whether to skip the queue
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
logger.warning(
"Lora load request is skipping the completions queue. "
"Unexpected results may occur."
)
return await internal_callback()
else:
return await call_with_semaphore(internal_callback)
load_callback = partial(call_with_semaphore, load_callback)

load_result = await load_callback()

return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
failure=unwrap(load_result.get("failure"), []),
)


# Unload lora endpoint
Expand All @@ -404,7 +360,8 @@ def load_loras_internal():
)
async def unload_loras():
"""Unloads the currently loaded loras."""
model.container.unload(loras_only=True)

model.unload_loras()


# Encode tokens endpoint
Expand Down Expand Up @@ -439,7 +396,7 @@ async def decode_tokens(data: TokenDecodeRequest):
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_completion(request: Request, data: CompletionRequest):
async def completion_request(request: Request, data: CompletionRequest):
"""Generates a completion from a prompt."""
model_path = model.container.get_model_path()

Expand All @@ -451,60 +408,25 @@ async def generate_completion(request: Request, data: CompletionRequest):
)

if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_completion, request, data, model_path
)

async def generator():
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return

response = create_completion_response(generation, model_path.name)

yield response.model_dump_json()

# Yield a finish response on successful generation
yield "[DONE]"
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)

return EventSourceResponse(generate_with_semaphore(generator))

try:
generation = await call_with_semaphore(
partial(
run_in_threadpool,
model.container.generate,
data.prompt,
**data.to_gen_params(),
)
return EventSourceResponse(generate_with_semaphore(generator_callback))
else:
response = await call_with_semaphore(
partial(generate_completion, data, model_path)
)

response = create_completion_response(generation, model_path.name)
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message

# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc


# Chat completions endpoint
@app.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
async def chat_completion_request(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""

if model.container.prompt_template is None:
Expand All @@ -518,90 +440,24 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
if isinstance(data.messages, str):
prompt = data.messages
else:
try:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False),
)

prompt = get_prompt_from_template(
data.messages,
model.container.prompt_template,
data.add_generation_prompt,
special_tokens_dict,
)
except KeyError as exc:
raise HTTPException(
400,
"Could not find a Conversation from prompt template "
f"'{model.container.prompt_template.name}'. "
"Check your spelling?",
) from exc
except TemplateError as exc:
raise HTTPException(
400,
f"TemplateError: {str(exc)}",
) from exc
prompt = format_prompt_with_template(data)

disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)

if data.stream and not disable_request_streaming:
const_id = f"chatcmpl-{uuid4().hex}"

async def generator():
"""Generator for the generation process."""
try:
new_generation = model.container.generate_gen(
prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Chat completion generation cancelled by user.")
return

response = create_chat_completion_stream_chunk(
const_id, generation, model_path.name
)

yield response.model_dump_json()

# Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk(
const_id, finish_reason="stop"
)

yield finish_response.model_dump_json()
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
)

return EventSourceResponse(generate_with_semaphore(generator))
generator_callback = partial(
stream_generate_chat_completion, prompt, request, data, model_path
)

try:
generation = await call_with_semaphore(
partial(
run_in_threadpool,
model.container.generate,
prompt,
**data.to_gen_params(),
)
return EventSourceResponse(generate_with_semaphore(generator_callback))
else:
response = await call_with_semaphore(
partial(generate_chat_completion, prompt, request, data, model_path)
)
response = create_chat_completion_response(generation, model_path.name)

return response
except Exception as exc:
error_message = handle_request_error(
"Chat completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message

# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc


def start_api(host: str, port: int):
Expand Down
Loading

0 comments on commit 46caa50

Please sign in to comment.