diff --git a/OAI/types/sampler_overrides.py b/OAI/types/sampler_overrides.py new file mode 100644 index 00000000..c4b75d82 --- /dev/null +++ b/OAI/types/sampler_overrides.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class SamplerOverrideSwitchRequest(BaseModel): + """Sampler override switch request""" + + preset: Optional[str] = Field( + default=None, description="Pass a sampler override preset name" + ) + + overrides: Optional[dict] = Field( + default=None, + description=( + "Sampling override parent takes in individual keys and overrides." + + "Ignored if preset is provided." + ), + examples=[ + { + "top_p": { + "override": 1.5, + "force": False, + } + } + ], + ) diff --git a/common/sampling.py b/common/sampling.py index 01d11f10..53defcc1 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -166,6 +166,10 @@ def to_gen_params(self): DEFAULT_OVERRIDES = {} +def get_sampler_overrides(): + return DEFAULT_OVERRIDES + + def set_overrides_from_dict(new_overrides: dict): """Wrapper function to update sampler overrides""" @@ -174,10 +178,10 @@ def set_overrides_from_dict(new_overrides: dict): if isinstance(new_overrides, dict): DEFAULT_OVERRIDES = new_overrides else: - raise TypeError("new sampler overrides must be a dict!") + raise TypeError("New sampler overrides must be a dict!") -def get_overrides_from_file(preset_name: str): +def set_overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") @@ -188,11 +192,13 @@ def get_overrides_from_file(preset_name: str): logger.info("Applied sampler overrides from file.") else: - logger.warn( - f"Sampler override file named \"{preset_name}\" was not found. " + error_message = ( + f'Sampler override file named "{preset_name}" was not found. ' + "Make sure it's located in the sampler_overrides folder." ) + raise FileNotFoundError(error_message) + # TODO: Maybe move these into the class # Classmethods aren't recognized in pydantic default_factories diff --git a/main.py b/main.py index df0d13df..cbd8f600 100644 --- a/main.py +++ b/main.py @@ -26,7 +26,11 @@ get_network_config, ) from common.generators import call_with_semaphore, generate_with_semaphore -from common.sampling import get_overrides_from_file +from common.sampling import ( + get_sampler_overrides, + set_overrides_from_file, + set_overrides_from_dict, +) from common.templating import ( get_all_templates, get_prompt_from_template, @@ -43,6 +47,7 @@ ModelLoadResponse, ModelCardParameters, ) +from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest from OAI.types.template import TemplateList, TemplateSwitchRequest from OAI.types.token import ( TokenEncodeRequest, @@ -288,6 +293,47 @@ async def unload_template(): MODEL_CONTAINER.prompt_template = None +# Sampler override endpoints +@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +async def list_sampler_overrides(): + """API wrapper to list all currently applied sampler overrides""" + + return get_sampler_overrides() + + +@app.post( + "/v1/sampling/override/switch", + dependencies=[Depends(check_admin_key)], +) +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): + """Switch the currently loaded override preset""" + + if data.preset: + try: + set_overrides_from_file(data.preset) + except FileNotFoundError as e: + raise HTTPException( + 400, "Sampler override preset does not exist. Check the name?" + ) from e + elif data.overrides: + set_overrides_from_dict(data.overrides) + else: + raise HTTPException( + 400, "A sampler override preset or dictionary wasn't provided." + ) + + +@app.post( + "/v1/sampling/override/unload", + dependencies=[Depends(check_admin_key)], +) +async def unload_sampler_override(): + """Unloads the currently selected override preset""" + + set_overrides_from_dict({}) + + # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) @@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None): sampling_config = get_sampling_config() sampling_override_preset = sampling_config.get("override_preset") if sampling_override_preset: - get_overrides_from_file(sampling_override_preset) + try: + set_overrides_from_file(sampling_override_preset) + except FileNotFoundError as e: + logger.warning(str(e)) # If an initial model name is specified, create a container # and load the model