From a79c42ff4c686ef9c9b91eaf7413a40606b8f30d Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 11 Feb 2024 15:22:43 -0500 Subject: [PATCH] Sampling: Make validators simpler Injecting into Pydantic fields caused issues with serialization for documentation rendering. Rather than reinvent the wheel again, switch to a chain of if statements for now. This may change in the future if subclasses from the base sampler request need to be validated as well. Signed-off-by: kingbri --- common/sampling.py | 85 +++++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index fb8e7fdb..3d0ce324 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -37,8 +37,6 @@ class BaseSamplerRequest(BaseModel): temperature: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("temperature", 1.0), examples=[1.0], - sample_validator=lambda value: value >= 0.0, - validation_error="Temperature must be a non-negative value", ) temperature_last: Optional[bool] = Field( @@ -47,21 +45,15 @@ class BaseSamplerRequest(BaseModel): smoothing_factor: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0), - sample_validator=lambda value: value >= 0.0, - validation_error="Smoothing factor must be a non-negative value", ) top_k: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("top_k", 0), - sample_validator=lambda value: value >= 0, - validation_error="Top K must be a non-negative value", ) top_p: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0], - sample_validator=lambda value: value >= 0.0 and value <= 1.0, - validation_error="Top P must be in [0, 1]", ) top_a: Optional[float] = Field( @@ -88,8 +80,6 @@ class BaseSamplerRequest(BaseModel): repetition_penalty: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), examples=[1.0], - sample_validator=lambda value: value > 0.0, - validation_error="Repetition penalty must be a positive value", ) repetition_decay: Optional[int] = Field( @@ -134,8 +124,6 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", examples=[1.0], - sample_validator=lambda value: value > 0.0 and value <= 1.0, - validation_error="Typical must be in (0, 1]", ) penalty_range: Optional[int] = Field( @@ -160,8 +148,6 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("max_temp", "dynatemp_high"), description="Aliases: dynatemp_high", examples=[1.0], - sample_validator=lambda value: value >= 0.0, - validation_error="Max temperature must be a non-negative value", ) min_temp: Optional[float] = Field( @@ -169,38 +155,69 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("min_temp", "dynatemp_low"), description="Aliases: dynatemp_low", examples=[1.0], - sample_validator=lambda value: value >= 0.0, - validation_error="Min temperature must be a non-negative value", ) temp_exponent: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"), examples=[1.0], - sample_validator=lambda value: value >= 0.0, - validation_error="Temperature exponent must be a non-negative value", ) + # TODO: Return back to adaptable class-based validation But that's just too much + # abstraction compared to simple if statements at the moment def validate_params(self): """ - Validates if the class field satisfies a condition if present. - - Validators are present in the extras section of a Pydantic field - to make it easy for adding more samplers if needed. + Validates sampler parameters to be within sane ranges. """ - for field_name, field_info in self.model_fields.items(): - extra_field_info = unwrap(field_info.json_schema_extra, {}) - if not extra_field_info: - continue - - sample_validator = extra_field_info.get("sample_validator") - validation_error = unwrap(extra_field_info.get("validation_error"), "") - - if sample_validator: - value = getattr(self, field_name) - if not sample_validator(value): - raise ValueError(f"{validation_error}. Got {value}") + # Temperature + if self.temperature < 0.0: + raise ValueError( + "Temperature must be a non-negative value. " f"Got {self.temperature}" + ) + + # Smoothing factor + if self.smoothing_factor < 0.0: + raise ValueError( + "Smoothing factor must be a non-negative value. " + f"Got {self.smoothing_factor}" + ) + + # Top K + if self.top_k < 0: + raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}") + + # Top P + if self.top_p < 0.0 or self.top_p > 1.0: + raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}") + + # Repetition Penalty + if self.repetition_penalty <= 0.0: + raise ValueError( + "Repetition penalty must be a positive value. " + f"Got {self.repetition_penalty}" + ) + + # Typical + if self.typical <= 0 and self.typical > 1: + raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}") + + # Dynatemp values + if self.max_temp < 0.0: + raise ValueError( + "Max temp must be a non-negative value. ", f"Got {self.max_temp}" + ) + + if self.min_temp < 0.0: + raise ValueError( + "Min temp must be a non-negative value. ", f"Got {self.min_temp}" + ) + + if self.temp_exponent < 0.0: + raise ValueError( + "Temp exponent must be a non-negative value. ", + f"Got {self.temp_exponent}", + ) def to_gen_params(self, **kwargs): """Converts samplers to internal generation params"""