From 1979ca319f5df3fc1824911111964c342bded641 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 18 Jan 2024 00:42:52 -0500 Subject: [PATCH 1/7] Tree: Refactor code organization Move common functions into their own folder and refactor the backends to use their own folder as well. Also cleanup imports and alphabetize import statments themselves. Finally, move colab and docker into their own folders as well. Signed-off-by: kingbri --- OAI/types/chat_completion.py | 5 ++-- OAI/types/completion.py | 3 +-- OAI/types/lora.py | 3 +-- OAI/types/model.py | 5 ++-- OAI/types/token.py | 3 +-- OAI/utils_oai.py | 3 +-- model.py => backends/exllamav2/model.py | 12 ++++----- .../TabbyAPI_Colab_Example.ipynb | 0 args.py => common/args.py | 0 auth.py => common/auth.py | 7 +++-- config.py => common/config.py | 4 +-- gen_logging.py => common/gen_logging.py | 2 +- generators.py => common/generators.py | 0 logger.py => common/logger.py | 0 templating.py => common/templating.py | 0 utils.py => common/utils.py | 2 +- .dockerignore => docker/.dockerignore | 0 Dockerfile => docker/Dockerfile | 0 .../docker-compose.yml | 3 ++- main.py | 27 ++++++++++--------- start.py | 2 +- tests/model_test.py | 2 +- 22 files changed, 41 insertions(+), 42 deletions(-) rename model.py => backends/exllamav2/model.py (99%) rename TabbyAPI_Colab_Example.ipynb => colab/TabbyAPI_Colab_Example.ipynb (100%) rename args.py => common/args.py (100%) rename auth.py => common/auth.py (99%) rename config.py => common/config.py (97%) rename gen_logging.py => common/gen_logging.py (98%) rename generators.py => common/generators.py (100%) rename logger.py => common/logger.py (100%) rename templating.py => common/templating.py (100%) rename utils.py => common/utils.py (97%) rename .dockerignore => docker/.dockerignore (100%) rename Dockerfile => docker/Dockerfile (100%) rename docker-compose.yml => docker/docker-compose.yml (86%) diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 5e0e80bb..ba0a9685 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -1,7 +1,8 @@ -from uuid import uuid4 -from time import time from pydantic import BaseModel, Field +from time import time from typing import Union, List, Optional, Dict +from uuid import uuid4 + from OAI.types.common import UsageStats, CommonCompletionRequest diff --git a/OAI/types/completion.py b/OAI/types/completion.py index 15e84a73..4fa380c7 100644 --- a/OAI/types/completion.py +++ b/OAI/types/completion.py @@ -1,10 +1,9 @@ """ Completion API protocols """ +from pydantic import BaseModel, Field from time import time from typing import List, Optional, Union from uuid import uuid4 -from pydantic import BaseModel, Field - from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats diff --git a/OAI/types/lora.py b/OAI/types/lora.py index 841c3a8d..018bf061 100644 --- a/OAI/types/lora.py +++ b/OAI/types/lora.py @@ -1,9 +1,8 @@ """ Lora types """ +from pydantic import BaseModel, Field from time import time from typing import Optional, List -from pydantic import BaseModel, Field - class LoraCard(BaseModel): """Represents a single Lora card.""" diff --git a/OAI/types/model.py b/OAI/types/model.py index 483c41f4..9096d41d 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -1,10 +1,9 @@ """ Contains model card types. """ +from pydantic import BaseModel, Field, ConfigDict from time import time from typing import List, Optional -from pydantic import BaseModel, Field, ConfigDict - -from gen_logging import LogPreferences +from common.gen_logging import LogPreferences class ModelCardParameters(BaseModel): diff --git a/OAI/types/token.py b/OAI/types/token.py index 98cbc989..8467aa6c 100644 --- a/OAI/types/token.py +++ b/OAI/types/token.py @@ -1,7 +1,6 @@ """ Tokenization types """ -from typing import List - from pydantic import BaseModel +from typing import List class CommonTokenRequest(BaseModel): diff --git a/OAI/utils_oai.py b/OAI/utils_oai.py index b3c59d69..5ad2873f 100644 --- a/OAI/utils_oai.py +++ b/OAI/utils_oai.py @@ -2,6 +2,7 @@ import pathlib from typing import Optional +from common.utils import unwrap from OAI.types.chat_completion import ( ChatCompletionMessage, ChatCompletionRespChoice, @@ -14,8 +15,6 @@ from OAI.types.lora import LoraList, LoraCard from OAI.types.model import ModelList, ModelCard -from utils import unwrap - def create_completion_response( text: str, diff --git a/model.py b/backends/exllamav2/model.py similarity index 99% rename from model.py rename to backends/exllamav2/model.py index cef53c34..6ee186b4 100644 --- a/model.py +++ b/backends/exllamav2/model.py @@ -13,17 +13,17 @@ ExLlamaV2Lora, ) from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler - -from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import ( + +from common.gen_logging import log_generation_params, log_prompt, log_response +from common.templating import ( PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file, ) -from utils import coalesce, unwrap -from logger import init_logger +from common.utils import coalesce, unwrap +from common.logger import init_logger logger = init_logger(__name__) @@ -31,7 +31,7 @@ AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2 -class ModelContainer: +class ExllamaV2Container: """The model container class for ExLlamaV2 models.""" config: Optional[ExLlamaV2Config] = None diff --git a/TabbyAPI_Colab_Example.ipynb b/colab/TabbyAPI_Colab_Example.ipynb similarity index 100% rename from TabbyAPI_Colab_Example.ipynb rename to colab/TabbyAPI_Colab_Example.ipynb diff --git a/args.py b/common/args.py similarity index 100% rename from args.py rename to common/args.py diff --git a/auth.py b/common/auth.py similarity index 99% rename from auth.py rename to common/auth.py index 4185ddb6..ea421687 100644 --- a/auth.py +++ b/common/auth.py @@ -3,13 +3,12 @@ application, it should be fine. """ import secrets -from typing import Optional - +import yaml from fastapi import Header, HTTPException from pydantic import BaseModel -import yaml +from typing import Optional -from logger import init_logger +from common.logger import init_logger logger = init_logger(__name__) diff --git a/config.py b/common/config.py similarity index 97% rename from config.py rename to common/config.py index 178977bc..e46be62d 100644 --- a/config.py +++ b/common/config.py @@ -1,8 +1,8 @@ import yaml import pathlib -from logger import init_logger -from utils import unwrap +from common.logger import init_logger +from common.utils import unwrap logger = init_logger(__name__) diff --git a/gen_logging.py b/common/gen_logging.py similarity index 98% rename from gen_logging.py rename to common/gen_logging.py index a82cea35..a20e45c2 100644 --- a/gen_logging.py +++ b/common/gen_logging.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from typing import Dict, Optional -from logger import init_logger +from common.logger import init_logger logger = init_logger(__name__) diff --git a/generators.py b/common/generators.py similarity index 100% rename from generators.py rename to common/generators.py diff --git a/logger.py b/common/logger.py similarity index 100% rename from logger.py rename to common/logger.py diff --git a/templating.py b/common/templating.py similarity index 100% rename from templating.py rename to common/templating.py diff --git a/utils.py b/common/utils.py similarity index 97% rename from utils.py rename to common/utils.py index 529afe0c..2db97e90 100644 --- a/utils.py +++ b/common/utils.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from logger import init_logger +from common.logger import init_logger logger = init_logger(__name__) diff --git a/.dockerignore b/docker/.dockerignore similarity index 100% rename from .dockerignore rename to docker/.dockerignore diff --git a/Dockerfile b/docker/Dockerfile similarity index 100% rename from Dockerfile rename to docker/Dockerfile diff --git a/docker-compose.yml b/docker/docker-compose.yml similarity index 86% rename from docker-compose.yml rename to docker/docker-compose.yml index c553612a..d50682e2 100644 --- a/docker-compose.yml +++ b/docker/docker-compose.yml @@ -2,7 +2,8 @@ version: '3.8' services: tabbyapi: build: - context: . + context: .. + dockerfile: ./docker/Dockerfile ports: - "5000:5000" environment: diff --git a/main.py b/main.py index bbc7ad8a..0ba96505 100644 --- a/main.py +++ b/main.py @@ -11,10 +11,11 @@ from functools import partial from progress.bar import IncrementalBar -import gen_logging -from args import convert_args_to_dict, init_argparser -from auth import check_admin_key, check_api_key, load_auth_keys -from config import ( +import common.gen_logging as gen_logging +from backends.exllamav2.model import ExllamaV2Container +from common.args import convert_args_to_dict, init_argparser +from common.auth import check_admin_key, check_api_key, load_auth_keys +from common.config import ( override_config_from_args, read_config_from_file, get_gen_logging_config, @@ -23,8 +24,10 @@ get_lora_config, get_network_config, ) -from generators import call_with_semaphore, generate_with_semaphore -from model import ModelContainer +from common.generators import call_with_semaphore, generate_with_semaphore +from common.templating import get_all_templates, get_prompt_from_template +from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap +from common.logger import init_logger from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse @@ -48,9 +51,6 @@ create_chat_completion_response, create_chat_completion_stream_chunk, ) -from templating import get_all_templates, get_prompt_from_template -from utils import get_generator_error, get_sse_packet, load_progress, unwrap -from logger import init_logger logger = init_logger(__name__) @@ -64,7 +64,7 @@ ) # Globally scoped variables. Undefined until initalized in main -MODEL_CONTAINER: Optional[ModelContainer] = None +MODEL_CONTAINER: Optional[ExllamaV2Container] = None def _check_model_container(): @@ -182,7 +182,7 @@ async def load_model(request: Request, data: ModelLoadRequest): if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") - MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data) + MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) async def generator(): """Generator for the loading process.""" @@ -530,7 +530,9 @@ def entrypoint(args: Optional[dict] = None): model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / model_name - MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config) + MODEL_CONTAINER = ExllamaV2Container( + model_path.resolve(), False, **model_config + ) load_status = MODEL_CONTAINER.load_gen(load_progress) for module, modules in load_status: if module == 0: @@ -550,6 +552,7 @@ def entrypoint(args: Optional[dict] = None): host = unwrap(network_config.get("host"), "127.0.0.1") port = unwrap(network_config.get("port"), 5000) + # TODO: Move OAI API to a separate folder logger.info(f"Developer documentation: http://{host}:{port}/docs") logger.info(f"Completions: http://{host}:{port}/v1/completions") logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") diff --git a/start.py b/start.py index c4b93344..9bfae3e0 100644 --- a/start.py +++ b/start.py @@ -3,7 +3,7 @@ import os import pathlib import subprocess -from args import convert_args_to_dict, init_argparser +from common.args import convert_args_to_dict, init_argparser def get_requirements_file(): diff --git a/tests/model_test.py b/tests/model_test.py index b4ac158f..b47449ee 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -1,5 +1,5 @@ """ Test the model container. """ -from model import ModelContainer +from backends.exllamav2.model import ModelContainer def progress(module, modules): From 1a8198dc222f8040af6387f1a83b00306bc6eb79 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 21 Jan 2024 23:34:44 -0500 Subject: [PATCH 2/7] Tree: Unify sampler parameters and add override support Unify API sampler params into a superclass which should make them easier to manage and inherit generic functions from. Not all frontends expose all sampling parameters due to connections with OAI (that handles sampling themselves with the exception of a few sliders). Add the ability for the user to customize fallback parameters from server-side. In addition, parameters can be forced to a certain value server-side in case the repo automatically sets other sampler values in the background that the user doesn't want. Signed-off-by: kingbri --- .gitignore | 4 + OAI/types/common.py | 92 +----------- common/config.py | 5 + common/sampling.py | 212 ++++++++++++++++++++++++++++ config_sample.yml | 8 ++ main.py | 8 ++ sampler_overrides/sample_preset.yml | 94 ++++++++++++ 7 files changed, 337 insertions(+), 86 deletions(-) create mode 100644 common/sampling.py create mode 100644 sampler_overrides/sample_preset.yml diff --git a/.gitignore b/.gitignore index 8dde2c2b..f77b7f9e 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,7 @@ templates/* !templates/place_your_templates_here.txt !templates/alpaca.jinja !templates/chatml.jinja + +# Sampler overrides folder +sampler_overrides/* +!sampler_overrides/sample_preset.yml diff --git a/OAI/types/common.py b/OAI/types/common.py index df543495..e90919e7 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,6 +1,8 @@ """ Common types for OAI. """ -from pydantic import BaseModel, Field, AliasChoices -from typing import List, Dict, Optional, Union +from pydantic import BaseModel, Field +from typing import List, Dict, Optional + +from common.sampling import SamplerParams class LogProbs(BaseModel): @@ -20,7 +22,7 @@ class UsageStats(BaseModel): total_tokens: int -class CommonCompletionRequest(BaseModel): +class CommonCompletionRequest(SamplerParams): """Represents a common completion request.""" # Model information @@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel): description="Not parsed. Only used for OAI compliance.", default=None ) - # Generation info - # seed: Optional[int] = -1 + # Generation info (remainder is in SamplerParams superclass) stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = [] - - # Default to 150 as 16 makes no sense as a default - max_tokens: Optional[int] = 150 - - # Sampling params - token_healing: Optional[bool] = False - temperature: Optional[float] = 1.0 - temperature_last: Optional[bool] = False - top_k: Optional[int] = 0 - top_p: Optional[float] = 1.0 - top_a: Optional[float] = 0.0 - min_p: Optional[float] = 0.0 - tfs: Optional[float] = 1.0 - frequency_penalty: Optional[float] = 0.0 - presence_penalty: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - repetition_decay: Optional[int] = 0 - mirostat_mode: Optional[int] = 0 - mirostat_tau: Optional[float] = 1.5 - mirostat_eta: Optional[float] = 0.1 - add_bos_token: Optional[bool] = True - ban_eos_token: Optional[bool] = False - logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]]) - negative_prompt: Optional[str] = None - - # Aliased variables - typical: Optional[float] = Field( - default=1.0, - validation_alias=AliasChoices("typical", "typical_p"), - description="Aliases: typical_p", - ) - - penalty_range: Optional[int] = Field( - default=-1, - validation_alias=AliasChoices( - "penalty_range", - "repetition_range", - "repetition_penalty_range", - ), - description="Aliases: repetition_range, repetition_penalty_range", - ) - - cfg_scale: Optional[float] = Field( - default=1.0, - validation_alias=AliasChoices("cfg_scale", "guidance_scale"), - description="Aliases: guidance_scale", - ) - - def to_gen_params(self): - """Converts to internal generation parameters.""" - # Convert stop to an array of strings - if isinstance(self.stop, str): - self.stop = [self.stop] - - return { - "stop": self.stop, - "max_tokens": self.max_tokens, - "add_bos_token": self.add_bos_token, - "ban_eos_token": self.ban_eos_token, - "token_healing": self.token_healing, - "logit_bias": self.logit_bias, - "temperature": self.temperature, - "temperature_last": self.temperature_last, - "top_k": self.top_k, - "top_p": self.top_p, - "top_a": self.top_a, - "typical": self.typical, - "min_p": self.min_p, - "tfs": self.tfs, - "frequency_penalty": self.frequency_penalty, - "presence_penalty": self.presence_penalty, - "repetition_penalty": self.repetition_penalty, - "penalty_range": self.penalty_range, - "repetition_decay": self.repetition_decay, - "mirostat": self.mirostat_mode == 2, - "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta, - "cfg_scale": self.cfg_scale, - "negative_prompt": self.negative_prompt, - } diff --git a/common/config.py b/common/config.py index e46be62d..9a4b7b17 100644 --- a/common/config.py +++ b/common/config.py @@ -56,6 +56,11 @@ def override_config_from_args(args: dict): } +def get_sampling_config(): + """Returns the sampling parameter config from the global config""" + return unwrap(GLOBAL_CONFIG.get("sampling"), {}) + + def get_model_config(): """Returns the model config from the global config""" return unwrap(GLOBAL_CONFIG.get("model"), {}) diff --git a/common/sampling.py b/common/sampling.py new file mode 100644 index 00000000..01d11f10 --- /dev/null +++ b/common/sampling.py @@ -0,0 +1,212 @@ +"""Common functions for sampling parameters""" + +import pathlib +from typing import Dict, List, Optional, Union +from pydantic import AliasChoices, BaseModel, Field +import yaml + +from common.logger import init_logger +from common.utils import unwrap + + +logger = init_logger(__name__) + + +# Common class for sampler params +class SamplerParams(BaseModel): + """Common class for sampler params that are used in APIs""" + + max_tokens: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("max_tokens", 150) + ) + + stop: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("stop", []) + ) + + token_healing: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("token_healing", False) + ) + + temperature: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("temperature", 1.0) + ) + + temperature_last: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("temperature_last", False) + ) + + top_k: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("top_k", 0) + ) + + top_p: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("top_p", 1.0) + ) + + top_a: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("top_a", 0.0) + ) + + min_p: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("min_p", 0.0) + ) + + tfs: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("tfs", 0.0) + ) + + frequency_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0) + ) + + presence_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0) + ) + + repetition_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + ) + + repetition_decay: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("repetition_decay", 0) + ) + + mirostat_mode: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) + ) + + mirostat_tau: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + ) + + mirostat_eta: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + ) + + add_bos_token: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("add_bos_token", True) + ) + + ban_eos_token: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + ) + + logit_bias: Optional[Dict[int, float]] = Field( + default_factory=lambda: get_default_sampler_value("logit_bias"), + examples=[[{"1": 10}]], + ) + + negative_prompt: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("negative_prompt") + ) + + # Aliased variables + typical: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("typical", 1.0), + validation_alias=AliasChoices("typical", "typical_p"), + description="Aliases: typical_p", + ) + + penalty_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + ), + description="Aliases: repetition_range, repetition_penalty_range", + ) + + cfg_scale: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), + validation_alias=AliasChoices("cfg_scale", "guidance_scale"), + description="Aliases: guidance_scale", + ) + + def to_gen_params(self): + """Converts samplers to internal generation params""" + + # Add forced overrides if present + apply_forced_sampler_overrides(self) + + # Convert stop to an array of strings + if isinstance(self.stop, str): + self.stop = [self.stop] + + return { + "stop": self.stop, + "max_tokens": self.max_tokens, + "add_bos_token": self.add_bos_token, + "ban_eos_token": self.ban_eos_token, + "token_healing": self.token_healing, + "logit_bias": self.logit_bias, + "temperature": self.temperature, + "temperature_last": self.temperature_last, + "top_k": self.top_k, + "top_p": self.top_p, + "top_a": self.top_a, + "typical": self.typical, + "min_p": self.min_p, + "tfs": self.tfs, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "repetition_penalty": self.repetition_penalty, + "penalty_range": self.penalty_range, + "repetition_decay": self.repetition_decay, + "mirostat": self.mirostat_mode == 2, + "mirostat_tau": self.mirostat_tau, + "mirostat_eta": self.mirostat_eta, + "cfg_scale": self.cfg_scale, + "negative_prompt": self.negative_prompt, + } + + +# Global for default overrides +DEFAULT_OVERRIDES = {} + + +def set_overrides_from_dict(new_overrides: dict): + """Wrapper function to update sampler overrides""" + + global DEFAULT_OVERRIDES + + if isinstance(new_overrides, dict): + DEFAULT_OVERRIDES = new_overrides + else: + raise TypeError("new sampler overrides must be a dict!") + + +def get_overrides_from_file(preset_name: str): + """Fetches an override preset from a file""" + + preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") + if preset_path.exists(): + with open(preset_path, "r", encoding="utf8") as raw_preset: + preset = yaml.safe_load(raw_preset) + set_overrides_from_dict(preset) + + logger.info("Applied sampler overrides from file.") + else: + logger.warn( + f"Sampler override file named \"{preset_name}\" was not found. " + + "Make sure it's located in the sampler_overrides folder." + ) + + +# TODO: Maybe move these into the class +# Classmethods aren't recognized in pydantic default_factories +def get_default_sampler_value(key, fallback=None): + """Gets an overridden default sampler value""" + + return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback) + + +def apply_forced_sampler_overrides(params: SamplerParams): + """Forcefully applies overrides if specified by the user""" + + for var, value in DEFAULT_OVERRIDES.items(): + override = value.get("override") + force = unwrap(value.get("force"), False) + if force and override: + setattr(params, var, override) diff --git a/config_sample.yml b/config_sample.yml index 7f88d946..89368acf 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -27,6 +27,14 @@ logging: # Enable generation parameter logging (default: False) generation_params: False +# Options for sampling +sampling: + # Override preset name. Find this in the sampler-overrides folder (default: None) + # This overrides default fallbacks for sampler values that are passed to the API + # Server-side overrides are NOT needed by default + # WARNING: Using this can result in a generation speed penalty + #override_preset: + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/main.py b/main.py index 0ba96505..218d9c09 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from common.args import convert_args_to_dict, init_argparser from common.auth import check_admin_key, check_api_key, load_auth_keys from common.config import ( + get_sampling_config, override_config_from_args, read_config_from_file, get_gen_logging_config, @@ -25,6 +26,7 @@ get_network_config, ) from common.generators import call_with_semaphore, generate_with_semaphore +from common.sampling import get_overrides_from_file from common.templating import get_all_templates, get_prompt_from_template from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap from common.logger import init_logger @@ -522,6 +524,12 @@ def entrypoint(args: Optional[dict] = None): gen_logging.broadcast_status() + # Set sampler parameter overrides if provided + sampling_config = get_sampling_config() + sampling_override_preset = sampling_config.get("override_preset") + if sampling_override_preset: + get_overrides_from_file(sampling_override_preset) + # If an initial model name is specified, create a container # and load the model model_config = get_model_config() diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml new file mode 100644 index 00000000..9c661a14 --- /dev/null +++ b/sampler_overrides/sample_preset.yml @@ -0,0 +1,94 @@ +# Sample YAML file for override presets. +# Each block corresponds to a sampler fallback override. Remove ones that you don't need. +# "force" always overrides the sampler to the specified value. +# For example, a top-p override of 1.5 with force = true will make every API request have a top_p value of 1.5 + +# You can use https://www.yamllint.com/ if you want to check your YAML formatting. + +# TODO: Improve documentation for each field + +# MARK: Misc generation parameters +max_tokens: + override: 150 + force: false +stop: + override: [] + force: false +token_healing: + override: false + force: false + +# MARK: Temperature +temperature: + override: 1.0 + force: false +temperature_last: + override: false + force: false + +# MARK: Alphabet soup +top_k: + override: 0 + force: false +top_p: + override: 1.0 + force: false +top_a: + override: 0.0 + force: false +min_p: + override: 0.0 + force: false +tfs: + override: 0.0 + force: false +typical: + override: 1.0 + force: false + +# MARK: Penalty settings +frequency_penalty: + override: 0.0 + force: false +presence_penalty: + override: 0.0 + force: false +repetition_penalty: + override: 1.0 + force: false +repetition_decay: + override: 0 + force: false +penalty_range: + override: -1 + force: false + +# MARK: Mirostat +mirostat_mode: + override: 0 + force: false +mirostat_tau: + override: 1.5 + force: false +mirostat_eta: + override: 0.3 + force: false + +# MARK: Token options +add_bos_token: + override: true + force: false +ban_eos_token: + override: false + force: false +logit_bias: + override: + force: false + +# MARK: CFG scale +cfg_scale: + override: 1.0 + force: false +negative_prompt: + override: + force: false From 4cf231d85ea9717d6567cdd17d26e75ade388b40 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 22 Jan 2024 23:13:52 -0500 Subject: [PATCH 3/7] API: Add template switching and unload endpoints Templates can be switched and unloaded without reloading the entire model. Signed-off-by: kingbri --- OAI/types/template.py | 6 ++++++ backends/exllamav2/model.py | 31 ++++++++++++++----------------- common/templating.py | 15 +++++++++------ main.py | 34 ++++++++++++++++++++++++++++++++-- 4 files changed, 61 insertions(+), 25 deletions(-) diff --git a/OAI/types/template.py b/OAI/types/template.py index 03745473..d72d6210 100644 --- a/OAI/types/template.py +++ b/OAI/types/template.py @@ -7,3 +7,9 @@ class TemplateList(BaseModel): object: str = "list" data: List[str] = Field(default_factory=list) + + +class TemplateSwitchRequest(BaseModel): + """Request to switch a template.""" + + name: str diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6ee186b4..0f91b052 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -163,30 +163,27 @@ def progress(loaded_modules: int, total_modules: int, if prompt_template_name: logger.info("Loading prompt template with name " f"{prompt_template_name}") # Read the template - self.prompt_template = get_template_from_file(prompt_template_name) - else: - # Then try finding the template from the tokenizer_config.json - self.prompt_template = get_template_from_model_json( - pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", - "from_tokenizer_config", - ) + try: + self.prompt_template = get_template_from_file(prompt_template_name) + except FileNotFoundError: + self.prompt_template = None - # Try finding the chat template from the model's config.json - # TODO: This may not even be used with huggingface models, - # mark for removal. - if self.prompt_template is None: + # Then try finding the template from the tokenizer_config.json + try: self.prompt_template = get_template_from_model_json( - pathlib.Path(self.config.model_config), + pathlib.Path(self.config.model_dir) / "tokenizer_config.json", "chat_template", - "from_model_config", + "from_tokenizer_config", ) + except FileNotFoundError: + self.prompt_template = None # If that fails, attempt fetching from model name - if self.prompt_template is None: + try: template_match = find_template_from_model(model_directory) - if template_match: - self.prompt_template = get_template_from_file(template_match) + self.prompt_template = get_template_from_file(template_match) + except (LookupError, FileNotFoundError): + self.prompt_template = None # Catch all for template lookup errors if self.prompt_template: diff --git a/common/templating.py b/common/templating.py index ddc0ca1b..bbd0596b 100644 --- a/common/templating.py +++ b/common/templating.py @@ -68,26 +68,29 @@ def find_template_from_model(model_path: pathlib.Path): """Find a matching template name from a model path.""" model_name = model_path.name template_files = get_all_templates() + for filepath in template_files: template_name = filepath.stem.lower() # Check if the template name is present in the model name if template_name in model_name.lower(): return template_name - - return None + else: + raise LookupError("Could not find template from model name.") def get_template_from_file(prompt_template_name: str): """Get a template from a jinja file.""" + template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") if template_path.exists(): with open(template_path, "r", encoding="utf8") as raw_template: return PromptTemplate( name=prompt_template_name, template=raw_template.read() ) - - return None + else: + # Let the user know if the template file isn't found + raise FileNotFoundError(f'Template "{prompt_template_name}" not found.') # Get a template from a JSON file @@ -100,5 +103,5 @@ def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): chat_template = model_config.get(key) if chat_template: return PromptTemplate(name=name, template=chat_template) - - return None + else: + raise FileNotFoundError(f'Model JSON path "{json_path}" not found.') diff --git a/main.py b/main.py index 218d9c09..df0d13df 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,11 @@ ) from common.generators import call_with_semaphore, generate_with_semaphore from common.sampling import get_overrides_from_file -from common.templating import get_all_templates, get_prompt_from_template +from common.templating import ( + get_all_templates, + get_prompt_from_template, + get_template_from_file, +) from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap from common.logger import init_logger from OAI.types.completion import CompletionRequest @@ -39,7 +43,7 @@ ModelLoadResponse, ModelCardParameters, ) -from OAI.types.template import TemplateList +from OAI.types.template import TemplateList, TemplateSwitchRequest from OAI.types.token import ( TokenEncodeRequest, TokenEncodeResponse, @@ -258,6 +262,32 @@ async def get_templates(): return TemplateList(data=template_strings) +@app.post( + "/v1/template/switch", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) +async def switch_template(data: TemplateSwitchRequest): + """Switch the currently loaded template""" + if not data.name: + raise HTTPException(400, "New template name not found.") + + try: + template = get_template_from_file(data.name) + MODEL_CONTAINER.prompt_template = template + except FileNotFoundError as e: + raise HTTPException(400, "Template does not exist. Check the name?") from e + + +@app.post( + "/v1/template/unload", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) +async def unload_template(): + """Unloads the currently selected template""" + + MODEL_CONTAINER.prompt_template = None + + # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) From a9a128cbefd1232dfb5e2de94d3f06923b4c9bbd Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jan 2024 01:20:58 -0500 Subject: [PATCH 4/7] API: Add sampler override switching Allow users to switch the currently overriden samplers via the API so a restart isn't required to switch the overrides. Signed-off-by: kingbri --- OAI/types/sampler_overrides.py | 26 +++++++++++++++++ common/sampling.py | 14 ++++++--- main.py | 53 ++++++++++++++++++++++++++++++++-- 3 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 OAI/types/sampler_overrides.py diff --git a/OAI/types/sampler_overrides.py b/OAI/types/sampler_overrides.py new file mode 100644 index 00000000..c4b75d82 --- /dev/null +++ b/OAI/types/sampler_overrides.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class SamplerOverrideSwitchRequest(BaseModel): + """Sampler override switch request""" + + preset: Optional[str] = Field( + default=None, description="Pass a sampler override preset name" + ) + + overrides: Optional[dict] = Field( + default=None, + description=( + "Sampling override parent takes in individual keys and overrides." + + "Ignored if preset is provided." + ), + examples=[ + { + "top_p": { + "override": 1.5, + "force": False, + } + } + ], + ) diff --git a/common/sampling.py b/common/sampling.py index 01d11f10..53defcc1 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -166,6 +166,10 @@ def to_gen_params(self): DEFAULT_OVERRIDES = {} +def get_sampler_overrides(): + return DEFAULT_OVERRIDES + + def set_overrides_from_dict(new_overrides: dict): """Wrapper function to update sampler overrides""" @@ -174,10 +178,10 @@ def set_overrides_from_dict(new_overrides: dict): if isinstance(new_overrides, dict): DEFAULT_OVERRIDES = new_overrides else: - raise TypeError("new sampler overrides must be a dict!") + raise TypeError("New sampler overrides must be a dict!") -def get_overrides_from_file(preset_name: str): +def set_overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") @@ -188,11 +192,13 @@ def get_overrides_from_file(preset_name: str): logger.info("Applied sampler overrides from file.") else: - logger.warn( - f"Sampler override file named \"{preset_name}\" was not found. " + error_message = ( + f'Sampler override file named "{preset_name}" was not found. ' + "Make sure it's located in the sampler_overrides folder." ) + raise FileNotFoundError(error_message) + # TODO: Maybe move these into the class # Classmethods aren't recognized in pydantic default_factories diff --git a/main.py b/main.py index df0d13df..cbd8f600 100644 --- a/main.py +++ b/main.py @@ -26,7 +26,11 @@ get_network_config, ) from common.generators import call_with_semaphore, generate_with_semaphore -from common.sampling import get_overrides_from_file +from common.sampling import ( + get_sampler_overrides, + set_overrides_from_file, + set_overrides_from_dict, +) from common.templating import ( get_all_templates, get_prompt_from_template, @@ -43,6 +47,7 @@ ModelLoadResponse, ModelCardParameters, ) +from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest from OAI.types.template import TemplateList, TemplateSwitchRequest from OAI.types.token import ( TokenEncodeRequest, @@ -288,6 +293,47 @@ async def unload_template(): MODEL_CONTAINER.prompt_template = None +# Sampler override endpoints +@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +async def list_sampler_overrides(): + """API wrapper to list all currently applied sampler overrides""" + + return get_sampler_overrides() + + +@app.post( + "/v1/sampling/override/switch", + dependencies=[Depends(check_admin_key)], +) +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): + """Switch the currently loaded override preset""" + + if data.preset: + try: + set_overrides_from_file(data.preset) + except FileNotFoundError as e: + raise HTTPException( + 400, "Sampler override preset does not exist. Check the name?" + ) from e + elif data.overrides: + set_overrides_from_dict(data.overrides) + else: + raise HTTPException( + 400, "A sampler override preset or dictionary wasn't provided." + ) + + +@app.post( + "/v1/sampling/override/unload", + dependencies=[Depends(check_admin_key)], +) +async def unload_sampler_override(): + """Unloads the currently selected override preset""" + + set_overrides_from_dict({}) + + # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) @@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None): sampling_config = get_sampling_config() sampling_override_preset = sampling_config.get("override_preset") if sampling_override_preset: - get_overrides_from_file(sampling_override_preset) + try: + set_overrides_from_file(sampling_override_preset) + except FileNotFoundError as e: + logger.warning(str(e)) # If an initial model name is specified, create a container # and load the model From 243acfe70975aa33c35de6d52a5bc4c880584e40 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jan 2024 01:26:38 -0500 Subject: [PATCH 5/7] Model: Dynamically scale generate_window Allows for adjustment of reservation space at the end of the context before rolling it. This should be scaled as a model's max_seq_len goes up. Signed-off-by: kingbri --- backends/exllamav2/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 0f91b052..d0c9f442 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -554,7 +554,9 @@ def generate_gen(self, prompt: str, **kwargs): token_healing = unwrap(kwargs.get("token_healing"), False) max_tokens = unwrap(kwargs.get("max_tokens"), 150) stream_interval = unwrap(kwargs.get("stream_interval"), 0) - generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens) + generate_window = max( + unwrap(kwargs.get("generate_window"), 512), max_tokens // 8 + ) # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() From 97d4f40e13f74162ae7a501d5fed4b61b6779cc7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jan 2024 23:36:35 -0500 Subject: [PATCH 6/7] Model: Fix prompt template initialization The previous commit iterated through multiple try conditions which made it so the user has to provide a dummy prompt template. Now, template loading is fallback based. Run through a loop of functions and return if one of them succeeds. Signed-off-by: kingbri --- backends/exllamav2/model.py | 58 ++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d0c9f442..ac939d44 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -158,32 +158,10 @@ def progress(loaded_modules: int, total_modules: int, self.config.set_low_mem() """ - # Set prompt template override if provided - prompt_template_name = kwargs.get("prompt_template") - if prompt_template_name: - logger.info("Loading prompt template with name " f"{prompt_template_name}") - # Read the template - try: - self.prompt_template = get_template_from_file(prompt_template_name) - except FileNotFoundError: - self.prompt_template = None - - # Then try finding the template from the tokenizer_config.json - try: - self.prompt_template = get_template_from_model_json( - pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", - "from_tokenizer_config", - ) - except FileNotFoundError: - self.prompt_template = None - - # If that fails, attempt fetching from model name - try: - template_match = find_template_from_model(model_directory) - self.prompt_template = get_template_from_file(template_match) - except (LookupError, FileNotFoundError): - self.prompt_template = None + # Try to set prompt template + self.prompt_template = self.find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) # Catch all for template lookup errors if self.prompt_template: @@ -250,6 +228,34 @@ def progress(loaded_modules: int, total_modules: int, self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 + def find_prompt_template(self, prompt_template_name, model_directory): + """Tries to find a prompt template using various methods""" + + logger.info("Loading prompt template with name " f"{prompt_template_name}") + + find_template_functions = [ + lambda: get_template_from_model_json( + pathlib.Path(self.config.model_dir) / "tokenizer_config.json", + "chat_template", + "from_tokenizer_config", + ), + lambda: get_template_from_file(find_template_from_model(model_directory)), + ] + + # Add lookup from prompt template name if provided + if prompt_template_name: + find_template_functions.insert( + 0, lambda: get_template_from_file(prompt_template_name) + ) + + for func in find_template_functions: + try: + prompt_template = func() + if prompt_template is not None: + return prompt_template + except (FileNotFoundError, LookupError): + continue + def calculate_rope_alpha(self, base_seq_len): """Calculate the rope alpha value for a given sequence length.""" ratio = self.config.max_seq_len / base_seq_len From 52918fd51a1452093cf90fc719132ab2b5506731 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 25 Jan 2024 00:11:30 -0500 Subject: [PATCH 7/7] API + Model: Add new parameters and clean up documentation The example JSON fields were changed because of the new sampler default strategy. Fix these by manually changing the values. Also add support for fasttensors and expose generate_window to the API. It's recommended to not adjust generate_window as it's dynamically scaled based on max_seq_len by default. Signed-off-by: kingbri --- backends/exllamav2/model.py | 17 ++++++++++++++-- common/sampling.py | 30 +++++++++++++++++++++-------- config_sample.yml | 3 +++ sampler_overrides/sample_preset.yml | 5 +++++ 4 files changed, 45 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ac939d44..52764e22 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -138,13 +138,25 @@ def progress(loaded_modules: int, total_modules: int, kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) + # Enable CFG if present + use_cfg = unwrap(kwargs.get("use_cfg"), False) if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"): - self.use_cfg = unwrap(kwargs.get("use_cfg"), False) - else: + self.use_cfg = use_cfg + elif use_cfg: logger.warning( "CFG is not supported by the currently installed ExLlamaV2 version." ) + # Enable fasttensors loading if present + use_fasttensors = unwrap(kwargs.get("fasttensors"), False) + if hasattr(ExLlamaV2Config, "fasttensors"): + self.config.fasttensors = use_fasttensors + elif use_fasttensors: + logger.warning( + "fasttensors is not supported by " + "the currently installed ExllamaV2 version." + ) + # Turn off flash attention if CFG is on # Workaround until batched FA2 is fixed in exllamav2 upstream self.config.no_flash_attn = ( @@ -668,6 +680,7 @@ def generate_gen(self, prompt: str, **kwargs): **vars(gen_settings), token_healing=token_healing, auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, stop_conditions=stop_conditions, diff --git a/common/sampling.py b/common/sampling.py index 53defcc1..8c28002d 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -17,7 +17,13 @@ class SamplerParams(BaseModel): """Common class for sampler params that are used in APIs""" max_tokens: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens", 150) + default_factory=lambda: get_default_sampler_value("max_tokens", 150), + examples=[150], + ) + + generate_window: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("generate_window"), + examples=[512], ) stop: Optional[Union[str, List[str]]] = Field( @@ -29,7 +35,8 @@ class SamplerParams(BaseModel): ) temperature: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("temperature", 1.0) + default_factory=lambda: get_default_sampler_value("temperature", 1.0), + examples=[1.0], ) temperature_last: Optional[bool] = Field( @@ -41,7 +48,7 @@ class SamplerParams(BaseModel): ) top_p: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("top_p", 1.0) + default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0] ) top_a: Optional[float] = Field( @@ -65,7 +72,8 @@ class SamplerParams(BaseModel): ) repetition_penalty: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + examples=[1.0], ) repetition_decay: Optional[int] = Field( @@ -77,11 +85,13 @@ class SamplerParams(BaseModel): ) mirostat_tau: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5), + examples=[1.5], ) mirostat_eta: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3), + examples=[0.3], ) add_bos_token: Optional[bool] = Field( @@ -89,7 +99,8 @@ class SamplerParams(BaseModel): ) ban_eos_token: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + examples=[False], ) logit_bias: Optional[Dict[int, float]] = Field( @@ -106,6 +117,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("typical", 1.0), validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", + examples=[1.0], ) penalty_range: Optional[int] = Field( @@ -122,6 +134,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), description="Aliases: guidance_scale", + examples=[1.0], ) def to_gen_params(self): @@ -135,8 +148,9 @@ def to_gen_params(self): self.stop = [self.stop] return { - "stop": self.stop, "max_tokens": self.max_tokens, + "generate_window": self.generate_window, + "stop": self.stop, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, diff --git a/config_sample.yml b/config_sample.yml index 89368acf..cf1ddb53 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -97,6 +97,9 @@ model: # WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream) #use_cfg: False + # Enables fasttensors to possibly increase model loading speeds (default: False) + #fasttensors: true + # Options for draft models (speculative decoding). This will use more VRAM! #draft: # Overrides the directory to look for draft (default: models) diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index 9c661a14..eae17ab4 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -18,6 +18,11 @@ token_healing: override: false force: false +# Commented out because the default is dynamically scaled +#generate_window: + #override: 512 + #force: false + # MARK: Temperature temperature: override: 1.0