Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the sampling class #199

Merged
merged 23 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
88d4082
improve validation
SecretiveShell Sep 12, 2024
c5a06db
Merge branch 'theroyallab:main' into refactor-sampling
SecretiveShell Sep 20, 2024
035269c
remove to_gen_params functions
SecretiveShell Sep 21, 2024
cef9a03
update changes for all endpoint types
SecretiveShell Sep 21, 2024
4c17b87
Merge remote-tracking branch 'upstream/main' into refactor-sampling
SecretiveShell Sep 26, 2024
402898b
OAI: Fix calls to generation
bdashore3 Oct 25, 2024
a4e03ed
Merge branch 'main' of https://github.com/theroyallab/tabbyapi into s…
bdashore3 Oct 25, 2024
0936d1a
Sampling: Convert Top-K values of -1 to 0
bdashore3 Oct 25, 2024
f389178
Sampling: Format and space out
bdashore3 Oct 25, 2024
95f36e2
Sampling: Fix mirostat
bdashore3 Oct 26, 2024
dc08edc
Sampling: Format
bdashore3 Oct 26, 2024
e85706c
Sampling: Fix banned_tokens and allowed_tokens conversion
bdashore3 Oct 26, 2024
8a91c9b
Sampling: Add helpful log to dry_sequence_breakers
bdashore3 Oct 27, 2024
d8ad68d
Sampling: Apply validators in right order
bdashore3 Oct 27, 2024
4bce8d7
Endpoints: Format
bdashore3 Oct 27, 2024
9681e9c
Kobold: Update validators and fix parameter application
bdashore3 Oct 27, 2024
a749ea3
Sampling: Remove validate defaults and fix mirostat
bdashore3 Oct 27, 2024
8c24ee5
Kobold: Rework badwordsids
bdashore3 Oct 27, 2024
4ca8bc7
Model: Remove HuggingfaceConfig
bdashore3 Oct 27, 2024
e2b2947
Kobold: Bump kcpp impersonation
bdashore3 Oct 27, 2024
1dcb7d4
Sampling: Change alias to validation_alias
bdashore3 Oct 27, 2024
c7272f1
OAI: Use constraints for validation
bdashore3 Oct 27, 2024
a73a5a1
Tree: Lint
bdashore3 Oct 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
TemplateLoadError,
find_template_from_model,
)
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap


Expand Down Expand Up @@ -84,7 +84,6 @@ class ExllamaV2Container:
draft_cache_mode: str = "FP16"
max_batch_size: Optional[int] = None
generation_config: Optional[GenerationConfig] = None
hf_config: Optional[HuggingFaceConfig] = None

# GPU split vars
gpu_split: Optional[list] = None
Expand Down Expand Up @@ -129,9 +128,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
Expand Down
218 changes: 83 additions & 135 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
import aiofiles
import json
import pathlib
from pydantic_core import ValidationError
from ruamel.yaml import YAML
from copy import deepcopy
from loguru import logger
from pydantic import AliasChoices, BaseModel, Field
from pydantic import (
AliasChoices,
BaseModel,
Field,
field_validator,
model_validator,
)
from typing import Dict, List, Optional, Union

from common.utils import filter_none_values, unwrap
Expand All @@ -21,18 +28,21 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("max_tokens", "max_length"),
description="Aliases: max_length",
examples=[150],
ge=0,
)

min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
validation_alias=AliasChoices("min_tokens", "min_length"),
description="Aliases: min_length",
examples=[0],
ge=0,
)

generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512],
ge=0,
)

stop: Optional[Union[str, List[Union[str, int]]]] = Field(
Expand Down Expand Up @@ -66,22 +76,28 @@ class BaseSamplerRequest(BaseModel):
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
ge=0,
le=10,
)

temperature_last: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("temperature_last", False)
default_factory=lambda: get_default_sampler_value("temperature_last", False),
)

smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
ge=0,
)

top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
ge=-1,
)

top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
ge=0,
le=1,
examples=[1.0],
)

Expand All @@ -103,6 +119,8 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
gt=0,
le=1,
)

skew: Optional[float] = Field(
Expand All @@ -119,18 +137,21 @@ class BaseSamplerRequest(BaseModel):
)

frequency_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0),
ge=0,
)

presence_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0)
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0),
ge=0,
)

repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
description="Aliases: rep_pen",
examples=[1.0],
gt=0,
)

penalty_range: Optional[int] = Field(
Expand Down Expand Up @@ -164,14 +185,16 @@ class BaseSamplerRequest(BaseModel):

dry_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_range", 0),
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
validation_alias=AliasChoices("dry_range", "dry_penalty_last_n"),
description=("Aliases: dry_penalty_last_n"),
)

dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)

mirostat: Optional[bool] = False

mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
Expand Down Expand Up @@ -239,165 +262,90 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
ge=0,
)

min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
ge=0,
)

temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
ge=0,
)

# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
def validate_params(self):
"""
Validates sampler parameters to be within sane ranges.
"""
@field_validator("top_k", mode="before")
def convert_top_k(cls, v):
"""Fixes instance if Top-K is -1."""

# Temperature
if self.temperature < 0.0:
raise ValueError(
"Temperature must be a non-negative value. " f"Got {self.temperature}"
)
if v == -1:
logger.warning("Provided a top-k value of -1. Converting to 0 instead.")
return 0

# Smoothing factor
if self.smoothing_factor < 0.0:
raise ValueError(
"Smoothing factor must be a non-negative value. "
f"Got {self.smoothing_factor}"
)
return v

# Top K
if self.top_k < 0:
raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}")
@field_validator("stop", "banned_strings", mode="before")
def convert_str_to_list(cls, v):
"""Convert single string to list of strings."""

# Top P
if self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}")
if isinstance(v, str):
return [v]

# Repetition Penalty
if self.repetition_penalty <= 0.0:
raise ValueError(
"Repetition penalty must be a positive value. "
f"Got {self.repetition_penalty}"
)
return v

# Typical
if self.typical <= 0 and self.typical > 1:
raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}")
@field_validator("banned_tokens", "allowed_tokens", mode="before")
def convert_tokens_to_int_list(cls, v):
"""Convert comma-separated string of numbers to a list of integers."""

# Dynatemp values
if self.max_temp < 0.0:
raise ValueError(
"Max temp must be a non-negative value. ", f"Got {self.max_temp}"
)
if isinstance(v, str):
return [int(x) for x in v.replace(" ", "").split(",") if x.isdigit()]

if self.min_temp < 0.0:
raise ValueError(
"Min temp must be a non-negative value. ", f"Got {self.min_temp}"
)
return v

if self.temp_exponent < 0.0:
raise ValueError(
"Temp exponent must be a non-negative value. ",
f"Got {self.temp_exponent}",
@field_validator("dry_sequence_breakers", mode="before")
def parse_json_if_needed(cls, v):
"""Parse dry_sequence_breakers string to JSON array."""

if isinstance(v, str) and not v.startswith("["):
v = f"[{v}]"

try:
return json.loads(v) if isinstance(v, str) else v
except Exception:
logger.warning(
"Could not parse DRY sequence breakers. Using an empty array."
)
return [] # Return empty list if parsing fails

@field_validator("mirostat_mode", mode="before")
def convert_mirostat(cls, v, field_info):
"""Mirostat is enabled if mirostat_mode == 2."""

def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""
if v == 2:
field_info.data["mirostat"] = True

# Add forced overrides if present
return v

@model_validator(mode="after")
def after_validate(self):
# FIXME: find a better way to register this
# Maybe make a function to assign values to the
# model if they do not exist post creation
apply_forced_sampler_overrides(self)

self.validate_params()

# Convert stop to an array of strings
if self.stop and isinstance(self.stop, str):
self.stop = [self.stop]

# Convert banned_strings to an array of strings
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]

# Convert string banned and allowed tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
]

if self.allowed_tokens and isinstance(self.allowed_tokens, str):
self.allowed_tokens = [
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]

# Convert sequence breakers into an array of strings
# NOTE: This sampler sucks to parse.
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
if not self.dry_sequence_breakers.startswith("["):
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"

try:
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
except Exception:
self.dry_sequence_breakers = []

gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
"banned_strings": self.banned_strings,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"banned_tokens": self.banned_tokens,
"allowed_tokens": self.allowed_tokens,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"min_temp": self.min_temp,
"max_temp": self.max_temp,
"temp_exponent": self.temp_exponent,
"smoothing_factor": self.smoothing_factor,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"skew": self.skew,
"xtc_probability": self.xtc_probability,
"xtc_threshold": self.xtc_threshold,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"dry_multiplier": self.dry_multiplier,
"dry_base": self.dry_base,
"dry_allowed_length": self.dry_allowed_length,
"dry_sequence_breakers": self.dry_sequence_breakers,
"dry_range": self.dry_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
"json_schema": self.json_schema,
"regex_pattern": self.regex_pattern,
"grammar_string": self.grammar_string,
"speculative_ngram": self.speculative_ngram,
}

return {**gen_params, **kwargs}
if self.min_temp and self.max_temp and self.min_temp > self.max_temp:
raise ValidationError("min temp cannot be more then max temp")

if self.min_tokens and self.max_tokens and self.min_tokens > self.max_tokens:
raise ValidationError("min tokens cannot be more then max tokens")

return self


class SamplerOverridesContainer(BaseModel):
Expand Down
Loading
Loading