From 46e8ac0daa4b0a7e1f442cfcc913e1740171121d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 17:34:35 +0100 Subject: [PATCH] fix examples --- .../linux_examples/scripts/run_test.sh | 48 +++++++++---------- sota-implementations/dreamer/dreamer.py | 6 +-- test/test_libs.py | 4 +- torchrl/modules/models/model_based.py | 7 ++- 4 files changed, 33 insertions(+), 32 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 1d11d481e3c..6ce551cb140 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -167,19 +167,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di # logger.record_video=True \ # logger.record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - model_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 + collector.total_frames=200 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + replay_buffer.batch_size=10 \ + env.n_parallel_envs=4 \ +# env_per_collector=2 \ + optimization.optim_steps_per_batch=1 \ + logger.video=True \ +# record_frames=4 \ + replay_buffer.buffer_size=120 \ + networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -223,19 +221,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - model_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 + collector.total_frames=200 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + replay_buffer.batch_size=10 \ + env.n_parallel_envs=1 \ +# env_per_collector=2 \ + optimization.optim_steps_per_batch=1 \ + logger.video=True \ +# record_frames=4 \ + replay_buffer.buffer_size=120 \ + networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a0679f65c05..f002c6420e7 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -188,7 +188,7 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16, - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -216,7 +216,7 @@ def compile_rssms(module): t_loss_actor_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() @@ -237,7 +237,7 @@ def compile_rssms(module): t_loss_critic_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() diff --git a/test/test_libs.py b/test/test_libs.py index 7ddb0d4fc02..cfcde85cc2c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3416,7 +3416,9 @@ def test_robohive(self, envname, from_pixels, from_depths): torchrl_logger.info("no camera") return try: - env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths) + env = RoboHiveEnv( + envname, from_pixels=from_pixels, from_depths=from_depths + ) except AttributeError as err: if "'MjData' object has no attribute 'get_body_xipos'" in str(err): torchrl_logger.info("tcdm are broken") diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index cc34d79bb5b..11ca9d12232 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -6,6 +6,7 @@ import torch from packaging import version +from tensordict import LazyStackedTensorDict from tensordict.nn import ( NormalParamExtractor, TensorDictModule, @@ -13,7 +14,7 @@ TensorDictSequential, ) from torch import nn -from tensordict import LazyStackedTensorDict + # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell from torchrl._utils import timeit @@ -263,7 +264,9 @@ def forward(self, tensordict): _tensordict = update_values[t + 1].update(_tensordict) out = torch.stack(tensordict_out, tensordict.ndim - 1) - assert not any(isinstance(val, LazyStackedTensorDict) for val in out.values(True)), out + assert not any( + isinstance(val, LazyStackedTensorDict) for val in out.values(True) + ), out return out