Skip to content

Commit

Permalink
[BugFix] Fix tutos (#1648)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 24, 2023
1 parent e7630f1 commit f8788b1
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 14 deletions.
28 changes: 20 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,9 @@ def __init__(
self.policy_weights = TensorDict({}, [])

self.env: EnvBase = self.env.to(self.device)
self.max_frames_per_traj = max_frames_per_traj
self.max_frames_per_traj = (
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
)
if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
# let's check that there is no StepCounter yet
for key in self.env.output_spec.keys(True, True):
Expand Down Expand Up @@ -595,9 +597,13 @@ def __init__(
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.total_frames = total_frames
self.total_frames = (
int(total_frames) if total_frames != float("inf") else total_frames
)
self.reset_at_each_iter = reset_at_each_iter
self.init_random_frames = init_random_frames
self.init_random_frames = (
int(init_random_frames) if init_random_frames is not None else 0
)
if (
init_random_frames is not None
and init_random_frames % frames_per_batch != 0
Expand All @@ -620,7 +626,7 @@ def __init__(
f" ({-(-frames_per_batch // self.n_env) * self.n_env})."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.requested_frames_per_batch = frames_per_batch
self.requested_frames_per_batch = int(frames_per_batch)
self.frames_per_batch = -(-frames_per_batch // self.n_env)
self.exploration_type = (
exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
Expand Down Expand Up @@ -1234,11 +1240,15 @@ def device_err_msg(device_name, devices_list):
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.total_frames = total_frames
self.total_frames = (
int(total_frames) if total_frames != float("inf") else total_frames
)
self.reset_at_each_iter = reset_at_each_iter
self.postprocs = postproc
self.max_frames_per_traj = max_frames_per_traj
self.requested_frames_per_batch = frames_per_batch
self.max_frames_per_traj = (
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
)
self.requested_frames_per_batch = int(frames_per_batch)
self.reset_when_done = reset_when_done
if split_trajs is None:
split_trajs = False
Expand All @@ -1247,7 +1257,9 @@ def device_err_msg(device_name, devices_list):
"Cannot split trajectories when reset_when_done is False."
)
self.split_trajs = split_trajs
self.init_random_frames = init_random_frames
self.init_random_frames = (
int(init_random_frames) if init_random_frames is not None else 0
)
self.update_at_each_batch = update_at_each_batch
self.exploration_type = exploration_type
self.frames_per_worker = np.inf
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _init(self):
transforms.append(resize)

# R3M
if out_keys is None:
if out_keys in (None, []):
if stack_images:
out_keys = ["r3m_vec"]
else:
Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -277,7 +276,7 @@ def _init(self):
transforms.append(resize)

# VIP
if out_keys is None:
if out_keys in (None, []):
if stack_images:
out_keys = ["vip_vec"]
else:
Expand Down
11 changes: 8 additions & 3 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def get_replay_buffer(buffer_size, n_optim, batch_size):


def get_collector(
obs_norm_sd,
stats,
num_collectors,
actor_explore,
frames_per_batch,
Expand All @@ -399,7 +399,7 @@ def get_collector(
):
data_collector = MultiaSyncDataCollector(
[
make_env(parallel=True, obs_norm_sd=obs_norm_sd),
make_env(parallel=True, obs_norm_sd=stats),
]
* num_collectors,
policy=actor_explore,
Expand Down Expand Up @@ -566,7 +566,12 @@ def get_loss_module(actor, gamma):
loss_module, target_net_updater = get_loss_module(actor, gamma)

collector = get_collector(
stats, num_collectors, actor_explore, frames_per_batch, total_frames, device
stats=stats,
num_collectors=num_collectors,
actor_explore=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
)
optimizer = torch.optim.Adam(
loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
Expand Down
12 changes: 12 additions & 0 deletions tutorials/sphinx-tutorials/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,12 @@ class SinTransform(Transform):
def _apply_transform(self, obs: torch.Tensor) -> None:
return obs.sin()

# The transform must also modify the data at reset time
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return self._call(tensordict_reset)

# _apply_to_composite will execute the observation spec transform across all
# in_keys/out_keys pairs and write the result in the observation_spec which
# is of type ``Composite``
Expand All @@ -670,6 +676,12 @@ class CosTransform(Transform):
def _apply_transform(self, obs: torch.Tensor) -> None:
return obs.cos()

# The transform must also modify the data at reset time
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return self._call(tensordict_reset)

# _apply_to_composite will execute the observation spec transform across all
# in_keys/out_keys pairs and write the result in the observation_spec which
# is of type ``Composite``
Expand Down

0 comments on commit f8788b1

Please sign in to comment.