diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e2d634c..6025b604 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Support for other model types than Gaussian and Deterministic in runners - Support for automatic mixed precision training in PyTorch - `init_state_dict` method to initialize model's lazy modules in PyTorch +- Model instantiators `fixed_log_std` parameter to define immutable log standard deviations ### Changed - Call agent's `pre_interaction` method during evaluation diff --git a/skrl/utils/model_instantiators/jax/gaussian.py b/skrl/utils/model_instantiators/jax/gaussian.py index 916dedf3..22a05b29 100644 --- a/skrl/utils/model_instantiators/jax/gaussian.py +++ b/skrl/utils/model_instantiators/jax/gaussian.py @@ -23,6 +23,7 @@ def gaussian_model( max_log_std: float = 2, reduction: str = "sum", initial_log_std: float = 0, + fixed_log_std: bool = False, network: Sequence[Mapping[str, Any]] = [], output: Union[str, Sequence[str]] = "", return_source: bool = False, @@ -54,6 +55,9 @@ def gaussian_model( :type reduction: str, optional :param initial_log_std: Initial value for the log standard deviation (default: 0) :type initial_log_std: float, optional + :param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False). + Fixed parameters will be excluded from model parameters. + :type fixed_log_std: bool, optional :param network: Network definition (default: []) :type network: list of dict, optional :param output: Output expression (default: "") @@ -90,6 +94,12 @@ def gaussian_model( # build substitutions and indent content networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] + if fixed_log_std: + log_std_parameter = f'jnp.full(shape={output["size"]}, fill_value={initial_log_std})' + else: + log_std_parameter = ( + f'self.param("log_std_parameter", lambda _: jnp.full(shape={output["size"]}, fill_value={initial_log_std}))' + ) template = f"""class GaussianModel(GaussianMixin, Model): def __init__(self, observation_space, action_space, device, clip_actions=False, @@ -99,7 +109,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False, def setup(self): {networks} - self.log_std_parameter = self.param("log_std_parameter", lambda _: {initial_log_std} * jnp.ones({output["size"]})) + self.log_std_parameter = {log_std_parameter} def __call__(self, inputs, role): states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) diff --git a/skrl/utils/model_instantiators/torch/gaussian.py b/skrl/utils/model_instantiators/torch/gaussian.py index 4c378b86..886202ef 100644 --- a/skrl/utils/model_instantiators/torch/gaussian.py +++ b/skrl/utils/model_instantiators/torch/gaussian.py @@ -22,6 +22,7 @@ def gaussian_model( max_log_std: float = 2, reduction: str = "sum", initial_log_std: float = 0, + fixed_log_std: bool = False, network: Sequence[Mapping[str, Any]] = [], output: Union[str, Sequence[str]] = "", return_source: bool = False, @@ -53,6 +54,9 @@ def gaussian_model( :type reduction: str, optional :param initial_log_std: Initial value for the log standard deviation (default: 0) :type initial_log_std: float, optional + :param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False). + Fixed parameters have the gradient computation deactivated + :type fixed_log_std: bool, optional :param network: Network definition (default: []) :type network: list of dict, optional :param output: Output expression (default: "") @@ -97,7 +101,7 @@ def __init__(self, observation_space, action_space, device, clip_actions, GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) {networks} - self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]})) + self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std}) def compute(self, inputs, role=""): states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) diff --git a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py index 41b62300..347567f6 100644 --- a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py +++ b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py @@ -21,6 +21,7 @@ def multivariate_gaussian_model( min_log_std: float = -20, max_log_std: float = 2, initial_log_std: float = 0, + fixed_log_std: bool = False, network: Sequence[Mapping[str, Any]] = [], output: Union[str, Sequence[str]] = "", return_source: bool = False, @@ -48,6 +49,9 @@ def multivariate_gaussian_model( :type max_log_std: float, optional :param initial_log_std: Initial value for the log standard deviation (default: 0) :type initial_log_std: float, optional + :param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False). + Fixed parameters have the gradient computation deactivated + :type fixed_log_std: bool, optional :param network: Network definition (default: []) :type network: list of dict, optional :param output: Output expression (default: "") @@ -92,7 +96,7 @@ def __init__(self, observation_space, action_space, device, clip_actions, MultivariateGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) {networks} - self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]})) + self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std}) def compute(self, inputs, role=""): states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) diff --git a/skrl/utils/model_instantiators/torch/shared.py b/skrl/utils/model_instantiators/torch/shared.py index 193c4b42..b0bac07f 100644 --- a/skrl/utils/model_instantiators/torch/shared.py +++ b/skrl/utils/model_instantiators/torch/shared.py @@ -95,10 +95,12 @@ def get_extra(class_name, parameter, role, model): return "" elif class_name.lower() == "gaussianmixin": initial_log_std = float(parameter.get("initial_log_std", 0)) - return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}))' + fixed_log_std = parameter.get("fixed_log_std", False) + return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})' elif class_name.lower() == "multivariategaussianmixin": initial_log_std = float(parameter.get("initial_log_std", 0)) - return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}))' + fixed_log_std = parameter.get("fixed_log_std", False) + return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})' raise ValueError(f"Unknown class: {class_name}") # compatibility with versions prior to 1.3.0