Skip to content

Commit

Permalink
Rename ONNX un-/structuring hooks and adjust docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Apr 30, 2024
1 parent 4efc004 commit 530f17c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,10 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No
"""


def _decode_onnx_str(raw_unstructure_hook: UnstructureHook) -> UnstructureHook:
"""Decode ONNX string for serialization purposes."""
def _make_hook_decode_onnx_str(
raw_unstructure_hook: UnstructureHook
) -> UnstructureHook:
"""Wrap an unstructuring hook to let it also decode the contained ONNX string."""

def wrapper(obj: StructuredValue) -> UnstructuredValue:
dct = raw_unstructure_hook(obj)
Expand All @@ -167,8 +169,8 @@ def wrapper(obj: StructuredValue) -> UnstructuredValue:
return wrapper


def _encode_onnx_str(raw_structure_hook: StructureHook) -> StructureHook:
"""Encode ONNX string for deserialization purposes."""
def _make_hook_encode_onnx_str(raw_structure_hook: StructureHook) -> StructureHook:
"""Wrap a structuring hook to let it also encode the contained ONNX string."""

def wrapper(dct: UnstructuredValue, _: TargetType) -> StructuredValue:
if (onnx_str := dct.get("onnx_str")) and isinstance(onnx_str, str):
Expand Down Expand Up @@ -211,12 +213,12 @@ def wrapper(obj: StructuredValue) -> UnstructuredValue:
# existing hooks of the concrete subclasses.
converter.register_unstructure_hook(
Surrogate,
_decode_onnx_str(
_make_hook_decode_onnx_str(
_block_serialize_custom_architecture(
lambda x: unstructure_base(x, overrides={"_model": override(omit=True)})
)
),
)
converter.register_structure_hook(
Surrogate, _encode_onnx_str(get_base_structure_hook(Surrogate))
Surrogate, _make_hook_encode_onnx_str(get_base_structure_hook(Surrogate))
)

0 comments on commit 530f17c

Please sign in to comment.