Skip to content

Commit

Permalink
Add parameter to define immutable log standard deviations
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 8, 2025
1 parent e5c6b81 commit dbdf374
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Support for other model types than Gaussian and Deterministic in runners
- Support for automatic mixed precision training in PyTorch
- `init_state_dict` method to initialize model's lazy modules in PyTorch
- Model instantiators `fixed_log_std` parameter to define immutable log standard deviations

### Changed
- Call agent's `pre_interaction` method during evaluation
Expand Down
12 changes: 11 additions & 1 deletion skrl/utils/model_instantiators/jax/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def gaussian_model(
max_log_std: float = 2,
reduction: str = "sum",
initial_log_std: float = 0,
fixed_log_std: bool = False,
network: Sequence[Mapping[str, Any]] = [],
output: Union[str, Sequence[str]] = "",
return_source: bool = False,
Expand Down Expand Up @@ -54,6 +55,9 @@ def gaussian_model(
:type reduction: str, optional
:param initial_log_std: Initial value for the log standard deviation (default: 0)
:type initial_log_std: float, optional
:param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False).
Fixed parameters will be excluded from model parameters.
:type fixed_log_std: bool, optional
:param network: Network definition (default: [])
:type network: list of dict, optional
:param output: Output expression (default: "")
Expand Down Expand Up @@ -90,6 +94,12 @@ def gaussian_model(
# build substitutions and indent content
networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:]
forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:]
if fixed_log_std:
log_std_parameter = f'jnp.full(shape={output["size"]}, fill_value={initial_log_std})'
else:
log_std_parameter = (
f'self.param("log_std_parameter", lambda _: jnp.full(shape={output["size"]}, fill_value={initial_log_std}))'
)

template = f"""class GaussianModel(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
Expand All @@ -99,7 +109,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False,
def setup(self):
{networks}
self.log_std_parameter = self.param("log_std_parameter", lambda _: {initial_log_std} * jnp.ones({output["size"]}))
self.log_std_parameter = {log_std_parameter}
def __call__(self, inputs, role):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
Expand Down
6 changes: 5 additions & 1 deletion skrl/utils/model_instantiators/torch/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def gaussian_model(
max_log_std: float = 2,
reduction: str = "sum",
initial_log_std: float = 0,
fixed_log_std: bool = False,
network: Sequence[Mapping[str, Any]] = [],
output: Union[str, Sequence[str]] = "",
return_source: bool = False,
Expand Down Expand Up @@ -53,6 +54,9 @@ def gaussian_model(
:type reduction: str, optional
:param initial_log_std: Initial value for the log standard deviation (default: 0)
:type initial_log_std: float, optional
:param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False).
Fixed parameters have the gradient computation deactivated
:type fixed_log_std: bool, optional
:param network: Network definition (default: [])
:type network: list of dict, optional
:param output: Output expression (default: "")
Expand Down Expand Up @@ -97,7 +101,7 @@ def __init__(self, observation_space, action_space, device, clip_actions,
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
{networks}
self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]}))
self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})
def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def multivariate_gaussian_model(
min_log_std: float = -20,
max_log_std: float = 2,
initial_log_std: float = 0,
fixed_log_std: bool = False,
network: Sequence[Mapping[str, Any]] = [],
output: Union[str, Sequence[str]] = "",
return_source: bool = False,
Expand Down Expand Up @@ -48,6 +49,9 @@ def multivariate_gaussian_model(
:type max_log_std: float, optional
:param initial_log_std: Initial value for the log standard deviation (default: 0)
:type initial_log_std: float, optional
:param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False).
Fixed parameters have the gradient computation deactivated
:type fixed_log_std: bool, optional
:param network: Network definition (default: [])
:type network: list of dict, optional
:param output: Output expression (default: "")
Expand Down Expand Up @@ -92,7 +96,7 @@ def __init__(self, observation_space, action_space, device, clip_actions,
MultivariateGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
{networks}
self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]}))
self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})
def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
Expand Down
6 changes: 4 additions & 2 deletions skrl/utils/model_instantiators/torch/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ def get_extra(class_name, parameter, role, model):
return ""
elif class_name.lower() == "gaussianmixin":
initial_log_std = float(parameter.get("initial_log_std", 0))
return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}))'
fixed_log_std = parameter.get("fixed_log_std", False)
return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})'
elif class_name.lower() == "multivariategaussianmixin":
initial_log_std = float(parameter.get("initial_log_std", 0))
return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}))'
fixed_log_std = parameter.get("fixed_log_std", False)
return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})'
raise ValueError(f"Unknown class: {class_name}")

# compatibility with versions prior to 1.3.0
Expand Down

0 comments on commit dbdf374

Please sign in to comment.