Skip to content

Commit

Permalink
Add type error suppressions for upcoming upgrade
Browse files Browse the repository at this point in the history
Differential Revision: D65342765

fbshipit-source-id: 585996571a368ef0aaed2fcbfa940a038647443f
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Nov 1, 2024
1 parent 4ffb0cd commit c11becd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions pearl/utils/functional_utils/learning/critic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class `TwinCritic` will be instantiated with the specified `network_type`. Note

# cast network_type to get around static Pyre type checking; the runtime check with
# `issubclass` ensures the network type is a sublcass of QValueNetwork
# pyre-fixme[22]: The cast is redundant.
network_type = cast(Type[QValueNetwork], network_type)

return TwinCritic(
Expand Down
4 changes: 4 additions & 0 deletions pearl/utils/instantiations/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def __init__(
seed: Random seed used to initialize the random number generator of the
underlying Gymnasium `Box` space.
"""
# pyre-fixme[9]: low has type `Union[float, Tensor]`; used as `ndarray[Any,
# Any]`.
low = low.numpy(force=True) if isinstance(low, Tensor) else np.array([low])
# pyre-fixme[9]: high has type `Union[float, Tensor]`; used as `ndarray[Any,
# Any]`.
high = high.numpy(force=True) if isinstance(high, Tensor) else np.array([high])
self._gym_space = Box(low=low, high=high, seed=seed)

Expand Down
4 changes: 4 additions & 0 deletions pearl/utils/instantiations/spaces/box_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def __init__(
seed: Random seed used to initialize the random number generator of the
underlying Gymnasium `Box` space.
"""
# pyre-fixme[9]: low has type `Union[float, Tensor]`; used as `ndarray[Any,
# Any]`.
low = (
reshape_to_1d_tensor(low).numpy(force=True)
if isinstance(low, Tensor)
else np.array([low])
)
# pyre-fixme[9]: high has type `Union[float, Tensor]`; used as `ndarray[Any,
# Any]`.
high = (
reshape_to_1d_tensor(high).numpy(force=True)
if isinstance(high, Tensor)
Expand Down

0 comments on commit c11becd

Please sign in to comment.