diff --git a/csp/impl/struct.py b/csp/impl/struct.py index fa5389157..38f061d53 100644 --- a/csp/impl/struct.py +++ b/csp/impl/struct.py @@ -151,6 +151,18 @@ def serializer(val, handler): class Struct(_csptypesimpl.PyStruct, metaclass=StructMeta): + @classmethod + def type_adapter(cls): + internal_type_adapter = getattr(cls, "_pydantic_type_adapter", None) + if internal_type_adapter: + return internal_type_adapter + + # Late import to avoid autogen issues + from pydantic import TypeAdapter + + cls._pydantic_type_adapter = TypeAdapter(cls) + return cls._pydantic_type_adapter + @classmethod def metadata(cls, typed=False): if typed: @@ -235,7 +247,9 @@ def _obj_from_python(cls, json, obj_type): return obj_type(json) @classmethod - def from_dict(cls, json: dict): + def from_dict(cls, json: dict, use_pydantic: bool = False): + if use_pydantic: + return cls.type_adapter().validate_python(json) return cls._obj_from_python(json, cls) def to_dict_depr(self): diff --git a/csp/impl/types/typing_utils.py b/csp/impl/types/typing_utils.py index 95dd76770..f852168d2 100644 --- a/csp/impl/types/typing_utils.py +++ b/csp/impl/types/typing_utils.py @@ -15,6 +15,33 @@ class FastList(typing.List, typing.Generic[T]): # Need to inherit from Generic[ def __init__(self): raise NotImplementedError("Can not init FastList class") + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core import core_schema + + # Late import to not interfere with autogen + args = typing.get_args(source_type) + if args: + inner_type = args[0] + list_schema = handler.generate_schema(typing.List[inner_type]) + else: + list_schema = handler.generate_schema(typing.List) + + def create_instance(raw_data, validator): + if isinstance(raw_data, FastList): + return raw_data + return validator(raw_data) # just return a list + + return core_schema.no_info_wrap_validator_function( + function=create_instance, + schema=list_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda val: list(v for v in val), + return_schema=list_schema, + when_used="json", + ), + ) + class CspTypingUtils39: _ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple} diff --git a/csp/tests/impl/test_struct.py b/csp/tests/impl/test_struct.py index 4255d9e06..a6fdb908a 100644 --- a/csp/tests/impl/test_struct.py +++ b/csp/tests/impl/test_struct.py @@ -2,7 +2,6 @@ import json import pickle import sys -import typing import unittest from datetime import date, datetime, time, timedelta from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union @@ -35,6 +34,7 @@ class StructNoDefaults(csp.Struct): a2: List[str] a3: FastList[object] a4: List[bytes] + a5: Numpy1DArray[float] class StructWithDefaults(csp.Struct): @@ -797,6 +797,8 @@ class MyStruct(csp.Struct): def test_from_dict_with_enum(self): struct = StructWithDefaults.from_dict({"e": MyEnum.A}) self.assertEqual(MyEnum.A, getattr(struct, "e")) + struct = StructWithDefaults.from_dict({"e": MyEnum.A}, use_pydantic=True) + self.assertEqual(MyEnum.A, getattr(struct, "e")) def test_from_dict_with_list_derived_type(self): class ListDerivedType(list): @@ -810,32 +812,62 @@ class StructWithListDerivedType(csp.Struct): self.assertTrue(isinstance(s1.to_dict()["ldt"], ListDerivedType)) s2 = StructWithListDerivedType.from_dict(s1.to_dict()) self.assertEqual(s1, s2) + s3 = StructWithListDerivedType.from_dict(s1.to_dict(), use_pydantic=True) + self.assertEqual(s1, s3) def test_from_dict_loop_no_defaults(self): looped = StructNoDefaults.from_dict(StructNoDefaults(a1=[9, 10]).to_dict()) self.assertEqual(looped, StructNoDefaults(a1=[9, 10])) + looped = StructNoDefaults.from_dict(StructNoDefaults(a1=[9, 10]).to_dict(), use_pydantic=True) + self.assertEqual(looped, StructNoDefaults(a1=[9, 10])) def test_from_dict_loop_with_defaults(self): - looped = StructWithDefaults.from_dict(StructWithDefaults().to_dict()) - # Note that we cant compare numpy arrays, so we check them independently - comp = StructWithDefaults() - self.assertTrue(np.array_equal(looped.np_arr, comp.np_arr)) + for use_pydantic in [True, False]: + looped = StructWithDefaults.from_dict(StructWithDefaults().to_dict(), use_pydantic=use_pydantic) + # Note that we cant compare numpy arrays, so we check them independently + comp = StructWithDefaults() + self.assertTrue(np.array_equal(looped.np_arr, comp.np_arr)) + + del looped.np_arr + del comp.np_arr + self.assertEqual(looped, comp) + + def test_to_json_loop_with_no_defaults(self): + for use_pydantic in [True, False]: + base_struct = StructNoDefaults( + a1=[9, 10], + a5=np.array([1, 2, 3]), + bt=b"ab\001\000c", + ) + if not use_pydantic: + # Need the callback to handle the numpy array type + looped = StructNoDefaults.type_adapter().validate_json(base_struct.to_json(lambda x: x.tolist())) + else: + looped = StructNoDefaults.type_adapter().validate_json( + StructNoDefaults.type_adapter().dump_json(base_struct) + ) + # Note that we cant compare numpy arrays, so we check them independently + self.assertTrue(np.array_equal(looped.a5, base_struct.a5)) - del looped.np_arr - del comp.np_arr - self.assertEqual(looped, comp) + del looped.a5 + del base_struct.a5 + self.assertEqual(looped, base_struct) + self.assertFalse(isinstance(looped.a1, list)) def test_from_dict_loop_with_generic_typing(self): class MyStruct(csp.Struct): foo: Set[int] - bar: Tuple[str] + bar: Tuple[str, ...] np_arr: csp.typing.NumpyNDArray[float] - looped = MyStruct.from_dict(MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])).to_dict()) - expected = MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])) - self.assertEqual(looped.foo, expected.foo) - self.assertEqual(looped.bar, expected.bar) - self.assertTrue(np.all(looped.np_arr == expected.np_arr)) + for use_pydantic in [True, False]: + looped = MyStruct.from_dict( + MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])).to_dict(), use_pydantic=use_pydantic + ) + expected = MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])) + self.assertEqual(looped.foo, expected.foo) + self.assertEqual(looped.bar, expected.bar) + self.assertTrue(np.all(looped.np_arr == expected.np_arr)) def test_struct_yaml_serialization(self): class S1(csp.Struct): @@ -3011,7 +3043,7 @@ class SimpleStruct(csp.Struct): # Valid data valid_data = {"value": 11, "name": "ya", "scores": [1.1, 2.2, 3.3]} - result = TypeAdapter(SimpleStruct).validate_python(valid_data) + result = SimpleStruct.from_dict(valid_data, use_pydantic=True) self.assertIsInstance(result, SimpleStruct) self.assertEqual(result.value, 11) self.assertEqual(result.name, "ya") @@ -3020,11 +3052,11 @@ class SimpleStruct(csp.Struct): invalid_data = valid_data.copy() invalid_data["missing"] = False with self.assertRaises(ValidationError): - TypeAdapter(SimpleStruct).validate_python(invalid_data) # extra fields throw an error + SimpleStruct.from_dict(invalid_data, use_pydantic=True) # extra fields throw an error # Test that we can validate existing structs existing = SimpleStruct(value=1, scores=[1]) - new = TypeAdapter(SimpleStruct).validate_python(existing) + new = SimpleStruct.from_dict(existing, use_pydantic=True) self.assertTrue(existing is new) # we do not revalidate self.assertEqual(existing.value, 1) @@ -3033,7 +3065,7 @@ class SimpleStruct(csp.Struct): "value": "42", # string should convert to int "scores": ["1.1", 2, "3.3"], # mixed types should convert to float } - result = TypeAdapter(SimpleStruct).validate_python(coercion_data) + result = SimpleStruct.from_dict(coercion_data, use_pydantic=True) self.assertEqual(result.value, 42) self.assertEqual(result.scores, [1.1, 2.0, 3.3]) @@ -3043,7 +3075,7 @@ class NestedStruct(csp.Struct): tags: List[str] nested_data = {"simple": {"value": 11, "name": "ya", "scores": [1.1, 2.2, 3.3]}, "tags": ["test1", "test2"]} - result = TypeAdapter(NestedStruct).validate_python(nested_data) + result = NestedStruct.from_dict(nested_data, use_pydantic=True) self.assertIsInstance(result, NestedStruct) self.assertIsInstance(result.simple, SimpleStruct) self.assertEqual(result.simple.value, 11) @@ -3051,7 +3083,7 @@ class NestedStruct(csp.Struct): # 3. Test validation errors with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(SimpleStruct).validate_python({"value": "not an integer", "scores": [1.1, 2.2, "invalid"]}) + SimpleStruct.from_dict({"value": "not an integer", "scores": [1.1, 2.2, "invalid"]}, use_pydantic=True) self.assertIn("Input should be a valid integer", str(exc_info.exception)) # 4. Test with complex types @@ -3064,7 +3096,7 @@ class ComplexStruct(csp.Struct): "dates": ["2023-01-01", "2023-01-02"], # strings should convert to datetime "mapping": {"a": "1.1", "b": 2.2}, # mixed types should convert to float } - result = TypeAdapter(ComplexStruct).validate_python(complex_data) + result = ComplexStruct.from_dict(complex_data, use_pydantic=True) self.assertIsInstance(result.dates[0], datetime) self.assertEqual(result.mapping, {"a": 1.1, "b": 2.2}) @@ -3078,7 +3110,7 @@ class EnumStruct(csp.Struct): enum_list: List[MyEnum] enum_data = {"enum_field": "A", "enum_list": ["A", "B", "A"]} - result = TypeAdapter(EnumStruct).validate_python(enum_data) + result = EnumStruct.from_dict(enum_data, use_pydantic=True) self.assertEqual(result.enum_field, MyEnum.A) self.assertEqual(result.enum_list, [MyEnum.A, MyEnum.B, MyEnum.A]) @@ -3096,7 +3128,7 @@ class StructWithDummy(csp.Struct): val = DummyBlankClass() struct_as_dict = dict(x=12, y=val, z=[val], z1={val: val}, z2=None) - new_struct = TypeAdapter(StructWithDummy).validate_python(struct_as_dict) + new_struct = StructWithDummy.from_dict(struct_as_dict, use_pydantic=True) self.assertTrue(new_struct.y is val) self.assertTrue(new_struct.z[0] is val) self.assertTrue(new_struct.z1[val] is val) @@ -3114,7 +3146,7 @@ class StructWithDummy(csp.Struct): z3=z3_val, z4=z3_val, ) - new_struct = TypeAdapter(StructWithDummy).validate_python(struct_as_dict) + new_struct = StructWithDummy.from_dict(struct_as_dict, use_pydantic=True) self.assertTrue(new_struct.y is val) self.assertTrue(new_struct.z[0] is val) self.assertTrue(new_struct.z1[val] is val) @@ -3207,7 +3239,7 @@ class ProjectStruct(csp.Struct): } # 1. Test validation - result = TypeAdapter(ProjectStruct).validate_python(project_data) + result = ProjectStruct.from_dict(project_data, use_pydantic=True) # Verify the structure self.assertIsInstance(result, ProjectStruct) @@ -3250,14 +3282,14 @@ class ProjectStruct(csp.Struct): invalid_data["task_statuses"][99] = [] # Invalid enum value with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(ProjectStruct).validate_python(invalid_data) + ProjectStruct.from_dict(invalid_data, use_pydantic=True) # 4. Test validation errors with invalid nested types invalid_task_data = project_data.copy() invalid_task_data["task_statuses"][1][0]["metadata"]["priority"] = 99 # Invalid priority with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(ProjectStruct).validate_python(invalid_task_data) + ProjectStruct.from_dict(invalid_task_data, use_pydantic=True) def test_pydantic_models_with_csp_structs(self): """Test Pydantic BaseModels containing CSP Structs as attributes""" @@ -3406,13 +3438,13 @@ class OuterStruct(csp.Struct): inner: Annotated[InnerStruct, WrapValidator(struct_validator)] # Test simple value validation - inner = TypeAdapter(InnerStruct).validate_python({"value": "21"}) + inner = InnerStruct.from_dict({"value": "21"}, use_pydantic=True) self.assertEqual(inner.value, 42) # "21" -> 21 -> 42 self.assertEqual(inner.description, "default") self.assertFalse(hasattr(inner, "z")) # test existing instance - inner_new = TypeAdapter(InnerStruct).validate_python(inner) + inner_new = InnerStruct.from_dict(inner, use_pydantic=True) self.assertTrue(inner is inner_new) # No revalidation self.assertEqual(inner_new.value, 42) @@ -3420,26 +3452,26 @@ class OuterStruct(csp.Struct): # Test validation with invalid value in existing instance inner.value = -5 # Set invalid value # No revalidation, no error - self.assertTrue(inner is TypeAdapter(InnerStruct).validate_python(inner)) + self.assertTrue(inner is InnerStruct.from_dict(inner, use_pydantic=True)) with self.assertRaises(ValidationError) as cm: - TypeAdapter(InnerStruct).validate_python(inner.to_dict()) + InnerStruct.from_dict(inner.to_dict(), use_pydantic=True) self.assertIn("value must be positive", str(cm.exception)) # Test simple value validation - inner = TypeAdapter(InnerStruct).validate_python({"value": "21", "z": 17}) + inner = InnerStruct.from_dict({"value": "21", "z": 17}, use_pydantic=True) self.assertEqual(inner.value, 42) # "21" -> 21 -> 42 self.assertEqual(inner.description, "default") self.assertEqual(inner.z, 17) # Test struct validation with expansion - outer = TypeAdapter(OuterStruct).validate_python({"name": "test", "inner": {"value": 10, "z": 12}}) + outer = OuterStruct.from_dict({"name": "test", "inner": {"value": 10, "z": 12}}, use_pydantic=True) self.assertEqual(outer.inner.value, 20) # 10 -> 20 (doubled) self.assertEqual(outer.inner.description, "auto_generated") self.assertEqual(outer.inner.z, 12) # Test normal full structure still works - outer = TypeAdapter(OuterStruct).validate_python( - {"name": "test", "inner": {"value": "5", "description": "custom"}} + outer = OuterStruct.from_dict( + {"name": "test", "inner": {"value": "5", "description": "custom"}}, use_pydantic=True ) self.assertEqual(outer.inner.value, 10) # "5" -> 5 -> 10 (doubled) self.assertEqual(outer.inner.description, "custom") @@ -3458,50 +3490,55 @@ class MetricStruct(csp.Struct): tags: Union[str, List[str]] = "default" # Test with different value types - metric1 = TypeAdapter(MetricStruct).validate_python( + metric1 = MetricStruct.from_dict( { "value": 42, # int - } + }, + use_pydantic=True, ) self.assertEqual(metric1.value, 42) self.assertIsNone(metric1.name) self.assertEqual(metric1.tags, "default") - metric2 = TypeAdapter(MetricStruct).validate_python( + metric2 = MetricStruct.from_dict( { "value": 42.5, # float "name": "test", "tags": ["tag1", "tag2"], - } + }, + use_pydantic=True, ) self.assertEqual(metric2.value, 42.5) self.assertEqual(metric2.name, "test") self.assertEqual(metric2.tags, ["tag1", "tag2"]) # Test with string that should convert to float - metric3 = TypeAdapter(MetricStruct).validate_python( + metric3 = MetricStruct.from_dict( { "value": "42.5", # should convert to float "tags": "single_tag", # single string tag - } + }, + use_pydantic=True, ) self.assertEqual(metric3.value, 42.5) self.assertEqual(metric3.tags, "single_tag") # Test validation error with invalid type with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(MetricStruct).validate_python( + MetricStruct.from_dict( { "value": "not a number", - } + }, + use_pydantic=True, ) self.assertIn("Input should be a valid number", str(exc_info.exception)) # Test with string that should convert to float - metric3 = TypeAdapter(MetricStruct).validate_python( + metric3 = MetricStruct.from_dict( { "tags": "single_tag" # single string tag - } + }, + use_pydantic=True, ) self.assertFalse(hasattr(metric3, "value")) self.assertEqual(metric3.tags, "single_tag") @@ -3528,7 +3565,7 @@ class DataPoint(csp.Struct): # Test with MetricStruct metric_data = {"id": "metric-1", "data": {"value": 42.5, "unit": "celsius"}} - result = TypeAdapter(DataPoint).validate_python(metric_data) + result = DataPoint.from_dict(metric_data, use_pydantic=True) self.assertIsInstance(result.data, MetricStruct) self.assertEqual(result.data.value, 42.5) self.assertEqual(result.data.unit, "celsius") @@ -3542,14 +3579,14 @@ class DataPoint(csp.Struct): {"name": "previous_event", "timestamp": "2023-01-01T11:00:00"}, ], } - result = TypeAdapter(DataPoint).validate_python(event_data) + result = DataPoint.from_dict(event_data, use_pydantic=True) self.assertIsInstance(result.data, EventStruct) self.assertEqual(result.data.name, "system_start") self.assertIsInstance(result.history[0], MetricStruct) self.assertIsInstance(result.history[1], EventStruct) # Test serialization and deserialization - result = TypeAdapter(DataPoint).validate_python(event_data) + result = DataPoint.from_dict(event_data, use_pydantic=True) json_data = result.to_json() restored = TypeAdapter(DataPoint).validate_json(json_data) @@ -3590,14 +3627,14 @@ class DataPoint(csp.Struct): "precision": 1, # specific to TemperatureMetric }, } - result = TypeAdapter(DataPoint).validate_python(temp_data) + result = DataPoint.from_dict(temp_data, use_pydantic=True) self.assertIsInstance(result.metric, TemperatureMetric) # Should be TemperatureMetric, not BaseMetric self.assertEqual(result.metric.unit, "celsius") self.assertEqual(result.metric.precision, 1) # Test with PressureMetric data pressure_data = {"id": "pressure-1", "metric": {"name": "pressure", "value": 101.325, "altitude": 0.0}} - result = TypeAdapter(DataPoint).validate_python(pressure_data) + result = DataPoint.from_dict(pressure_data, use_pydantic=True) self.assertIsInstance(result.metric, PressureMetric) # Should be PressureMetric, not BaseMetric self.assertEqual(result.metric.unit, "pascal") self.assertEqual(result.metric.altitude, 0.0) @@ -3618,7 +3655,7 @@ class DataPoint(csp.Struct): }, ], } - result = TypeAdapter(DataPoint).validate_python(mixed_data) + result = DataPoint.from_dict(mixed_data, use_pydantic=True) self.assertIsInstance(result.metric, BaseMetric) # Should be base metric self.assertIsInstance(result.history[0], TemperatureMetric) # Should be temperature self.assertIsInstance(result.history[1], PressureMetric) # Should be pressure @@ -3783,7 +3820,7 @@ class NestedStruct(csp.Struct): self.assertEqual(enum_as_enum.name, enum_as_str) self.assertEqual( - nested, TypeAdapter(NestedStruct).validate_python(TypeAdapter(NestedStruct).dump_python(nested)) + nested, NestedStruct.from_dict(TypeAdapter(NestedStruct).dump_python(nested), use_pydantic=True) ) json_native = nested.to_json() @@ -3803,7 +3840,7 @@ class NPStruct(csp.Struct): NPStruct(arr=np.array([1, 3, "ab"])) # No error, even though the types are wrong with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(NPStruct).validate_python(dict(arr=[1, 3, "ab"])) + NPStruct.from_dict(dict(arr=[1, 3, "ab"]), use_pydantic=True) self.assertIn("could not convert string to float", str(exc_info.exception)) # We should be able to generate the json_schema TypeAdapter(NPStruct).json_schema() @@ -3852,7 +3889,7 @@ class DataPoint(csp.Struct): }, } - result = TypeAdapter(DataPoint).validate_python(metric_data) + result = DataPoint.from_dict(metric_data, use_pydantic=True) # Verify private fields are properly set including inherited ones self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0)) @@ -3894,7 +3931,7 @@ class DataPoint(csp.Struct): }, } - result = TypeAdapter(DataPoint).validate_python(event_data) + result = DataPoint.from_dict(event_data, use_pydantic=True) # Verify private fields are set but excluded from serialization self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0))