Skip to content

Commit

Permalink
Implement all model instantiators using dynamic execution of Python c…
Browse files Browse the repository at this point in the history
…ode in jax
  • Loading branch information
Toni-SM committed Aug 13, 2024
1 parent 31fb727 commit bc0fdfb
Showing 1 changed file with 73 additions and 107 deletions.
180 changes: 73 additions & 107 deletions skrl/utils/model_instantiators/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
from typing import Optional, Tuple, Union

import sys
from enum import Enum
import gym
import gymnasium

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.numpy as jnp # noqa

from skrl.models.jax import Model # noqa
from skrl.models.jax import CategoricalMixin, DeterministicMixin, GaussianMixin # noqa
Expand All @@ -27,7 +26,7 @@ class Shape(Enum):
STATES_ACTIONS = -2


def _get_activation_function(activation: str, as_string: bool = False) -> nn.Module:
def _get_activation_function(activation: str, as_string: bool = False) -> Union[nn.Module, str]:
"""Get the activation function
Supported activation functions:
Expand Down Expand Up @@ -76,7 +75,7 @@ def _get_activation_function(activation: str, as_string: bool = False) -> nn.Mod
else:
raise ValueError(f"Unknown activation function: {activation}")

def _get_num_units_by_shape(model: Model, shape: Shape, as_string: bool = False) -> int:
def _get_num_units_by_shape(model: Model, shape: Shape, as_string: bool = False) -> Union[int, str]:
"""Get the number of units in a layer by shape
:param model: Model to get the number of units for
Expand Down Expand Up @@ -137,7 +136,7 @@ def _generate_sequential(model: Model,
modules.append(f"nn.Dense({_get_num_units_by_shape(None, output_shape, as_string=True)})")
if output_activation:
modules.append(_get_activation_function(output_activation, as_string=True))
return f'self.net = nn.Sequential([{", ".join(modules)}])'
return f'nn.Sequential([{", ".join(modules)}])'

def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -191,57 +190,41 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space,
:return: Gaussian model instance
:rtype: Model
"""
class GaussianModel(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

# override the hash method for Python versions prior to 3.8 to avoid the following error:
# TypeError: Failed to hash Flax Module. The module probably contains unhashable attributes.
if sys.version_info < (3, 8):
def __hash__(self):
return id(self)

def setup(self):
self.instantiator_output_scale = metadata["output_scale"]
self.instantiator_input_type = metadata["input_shape"].value

self.net = _generate_sequential(model=self,
input_shape=metadata["input_shape"],
hiddens=metadata["hiddens"],
hidden_activation=metadata["hidden_activation"],
output_shape=metadata["output_shape"],
output_activation=metadata["output_activation"],
output_scale=metadata["output_scale"])
self.log_std_parameter = self.param("log_std_parameter", lambda _: metadata["initial_log_std"] \
* jnp.ones(_get_num_units_by_shape(self, metadata["output_shape"])))

def __call__(self, inputs, role):
if self.instantiator_input_type == 0:
output = self.net(inputs["states"])
elif self.instantiator_input_type == -1:
output = self.net(inputs["taken_actions"])
elif self.instantiator_input_type == -2:
output = self.net(jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1))

return output * self.instantiator_output_scale, self.log_std_parameter, {}

metadata = {"input_shape": input_shape,
"hiddens": hiddens,
"hidden_activation": hidden_activation,
"output_shape": output_shape,
"output_activation": output_activation,
"output_scale": output_scale,
"initial_log_std": initial_log_std}

return GaussianModel(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions,
clip_log_std=clip_log_std,
min_log_std=min_log_std,
max_log_std=max_log_std)
# network
net = _generate_sequential(None, input_shape, hiddens, hidden_activation, output_shape, output_activation)

# compute
if input_shape == Shape.OBSERVATIONS:
forward = 'self.net(inputs["states"])'
elif input_shape == Shape.ACTIONS:
forward = 'self.net(inputs["taken_actions"])'
elif input_shape == Shape.STATES_ACTIONS:
forward = 'self.net(jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1))'
if output_scale != 1:
forward = f"{output_scale} * {forward}"

template = f"""class GaussianModel(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
def setup(self):
self.net = {net}
self.log_std_parameter = self.param("log_std_parameter", lambda _: {initial_log_std} * jnp.ones({_get_num_units_by_shape(None, output_shape, as_string=True)}))
def __call__(self, inputs, role):
return {forward}, self.log_std_parameter, {{}}
"""
_locals = {}
exec(template, globals(), _locals)
return _locals["GaussianModel"](observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions,
clip_log_std=clip_log_std,
min_log_std=min_log_std,
max_log_std=max_log_std)

def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -302,21 +285,17 @@ def __init__(self, observation_space, action_space, device, clip_actions=False,
DeterministicMixin.__init__(self, clip_actions)
def setup(self):
{net}
self.net = {net}
def __call__(self, inputs, role):
return {forward}, {{}}
"""
print(template)
_locals = {}
exec(template, globals(), _locals)
DeterministicModel = _locals["DeterministicModel"]
model = DeterministicModel(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions)
return model
return _locals["DeterministicModel"](observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions)

def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -357,44 +336,31 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa
:return: Categorical model instance
:rtype: Model
"""
class CategoricalModel(CategoricalMixin, Model):
def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
CategoricalMixin.__init__(self, unnormalized_log_prob)

# override the hash method for Python versions prior to 3.8 to avoid the following error:
# TypeError: Failed to hash Flax Module. The module probably contains unhashable attributes.
if sys.version_info < (3, 8):
def __hash__(self):
return id(self)

def setup(self):
self.instantiator_input_type = metadata["input_shape"].value

self.net = _generate_sequential(model=self,
input_shape=metadata["input_shape"],
hiddens=metadata["hiddens"],
hidden_activation=metadata["hidden_activation"],
output_shape=metadata["output_shape"],
output_activation=metadata["output_activation"])

def __call__(self, inputs, role):
if self.instantiator_input_type == 0:
output = self.net(inputs["states"])
elif self.instantiator_input_type == -1:
output = self.net(inputs["taken_actions"])
elif self.instantiator_input_type == -2:
output = self.net(jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1))

return output, {}

metadata = {"input_shape": input_shape,
"hiddens": hiddens,
"hidden_activation": hidden_activation,
"output_shape": output_shape,
"output_activation": output_activation}

return CategoricalModel(observation_space=observation_space,
action_space=action_space,
device=device,
unnormalized_log_prob=unnormalized_log_prob)
# network
net = _generate_sequential(None, input_shape, hiddens, hidden_activation, output_shape, output_activation)

# compute
if input_shape == Shape.OBSERVATIONS:
forward = 'self.net(inputs["states"])'
elif input_shape == Shape.ACTIONS:
forward = 'self.net(inputs["taken_actions"])'
elif input_shape == Shape.STATES_ACTIONS:
forward = 'self.net(jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1))'

template = f"""class CategoricalModel(CategoricalMixin, Model):
def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
CategoricalMixin.__init__(self, unnormalized_log_prob)
def setup(self):
self.net = {net}
def __call__(self, inputs, role):
return {forward}, {{}}
"""
_locals = {}
exec(template, globals(), _locals)
return _locals["CategoricalModel"](observation_space=observation_space,
action_space=action_space,
device=device,
unnormalized_log_prob=unnormalized_log_prob)

0 comments on commit bc0fdfb

Please sign in to comment.