diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 0f9ba3a2..b1eb127b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -55,7 +55,7 @@ TemplateLoadError, find_template_from_model, ) -from common.transformers_utils import GenerationConfig, HuggingFaceConfig +from common.transformers_utils import GenerationConfig from common.utils import coalesce, unwrap @@ -84,7 +84,6 @@ class ExllamaV2Container: draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None generation_config: Optional[GenerationConfig] = None - hf_config: Optional[HuggingFaceConfig] = None # GPU split vars gpu_split: Optional[list] = None @@ -129,9 +128,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Check if the model arch is compatible with various exl2 features self.config.arch_compat_overrides() - # Create the hf_config - self.hf_config = await HuggingFaceConfig.from_file(model_directory) - # Load generation config overrides generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): diff --git a/common/sampling.py b/common/sampling.py index 8172067f..c2b4c3c7 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -3,10 +3,17 @@ import aiofiles import json import pathlib +from pydantic_core import ValidationError from ruamel.yaml import YAML from copy import deepcopy from loguru import logger -from pydantic import AliasChoices, BaseModel, Field +from pydantic import ( + AliasChoices, + BaseModel, + Field, + field_validator, + model_validator, +) from typing import Dict, List, Optional, Union from common.utils import filter_none_values, unwrap @@ -21,6 +28,7 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("max_tokens", "max_length"), description="Aliases: max_length", examples=[150], + ge=0, ) min_tokens: Optional[int] = Field( @@ -28,11 +36,13 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("min_tokens", "min_length"), description="Aliases: min_length", examples=[0], + ge=0, ) generate_window: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("generate_window"), examples=[512], + ge=0, ) stop: Optional[Union[str, List[Union[str, int]]]] = Field( @@ -66,22 +76,28 @@ class BaseSamplerRequest(BaseModel): temperature: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("temperature", 1.0), examples=[1.0], + ge=0, + le=10, ) temperature_last: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("temperature_last", False) + default_factory=lambda: get_default_sampler_value("temperature_last", False), ) smoothing_factor: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0), + ge=0, ) top_k: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("top_k", 0), + ge=-1, ) top_p: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("top_p", 1.0), + ge=0, + le=1, examples=[1.0], ) @@ -103,6 +119,8 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", examples=[1.0], + gt=0, + le=1, ) skew: Optional[float] = Field( @@ -119,11 +137,13 @@ class BaseSamplerRequest(BaseModel): ) frequency_penalty: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0) + default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0), + ge=0, ) presence_penalty: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0) + default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0), + ge=0, ) repetition_penalty: Optional[float] = Field( @@ -131,6 +151,7 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("repetition_penalty", "rep_pen"), description="Aliases: rep_pen", examples=[1.0], + gt=0, ) penalty_range: Optional[int] = Field( @@ -164,7 +185,7 @@ class BaseSamplerRequest(BaseModel): dry_range: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("dry_range", 0), - alias=AliasChoices("dry_range", "dry_penalty_last_n"), + validation_alias=AliasChoices("dry_range", "dry_penalty_last_n"), description=("Aliases: dry_penalty_last_n"), ) @@ -172,6 +193,8 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) ) + mirostat: Optional[bool] = False + mirostat_mode: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) ) @@ -239,6 +262,7 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("max_temp", "dynatemp_high"), description="Aliases: dynatemp_high", examples=[1.0], + ge=0, ) min_temp: Optional[float] = Field( @@ -246,158 +270,82 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("min_temp", "dynatemp_low"), description="Aliases: dynatemp_low", examples=[1.0], + ge=0, ) temp_exponent: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"), examples=[1.0], + ge=0, ) - # TODO: Return back to adaptable class-based validation But that's just too much - # abstraction compared to simple if statements at the moment - def validate_params(self): - """ - Validates sampler parameters to be within sane ranges. - """ + @field_validator("top_k", mode="before") + def convert_top_k(cls, v): + """Fixes instance if Top-K is -1.""" - # Temperature - if self.temperature < 0.0: - raise ValueError( - "Temperature must be a non-negative value. " f"Got {self.temperature}" - ) + if v == -1: + logger.warning("Provided a top-k value of -1. Converting to 0 instead.") + return 0 - # Smoothing factor - if self.smoothing_factor < 0.0: - raise ValueError( - "Smoothing factor must be a non-negative value. " - f"Got {self.smoothing_factor}" - ) + return v - # Top K - if self.top_k < 0: - raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}") + @field_validator("stop", "banned_strings", mode="before") + def convert_str_to_list(cls, v): + """Convert single string to list of strings.""" - # Top P - if self.top_p < 0.0 or self.top_p > 1.0: - raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}") + if isinstance(v, str): + return [v] - # Repetition Penalty - if self.repetition_penalty <= 0.0: - raise ValueError( - "Repetition penalty must be a positive value. " - f"Got {self.repetition_penalty}" - ) + return v - # Typical - if self.typical <= 0 and self.typical > 1: - raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}") + @field_validator("banned_tokens", "allowed_tokens", mode="before") + def convert_tokens_to_int_list(cls, v): + """Convert comma-separated string of numbers to a list of integers.""" - # Dynatemp values - if self.max_temp < 0.0: - raise ValueError( - "Max temp must be a non-negative value. ", f"Got {self.max_temp}" - ) + if isinstance(v, str): + return [int(x) for x in v.replace(" ", "").split(",") if x.isdigit()] - if self.min_temp < 0.0: - raise ValueError( - "Min temp must be a non-negative value. ", f"Got {self.min_temp}" - ) + return v - if self.temp_exponent < 0.0: - raise ValueError( - "Temp exponent must be a non-negative value. ", - f"Got {self.temp_exponent}", + @field_validator("dry_sequence_breakers", mode="before") + def parse_json_if_needed(cls, v): + """Parse dry_sequence_breakers string to JSON array.""" + + if isinstance(v, str) and not v.startswith("["): + v = f"[{v}]" + + try: + return json.loads(v) if isinstance(v, str) else v + except Exception: + logger.warning( + "Could not parse DRY sequence breakers. Using an empty array." ) + return [] # Return empty list if parsing fails + + @field_validator("mirostat_mode", mode="before") + def convert_mirostat(cls, v, field_info): + """Mirostat is enabled if mirostat_mode == 2.""" - def to_gen_params(self, **kwargs): - """Converts samplers to internal generation params""" + if v == 2: + field_info.data["mirostat"] = True - # Add forced overrides if present + return v + + @model_validator(mode="after") + def after_validate(self): + # FIXME: find a better way to register this + # Maybe make a function to assign values to the + # model if they do not exist post creation apply_forced_sampler_overrides(self) - self.validate_params() - - # Convert stop to an array of strings - if self.stop and isinstance(self.stop, str): - self.stop = [self.stop] - - # Convert banned_strings to an array of strings - if self.banned_strings and isinstance(self.banned_strings, str): - self.banned_strings = [self.banned_strings] - - # Convert string banned and allowed tokens to an integer list - if self.banned_tokens and isinstance(self.banned_tokens, str): - self.banned_tokens = [ - int(x) for x in self.banned_tokens.split(",") if x.isdigit() - ] - - if self.allowed_tokens and isinstance(self.allowed_tokens, str): - self.allowed_tokens = [ - int(x) for x in self.allowed_tokens.split(",") if x.isdigit() - ] - - # Convert sequence breakers into an array of strings - # NOTE: This sampler sucks to parse. - if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str): - if not self.dry_sequence_breakers.startswith("["): - self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]" - - try: - self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers) - except Exception: - self.dry_sequence_breakers = [] - - gen_params = { - "max_tokens": self.max_tokens, - "min_tokens": self.min_tokens, - "generate_window": self.generate_window, - "stop": self.stop, - "banned_strings": self.banned_strings, - "add_bos_token": self.add_bos_token, - "ban_eos_token": self.ban_eos_token, - "skip_special_tokens": self.skip_special_tokens, - "token_healing": self.token_healing, - "logit_bias": self.logit_bias, - "banned_tokens": self.banned_tokens, - "allowed_tokens": self.allowed_tokens, - "temperature": self.temperature, - "temperature_last": self.temperature_last, - "min_temp": self.min_temp, - "max_temp": self.max_temp, - "temp_exponent": self.temp_exponent, - "smoothing_factor": self.smoothing_factor, - "top_k": self.top_k, - "top_p": self.top_p, - "top_a": self.top_a, - "typical": self.typical, - "min_p": self.min_p, - "tfs": self.tfs, - "skew": self.skew, - "xtc_probability": self.xtc_probability, - "xtc_threshold": self.xtc_threshold, - "frequency_penalty": self.frequency_penalty, - "presence_penalty": self.presence_penalty, - "repetition_penalty": self.repetition_penalty, - "penalty_range": self.penalty_range, - "dry_multiplier": self.dry_multiplier, - "dry_base": self.dry_base, - "dry_allowed_length": self.dry_allowed_length, - "dry_sequence_breakers": self.dry_sequence_breakers, - "dry_range": self.dry_range, - "repetition_decay": self.repetition_decay, - "mirostat": self.mirostat_mode == 2, - "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta, - "cfg_scale": self.cfg_scale, - "negative_prompt": self.negative_prompt, - "json_schema": self.json_schema, - "regex_pattern": self.regex_pattern, - "grammar_string": self.grammar_string, - "speculative_ngram": self.speculative_ngram, - } - - return {**gen_params, **kwargs} + if self.min_temp and self.max_temp and self.min_temp > self.max_temp: + raise ValidationError("min temp cannot be more then max temp") + + if self.min_tokens and self.max_tokens and self.min_tokens > self.max_tokens: + raise ValidationError("min tokens cannot be more then max tokens") + + return self class SamplerOverridesContainer(BaseModel): diff --git a/common/transformers_utils.py b/common/transformers_utils.py index c00fef41..a765b9f8 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -2,7 +2,6 @@ import json import pathlib from typing import List, Optional, Union -from loguru import logger from pydantic import BaseModel @@ -13,7 +12,6 @@ class GenerationConfig(BaseModel): """ eos_token_id: Optional[Union[int, List[int]]] = None - bad_words_ids: Optional[List[List[int]]] = None @classmethod async def from_file(cls, model_directory: pathlib.Path): @@ -38,12 +36,12 @@ def eos_tokens(self): class HuggingFaceConfig(BaseModel): """ + DEPRECATED: Currently a stub and doesn't do anything. + An abridged version of HuggingFace's model config. Will be expanded as needed. """ - badwordsids: Optional[str] = None - @classmethod async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" @@ -55,20 +53,3 @@ async def from_file(cls, model_directory: pathlib.Path): contents = await hf_config_json.read() hf_config_dict = json.loads(contents) return cls.model_validate(hf_config_dict) - - def get_badwordsids(self): - """Wrapper method to fetch badwordsids.""" - - if self.badwordsids: - try: - bad_words_list = json.loads(self.badwordsids) - return bad_words_list - except json.JSONDecodeError: - logger.warning( - "Skipping badwordsids from config.json " - "since it's not a valid array." - ) - - return [] - else: - return [] diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 310a3809..ea894ead 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -137,7 +137,7 @@ async def get_version(): async def get_extra_version(): """Impersonate Koboldcpp.""" - return {"result": "KoboldCpp", "version": "1.71"} + return {"result": "KoboldCpp", "version": "1.74"} @kai_router.get("/config/soft_prompts_list") diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 310484b4..5432130b 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -1,9 +1,9 @@ +from functools import partial +from pydantic import BaseModel, Field, field_validator from typing import List, Optional -from pydantic import BaseModel, Field -from common import model from common.sampling import BaseSamplerRequest, get_default_sampler_value -from common.utils import flat_map, unwrap +from common.utils import unwrap class GenerateRequest(BaseSamplerRequest): @@ -11,29 +11,31 @@ class GenerateRequest(BaseSamplerRequest): genkey: Optional[str] = None use_default_badwordsids: Optional[bool] = False dynatemp_range: Optional[float] = Field( - default_factory=get_default_sampler_value("dynatemp_range") + default_factory=partial(get_default_sampler_value, "dynatemp_range") ) - def to_gen_params(self, **kwargs): - # Exl2 uses -1 to include all tokens in repetition penalty - if self.penalty_range == 0: - self.penalty_range = -1 + # Validate on the parent class's fields + @field_validator("penalty_range", mode="before") + def validate_penalty_range(cls, v): + return -1 if v == 0 else v - if self.dynatemp_range: - self.min_temp = self.temperature - self.dynatemp_range - self.max_temp = self.temperature + self.dynatemp_range + @field_validator("dynatemp_range", mode="before") + def validate_temp_range(cls, v, field_info): + if v > 0: + # A default temperature is always 1 + temperature = unwrap(field_info.data.get("temperature"), 1) - # Move badwordsids into banned tokens for generation - if self.use_default_badwordsids: - bad_words_ids = unwrap( - model.container.generation_config.bad_words_ids, - model.container.hf_config.get_badwordsids(), - ) + field_info.data["min_temp"] = temperature - v + field_info.data["max_temp"] = temperature + v - if bad_words_ids: - self.banned_tokens += flat_map(bad_words_ids) + return v - return super().to_gen_params(**kwargs) + # Currently only serves to ban EOS token, but can change + @field_validator("use_default_badwordsids", mode="before") + def validate_badwordsids(cls, v, field_info): + field_info.data["ban_eos_token"] = v + + return v class GenerateResponseResult(BaseModel): diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py index 5febcffe..8ffbf0d8 100644 --- a/endpoints/Kobold/utils/generation.py +++ b/endpoints/Kobold/utils/generation.py @@ -53,7 +53,7 @@ async def _stream_collector(data: GenerateRequest, request: Request): logger.info(f"Received Kobold generation request {data.genkey}") generator = model.container.generate_gen( - data.prompt, data.genkey, abort_event, **data.to_gen_params() + request_id=data.genkey, abort_event=abort_event, **data.model_dump() ) async for generation in generator: if disconnect_task.done(): diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 6970adf7..640ead7b 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -33,12 +33,16 @@ class CommonCompletionRequest(BaseSamplerRequest): stream: Optional[bool] = False stream_options: Optional[ChatCompletionStreamOptions] = None logprobs: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("logprobs", 0) + default_factory=lambda: get_default_sampler_value("logprobs", 0), + ge=0, ) response_format: Optional[CompletionResponseFormat] = Field( default_factory=CompletionResponseFormat ) - n: Optional[int] = Field(default_factory=lambda: get_default_sampler_value("n", 1)) + n: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("n", 1), + ge=1, + ) # Extra OAI request stuff best_of: Optional[int] = Field( @@ -53,18 +57,3 @@ class CommonCompletionRequest(BaseSamplerRequest): user: Optional[str] = Field( description="Not parsed. Only used for OAI compliance.", default=None ) - - def validate_params(self): - # Temperature - if self.n < 1: - raise ValueError(f"n must be greater than or equal to 1. Got {self.n}") - - return super().validate_params() - - def to_gen_params(self): - extra_gen_params = { - "stream": self.stream, - "logprobs": self.logprobs, - } - - return super().to_gen_params(**extra_gen_params) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 32e06969..3b5c07ff 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,7 +3,6 @@ import asyncio import pathlib from asyncio import CancelledError -from copy import deepcopy from typing import List, Optional import json @@ -291,13 +290,8 @@ async def stream_generate_chat_completion( try: logger.info(f"Received chat completion streaming request {request.state.id}") - gen_params = data.to_gen_params() - for n in range(0, data.n): - if n > 0: - task_gen_params = deepcopy(gen_params) - else: - task_gen_params = gen_params + task_gen_params = data.model_copy(deep=True) gen_task = asyncio.create_task( _stream_collector( @@ -306,7 +300,7 @@ async def stream_generate_chat_completion( prompt, request.state.id, abort_event, - **task_gen_params, + **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -381,21 +375,13 @@ async def generate_chat_completion( prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): gen_tasks: List[asyncio.Task] = [] - gen_params = data.to_gen_params() try: - for n in range(0, data.n): - # Deepcopy gen params above the first index - # to ensure nested structures aren't shared - if n > 0: - task_gen_params = deepcopy(gen_params) - else: - task_gen_params = gen_params - + for _ in range(0, data.n): gen_tasks.append( asyncio.create_task( model.container.generate( - prompt, request.state.id, **task_gen_params + prompt, request.state.id, **data.model_dump(exclude={"prompt"}) ) ) ) @@ -433,9 +419,9 @@ async def generate_tool_calls( # Copy to make sure the parent JSON schema doesn't get modified # FIXME: May not be necessary depending on how the codebase evolves - tool_data = deepcopy(data) + tool_data = data.model_copy(deep=True) tool_data.json_schema = tool_data.tool_call_schema - gen_params = tool_data.to_gen_params() + gen_params = tool_data.model_dump() for idx, gen in enumerate(generations): if gen["stop_str"] in tool_data.tool_call_start: diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 59f3844b..e9395255 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,7 +7,6 @@ import asyncio import pathlib from asyncio import CancelledError -from copy import deepcopy from fastapi import HTTPException, Request from typing import List, Union @@ -169,13 +168,8 @@ async def stream_generate_completion( try: logger.info(f"Received streaming completion request {request.state.id}") - gen_params = data.to_gen_params() - for n in range(0, data.n): - if n > 0: - task_gen_params = deepcopy(gen_params) - else: - task_gen_params = gen_params + task_gen_params = data.model_copy(deep=True) gen_task = asyncio.create_task( _stream_collector( @@ -184,7 +178,7 @@ async def stream_generate_completion( data.prompt, request.state.id, abort_event, - **task_gen_params, + **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -232,23 +226,19 @@ async def generate_completion( """Non-streaming generate for completions""" gen_tasks: List[asyncio.Task] = [] - gen_params = data.to_gen_params() try: logger.info(f"Recieved completion request {request.state.id}") - for n in range(0, data.n): - # Deepcopy gen params above the first index - # to ensure nested structures aren't shared - if n > 0: - task_gen_params = deepcopy(gen_params) - else: - task_gen_params = gen_params + for _ in range(0, data.n): + task_gen_params = data.model_copy(deep=True) gen_tasks.append( asyncio.create_task( model.container.generate( - data.prompt, request.state.id, **task_gen_params + data.prompt, + request.state.id, + **task_gen_params.model_dump(exclude={"prompt"}), ) ) ) diff --git a/endpoints/server.py b/endpoints/server.py index d6723a19..3555a5b4 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -36,27 +36,28 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None): ) api_servers = config.network.api_servers + api_servers = ( + api_servers + if api_servers + else [ + "oai", + ] + ) # Map for API id to server router router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter} # Include the OAI api by default - if api_servers: - for server in api_servers: - selected_server = router_mapping.get(server.lower()) - - if selected_server: - app.include_router(selected_server.setup()) - - logger.info(f"Starting {selected_server.api_name} API") - for path, url in selected_server.urls.items(): - formatted_url = url.format(host=host, port=port) - logger.info(f"{path}: {formatted_url}") - else: - app.include_router(OAIRouter.setup()) - for path, url in OAIRouter.urls.items(): - formatted_url = url.format(host=host, port=port) - logger.info(f"{path}: {formatted_url}") + for server in api_servers: + selected_server = router_mapping.get(server.lower()) + + if selected_server: + app.include_router(selected_server.setup()) + + logger.info(f"Starting {selected_server.api_name} API") + for path, url in selected_server.urls.items(): + formatted_url = url.format(host=host, port=port) + logger.info(f"{path}: {formatted_url}") # Include core API request paths app.include_router(CoreRouter)