diff --git a/skrl/utils/model_instantiators/jax/common.py b/skrl/utils/model_instantiators/jax/common.py index b54b1dfc..c42c38e9 100644 --- a/skrl/utils/model_instantiators/jax/common.py +++ b/skrl/utils/model_instantiators/jax/common.py @@ -275,12 +275,15 @@ def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mappin logger.warning(f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. ' "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html") # network definition + activations = parameters.get("hidden_activation", []) + if type(activations) in [list, tuple] and len(set(activations)) == 1: + activations = activations[0] network = [ { "name": "net", "input": str(parameters.get("input_shape", "STATES")), "layers": parameters.get("hiddens", []), - "activations": parameters.get("hidden_activation", []), + "activations": activations, } ] # output diff --git a/skrl/utils/model_instantiators/torch/common.py b/skrl/utils/model_instantiators/torch/common.py index 348c0990..cc6e560f 100644 --- a/skrl/utils/model_instantiators/torch/common.py +++ b/skrl/utils/model_instantiators/torch/common.py @@ -282,12 +282,15 @@ def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mappin logger.warning(f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. ' "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html") # network definition + activations = parameters.get("hidden_activation", []) + if type(activations) in [list, tuple] and len(set(activations)) == 1: + activations = activations[0] network = [ { "name": "net", "input": str(parameters.get("input_shape", "STATES")), "layers": parameters.get("hiddens", []), - "activations": parameters.get("hidden_activation", []), + "activations": activations, } ] # output