Skip to content

Commit

Permalink
[Doc] Tutorial on exporting TorchRL models
Browse files Browse the repository at this point in the history
ghstack-source-id: b93146e22d8376563e7ac302b5cff95f09ae50d4
Pull Request resolved: #2557
  • Loading branch information
vmoens committed Nov 13, 2024
1 parent 165163a commit c0187a9
Show file tree
Hide file tree
Showing 5 changed files with 432 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ memory_profiler
pyrender
pytest
vmas
onnxscript
onnxruntime
onnx
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Intermediate
tutorials/pretrained_models
tutorials/dqn_with_rnn
tutorials/rb_tutorial
tutorials/export

Advanced
--------
Expand Down
26 changes: 25 additions & 1 deletion torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def __init__(
if not isinstance(out_features, Number):
_out_features_num = prod(out_features)
self.out_features = out_features
self._reshape_out = not isinstance(
self.out_features, (int, torch.SymInt, Number)
)
self._out_features_num = _out_features_num
self.activation_class = activation_class
self.norm_class = norm_class
Expand Down Expand Up @@ -302,7 +305,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
inputs = (torch.cat([*inputs], -1),)

out = super().forward(*inputs)
if not isinstance(self.out_features, Number):
if self._reshape_out:
out = out.view(*out.shape[:-1], *self.out_features)
return out

Expand Down Expand Up @@ -549,6 +552,27 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
out = out.unflatten(0, batch)
return out

@classmethod
def default_atari_dqn(cls, num_actions: int):
"""Returns the default DQN as presented in the seminal DQN paper.
Args:
num_actions (int): the action space of the atari game.
"""
cnn = ConvNet(
activation_class=torch.nn.ReLU,
num_cells=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
mlp = MLP(
activation_class=torch.nn.ReLU,
out_features=num_actions,
num_cells=[512],
)
return nn.Sequential(cnn, mlp)


Conv2dNet = ConvNet

Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
out = action_tensordict.get(action_key)
eps = self.eps.item()
cond = torch.rand(action_tensordict.shape, device=out.device) < eps
# cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
Expand Down
Loading

0 comments on commit c0187a9

Please sign in to comment.