Skip to content

Commit

Permalink
Add reduction parameter to gaussian_model instantiator
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 21, 2024
1 parent d2aee9f commit 66842c4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions skrl/utils/model_instantiators/jax/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def gaussian_model(
clip_log_std: bool = True,
min_log_std: float = -20,
max_log_std: float = 2,
reduction: str = "sum",
initial_log_std: float = 0,
network: Sequence[Mapping[str, Any]] = [],
output: Union[str, Sequence[str]] = "",
Expand All @@ -47,6 +48,10 @@ def gaussian_model(
:type min_log_std: float, optional
:param max_log_std: Maximum value of the log standard deviation (default: 2)
:type max_log_std: float, optional
:param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``).
Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density
function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)``
:type reduction: str, optional
:param initial_log_std: Initial value for the log standard deviation (default: 0)
:type initial_log_std: float, optional
:param network: Network definition (default: [])
Expand Down Expand Up @@ -117,4 +122,5 @@ def __call__(self, inputs, role):
clip_log_std=clip_log_std,
min_log_std=min_log_std,
max_log_std=max_log_std,
reduction=reduction,
)
6 changes: 6 additions & 0 deletions skrl/utils/model_instantiators/torch/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def gaussian_model(
clip_log_std: bool = True,
min_log_std: float = -20,
max_log_std: float = 2,
reduction: str = "sum",
initial_log_std: float = 0,
network: Sequence[Mapping[str, Any]] = [],
output: Union[str, Sequence[str]] = "",
Expand All @@ -46,6 +47,10 @@ def gaussian_model(
:type min_log_std: float, optional
:param max_log_std: Maximum value of the log standard deviation (default: 2)
:type max_log_std: float, optional
:param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``).
Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density
function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)``
:type reduction: str, optional
:param initial_log_std: Initial value for the log standard deviation (default: 0)
:type initial_log_std: float, optional
:param network: Network definition (default: [])
Expand Down Expand Up @@ -115,4 +120,5 @@ def compute(self, inputs, role=""):
clip_log_std=clip_log_std,
min_log_std=min_log_std,
max_log_std=max_log_std,
reduction=reduction,
)

0 comments on commit 66842c4

Please sign in to comment.