From acce26b595f47e6a069e4120cc05b2de7855c20d Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 5 Feb 2025 18:31:41 +0100 Subject: [PATCH] Do not call default factories taking the data argument if a validation error already occurred --- src/errors/types.rs | 4 ++++ src/validators/mod.rs | 2 +- src/validators/model_fields.rs | 15 +++++++++----- src/validators/validation_state.rs | 5 +++++ src/validators/with_default.rs | 7 ++++++- tests/validators/test_with_default.py | 30 +++++++++++++++++++++++++++ 6 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/errors/types.rs b/src/errors/types.rs index 359f0c3de..855ebbc09 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -196,6 +196,9 @@ error_types! { class_name: {ctx_type: String, ctx_fn: field_from_context}, }, // --------------------- + // Default factory not called (happens when there's already an error and the factory takes data) + DefaultFactoryNotCalled {}, + // --------------------- // None errors NoneRequired {}, // --------------------- @@ -490,6 +493,7 @@ impl ErrorType { Self::ModelAttributesType {..} => "Input should be a valid dictionary or object to extract fields from", Self::DataclassType {..} => "Input should be a dictionary or an instance of {class_name}", Self::DataclassExactType {..} => "Input should be an instance of {class_name}", + Self::DefaultFactoryNotCalled {..} => "The default factory uses validated data, but at least one validation error occurred", Self::NoneRequired {..} => "Input should be None", Self::GreaterThan {..} => "Input should be greater than {gt}", Self::GreaterThanEqual {..} => "Input should be greater than or equal to {ge}", diff --git a/src/validators/mod.rs b/src/validators/mod.rs index bc1851ac5..896732733 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -618,7 +618,7 @@ pub fn build_validator( pub struct Extra<'a, 'py> { /// Validation mode pub input_type: InputType, - /// This is used as the `data` kwargs to validator functions + /// This is used as the `data` kwargs to validator functions and default factories (if they accept the argument) pub data: Option>, /// whether we're in strict or lax mode pub strict: Option, diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 392760964..3ff197044 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -191,13 +191,18 @@ impl Validator for ModelFieldsValidator { fields_set_vec.push(field.name_py.clone_ref(py)); fields_set_count += 1; } - Err(ValError::Omit) => continue, - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + Err(e) => { + state.has_field_error = true; + match e { + ValError::Omit => continue, + ValError::LineErrors(line_errors) => { + for err in line_errors { + errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + } + } + err => return Err(err), } } - Err(err) => return Err(err), } continue; } diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 8ee41f5de..a1d9b2f25 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -25,6 +25,10 @@ pub struct ValidationState<'a, 'py> { pub fields_set_count: Option, // True if `allow_partial=true` and we're validating the last element of a sequence or mapping. pub allow_partial: PartialMode, + // Whether at least one field had a validation error. This is used in the context of structured types + // (models, dataclasses, etc), where we need to know if a validation error occurred before calling + // a default factory that takes the validated data. + pub has_field_error: bool, // deliberately make Extra readonly extra: Extra<'a, 'py>, } @@ -36,6 +40,7 @@ impl<'a, 'py> ValidationState<'a, 'py> { exactness: None, fields_set_count: None, allow_partial, + has_field_error: false, extra, } } diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index 810008373..cb78fd34e 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -9,7 +9,7 @@ use pyo3::PyVisit; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; -use crate::errors::{LocItem, ValError, ValResult}; +use crate::errors::{ErrorTypeDefaults, LocItem, ValError, ValResult}; use crate::input::Input; use crate::py_gc::PyGcTraverse; use crate::tools::SchemaDict; @@ -180,6 +180,11 @@ impl Validator for WithDefaultValidator { outer_loc: Option>, state: &mut ValidationState<'_, 'py>, ) -> ValResult> { + if matches!(self.default, DefaultType::DefaultFactory(_, true)) && state.has_field_error { + // The default factory might use data from fields that failed to validate, and this results + // in an unhelpul error. + return Err(ValError::new(ErrorTypeDefaults::DefaultFactoryNotCalled, input)); + } match self.default.default_value(py, state.extra().data.as_ref())? { Some(stored_dft) => { let dft: Py = if self.copy_default { diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index a27b80268..6ec461942 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -818,3 +818,33 @@ def _raise(ex: Exception) -> None: v.validate_python(input_value) assert exc_info.value.errors(include_url=False, include_context=False) == expected + + +def test_default_factory_not_called_if_existing_error() -> None: + class Test: + def __init__(self, a: int, b: int): + self.a = a + self.b = b + + schema = core_schema.model_schema( + cls=Test, + schema=core_schema.model_fields_schema( + computed_fields=[], + fields={ + 'a': core_schema.model_field( + schema=core_schema.int_schema(), + ), + 'b': core_schema.model_field( + schema=core_schema.with_default_schema( + schema=core_schema.int_schema(), + default_factory=lambda data: data['a'], + default_factory_takes_data=True, + ), + ), + }, + ), + ) + + v = SchemaValidator(schema) + with pytest.raises(ValidationError): + v.validate_python({'a': 'not_an_int'})