diff --git a/common/sampling.py b/common/sampling.py index c818b833..fb8e7fdb 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -37,6 +37,8 @@ class BaseSamplerRequest(BaseModel): temperature: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("temperature", 1.0), examples=[1.0], + sample_validator=lambda value: value >= 0.0, + validation_error="Temperature must be a non-negative value", ) temperature_last: Optional[bool] = Field( @@ -45,14 +47,21 @@ class BaseSamplerRequest(BaseModel): smoothing_factor: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0), + sample_validator=lambda value: value >= 0.0, + validation_error="Smoothing factor must be a non-negative value", ) top_k: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("top_k", 0) + default_factory=lambda: get_default_sampler_value("top_k", 0), + sample_validator=lambda value: value >= 0, + validation_error="Top K must be a non-negative value", ) top_p: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0] + default_factory=lambda: get_default_sampler_value("top_p", 1.0), + examples=[1.0], + sample_validator=lambda value: value >= 0.0 and value <= 1.0, + validation_error="Top P must be in [0, 1]", ) top_a: Optional[float] = Field( @@ -64,7 +73,8 @@ class BaseSamplerRequest(BaseModel): ) tfs: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("tfs", 1.0) + default_factory=lambda: get_default_sampler_value("tfs", 1.0), + examples=[1.0], ) frequency_penalty: Optional[float] = Field( @@ -78,6 +88,8 @@ class BaseSamplerRequest(BaseModel): repetition_penalty: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), examples=[1.0], + sample_validator=lambda value: value > 0.0, + validation_error="Repetition penalty must be a positive value", ) repetition_decay: Optional[int] = Field( @@ -122,6 +134,8 @@ class BaseSamplerRequest(BaseModel): validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", examples=[1.0], + sample_validator=lambda value: value > 0.0 and value <= 1.0, + validation_error="Typical must be in (0, 1]", ) penalty_range: Optional[int] = Field( @@ -145,26 +159,57 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("max_temp", 1.0), validation_alias=AliasChoices("max_temp", "dynatemp_high"), description="Aliases: dynatemp_high", + examples=[1.0], + sample_validator=lambda value: value >= 0.0, + validation_error="Max temperature must be a non-negative value", ) min_temp: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("min_temp", 1.0), validation_alias=AliasChoices("min_temp", "dynatemp_low"), description="Aliases: dynatemp_low", + examples=[1.0], + sample_validator=lambda value: value >= 0.0, + validation_error="Min temperature must be a non-negative value", ) temp_exponent: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"), examples=[1.0], + sample_validator=lambda value: value >= 0.0, + validation_error="Temperature exponent must be a non-negative value", ) + def validate_params(self): + """ + Validates if the class field satisfies a condition if present. + + Validators are present in the extras section of a Pydantic field + to make it easy for adding more samplers if needed. + """ + + for field_name, field_info in self.model_fields.items(): + extra_field_info = unwrap(field_info.json_schema_extra, {}) + if not extra_field_info: + continue + + sample_validator = extra_field_info.get("sample_validator") + validation_error = unwrap(extra_field_info.get("validation_error"), "") + + if sample_validator: + value = getattr(self, field_name) + if not sample_validator(value): + raise ValueError(f"{validation_error}. Got {value}") + def to_gen_params(self, **kwargs): """Converts samplers to internal generation params""" # Add forced overrides if present apply_forced_sampler_overrides(self) + self.validate_params() + # Convert stop to an array of strings if isinstance(self.stop, str): self.stop = [self.stop]