From b5e90c4f29ff782437f009b8a15dd511074033c9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 18:31:53 +0000 Subject: [PATCH] [Deprecation] Deprecate in prep for release (#1820) --- .../linux_examples/scripts/run_test.sh | 2 + examples/bandits/dqn.py | 17 +- examples/cql/discrete_cql_config.yaml | 4 +- examples/cql/discrete_cql_online.py | 2 +- examples/cql/offline_config.yaml | 2 +- examples/cql/online_config.yaml | 2 +- examples/cql/utils.py | 7 +- examples/ddpg/config.yaml | 2 +- examples/ddpg/ddpg.py | 2 +- examples/ddpg/utils.py | 6 +- examples/decision_transformer/dt_config.yaml | 2 +- examples/decision_transformer/odt_config.yaml | 2 +- examples/decision_transformer/utils.py | 2 +- examples/discrete_sac/config.yaml | 2 +- examples/discrete_sac/discrete_sac.py | 2 +- examples/discrete_sac/utils.py | 6 +- .../collectors/multi_nodes/ray_train.py | 2 +- examples/iql/utils.py | 6 +- examples/multiagent/iql.py | 22 +- examples/multiagent/maddpg_iddpg.py | 2 +- examples/multiagent/mappo_ippo.py | 2 +- examples/multiagent/qmix_vdn.py | 22 +- examples/multiagent/sac.py | 2 +- examples/redq/config.yaml | 2 +- examples/rlhf/train_rlhf.py | 2 +- examples/sac/config.yaml | 2 +- examples/sac/sac.py | 2 +- examples/sac/utils.py | 6 +- examples/td3/config.yaml | 2 +- examples/td3/td3.py | 2 +- examples/td3/utils.py | 7 +- test/mocking_classes.py | 7 +- test/test_collector.py | 65 +-- test/test_cost.py | 4 +- test/test_distributed.py | 3 +- test/test_exploration.py | 8 - test/test_helpers.py | 18 +- test/test_transforms.py | 452 ++++++++++++++---- torchrl/collectors/collectors.py | 113 +++-- torchrl/collectors/distributed/generic.py | 14 +- torchrl/collectors/distributed/ray.py | 19 +- torchrl/collectors/distributed/rpc.py | 14 +- torchrl/collectors/distributed/sync.py | 14 +- torchrl/data/datasets/d4rl.py | 31 +- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/tensor_specs.py | 6 +- torchrl/envs/gym_like.py | 2 +- torchrl/envs/transforms/transforms.py | 12 +- torchrl/modules/models/models.py | 9 +- torchrl/modules/tensordict_module/actors.py | 10 +- .../modules/tensordict_module/exploration.py | 99 +--- torchrl/modules/tensordict_module/rnn.py | 9 +- torchrl/objectives/a2c.py | 31 +- torchrl/objectives/common.py | 8 +- torchrl/objectives/cql.py | 8 +- torchrl/objectives/ddpg.py | 6 +- torchrl/objectives/deprecated.py | 6 +- torchrl/objectives/dqn.py | 5 +- torchrl/objectives/dreamer.py | 9 +- torchrl/objectives/iql.py | 17 +- torchrl/objectives/multiagent/qmixer.py | 5 +- torchrl/objectives/ppo.py | 8 +- torchrl/objectives/redq.py | 6 +- torchrl/objectives/reinforce.py | 31 +- torchrl/objectives/sac.py | 5 +- torchrl/objectives/td3.py | 6 +- torchrl/objectives/utils.py | 9 +- torchrl/objectives/value/advantages.py | 18 +- torchrl/trainers/helpers/logger.py | 2 + torchrl/trainers/trainers.py | 6 +- 70 files changed, 729 insertions(+), 513 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 0cbcb70ad15..e75f4b1bc1c 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -114,6 +114,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari. buffer.batch_size=10 \ device=cuda:0 \ loss.num_updates=1 \ + logger.backend= \ buffer.buffer_size=120 python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \ collector.total_frames=48 \ @@ -256,6 +257,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari. buffer.batch_size=10 \ device=cuda:0 \ loss.num_updates=1 \ + logger.backend= \ buffer.buffer_size=120 python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ num_workers=2 \ diff --git a/examples/bandits/dqn.py b/examples/bandits/dqn.py index 847cfbfc124..0d9ca828ee6 100644 --- a/examples/bandits/dqn.py +++ b/examples/bandits/dqn.py @@ -7,11 +7,12 @@ import torch import tqdm -from torch import nn +from tensordict.nn import TensorDictSequential +from torch import nn from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import DistributionalQValueActor, EGreedyWrapper, MLP, QValueActor +from torchrl.modules import DistributionalQValueActor, EGreedyModule, MLP, QValueActor from torchrl.objectives import DistributionalDQNLoss, DQNLoss parser = argparse.ArgumentParser() @@ -85,12 +86,14 @@ actor(env.reset()) loss = DQNLoss(actor, loss_function="smooth_l1", action_space=env.action_spec) loss.make_value_estimator(gamma=0.0) - policy = EGreedyWrapper( + policy = TensorDictSequential( actor, - eps_init=eps_greedy, - eps_end=0.0, - annealing_num_steps=n_steps, - spec=env.action_spec, + EGreedyModule( + eps_init=eps_greedy, + eps_end=0.0, + annealing_num_steps=n_steps, + spec=env.action_spec, + ), ) optim = torch.optim.Adam(loss.parameters(), lr, weight_decay=wd) diff --git a/examples/cql/discrete_cql_config.yaml b/examples/cql/discrete_cql_config.yaml index b7f8d527ba3..807479d45bd 100644 --- a/examples/cql/discrete_cql_config.yaml +++ b/examples/cql/discrete_cql_config.yaml @@ -2,7 +2,7 @@ env: name: CartPole-v1 task: "" - backend: gym + backend: gymnasium n_samples_stats: 1000 max_episode_steps: 200 seed: 0 @@ -36,7 +36,7 @@ replay_buffer: prb: 0 buffer_prefetch: 64 size: 1_000_000 - scratch_dir: ${env.exp_name}_${env.seed} + scratch_dir: null # Optimization optim: diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py index facbcc49bf9..107739f3aba 100644 --- a/examples/cql/discrete_cql_online.py +++ b/examples/cql/discrete_cql_online.py @@ -73,7 +73,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/cql/offline_config.yaml b/examples/cql/offline_config.yaml index d41db847077..0047b74d14c 100644 --- a/examples/cql/offline_config.yaml +++ b/examples/cql/offline_config.yaml @@ -5,7 +5,7 @@ env: library: gym n_samples_stats: 1000 seed: 0 - backend: gym # D4RL uses gym so we make sure gymnasium is hidden + backend: gymnasium # logger logger: diff --git a/examples/cql/online_config.yaml b/examples/cql/online_config.yaml index 367d4755cac..9b3e5b5bf24 100644 --- a/examples/cql/online_config.yaml +++ b/examples/cql/online_config.yaml @@ -6,7 +6,7 @@ env: seed: 0 train_num_envs: 1 eval_num_envs: 1 - backend: gym + backend: gymnasium # Collector collector: diff --git a/examples/cql/utils.py b/examples/cql/utils.py index 0af1a082e28..350b105b441 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -121,7 +121,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -133,7 +133,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -144,7 +144,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -320,7 +320,6 @@ def make_discrete_loss(loss_cfg, model): model, loss_function=loss_cfg.loss_function, delay_value=True, - gamma=loss_cfg.gamma, ) loss_module.make_value_estimator(gamma=loss_cfg.gamma) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index fb4a3fa4725..7d17038330b 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -21,7 +21,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optimization optim: diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index ea5a1386e4f..92fdd850fbd 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 935fb426988..4006fc27b38 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -119,7 +119,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -131,7 +131,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -142,7 +142,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index 80915c4f93a..b42d8b58d35 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -36,7 +36,7 @@ replay_buffer: stacked_frames: 20 buffer_prefetch: 64 capacity: 1_000_000 - buffer_scratch_dir: + scratch_dir: device: cpu prefetch: 3 diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index b6137ac62a1..f06972fd46b 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -36,7 +36,7 @@ replay_buffer: stacked_frames: 20 buffer_prefetch: 64 capacity: 1_000_000 - buffer_scratch_dir: + scratch_dir: device: cuda:0 prefetch: 3 diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 8bd9f3bebbf..9d479a8118d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -296,7 +296,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): ) storage = LazyMemmapStorage( max_size=rb_cfg.capacity, - scratch_dir=rb_cfg.buffer_scratch_dir, + scratch_dir=rb_cfg.scratch_dir, device=rb_cfg.device, ) diff --git a/examples/discrete_sac/config.yaml b/examples/discrete_sac/config.yaml index 03ae3999f87..df26c835ef0 100644 --- a/examples/discrete_sac/config.yaml +++ b/examples/discrete_sac/config.yaml @@ -22,7 +22,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optim optim: diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 2976cf8806d..16c5de80a64 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/discrete_sac/utils.py b/examples/discrete_sac/utils.py index 49ec8bc1204..5821ed53465 100644 --- a/examples/discrete_sac/utils.py +++ b/examples/discrete_sac/utils.py @@ -120,14 +120,14 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): with ( tempfile.TemporaryDirectory() - if buffer_scratch_dir is None - else nullcontext(buffer_scratch_dir) + if scratch_dir is None + else nullcontext(scratch_dir) ) as scratch_dir: if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 2db86b9f917..7d456367a5a 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -36,7 +36,7 @@ if __name__ == "__main__": # 1. Define Hyperparameters - device = "cpu" # if not torch.has_cuda else "cuda:0" + device = "cpu" # if not torch.cuda.device_count() else "cuda:0" num_cells = 256 max_grad_norm = 1.0 frame_skip = 1 diff --git a/examples/iql/utils.py b/examples/iql/utils.py index fe1e5ce32b8..997df401b82 100644 --- a/examples/iql/utils.py +++ b/examples/iql/utils.py @@ -125,7 +125,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -137,7 +137,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -148,7 +148,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 4af5da62c91..011e04cde77 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -8,7 +8,7 @@ import hydra import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer @@ -17,7 +17,7 @@ from torchrl.envs import RewardSum, TransformedEnv from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential +from torchrl.modules import EGreedyModule, QValueModule, SafeSequential from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators from utils.logging import init_logging, log_evaluation, log_training @@ -31,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="iql") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding @@ -96,13 +96,15 @@ def train(cfg: "DictConfig"): # noqa: F821 ) qnet = SafeSequential(module, value_module) - qnet_explore = EGreedyWrapper( + qnet_explore = TensorDictSequential( qnet, - eps_init=0.3, - eps_end=0, - annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), - action_key=env.action_key, - spec=env.unbatched_action_spec, + EGreedyModule( + eps_init=0.3, + eps_end=0, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + spec=env.unbatched_action_spec, + ), ) collector = SyncDataCollector( @@ -174,7 +176,7 @@ def train(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() target_net_updater.step() - qnet_explore.step(frames=current_frames) # Update exploration annealing + qnet_explore[1].step(frames=current_frames) # Update exploration annealing collector.update_policy_weights_() training_time = time.time() - training_start diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index 4e6b821604c..e4fd4a25e12 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -36,7 +36,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="maddpg_iddpg") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index b00bb18a2a0..d4481c93071 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -31,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="mappo_ippo") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 5822bda39da..e53c47e04f4 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -8,7 +8,7 @@ import hydra import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer @@ -17,7 +17,7 @@ from torchrl.envs import RewardSum, TransformedEnv from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential +from torchrl.modules import EGreedyModule, QValueModule, SafeSequential from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer from torchrl.objectives import SoftUpdate, ValueEstimators from torchrl.objectives.multiagent.qmixer import QMixerLoss @@ -31,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="qmix_vdn") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding @@ -96,13 +96,15 @@ def train(cfg: "DictConfig"): # noqa: F821 ) qnet = SafeSequential(module, value_module) - qnet_explore = EGreedyWrapper( + qnet_explore = TensorDictSequential( qnet, - eps_init=0.3, - eps_end=0, - annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), - action_key=env.action_key, - spec=env.unbatched_action_spec, + EGreedyModule( + eps_init=0.3, + eps_end=0, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + spec=env.unbatched_action_spec, + ), ) if cfg.loss.mixer_type == "qmix": @@ -209,7 +211,7 @@ def train(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() target_net_updater.step() - qnet_explore.step(frames=current_frames) # Update exploration annealing + qnet_explore[1].step(frames=current_frames) # Update exploration annealing collector.update_policy_weights_() training_time = time.time() - training_start diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 1c01b5e50b7..528b5422921 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -33,7 +33,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="sac") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding diff --git a/examples/redq/config.yaml b/examples/redq/config.yaml index fc77974cb38..c67543716dc 100644 --- a/examples/redq/config.yaml +++ b/examples/redq/config.yaml @@ -68,7 +68,7 @@ buffer: prb: 1 sub_traj_len: size: 500_000 - scratch_dir: + scratch_dir: null prefetch: 64 network: diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 6f3e80649d7..a921e58bad6 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -62,7 +62,7 @@ def main(cfg): wandb_kwargs={ "config": dict(cfg), "project": cfg.io.project_name, - "group": cfg.logger.group_name, + "group": cfg.io.group_name, }, ) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index b6675ecc9a0..6546f1e30b7 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -20,7 +20,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optim optim: diff --git a/examples/sac/sac.py b/examples/sac/sac.py index a93e3a833dd..db23071867a 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 1e157ce85cd..afb731dcc95 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -108,7 +108,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -120,7 +120,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -131,7 +131,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 561766cd5a4..e94a5b6b774 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -21,7 +21,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optim optim: diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 1f42e7273d1..003a3bf228c 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 0abc769d365..fed055f98bf 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -121,14 +121,14 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): with ( tempfile.TemporaryDirectory() - if buffer_scratch_dir is None - else nullcontext(buffer_scratch_dir) + if scratch_dir is None + else nullcontext(scratch_dir) ) as scratch_dir: if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( @@ -248,7 +248,6 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, delay_actor=True, delay_qvalue=True, - gamma=cfg.optim.gamma, action_spec=model[0][1].spec, policy_noise=cfg.optim.policy_noise, noise_clip=cfg.optim.noise_clip, diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 9e5b2ff6879..7a32c9a38ef 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -646,7 +646,7 @@ def _obs_step(self, obs, a): return obs + a / self.maxstep -class DiscreteActionVecPolicy: +class DiscreteActionVecPolicy(TensorDictModuleBase): in_keys = ["observation"] out_keys = ["action"] @@ -979,10 +979,13 @@ def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) -class CountingEnvCountPolicy: +class CountingEnvCountPolicy(TensorDictModuleBase): def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() self.action_spec = action_spec self.action_key = action_key + self.in_keys = [] + self.out_keys = [action_key] def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.set(self.action_key, self.action_spec.zero() + 1) diff --git a/test/test_collector.py b/test/test_collector.py index 027cf776ee4..8369be1578e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1491,37 +1491,38 @@ def test_auto_wrap_modules( collector.shutdown() del collector - def test_no_wrap_compatible_module(self, collector_class, env_maker): - policy = TensorDictCompatiblePolicy( - out_features=env_maker().action_spec.shape[-1] - ) - policy(env_maker().reset()) - - collector = collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy) - ) - - if collector_class is not SyncDataCollector: - # We now do the casting only on the remote workers - pass - else: - assert isinstance(collector.policy, TensorDictCompatiblePolicy) - assert collector.policy.out_keys == ["action"] - assert collector.policy is policy - - for i, data in enumerate(collector): - if i == 0: - assert (data["action"] != 0).any() - for p in policy.parameters(): - p.data.zero_() - assert p.device == torch.device("cpu") - collector.update_policy_weights_() - elif i == 4: - assert (data["action"] == 0).all() - break - - collector.shutdown() - del collector + # Deprecated as from v0.3 + # def test_no_wrap_compatible_module(self, collector_class, env_maker): + # policy = TensorDictCompatiblePolicy( + # out_features=env_maker().action_spec.shape[-1] + # ) + # policy(env_maker().reset()) + # + # collector = collector_class( + # **self._create_collector_kwargs(env_maker, collector_class, policy) + # ) + # + # if collector_class is not SyncDataCollector: + # # We now do the casting only on the remote workers + # pass + # else: + # assert isinstance(collector.policy, TensorDictCompatiblePolicy) + # assert collector.policy.out_keys == ["action"] + # assert collector.policy is policy + # + # for i, data in enumerate(collector): + # if i == 0: + # assert (data["action"] != 0).any() + # for p in policy.parameters(): + # p.data.zero_() + # assert p.device == torch.device("cpu") + # collector.update_policy_weights_() + # elif i == 4: + # assert (data["action"] == 0).all() + # break + # + # collector.shutdown() + # del collector def test_auto_wrap_error(self, collector_class, env_maker): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) @@ -2062,7 +2063,7 @@ def _reset(self, tensordict=None): def _set_seed(self, seed): return seed - class Policy(nn.Module): + class Policy(TensorDictModuleBase): def __init__(self): super().__init__() self.param = nn.Parameter(torch.zeros(())) diff --git a/test/test_cost.py b/test/test_cost.py index 87e17eb252c..c6eb27172ee 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6091,7 +6091,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): else: raise NotImplementedError - loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2") + loss_fn = loss_class(actor, value, loss_critic_type="l2") params = TensorDict.from_module(loss_fn, as_module=True) @@ -11960,7 +11960,7 @@ def test_set_deprecated_keys(self, adv, kwargs): nn.Linear(3, 1), in_keys=["obs"], out_keys=["test_value"] ) - with pytest.warns(DeprecationWarning): + with pytest.raises(RuntimeError, match="via constructor is deprecated"): if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] diff --git a/test/test_distributed.py b/test/test_distributed.py index debfa058ace..6215abd7ceb 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -14,6 +14,7 @@ import time import pytest +from tensordict.nn import TensorDictModuleBase try: import ray @@ -49,7 +50,7 @@ pytest.skip("skipping windows tests in windows", allow_module_level=True) -class CountingPolicy(nn.Module): +class CountingPolicy(TensorDictModuleBase): """A policy for counting env. Returns a step of 1 by default but weights can be adapted. diff --git a/test/test_exploration.py b/test/test_exploration.py index 777f2714edb..d0735a53ae8 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -156,14 +156,6 @@ def test_egreedy_masked(self, module, eps_init, spec_class): assert not (action[~action_mask] == 0).all() assert (masked_action[~action_mask] == 0).all() - def test_egreedy_wrapper_deprecation(self): - torch.manual_seed(0) - spec = BoundedTensorSpec(1, 1, torch.Size([4])) - module = torch.nn.Linear(4, 4, bias=False) - policy = Actor(spec=spec, module=module) - with pytest.deprecated_call(): - EGreedyWrapper(policy) - def test_no_spec_error( self, ): diff --git a/test/test_helpers.py b/test/test_helpers.py index 1843a3f738f..eb9620001c7 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -501,9 +501,13 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial stats = {"loc": None, "scale": None} if initialized: stats = {"loc": 0.0, "scale": 1.0} - t_env.transform = ObservationNorm(standard_normal=True, **stats) + t_env.transform = ObservationNorm( + in_keys=["observation"], standard_normal=True, **stats + ) if composed: - t_env.append_transform(ObservationNorm(standard_normal=True, **stats)) + t_env.append_transform( + ObservationNorm(in_keys=["observation"], standard_normal=True, **stats) + ) if not initialized: with pytest.raises( ValueError, match="Attempted to use an uninitialized parameter" @@ -539,7 +543,7 @@ def test_initialize_stats_from_non_obs_transform(device): def test_initialize_obs_transform_stats_raise_exception(): env = ContinuousActionVecMockEnv() t_env = TransformedEnv(env) - t_env.transform = ObservationNorm() + t_env.transform = ObservationNorm(in_keys=["observation"]) with pytest.raises( RuntimeError, match="More than one key exists in the observation_specs" ): @@ -553,10 +557,14 @@ def test_retrieve_observation_norms_state_dict(device, composed): env.set_seed(1) t_env = TransformedEnv(env) - t_env.transform = ObservationNorm(standard_normal=True, loc=0.5, scale=0.2) + t_env.transform = ObservationNorm( + standard_normal=True, loc=0.5, scale=0.2, in_keys=["observation"] + ) if composed: t_env.append_transform( - ObservationNorm(standard_normal=True, loc=1.0, scale=0.3) + ObservationNorm( + standard_normal=True, loc=1.0, scale=0.3, in_keys=["observation"] + ) ) initialize_observation_norm_transforms(proof_environment=t_env, num_iter=100) state_dicts = retrieve_observation_norms_state_dict(t_env) diff --git a/test/test_transforms.py b/test/test_transforms.py index b325a1ccd99..725945ef113 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -229,22 +229,28 @@ def test_parallel_trans_env_check(self): env = ParallelEnv( 2, lambda: TransformedEnv(ContinuousActionVecMockEnv(), BinarizeReward()) ) - check_env_specs(env) - env.close() + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, lambda: ContinuousActionVecMockEnv()), BinarizeReward() ) - check_env_specs(env) - env.close() + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), BinarizeReward() ) - check_env_specs(env) - env.close() + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) @@ -546,7 +552,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -575,7 +584,10 @@ def test_trans_parallel_env_check(self): high=0.1, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = ContinuousActionVecMockEnv() @@ -618,7 +630,10 @@ def test_parallel_trans_env_check(self): CatFrames(dim=-1, N=3, in_keys=["observation"]), ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -639,7 +654,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), CatFrames(dim=-1, N=3, in_keys=["observation"]), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -1171,7 +1189,10 @@ def make_env(): ) transformed_env = ParallelEnv(2, make_env) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_serial_trans_env_check(self, model, device): if model != "resnet18": @@ -1213,7 +1234,10 @@ def test_trans_parallel_env_check(self, model, device): transformed_env = TransformedEnv( ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), r3m ) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_trans_serial_env_check(self, model, device): if model != "resnet18": @@ -1545,7 +1569,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(10)) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -1558,7 +1585,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), StepCounter(10) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), StepCounter(10)) @@ -1839,7 +1869,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), ct) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): ct = CatTensors( @@ -1861,7 +1894,10 @@ def test_trans_parallel_env_check(self): ) env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), ct) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( @@ -2210,7 +2246,10 @@ def make_env(): return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): keys = ["pixels"] @@ -2222,7 +2261,10 @@ def test_trans_parallel_env_check(self): keys = ["pixels"] ct = Compose(ToTensorImage(), CenterCrop(w=20, h=20, in_keys=keys)) env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="No Gym detected") @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) @@ -2266,7 +2308,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -2280,7 +2325,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), DiscreteActionProjection(7, 10), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("action_key", ["action", ("nested", "stuff")]) def test_transform_no_env(self, action_key): @@ -2526,7 +2574,10 @@ def test_trans_parallel_env_check(self, dtype_fixture): # noqa: F811 ParallelEnv(2, lambda: ContinuousActionVecMockEnv(dtype=torch.float64)), DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self, dtype_fixture): # noqa: F811 t = DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]) @@ -2681,7 +2732,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): t = Compose( @@ -2701,7 +2755,10 @@ def test_trans_parallel_env_check(self): ExcludeTransform("observation_copy"), ) env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -2907,7 +2964,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): t = Compose( @@ -2927,7 +2987,10 @@ def test_trans_parallel_env_check(self): SelectTransform("observation", "observation_orig"), ) env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -3094,6 +3157,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3113,7 +3180,10 @@ def test_trans_parallel_env_check(self): -1, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize("nchannels", [1, 3]) @@ -3265,7 +3335,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3277,7 +3350,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), FrameSkipTransform(2) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = FrameSkipTransform(2) @@ -3500,7 +3576,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): out_keys = None @@ -3516,7 +3595,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), GrayScale(out_keys=out_keys)), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_transform_env(self, out_keys): @@ -3589,7 +3671,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), NoopResetEnv()) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), NoopResetEnv()) @@ -3759,7 +3844,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check( self, @@ -3785,7 +3873,10 @@ def test_trans_parallel_env_check( scale=1.0, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("standard_normal", [True, False]) @pytest.mark.parametrize("in_key", ["observation", ("some_other", "observation")]) @@ -4176,13 +4267,13 @@ def make_env(): ) def test_observationnorm_stats_already_initialized_error(self): - transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1) + transform = ObservationNorm(in_keys=["next_observation"], loc=0, scale=1) with pytest.raises(RuntimeError, match="Loc/Scale are already initialized"): transform.init_stats(num_iter=11) def test_observationnorm_wrong_catdim(self): - transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1) + transform = ObservationNorm(in_keys=["next_observation"], loc=0, scale=1) with pytest.raises( ValueError, match="cat_dim must be part of or equal to reduce_dim" @@ -4336,7 +4427,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4350,7 +4444,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), Resize(20, 21, in_keys=["pixels"])), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="No gym") @pytest.mark.parametrize("out_key", ["pixels", ("agents", "pixels")]) @@ -4406,7 +4503,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4418,7 +4518,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), RewardClipping(-0.1, 0.1) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("reward_key", ["reward", ("agents", "reward")]) def test_transform_no_env(self, reward_key): @@ -4535,7 +4638,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), RewardScaling(0.5, 1.5)) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4547,7 +4653,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), RewardScaling(0.5, 1.5) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("standard_normal", [True, False]) def test_transform_no_env(self, standard_normal): @@ -4660,9 +4769,12 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) - r = env.rollout(4) - assert r["next", "episode_reward"].unique().numel() > 1 + try: + check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4678,9 +4790,12 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, ContinuousActionVecMockEnv), Compose(RewardScaling(loc=-1, scale=1), RewardSum()), ) - check_env_specs(env) - r = env.rollout(4) - assert r["next", "episode_reward"].unique().numel() > 1 + try: + check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 + finally: + env.close() @pytest.mark.parametrize("has_in_keys,", [True, False]) @pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3]) @@ -5320,7 +5435,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -5334,7 +5452,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, ContinuousActionVecMockEnv), UnsqueezeTransform(-1, in_keys=["observation"]), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5619,19 +5740,28 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), self._circular_transform ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), self._circular_transform ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("squeeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5821,7 +5951,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -5831,7 +5964,10 @@ def test_trans_serial_env_check(self, mode, device): TargetReturn(target_return=10.0, mode=mode).to(device), device=device, ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -5841,7 +5977,10 @@ def test_trans_parallel_env_check(self, mode, device): TargetReturn(target_return=10.0, mode=mode), device=device, ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [SerialEnv, ParallelEnv]) @@ -6047,7 +6186,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6061,7 +6203,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ToTensorImage(in_keys=["pixels"], out_keys=None), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("out_keys", [None, ["stuff"], [("nested", "stuff")]]) @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) @@ -6189,9 +6334,12 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) - assert "mykey" in env.reset().keys() - assert ("next", "mykey") in env.rollout(3).keys(True) + try: + check_env_specs(env) + assert "mykey" in env.reset().keys() + assert ("next", "mykey") in env.rollout(3).keys(True) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -6201,20 +6349,26 @@ def make_env(): ) env = SerialEnv(2, make_env) - check_env_specs(env) - assert "mykey" in env.reset().keys() - assert ("next", "mykey") in env.rollout(3).keys(True) + try: + check_env_specs(env) + assert "mykey" in env.reset().keys() + assert ("next", "mykey") in env.rollout(3).keys(True) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), ) - check_env_specs(env) - assert "mykey" in env.reset().keys() - r = env.rollout(3) - assert ("next", "mykey") in r.keys(True) - assert r["next", "mykey"].shape == torch.Size([2, 3, 4]) + try: + check_env_specs(env) + assert "mykey" in env.reset().keys() + r = env.rollout(3) + assert ("next", "mykey") in r.keys(True) + assert r["next", "mykey"].shape == torch.Size([2, 3, 4]) + finally: + env.close() def test_trans_serial_env_check(self): with pytest.raises(RuntimeError, match="The leading shape of the primer specs"): @@ -6414,7 +6568,10 @@ def test_parallel_trans_env_check(self): ), ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6434,7 +6591,10 @@ def test_trans_parallel_env_check(self): T=3, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -6595,7 +6755,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): state_dim = 7 @@ -6610,7 +6773,10 @@ def test_trans_serial_env_check(self): SerialEnv(2, ContinuousActionVecMockEnv), gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): state_dim = 7 @@ -6619,7 +6785,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, ContinuousActionVecMockEnv), gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): state_dim = 7 @@ -6736,7 +6905,10 @@ def test_trans_parallel_env_check(self, model, device): transformed_env = TransformedEnv( ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), vip ) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_serial_trans_env_check(self, model, device): in_keys = ["pixels"] @@ -6774,7 +6946,10 @@ def make_env(): ) transformed_env = ParallelEnv(2, make_env) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_transform_model(self, model, device): in_keys = ["pixels"] @@ -7194,7 +7369,10 @@ def test_trans_parallel_env_check(self, device): transformed_env = TransformedEnv( ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), vc1 ) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_serial_trans_env_check(self, device): in_keys = ["pixels"] @@ -7671,8 +7849,12 @@ def test_independent_obs_specs_from_shared_env(self): observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,))) ) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) - t1 = TransformedEnv(base_env, transform=ObservationNorm(loc=3, scale=2)) - t2 = TransformedEnv(base_env, transform=ObservationNorm(loc=1, scale=6)) + t1 = TransformedEnv( + base_env, transform=ObservationNorm(in_keys=["observation"], loc=3, scale=2) + ) + t2 = TransformedEnv( + base_env, transform=ObservationNorm(in_keys=["observation"], loc=1, scale=6) + ) t1_obs_spec = t1.observation_spec t2_obs_spec = t2.observation_spec @@ -8122,7 +8304,9 @@ def test_batch_unlocked_with_batch_size_transformed(device): ), pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), GrayScale, - ObservationNorm, + pytest.param( + partial(ObservationNorm, in_keys=["observation"]), id="ObservationNorm" + ), pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), FiniteTensorDictCheck, @@ -8308,7 +8492,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def make_env(): return TransformedEnv( @@ -8323,7 +8510,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self, create_copy): def make_env(): @@ -8362,7 +8552,10 @@ def make_env(): create_copy=create_copy, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() env = TransformedEnv( ParallelEnv(2, make_env), RenameTransform( @@ -8373,7 +8566,10 @@ def make_env(): create_copy=create_copy, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("mode", ["forward", "_call"]) @pytest.mark.parametrize( @@ -8572,7 +8768,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): def make_env(): @@ -8581,7 +8780,10 @@ def make_env(): env = SerialEnv(2, make_env) env = TransformedEnv(env, InitTracker()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): def make_env(): @@ -8590,7 +8792,10 @@ def make_env(): env = ParallelEnv(2, make_env) env = TransformedEnv(env, InitTracker()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): with pytest.raises(ValueError, match="init_key can only be of type str"): @@ -8855,19 +9060,28 @@ def make_env(): return TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): out_key = "reward" base_env = SerialEnv(2, self.envclass) env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): out_key = "reward" base_env = ParallelEnv(2, self.envclass) env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_model(self): actor = self._make_actor() @@ -9003,15 +9217,24 @@ def test_serial_trans_env_check(self): def test_parallel_trans_env_check(self): env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = ActionMask() @@ -9084,7 +9307,10 @@ def make_env(): env = ParallelEnv(2, make_env) assert env.device == torch.device("cpu:1") - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): def make_env(): @@ -9100,7 +9326,10 @@ def make_env(): env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1")) assert env.device == torch.device("cpu:1") - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = DeviceCastTransform("cpu:1", "cpu:0") @@ -9189,21 +9418,30 @@ def test_parallel_trans_env_check(self): TestPermuteTransform.envclass(), TestPermuteTransform._get_permute() ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, TestPermuteTransform.envclass), TestPermuteTransform._get_permute(), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, TestPermuteTransform.envclass), TestPermuteTransform._get_permute(), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) def test_transform_compose(self, batch): @@ -9352,7 +9590,10 @@ def make(): ) env = ParallelEnv(2, make) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = EndOfLifeTransform() @@ -9752,7 +9993,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -9776,7 +10020,10 @@ def test_trans_parallel_env_check(self): in_keys_inv=["observation_orig"], ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -9786,7 +10033,10 @@ def test_trans_serial_env_check(self): in_keys_inv=["observation_orig"], ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() class TestRemoveEmptySpecs(TransformBase): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index ffb8c0f5270..ef972fd343e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -26,10 +26,11 @@ from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np + import torch import torch.nn as nn - from tensordict import ( + is_tensor_collection, LazyStackedTensorDict, TensorDict, TensorDictBase, @@ -169,16 +170,13 @@ def _policy_is_tensordict_compatible(policy: nn.Module): and hasattr(policy, "in_keys") and hasattr(policy, "out_keys") ): - warnings.warn( - "Passing a policy that is not a TensorDictModuleBase subclass but has in_keys and out_keys " - "will soon be deprecated. We'd like to motivate our users to inherit from this class (which " - "has very few restrictions) to make the experience smoother.", - category=DeprecationWarning, + raise RuntimeError( + "Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys " + "is deprecated. Users should inherit from this class (which " + "has very few restrictions) to make the experience smoother. " + "Simply change your policy from `class Policy(nn.Module)` to `Policy(tensordict.nn.TensorDictModuleBase)` " + "and this error should disappear.", ) - # if the policy is a TensorDictModule or takes a single argument and defines - # in_keys and out_keys then we assume it can already deal with TensorDict input - # to forward and we return True - return True elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): # if it's not a TensorDictModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. @@ -235,7 +233,15 @@ def _make_compatible_policy(self, policy, observation_spec=None): key: value for key, value in observation_spec.rand().items() } # we check if all the mandatory params are there - if set(sig.parameters) == {"tensordict"} or set(sig.parameters) == {"td"}: + params = list(sig.parameters.keys()) + if ( + set(sig.parameters) == {"tensordict"} + or set(sig.parameters) == {"td"} + or ( + len(params) == 1 + and is_tensor_collection(sig.parameters[params[0]].annotation) + ) + ): pass elif not required_kwargs.difference(set(next_observation)): in_keys = [str(k) for k in sig.parameters if k in next_observation] @@ -266,6 +272,7 @@ def _make_compatible_policy(self, policy, observation_spec=None): then the arguments to policy.forward must correspond one-to-one with entries in env.observation_spec that are prefixed with 'next_'. For more complex behaviour and more control you can consider writing your own TensorDictModule. +Check the collector documentation to know more about accepted policies. """ ) return policy @@ -385,6 +392,18 @@ class SyncDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total @@ -978,10 +997,10 @@ def is_private(key): # >>> assert data0["done"] is not data1["done"] yield tensordict_out.clone() - def _update_traj_ids(self, tensordict) -> None: + def _update_traj_ids(self, env_output) -> None: # we can't use the reset keys because they're gone traj_sop = _aggregate_end_of_traj( - tensordict.get("next"), done_keys=self.env.done_keys + env_output.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): traj_ids = self._shuttle.get(("collector", "traj_ids")) @@ -1230,11 +1249,23 @@ class _MultiDataCollector(DataCollectorBase): Args: create_env_fn (List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable, optional): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the @@ -2299,8 +2330,23 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable, optional): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the @@ -2497,7 +2543,7 @@ def _main_async_collector( ) -> None: pipe_parent.close() # init variables that will be cleared when closing - tensordict = data = d = data_in = inner_collector = dc_iter = None + collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None inner_collector = SyncDataCollector( create_env_fn, @@ -2571,42 +2617,45 @@ def _main_async_collector( else: inner_collector.init_random_frames = -1 - d = next(dc_iter) + next_data = next(dc_iter) if pipe_child.poll(_MIN_TIMEOUT): # in this case, main send a message to the worker while it was busy collecting trajectories. # In that case, we skip the collected trajectory and get the message from main. This is faster than # sending the trajectory in the queue until timeout when it's never going to be received. continue if j == 0: - tensordict = d - if storing_device is not None and tensordict.device != storing_device: + collected_tensordict = next_data + if ( + storing_device is not None + and collected_tensordict.device != storing_device + ): raise RuntimeError( - f"expected device to be {storing_device} but got {tensordict.device}" + f"expected device to be {storing_device} but got {collected_tensordict.device}" ) # If policy and env are on cpu, we put in shared mem, # if policy is on cuda and env on cuda, we are fine with this # If policy is on cuda and env on cpu (or opposite) we put tensors that # are on cpu in shared mem. - if tensordict.device is not None: + if collected_tensordict.device is not None: # placehoder in case we need different behaviours - if tensordict.device.type in ("cpu", "mps"): - tensordict.share_memory_() - elif tensordict.device.type == "cuda": - tensordict.share_memory_() + if collected_tensordict.device.type in ("cpu", "mps"): + collected_tensordict.share_memory_() + elif collected_tensordict.device.type == "cuda": + collected_tensordict.share_memory_() else: raise NotImplementedError( - f"Device {tensordict.device} is not supported in multi-collectors yet." + f"Device {collected_tensordict.device} is not supported in multi-collectors yet." ) else: # make sure each cpu tensor is shared - assuming non-cpu devices are shared - tensordict.apply( + collected_tensordict.apply( lambda x: x.share_memory_() if x.device.type in ("cpu", "mps") else x ) - data = (tensordict, idx) + data = (collected_tensordict, idx) else: - if d is not tensordict: + if next_data is not collected_tensordict: raise RuntimeError( "SyncDataCollector should return the same tensordict modified in-place." ) @@ -2661,7 +2710,7 @@ def _main_async_collector( continue elif msg == "close": - del tensordict, data, d, data_in + del collected_tensordict, data, next_data, data_in inner_collector.shutdown() del inner_collector, dc_iter pipe_child.send("closed") diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 073d2f445ab..0c5c74b6510 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -260,8 +260,20 @@ class DistributedDataCollector(DataCollectorBase): policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index fa2d8e8191e..6788e48ee3a 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -118,8 +118,23 @@ class RayCollector(DataCollectorBase): Args: create_env_fn (Callable or List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 50729038b4a..dbfc5a7dfd9 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -101,8 +101,20 @@ class RPCDataCollector(DataCollectorBase): policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index d7a5c94487d..7ea805248c9 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -141,8 +141,20 @@ class DistributedSyncDataCollector(DataCollectorBase): policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 10b9767de8e..468fcb9150c 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -146,7 +146,7 @@ def __init__( prefetch: int | None = None, transform: "torchrl.envs.Transform" | None = None, # noqa-F821 split_trajs: bool = False, - from_env: bool = None, + from_env: bool = False, use_truncated_as_done: bool = True, direct_download: bool = None, terminate_on_end: bool = None, @@ -165,29 +165,16 @@ def __init__( direct_download = not self._has_d4rl if not direct_download: - if from_env is None: - warnings.warn( - "from_env will soon default to ``False``, ie the data will be " - "downloaded without relying on d4rl by default. " - "For now, ``True`` will still be the default. " - "To disable this warning, explicitly pass the ``from_env`` argument " - "during construction of the dataset.", - category=DeprecationWarning, - ) - from_env = True - else: - warnings.warn( - "You are using the D4RL library for collecting data. " - "We advise against this use, as D4RL formatting can be " - "inconsistent. " - "To download the D4RL data without the D4RL library, use " - "direct_download=True in the dataset constructor. " - "Recurring to `direct_download=False` will soon be deprecated." - ) + warnings.warn( + "You are using the D4RL library for collecting data. " + "We advise against this use, as D4RL formatting can be " + "inconsistent. " + "To download the D4RL data without the D4RL library, use " + "direct_download=True in the dataset constructor. " + "Recurring to `direct_download=False` will soon be deprecated." + ) self.from_env = from_env else: - if from_env is None: - from_env = False self.from_env = from_env if (download == "force") or (download and not self._is_downloaded()): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index fd847f25c74..45c7be64a1a 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1087,7 +1087,7 @@ def _reset_batch_size(x): shape = x.get("_rb_batch_size", None) if shape is not None: warnings.warn( - "Reshaping nested tensordicts will be deprecated soon.", + "Reshaping nested tensordicts will be deprecated in v0.4.0.", category=DeprecationWarning, ) data = x.get("_data") diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1cfc970e61f..efe928856b9 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -379,7 +379,7 @@ def high(self, value): @property def minimum(self): warnings.warn( - f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low", + f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0", category=DeprecationWarning, ) return self._low.to(self.device) @@ -387,7 +387,7 @@ def minimum(self): @property def maximum(self): warnings.warn( - f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high", + f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0", category=DeprecationWarning, ) return self._high.to(self.device) @@ -1521,7 +1521,7 @@ class BoundedTensorSpec(TensorSpec): # SPEC_HANDLED_FUNCTIONS = {} DEPRECATED_KWARGS = ( "The `minimum` and `maximum` keyword arguments are now " - "deprecated in favour of `low` and `high`." + "deprecated in favour of `low` and `high` in v0.4.0." ) CONFLICTING_KWARGS = ( "The keyword arguments {} and {} conflict. Only one of these can be passed." diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 3ce3d2d630c..002b270cd84 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -520,7 +520,7 @@ def info_dict_reader(self, value: callable): warnings.warn( f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " f"This method will append a reader to the list of existing readers (if any). " - f"Setting info_dict_reader directly will be soon deprecated.", + f"Setting info_dict_reader directly will be deprecated in v0.4.0.", category=DeprecationWarning, ) self._info_dict_reader.append(value) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e59c481419c..a661b152d39 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2372,15 +2372,9 @@ def __init__( standard_normal: bool = False, ): if in_keys is None: - warnings.warn( - "Not passing in_keys to ObservationNorm will soon be deprecated. " - "Ensure you specify the entries to be normalized", - category=DeprecationWarning, + raise RuntimeError( + "Not passing in_keys to ObservationNorm is a deprecated behaviour." ) - in_keys = [ - "observation", - "pixels", - ] if out_keys is None: out_keys = copy(in_keys) @@ -2719,7 +2713,7 @@ def __init__( raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}") if padding == "zeros": warnings.warn( - "Padding option 'zeros' will be deprecated in the future. " + "Padding option 'zeros' will be deprecated in v0.4.0. " "Please use 'constant' padding with padding_value 0 instead.", category=DeprecationWarning, ) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 1cc10316045..c610bb61350 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -872,9 +872,6 @@ class DistributionalDQNnet(TensorDictModuleBase): """Distributional Deep Q-Network. Args: - DQNet (nn.Module): (deprecated) Q-Network with output length equal - to the number of atoms: - output.shape = [*batch, atoms, actions]. in_keys (list of str or tuples of str): input keys to the log-softmax operation. Defaults to ``["action_value"]``. out_keys (list of str or tuples of str): output keys to the log-softmax @@ -888,11 +885,11 @@ class DistributionalDQNnet(TensorDictModuleBase): "instead." ) - def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None): + def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): super().__init__() if DQNet is not None: warnings.warn( - f"Passing a network to {type(self)} is going to be deprecated.", + f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.", category=DeprecationWarning, ) if not ( @@ -1280,7 +1277,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, ) -> None: warnings.warn( - "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.", + "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", category=DeprecationWarning, ) super().__init__() diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index bf81cfd5dfd..b7a044cae7d 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -445,7 +445,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -825,7 +825,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -922,7 +922,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -1043,7 +1043,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -1189,7 +1189,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index c8fa9cc040f..9a7f88844cc 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -247,105 +247,10 @@ def __init__( action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): - warnings.warn( - "EGreedyWrapper is deprecated and it will be removed in v0.3. " - "Please use torchrl.modules.EGreedyModule instead.", - category=DeprecationWarning, + raise RuntimeError( + "This class is not removed in favour of torchrl.modules.EGreedyModule." ) - super().__init__(policy) - self.register_buffer("eps_init", torch.as_tensor([eps_init])) - self.register_buffer("eps_end", torch.as_tensor([eps_end])) - if self.eps_end > self.eps_init: - raise RuntimeError("eps should decrease over time or be constant") - self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) - self.action_key = action_key - self.action_mask_key = action_mask_key - if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) - self._spec = spec - elif hasattr(self.td_module, "_spec"): - self._spec = self.td_module._spec.clone() - if action_key not in self._spec.keys(): - self._spec[action_key] = None - elif hasattr(self.td_module, "spec"): - self._spec = self.td_module.spec.clone() - if action_key not in self._spec.keys(): - self._spec[action_key] = None - else: - self._spec = spec - - @property - def spec(self): - return self._spec - - def step(self, frames: int = 1) -> None: - """A step of epsilon decay. - - After self.annealing_num_steps, this function is a no-op. - - Args: - frames (int): number of frames since last step. - - """ - for _ in range(frames): - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), - ) - - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = self.td_module.forward(tensordict) - if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: - if isinstance(self.action_key, tuple) and len(self.action_key) > 1: - action_tensordict = tensordict.get(self.action_key[:-1]) - action_key = self.action_key[-1] - else: - action_tensordict = tensordict - action_key = self.action_key - - out = action_tensordict.get(action_key) - eps = self.eps.item() - cond = ( - torch.rand(action_tensordict.shape, device=action_tensordict.device) - < eps - ).to(out.dtype) - cond = expand_as_right(cond, out) - spec = self.spec - if spec is not None: - if isinstance(spec, CompositeSpec): - spec = spec[self.action_key] - if spec.shape != out.shape: - # In batched envs if the spec is passed unbatched, the rand() will not - # cover all batched dims - if ( - not len(spec.shape) - or out.shape[-len(spec.shape) :] == spec.shape - ): - spec = spec.expand(out.shape) - else: - raise ValueError( - "Action spec shape does not match the action shape" - ) - if self.action_mask_key is not None: - action_mask = tensordict.get(self.action_mask_key, None) - if action_mask is None: - raise KeyError( - f"Action mask key {self.action_mask_key} not found in {tensordict}." - ) - spec.update_mask(action_mask) - out = cond * spec.rand().to(out.device) + (1 - cond) * out - else: - raise RuntimeError( - "spec must be provided by the policy or directly to the exploration wrapper." - ) - action_tensordict.set(action_key, out) - return tensordict - class AdditiveGaussianWrapper(TensorDictModuleWrapper): """Additive Gaussian PO wrapper. diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index b05cbd55356..13cbd05e877 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from typing import Optional, Tuple import torch @@ -555,11 +554,9 @@ def recurrent_mode(self, value): @property def temporal_mode(self): - warnings.warn( + raise RuntimeError( "temporal_mode is deprecated, use recurrent_mode instead.", - category=DeprecationWarning, ) - return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). @@ -1255,11 +1252,9 @@ def recurrent_mode(self, value): @property def temporal_mode(self): - warnings.warn( + raise RuntimeError( "temporal_mode is deprecated, use recurrent_mode instead.", - category=DeprecationWarning, ) - return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index de963bcfdb9..6edcda5c800 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib -import logging import warnings from copy import deepcopy from dataclasses import dataclass @@ -18,7 +17,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -288,8 +287,7 @@ def __init__( ) self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device)) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type @property @@ -298,41 +296,46 @@ def functional(self): @property def actor(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network @property def critic(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network @property def actor_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network_params @property def critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network_params @property def target_critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.target_critic_network_params diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 04a7708e7db..1f5edcf26ed 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -113,14 +113,11 @@ def __init__(self): # self.register_forward_pre_hook(_parameters_to_tensordict) def _set_deprecated_ctor_keys(self, **kwargs) -> None: - """Helper function to set a tensordict key from a constructor and raise a warning simultaneously.""" for key, value in kwargs.items(): if value is not None: - warnings.warn( + raise RuntimeError( f"Setting '{key}' via the constructor is deprecated, use .set_keys(='some_key') instead.", - category=DeprecationWarning, ) - self.set_keys(**{key: value}) def set_keys(self, **kwargs) -> None: """Set tensordict key names. @@ -217,7 +214,8 @@ def convert_to_functional( """ if kwargs.pop("funs_to_decorate", None) is not None: warnings.warn( - "funs_to_decorate is without effect with the new objective API.", + "funs_to_decorate is without effect with the new objective API. This " + "warning will be replaced by an error in v0.4.0.", category=DeprecationWarning, ) if kwargs: diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 3d90d0174b9..f963f0e0b52 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -26,7 +26,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -332,8 +332,7 @@ def __init__( self.target_entropy_buffer = None if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.temperature = temperature self.min_q_weight = min_q_weight @@ -1030,8 +1029,7 @@ def __init__( self.action_space = _find_action_space(action_space) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 70239ea62e9..6572084c8ec 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -5,7 +5,6 @@ from __future__ import annotations -import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -19,7 +18,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -230,8 +229,7 @@ def __init__( self.loss_function = loss_function if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 6ef7ab7386e..3ff093d445c 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math -import warnings from dataclasses import dataclass from numbers import Number from typing import Tuple, Union @@ -22,7 +21,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -202,8 +201,7 @@ def __init__( self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @property def target_entropy(self): diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 623c3f7189a..37fd1cbdaea 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -24,7 +24,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -224,8 +224,7 @@ def __init__( self.action_space = _find_action_space(action_space) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 7bdfde573fa..9fd8a8a0bd2 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -15,7 +14,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, hold_out_net, @@ -247,11 +246,9 @@ def __init__( self.imagination_horizon = imagination_horizon self.discount_loss = discount_loss if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) if lmbda is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.lmbda = lmbda + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 1fd48675cb4..62d2a628af4 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -17,7 +17,7 @@ from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -285,22 +285,15 @@ def __init__( self.loss_function = loss_function if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) @property def device(self) -> torch.device: - warnings.warn( - "The device attributes of the looses will be deprecated in v0.3.", - category=DeprecationWarning, - ) - for p in self.parameters(): - return p.device raise RuntimeError( - "At least one of the networks of SACLoss must have trainable " "parameters." + "The device attributes of the losses is deprecated since v0.3.", ) def _set_in_keys(self): @@ -407,7 +400,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: ) # assert has no gradient exp_a = torch.exp((min_q - value) * self.temperature) - exp_a = torch.min(exp_a, torch.FloatTensor([100.0]).to(self.device)) + exp_a = exp_a.clamp_max(100) # write log_prob in tensordict for alpha loss tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) @@ -775,7 +768,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: ) # assert has no gradient exp_a = torch.exp((min_Q - value) * self.temperature) - exp_a = torch.min(exp_a, torch.FloatTensor([100.0]).to(self.device)) + exp_a = exp_a.clamp_max(100) # write log_prob in tensordict for alpha loss tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 38f56108784..f7b9307a962 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -28,7 +28,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -265,8 +265,7 @@ def __init__( self.action_space = _find_action_space(action_space) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 0f7ea835949..ac2244b9a23 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import logging import math import warnings @@ -21,7 +20,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -331,8 +330,7 @@ def __init__( self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._set_deprecated_ctor_keys( advantage=advantage_key, value_target=value_target_key, @@ -363,7 +361,7 @@ def critic(self): @property def actor_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " "link will be removed in v0.4.", category=DeprecationWarning, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index af0a94cbc96..61aaf5990e4 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math -import warnings from dataclasses import dataclass from numbers import Number from typing import Union @@ -21,7 +20,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -313,8 +312,7 @@ def __init__( self.gSDE = gSDE if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index c9cc8f383ad..4613810d0d3 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import logging import warnings from copy import deepcopy from dataclasses import dataclass @@ -17,7 +16,7 @@ from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -281,8 +280,7 @@ def __init__( self.target_critic_network_params = None if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @property def functional(self): @@ -290,41 +288,46 @@ def functional(self): @property def actor(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network @property def critic(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network @property def actor_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network_params @property def critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network_params @property def target_critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.target_critic_network_params diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 431296e7486..053da9e53d2 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -25,7 +25,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -374,8 +374,7 @@ def __init__( self.actor_network, self.value_network ) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index e1aeb253681..877a8f0c819 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -18,7 +17,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -293,8 +292,7 @@ def __init__( self.register_buffer("max_action", high) self.register_buffer("min_action", low) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4c0b8ae67bd..43dfa65c0c4 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -25,9 +25,9 @@ raise err_ft from err from torchrl.envs.utils import step_mdp -_GAMMA_LMBDA_DEPREC_WARNING = ( +_GAMMA_LMBDA_DEPREC_ERROR = ( "Passing gamma / lambda parameters through the loss constructor " - "is deprecated and will be removed soon. To customize your value function, " + "is a deprecated feature. To customize your value function, " "run `loss_module.make_value_estimator(ValueEstimators., gamma=val)`." ) @@ -299,9 +299,8 @@ def __init__( tau: Optional[float] = None, ): if eps is None and tau is None: - warnings.warn( - "Neither eps nor tau was provided. Taking the default value " - "eps=0.999. This behaviour will soon be deprecated.", + raise RuntimeError( + "Neither eps nor tau was provided. " "This behaviour is deprecated.", category=DeprecationWarning, ) eps = 0.999 diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index fc2e58a19f6..dfa56e5c672 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -299,23 +299,17 @@ def __init__( self.shifted = shifted if advantage_key is not None: - warnings.warn( - "Setting 'advantage_key' via ctor is deprecated, use .set_keys(advantage_key='some_key') instead.", - category=DeprecationWarning, + raise RuntimeError( + "Setting 'advantage_key' via constructor is deprecated, use .set_keys(advantage_key='some_key') instead.", ) - self.dep_keys["advantage"] = advantage_key if value_target_key is not None: - warnings.warn( - "Setting 'value_target_key' via ctor is deprecated, use .set_keys(value_target_key='some_key') instead.", - category=DeprecationWarning, + raise RuntimeError( + "Setting 'value_target_key' via constructor is deprecated, use .set_keys(value_target_key='some_key') instead.", ) - self.dep_keys["value_target"] = value_target_key if value_key is not None: - warnings.warn( - "Setting 'value_key' via ctor is deprecated, use .set_keys(value_key='some_key') instead.", - category=DeprecationWarning, + raise RuntimeError( + "Setting 'value_key' via constructor is deprecated, use .set_keys(value_key='some_key') instead.", ) - self.dep_keys["value"] = value_key @property def tensor_keys(self) -> _AcceptedKeys: diff --git a/torchrl/trainers/helpers/logger.py b/torchrl/trainers/helpers/logger.py index 6e4e864aa2e..b0b37533519 100644 --- a/torchrl/trainers/helpers/logger.py +++ b/torchrl/trainers/helpers/logger.py @@ -28,3 +28,5 @@ class LoggerConfig: # Keys to log in the recorder offline_logging: bool = True # If True, Wandb will do the logging offline + project_name: str = "" + # The name of the project for WandB diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 96f8d98477d..0764bf9fb72 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -667,9 +667,11 @@ def __init__( self.device = device if flatten_tensordicts is None: warnings.warn( - "flatten_tensordicts default value will soon be changed " + "flatten_tensordicts default value has now changed " "to False for a faster execution. Make sure your " - "code is robust to this change.", + "code is robust to this change. To silence this warning, " + "pass flatten_tensordicts= in your code. " + "This warning will be removed in v0.4.", category=DeprecationWarning, ) flatten_tensordicts = True