From ee7e984384e2269c5e0b0c605d10d7516d434153 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Mon, 23 Dec 2024 22:46:23 +0000 Subject: [PATCH] validate default at schema definition --- Cargo.lock | 7 ++ Cargo.toml | 1 + python/pydantic_core/core_schema.py | 12 +-- src/validators/with_default.rs | 134 ++++++++++++++++++++++---- tests/validators/test_with_default.py | 86 ++++++++++++++++- 5 files changed, 212 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2fb6ee4a9..65a4008dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "bitvec" version = "1.0.1" @@ -438,6 +444,7 @@ version = "2.27.2" dependencies = [ "ahash", "base64", + "bitflags", "enum_dispatch", "hex", "idna 1.0.3", diff --git a/Cargo.toml b/Cargo.toml index 0abfc5450..7de39a20c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ num-bigint = "0.4.6" uuid = "1.11.0" jiter = { version = "0.8.2", features = ["python"] } hex = "0.4.3" +bitflags = "2.6.0" [lib] name = "_pydantic_core" diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index a45d82f9c..aabed09a9 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -57,7 +57,7 @@ class CoreConfig(TypedDict, total=False): loc_by_alias: Whether to use the used alias (or first alias for "field required" errors) instead of `field_names` to construct error `loc`s. Default is `True`. revalidate_instances: Whether instances of models and dataclasses should re-validate. Default is 'never'. - validate_default: Whether to validate default values during validation. Default is `False`. + validate_default: Whether to validate default values during validation. Default is `never`. populate_by_name: Whether an aliased field may be populated by its name as given by the model attribute, as well as the alias. (Replaces 'allow_population_by_field_name' in Pydantic v1.) Default is `False`. str_max_length: The maximum length for string fields. @@ -92,8 +92,8 @@ class CoreConfig(TypedDict, total=False): loc_by_alias: bool # whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never' revalidate_instances: Literal['always', 'never', 'subclass-instances'] - # whether to validate default values during validation, default False - validate_default: bool + # whether to validate default values during validation, default 'never' + validate_default: Union[bool, Literal['never', 'definition', 'init']] # used on typed-dicts and arguments populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 # fields related to string fields only @@ -2403,7 +2403,7 @@ class WithDefaultSchema(TypedDict, total=False): default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]] default_factory_takes_data: bool on_error: Literal['raise', 'omit', 'default'] # default: 'raise' - validate_default: bool # default: False + validate_default: Union[bool, Literal['never', 'definition', 'init']] # default: 'never' strict: bool ref: str metadata: Dict[str, Any] @@ -2417,7 +2417,7 @@ def with_default_schema( default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None, default_factory_takes_data: bool | None = None, on_error: Literal['raise', 'omit', 'default'] | None = None, - validate_default: bool | None = None, + validate_default: bool | Literal['never', 'definition', 'init'] | None = None, strict: bool | None = None, ref: str | None = None, metadata: Dict[str, Any] | None = None, @@ -2443,7 +2443,7 @@ def with_default_schema( default_factory: A callable that returns the default value to use default_factory_takes_data: Whether the default factory takes a validated data argument on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default' - validate_default: Whether the default value should be validated + validate_default: Whether the default value should be validated. One of 'never', 'definition', 'init' or True/False strict: Whether the underlying schema should be validated with strict mode ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index 55fbc6ce5..be47074cd 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -1,19 +1,20 @@ +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::intern; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; -use pyo3::types::PyDict; -use pyo3::types::PyString; +use pyo3::types::{PyBool, PyDict, PyString}; use pyo3::PyTraverseError; use pyo3::PyVisit; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -use crate::build_tools::py_schema_err; -use crate::build_tools::schema_or_config_same; +use crate::build_tools::{is_strict, py_schema_err, schema_or_config_same}; use crate::errors::{LocItem, ValError, ValResult}; use crate::input::Input; use crate::py_gc::PyGcTraverse; use crate::tools::SchemaDict; +use crate::validators::{Extra, InputType, RecursionState}; use crate::PydanticUndefinedType; +use crate::SchemaError; static COPY_DEEPCOPY: GILOnceCell = GILOnceCell::new(); @@ -84,12 +85,58 @@ enum OnError { Default, } +bitflags::bitflags! { + #[derive(Debug, Clone)] + struct ValidateDefaultFlag: u8 { + const NEVER = 0; + const INIT = 0x01; + const DEFINITION = 0x02; + } +} + +impl<'py> FromPyObject<'py> for ValidateDefaultFlag { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(bool_value) = ob.downcast::() { + Ok(bool_value.is_true().into()) + } else if let Ok(str_value) = ob.extract::<&str>() { + match str_value { + "never" => Ok(Self::NEVER), + "init" => Ok(Self::INIT), + "definition" => Ok(Self::DEFINITION), + _ => Err(PyValueError::new_err( + "Invalid value for option `validate_default`, should be `'init'`, `'definition'`, `'never'` or a `bool`", + )), + } + } else { + Err(PyTypeError::new_err( + "Invalid value for option `validate_default`, should be `'init'`, `'definition'`, `'never'` or a `bool`", + )) + } + } +} + +impl From for ValidateDefaultFlag { + fn from(mode: bool) -> Self { + if mode { + Self::INIT + } else { + Self::NEVER + } + } +} + +impl Default for ValidateDefaultFlag { + fn default() -> Self { + Self::NEVER + } +} + #[derive(Debug)] pub struct WithDefaultValidator { default: DefaultType, on_error: OnError, validator: Box, - validate_default: bool, + validate_default: ValidateDefaultFlag, copy_default: bool, name: String, undefined: PyObject, @@ -134,17 +181,21 @@ impl BuildValidator for WithDefaultValidator { }; let name = format!("{}[{}]", Self::EXPECTED_TYPE, validator.get_name()); - - Ok(Self { + let validate_default = + schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or_default(); + let validator = Self { default, on_error, validator, - validate_default: schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or(false), + validate_default, copy_default, name, undefined: PydanticUndefinedType::new(py).to_object(py), - } - .into()) + }; + + validator.validate_default_on_build(schema, config)?; + + Ok(validator.into()) } } @@ -188,17 +239,8 @@ impl Validator for WithDefaultValidator { } else { stored_dft }; - if self.validate_default { - match self.validate(py, dft.bind(py), state) { - Ok(v) => Ok(Some(v)), - Err(e) => { - if let Some(outer_loc) = outer_loc { - Err(e.with_outer_location(outer_loc)) - } else { - Err(e) - } - } - } + if self.validate_default.contains(ValidateDefaultFlag::INIT) { + self.validate_default(py, outer_loc, state, dft) } else { Ok(Some(dft)) } @@ -220,4 +262,54 @@ impl WithDefaultValidator { pub fn omit_on_error(&self) -> bool { matches!(self.on_error, OnError::Omit) } + + fn validate_default<'py>( + &self, + py: Python<'py>, + outer_loc: Option>, + state: &mut ValidationState<'_, 'py>, + dft: Py, + ) -> ValResult> { + match self.validate(py, dft.bind(py), state) { + Ok(v) => Ok(Some(v)), + Err(e) => { + if let Some(outer_loc) = outer_loc { + Err(e.with_outer_location(outer_loc)) + } else { + Err(e) + } + } + } + } + + fn validate_default_on_build( + &self, + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + ) -> PyResult<()> { + if self.validate_default.contains(ValidateDefaultFlag::DEFINITION) && self.has_default() { + // Since this method is called in `build` where validation state is not available, + // we need to craft a dummy one here. This setup is basically the same as in `SchemaValidator::get_default_value` + let mut recursion_guard = RecursionState::default(); + let mut state = ValidationState::new( + Extra::new( + Some(is_strict(schema, config)?), + None, + None, + None, + InputType::Python, + true.into(), + ), + &mut recursion_guard, + false.into(), + ); + let py = schema.py(); + if let Some(dft) = self.default.default_value(py, state.extra().data.as_ref())? { + if let Err(e) = self.validate_default(py, None::, &mut state, dft) { + return Err(SchemaError::from_val_error(py, e)); + } + } + } + Ok(()) + } } diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 105241140..ec02ef3b5 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -3,7 +3,7 @@ import sys import weakref from collections import deque -from typing import Any, Callable, Dict, List, Union, cast +from typing import Any, Callable, Dict, List, Literal, Union, cast import pytest @@ -375,6 +375,75 @@ def test_validate_default( assert v.validate_python({}) == {'x': expected} +@pytest.mark.parametrize('config_validate_default', [None, 'never', 'definition', 'init']) +@pytest.mark.parametrize( + 'schema_validate_default', + [ + None, + 'never', + 'definition', + 'init', + ], +) +def test_validate_default_flags( + config_validate_default: None | Literal['never', 'definition', 'init'], + schema_validate_default: None | Literal['never', 'definition', 'init'], +): + def create_schema(): + if config_validate_default is not None: + config = core_schema.CoreConfig(validate_default=config_validate_default) + else: + config = None + + v = SchemaValidator( + core_schema.typed_dict_schema( + { + 'x': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.int_schema(), default='a', validate_default=schema_validate_default + ) + ) + }, + config=config, + ), + ) + return v + + def validate(v: SchemaValidator): + assert v.validate_python({}) == {'x': 'a'} + + if ( + schema_validate_default == 'definition' + or schema_validate_default is None + and (config_validate_default == 'definition') + ): + with pytest.raises(SchemaError, match='Input should be a valid integer'): + create_schema() + else: + v = create_schema() + if schema_validate_default == 'init' or schema_validate_default is None and config_validate_default == 'init': + with pytest.raises(ValidationError, match='Input should be a valid integer'): + validate(v) + else: + validate(v) + + +def test_validate_default_flags_strict(): + with pytest.raises(SchemaError, match='Input should be a valid integer'): + SchemaValidator( + core_schema.typed_dict_schema( + { + 'x': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.int_schema(), default='1', validate_default='definition' + ) + ) + }, + config=core_schema.CoreConfig(strict=True), + ), + ) + + def test_validate_default_factory(): v = SchemaValidator( core_schema.tuple_positional_schema( @@ -818,3 +887,18 @@ def _raise(ex: Exception) -> None: v.validate_python(input_value) assert exc_info.value.errors(include_url=False, include_context=False) == expected + + +def test_validate_default_on_validator_creation(): + SchemaValidator( + { + 'type': 'typed-dict', + 'fields': { + 'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, + 'y': { + 'type': 'typed-dict-field', + 'schema': {'type': 'default', 'schema': {'type': 'str'}, 'default': 1}, + }, + }, + } + )