diff --git a/skrl/utils/model_instantiators/torch/categorical.py b/skrl/utils/model_instantiators/torch/categorical.py index 3be54bb0..48d70b51 100644 --- a/skrl/utils/model_instantiators/torch/categorical.py +++ b/skrl/utils/model_instantiators/torch/categorical.py @@ -10,6 +10,7 @@ from skrl.models.torch import CategoricalMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, @@ -82,6 +83,8 @@ def __init__(self, observation_space, action_space, device, unnormalized_log_pro {networks} def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/common.py b/skrl/utils/model_instantiators/torch/common.py index 331427b9..18cadf54 100644 --- a/skrl/utils/model_instantiators/torch/common.py +++ b/skrl/utils/model_instantiators/torch/common.py @@ -53,7 +53,7 @@ def visit_Call(self, node: ast.Call): node.func = ast.Attribute(value=ast.Name("torch"), attr="cat") node.keywords = [ast.keyword(arg="dim", value=ast.Constant(value=1))] # operation: permute - if node.func.id == "permute": + elif node.func.id == "permute": node.func = ast.Attribute(value=ast.Name("torch"), attr="permute") return node @@ -62,11 +62,11 @@ def visit_Call(self, node: ast.Call): NodeTransformer().visit(tree) source = ast.unparse(tree) # enum substitutions - source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)') - source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)') - source = source.replace("Shape.STATES", "STATES").replace("STATES", 'inputs["states"]') - source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'inputs["states"]') - source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'inputs["taken_actions"]') + source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'torch.cat((states, taken_actions), dim=1)') + source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'torch.cat((states, taken_actions), dim=1)') + source = source.replace("Shape.STATES", "STATES").replace("STATES", 'states') + source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'states') + source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'taken_actions') return source def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]: diff --git a/skrl/utils/model_instantiators/torch/deterministic.py b/skrl/utils/model_instantiators/torch/deterministic.py index ca67c9b6..ce0475ab 100644 --- a/skrl/utils/model_instantiators/torch/deterministic.py +++ b/skrl/utils/model_instantiators/torch/deterministic.py @@ -10,6 +10,7 @@ from skrl.models.torch import DeterministicMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, @@ -79,6 +80,8 @@ def __init__(self, observation_space, action_space, device, clip_actions): {networks} def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/gaussian.py b/skrl/utils/model_instantiators/torch/gaussian.py index b3e5cfce..48b2ac37 100644 --- a/skrl/utils/model_instantiators/torch/gaussian.py +++ b/skrl/utils/model_instantiators/torch/gaussian.py @@ -10,6 +10,7 @@ from skrl.models.torch import GaussianMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, @@ -93,6 +94,8 @@ def __init__(self, observation_space, action_space, device, clip_actions, self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]})) def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, self.log_std_parameter, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py index e6d46f00..ea079f65 100644 --- a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py +++ b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py @@ -10,6 +10,7 @@ from skrl.models.torch import MultivariateGaussianMixin # noqa from skrl.models.torch import Model from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, @@ -93,6 +94,8 @@ def __init__(self, observation_space, action_space, device, clip_actions, self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]})) def compute(self, inputs, role=""): + states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) + taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) {forward} return output, self.log_std_parameter, {{}} """ diff --git a/skrl/utils/model_instantiators/torch/shared.py b/skrl/utils/model_instantiators/torch/shared.py index 861eea93..8e1de43b 100644 --- a/skrl/utils/model_instantiators/torch/shared.py +++ b/skrl/utils/model_instantiators/torch/shared.py @@ -10,6 +10,7 @@ from skrl.models.torch import Model # noqa from skrl.models.torch import DeterministicMixin, GaussianMixin # noqa from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers +from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None, @@ -68,6 +69,8 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g for container in containers_gaussian: networks_common.append(f'self.{container["name"]}_container = {container["sequential"]}') forward_common.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') + forward_common.insert(0, 'taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))') + forward_common.insert(0, 'states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))') # process output networks_gaussian = []