Skip to content

Commit 3359241

Browse files
authored
move check from Extra to SerializationState (#1862)
1 parent 423264f commit 3359241

25 files changed

+124
-133
lines changed

src/serializers/extra.rs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ pub(crate) struct SerializationState<'py> {
3030
pub model: Option<Bound<'py, PyAny>>,
3131
/// The name of the field currently being serialized, if any
3232
pub field_name: Option<FieldName<'py>>,
33+
/// Inside unions, checks are applied to attempt to select a preferred branch
34+
pub check: SerCheck,
3335
pub include_exclude: (Option<Bound<'py, PyAny>>, Option<Bound<'py, PyAny>>),
3436
}
3537

@@ -69,6 +71,7 @@ impl<'py> SerializationState<'py> {
6971
config,
7072
model: None,
7173
field_name: None,
74+
check: SerCheck::None,
7275
include_exclude: (include, exclude),
7376
})
7477
}
@@ -86,19 +89,18 @@ impl SerializationState<'_> {
8689
})
8790
}
8891

89-
pub fn warn_fallback_py(&mut self, field_type: &str, value: &Bound<'_, PyAny>, extra: &Extra) -> PyResult<()> {
92+
pub fn warn_fallback_py(&mut self, field_type: &str, value: &Bound<'_, PyAny>) -> PyResult<()> {
9093
self.warnings
91-
.on_fallback_py(field_type, value, self.field_name.as_ref(), extra)
94+
.on_fallback_py(field_type, value, self.field_name.as_ref(), self.check)
9295
}
9396

94-
pub fn warn_fallback_ser<'py, S: serde::ser::Serializer>(
97+
pub fn warn_fallback_ser<S: serde::ser::Serializer>(
9598
&mut self,
9699
field_type: &str,
97-
value: &Bound<'py, PyAny>,
98-
extra: &Extra<'_, 'py>,
100+
value: &Bound<'_, PyAny>,
99101
) -> Result<(), S::Error> {
100102
self.warnings
101-
.on_fallback_ser::<S>(field_type, value, self.field_name.as_ref(), extra)
103+
.on_fallback_ser::<S>(field_type, value, self.field_name.as_ref(), self.check)
102104
}
103105

104106
pub fn final_check(&self, py: Python) -> PyResult<()> {
@@ -160,8 +162,6 @@ pub(crate) struct Extra<'a, 'py> {
160162
pub exclude_none: bool,
161163
pub exclude_computed_fields: bool,
162164
pub round_trip: bool,
163-
// the next two are used for union logic
164-
pub check: SerCheck,
165165
pub serialize_unknown: bool,
166166
pub fallback: Option<&'a Bound<'py, PyAny>>,
167167
pub serialize_as_any: bool,
@@ -193,7 +193,6 @@ impl<'a, 'py> Extra<'a, 'py> {
193193
exclude_none,
194194
exclude_computed_fields,
195195
round_trip,
196-
check: SerCheck::None,
197196
serialize_unknown,
198197
fallback,
199198
serialize_as_any,
@@ -283,7 +282,7 @@ impl ExtraOwned {
283282
round_trip: extra.round_trip,
284283
config: state.config,
285284
rec_guard: state.rec_guard.clone(),
286-
check: extra.check,
285+
check: state.check,
287286
model: state.model.as_ref().map(|model| model.clone().into()),
288287
field_name: state.field_name.as_ref().map(|name| match name {
289288
FieldName::Root => FieldNameOwned::Root,
@@ -308,7 +307,6 @@ impl ExtraOwned {
308307
exclude_none: self.exclude_none,
309308
exclude_computed_fields: self.exclude_computed_fields,
310309
round_trip: self.round_trip,
311-
check: self.check,
312310
serialize_unknown: self.serialize_unknown,
313311
fallback: self.fallback.as_ref().map(|m| m.bind(py)),
314312
serialize_as_any: self.serialize_as_any,
@@ -321,16 +319,17 @@ impl ExtraOwned {
321319
warnings: self.warnings.clone(),
322320
rec_guard: self.rec_guard.clone(),
323321
config: self.config,
322+
model: self.model.as_ref().map(|m| m.bind(py).clone()),
324323
field_name: match &self.field_name {
325324
Some(FieldNameOwned::Root) => Some(FieldName::Root),
326325
Some(FieldNameOwned::Regular(b)) => Some(FieldName::Regular(b.bind(py).clone())),
327326
None => None,
328327
},
328+
check: self.check,
329329
include_exclude: (
330330
self.include.as_ref().map(|m| m.bind(py).clone()),
331331
self.exclude.as_ref().map(|m| m.bind(py).clone()),
332332
),
333-
model: self.model.as_ref().map(|m| m.bind(py).clone()),
334333
}
335334
}
336335
}
@@ -443,17 +442,17 @@ impl CollectWarnings {
443442
}
444443
}
445444

446-
pub fn on_fallback_py<'py>(
445+
fn on_fallback_py(
447446
&mut self,
448447
field_type: &str,
449-
value: &Bound<'py, PyAny>,
448+
value: &Bound<'_, PyAny>,
450449
field_name: Option<&FieldName<'_>>,
451-
extra: &Extra<'_, 'py>,
450+
check: SerCheck,
452451
) -> PyResult<()> {
453452
// special case for None as it's very common e.g. as a default value
454453
if value.is_none() {
455454
Ok(())
456-
} else if extra.check.enabled() {
455+
} else if check.enabled() {
457456
Err(PydanticSerializationUnexpectedValue::new_from_parts(
458457
field_name.map(ToString::to_string),
459458
Some(field_type.to_string()),
@@ -466,17 +465,17 @@ impl CollectWarnings {
466465
}
467466
}
468467

469-
pub fn on_fallback_ser<'py, S: serde::ser::Serializer>(
468+
pub fn on_fallback_ser<S: serde::ser::Serializer>(
470469
&mut self,
471470
field_type: &str,
472-
value: &Bound<'py, PyAny>,
471+
value: &Bound<'_, PyAny>,
473472
field_name: Option<&FieldName<'_>>,
474-
extra: &Extra<'_, 'py>,
473+
check: SerCheck,
475474
) -> Result<(), S::Error> {
476475
// special case for None as it's very common e.g. as a default value
477476
if value.is_none() {
478477
Ok(())
479-
} else if extra.check.enabled() {
478+
} else if check.enabled() {
480479
// note: I think this should never actually happen since we use `to_python(..., mode='json')` during
481480
// JSON serialization to "try" union branches, but it's here for completeness/correctness
482481
// in particular, in future we could allow errors instead of warnings on fallback

src/serializers/fields.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ impl GeneralFieldsSerializer {
215215
.filter(|_| !extra.serialize_as_any)
216216
.unwrap_or_else(|| AnySerializer::get());
217217
(&key, serializer)
218-
} else if extra.check == SerCheck::Strict {
218+
} else if state.check == SerCheck::Strict {
219219
return Err(PydanticSerializationUnexpectedValue::new(
220220
Some(format!("Unexpected field `{key}`")),
221221
Some(key_str.to_string()),
@@ -233,7 +233,7 @@ impl GeneralFieldsSerializer {
233233
}
234234
}
235235

236-
if extra.check.enabled()
236+
if state.check.enabled()
237237
// If any of these are true we can't count fields
238238
&& !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none || extra.exclude_computed_fields || state.exclude().is_some())
239239
// Check for missing fields, we can't have extra fields here
@@ -399,7 +399,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
399399
let model = state.model.clone().unwrap_or_else(|| value.clone());
400400

401401
let Some((main_dict, extra_dict)) = self.extract_dicts(value) else {
402-
state.warn_fallback_py(self.get_name(), value, extra)?;
402+
state.warn_fallback_py(self.get_name(), value)?;
403403
return infer_to_python(value, state, extra);
404404
};
405405
let output_dict = self.main_to_python(
@@ -450,7 +450,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
450450
extra: &Extra<'_, 'py>,
451451
) -> Result<S::Ok, S::Error> {
452452
let Some((main_dict, extra_dict)) = self.extract_dicts(value) else {
453-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
453+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
454454
return infer_serialize(value, serializer, state, extra);
455455
};
456456
let missing_sentinel = get_missing_sentinel_object(value.py());

src/serializers/infer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ pub(crate) fn call_pydantic_serializer<'py, T, E: From<PyErr>>(
648648
model: state.model.clone(),
649649
field_name: state.field_name.clone(),
650650
include_exclude: state.include_exclude.clone(),
651+
check: state.check,
651652
};
652653

653654
// Avoid falling immediately back into inference because we need to use the serializer

src/serializers/shared.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Debug {
390390
match extra.ob_type_lookup.is_type(key, ObType::None) {
391391
IsType::Exact | IsType::Subclass => py_err!(PyTypeError; "`{}` not valid as object key", expected_type),
392392
IsType::False => {
393-
state.warn_fallback_py(self.get_name(), key, extra)?;
393+
state.warn_fallback_py(self.get_name(), key)?;
394394
infer_json_key(key, state, extra)
395395
}
396396
}
@@ -680,7 +680,7 @@ impl<'py> DoSerialize<'py, Py<PyAny>, PyErr> for SerializeToPython {
680680
state: &mut SerializationState<'py>,
681681
extra: &Extra<'_, 'py>,
682682
) -> PyResult<Py<PyAny>> {
683-
state.warn_fallback_py(name, value, extra)?;
683+
state.warn_fallback_py(name, value)?;
684684
infer_to_python(value, state, extra)
685685
}
686686
}
@@ -709,9 +709,7 @@ impl<'py, S: Serializer> DoSerialize<'py, S::Ok, WrappedSerError<S::Error>> for
709709
state: &mut SerializationState<'py>,
710710
extra: &Extra<'_, 'py>,
711711
) -> Result<S::Ok, WrappedSerError<S::Error>> {
712-
state
713-
.warn_fallback_ser::<S>(name, value, extra)
714-
.map_err(WrappedSerError)?;
712+
state.warn_fallback_ser::<S>(name, value).map_err(WrappedSerError)?;
715713
infer_serialize(value, self.serializer, state, extra).map_err(WrappedSerError)
716714
}
717715
}

src/serializers/type_serializers/bytes.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl TypeSerializer for BytesSerializer {
8282
_ => Ok(value.clone().unbind()),
8383
},
8484
Err(_) => {
85-
state.warn_fallback_py(self.get_name(), value, extra)?;
85+
state.warn_fallback_py(self.get_name(), value)?;
8686
infer_to_python(value, state, extra)
8787
}
8888
}
@@ -97,7 +97,7 @@ impl TypeSerializer for BytesSerializer {
9797
match key.downcast::<PyBytes>() {
9898
Ok(py_bytes) => self.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
9999
Err(_) => {
100-
state.warn_fallback_py(self.get_name(), key, extra)?;
100+
state.warn_fallback_py(self.get_name(), key)?;
101101
infer_json_key(key, state, extra)
102102
}
103103
}
@@ -113,7 +113,7 @@ impl TypeSerializer for BytesSerializer {
113113
match value.downcast::<PyBytes>() {
114114
Ok(py_bytes) => self.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
115115
Err(_) => {
116-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
116+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
117117
infer_serialize(value, serializer, state, extra)
118118
}
119119
}

src/serializers/type_serializers/complex.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ impl TypeSerializer for ComplexSerializer {
4242
_ => Ok(value.clone().unbind()),
4343
},
4444
Err(_) => {
45-
state.warn_fallback_py(self.get_name(), value, extra)?;
45+
state.warn_fallback_py(self.get_name(), value)?;
4646
infer_to_python(value, state, extra)
4747
}
4848
}
@@ -70,7 +70,7 @@ impl TypeSerializer for ComplexSerializer {
7070
Ok(serializer.collect_str::<String>(&s)?)
7171
}
7272
Err(_) => {
73-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
73+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
7474
infer_serialize(value, serializer, state, extra)
7575
}
7676
}

src/serializers/type_serializers/dataclass.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ impl BuildSerializer for DataclassSerializer {
122122
}
123123

124124
impl DataclassSerializer {
125-
fn allow_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> PyResult<bool> {
126-
match extra.check {
125+
fn allow_value(&self, value: &Bound<'_, PyAny>, state: &SerializationState<'_>) -> PyResult<bool> {
126+
match state.check {
127127
SerCheck::Strict => Ok(value.get_type().is(self.class.bind(value.py()))),
128128
SerCheck::Lax => value.is_instance(self.class.bind(value.py())),
129129
SerCheck::None => value.hasattr(intern!(value.py(), "__dataclass_fields__")),
@@ -152,7 +152,7 @@ impl TypeSerializer for DataclassSerializer {
152152
extra: &Extra<'_, 'py>,
153153
) -> PyResult<Py<PyAny>> {
154154
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
155-
if self.allow_value(value, extra)? {
155+
if self.allow_value(value, state)? {
156156
let py = value.py();
157157
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
158158
let output_dict: Bound<PyDict> =
@@ -166,7 +166,7 @@ impl TypeSerializer for DataclassSerializer {
166166
}
167167
} else {
168168
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
169-
state.warn_fallback_py(self.get_name(), value, extra)?;
169+
state.warn_fallback_py(self.get_name(), value)?;
170170
infer_to_python(value, state, extra)
171171
}
172172
}
@@ -177,10 +177,10 @@ impl TypeSerializer for DataclassSerializer {
177177
state: &mut SerializationState<'py>,
178178
extra: &Extra<'_, 'py>,
179179
) -> PyResult<Cow<'a, str>> {
180-
if self.allow_value(key, extra)? {
180+
if self.allow_value(key, state)? {
181181
infer_json_key_known(ObType::Dataclass, key, state, extra)
182182
} else {
183-
state.warn_fallback_py(&self.name, key, extra)?;
183+
state.warn_fallback_py(&self.name, key)?;
184184
infer_json_key(key, state, extra)
185185
}
186186
}
@@ -193,7 +193,7 @@ impl TypeSerializer for DataclassSerializer {
193193
extra: &Extra<'_, 'py>,
194194
) -> Result<S::Ok, S::Error> {
195195
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
196-
if self.allow_value(value, extra).map_err(py_err_se_err)? {
196+
if self.allow_value(value, state).map_err(py_err_se_err)? {
197197
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
198198
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
199199
let mut map = fields_serializer.main_serde_serialize(
@@ -211,7 +211,7 @@ impl TypeSerializer for DataclassSerializer {
211211
}
212212
} else {
213213
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
214-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
214+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
215215
infer_serialize(value, serializer, state, extra)
216216
}
217217
}

src/serializers/type_serializers/datetime_etc.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ macro_rules! build_temporal_serializer {
125125
_ => Ok(value.clone().unbind()),
126126
},
127127
_ => {
128-
state.warn_fallback_py(self.get_name(), value, extra)?;
128+
state.warn_fallback_py(self.get_name(), value)?;
129129
infer_to_python(value, state, extra)
130130
}
131131
}
@@ -140,7 +140,7 @@ macro_rules! build_temporal_serializer {
140140
match $downcast(key) {
141141
Ok(py_value) => Ok(self.temporal_mode.$json_key_fn(py_value)?),
142142
Err(_) => {
143-
state.warn_fallback_py(self.get_name(), key, extra)?;
143+
state.warn_fallback_py(self.get_name(), key)?;
144144
infer_json_key(key, state, extra)
145145
}
146146
}
@@ -156,7 +156,7 @@ macro_rules! build_temporal_serializer {
156156
match $downcast(value) {
157157
Ok(py_value) => self.temporal_mode.$serialize_fn(py_value, serializer),
158158
Err(_) => {
159-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
159+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
160160
infer_serialize(value, serializer, state, extra)
161161
}
162162
}

src/serializers/type_serializers/decimal.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ impl TypeSerializer for DecimalSerializer {
4444
match extra.ob_type_lookup.is_type(value, ObType::Decimal) {
4545
IsType::Exact | IsType::Subclass => infer_to_python_known(ObType::Decimal, value, state, extra),
4646
IsType::False => {
47-
state.warn_fallback_py(self.get_name(), value, extra)?;
47+
state.warn_fallback_py(self.get_name(), value)?;
4848
infer_to_python(value, state, extra)
4949
}
5050
}
@@ -59,7 +59,7 @@ impl TypeSerializer for DecimalSerializer {
5959
match extra.ob_type_lookup.is_type(key, ObType::Decimal) {
6060
IsType::Exact | IsType::Subclass => infer_json_key_known(ObType::Decimal, key, state, extra),
6161
IsType::False => {
62-
state.warn_fallback_py(self.get_name(), key, extra)?;
62+
state.warn_fallback_py(self.get_name(), key)?;
6363
infer_json_key(key, state, extra)
6464
}
6565
}
@@ -75,7 +75,7 @@ impl TypeSerializer for DecimalSerializer {
7575
match extra.ob_type_lookup.is_type(value, ObType::Decimal) {
7676
IsType::Exact | IsType::Subclass => infer_serialize_known(ObType::Decimal, value, serializer, state, extra),
7777
IsType::False => {
78-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
78+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
7979
infer_serialize(value, serializer, state, extra)
8080
}
8181
}

src/serializers/type_serializers/dict.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl TypeSerializer for DictSerializer {
105105
Ok(new_dict.into())
106106
}
107107
Err(_) => {
108-
state.warn_fallback_py(self.get_name(), value, extra)?;
108+
state.warn_fallback_py(self.get_name(), value)?;
109109
infer_to_python(value, state, extra)
110110
}
111111
}
@@ -145,7 +145,7 @@ impl TypeSerializer for DictSerializer {
145145
map.end()
146146
}
147147
Err(_) => {
148-
state.warn_fallback_ser::<S>(self.get_name(), value, extra)?;
148+
state.warn_fallback_ser::<S>(self.get_name(), value)?;
149149
infer_serialize(value, serializer, state, extra)
150150
}
151151
}

0 commit comments

Comments
 (0)