Skip to content

Commit

Permalink
Update model defintions to support different input spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Sep 29, 2024
1 parent 3324048 commit 45418b8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 6 deletions.
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/torch/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skrl.models.torch import CategoricalMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa


def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -82,6 +83,8 @@ def __init__(self, observation_space, action_space, device, unnormalized_log_pro
{networks}
def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, {{}}
"""
Expand Down
12 changes: 6 additions & 6 deletions skrl/utils/model_instantiators/torch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def visit_Call(self, node: ast.Call):
node.func = ast.Attribute(value=ast.Name("torch"), attr="cat")
node.keywords = [ast.keyword(arg="dim", value=ast.Constant(value=1))]
# operation: permute
if node.func.id == "permute":
elif node.func.id == "permute":
node.func = ast.Attribute(value=ast.Name("torch"), attr="permute")
return node

Expand All @@ -62,11 +62,11 @@ def visit_Call(self, node: ast.Call):
NodeTransformer().visit(tree)
source = ast.unparse(tree)
# enum substitutions
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)')
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)')
source = source.replace("Shape.STATES", "STATES").replace("STATES", 'inputs["states"]')
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'inputs["states"]')
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'inputs["taken_actions"]')
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'torch.cat((states, taken_actions), dim=1)')
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'torch.cat((states, taken_actions), dim=1)')
source = source.replace("Shape.STATES", "STATES").replace("STATES", 'states')
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'states')
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'taken_actions')
return source

def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]:
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/torch/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skrl.models.torch import DeterministicMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa


def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(self, observation_space, action_space, device, clip_actions):
{networks}
def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, {{}}
"""
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/torch/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skrl.models.torch import GaussianMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa


def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -93,6 +94,8 @@ def __init__(self, observation_space, action_space, device, clip_actions,
self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]}))
def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, self.log_std_parameter, {{}}
"""
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/torch/multivariate_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skrl.models.torch import MultivariateGaussianMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa


def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -93,6 +94,8 @@ def __init__(self, observation_space, action_space, device, clip_actions,
self.log_std_parameter = nn.Parameter({initial_log_std} * torch.ones({output["size"]}))
def compute(self, inputs, role=""):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, self.log_std_parameter, {{}}
"""
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/torch/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skrl.models.torch import Model # noqa
from skrl.models.torch import DeterministicMixin, GaussianMixin # noqa
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa


def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -68,6 +69,8 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g
for container in containers_gaussian:
networks_common.append(f'self.{container["name"]}_container = {container["sequential"]}')
forward_common.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})')
forward_common.insert(0, 'taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))')
forward_common.insert(0, 'states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))')

# process output
networks_gaussian = []
Expand Down

0 comments on commit 45418b8

Please sign in to comment.