Skip to content

Commit

Permalink
Sampling: Make validators simpler
Browse files Browse the repository at this point in the history
Injecting into Pydantic fields caused issues with serialization for
documentation rendering. Rather than reinvent the wheel again,
switch to a chain of if statements for now. This may change in the
future if subclasses from the base sampler request need to be
validated as well.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 11, 2024
1 parent f627485 commit a79c42f
Showing 1 changed file with 51 additions and 34 deletions.
85 changes: 51 additions & 34 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class BaseSamplerRequest(BaseModel):
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature must be a non-negative value",
)

temperature_last: Optional[bool] = Field(
Expand All @@ -47,21 +45,15 @@ class BaseSamplerRequest(BaseModel):

smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
sample_validator=lambda value: value >= 0.0,
validation_error="Smoothing factor must be a non-negative value",
)

top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
sample_validator=lambda value: value >= 0,
validation_error="Top K must be a non-negative value",
)

top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0 and value <= 1.0,
validation_error="Top P must be in [0, 1]",
)

top_a: Optional[float] = Field(
Expand All @@ -88,8 +80,6 @@ class BaseSamplerRequest(BaseModel):
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0],
sample_validator=lambda value: value > 0.0,
validation_error="Repetition penalty must be a positive value",
)

repetition_decay: Optional[int] = Field(
Expand Down Expand Up @@ -134,8 +124,6 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
sample_validator=lambda value: value > 0.0 and value <= 1.0,
validation_error="Typical must be in (0, 1]",
)

penalty_range: Optional[int] = Field(
Expand All @@ -160,47 +148,76 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Max temperature must be a non-negative value",
)

min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Min temperature must be a non-negative value",
)

temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature exponent must be a non-negative value",
)

# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
def validate_params(self):
"""
Validates if the class field satisfies a condition if present.
Validators are present in the extras section of a Pydantic field
to make it easy for adding more samplers if needed.
Validates sampler parameters to be within sane ranges.
"""

for field_name, field_info in self.model_fields.items():
extra_field_info = unwrap(field_info.json_schema_extra, {})
if not extra_field_info:
continue

sample_validator = extra_field_info.get("sample_validator")
validation_error = unwrap(extra_field_info.get("validation_error"), "")

if sample_validator:
value = getattr(self, field_name)
if not sample_validator(value):
raise ValueError(f"{validation_error}. Got {value}")
# Temperature
if self.temperature < 0.0:
raise ValueError(
"Temperature must be a non-negative value. " f"Got {self.temperature}"
)

# Smoothing factor
if self.smoothing_factor < 0.0:
raise ValueError(
"Smoothing factor must be a non-negative value. "
f"Got {self.smoothing_factor}"
)

# Top K
if self.top_k < 0:
raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}")

# Top P
if self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}")

# Repetition Penalty
if self.repetition_penalty <= 0.0:
raise ValueError(
"Repetition penalty must be a positive value. "
f"Got {self.repetition_penalty}"
)

# Typical
if self.typical <= 0 and self.typical > 1:
raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}")

# Dynatemp values
if self.max_temp < 0.0:
raise ValueError(
"Max temp must be a non-negative value. ", f"Got {self.max_temp}"
)

if self.min_temp < 0.0:
raise ValueError(
"Min temp must be a non-negative value. ", f"Got {self.min_temp}"
)

if self.temp_exponent < 0.0:
raise ValueError(
"Temp exponent must be a non-negative value. ",
f"Got {self.temp_exponent}",
)

def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""
Expand Down

0 comments on commit a79c42f

Please sign in to comment.