Skip to content

Commit

Permalink
OAI: Log request errors to console
Browse files Browse the repository at this point in the history
Previously, some request errors were only sent to the client, but
some clients don't log the full error, so log it in console.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Mar 24, 2024
1 parent 26496c4 commit db62d1e
Showing 1 changed file with 61 additions and 22 deletions.
83 changes: 61 additions & 22 deletions endpoints/OAI/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,37 @@ async def load_model(request: Request, data: ModelLoadRequest):

# Verify request parameters
if not data.name:
raise HTTPException(400, "A model name was not provided.")
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name

draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(
400, "draft_model_name was not found inside the draft object."
)
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)

if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
error_message = handle_request_error(
"Could not find the model path for load. Check model name or config.yml?",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

load_callback = partial(stream_model_load, data, model_path, draft_model_path)

Expand Down Expand Up @@ -220,13 +233,23 @@ async def get_templates():
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
raise HTTPException(400, "New template name not found.")
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

try:
template = get_template_from_file(data.name)
model.container.prompt_template = template
except FileNotFoundError as e:
raise HTTPException(400, "Template does not exist. Check the name?") from e
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message

raise HTTPException(400, error_message) from e


@app.post(
Expand Down Expand Up @@ -259,15 +282,22 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
raise HTTPException(
400, "Sampler override preset does not exist. Check the name?"
) from e
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message

raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
raise HTTPException(
400, "A sampler override preset or dictionary wasn't provided."
)
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)


@app.post(
Expand Down Expand Up @@ -322,14 +352,21 @@ async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""

if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
error_message = handle_request_error(
"List of loras to load is not found.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(
400,
"A parent lora directory does not exist. Check your config.yml?",
)
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

load_callback = partial(model.load_loras, lora_dir, **data.model_dump())

Expand Down Expand Up @@ -459,10 +496,12 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
"""Generates a chat completion from a prompt."""

if model.container.prompt_template is None:
raise HTTPException(
422,
"This endpoint is disabled because a prompt template is not set.",
)
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message

raise HTTPException(422, error_message)

model_path = model.container.get_model_path()

Expand Down

0 comments on commit db62d1e

Please sign in to comment.