diff --git a/skrl/utils/model_instantiators/jax/gaussian.py b/skrl/utils/model_instantiators/jax/gaussian.py index 529be7fd..916dedf3 100644 --- a/skrl/utils/model_instantiators/jax/gaussian.py +++ b/skrl/utils/model_instantiators/jax/gaussian.py @@ -21,6 +21,7 @@ def gaussian_model( clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, + reduction: str = "sum", initial_log_std: float = 0, network: Sequence[Mapping[str, Any]] = [], output: Union[str, Sequence[str]] = "", @@ -47,6 +48,10 @@ def gaussian_model( :type min_log_std: float, optional :param max_log_std: Maximum value of the log standard deviation (default: 2) :type max_log_std: float, optional + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional :param initial_log_std: Initial value for the log standard deviation (default: 0) :type initial_log_std: float, optional :param network: Network definition (default: []) @@ -117,4 +122,5 @@ def __call__(self, inputs, role): clip_log_std=clip_log_std, min_log_std=min_log_std, max_log_std=max_log_std, + reduction=reduction, ) diff --git a/skrl/utils/model_instantiators/torch/gaussian.py b/skrl/utils/model_instantiators/torch/gaussian.py index b37cdefc..4c378b86 100644 --- a/skrl/utils/model_instantiators/torch/gaussian.py +++ b/skrl/utils/model_instantiators/torch/gaussian.py @@ -20,6 +20,7 @@ def gaussian_model( clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, + reduction: str = "sum", initial_log_std: float = 0, network: Sequence[Mapping[str, Any]] = [], output: Union[str, Sequence[str]] = "", @@ -46,6 +47,10 @@ def gaussian_model( :type min_log_std: float, optional :param max_log_std: Maximum value of the log standard deviation (default: 2) :type max_log_std: float, optional + :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). + Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density + function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` + :type reduction: str, optional :param initial_log_std: Initial value for the log standard deviation (default: 0) :type initial_log_std: float, optional :param network: Network definition (default: []) @@ -115,4 +120,5 @@ def compute(self, inputs, role=""): clip_log_std=clip_log_std, min_log_std=min_log_std, max_log_std=max_log_std, + reduction=reduction, )