Skip to content

Commit bed9c14

Browse files
davidhewittViicos
andauthored
fix issue with field_serializers on nested typed dicts (#1879)
Co-authored-by: Victorien <[email protected]>
1 parent 20d576b commit bed9c14

File tree

6 files changed

+115
-40
lines changed

6 files changed

+115
-40
lines changed

src/serializers/extra.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ impl<'a, 'py> SerializationState<'a, 'py> {
149149
self.include_exclude.1.as_ref()
150150
}
151151

152-
pub(crate) fn model_type_name(&self) -> Option<Bound<'py, PyString>> {
153-
self.model.as_ref().and_then(|model| model.get_type().name().ok())
154-
}
155-
156152
pub fn serialize_infer<'slf>(
157153
&'slf mut self,
158154
value: &'slf Bound<'py, PyAny>,

src/serializers/fields.rs

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ impl GeneralFieldsSerializer {
171171
pub(crate) fn main_to_python<'py>(
172172
&self,
173173
py: Python<'py>,
174+
model: &Bound<'py, PyAny>,
174175
main_iter: impl Iterator<Item = PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>>,
175176
state: &mut SerializationState<'_, 'py>,
176177
) -> PyResult<Bound<'py, PyDict>> {
@@ -218,7 +219,7 @@ impl GeneralFieldsSerializer {
218219
return Err(PydanticSerializationUnexpectedValue::new(
219220
Some(format!("Unexpected field `{key}`")),
220221
Some(key_str.to_string()),
221-
state.model_type_name().map(|bound| bound.to_string()),
222+
model_type_name(model),
222223
None,
223224
)
224225
.to_py_err());
@@ -244,8 +245,8 @@ impl GeneralFieldsSerializer {
244245
Err(PydanticSerializationUnexpectedValue::new(
245246
Some(format!("Expected {required_fields} fields but got {used_req_fields}").to_string()),
246247
state.field_name.as_ref().map(ToString::to_string),
247-
state.model_type_name().map(|bound| bound.to_string()),
248-
state.model.clone().map(Bound::unbind),
248+
model_type_name(model),
249+
Some(model.clone().unbind()),
249250
)
250251
.to_py_err())
251252
} else {
@@ -353,7 +354,6 @@ impl GeneralFieldsSerializer {
353354
state: &mut SerializationState<'_, 'py>,
354355
) -> PyResult<()> {
355356
if let Some(ref computed_fields) = self.computed_fields {
356-
let state = &mut state.scoped_set(|s| &mut s.model, Some(model.clone()));
357357
computed_fields.to_python(model, output_dict, &self.filter, state)?;
358358
}
359359
Ok(())
@@ -366,7 +366,6 @@ impl GeneralFieldsSerializer {
366366
state: &mut SerializationState<'_, 'py>,
367367
) -> Result<(), S::Error> {
368368
if let Some(ref computed_fields) = self.computed_fields {
369-
// FIXME: need to match state.model setting above in `add_computed_fields_python`??
370369
computed_fields.serde_serialize::<S>(model, map, &self.filter, state)?;
371370
}
372371
Ok(())
@@ -390,21 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
390389
) -> PyResult<Py<PyAny>> {
391390
let py = value.py();
392391
let missing_sentinel = get_missing_sentinel_object(py);
393-
// If there is already a model registered (from a dataclass, BaseModel)
394-
// then do not touch it
395-
// If there is no model, we (a TypedDict) are the model
396-
let model = state.model.clone().unwrap_or_else(|| value.clone());
392+
393+
let model = get_model(state)?;
397394

398395
let Some((main_dict, extra_dict)) = self.extract_dicts(value) else {
399396
state.warn_fallback_py(self.get_name(), value)?;
400397
return infer_to_python(value, state);
401398
};
402-
let output_dict = self.main_to_python(
403-
py,
404-
dict_items(&main_dict),
405-
// FIXME: should also set model for extra serialization?
406-
&mut state.scoped_set(|s| &mut s.model, Some(model.clone())),
407-
)?;
399+
let output_dict = self.main_to_python(py, &model, dict_items(&main_dict), state)?;
408400

409401
// this is used to include `__pydantic_extra__` in serialization on models
410402
if let Some(extra_dict) = extra_dict {
@@ -448,24 +440,15 @@ impl TypeSerializer for GeneralFieldsSerializer {
448440
return infer_serialize(value, serializer, state);
449441
};
450442
let missing_sentinel = get_missing_sentinel_object(value.py());
451-
// If there is already a model registered (from a dataclass, BaseModel)
452-
// then do not touch it
453-
// If there is no model, we (a TypedDict) are the model
454-
let model = state.model.clone().unwrap_or_else(|| value.clone());
443+
let model = get_model(state).map_err(py_err_se_err)?;
455444

456445
let expected_len = match self.mode {
457446
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
458447
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
459448
};
460449
// NOTE! As above, we maintain the order of the input dict assuming that's right
461450
// we don't both with `used_req_fields` here because on unions, `to_python(..., mode='json')` is used
462-
let mut map = self.main_serde_serialize(
463-
dict_items(&main_dict),
464-
expected_len,
465-
serializer,
466-
// FIXME: should also set model for extra serialization?
467-
&mut state.scoped_set(|s| &mut s.model, Some(model.clone())),
468-
)?;
451+
let mut map = self.main_serde_serialize(dict_items(&main_dict), expected_len, serializer, state)?;
469452

470453
// this is used to include `__pydantic_extra__` in serialization on models
471454
if let Some(extra_dict) = extra_dict {
@@ -507,3 +490,19 @@ fn dict_items<'py>(
507490
let main_items: SmallVec<[_; 16]> = main_dict.iter().collect();
508491
main_items.into_iter().map(Ok)
509492
}
493+
494+
fn get_model<'py>(state: &mut SerializationState<'_, 'py>) -> PyResult<Bound<'py, PyAny>> {
495+
state.model.clone().ok_or_else(|| {
496+
PydanticSerializationUnexpectedValue::new(
497+
Some("No model found for fields serialization".to_string()),
498+
None,
499+
None,
500+
None,
501+
)
502+
.to_py_err()
503+
})
504+
}
505+
506+
fn model_type_name(model: &Bound<'_, PyAny>) -> Option<String> {
507+
model.get_type().name().ok().map(|s| s.to_string())
508+
}

src/serializers/shared.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ combined_serializer! {
110110
super::type_serializers::function::FunctionPlainSerializerBuilder;
111111
super::type_serializers::function::FunctionWrapSerializerBuilder;
112112
super::type_serializers::model::ModelFieldsBuilder;
113-
super::type_serializers::typed_dict::TypedDictBuilder;
114113
}
115114
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
116115
// `find_serializer` so they can be used via a `type` str.
@@ -151,6 +150,7 @@ combined_serializer! {
151150
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
152151
Tuple: super::type_serializers::tuple::TupleSerializer;
153152
Complex: super::type_serializers::complex::ComplexSerializer;
153+
TypedDict: super::type_serializers::typed_dict::TypedDictSerializer;
154154
}
155155
}
156156

@@ -356,6 +356,7 @@ impl PyGcTraverse for CombinedSerializer {
356356
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
357357
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
358358
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
359+
CombinedSerializer::TypedDict(inner) => inner.py_gc_traverse(visit),
359360
}
360361
}
361362
}

src/serializers/type_serializers/dataclass.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,21 +150,21 @@ impl TypeSerializer for DataclassSerializer {
150150
value: &Bound<'py, PyAny>,
151151
state: &mut SerializationState<'_, 'py>,
152152
) -> PyResult<Py<PyAny>> {
153-
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
154153
if self.allow_value(value, state)? {
154+
let model = value;
155+
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
155156
let py = value.py();
156157
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
157158
let output_dict: Bound<PyDict> =
158-
fields_serializer.main_to_python(py, known_dataclass_iter(&self.fields, value), state)?;
159+
fields_serializer.main_to_python(py, model, known_dataclass_iter(&self.fields, model), state)?;
159160

160-
fields_serializer.add_computed_fields_python(value, &output_dict, state)?;
161+
fields_serializer.add_computed_fields_python(model, &output_dict, state)?;
161162
Ok(output_dict.into())
162163
} else {
163164
let inner_value = self.get_inner_value(value)?;
164165
self.serializer.to_python(&inner_value, state)
165166
}
166167
} else {
167-
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
168168
state.warn_fallback_py(self.get_name(), value)?;
169169
infer_to_python(value, state)
170170
}
@@ -189,8 +189,8 @@ impl TypeSerializer for DataclassSerializer {
189189
serializer: S,
190190
state: &mut SerializationState<'_, 'py>,
191191
) -> Result<S::Ok, S::Error> {
192-
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
193192
if self.allow_value(value, state).map_err(py_err_se_err)? {
193+
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
194194
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
195195
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
196196
let mut map = fields_serializer.main_serde_serialize(
@@ -206,7 +206,6 @@ impl TypeSerializer for DataclassSerializer {
206206
self.serializer.serde_serialize(&inner_value, serializer, state)
207207
}
208208
} else {
209-
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
210209
state.warn_fallback_ser::<S>(self.get_name(), value)?;
211210
infer_serialize(value, serializer, state)
212211
}

src/serializers/type_serializers/typed_dict.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::borrow::Cow;
12
use std::sync::Arc;
23

34
use pyo3::intern;
@@ -9,14 +10,20 @@ use ahash::AHashMap;
910
use crate::build_tools::py_schema_err;
1011
use crate::build_tools::{py_schema_error_type, schema_or_config, ExtraBehavior};
1112
use crate::definitions::DefinitionsBuilder;
13+
use crate::serializers::shared::TypeSerializer;
14+
use crate::serializers::SerializationState;
1215
use crate::tools::SchemaDict;
1316

1417
use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField};
1518

1619
#[derive(Debug)]
17-
pub struct TypedDictBuilder;
20+
pub struct TypedDictSerializer {
21+
serializer: GeneralFieldsSerializer,
22+
}
23+
24+
impl_py_gc_traverse!(TypedDictSerializer { serializer });
1825

19-
impl BuildSerializer for TypedDictBuilder {
26+
impl BuildSerializer for TypedDictSerializer {
2027
const EXPECTED_TYPE: &'static str = "typed-dict";
2128

2229
fn build(
@@ -82,10 +89,51 @@ impl BuildSerializer for TypedDictBuilder {
8289
}
8390
}
8491

92+
// FIXME: computed fields do not work for TypedDict, and may never
93+
// see the closed https://github.com/pydantic/pydantic-core/pull/1018
8594
let computed_fields = ComputedFields::new(schema, config, definitions)?;
8695

8796
Ok(Arc::new(
88-
GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into(),
97+
Self {
98+
serializer: GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields),
99+
}
100+
.into(),
89101
))
90102
}
91103
}
104+
105+
impl TypeSerializer for TypedDictSerializer {
106+
fn to_python<'py>(
107+
&self,
108+
value: &Bound<'py, PyAny>,
109+
state: &mut SerializationState<'_, 'py>,
110+
) -> PyResult<Py<PyAny>> {
111+
self.serializer
112+
.to_python(value, &mut state.scoped_set(|s| &mut s.model, Some(value.clone())))
113+
}
114+
115+
fn json_key<'a, 'py>(
116+
&self,
117+
key: &'a Bound<'py, PyAny>,
118+
state: &mut SerializationState<'_, 'py>,
119+
) -> PyResult<Cow<'a, str>> {
120+
self.invalid_as_json_key(key, state, "typed-dict")
121+
}
122+
123+
fn serde_serialize<'py, S: serde::ser::Serializer>(
124+
&self,
125+
value: &Bound<'py, PyAny>,
126+
serializer: S,
127+
state: &mut SerializationState<'_, 'py>,
128+
) -> Result<S::Ok, S::Error> {
129+
self.serializer.serde_serialize(
130+
value,
131+
serializer,
132+
&mut state.scoped_set(|s| &mut s.model, Some(value.clone())),
133+
)
134+
}
135+
136+
fn get_name(&self) -> &'static str {
137+
"typed-dict"
138+
}
139+
}

tests/serializers/test_typed_dict.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,35 @@ class Model(TypedDict):
376376
)
377377
s = SchemaSerializer(schema, config=core_schema.CoreConfig(serialize_by_alias=config or False))
378378
assert s.to_python(Model(my_field=1), by_alias=runtime) == expected
379+
380+
381+
def test_nested_typed_dict_field_serializers():
382+
class Model(TypedDict):
383+
x: Any
384+
385+
class OuterModel(TypedDict):
386+
model: Model
387+
388+
schema = core_schema.typed_dict_schema(
389+
{
390+
'x': core_schema.typed_dict_field(
391+
core_schema.any_schema(
392+
serialization=core_schema.wrap_serializer_function_ser_schema(
393+
# in an incorrect core implementation, self could be OuterModel here
394+
lambda self, v, serializer: f'{list(self.keys())}',
395+
is_field_serializer=True,
396+
schema=core_schema.any_schema(),
397+
)
398+
)
399+
)
400+
}
401+
)
402+
outer_schema = core_schema.typed_dict_schema({'model': core_schema.typed_dict_field(schema)})
403+
404+
s = SchemaSerializer(schema)
405+
assert s.to_python(Model(x=None)) == {'x': "['x']"}
406+
407+
outer_s = SchemaSerializer(outer_schema)
408+
# if the inner field serializer incorrectly receives OuterModel as self, the keys
409+
# will be ['model'] instead of ['x']
410+
assert outer_s.to_python(OuterModel(model=Model(x=None))) == {'model': {'x': "['x']"}}

0 commit comments

Comments
 (0)