Skip to content

Commit

Permalink
Add type hints for hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Apr 26, 2024
1 parent 3746777 commit a88a5c1
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

from attrs import define, field
from cattrs import override
from cattrs.dispatch import (
StructuredValue,
StructureHook,
TargetType,
UnstructuredValue,
UnstructureHook,
)

from baybe.searchspace import SearchSpace
from baybe.serialization import SerialMixin, converter, unstructure_base
Expand Down Expand Up @@ -160,33 +167,35 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No
"""


def _decode_onnx_str(raw_unstructure_hook):
def _decode_onnx_str(raw_unstructure_hook: UnstructureHook) -> UnstructureHook:
"""Decode ONNX string for serialization purposes."""

def wrapper(obj):
dict_ = raw_unstructure_hook(obj)
if "onnx_str" in dict_:
dict_["onnx_str"] = dict_["onnx_str"].decode(_ONNX_ENCODING)
def wrapper(obj: StructuredValue) -> UnstructuredValue:
dct = raw_unstructure_hook(obj)
if "onnx_str" in dct:
dct["onnx_str"] = dct["onnx_str"].decode(_ONNX_ENCODING)

return dict_
return dct

return wrapper


def _encode_onnx_str(raw_structure_hook):
def _encode_onnx_str(raw_structure_hook: StructureHook) -> StructureHook:
"""Encode ONNX string for deserialization purposes."""

def wrapper(dict_, _):
if (onnx_str := dict_.get("onnx_str")) and isinstance(onnx_str, str):
dict_["onnx_str"] = onnx_str.encode(_ONNX_ENCODING)
obj = raw_structure_hook(dict_, _)
def wrapper(dct: UnstructuredValue, _: TargetType) -> StructuredValue:
if (onnx_str := dct.get("onnx_str")) and isinstance(onnx_str, str):
dct["onnx_str"] = onnx_str.encode(_ONNX_ENCODING)
obj = raw_structure_hook(dct, _)

return obj

return wrapper


def _block_serialize_custom_architecture(raw_unstructure_hook):
def _block_serialize_custom_architecture(
raw_unstructure_hook: UnstructureHook
) -> UnstructureHook:
"""Raise error if attempt to serialize a custom architecture surrogate."""
# TODO: Ideally, this hook should be removed and unstructuring the Surrogate
# base class should automatically invoke the blocking hook that is already
Expand All @@ -196,7 +205,7 @@ def _block_serialize_custom_architecture(raw_unstructure_hook):
# because the role of the subclass will probably be replaced with a surrogate
# protocol.

def wrapper(obj):
def wrapper(obj: StructuredValue) -> UnstructuredValue:
if obj.__class__.__name__ == "CustomArchitectureSurrogate":
raise NotImplementedError(
"Serializing objects of type 'CustomArchitectureSurrogate' "
Expand Down

0 comments on commit a88a5c1

Please sign in to comment.