diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 166e5d0c34..6c9885e603 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -7,6 +7,13 @@ from attrs import define, field from cattrs import override +from cattrs.dispatch import ( + StructuredValue, + StructureHook, + TargetType, + UnstructuredValue, + UnstructureHook, +) from baybe.searchspace import SearchSpace from baybe.serialization import SerialMixin, converter, unstructure_base @@ -160,33 +167,35 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No """ -def _decode_onnx_str(raw_unstructure_hook): +def _decode_onnx_str(raw_unstructure_hook: UnstructureHook) -> UnstructureHook: """Decode ONNX string for serialization purposes.""" - def wrapper(obj): - dict_ = raw_unstructure_hook(obj) - if "onnx_str" in dict_: - dict_["onnx_str"] = dict_["onnx_str"].decode(_ONNX_ENCODING) + def wrapper(obj: StructuredValue) -> UnstructuredValue: + dct = raw_unstructure_hook(obj) + if "onnx_str" in dct: + dct["onnx_str"] = dct["onnx_str"].decode(_ONNX_ENCODING) - return dict_ + return dct return wrapper -def _encode_onnx_str(raw_structure_hook): +def _encode_onnx_str(raw_structure_hook: StructureHook) -> StructureHook: """Encode ONNX string for deserialization purposes.""" - def wrapper(dict_, _): - if (onnx_str := dict_.get("onnx_str")) and isinstance(onnx_str, str): - dict_["onnx_str"] = onnx_str.encode(_ONNX_ENCODING) - obj = raw_structure_hook(dict_, _) + def wrapper(dct: UnstructuredValue, _: TargetType) -> StructuredValue: + if (onnx_str := dct.get("onnx_str")) and isinstance(onnx_str, str): + dct["onnx_str"] = onnx_str.encode(_ONNX_ENCODING) + obj = raw_structure_hook(dct, _) return obj return wrapper -def _block_serialize_custom_architecture(raw_unstructure_hook): +def _block_serialize_custom_architecture( + raw_unstructure_hook: UnstructureHook +) -> UnstructureHook: """Raise error if attempt to serialize a custom architecture surrogate.""" # TODO: Ideally, this hook should be removed and unstructuring the Surrogate # base class should automatically invoke the blocking hook that is already @@ -196,7 +205,7 @@ def _block_serialize_custom_architecture(raw_unstructure_hook): # because the role of the subclass will probably be replaced with a surrogate # protocol. - def wrapper(obj): + def wrapper(obj: StructuredValue) -> UnstructuredValue: if obj.__class__.__name__ == "CustomArchitectureSurrogate": raise NotImplementedError( "Serializing objects of type 'CustomArchitectureSurrogate' "