From afadb50c26297992189286b10f062f493d702a46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 13 Aug 2024 12:55:59 -0400 Subject: [PATCH] Update annotation and doctring --- .../utils/model_instantiators/jax/__init__.py | 12 +++++------ .../model_instantiators/torch/__init__.py | 20 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/skrl/utils/model_instantiators/jax/__init__.py b/skrl/utils/model_instantiators/jax/__init__.py index 632d753e..d1a0c98c 100644 --- a/skrl/utils/model_instantiators/jax/__init__.py +++ b/skrl/utils/model_instantiators/jax/__init__.py @@ -152,7 +152,7 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = "tanh", output_scale: float = 1.0, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a Gaussian model :param observation_space: Observation/state space or shape (default: None). @@ -191,7 +191,7 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Gaussian model instance + :return: Gaussian model instance or definition source :rtype: Model """ # network @@ -242,7 +242,7 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = "tanh", output_scale: float = 1.0, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a deterministic model :param observation_space: Observation/state space or shape (default: None). @@ -273,7 +273,7 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Deterministic model instance + :return: Deterministic model instance or definition source :rtype: Model """ # network @@ -318,7 +318,7 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa hidden_activation: list = ["relu", "relu"], output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = None, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a categorical model :param observation_space: Observation/state space or shape (default: None). @@ -349,7 +349,7 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Categorical model instance + :return: Categorical model instance or definition source :rtype: Model """ # network diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py index 53743870..ce6151e1 100644 --- a/skrl/utils/model_instantiators/torch/__init__.py +++ b/skrl/utils/model_instantiators/torch/__init__.py @@ -161,7 +161,7 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = "tanh", output_scale: float = 1.0, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a Gaussian model :param observation_space: Observation/state space or shape (default: None). @@ -200,7 +200,7 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Gaussian model instance + :return: Gaussian model instance or definition source :rtype: Model """ # network @@ -254,7 +254,7 @@ def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = "tanh", output_scale: float = 1.0, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a multivariate Gaussian model :param observation_space: Observation/state space or shape (default: None). @@ -293,7 +293,7 @@ def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Multivariate Gaussian model instance + :return: Multivariate Gaussian model instance or definition source :rtype: Model """ # network @@ -343,7 +343,7 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = "tanh", output_scale: float = 1.0, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a deterministic model :param observation_space: Observation/state space or shape (default: None). @@ -374,7 +374,7 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Deterministic model instance + :return: Deterministic model instance or definition source :rtype: Model """ # network @@ -418,7 +418,7 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa hidden_activation: list = ["relu", "relu"], output_shape: Shape = Shape.ACTIONS, output_activation: Optional[str] = None, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a categorical model :param observation_space: Observation/state space or shape (default: None). @@ -449,7 +449,7 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Categorical model instance + :return: Categorical model instance or definition source :rtype: Model """ # network @@ -489,7 +489,7 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g roles: Sequence[str] = [], parameters: Sequence[Mapping[str, Any]] = [], single_forward_pass: bool = True, - return_source: bool = False) -> Model: + return_source: bool = False) -> Union[Model, str]: """Instantiate a shared model :param observation_space: Observation/state space or shape (default: None). @@ -514,7 +514,7 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g instantiate the model rather than the model instance (default: False). :type return_source: bool, optional - :return: Shared model instance + :return: Shared model instance or definition source :rtype: Model """ # network