Skip to content

Commit

Permalink
API: Add sampler override switching
Browse files Browse the repository at this point in the history
Allow users to switch the currently overriden samplers via the API
so a restart isn't required to switch the overrides.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jan 24, 2024
1 parent 4cf231d commit a9a128c
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 6 deletions.
26 changes: 26 additions & 0 deletions OAI/types/sampler_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import Optional


class SamplerOverrideSwitchRequest(BaseModel):
"""Sampler override switch request"""

preset: Optional[str] = Field(
default=None, description="Pass a sampler override preset name"
)

overrides: Optional[dict] = Field(
default=None,
description=(
"Sampling override parent takes in individual keys and overrides."
+ "Ignored if preset is provided."
),
examples=[
{
"top_p": {
"override": 1.5,
"force": False,
}
}
],
)
14 changes: 10 additions & 4 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def to_gen_params(self):
DEFAULT_OVERRIDES = {}


def get_sampler_overrides():
return DEFAULT_OVERRIDES


def set_overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""

Expand All @@ -174,10 +178,10 @@ def set_overrides_from_dict(new_overrides: dict):
if isinstance(new_overrides, dict):
DEFAULT_OVERRIDES = new_overrides
else:
raise TypeError("new sampler overrides must be a dict!")
raise TypeError("New sampler overrides must be a dict!")


def get_overrides_from_file(preset_name: str):
def set_overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""

preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
Expand All @@ -188,11 +192,13 @@ def get_overrides_from_file(preset_name: str):

logger.info("Applied sampler overrides from file.")
else:
logger.warn(
f"Sampler override file named \"{preset_name}\" was not found. "
error_message = (
f'Sampler override file named "{preset_name}" was not found. '
+ "Make sure it's located in the sampler_overrides folder."
)

raise FileNotFoundError(error_message)


# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories
Expand Down
53 changes: 51 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
get_network_config,
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.sampling import (
get_sampler_overrides,
set_overrides_from_file,
set_overrides_from_dict,
)
from common.templating import (
get_all_templates,
get_prompt_from_template,
Expand All @@ -43,6 +47,7 @@
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
from OAI.types.template import TemplateList, TemplateSwitchRequest
from OAI.types.token import (
TokenEncodeRequest,
Expand Down Expand Up @@ -288,6 +293,47 @@ async def unload_template():
MODEL_CONTAINER.prompt_template = None


# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""

return get_sampler_overrides()


@app.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""

if data.preset:
try:
set_overrides_from_file(data.preset)
except FileNotFoundError as e:
raise HTTPException(
400, "Sampler override preset does not exist. Check the name?"
) from e
elif data.overrides:
set_overrides_from_dict(data.overrides)
else:
raise HTTPException(
400, "A sampler override preset or dictionary wasn't provided."
)


@app.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""

set_overrides_from_dict({})


# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
Expand Down Expand Up @@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None):
sampling_config = get_sampling_config()
sampling_override_preset = sampling_config.get("override_preset")
if sampling_override_preset:
get_overrides_from_file(sampling_override_preset)
try:
set_overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))

# If an initial model name is specified, create a container
# and load the model
Expand Down

0 comments on commit a9a128c

Please sign in to comment.