Skip to content

Commit

Permalink
Define deterministic models using dynamic execution of Python code
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 11, 2024
1 parent 82ea390 commit 6909c92
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 141 deletions.
140 changes: 66 additions & 74 deletions skrl/utils/model_instantiators/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Shape(Enum):
STATES_ACTIONS = -2


def _get_activation_function(activation: str) -> nn.Module:
def _get_activation_function(activation: str, as_string: bool = False) -> nn.Module:
"""Get the activation function
Supported activation functions:
Expand All @@ -45,50 +45,54 @@ def _get_activation_function(activation: str) -> nn.Module:
:param activation: activation function name.
If activation is an empty string, a placeholder will be returned (``lambda x: x``)
:type activation: str
:param as_string: Whether to return the activation function as string.
:type as_string: bool
:raises: ValueError if activation is not a valid activation function
:return: activation function
:rtype: nn.Module
"""
if not activation:
return lambda x: x
return None if as_string else lambda x: x
elif activation == "relu":
return nn.relu
return "nn.relu" if as_string else nn.relu
elif activation == "tanh":
return nn.tanh
return "nn.tanh" if as_string else nn.tanh
elif activation == "sigmoid":
return nn.sigmoid
return "nn.sigmoid" if as_string else nn.sigmoid
elif activation == "leaky_relu":
return nn.leaky_relu
return "nn.leaky_relu" if as_string else nn.leaky_relu
elif activation == "elu":
return nn.elu
return "nn.elu" if as_string else nn.elu
elif activation == "softplus":
return nn.softplus
return "nn.softplus" if as_string else nn.softplus
elif activation == "softsign":
return nn.soft_sign
return "nn.soft_sign" if as_string else nn.soft_sign
elif activation == "selu":
return nn.selu
return "nn.selu" if as_string else nn.selu
elif activation == "softmax":
return nn.softmax
return "nn.softmax" if as_string else nn.softmax
else:
raise ValueError(f"Unknown activation function: {activation}")

def _get_num_units_by_shape(model: Model, shape: Shape) -> int:
def _get_num_units_by_shape(model: Model, shape: Shape, as_string: bool = False) -> int:
"""Get the number of units in a layer by shape
:param model: Model to get the number of units for
:type model: Model
:param shape: Shape of the layer
:type shape: Shape or int
:param as_string: Whether to return the activation function as string.
:type as_string: bool
:return: Number of units in the layer
:rtype: int
"""
num_units = {Shape.ONE: 1,
Shape.STATES: model.num_observations,
Shape.ACTIONS: model.num_actions,
Shape.STATES_ACTIONS: model.num_observations + model.num_actions}
num_units = {Shape.ONE: "1" if as_string else 1,
Shape.STATES: "self.num_observations" if as_string else model.num_observations,
Shape.ACTIONS: "self.num_actions" if as_string else model.num_actions,
Shape.STATES_ACTIONS: "self.num_observations + self.num_actions" if as_string else model.num_observations + model.num_actions}
try:
return num_units[shape]
except:
Expand Down Expand Up @@ -122,20 +126,18 @@ def _generate_sequential(model: Model,
:return: sequential model
:rtype: nn.Sequential
"""
# input layer
input_layer = [nn.Dense(hiddens[0])]
# hidden layers
hidden_layers = []
for i in range(len(hiddens) - 1):
hidden_layers.append(_get_activation_function(hidden_activation[i]))
hidden_layers.append(nn.Dense(hiddens[i + 1]))
hidden_layers.append(_get_activation_function(hidden_activation[-1]))
# output layer
output_layer = [nn.Dense(_get_num_units_by_shape(model, output_shape))]
if output_activation is not None:
output_layer.append(_get_activation_function(output_activation))

return nn.Sequential(input_layer + hidden_layers + output_layer)
modules = []
for i in range(len(hiddens)):
# first and middle layers
modules.append(f"nn.Dense({hiddens[i]})")
if hidden_activation[i]:
modules.append(_get_activation_function(hidden_activation[i], as_string=True))
# last layer
if i == len(hiddens) - 1:
modules.append(f"nn.Dense({_get_num_units_by_shape(None, output_shape, as_string=True)})")
if output_activation:
modules.append(_get_activation_function(output_activation, as_string=True))
return f'self.net = nn.Sequential([{", ".join(modules)}])'

def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -281,50 +283,40 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S
:return: Deterministic model instance
:rtype: Model
"""
class DeterministicModel(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

# override the hash method for Python versions prior to 3.8 to avoid the following error:
# TypeError: Failed to hash Flax Module. The module probably contains unhashable attributes.
if sys.version_info < (3, 8):
def __hash__(self):
return id(self)

def setup(self):
self.instantiator_output_scale = metadata["output_scale"]
self.instantiator_input_type = metadata["input_shape"].value

self.net = _generate_sequential(model=self,
input_shape=metadata["input_shape"],
hiddens=metadata["hiddens"],
hidden_activation=metadata["hidden_activation"],
output_shape=metadata["output_shape"],
output_activation=metadata["output_activation"],
output_scale=metadata["output_scale"])

def __call__(self, inputs, role):
if self.instantiator_input_type == 0:
output = self.net(inputs["states"])
elif self.instantiator_input_type == -1:
output = self.net(inputs["taken_actions"])
elif self.instantiator_input_type == -2:
output = self.net(jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1))

return output * self.instantiator_output_scale, {}

metadata = {"input_shape": input_shape,
"hiddens": hiddens,
"hidden_activation": hidden_activation,
"output_shape": output_shape,
"output_activation": output_activation,
"output_scale": output_scale}

return DeterministicModel(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions)
# network
net = _generate_sequential(None, input_shape, hiddens, hidden_activation, output_shape, output_activation)

# compute
if input_shape == Shape.OBSERVATIONS:
forward = 'self.net(inputs["states"])'
elif input_shape == Shape.ACTIONS:
forward = 'self.net(inputs["taken_actions"])'
elif input_shape == Shape.STATES_ACTIONS:
forward = 'self.net(jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1))'
if output_scale != 1:
forward = f"{output_scale} * {forward}"

template = f"""class DeterministicModel(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
def setup(self):
{net}
def __call__(self, inputs, role):
return {forward}, {{}}
"""
print(template)
_locals = {}
exec(template, globals(), _locals)
DeterministicModel = _locals["DeterministicModel"]
model = DeterministicModel(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions)
return model

def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down
137 changes: 70 additions & 67 deletions skrl/utils/model_instantiators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Shape(Enum):
STATES_ACTIONS = -2


def _get_activation_function(activation: str) -> nn.Module:
def _get_activation_function(activation: str, as_string: bool = False) -> Union[nn.Module, str]:
"""Get the activation function
Supported activation functions:
Expand All @@ -43,50 +43,54 @@ def _get_activation_function(activation: str) -> nn.Module:
:param activation: activation function name.
If activation is an empty string, a placeholder will be returned (``torch.nn.Identity()``)
:type activation: str
:param as_string: Whether to return the activation function as string.
:type as_string: bool
:raises: ValueError if activation is not a valid activation function
:return: activation function
:rtype: nn.Module
"""
if not activation:
return torch.nn.Identity()
return "torch.nn.Identity()" if as_string else torch.nn.Identity()
elif activation == "relu":
return torch.nn.ReLU()
return "torch.nn.ReLU()" if as_string else torch.nn.ReLU()
elif activation == "tanh":
return torch.nn.Tanh()
return "torch.nn.Tanh()" if as_string else torch.nn.Tanh()
elif activation == "sigmoid":
return torch.nn.Sigmoid()
return "torch.nn.Sigmoid()" if as_string else torch.nn.Sigmoid()
elif activation == "leaky_relu":
return torch.nn.LeakyReLU()
return "torch.nn.LeakyReLU()" if as_string else torch.nn.LeakyReLU()
elif activation == "elu":
return torch.nn.ELU()
return "torch.nn.ELU()" if as_string else torch.nn.ELU()
elif activation == "softplus":
return torch.nn.Softplus()
return "torch.nn.Softplus()" if as_string else torch.nn.Softplus()
elif activation == "softsign":
return torch.nn.Softsign()
return "torch.nn.Softsign()" if as_string else torch.nn.Softsign()
elif activation == "selu":
return torch.nn.SELU()
return "torch.nn.SELU()" if as_string else torch.nn.SELU()
elif activation == "softmax":
return torch.nn.Softmax()
return "torch.nn.Softmax()" if as_string else torch.nn.Softmax()
else:
raise ValueError(f"Unknown activation function: {activation}")

def _get_num_units_by_shape(model: Model, shape: Shape) -> int:
def _get_num_units_by_shape(model: Model, shape: Shape, as_string: bool = False) -> Union[int, str]:
"""Get the number of units in a layer by shape
:param model: Model to get the number of units for
:type model: Model
:param shape: Shape of the layer
:type shape: Shape or int
:param as_string: Whether to return the activation function as string.
:type as_string: bool
:return: Number of units in the layer
:rtype: int
"""
num_units = {Shape.ONE: 1,
Shape.STATES: model.num_observations,
Shape.ACTIONS: model.num_actions,
Shape.STATES_ACTIONS: model.num_observations + model.num_actions}
num_units = {Shape.ONE: "1" if as_string else 1,
Shape.STATES: "self.num_observations" if as_string else model.num_observations,
Shape.ACTIONS: "self.num_actions" if as_string else model.num_actions,
Shape.STATES_ACTIONS: "self.num_observations + self.num_actions" if as_string else model.num_observations + model.num_actions}
try:
return num_units[shape]
except:
Expand Down Expand Up @@ -120,20 +124,24 @@ def _generate_sequential(model: Model,
:return: sequential model
:rtype: nn.Sequential
"""
# input layer
input_layer = [nn.Linear(_get_num_units_by_shape(model, input_shape), hiddens[0])]
# hidden layers
hidden_layers = []
for i in range(len(hiddens) - 1):
hidden_layers.append(_get_activation_function(hidden_activation[i]))
hidden_layers.append(nn.Linear(hiddens[i], hiddens[i + 1]))
hidden_layers.append(_get_activation_function(hidden_activation[-1]))
# output layer
output_layer = [nn.Linear(hiddens[-1], _get_num_units_by_shape(model, output_shape))]
if output_activation is not None:
output_layer.append(_get_activation_function(output_activation))

return nn.Sequential(*input_layer, *hidden_layers, *output_layer)
modules = []
for i in range(len(hiddens)):
# first layer
if not i:
modules.append(f"nn.Linear({_get_num_units_by_shape(None, input_shape, as_string=True)}, {hiddens[i]})")
if hidden_activation[i]:
modules.append(_get_activation_function(hidden_activation[i], as_string=True))
# last layer
if i == len(hiddens) - 1:
modules.append(f"nn.Linear({hiddens[i]}, {_get_num_units_by_shape(None, output_shape, as_string=True)})")
if output_activation:
modules.append(_get_activation_function(output_activation, as_string=True))
# hidden layers
else:
modules.append(f"nn.Linear({hiddens[i]}, {hiddens[i + 1]})")
if hidden_activation[i]:
modules.append(_get_activation_function(hidden_activation[i], as_string=True))
return f'self.net = nn.Sequential({", ".join(modules)})'

def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -369,43 +377,38 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S
:return: Deterministic model instance
:rtype: Model
"""
class DeterministicModel(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions):
Model.__init__(self, observation_space, action_space, device)
DeterministicMixin.__init__(self, clip_actions)

self.instantiator_output_scale = metadata["output_scale"]
self.instantiator_input_type = metadata["input_shape"].value

self.net = _generate_sequential(model=self,
input_shape=metadata["input_shape"],
hiddens=metadata["hiddens"],
hidden_activation=metadata["hidden_activation"],
output_shape=metadata["output_shape"],
output_activation=metadata["output_activation"],
output_scale=metadata["output_scale"])

def compute(self, inputs, role=""):
if self.instantiator_input_type == 0:
output = self.net(inputs["states"])
elif self.instantiator_input_type == -1:
output = self.net(inputs["taken_actions"])
elif self.instantiator_input_type == -2:
output = self.net(torch.cat((inputs["states"], inputs["taken_actions"]), dim=1))

return output * self.instantiator_output_scale, {}

metadata = {"input_shape": input_shape,
"hiddens": hiddens,
"hidden_activation": hidden_activation,
"output_shape": output_shape,
"output_activation": output_activation,
"output_scale": output_scale}

return DeterministicModel(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions)
# network
net = _generate_sequential(None, input_shape, hiddens, hidden_activation, output_shape, output_activation)

# compute
if input_shape == Shape.OBSERVATIONS:
forward = 'self.net(inputs["states"])'
elif input_shape == Shape.ACTIONS:
forward = 'self.net(inputs["taken_actions"])'
elif input_shape == Shape.STATES_ACTIONS:
forward = 'self.net(torch.cat((inputs["states"], inputs["taken_actions"]), dim=1))'
if output_scale != 1:
forward = f"{output_scale} * {forward}"

template = f"""class DeterministicModel(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions):
Model.__init__(self, observation_space, action_space, device)
DeterministicMixin.__init__(self, clip_actions)
{net}
def compute(self, inputs, role=""):
return {forward}, {{}}
"""
print(template)
_locals = {}
exec(template, globals(), _locals)
DeterministicModel = _locals["DeterministicModel"]
model = DeterministicModel(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=clip_actions)
return model

def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down

0 comments on commit 6909c92

Please sign in to comment.