Skip to content

Commit 693d799

Browse files
Adds convolutional actor critic module
1 parent cf71aa6 commit 693d799

File tree

7 files changed

+367
-2
lines changed

7 files changed

+367
-2
lines changed

rsl_rl/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""Definitions for neural-network components for RL-agents."""
77

88
from .actor_critic import ActorCritic
9+
from .actor_critic_conv2d import ActorCriticConv2d
910
from .actor_critic_recurrent import ActorCriticRecurrent
1011
from .rnd import *
1112
from .student_teacher import StudentTeacher
@@ -14,6 +15,7 @@
1415

1516
__all__ = [
1617
"ActorCritic",
18+
"ActorCriticConv2d",
1719
"ActorCriticRecurrent",
1820
"StudentTeacher",
1921
"StudentTeacherRecurrent",

rsl_rl/modules/actor_critic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
class ActorCritic(nn.Module):
1616
is_recurrent = False
17+
is_conv2d = False
1718

1819
def __init__(
1920
self,
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
import torch
7+
import torch.nn as nn
8+
from torch.distributions import Normal
9+
10+
from rsl_rl.networks import EmpiricalNormalization
11+
from rsl_rl.utils import resolve_nn_activation
12+
13+
14+
class ResidualBlock(nn.Module):
15+
def __init__(self, channels):
16+
super().__init__()
17+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
18+
self.bn1 = nn.BatchNorm2d(channels)
19+
self.relu = nn.ReLU(inplace=True)
20+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
21+
self.bn2 = nn.BatchNorm2d(channels)
22+
23+
def forward(self, x):
24+
residual = x
25+
out = self.conv1(x)
26+
out = self.bn1(out)
27+
out = self.relu(out)
28+
out = self.conv2(out)
29+
out = self.bn2(out)
30+
out += residual
31+
out = self.relu(out)
32+
return out
33+
34+
35+
class ConvolutionalNetwork(nn.Module):
36+
def __init__(
37+
self,
38+
proprio_input_dim,
39+
output_dim,
40+
image_input_shape,
41+
conv_layers_params,
42+
hidden_dims,
43+
activation_fn,
44+
conv_linear_output_size,
45+
):
46+
super().__init__()
47+
48+
self.image_input_shape = image_input_shape # (C, H, W)
49+
self.image_obs_size = torch.prod(torch.tensor(self.image_input_shape)).item()
50+
self.proprio_obs_size = proprio_input_dim
51+
self.input_dim = self.proprio_obs_size + self.image_obs_size
52+
self.activation_fn = activation_fn
53+
54+
# build conv network and get its output size
55+
self.conv_net = self.build_conv_net(conv_layers_params)
56+
with torch.no_grad():
57+
dummy_image = torch.zeros(1, *self.image_input_shape)
58+
conv_output = self.conv_net(dummy_image)
59+
self.image_feature_size = conv_output.view(1, -1).shape[1]
60+
61+
# connection layers between conv net and mlp
62+
self.conv_linear = nn.Linear(self.image_feature_size, conv_linear_output_size)
63+
self.layernorm = nn.LayerNorm(conv_linear_output_size)
64+
65+
# mlp
66+
self.mlp = nn.Sequential(
67+
nn.Linear(self.proprio_obs_size + conv_linear_output_size, hidden_dims[0]),
68+
self.activation_fn,
69+
*[
70+
layer
71+
for dim in zip(hidden_dims[:-1], hidden_dims[1:])
72+
for layer in (nn.Linear(dim[0], dim[1]), self.activation_fn)
73+
],
74+
nn.Linear(hidden_dims[-1], output_dim),
75+
)
76+
77+
# initialize weights
78+
self._initialize_weights()
79+
80+
def build_conv_net(self, conv_layers_params):
81+
layers = []
82+
in_channels = self.image_input_shape[0]
83+
for idx, params in enumerate(conv_layers_params[:-1]):
84+
layers.extend([
85+
nn.Conv2d(
86+
in_channels,
87+
params["out_channels"],
88+
kernel_size=params.get("kernel_size", 3),
89+
stride=params.get("stride", 1),
90+
padding=params.get("padding", 0),
91+
),
92+
nn.BatchNorm2d(params["out_channels"]),
93+
nn.ReLU(inplace=True),
94+
ResidualBlock(params["out_channels"]) if idx > 0 else nn.Identity(),
95+
])
96+
in_channels = params["out_channels"]
97+
last_params = conv_layers_params[-1]
98+
layers.append(
99+
nn.Conv2d(
100+
in_channels,
101+
last_params["out_channels"],
102+
kernel_size=last_params.get("kernel_size", 3),
103+
stride=last_params.get("stride", 1),
104+
padding=last_params.get("padding", 0),
105+
)
106+
)
107+
layers.append(nn.BatchNorm2d(last_params["out_channels"]))
108+
return nn.Sequential(*layers)
109+
110+
def _initialize_weights(self):
111+
for m in self.conv_net.modules():
112+
if isinstance(m, nn.Conv2d):
113+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
114+
elif isinstance(m, nn.BatchNorm2d):
115+
nn.init.constant_(m.weight, 1)
116+
nn.init.constant_(m.bias, 0)
117+
118+
nn.init.kaiming_normal_(self.conv_linear.weight, mode="fan_out", nonlinearity="tanh")
119+
nn.init.constant_(self.conv_linear.bias, 0)
120+
nn.init.constant_(self.layernorm.weight, 1.0)
121+
nn.init.constant_(self.layernorm.bias, 0.0)
122+
123+
for layer in self.mlp:
124+
if isinstance(layer, nn.Linear):
125+
nn.init.orthogonal_(layer.weight, gain=0.01)
126+
nn.init.zeros_(layer.bias) if layer.bias is not None else None
127+
128+
def forward(self, proprio_obs, image_obs):
129+
conv_features = self.conv_net(image_obs)
130+
flattened_conv_features = conv_features.reshape(conv_features.size(0), -1)
131+
normalized_conv_output = self.layernorm(self.conv_linear(flattened_conv_features))
132+
combined_input = torch.cat([proprio_obs, normalized_conv_output], dim=1)
133+
output = self.mlp(combined_input)
134+
return output
135+
136+
137+
class ActorCriticConv2d(nn.Module):
138+
is_recurrent = False
139+
is_conv2d = True
140+
141+
def __init__(
142+
self,
143+
obs,
144+
obs_groups,
145+
num_actions,
146+
conv_layers_params,
147+
conv_linear_output_size,
148+
actor_hidden_dims,
149+
critic_hidden_dims,
150+
actor_obs_normalization=False,
151+
critic_obs_normalization=False,
152+
activation="elu",
153+
init_noise_std=1.0,
154+
noise_std_type: str = "scalar",
155+
**kwargs,
156+
):
157+
if kwargs:
158+
print(
159+
"ActorCriticConv2d.__init__ got unexpected arguments, which will be ignored: "
160+
+ str([key for key in kwargs.keys()])
161+
)
162+
super().__init__()
163+
164+
self.obs_groups = obs_groups
165+
self.activation_fn = resolve_nn_activation(activation)
166+
167+
# get observation dimensions
168+
self.num_actor_obs, self.actor_image_shape = self._calculate_obs_dims(obs, obs_groups["policy"])
169+
self.num_critic_obs, self.critic_image_shape = self._calculate_obs_dims(obs, obs_groups["critic"])
170+
171+
self.image_input_shape = self.actor_image_shape
172+
if self.image_input_shape is None:
173+
raise ValueError("No image observations found. Conv2d networks require image inputs.")
174+
175+
# actor
176+
self.actor = ConvolutionalNetwork(
177+
proprio_input_dim=self.num_actor_obs,
178+
output_dim=num_actions,
179+
image_input_shape=self.image_input_shape,
180+
conv_layers_params=conv_layers_params,
181+
hidden_dims=actor_hidden_dims,
182+
activation_fn=self.activation_fn,
183+
conv_linear_output_size=conv_linear_output_size,
184+
)
185+
186+
# actor observation normalization
187+
self.actor_obs_normalization = actor_obs_normalization
188+
if actor_obs_normalization:
189+
self.actor_obs_normalizer = EmpiricalNormalization(self.num_actor_obs)
190+
else:
191+
self.actor_obs_normalizer = torch.nn.Identity()
192+
193+
# critic
194+
self.critic = ConvolutionalNetwork(
195+
proprio_input_dim=self.num_critic_obs,
196+
output_dim=1,
197+
image_input_shape=self.image_input_shape,
198+
conv_layers_params=conv_layers_params,
199+
hidden_dims=critic_hidden_dims,
200+
activation_fn=self.activation_fn,
201+
conv_linear_output_size=conv_linear_output_size,
202+
)
203+
204+
# critic observation normalization
205+
self.critic_obs_normalization = critic_obs_normalization
206+
if critic_obs_normalization:
207+
self.critic_obs_normalizer = EmpiricalNormalization(self.num_critic_obs)
208+
else:
209+
self.critic_obs_normalizer = torch.nn.Identity()
210+
211+
print(f"Actor ConvNet: {self.actor}")
212+
print(f"Critic ConvNet: {self.critic}")
213+
214+
# action noise
215+
self.noise_std_type = noise_std_type
216+
if self.noise_std_type == "scalar":
217+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
218+
elif self.noise_std_type == "log":
219+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
220+
else:
221+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
222+
223+
# action distribution
224+
self.distribution = None
225+
# disable args validation for speedup
226+
Normal.set_default_validate_args(False)
227+
228+
def _calculate_obs_dims(self, obs, obs_group_list):
229+
"""Calculate total proprioceptive obs dim and extract image shape."""
230+
total_proprio_dim = 0
231+
image_shape = None
232+
233+
for obs_group in obs_group_list:
234+
obs_tensor = obs[obs_group]
235+
if obs_group == "sensor":
236+
image_shape = obs_tensor.permute(0, 3, 1, 2).shape[1:]
237+
else:
238+
assert len(obs_tensor.shape) == 2, f"Non-image observations must be 1D. Got {obs_tensor.shape} for {obs_group}"
239+
total_proprio_dim += obs_tensor.shape[-1]
240+
241+
return total_proprio_dim, image_shape
242+
243+
def reset(self, dones=None):
244+
pass
245+
246+
def forward(self):
247+
raise NotImplementedError
248+
249+
@property
250+
def action_mean(self):
251+
return self.distribution.mean
252+
253+
@property
254+
def action_std(self):
255+
return self.distribution.stddev
256+
257+
@property
258+
def entropy(self):
259+
return self.distribution.entropy().sum(dim=-1)
260+
261+
def update_distribution(self, proprio_obs, image_obs):
262+
mean = self.actor(proprio_obs, image_obs)
263+
if self.noise_std_type == "scalar":
264+
std = self.std.expand_as(mean)
265+
elif self.noise_std_type == "log":
266+
std = torch.exp(self.log_std).expand_as(mean)
267+
else:
268+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
269+
self.distribution = Normal(mean, std)
270+
271+
def act(self, obs, **kwargs):
272+
proprio_obs, image_obs = self.get_actor_obs(obs)
273+
proprio_obs = self.actor_obs_normalizer(proprio_obs)
274+
self.update_distribution(proprio_obs, image_obs)
275+
return self.distribution.sample()
276+
277+
def act_inference(self, obs):
278+
proprio_obs, image_obs = self.get_actor_obs(obs)
279+
proprio_obs = self.actor_obs_normalizer(proprio_obs)
280+
return self.actor(proprio_obs, image_obs)
281+
282+
def evaluate(self, obs, **kwargs):
283+
proprio_obs, image_obs = self.get_critic_obs(obs)
284+
proprio_obs = self.critic_obs_normalizer(proprio_obs)
285+
return self.critic(proprio_obs, image_obs)
286+
287+
def get_actor_obs(self, obs):
288+
obs_list = []
289+
image_obs = None
290+
291+
for obs_group in self.obs_groups["policy"]:
292+
if obs_group == "sensor":
293+
image_obs = obs[obs_group].permute(0, 3, 1, 2)
294+
else:
295+
obs_list.append(obs[obs_group])
296+
297+
if obs_list:
298+
proprio_obs = torch.cat(obs_list, dim=-1)
299+
else:
300+
proprio_obs = torch.zeros(obs[list(obs.keys())[0]].shape[0], 0, device=obs.device)
301+
302+
if image_obs is not None:
303+
return proprio_obs, image_obs
304+
else:
305+
dummy_image = torch.zeros(proprio_obs.shape[0], *self.image_input_shape, device=proprio_obs.device)
306+
return proprio_obs, dummy_image
307+
308+
def get_critic_obs(self, obs):
309+
obs_list = []
310+
image_obs = None
311+
312+
for obs_group in self.obs_groups["critic"]:
313+
if obs_group == "sensor":
314+
image_obs = obs[obs_group].permute(0, 3, 1, 2)
315+
else:
316+
obs_list.append(obs[obs_group])
317+
318+
if obs_list:
319+
proprio_obs = torch.cat(obs_list, dim=-1)
320+
else:
321+
proprio_obs = torch.zeros(obs[list(obs.keys())[0]].shape[0], 0, device=obs.device)
322+
323+
if image_obs is not None:
324+
return proprio_obs, image_obs
325+
else:
326+
dummy_image = torch.zeros(proprio_obs.shape[0], *self.image_input_shape, device=proprio_obs.device)
327+
return proprio_obs, dummy_image
328+
329+
def get_actions_log_prob(self, actions):
330+
return self.distribution.log_prob(actions).sum(dim=-1)
331+
332+
def update_normalization(self, obs):
333+
if self.actor_obs_normalization:
334+
proprio_obs, _ = self.get_actor_obs(obs)
335+
self.actor_obs_normalizer.update(proprio_obs)
336+
if self.critic_obs_normalization:
337+
proprio_obs, _ = self.get_critic_obs(obs)
338+
self.critic_obs_normalizer.update(proprio_obs)
339+
340+
def load_state_dict(self, state_dict, strict=True):
341+
"""Load the parameters of the actor-critic model.
342+
343+
Args:
344+
state_dict (dict): State dictionary of the model.
345+
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
346+
module's state_dict() function.
347+
348+
Returns:
349+
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
350+
`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
351+
"""
352+
super().load_state_dict(state_dict, strict=strict)
353+
return True # training resumes

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
class ActorCriticRecurrent(nn.Module):
1717
is_recurrent = True
18+
is_conv2d = False
1819

1920
def __init__(
2021
self,

rsl_rl/modules/student_teacher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
class StudentTeacher(nn.Module):
1616
is_recurrent = False
17+
is_conv2d = False
1718

1819
def __init__(
1920
self,

0 commit comments

Comments
 (0)