Skip to content

Commit

Permalink
API: Add inline exception for dummy models
Browse files Browse the repository at this point in the history
If an API key sends a dummy model, it shouldn't error as the server
is catering to clients that expect specific OAI model names. This
is a problem with inline model loading since these names would error
by default. Therefore, add an exception if the provided name is in the
dummy model names (which also doubles as inline strict exceptions).

However, the dummy model names weren't configurable, so add a new
option to specify exception names, otherwise the default is gpt-3.5-turbo.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Nov 18, 2024
1 parent b94c646 commit bd9e78e
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 11 deletions.
17 changes: 14 additions & 3 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,25 @@ class ModelConfig(BaseConfigModel):
False,
description=(
"Allow direct loading of models "
"from a completion or chat completion request (default: False)."
"from a completion or chat completion request (default: False).\n"
"This method of loading is strict by default.\n"
"Enable dummy models to add exceptions for invalid model names."
),
)
use_dummy_models: Optional[bool] = Field(
False,
description=(
"Sends dummy model names when the models endpoint is queried.\n"
"Enable this if the client is looking for specific OAI models."
"Sends dummy model names when the models endpoint is queried. "
"(default: False)\n"
"Enable this if the client is looking for specific OAI models.\n"
),
)
dummy_model_names: List[str] = Field(
default=["gpt-3.5-turbo"],
description=(
"A list of fake model names that are sent via the /v1/models endpoint. "
'(default: ["gpt-3.5-turbo"])\n'
"Also used as bypasses for strict mode if inline_model_loading is true."
),
)
model_name: Optional[str] = Field(
Expand Down
8 changes: 7 additions & 1 deletion config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ model:
model_dir: models

# Allow direct loading of models from a completion or chat completion request (default: False).
# This method of loading is strict by default.
# Enable dummy models to add exceptions for invalid model names.
inline_model_loading: false

# Sends dummy model names when the models endpoint is queried.
# Sends dummy model names when the models endpoint is queried. (default: False)
# Enable this if the client is looking for specific OAI models.
use_dummy_models: false

# A list of fake model names that are sent via the /v1/models endpoint. (default: ["gpt-3.5-turbo"])
# Also used as bypasses for strict mode if inline_model_loading is true.
dummy_model_names: ["gpt-3.5-turbo"]

# An initial model to load.
# Make sure the model is located in the model directory!
# REQUIRED: This must be filled out to load a model on startup.
Expand Down
27 changes: 21 additions & 6 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,33 @@ async def load_inline_model(model_name: str, request: Request):

return

is_dummy_model = (
config.model.use_dummy_models and model_name in config.model.dummy_model_names
)

# Error if an invalid key is passed
# If a dummy model is provided, don't error
if get_key_permission(request) != "admin":
error_message = handle_request_error(
f"Unable to switch model to {model_name} because "
+ "an admin key isn't provided",
exc_info=False,
).error.message
if not is_dummy_model:
error_message = handle_request_error(
f"Unable to switch model to {model_name} because "
+ "an admin key isn't provided",
exc_info=False,
).error.message

raise HTTPException(401, error_message)
raise HTTPException(401, error_message)
else:
return

# Start inline loading
# Past here, user is assumed to be admin

# Skip if the model is a dummy
if is_dummy_model:
logger.warning(f"Dummy model {model_name} provided. Skipping inline load.")

return

model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name

Expand Down
3 changes: 2 additions & 1 deletion endpoints/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from endpoints.core.utils.model import (
get_current_model,
get_current_model_list,
get_dummy_models,
get_model_list,
stream_model_load,
)
Expand Down Expand Up @@ -82,7 +83,7 @@ async def list_models(request: Request) -> ModelList:
models = await get_current_model_list()

if config.model.use_dummy_models:
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
models.data[:0] = get_dummy_models()

return models

Expand Down
7 changes: 7 additions & 0 deletions endpoints/core/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def get_current_model():
return model_card


def get_dummy_models():
if config.model.dummy_model_names:
return [ModelCard(id=dummy_id) for dummy_id in config.model.dummy_model_names]
else:
return [ModelCard(id="gpt-3.5-turbo")]


async def stream_model_load(
data: ModelLoadRequest,
model_path: pathlib.Path,
Expand Down

0 comments on commit bd9e78e

Please sign in to comment.