Skip to content

Commit

Permalink
validate default at schema definition
Browse files Browse the repository at this point in the history
  • Loading branch information
changhc committed Dec 23, 2024
1 parent 562fad3 commit ee7e984
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 28 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ num-bigint = "0.4.6"
uuid = "1.11.0"
jiter = { version = "0.8.2", features = ["python"] }
hex = "0.4.3"
bitflags = "2.6.0"

[lib]
name = "_pydantic_core"
Expand Down
12 changes: 6 additions & 6 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class CoreConfig(TypedDict, total=False):
loc_by_alias: Whether to use the used alias (or first alias for "field required" errors) instead of
`field_names` to construct error `loc`s. Default is `True`.
revalidate_instances: Whether instances of models and dataclasses should re-validate. Default is 'never'.
validate_default: Whether to validate default values during validation. Default is `False`.
validate_default: Whether to validate default values during validation. Default is `never`.
populate_by_name: Whether an aliased field may be populated by its name as given by the model attribute,
as well as the alias. (Replaces 'allow_population_by_field_name' in Pydantic v1.) Default is `False`.
str_max_length: The maximum length for string fields.
Expand Down Expand Up @@ -92,8 +92,8 @@ class CoreConfig(TypedDict, total=False):
loc_by_alias: bool
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
revalidate_instances: Literal['always', 'never', 'subclass-instances']
# whether to validate default values during validation, default False
validate_default: bool
# whether to validate default values during validation, default 'never'
validate_default: Union[bool, Literal['never', 'definition', 'init']]
# used on typed-dicts and arguments
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
# fields related to string fields only
Expand Down Expand Up @@ -2403,7 +2403,7 @@ class WithDefaultSchema(TypedDict, total=False):
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]]
default_factory_takes_data: bool
on_error: Literal['raise', 'omit', 'default'] # default: 'raise'
validate_default: bool # default: False
validate_default: Union[bool, Literal['never', 'definition', 'init']] # default: 'never'
strict: bool
ref: str
metadata: Dict[str, Any]
Expand All @@ -2417,7 +2417,7 @@ def with_default_schema(
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None,
default_factory_takes_data: bool | None = None,
on_error: Literal['raise', 'omit', 'default'] | None = None,
validate_default: bool | None = None,
validate_default: bool | Literal['never', 'definition', 'init'] | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Dict[str, Any] | None = None,
Expand All @@ -2443,7 +2443,7 @@ def with_default_schema(
default_factory: A callable that returns the default value to use
default_factory_takes_data: Whether the default factory takes a validated data argument
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
validate_default: Whether the default value should be validated
validate_default: Whether the default value should be validated. One of 'never', 'definition', 'init' or True/False
strict: Whether the underlying schema should be validated with strict mode
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand Down
134 changes: 113 additions & 21 deletions src/validators/with_default.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::PyDict;
use pyo3::types::PyString;
use pyo3::types::{PyBool, PyDict, PyString};
use pyo3::PyTraverseError;
use pyo3::PyVisit;

use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
use crate::build_tools::py_schema_err;
use crate::build_tools::schema_or_config_same;
use crate::build_tools::{is_strict, py_schema_err, schema_or_config_same};
use crate::errors::{LocItem, ValError, ValResult};
use crate::input::Input;
use crate::py_gc::PyGcTraverse;
use crate::tools::SchemaDict;
use crate::validators::{Extra, InputType, RecursionState};
use crate::PydanticUndefinedType;
use crate::SchemaError;

static COPY_DEEPCOPY: GILOnceCell<PyObject> = GILOnceCell::new();

Expand Down Expand Up @@ -84,12 +85,58 @@ enum OnError {
Default,
}

bitflags::bitflags! {
#[derive(Debug, Clone)]
struct ValidateDefaultFlag: u8 {
const NEVER = 0;
const INIT = 0x01;
const DEFINITION = 0x02;
}
}

impl<'py> FromPyObject<'py> for ValidateDefaultFlag {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
if let Ok(bool_value) = ob.downcast::<PyBool>() {
Ok(bool_value.is_true().into())
} else if let Ok(str_value) = ob.extract::<&str>() {
match str_value {
"never" => Ok(Self::NEVER),
"init" => Ok(Self::INIT),
"definition" => Ok(Self::DEFINITION),
_ => Err(PyValueError::new_err(
"Invalid value for option `validate_default`, should be `'init'`, `'definition'`, `'never'` or a `bool`",
)),
}
} else {
Err(PyTypeError::new_err(
"Invalid value for option `validate_default`, should be `'init'`, `'definition'`, `'never'` or a `bool`",
))
}
}
}

impl From<bool> for ValidateDefaultFlag {
fn from(mode: bool) -> Self {
if mode {
Self::INIT
} else {
Self::NEVER
}
}
}

impl Default for ValidateDefaultFlag {
fn default() -> Self {
Self::NEVER
}
}

#[derive(Debug)]
pub struct WithDefaultValidator {
default: DefaultType,
on_error: OnError,
validator: Box<CombinedValidator>,
validate_default: bool,
validate_default: ValidateDefaultFlag,
copy_default: bool,
name: String,
undefined: PyObject,
Expand Down Expand Up @@ -134,17 +181,21 @@ impl BuildValidator for WithDefaultValidator {
};

let name = format!("{}[{}]", Self::EXPECTED_TYPE, validator.get_name());

Ok(Self {
let validate_default =
schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or_default();
let validator = Self {
default,
on_error,
validator,
validate_default: schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or(false),
validate_default,
copy_default,
name,
undefined: PydanticUndefinedType::new(py).to_object(py),
}
.into())
};

validator.validate_default_on_build(schema, config)?;

Ok(validator.into())
}
}

Expand Down Expand Up @@ -188,17 +239,8 @@ impl Validator for WithDefaultValidator {
} else {
stored_dft
};
if self.validate_default {
match self.validate(py, dft.bind(py), state) {
Ok(v) => Ok(Some(v)),
Err(e) => {
if let Some(outer_loc) = outer_loc {
Err(e.with_outer_location(outer_loc))
} else {
Err(e)
}
}
}
if self.validate_default.contains(ValidateDefaultFlag::INIT) {
self.validate_default(py, outer_loc, state, dft)
} else {
Ok(Some(dft))
}
Expand All @@ -220,4 +262,54 @@ impl WithDefaultValidator {
pub fn omit_on_error(&self) -> bool {
matches!(self.on_error, OnError::Omit)
}

fn validate_default<'py>(
&self,
py: Python<'py>,
outer_loc: Option<impl Into<LocItem>>,
state: &mut ValidationState<'_, 'py>,
dft: Py<PyAny>,
) -> ValResult<Option<PyObject>> {
match self.validate(py, dft.bind(py), state) {
Ok(v) => Ok(Some(v)),
Err(e) => {
if let Some(outer_loc) = outer_loc {
Err(e.with_outer_location(outer_loc))
} else {
Err(e)
}
}
}
}

fn validate_default_on_build(
&self,
schema: &Bound<'_, PyDict>,
config: Option<&Bound<'_, PyDict>>,
) -> PyResult<()> {
if self.validate_default.contains(ValidateDefaultFlag::DEFINITION) && self.has_default() {
// Since this method is called in `build` where validation state is not available,
// we need to craft a dummy one here. This setup is basically the same as in `SchemaValidator::get_default_value`
let mut recursion_guard = RecursionState::default();
let mut state = ValidationState::new(
Extra::new(
Some(is_strict(schema, config)?),
None,
None,
None,
InputType::Python,
true.into(),
),
&mut recursion_guard,
false.into(),
);
let py = schema.py();
if let Some(dft) = self.default.default_value(py, state.extra().data.as_ref())? {
if let Err(e) = self.validate_default(py, None::<usize>, &mut state, dft) {
return Err(SchemaError::from_val_error(py, e));
}
}
}
Ok(())
}
}
86 changes: 85 additions & 1 deletion tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import weakref
from collections import deque
from typing import Any, Callable, Dict, List, Union, cast
from typing import Any, Callable, Dict, List, Literal, Union, cast

import pytest

Expand Down Expand Up @@ -375,6 +375,75 @@ def test_validate_default(
assert v.validate_python({}) == {'x': expected}


@pytest.mark.parametrize('config_validate_default', [None, 'never', 'definition', 'init'])
@pytest.mark.parametrize(
'schema_validate_default',
[
None,
'never',
'definition',
'init',
],
)
def test_validate_default_flags(
config_validate_default: None | Literal['never', 'definition', 'init'],
schema_validate_default: None | Literal['never', 'definition', 'init'],
):
def create_schema():
if config_validate_default is not None:
config = core_schema.CoreConfig(validate_default=config_validate_default)
else:
config = None

v = SchemaValidator(
core_schema.typed_dict_schema(
{
'x': core_schema.typed_dict_field(
core_schema.with_default_schema(
core_schema.int_schema(), default='a', validate_default=schema_validate_default
)
)
},
config=config,
),
)
return v

def validate(v: SchemaValidator):
assert v.validate_python({}) == {'x': 'a'}

if (
schema_validate_default == 'definition'
or schema_validate_default is None
and (config_validate_default == 'definition')
):
with pytest.raises(SchemaError, match='Input should be a valid integer'):
create_schema()
else:
v = create_schema()
if schema_validate_default == 'init' or schema_validate_default is None and config_validate_default == 'init':
with pytest.raises(ValidationError, match='Input should be a valid integer'):
validate(v)
else:
validate(v)


def test_validate_default_flags_strict():
with pytest.raises(SchemaError, match='Input should be a valid integer'):
SchemaValidator(
core_schema.typed_dict_schema(
{
'x': core_schema.typed_dict_field(
core_schema.with_default_schema(
core_schema.int_schema(), default='1', validate_default='definition'
)
)
},
config=core_schema.CoreConfig(strict=True),
),
)


def test_validate_default_factory():
v = SchemaValidator(
core_schema.tuple_positional_schema(
Expand Down Expand Up @@ -818,3 +887,18 @@ def _raise(ex: Exception) -> None:
v.validate_python(input_value)

assert exc_info.value.errors(include_url=False, include_context=False) == expected


def test_validate_default_on_validator_creation():
SchemaValidator(
{
'type': 'typed-dict',
'fields': {
'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}},
'y': {
'type': 'typed-dict-field',
'schema': {'type': 'default', 'schema': {'type': 'str'}, 'default': 1},
},
},
}
)

0 comments on commit ee7e984

Please sign in to comment.