From 530f17ca2c60fe299faacce213e5532238400b5b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 30 Apr 2024 10:07:21 +0200 Subject: [PATCH] Rename ONNX un-/structuring hooks and adjust docstrings --- baybe/surrogates/base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 43fb512010..5d2f447034 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -154,8 +154,10 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No """ -def _decode_onnx_str(raw_unstructure_hook: UnstructureHook) -> UnstructureHook: - """Decode ONNX string for serialization purposes.""" +def _make_hook_decode_onnx_str( + raw_unstructure_hook: UnstructureHook +) -> UnstructureHook: + """Wrap an unstructuring hook to let it also decode the contained ONNX string.""" def wrapper(obj: StructuredValue) -> UnstructuredValue: dct = raw_unstructure_hook(obj) @@ -167,8 +169,8 @@ def wrapper(obj: StructuredValue) -> UnstructuredValue: return wrapper -def _encode_onnx_str(raw_structure_hook: StructureHook) -> StructureHook: - """Encode ONNX string for deserialization purposes.""" +def _make_hook_encode_onnx_str(raw_structure_hook: StructureHook) -> StructureHook: + """Wrap a structuring hook to let it also encode the contained ONNX string.""" def wrapper(dct: UnstructuredValue, _: TargetType) -> StructuredValue: if (onnx_str := dct.get("onnx_str")) and isinstance(onnx_str, str): @@ -211,12 +213,12 @@ def wrapper(obj: StructuredValue) -> UnstructuredValue: # existing hooks of the concrete subclasses. converter.register_unstructure_hook( Surrogate, - _decode_onnx_str( + _make_hook_decode_onnx_str( _block_serialize_custom_architecture( lambda x: unstructure_base(x, overrides={"_model": override(omit=True)}) ) ), ) converter.register_structure_hook( - Surrogate, _encode_onnx_str(get_base_structure_hook(Surrogate)) + Surrogate, _make_hook_encode_onnx_str(get_base_structure_hook(Surrogate)) )