Skip to content

Commit

Permalink
fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 22, 2024
1 parent 63f7580 commit 46e8ac0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 32 deletions.
48 changes: 22 additions & 26 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
# logger.record_video=True \
# logger.record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=200 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
model_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
collector.total_frames=200 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
replay_buffer.batch_size=10 \
env.n_parallel_envs=4 \
# env_per_collector=2 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
# record_frames=4 \
replay_buffer.buffer_size=120 \
networks.rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -223,19 +221,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=200 \
num_workers=2 \
env_per_collector=1 \
collector_device=cuda:0 \
model_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
collector.total_frames=200 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
replay_buffer.batch_size=10 \
env.n_parallel_envs=1 \
# env_per_collector=2 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
# record_frames=4 \
replay_buffer.buffer_size=120 \
networks.rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def compile_rssms(module):
with torch.autocast(
device_type=device.type,
dtype=torch.bfloat16,
) if use_autocast else contextlib.nullcontext():
) if use_autocast else contextlib.nullcontext():
model_loss_td, sampled_tensordict = world_model_loss(
sampled_tensordict
)
Expand Down Expand Up @@ -216,7 +216,7 @@ def compile_rssms(module):
t_loss_actor_init = time.time()
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext():
) if use_autocast else contextlib.nullcontext():
actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict)

actor_opt.zero_grad()
Expand All @@ -237,7 +237,7 @@ def compile_rssms(module):
t_loss_critic_init = time.time()
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext():
) if use_autocast else contextlib.nullcontext():
value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)

value_opt.zero_grad()
Expand Down
4 changes: 3 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,7 +3416,9 @@ def test_robohive(self, envname, from_pixels, from_depths):
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths)
env = RoboHiveEnv(
envname, from_pixels=from_pixels, from_depths=from_depths
)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
Expand Down
7 changes: 5 additions & 2 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import torch
from packaging import version
from tensordict import LazyStackedTensorDict
from tensordict.nn import (
NormalParamExtractor,
TensorDictModule,
TensorDictModuleBase,
TensorDictSequential,
)
from torch import nn
from tensordict import LazyStackedTensorDict

# from torchrl.modules.tensordict_module.rnn import GRUCell
from torch.nn import GRUCell
from torchrl._utils import timeit
Expand Down Expand Up @@ -263,7 +264,9 @@ def forward(self, tensordict):
_tensordict = update_values[t + 1].update(_tensordict)

out = torch.stack(tensordict_out, tensordict.ndim - 1)
assert not any(isinstance(val, LazyStackedTensorDict) for val in out.values(True)), out
assert not any(
isinstance(val, LazyStackedTensorDict) for val in out.values(True)
), out
return out


Expand Down

0 comments on commit 46e8ac0

Please sign in to comment.