diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index bcc688b0a6d..075489b208d 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -86,8 +86,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= -# record_video=True \ -# record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/a2c/a2c_mujoco.py \ env.env_name=HalfCheetah-v4 \ collector.total_frames=40 \ @@ -125,7 +123,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re collector.env_per_collector=2 \ buffer.batch_size=10 \ optim.steps_per_batch=1 \ - logger.record_video=True \ + logger.video=True \ logger.record_frames=4 \ buffer.size=120 \ logger.backend= @@ -151,8 +149,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di replay_buffer.size=120 \ env.name=CartPole-v1 \ logger.backend= -# logger.record_video=True \ -# logger.record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ collector.total_frames=200 \ collector.init_random_frames=10 \ @@ -220,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= -# record_video=True \ -# record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dqn/dqn_atari.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -238,7 +232,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re collector.env_per_collector=1 \ buffer.batch_size=10 \ optim.steps_per_batch=1 \ - logger.record_video=True \ + logger.video=True \ logger.record_frames=4 \ buffer.size=120 \ logger.backend= diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index d6b1668aadf..c6b96db9292 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -76,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821 }, ) else: - logger = "" + logger = None key, init_env_steps, stats = None, None, None if not cfg.env.vecnorm and cfg.env.norm_stats: @@ -174,14 +174,14 @@ def main(cfg: "DictConfig"): # noqa: F821 t.loc.fill_(0.0) trainer = make_trainer( - collector, - loss_module, - recorder, - target_net_updater, - actor_model_explore, - replay_buffer, - logger, - cfg, + collector=collector, + loss_module=loss_module, + recorder=recorder, + target_net_updater=target_net_updater, + policy_exploration=actor_model_explore, + replay_buffer=replay_buffer, + logger=logger, + cfg=cfg, ) trainer.train() diff --git a/test/test_env.py b/test/test_env.py index f5f1f37ed8e..d0ccc2cff57 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1832,7 +1832,16 @@ def test_info_dict_reader(self, device, seed=0): import gym env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device) - env.set_info_dict_reader(default_info_dict_reader(["x_position"])) + env.set_info_dict_reader( + default_info_dict_reader( + ["x_position"], + spec=CompositeSpec( + x_position=UnboundedContinuousTensorSpec( + dtype=torch.float64, shape=() + ) + ), + ) + ) assert "x_position" in env.observation_spec.keys() assert isinstance( @@ -1842,15 +1851,21 @@ def test_info_dict_reader(self, device, seed=0): tensordict = env.reset() tensordict = env.rand_step(tensordict) - assert env.observation_spec["x_position"].is_in( - tensordict[("next", "x_position")] + x_position_data = tensordict["next", "x_position"] + assert env.observation_spec["x_position"].is_in(x_position_data), ( + x_position_data.shape, + x_position_data.dtype, + env.observation_spec["x_position"], ) for spec in ( - {"x_position": UnboundedContinuousTensorSpec(10)}, - None, - CompositeSpec(x_position=UnboundedContinuousTensorSpec(10), shape=[]), - [UnboundedContinuousTensorSpec(10)], + {"x_position": UnboundedContinuousTensorSpec((), dtype=torch.float64)}, + # None, + CompositeSpec( + x_position=UnboundedContinuousTensorSpec((), dtype=torch.float64), + shape=[], + ), + [UnboundedContinuousTensorSpec((), dtype=torch.float64)], ): env2 = GymWrapper(gym.make("HalfCheetah-v4")) env2.set_info_dict_reader( @@ -1859,9 +1874,12 @@ def test_info_dict_reader(self, device, seed=0): tensordict2 = env2.reset() tensordict2 = env2.rand_step(tensordict2) - - assert env2.observation_spec["x_position"].is_in( - tensordict2[("next", "x_position")] + data = tensordict2[("next", "x_position")] + assert env2.observation_spec["x_position"].is_in(data), ( + data.dtype, + data.device, + data.shape, + env2.observation_spec["x_position"], ) @pytest.mark.skipif(not _has_gym, reason="no gym") diff --git a/test/test_exploration.py b/test/test_exploration.py index e6493bd1804..89de8005555 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -254,7 +254,11 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= out_noexp = [] out = [] for i in range(n_steps): - tensordict_noexp = policy(tensordict.clone()) + tensordict_noexp = policy( + tensordict.clone().exclude( + *(key for key in tensordict.keys() if key.startswith("_")) + ) + ) tensordict = exploratory_policy(tensordict.clone()) if i == 0: assert (tensordict[exploratory_policy.ou.steps_key] == 1).all() diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 75ae396d95d..4527e32d16b 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -8,7 +8,7 @@ import pytest import torch from mocking_classes import CountingEnv, DiscreteActionVecMockEnv -from tensordict import pad, TensorDict, unravel_key_list +from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn from torchrl.data.tensor_specs import ( @@ -515,7 +515,7 @@ def test_sequential_partial(self, stack): ) if stack: - td = torch.stack( + td = LazyStackedTensorDict.maybe_dense_stack( [ TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []), TensorDict({"a": torch.randn(3), "c": torch.randn(4)}, []), diff --git a/test/test_transforms.py b/test/test_transforms.py index 36408bf4964..c9d2fb8c031 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10583,6 +10583,7 @@ def test_multistep_transform(self): outs_2 = [] td = env.reset().contiguous() + assert "reward" not in td for _ in range(1): rollout = env.rollout( 250, auto_reset=False, tensordict=td, break_when_any_done=False @@ -10626,7 +10627,7 @@ def test_multistep_transform(self): ).contiguous() assert "reward" not in rollout.keys() out = t._inv_call(rollout) - td = rollout[..., -1]["next"] + td = rollout[..., -1]["next"].exclude("reward") if out is not None: outs_3.append(out) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 8a4b3e517dd..f75eaaeedcf 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -39,6 +39,7 @@ ) from torchrl.data.replay_buffers.storages import ( _get_default_collate, + _stack_anything, ListStorage, Storage, StorageEnsemble, @@ -1541,8 +1542,10 @@ def __init__( num_buffer_sampled: int | None = None, **kwargs, ): + if collate_fn is None: - collate_fn = torch.stack + collate_fn = _stack_anything + if rbs: if storages is not None or samplers is not None or writers is not None: raise RuntimeError diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 7d4fb64babd..a1ada2eb72e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -18,7 +18,12 @@ import numpy as np import tensordict import torch -from tensordict import is_tensor_collection, TensorDict, TensorDictBase +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDict, + TensorDictBase, +) from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp @@ -1322,6 +1327,12 @@ def _collate_list_tensordict(x): return out +def _stack_anything(x): + if is_tensor_collection(x[0]): + return LazyStackedTensorDict.maybe_dense_stack(x) + return torch.stack(x) + + def _collate_id(x): return x diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c9d0683ad9c..105c13214e0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -974,14 +974,18 @@ def zero(self, shape=None) -> TensorDictBase: dim = self.dim + len(shape) else: dim = self.dim - return torch.stack([spec.zero(shape) for spec in self._specs], dim) + return LazyStackedTensorDict.maybe_dense_stack( + [spec.zero(shape) for spec in self._specs], dim + ) def rand(self, shape=None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: dim = self.dim - return torch.stack([spec.rand(shape) for spec in self._specs], dim) + return LazyStackedTensorDict.maybe_dense_stack( + [spec.rand(shape) for spec in self._specs], dim + ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: if dest is None: @@ -4344,7 +4348,7 @@ def project(self, val: TensorDictBase) -> TensorDictBase: vals.append(spec.project(subval)) else: vals.append(subval) - res = torch.stack(vals, dim=self.dim) + res = LazyStackedTensorDict.maybe_dense_stack(vals, dim=self.dim) if not isinstance(val, LazyStackedTensorDict): res = res.to_tensordict() return res diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 6c1dbbd0389..b27c1f795a2 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import re import warnings from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union @@ -506,13 +507,18 @@ def auto_register_info_dict(self): try: check_env_specs(self) return self - except AssertionError as err: - if "The keys of the specs and data do not match" in str(err): - result = TransformedEnv( - self, TensorDictPrimer(self.info_dict_reader[0].info_spec) - ) - check_env_specs(result) - return result + except (AssertionError, RuntimeError) as err: + patterns = [ + "The keys of the specs and data do not match", + "The sets of keys in the tensordicts to stack are exclusive", + ] + for pattern in patterns: + if re.search(pattern, str(err)): + result = TransformedEnv( + self, TensorDictPrimer(self.info_dict_reader[0].info_spec) + ) + check_env_specs(result) + return result raise err def __repr__(self) -> str: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index cd51b4fd23b..08615ea12ae 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -193,7 +193,6 @@ def _is_reset(key: NestedKey): expected = set(expected) self.validated = expected.intersection(actual) == expected if not self.validated: - raise RuntimeError warnings.warn( "The expected key set and actual key set differ. " "This will work but with a slower throughput than "