Skip to content

Commit

Permalink
Sampling: Add universal validation system
Browse files Browse the repository at this point in the history
Rather than maintaining yet another function to validate sampler
ranges/values, embed them in fields which allows for less
maintainence in the future.

Also add validation for existing samplers that can corrupt
the sampling stack if set improperly.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 10, 2024
1 parent 9f1d891 commit 7e730e3
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class BaseSamplerRequest(BaseModel):
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature must be a non-negative value",
)

temperature_last: Optional[bool] = Field(
Expand All @@ -45,14 +47,21 @@ class BaseSamplerRequest(BaseModel):

smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
sample_validator=lambda value: value >= 0.0,
validation_error="Smoothing factor must be a non-negative value",
)

top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0)
default_factory=lambda: get_default_sampler_value("top_k", 0),
sample_validator=lambda value: value >= 0,
validation_error="Top K must be a non-negative value",
)

top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0]
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0 and value <= 1.0,
validation_error="Top P must be in [0, 1]",
)

top_a: Optional[float] = Field(
Expand All @@ -64,7 +73,8 @@ class BaseSamplerRequest(BaseModel):
)

tfs: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("tfs", 1.0)
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
examples=[1.0],
)

frequency_penalty: Optional[float] = Field(
Expand All @@ -78,6 +88,8 @@ class BaseSamplerRequest(BaseModel):
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0],
sample_validator=lambda value: value > 0.0,
validation_error="Repetition penalty must be a positive value",
)

repetition_decay: Optional[int] = Field(
Expand Down Expand Up @@ -122,6 +134,8 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
sample_validator=lambda value: value > 0.0 and value <= 1.0,
validation_error="Typical must be in (0, 1]",
)

penalty_range: Optional[int] = Field(
Expand All @@ -145,26 +159,57 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Max temperature must be a non-negative value",
)

min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Min temperature must be a non-negative value",
)

temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature exponent must be a non-negative value",
)

def validate_params(self):
"""
Validates if the class field satisfies a condition if present.
Validators are present in the extras section of a Pydantic field
to make it easy for adding more samplers if needed.
"""

for field_name, field_info in self.model_fields.items():
extra_field_info = unwrap(field_info.json_schema_extra, {})
if not extra_field_info:
continue

sample_validator = extra_field_info.get("sample_validator")
validation_error = unwrap(extra_field_info.get("validation_error"), "")

if sample_validator:
value = getattr(self, field_name)
if not sample_validator(value):
raise ValueError(f"{validation_error}. Got {value}")

def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""

# Add forced overrides if present
apply_forced_sampler_overrides(self)

self.validate_params()

# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
Expand Down

0 comments on commit 7e730e3

Please sign in to comment.