diff --git a/.github/scripts/m1_script.sh b/.github/scripts/td_script.sh similarity index 71% rename from .github/scripts/m1_script.sh rename to .github/scripts/td_script.sh index 6552d8e4622..6da1cad5d79 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/td_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export TORCHRL_BUILD_VERSION=0.4.0 +export TORCHRL_BUILD_VERSION=0.5.0 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 075489b208d..f8b700c0410 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -9,9 +9,19 @@ # # -set -e +#set -e set -v +# Initialize an error flag +error_occurred=0 +# Function to handle errors +error_handler() { + echo "Error on line $1" + error_occurred=1 +} +# Trap ERR to call the error_handler function with the failing line number +trap 'error_handler $LINENO' ERR + export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" @@ -24,6 +34,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export CUDA_LAUNCH_BLOCKING=1 python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 #python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 @@ -51,10 +62,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ logger.backend= - # ==================================================================================== # # ================================ Gymnasium ========================================= # +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \ + optim.gradient_steps=55 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \ collector.total_frames=80 \ collector.frames_per_batch=20 \ @@ -149,18 +162,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di replay_buffer.size=120 \ env.name=CartPole-v1 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - collector.total_frames=200 \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \ + collector.total_frames=48 \ collector.init_random_frames=10 \ - collector.frames_per_batch=200 \ - env.n_parallel_envs=4 \ - optimization.optim_steps_per_batch=1 \ - logger.video=True \ - logger.backend=csv \ - replay_buffer.buffer_size=120 \ - replay_buffer.batch_size=24 \ - replay_buffer.batch_length=12 \ - networks.rssm_hidden_dim=17 + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.device= \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + network.device= \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -200,8 +213,8 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.frames_per_batch=200 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ - logger.backend=csv \ logger.video=True \ + logger.backend=csv \ replay_buffer.buffer_size=120 \ replay_buffer.batch_size=24 \ replay_buffer.batch_length=12 \ @@ -298,3 +311,11 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ba coverage combine coverage xml -i + +# Check if any errors occurred during the script execution +if [ "$error_occurred" -ne 0 ]; then + echo "Errors occurred during script execution" + exit 1 +else + echo "Script executed successfully" +fi diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index 5171a7c3e2a..f51c5ed79b6 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -45,3 +45,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 84fe79d09d2..73a365a79f2 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -46,4 +46,4 @@ jobs: runner-type: macos-m1-stable smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - env-var-script: .github/scripts/m1_script.sh + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 683f2a93f69..1beef7318f4 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -46,3 +46,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh diff --git a/README.md b/README.md index 5ac72ff052e..2b250dac540 100644 --- a/README.md +++ b/README.md @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr - [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py) - [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py) - [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py) +- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py) - [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py) - [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py) - [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ccd6cb23ed0..b46d789ed15 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -317,6 +317,7 @@ Regular modules Conv3dNet SqueezeLayer Squeeze2dLayer + BatchRenorm Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index c2f43d8e9b6..96a887196aa 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -121,6 +121,15 @@ REDQ REDQLoss +CrossQ +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + CrossQ + IQL ---- @@ -160,6 +169,15 @@ TD3 TD3Loss +TD3+BC +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + TD3BCLoss + PPO --- diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 2f0982257eb..11384bda0e6 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -124,26 +124,26 @@ Checkpointing ------------- The trainer class and hooks support checkpointing, which can be achieved either -using the ``torchsnapshot ``_ backend or +using the `torchsnapshot `_ backend or the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``: .. code-block:: - $ CKPT_BACKEND=torch python script.py + $ CKPT_BACKEND=torchsnapshot python script.py -which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch +``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch is that it is a more flexible API, which supports distributed checkpointing and also allows users to load tensors from a file stored on disk to a tensor with a physical storage (which pytorch currently does not support). This allows, for instance, to load tensors from and to a replay buffer that would otherwise not fit in memory. -When building a trainer, one can provide a file path where the checkpoints are to +When building a trainer, one can provide a path where the checkpoints are to be written. With the ``torchsnapshot`` backend, a directory path is expected, whereas the ``torch`` backend expects a file path (typically a ``.pt`` file). .. code-block:: - >>> filepath = "path/to/dir/" + >>> filepath = "path/to/dir/or/file" >>> trainer = Trainer( ... collector=collector, ... total_frames=total_frames, diff --git a/setup.py b/setup.py index a439829db17..73541790e8f 100644 --- a/setup.py +++ b/setup.py @@ -172,7 +172,7 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.4.0" + tensordict_dep = "tensordict>=0.5.0" if is_nightly: version = get_nightly_version() diff --git a/sota-check/run_crossq.sh b/sota-check/run_crossq.sh new file mode 100644 index 00000000000..2ae4ea51c49 --- /dev/null +++ b/sota-check/run_crossq.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=crossq +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/crossq_%j.txt +#SBATCH --error=slurm_errors/crossq_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="crossq" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/crossq/crossq.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_td3bc.sh b/sota-check/run_td3bc.sh new file mode 100644 index 00000000000..0fefb3ecd6f --- /dev/null +++ b/sota-check/run_td3bc.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=td3bc_offline +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/td3bc_offline_%j.txt +#SBATCH --error=slurm_errors/td3bc_offline_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="td3bc_offline" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh index cad2783c653..515ac06a50b 100755 --- a/sota-check/submitit-release-check.sh +++ b/sota-check/submitit-release-check.sh @@ -65,6 +65,7 @@ scripts=( run_ppo_mujoco.sh run_sac.sh run_td3.sh + run_td3bc.sh run_dt.sh run_dt_online.sh ) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 775dcfe206d..f8c18147306 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 0276039058f..d115174eb9c 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval final = collected_frames >= collector.total_frames diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index d8185c8091c..5ca70f83b53 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 5f8f81357c8..cf629ed0733 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // evaluation_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 4b6f14cd058..d0d6693eb97 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml new file mode 100644 index 00000000000..1dcbd3db92d --- /dev/null +++ b/sota-implementations/crossq/config.yaml @@ -0,0 +1,58 @@ +# environment and task +env: + name: HalfCheetah-v4 + task: "" + library: gym + max_episode_steps: 1000 + seed: 42 + +# collector +collector: + total_frames: 1_000_000 + init_random_frames: 25000 + frames_per_batch: 1000 + init_env_steps: 1000 + device: cpu + env_per_collector: 1 + reset_at_each_iter: False + +# replay buffer +replay_buffer: + size: 1000000 + prb: 0 # use prioritized experience replay + scratch_dir: null + +# optim +optim: + utd_ratio: 1.0 + policy_update_delay: 3 + gamma: 0.99 + loss_function: l2 + lr: 1.0e-3 + weight_decay: 0.0 + batch_size: 256 + alpha_init: 1.0 + adam_eps: 1.0e-8 + beta1: 0.5 + beta2: 0.999 + +# network +network: + batch_norm_momentum: 0.01 + warmup_steps: 100000 + critic_hidden_sizes: [2048, 2048] + actor_hidden_sizes: [256, 256] + critic_activation: relu + actor_activation: relu + default_policy_scale: 1.0 + scale_lb: 0.1 + device: "cuda:0" + +# logging +logger: + backend: wandb + project_name: torchrl_example_crossQ + group_name: null + exp_name: ${env.name}_CrossQ + mode: online + eval_iter: 25000 diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py new file mode 100644 index 00000000000..df34d4ae68d --- /dev/null +++ b/sota-implementations/crossq/crossq.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""CrossQ Example. + +This is a simple self-contained example of a CrossQ training script. + +It supports state environments like MuJoCo. + +The helper functions are coded in the utils.py associated with this script. +""" +import time + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +from torchrl._utils import logger as torchrl_logger +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils import ( + log_metrics, + make_collector, + make_crossQ_agent, + make_crossQ_optimizer, + make_environment, + make_loss_module, + make_replay_buffer, +) + + +@hydra.main(version_base="1.1", config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) + + # Create logger + exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="crossq_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create environments + train_env, eval_env = make_environment(cfg) + + # Create agent + model, exploration_policy = make_crossQ_agent(cfg, train_env, device) + + # Create CrossQ loss + loss_module = make_loss_module(cfg, model) + + # Create off-policy collector + collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + scratch_dir=cfg.replay_buffer.scratch_dir, + device="cpu", + ) + + # Create optimizers + ( + optimizer_actor, + optimizer_critic, + optimizer_alpha, + ) = make_crossQ_optimizer(cfg, loss_module) + + # Main loop + start_time = time.time() + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.env.max_episode_steps + + sampling_start = time.time() + update_counter = 0 + delayed_updates = cfg.optim.policy_update_delay + for _, tensordict in enumerate(collector): + sampling_time = time.time() - sampling_start + + # Update weights of the inference policy + collector.update_policy_weights_() + + pbar.update(tensordict.numel()) + + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # Add to replay buffer + replay_buffer.extend(tensordict.cpu()) + collected_frames += current_frames + + # Optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: + ( + actor_losses, + alpha_losses, + q_losses, + ) = ([], [], []) + for _ in range(num_updates): + + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to(device) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) + q_loss = q_loss.mean() + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_losses.append(q_loss.detach().item()) + + if update_actor: + actor_loss, metadata_actor = loss_module.actor_loss( + sampled_tensordict + ) + actor_loss = actor_loss.mean() + alpha_loss = loss_module.alpha_loss( + log_prob=metadata_actor["log_prob"] + ).mean() + + # Update actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() + + actor_losses.append(actor_loss.detach().item()) + alpha_losses.append(alpha_loss.detach().item()) + + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses).item() + metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() + metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py new file mode 100644 index 00000000000..9883bc50b17 --- /dev/null +++ b/sota-implementations/crossq/utils.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.nn import InteractionType, TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn, optim +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + ParallelEnv, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import MLP, ProbabilisticActor, ValueOperator +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.models.batchrenorm import BatchRenorm1d +from torchrl.objectives import CrossQLoss + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps=1000): + transformed_env = TransformedEnv( + env, + Compose( + InitTracker(), + StepCounter(max_episode_steps), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg): + """Make environments for training and evaluation.""" + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, + ), + train_env.transform.clone(), + ) + return train_env, eval_env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, train_env, actor_model_explore, device): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + ) + collector.set_seed(cfg.env.seed) + return collector + + +def make_replay_buffer( + batch_size, + prb=False, + buffer_size=1000000, + scratch_dir=None, + device="cpu", + prefetch=3, +): + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + ), + batch_size=batch_size, + ) + replay_buffer.append_transform(lambda x: x.to(device, non_blocking=True)) + return replay_buffer + + +# ==================================================================== +# Model +# ----- + + +def make_crossQ_agent(cfg, train_env, device): + """Make CrossQ agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.actor_hidden_sizes, + "out_features": 2 * action_spec.shape[-1], + "activation_class": get_activation(cfg.network.actor_activation), + "norm_class": BatchRenorm1d, + "norm_kwargs": { + "momentum": cfg.network.batch_norm_momentum, + "num_features": cfg.network.actor_hidden_sizes[-1], + "warmup_steps": cfg.network.warmup_steps, + }, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "low": action_spec.space.low, + "high": action_spec.space.high, + "tanh_loc": False, + } + + actor_extractor = NormalParamExtractor( + scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ) + actor_net = nn.Sequential(actor_net, actor_extractor) + + in_keys_actor = in_keys + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.critic_hidden_sizes, + "out_features": 1, + "activation_class": get_activation(cfg.network.critic_activation), + "norm_class": BatchRenorm1d, + "norm_kwargs": { + "momentum": cfg.network.batch_norm_momentum, + "num_features": cfg.network.critic_hidden_sizes[-1], + "warmup_steps": cfg.network.warmup_steps, + }, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = train_env.fake_tensordict() + td = td.to(device) + for net in model: + net.eval() + net(td) + net.train() + del td + + return model, model[0] + + +# ==================================================================== +# CrossQ Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create CrossQ loss + loss_module = CrossQLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + loss_function=cfg.optim.loss_function, + alpha_init=cfg.optim.alpha_init, + ) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) + + return loss_module + + +def split_critic_params(critic_params): + critic1_params = [] + critic2_params = [] + + for param in critic_params: + data1, data2 = param.data.chunk(2, dim=0) + critic1_params.append(nn.Parameter(data1)) + critic2_params.append(nn.Parameter(data2)) + return critic1_params, critic2_params + + +def make_crossQ_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + betas=(cfg.optim.beta1, cfg.optim.beta2), + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + betas=(cfg.optim.beta1, cfg.optim.beta2), + ) + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=cfg.optim.lr, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(activation: str): + if activation == "relu": + return nn.ReLU + elif activation == "tanh": + return nn.Tanh + elif activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index eb0b88c26f7..a92ee6185c3 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 59dbcafd8c9..9cca9fd8af5 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -56,7 +56,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + test_env = make_env( + cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device + ) if cfg.logger.video: test_env = test_env.append_transform( VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) @@ -114,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log = {"train/loss": loss_vals["loss"]} # Evaluation - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 5cb297e5c0b..da2241ce9fa 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } # Evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 7c9500aa4e7..409833c75fa 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -57,7 +57,7 @@ # ----------------- -def make_base_env(env_cfg, from_pixels=False): +def make_base_env(env_cfg, from_pixels=False, device=None): set_gym_backend(env_cfg.backend).set() env_library = LIBS[env_cfg.library] @@ -73,7 +73,7 @@ def make_base_env(env_cfg, from_pixels=False): if env_library is DMControlEnv: env_task = env_cfg.task env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) + env = env_library(**env_kwargs, device=device) return env @@ -134,7 +134,9 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): return transformed_env -def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): +def make_parallel_env( + env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None +): if train: num_envs = env_cfg.num_train_envs else: @@ -142,10 +144,12 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False) def make_env(): with set_gym_backend(env_cfg.backend): - return make_base_env(env_cfg, from_pixels=from_pixels) + return make_base_env(env_cfg, from_pixels=from_pixels, device="cpu") env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True), + ParallelEnv( + num_envs, EnvCreator(make_env), serial_for_single=True, device=device + ), env_cfg, obs_loc, obs_std, @@ -154,11 +158,15 @@ def make_env(): return env -def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): - env = make_parallel_env( - env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels +def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None): + return make_parallel_env( + env_cfg, + obs_loc, + obs_std, + train=train, + from_pixels=from_pixels, + device=device, ) - return env # ==================================================================== diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 6e100f92dc3..386f743c7d3 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // eval_iter final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 90f93551d4d..906273ee2f5 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index ac3f17a9203..173f88f7028 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index ab101e8486a..604e1ac546a 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -9,17 +9,13 @@ env: image_size : 64 horizon: 500 n_parallel_envs: 8 - device: - _target_: dreamer_utils._default_device - device: null + device: cpu collector: total_frames: 5_000_000 init_random_frames: 3000 frames_per_batch: 1000 device: - _target_: dreamer_utils._default_device - device: null optimization: train_every: 1000 @@ -41,8 +37,6 @@ optimization: networks: exploration_noise: 0.3 device: - _target_: dreamer_utils._default_device - device: null state_dim: 30 rssm_hidden_dim: 200 hidden_dim: 400 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e7b346b2b22..e521b9df386 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -10,6 +10,7 @@ import torch.cuda import tqdm from dreamer_utils import ( + _default_device, dump_video, log_metrics, make_collector, @@ -17,7 +18,6 @@ make_environments, make_replay_buffer, ) -from hydra.utils import instantiate # mixed precision training from torch.cuda.amp import GradScaler @@ -38,7 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # cfg = correct_for_frame_skip(cfg) - device = torch.device(instantiate(cfg.networks.device)) + device = _default_device(cfg.networks.device) # Create logger exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) @@ -284,7 +284,7 @@ def compile_rssms(module): # Evaluation if (i % eval_iter) == 0: # Real env - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_rollout = test_env.rollout( eval_rollout_steps, policy, @@ -298,7 +298,9 @@ def compile_rssms(module): log_metrics(logger, eval_metrics, collected_frames) # Simulated env if model_based_env_eval is not None: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index ff14871b011..73baa310821 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from hydra.utils import instantiate from tensordict import NestedKey from tensordict.nn import ( InteractionType, @@ -88,6 +87,7 @@ def _make_env(cfg, device, from_pixels=False): cfg.env.task, from_pixels=cfg.env.from_pixels or from_pixels, pixels_only=cfg.env.from_pixels, + device=device, ) else: raise NotImplementedError(f"Unknown lib {lib}.") @@ -98,7 +98,6 @@ def _make_env(cfg, device, from_pixels=False): env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) ) - assert env is not None return env @@ -129,7 +128,7 @@ def transform_env(cfg, env): def make_environments(cfg, parallel_envs=1, logger=None): """Make environments for training and evaluation.""" - func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device) + func = functools.partial(_make_env, cfg=cfg, device=_default_device(cfg.env.device)) train_env = ParallelEnv( parallel_envs, EnvCreator(func), @@ -138,7 +137,10 @@ def make_environments(cfg, parallel_envs=1, logger=None): train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) func = functools.partial( - _make_env, cfg=cfg, device=cfg.env.device, from_pixels=cfg.logger.video + _make_env, + cfg=cfg, + device=_default_device(cfg.env.device), + from_pixels=cfg.logger.video, ) eval_env = ParallelEnv( 1, @@ -332,7 +334,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - policy_device=instantiate(cfg.collector.device), + policy_device=_default_device(cfg.collector.device), env_device=train_env.device, storing_device="cpu", ) @@ -535,7 +537,7 @@ def _dreamer_make_actor_real( SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=[action_key], - default_interaction_type=InteractionType.MODE, + default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec( diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index 0482a595ffa..1998c044305 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -247,7 +247,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index ce96cf06ce8..fdee4256c42 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -239,7 +239,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index bb0f314197a..cf583909620 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 33513dd3973..ae1894379fd 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index d98724e1371..d1a16fd8192 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -130,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index b66c6f9dcf2..d50ff806294 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 81551ebefb7..a4d2b88a9d0 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index 9d14ff04b04..bd44bb0a043 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index e752c4d73f2..fa006a7d4a2 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -236,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index d294a9c783e..4e6a962c556 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 30b7e7e98bc..f7b2523010b 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 908cb7924a3..2b02254032a 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index e3e74971a49..219ae1b59b6 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -210,7 +210,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index f7a399cda72..9904fe072ab 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 97fd039c238..5fbc9b032d7 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/td3_bc/config.yaml b/sota-implementations/td3_bc/config.yaml new file mode 100644 index 00000000000..54275a94bc2 --- /dev/null +++ b/sota-implementations/td3_bc/config.yaml @@ -0,0 +1,45 @@ +# task and env +env: + name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency + task: "" + library: gymnasium + seed: 42 + max_episode_steps: 1000 + +# replay buffer +replay_buffer: + dataset: halfcheetah-medium-v2 + batch_size: 256 + +# optim +optim: + gradient_steps: 100000 + gamma: 0.99 + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 + adam_eps: 1e-4 + batch_size: 256 + target_update_polyak: 0.995 + policy_update_delay: 2 + policy_noise: 0.2 + noise_clip: 0.5 + alpha: 2.5 + +# network +network: + hidden_sizes: [256, 256] + activation: relu + device: null + +# logging +logger: + backend: wandb + project_name: td3+bc_${replay_buffer.dataset} + group_name: null + exp_name: TD3+BC_${replay_buffer.dataset} + mode: online + eval_iter: 5000 + eval_steps: 1000 + eval_envs: 1 + video: False diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py new file mode 100644 index 00000000000..7c43fdc1a12 --- /dev/null +++ b/sota-implementations/td3_bc/td3_bc.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""TD3+BC Example. + +This is a self-contained example of an offline RL TD3+BC training script. + +The helper functions are coded in the utils.py associated with this script. + +""" +import time + +import hydra +import numpy as np +import torch +import tqdm +from torchrl._utils import logger as torchrl_logger + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.record.loggers import generate_exp_name, get_logger + +from utils import ( + dump_video, + log_metrics, + make_environment, + make_loss_module, + make_offline_replay_buffer, + make_optimizer, + make_td3_agent, +) + + +@hydra.main(config_path="", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.library).set() + + # Create logger + exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="td3bc_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + + # Creante env + eval_env = make_environment( + cfg, + logger=logger, + ) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create agent + model, _ = make_td3_agent(cfg, eval_env, device) + + # Create loss + loss_module, target_net_updater = make_loss_module(cfg.optim, model) + + # Create optimizer + optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module) + + gradient_steps = cfg.optim.gradient_steps + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + delayed_updates = cfg.optim.policy_update_delay + update_counter = 0 + pbar = tqdm.tqdm(range(gradient_steps)) + # Training loop + start_time = time.time() + for i in pbar: + pbar.update(1) + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to(device) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) + + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_loss.item() + + to_log = {"q_loss": q_loss.item()} + + # Update actor + if update_actor: + actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict) + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Update target params + target_net_updater.step() + + to_log["actor_loss"] = actor_loss.item() + to_log.update(actorloss_metadata) + + # evaluation + if i % evaluation_interval == 0: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_env.apply(dump_video) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + if logger is not None: + log_metrics(logger, to_log, i) + + pbar.close() + torchrl_logger.info(f"Training time: {time.time() - start_time}") + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py new file mode 100644 index 00000000000..3772eefccde --- /dev/null +++ b/sota-implementations/td3_bc/utils.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import functools + +import torch + +from torch import nn, optim +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + InitTracker, + ParallelEnv, + RewardSum, + StepCounter, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ( + AdditiveGaussianWrapper, + MLP, + SafeModule, + SafeSequential, + TanhModule, + ValueOperator, +) + +from torchrl.objectives import SoftUpdate +from torchrl.objectives.td3_bc import TD3BCLoss +from torchrl.record import VideoRecorder + + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu", from_pixels=False): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + ) + elif lib == "dm_control": + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps): + transformed_env = TransformedEnv( + env, + Compose( + StepCounter(max_steps=max_episode_steps), + InitTracker(), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg, logger=None): + """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) + parallel_env = ParallelEnv( + cfg.logger.eval_envs, + EnvCreator(partial), + serial_for_single=True, + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + return train_env + + +# ==================================================================== +# Replay buffer +# --------------------------- + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + dataset_id=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +# ==================================================================== +# Model +# ----- + + +def make_td3_agent(cfg, train_env, device): + """Make TD3 agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": get_activation(cfg), + } + + actor_net = MLP(**actor_net_kwargs) + + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "param", + ], + ) + actor = SafeSequential( + actor_module, + TanhModule( + in_keys=["param"], + out_keys=["action"], + spec=action_spec, + ), + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": 1, + "activation_class": get_activation(cfg), + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = train_env.fake_tensordict() + td = td.to(device) + for net in model: + net(td) + del td + + # Exploration wrappers: + actor_model_explore = AdditiveGaussianWrapper( + model[0], + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device) + return model, actor_model_explore + + +# ==================================================================== +# TD3 Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create TD3 loss + loss_module = TD3BCLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + loss_function=cfg.loss_function, + delay_actor=True, + delay_qvalue=True, + action_spec=model[0][1].spec, + policy_noise=cfg.policy_noise, + noise_clip=cfg.noise_clip, + alpha=cfg.alpha, + ) + loss_module.make_value_estimator(gamma=cfg.gamma) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) + return loss_module, target_net_updater + + +def make_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.values(True, True)) + actor_params = list(loss_module.actor_network_params.values(True, True)) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + eps=cfg.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + eps=cfg.adam_eps, + ) + return optimizer_actor, optimizer_critic + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/test/test_cost.py b/test/test_cost.py index 76fc4e651f4..a318f5694cd 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -12,7 +12,7 @@ from dataclasses import asdict, dataclass from packaging import version as pack_version -from tensordict._tensordict import unravel_keys +from tensordict._C import unravel_keys from tensordict.nn import ( InteractionType, @@ -98,6 +98,7 @@ A2CLoss, ClipPPOLoss, CQLLoss, + CrossQLoss, DDPGLoss, DiscreteCQLLoss, DiscreteIQLLoss, @@ -114,6 +115,7 @@ PPOLoss, QMixerLoss, SACLoss, + TD3BCLoss, TD3Loss, ) from torchrl.objectives.common import LossModule @@ -261,9 +263,9 @@ def __init__(self): self.vmap_model = _vmap_func( self.model, (None, 0), - randomness="error" - if vmap_randomness == "error" - else self.vmap_randomness, + randomness=( + "error" if vmap_randomness == "error" else self.vmap_randomness + ), ) def forward(self, td): @@ -319,9 +321,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], @@ -2714,11 +2716,7 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == torch.Size([]) -@pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" -) -@pytest.mark.parametrize("version", [1, 2]) -class TestSAC(LossModuleTestBase): +class TestTD3BC(LossModuleTestBase): seed = 0 def _create_mock_actor( @@ -2727,36 +2725,35 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", - observation_key="observation", - action_key="action", + in_keys=None, + out_keys=None, + dropout=0.0, ): # Actor action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + module = nn.Sequential( + nn.Linear(obs_dim, obs_dim), + nn.Dropout(dropout), + nn.Linear(obs_dim, action_dim), ) - actor = ProbabilisticActor( - module=module, - in_keys=["loc", "scale"], - spec=action_spec, - distribution_class=TanhNormal, - out_keys=[action_key], + actor = Actor( + spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys ) return actor.to(device) - def _create_mock_qvalue( + def _create_mock_value( self, batch=2, obs_dim=3, action_dim=4, device="cpu", - observation_key="observation", - action_key="action", out_keys=None, + action_key="action", + observation_key="observation", ): + # Actor class ValueClass(nn.Module): def __init__(self): super().__init__() @@ -2766,29 +2763,17 @@ def forward(self, obs, act): return self.linear(torch.cat([obs, act], -1)) module = ValueClass() - qvalue = ValueOperator( + value = ValueOperator( module=module, in_keys=[observation_key, action_key], out_keys=out_keys, ) - return qvalue.to(device) + return value.to(device) - def _create_mock_value( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - observation_key="observation", - out_keys=None, + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 ): - module = nn.Linear(obs_dim, 1) - value = ValueOperator( - module=module, - in_keys=[observation_key], - out_keys=out_keys, - ) - return value.to(device) + raise NotImplementedError def _create_mock_common_layer_setup( self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 @@ -2805,7 +2790,7 @@ def _create_mock_common_layer_setup( depth=1, out_features=2 * n_act, ) - qvalue = MLP( + value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=1, @@ -2836,31 +2821,27 @@ def _create_mock_common_layer_setup( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, + return_log_prob=True, ), ) - qvalue_head = Mod( - qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"] + value_head = Mod( + value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) - qvalue = Seq(common, qvalue_head) - return actor, qvalue, common, td - - def _create_mock_distributional_actor( - self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 - ): - raise NotImplementedError + value = Seq(common, value_head) + return actor, value, common, td - def _create_mock_data_sac( + def _create_mock_data_td3bc( self, - batch=16, + batch=8, obs_dim=3, action_dim=4, atoms=None, device="cpu", - observation_key="observation", action_key="action", + observation_key="observation", + reward_key="reward", done_key="done", terminated_key="terminated", - reward_key="reward", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -2888,7 +2869,7 @@ def _create_mock_data_sac( ) return td - def _create_seq_mock_data_sac( + def _create_seq_mock_data_td3bc( self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" ): # create a tensordict @@ -2904,269 +2885,225 @@ def _create_seq_mock_data_sac( reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = torch.ones(batch, T, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "observation": obs * mask.to(obs.dtype), "next": { - "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "observation": next_obs * mask.to(obs.dtype), + "reward": reward * mask.to(obs.dtype), "done": done, "terminated": terminated, - "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action * mask.to(obs.dtype), }, names=[None, "time"], device=device, ) return td - @pytest.mark.parametrize("delay_value", (True, False)) - @pytest.mark.parametrize("delay_actor", (True, False)) - @pytest.mark.parametrize("delay_qvalue", (True, False)) - @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) + @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) + @pytest.mark.parametrize("alpha", [0.1, 6.0]) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_sac( + @pytest.mark.parametrize("use_action_spec", [True, False]) + @pytest.mark.parametrize("dropout", [0.0, 0.1]) + def test_td3bc( self, - delay_value, delay_actor, delay_qvalue, - num_qvalue, device, - version, + policy_noise, + noise_clip, + alpha, td_est, + use_action_spec, + dropout, ): - if (delay_actor or delay_qvalue) and not delay_value: - pytest.skip("incompatible config") - torch.manual_seed(self.seed) - td = self._create_mock_data_sac(device=device) - - actor = self._create_mock_actor(device=device) - qvalue = self._create_mock_qvalue(device=device) - if version == 1: - value = self._create_mock_value(device=device) + actor = self._create_mock_actor(device=device, dropout=dropout) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3bc(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None else: - value = None - - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - if delay_qvalue: - kwargs["delay_qvalue"] = True - if delay_value: - kwargs["delay_value"] = True - - loss_fn = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=num_qvalue, + bounds = (-1, 1) + action_spec = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, loss_function="l2", - **kwargs, + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return if td_est is not None: loss_fn.make_value_estimator(td_est) - - with _check_td_steady(td), pytest.warns( - UserWarning, match="No target network updater" + with ( + pytest.warns( + UserWarning, + match="No target network updater has been associated with this loss module", + ) + if (delay_actor or delay_qvalue) + else contextlib.nullcontext() ): - loss = loss_fn(td) - - assert loss_fn.tensor_keys.priority in td.keys() + with _check_td_steady(td): + loss = loss_fn(td) - # check that losses are independent - for k in loss.keys(): - if not k.startswith("loss"): - continue - loss[k].sum().backward(retain_graph=True) - if k == "loss_actor": - if version == 1: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params.values( - include_nested=True, leaves_only=True - ) - ) - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ) - assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) - ) - elif k == "loss_value" and version == 1: - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) - ) - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ) - assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params.values( - include_nested=True, leaves_only=True + for p in loss_fn.qvalue_network_params.values(True, True) ) - ) - elif k == "loss_qvalue": - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) ) - ) - if version == 1: + elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params.values( - include_nested=True, leaves_only=True - ) - ) - assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ) - elif k == "loss_alpha": - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True + for p in loss_fn.actor_network_params.values(True, True) ) - ) - if version == 1: - assert all( + assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params.values( - include_nested=True, leaves_only=True - ) + for p in loss_fn.qvalue_network_params.values(True, True) ) - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ) - else: - raise NotImplementedError(k) - loss_fn.zero_grad() + else: + raise NotImplementedError(k) + loss_fn.zero_grad() - sum( - [item for name, item in loss.items() if name.startswith("loss_")] - ).backward() - named_parameters = list(loss_fn.named_parameters()) - named_buffers = list(loss_fn.named_buffers()) + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) - assert len({p for n, p in named_parameters}) == len(list(named_parameters)) - assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) - for name, p in named_parameters: - if not name.startswith("target_"): - assert ( - p.grad is not None and p.grad.norm() > 0.0 - ), f"parameter {name} (shape: {p.shape}) has a null gradient" - else: - assert ( - p.grad is None or p.grad.norm() == 0.0 - ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" - @pytest.mark.parametrize("delay_value", (True, False)) - @pytest.mark.parametrize("delay_actor", (True, False)) - @pytest.mark.parametrize("delay_qvalue", (True, False)) - @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("device", get_default_devices()) - def test_sac_state_dict( + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1]) + @pytest.mark.parametrize("noise_clip", [0.1]) + @pytest.mark.parametrize("alpha", [0.1]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_td3bc_state_dict( self, - delay_value, delay_actor, delay_qvalue, - num_qvalue, device, - version, + policy_noise, + noise_clip, + alpha, + use_action_spec, ): - if (delay_actor or delay_qvalue) and not delay_value: - pytest.skip("incompatible config") - torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) - qvalue = self._create_mock_qvalue(device=device) - if version == 1: - value = self._create_mock_value(device=device) + value = self._create_mock_value(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None else: - value = None - - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - if delay_qvalue: - kwargs["delay_qvalue"] = True - if delay_value: - kwargs["delay_value"] = True - - loss_fn = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=num_qvalue, + bounds = (-1, 1) + action_spec = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, loss_function="l2", - **kwargs, + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, ) sd = loss_fn.state_dict() - loss_fn2 = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=num_qvalue, + loss_fn2 = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, loss_function="l2", - **kwargs, + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, ) loss_fn2.load_state_dict(sd) + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [False, True]) - def test_sac_separate_losses( + def test_td3bc_separate_losses( self, device, separate_losses, - version, n_act=4, ): torch.manual_seed(self.seed) - actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) - - loss_fn = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), - num_qvalue_nets=1, + actor, value, common, td = self._create_mock_common_layer_setup(n_act=n_act) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + loss_function="l2", separate_losses=separate_losses, ) with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) - assert loss_fn.tensor_keys.priority in td.keys() - + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -3175,25 +3112,19 @@ def test_sac_separate_losses( if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) + for p in loss_fn.qvalue_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) + for p in loss_fn.actor_network_params.values(True, True) ) elif k == "loss_qvalue": - common_layers_no = len(list(common.parameters())) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) + for p in loss_fn.actor_network_params.values(True, True) ) if separate_losses: + common_layers_no = len(list(common.parameters())) common_layers = itertools.islice( loss_fn.qvalue_network_params.values(True, True), common_layers_no, @@ -3216,235 +3147,1686 @@ def test_sac_separate_losses( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.qvalue_network_params.values(True, True) ) - elif k == "loss_alpha": - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) - ) - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ) + else: raise NotImplementedError(k) loss_fn.zero_grad() + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("n", range(1, 4)) - @pytest.mark.parametrize("delay_value", (True, False)) - @pytest.mark.parametrize("delay_actor", (True, False)) - @pytest.mark.parametrize("delay_qvalue", (True, False)) - @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) - def test_sac_batcher( + @pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)]) + @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) + @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) + @pytest.mark.parametrize("alpha", [0.1, 6.0]) + def test_td3bc_batcher( self, n, - delay_value, delay_actor, delay_qvalue, - num_qvalue, device, - version, + policy_noise, + noise_clip, + alpha, + gamma=0.9, ): - if (delay_actor or delay_qvalue) and not delay_value: - pytest.skip("incompatible config") torch.manual_seed(self.seed) - td = self._create_seq_mock_data_sac(device=device) - actor = self._create_mock_actor(device=device) - qvalue = self._create_mock_qvalue(device=device) - if version == 1: - value = self._create_mock_value(device=device) - else: - value = None - - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - if delay_qvalue: - kwargs["delay_qvalue"] = True - if delay_value: - kwargs["delay_value"] = True - - loss_fn = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=num_qvalue, - loss_function="l2", - **kwargs, + value = self._create_mock_value(device=device) + td = self._create_seq_mock_data_td3bc(device=device) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_qvalue=delay_qvalue, + delay_actor=delay_actor, ) - ms = MultiStep(gamma=0.9, n_steps=n).to(device) + ms = MultiStep(gamma=gamma, n_steps=n).to(device) td_clone = td.clone() ms_td = ms(td_clone) torch.manual_seed(0) np.random.seed(0) - with pytest.warns( - UserWarning, - match="No target network updater has been associated with this loss module", - ): - with _check_td_steady(ms_td): - loss_ms = loss_fn(ms_td) - assert loss_fn.tensor_keys.priority in ms_td.keys() - - with torch.no_grad(): - torch.manual_seed(0) # log-prob is computed with a random action - np.random.seed(0) - loss = loss_fn(td) - if n == 1: - assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum( - [item for name, item in loss.items() if name.startswith("loss_")] - ) - _loss_ms = sum( - [item for name, item in loss_ms.items() if name.startswith("loss_")] - ) - assert ( - abs(_loss - _loss_ms) < 1e-3 - ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" - else: - with pytest.raises(AssertionError): - assert_allclose_td(loss, loss_ms) - sum( - [item for name, item in loss_ms.items() if name.startswith("loss_")] - ).backward() - named_parameters = loss_fn.named_parameters() - for name, p in named_parameters: - if not name.startswith("target_"): - assert ( - p.grad is not None and p.grad.norm() > 0.0 - ), f"parameter {name} (shape: {p.shape}) has a null gradient" - else: - assert ( - p.grad is None or p.grad.norm() == 0.0 - ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" - # Check param update effect on targets - target_actor = [ - p.clone() - for p in loss_fn.target_actor_network_params.values( - include_nested=True, leaves_only=True - ) - ] - target_qvalue = [ - p.clone() - for p in loss_fn.target_qvalue_network_params.values( - include_nested=True, leaves_only=True + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if (delay_qvalue or delay_actor) + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + actorp_set = set(actor.parameters()) + loss_fnp_set = set(loss_fn.parameters()) + assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_td3bc_tensordict_keys(self, td_est): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value(out_keys=["state_action_value_test"]) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + ) + key_mapping = { + "state_action_value": ("value", "state_action_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("spec", [True, False]) + @pytest.mark.parametrize("bounds", [True, False]) + def test_constructor(self, spec, bounds): + actor = self._create_mock_actor() + value = self._create_mock_value() + action_spec = actor.spec if spec else None + bounds = (-1, 1) if bounds else None + if (bounds is not None and action_spec is not None) or ( + bounds is None and action_spec is None + ): + with pytest.raises(ValueError, match="but not both"): + TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, ) - ] - if version == 1: - target_value = [ - p.clone() - for p in loss_fn.target_value_network_params.values( + return + TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + ) + + # TODO: test for action_key, atm the action key of the TD3+BC loss is not configurable, + # since it is used in it's constructor + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_td3bc_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(in_keys=[observation_key]) + qvalue = self._create_mock_value( + observation_key=observation_key, out_keys=["state_action_value"] + ) + td = self._create_mock_data_td3bc( + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + "action": td.get("action"), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + with pytest.warns(UserWarning, match="No target network updater has been"): + torch.manual_seed(0) + loss_val_td = loss(td) + torch.manual_seed(0) + loss_val = loss(**kwargs) + loss_val_reconstruct = TensorDict(dict(zip(loss.out_keys, loss_val)), []) + assert_allclose_td(loss_val_reconstruct, loss_val_td) + + # test select + loss.select_out_keys("loss_actor", "loss_qvalue") + torch.manual_seed(0) + if torch.__version__ >= "2.0.0": + loss_actor, loss_qvalue = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_qvalue = loss(**kwargs) + return + + assert loss_actor == loss_val_td["loss_actor"] + assert loss_qvalue == loss_val_td["loss_qvalue"] + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_td3bc_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3bc(device=device) + action_spec = actor.spec + bounds = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + delay_qvalue=False, + delay_actor=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) + + +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +@pytest.mark.parametrize("version", [1, 2]) +class TestSAC(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, + out_keys=[action_key], + ) + return actor.to(device) + + def _create_mock_qvalue( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + out_keys=None, + ): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + qvalue = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return qvalue.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + out_keys=None, + ): + module = nn.Linear(obs_dim, 1) + value = ValueOperator( + module=module, + in_keys=[observation_key], + out_keys=out_keys, + ) + return value.to(device) + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + qvalue = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ), + ) + qvalue_head = Mod( + qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + qvalue = Seq(common, qvalue_head) + return actor, qvalue, common, td + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_sac( + self, + batch=16, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + observation_key="observation", + action_key="action", + done_key="done", + terminated_key="terminated", + reward_key="reward", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_sac( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "terminated": terminated, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + "collector": {"mask": mask}, + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.parametrize("delay_value", (True, False)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_sac( + self, + delay_value, + delay_actor, + delay_qvalue, + num_qvalue, + device, + version, + td_est, + ): + if (delay_actor or delay_qvalue) and not delay_value: + pytest.skip("incompatible config") + + torch.manual_seed(self.seed) + td = self._create_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + if delay_value: + kwargs["delay_value"] = True + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + with _check_td_steady(td), pytest.warns( + UserWarning, match="No target network updater" + ): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + if version == 1: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_value" and version == 1: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if version == 1: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if version == 1: + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_value", (True, False)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_sac_state_dict( + self, + delay_value, + delay_actor, + delay_qvalue, + num_qvalue, + device, + version, + ): + if (delay_actor or delay_qvalue) and not delay_value: + pytest.skip("incompatible config") + + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + if delay_value: + kwargs["delay_value"] = True + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + sd = loss_fn.state_dict() + loss_fn2 = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_sac_separate_losses( + self, + device, + separate_losses, + version, + n_act=4, + ): + torch.manual_seed(self.seed) + actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + num_qvalue_nets=1, + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + common_layers_no = len(list(common.parameters())) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.parametrize("n", range(1, 4)) + @pytest.mark.parametrize("delay_value", (True, False)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_sac_batcher( + self, + n, + delay_value, + delay_actor, + delay_qvalue, + num_qvalue, + device, + version, + ): + if (delay_actor or delay_qvalue) and not delay_value: + pytest.skip("incompatible config") + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + if delay_value: + kwargs["delay_value"] = True + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + + ms = MultiStep(gamma=0.9, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + with pytest.warns( + UserWarning, + match="No target network updater has been associated with this loss module", + ): + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + if version == 1: + target_value = [ + p.clone() + for p in loss_fn.target_value_network_params.values( + include_nested=True, leaves_only=True + ) + ] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue2 = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + if version == 1: + target_value2 = [ + p.clone() + for p in loss_fn.target_value_network_params.values( + include_nested=True, leaves_only=True + ) + ] + if loss_fn.delay_actor: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + if version == 1: + if loss_fn.delay_value: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_value, target_value2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all( + (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) + ) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_sac_tensordict_keys(self, td_est, version): + td = self._create_mock_data_sac() + + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + if version == 1: + value = self._create_mock_value() + else: + value = None + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + loss_function="l2", + ) + + default_keys = { + "priority": "td_error", + "value": "state_value", + "state_action_value": "state_action_value", + "action": "action", + "log_prob": "_log_prob", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value() + loss_fn = SACLoss( + actor, + value, + loss_function="l2", + ) + + key_mapping = { + "value": ("value", "state_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_sac_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key, version + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + action_key=action_key, + out_keys=["state_action_value"], + ) + if version == 1: + value = self._create_mock_value(observation_key=observation_key) + else: + value = None + + loss = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + # setting the seed for each loss so that drawing the random samples from value network + # leads to same numbers for both runs + torch.manual_seed(self.seed) + with pytest.warns(UserWarning, match="No target network updater"): + loss_val = loss(**kwargs) + + torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + + loss_val_td = loss(td) + + if version == 1: + assert len(loss_val) == 6 + elif version == 2: + assert len(loss_val) == 5 + + torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) + torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) + torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) + if version == 1: + torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[5]) + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_alpha") + if torch.__version__ >= "2.0.0": + loss_actor, loss_alpha = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_alpha = loss(**kwargs) + return + assert loss_actor == loss_val_td["loss_actor"] + assert loss_alpha == loss_val_td["loss_alpha"] + + def test_state_dict(self, version): + if version == 1: + pytest.skip("Test not implemented for version 1.") + model = torch.nn.Linear(3, 4) + actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) + policy = ProbabilisticActor( + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=TanhDelta, + ) + value = ValueOperator(module=model, in_keys=["obs"], out_keys="value") + + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + state = loss.state_dict() + + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + # with an access in between + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.target_entropy + state = loss.state_dict() + + loss = SACLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_sac_reduction(self, reduction, version): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_sac(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + if version == 1: + value = self._create_mock_value(device=device) + else: + value = None + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + loss_function="l2", + delay_qvalue=False, + delay_actor=False, + delay_value=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) + + +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestDiscreteSAC(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + ): + # Actor + action_spec = OneHotDiscreteTensorSpec(action_dim) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) + actor = ProbabilisticActor( + spec=action_spec, + module=module, + in_keys=["logits"], + out_keys=[action_key], + distribution_class=OneHotCategorical, + return_log_prob=False, + ) + return actor.to(device) + + def _create_mock_qvalue( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + ): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim, action_dim) + + def forward(self, obs): + return self.linear(obs) + + module = ValueClass() + qvalue = ValueOperator( + module=module, in_keys=[observation_key], out_keys=["action_value"] + ) + return qvalue.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_sac( + self, + batch=16, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + observation_key="observation", + action_key="action", + done_key="done", + terminated_key="terminated", + reward_key="reward", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + action_value = torch.randn(batch, atoms, action_dim).softmax(-2) + action = ( + (action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0]) + .to(torch.long) + .to(device) + ) + else: + action_value = torch.randn(batch, action_dim, device=device) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_sac( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action_value = torch.randn( + batch, T, atoms, action_dim, device=device + ).softmax(-2) + action = ( + action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0] + ).to(torch.long) + else: + action_value = torch.randn(batch, T, action_dim, device=device) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "terminated": terminated, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + "collector": {"mask": mask}, + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + names=[None, "time"], + ) + return td + + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99]) + @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0]) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_discrete_sac( + self, + delay_qvalue, + num_qvalue, + device, + target_entropy_weight, + target_entropy, + td_est, + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + kwargs = {} + if delay_qvalue: + kwargs["delay_qvalue"] = True + + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + num_qvalue_nets=num_qvalue, + target_entropy_weight=target_entropy_weight, + target_entropy=target_entropy, + loss_function="l2", + action_space="one-hot", + **kwargs, + ) + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + with _check_td_steady(td), pytest.warns( + UserWarning, match="No target network updater" + ): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( include_nested=True, leaves_only=True ) - ] - for p in loss_fn.parameters(): - if p.requires_grad: - p.data += torch.randn_like(p) - target_actor2 = [ - p.clone() - for p in loss_fn.target_actor_network_params.values( - include_nested=True, leaves_only=True - ) - ] - target_qvalue2 = [ - p.clone() - for p in loss_fn.target_qvalue_network_params.values( - include_nested=True, leaves_only=True ) - ] - if version == 1: - target_value2 = [ - p.clone() - for p in loss_fn.target_value_network_params.values( + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( include_nested=True, leaves_only=True ) - ] - if loss_fn.delay_actor: + ) + elif k == "loss_qvalue": assert all( - (p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2) + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) - else: assert not any( - (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) - if loss_fn.delay_qvalue: + elif k == "loss_alpha": assert all( - (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) - ) - if version == 1: - if loss_fn.delay_value: - assert all( - (p1 == p2).all() for p1, p2 in zip(target_value, target_value2) - ) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("target_entropy_weight", [0.5]) + @pytest.mark.parametrize("target_entropy", ["auto"]) + def test_discrete_sac_state_dict( + self, + delay_qvalue, + num_qvalue, + device, + target_entropy_weight, + target_entropy, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + kwargs = {} + if delay_qvalue: + kwargs["delay_qvalue"] = True + + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + num_qvalue_nets=num_qvalue, + target_entropy_weight=target_entropy_weight, + target_entropy=target_entropy, + loss_function="l2", + action_space="one-hot", + **kwargs, + ) + sd = loss_fn.state_dict() + loss_fn2 = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + num_qvalue_nets=num_qvalue, + target_entropy_weight=target_entropy_weight, + target_entropy=target_entropy, + loss_function="l2", + action_space="one-hot", + **kwargs, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("n", range(1, 4)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99]) + @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0]) + def test_discrete_sac_batcher( + self, + n, + delay_qvalue, + num_qvalue, + device, + target_entropy_weight, + target_entropy, + gamma=0.9, + ): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + kwargs = {} + if delay_qvalue: + kwargs["delay_qvalue"] = True + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + num_qvalue_nets=num_qvalue, + loss_function="l2", + target_entropy_weight=target_entropy_weight, + target_entropy=target_entropy, + action_space="one-hot", + **kwargs, + ) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + with _check_td_steady(ms_td), pytest.warns( + UserWarning, match="No target network updater" + ): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + SoftUpdate(loss_fn, eps=0.5) + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - if p.requires_grad: - p.data += torch.randn_like(p) + # Check param update effect on targets + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue2 = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: assert all( - (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) ) + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + @pytest.mark.parametrize( "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] ) - def test_sac_tensordict_keys(self, td_est, version): - td = self._create_mock_data_sac() - + def test_discrete_sac_tensordict_keys(self, td_est): actor = self._create_mock_actor() qvalue = self._create_mock_qvalue() - if version == 1: - value = self._create_mock_value() - else: - value = None - loss_fn = SACLoss( + loss_fn = DiscreteSACLoss( actor_network=actor, qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=2, + num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) default_keys = { "priority": "td_error", "value": "state_value", - "state_action_value": "state_action_value", "action": "action", - "log_prob": "_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", } - self.tensordict_keys_test( loss_fn, default_keys=default_keys, td_est=td_est, ) - value = self._create_mock_value() - loss_fn = SACLoss( - actor, - value, + qvalue = self._create_mock_qvalue() + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -3460,8 +4842,8 @@ def test_sac_tensordict_keys(self, td_est, version): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) - def test_sac_notensordict( - self, action_key, observation_key, reward_key, done_key, terminated_key, version + def test_discrete_sac_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -3477,18 +4859,13 @@ def test_sac_notensordict( ) qvalue = self._create_mock_qvalue( observation_key=observation_key, - action_key=action_key, - out_keys=["state_action_value"], ) - if version == 1: - value = self._create_mock_value(observation_key=observation_key) - else: - value = None - loss = SACLoss( + loss = DiscreteSACLoss( actor_network=actor, qvalue_network=qvalue, - value_network=value, + num_actions=actor.spec[action_key].space.n, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -3507,90 +4884,32 @@ def test_sac_notensordict( } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") - # setting the seed for each loss so that drawing the random samples from value network - # leads to same numbers for both runs - torch.manual_seed(self.seed) - with pytest.warns(UserWarning, match="No target network updater"): + with pytest.warns(UserWarning, match="No target network updater has been"): loss_val = loss(**kwargs) + loss_val_td = loss(td) - torch.manual_seed(self.seed) - - SoftUpdate(loss, eps=0.5) - - loss_val_td = loss(td) - - if version == 1: - assert len(loss_val) == 6 - elif version == 2: - assert len(loss_val) == 5 - - torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) - torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) - torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) - torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) - torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) - if version == 1: - torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[5]) - # test select - torch.manual_seed(self.seed) - loss.select_out_keys("loss_actor", "loss_alpha") - if torch.__version__ >= "2.0.0": - loss_actor, loss_alpha = loss(**kwargs) - else: - with pytest.raises( - RuntimeError, - match="You are likely using tensordict.nn.dispatch with keyword arguments", - ): + torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) + torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) + torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_alpha") + if torch.__version__ >= "2.0.0": loss_actor, loss_alpha = loss(**kwargs) - return - assert loss_actor == loss_val_td["loss_actor"] - assert loss_alpha == loss_val_td["loss_alpha"] - - def test_state_dict(self, version): - if version == 1: - pytest.skip("Test not implemented for version 1.") - model = torch.nn.Linear(3, 4) - actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) - policy = ProbabilisticActor( - module=actor_module, - in_keys=["logits"], - out_keys=["action"], - distribution_class=TanhDelta, - ) - value = ValueOperator(module=model, in_keys=["obs"], out_keys="value") - - loss = SACLoss( - actor_network=policy, - qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), - ) - state = loss.state_dict() - - loss = SACLoss( - actor_network=policy, - qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), - ) - loss.load_state_dict(state) - - # with an access in between - loss = SACLoss( - actor_network=policy, - qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), - ) - loss.target_entropy - state = loss.state_dict() - - loss = SACLoss( - actor_network=policy, - qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), - ) - loss.load_state_dict(state) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_alpha = loss(**kwargs) + return + assert loss_actor == loss_val_td["loss_actor"] + assert loss_alpha == loss_val_td["loss_alpha"] @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_sac_reduction(self, reduction, version): + def test_discrete_sac_reduction(self, reduction): torch.manual_seed(self.seed) device = ( torch.device("cpu") @@ -3600,18 +4919,13 @@ def test_sac_reduction(self, reduction, version): td = self._create_mock_data_sac(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - if version == 1: - value = self._create_mock_value(device=device) - else: - value = None - loss_fn = SACLoss( + loss_fn = DiscreteSACLoss( actor_network=actor, qvalue_network=qvalue, - value_network=value, + num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", delay_qvalue=False, - delay_actor=False, - delay_value=False, reduction=reduction, ) loss_fn.make_value_estimator() @@ -3627,10 +4941,7 @@ def test_sac_reduction(self, reduction, version): assert loss[key].shape == torch.Size([]) -@pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" -) -class TestDiscreteSAC(LossModuleTestBase): +class TestCrossQ(LossModuleTestBase): seed = 0 def _create_mock_actor( @@ -3643,16 +4954,19 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) actor = ProbabilisticActor( - spec=action_spec, module=module, - in_keys=["logits"], + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, out_keys=[action_key], - distribution_class=OneHotCategorical, - return_log_prob=False, ) return actor.to(device) @@ -3663,27 +4977,85 @@ def _create_mock_qvalue( action_dim=4, device="cpu", observation_key="observation", + action_key="action", + out_keys=None, ): class ValueClass(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(obs_dim, action_dim) + self.linear = nn.Linear(obs_dim + action_dim, 1) - def forward(self, obs): - return self.linear(obs) + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) module = ValueClass() qvalue = ValueOperator( - module=module, in_keys=[observation_key], out_keys=["action_value"] + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, ) return qvalue.to(device) + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + qvalue = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ), + ) + qvalue_head = Mod( + qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + qvalue = Seq(common, qvalue_head) + return actor, qvalue, common, td + def _create_mock_distributional_actor( self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 ): raise NotImplementedError - def _create_mock_data_sac( + def _create_mock_data_crossq( self, batch=16, obs_dim=3, @@ -3700,15 +5072,9 @@ def _create_mock_data_sac( obs = torch.randn(batch, obs_dim, device=device) next_obs = torch.randn(batch, obs_dim, device=device) if atoms: - action_value = torch.randn(batch, atoms, action_dim).softmax(-2) - action = ( - (action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0]) - .to(torch.long) - .to(device) - ) + raise NotImplementedError else: - action_value = torch.randn(batch, action_dim, device=device) - action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) @@ -3728,7 +5094,7 @@ def _create_mock_data_sac( ) return td - def _create_seq_mock_data_sac( + def _create_seq_mock_data_crossq( self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" ): # create a tensordict @@ -3736,20 +5102,15 @@ def _create_seq_mock_data_sac( obs = total_obs[:, :T] next_obs = total_obs[:, 1:] if atoms: - action_value = torch.randn( - batch, T, atoms, action_dim, device=device - ).softmax(-2) - action = ( - action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0] - ).to(torch.long) + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) else: - action_value = torch.randn(batch, T, action_dim, device=device) - action = (action_value == action_value.max(-1, True)[0]).to(torch.long) - + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ @@ -3762,48 +5123,33 @@ def _create_seq_mock_data_sac( }, "collector": {"mask": mask}, "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0), }, names=[None, "time"], + device=device, ) return td - @pytest.mark.parametrize("delay_qvalue", (True, False)) - @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) - @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99]) - @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0]) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_discrete_sac( + def test_crossq( self, - delay_qvalue, num_qvalue, device, - target_entropy_weight, - target_entropy, td_est, ): torch.manual_seed(self.seed) - td = self._create_mock_data_sac(device=device) - + td = self._create_mock_data_crossq(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - kwargs = {} - if delay_qvalue: - kwargs["delay_qvalue"] = True - - loss_fn = DiscreteSACLoss( + loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, num_qvalue_nets=num_qvalue, - target_entropy_weight=target_entropy_weight, - target_entropy=target_entropy, loss_function="l2", - action_space="one-hot", - **kwargs, ) + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) @@ -3811,9 +5157,7 @@ def test_discrete_sac( if td_est is not None: loss_fn.make_value_estimator(td_est) - with _check_td_steady(td), pytest.warns( - UserWarning, match="No target network updater" - ): + with _check_td_steady(td): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -3842,13 +5186,145 @@ def test_discrete_sac( for p in loss_fn.actor_network_params.values( include_nested=True, leaves_only=True ) - ) - assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_crossq_state_dict( + self, + num_qvalue, + device, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + ) + sd = loss_fn.state_dict() + loss_fn2 = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_crossq_separate_losses( + self, + separate_losses, + device, + ): + n_act = 4 + torch.manual_seed(self.seed) + actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + num_qvalue_nets=1, + separate_losses=separate_losses, + ) + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + common_layers_no = len(list(common.parameters())) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) ) - ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() @@ -3866,124 +5342,40 @@ def test_discrete_sac( raise NotImplementedError(k) loss_fn.zero_grad() - sum( - [item for name, item in loss.items() if name.startswith("loss_")] - ).backward() - named_parameters = list(loss_fn.named_parameters()) - named_buffers = list(loss_fn.named_buffers()) - - assert len({p for n, p in named_parameters}) == len(list(named_parameters)) - assert len({p for n, p in named_buffers}) == len(list(named_buffers)) - - for name, p in named_parameters: - if not name.startswith("target_"): - assert ( - p.grad is not None and p.grad.norm() > 0.0 - ), f"parameter {name} (shape: {p.shape}) has a null gradient" - else: - assert ( - p.grad is None or p.grad.norm() == 0.0 - ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" - - @pytest.mark.parametrize("delay_qvalue", (True, False)) - @pytest.mark.parametrize("num_qvalue", [2]) - @pytest.mark.parametrize("device", get_default_devices()) - @pytest.mark.parametrize("target_entropy_weight", [0.5]) - @pytest.mark.parametrize("target_entropy", ["auto"]) - def test_discrete_sac_state_dict( - self, - delay_qvalue, - num_qvalue, - device, - target_entropy_weight, - target_entropy, - ): - torch.manual_seed(self.seed) - - actor = self._create_mock_actor(device=device) - qvalue = self._create_mock_qvalue(device=device) - - kwargs = {} - if delay_qvalue: - kwargs["delay_qvalue"] = True - - loss_fn = DiscreteSACLoss( - actor_network=actor, - qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, - num_qvalue_nets=num_qvalue, - target_entropy_weight=target_entropy_weight, - target_entropy=target_entropy, - loss_function="l2", - action_space="one-hot", - **kwargs, - ) - sd = loss_fn.state_dict() - loss_fn2 = DiscreteSACLoss( - actor_network=actor, - qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, - num_qvalue_nets=num_qvalue, - target_entropy_weight=target_entropy_weight, - target_entropy=target_entropy, - loss_function="l2", - action_space="one-hot", - **kwargs, - ) - loss_fn2.load_state_dict(sd) - @pytest.mark.parametrize("n", range(1, 4)) - @pytest.mark.parametrize("delay_qvalue", (True, False)) - @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) - @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99]) - @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0]) - def test_discrete_sac_batcher( + def test_crossq_batcher( self, n, - delay_qvalue, num_qvalue, device, - target_entropy_weight, - target_entropy, - gamma=0.9, ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_sac(device=device) + td = self._create_seq_mock_data_crossq(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - kwargs = {} - if delay_qvalue: - kwargs["delay_qvalue"] = True - loss_fn = DiscreteSACLoss( + loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, num_qvalue_nets=num_qvalue, loss_function="l2", - target_entropy_weight=target_entropy_weight, - target_entropy=target_entropy, - action_space="one-hot", - **kwargs, ) - ms = MultiStep(gamma=gamma, n_steps=n).to(device) + ms = MultiStep(gamma=0.9, n_steps=n).to(device) td_clone = td.clone() ms_td = ms(td_clone) torch.manual_seed(0) np.random.seed(0) - with _check_td_steady(ms_td), pytest.warns( - UserWarning, match="No target network updater" - ): + + with _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() - SoftUpdate(loss_fn, eps=0.5) - with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -4023,12 +5415,6 @@ def test_discrete_sac_batcher( include_nested=True, leaves_only=True ) ] - target_qvalue = [ - p.clone() - for p in loss_fn.target_qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ] for p in loss_fn.parameters(): if p.requires_grad: p.data += torch.randn_like(p) @@ -4038,26 +5424,8 @@ def test_discrete_sac_batcher( include_nested=True, leaves_only=True ) ] - target_qvalue2 = [ - p.clone() - for p in loss_fn.target_qvalue_network_params.values( - include_nested=True, leaves_only=True - ) - ] - if loss_fn.delay_actor: - assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) - ) - if loss_fn.delay_qvalue: - assert all( - (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) - ) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) - ) + + assert not any((p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)) # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] @@ -4069,26 +5437,29 @@ def test_discrete_sac_batcher( @pytest.mark.parametrize( "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] ) - def test_discrete_sac_tensordict_keys(self, td_est): + def test_crossq_tensordict_keys(self, td_est): + actor = self._create_mock_actor() qvalue = self._create_mock_qvalue() + value = None - loss_fn = DiscreteSACLoss( + loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, + num_qvalue_nets=2, loss_function="l2", - action_space="one-hot", ) default_keys = { "priority": "td_error", - "value": "state_value", + "state_action_value": "state_action_value", "action": "action", + "log_prob": "_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", } + self.tensordict_keys_test( loss_fn, default_keys=default_keys, @@ -4096,16 +5467,13 @@ def test_discrete_sac_tensordict_keys(self, td_est): ) qvalue = self._create_mock_qvalue() - loss_fn = DiscreteSACLoss( - actor_network=actor, - qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, + loss_fn = CrossQLoss( + actor, + qvalue, loss_function="l2", - action_space="one-hot", ) key_mapping = { - "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), "terminated": ("terminated", ("terminated", "test")), @@ -4117,11 +5485,11 @@ def test_discrete_sac_tensordict_keys(self, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) - def test_discrete_sac_notensordict( + def test_crossq_notensordict( self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) - td = self._create_mock_data_sac( + td = self._create_mock_data_crossq( action_key=action_key, observation_key=observation_key, reward_key=reward_key, @@ -4134,13 +5502,13 @@ def test_discrete_sac_notensordict( ) qvalue = self._create_mock_qvalue( observation_key=observation_key, + action_key=action_key, + out_keys=["state_action_value"], ) - loss = DiscreteSACLoss( + loss = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - num_actions=actor.spec[action_key].space.n, - action_space="one-hot", ) loss.set_keys( action=action_key, @@ -4159,48 +5527,97 @@ def test_discrete_sac_notensordict( } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") - with pytest.warns(UserWarning, match="No target network updater has been"): - loss_val = loss(**kwargs) - loss_val_td = loss(td) + # setting the seed for each loss so that drawing the random samples from value network + # leads to same numbers for both runs + torch.manual_seed(self.seed) + loss_val = loss(**kwargs) - torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) - torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) - torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) - torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) - torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) - # test select - torch.manual_seed(self.seed) - loss.select_out_keys("loss_actor", "loss_alpha") - if torch.__version__ >= "2.0.0": + torch.manual_seed(self.seed) + + loss_val_td = loss(td) + assert len(loss_val) == 5 + + torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) + torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) + torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) + + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_alpha") + if torch.__version__ >= "2.0.0": + loss_actor, loss_alpha = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): loss_actor, loss_alpha = loss(**kwargs) - else: - with pytest.raises( - RuntimeError, - match="You are likely using tensordict.nn.dispatch with keyword arguments", - ): - loss_actor, loss_alpha = loss(**kwargs) - return - assert loss_actor == loss_val_td["loss_actor"] - assert loss_alpha == loss_val_td["loss_alpha"] + return + assert loss_actor == loss_val_td["loss_actor"] + assert loss_alpha == loss_val_td["loss_alpha"] + + def test_state_dict( + self, + ): + + model = torch.nn.Linear(3, 4) + actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) + policy = ProbabilisticActor( + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=TanhDelta, + ) + value = ValueOperator(module=model, in_keys=["obs"], out_keys="value") + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + state = loss.state_dict() + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + # with an access in between + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.target_entropy + state = loss.state_dict() + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_discrete_sac_reduction(self, reduction): + def test_crossq_reduction(self, reduction): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_mock_data_sac(device=device) + td = self._create_mock_data_crossq(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - loss_fn = DiscreteSACLoss( + + loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - num_actions=actor.spec["action"].space.n, loss_function="l2", - action_space="one-hot", - delay_qvalue=False, reduction=reduction, ) loss_fn.make_value_estimator() @@ -5686,9 +7103,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], diff --git a/test/test_env.py b/test/test_env.py index e6ca38b729c..f8f242f3955 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2061,6 +2061,7 @@ def main_collector(j, q=None): total_frames=N * n_workers * 100, storing_device=device, device=device, + cat_results=-1, ) single_collectors = [ SyncDataCollector( diff --git a/test/test_modules.py b/test/test_modules.py index 59adbea653d..592464f0a96 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -34,7 +34,14 @@ VDNMixer, ) from torchrl.modules.distributions.utils import safeatanh, safetanh -from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models import ( + BatchRenorm1d, + Conv3dNet, + ConvNet, + MLP, + NoisyLazyLinear, + NoisyLinear, +) from torchrl.modules.models.decision_transformer import ( _has_transformers, DecisionTransformer, @@ -1438,6 +1445,40 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): torch.testing.assert_close(h1, h2) +class TestBatchRenorm: + @pytest.mark.parametrize("num_steps", [0, 5]) + @pytest.mark.parametrize("smooth", [False, True]) + def test_batchrenorm(self, num_steps, smooth): + torch.manual_seed(0) + bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) + brn = BatchRenorm1d( + 5, + momentum=0.1, + eps=1e-5, + warmup_steps=num_steps, + max_d=10000, + max_r=10000, + smooth=smooth, + ) + bn.train() + brn.train() + data_train = torch.randn(100, 5).split(25) + data_test = torch.randn(100, 5) + for i, d in enumerate(data_train): + b = bn(d) + a = brn(d) + if num_steps > 0 and ( + (i < num_steps and not smooth) or (i == 0 and smooth) + ): + torch.testing.assert_close(a, b) + else: + assert not torch.isclose(a, b).all(), i + + bn.eval() + brn.eval() + torch.testing.assert_close(bn(data_test), brn(data_test)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 50e3dd5cc49..32294a25edd 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2065,18 +2065,18 @@ def _queue_len(self) -> int: def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: - cat_results = 0 + cat_results = "stack" warnings.warn( f"`cat_results` was not specified in the constructor of {type(self).__name__}. " f"For MultiSyncDataCollector, `cat_results` indicates how the data should " - f"be packed: the preferred option is `cat_results='stack'` which provides " - f"the best interoperability across torchrl components. " + f"be packed: the preferred option and current default is `cat_results='stack'` " + f"which provides the best interoperability across torchrl components. " f"Other accepted values are `cat_results=0` (previous behaviour) and " f"`cat_results=-1` (cat along time dimension). Among these two, the latter " f"should be preferred for consistency across environment configurations. " - f"Currently, the default value is `0` (using torch.cat along first dimension)." - f"From v0.5 onward, this will default to `'stack'`. " - f"To suppress this warning, set stack_results to the desired value.", + f"Currently, the default value is `'stack'`." + f"From v0.6 onward, this warning will be removed. " + f"To suppress this warning, set `cat_results` to the desired value.", category=DeprecationWarning, ) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 04c24cb8d57..0006213cd27 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1143,6 +1143,7 @@ def __eq__(self, other): if not isinstance(other, LazyStackedTensorSpec): return False if self.device != other.device: + raise RuntimeError((self, other)) return False if len(self._specs) != len(other._specs): return False diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7f462782757..4996e527527 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -26,8 +26,8 @@ LazyStackedTensorDict, TensorDict, TensorDictBase, + unravel_key, ) -from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, @@ -406,17 +406,16 @@ def _find_sync_values(self): return _do_nothing, _do_nothing if worker_device is None: - worker_not_main = [False] + worker_not_main = False - def find_all_worker_devices(item, worker_not_main=worker_not_main): + def find_all_worker_devices(item): + nonlocal worker_not_main if hasattr(item, "device"): - worker_not_main[0] = worker_not_main[0] or ( - item.device != self_device - ) + worker_not_main = worker_not_main or (item.device != self_device) for td in self.shared_tensordicts: td.apply(find_all_worker_devices, filter_empty=True) - if worker_not_main[0]: + if worker_not_main: if torch.cuda.is_available(): worker_device = ( torch.device("cuda") @@ -431,6 +430,8 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main): ) else: raise RuntimeError("Did not find a valid worker device") + else: + worker_device = self_device if ( worker_device is not None @@ -460,6 +461,7 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main): and self_device.type == "mps" ): return _mps_sync(self_device), _mps_sync(self_device) + return _do_nothing, _do_nothing def __getstate__(self): out = copy(self.__dict__) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c965e7dedf3..e30de3534d9 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.base import NO_DEFAULT from tensordict.utils import NestedKey from torchrl._utils import ( _ends_with, @@ -3020,21 +3019,11 @@ class _EnvWrapper(EnvBase): def __init__( self, *args, - device: DEVICE_TYPING = NO_DEFAULT, + device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, **kwargs, ): - if device is NO_DEFAULT: - warnings.warn( - "Your wrapper was not given a device. Currently, this " - "value will default to 'cpu'. From v0.5 it will " - "default to `None`. With a device of None, no device casting " - "is performed and the resulting tensordicts are deviceless. " - "Please set your device accordingly.", - category=DeprecationWarning, - ) - device = torch.device("cpu") super().__init__( device=device, batch_size=batch_size, diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 47f93f09779..c7935272c91 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -348,8 +348,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: batch_size=tensordict.batch_size, ) if self.device is not None: - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + tensordict_out = tensordict_out.to(self.device) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -393,8 +392,7 @@ def _reset( if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() if self.device is not None: - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + tensordict_out = tensordict_out.to(self.device) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 07c48587c14..9195929e31d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -27,7 +27,6 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -246,8 +245,8 @@ def _gym_to_torchrl_spec_transform( ).expand(batch_size) gym_spaces = gym_backend("spaces") if isinstance(spec, gym_spaces.tuple.Tuple): - result = LazyStackedTensorSpec( - *[ + result = torch.stack( + [ _gym_to_torchrl_spec_transform( s, device=device, diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 8e30fdb2a7e..9751e84a3ac 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -795,7 +795,9 @@ def _build_env( env=vmas.make_env( scenario=scenario, num_envs=num_envs, - device=self.device, + device=self.device + if self.device is not None + else torch.get_default_device(), continuous_actions=continuous_actions, max_steps=max_steps, seed=seed, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bec76c603e6..70aef03e041 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -39,7 +39,7 @@ unravel_key, unravel_key_list, ) -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._C import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import expand_as_right, expand_right, NestedKey from torch import nn, Tensor @@ -3411,14 +3411,7 @@ def __init__( out_keys_inv: Sequence[NestedKey] | None = None, ): if in_keys is not None and in_keys_inv is None: - warnings.warn( - "in_keys have been provided but not in_keys_inv. From v0.5, " - "this will result in in_keys_inv being an empty list whereas " - "now the input keys are retrieved automatically. " - "To silence this warning, pass the (possibly empty) " - "list of in_keys_inv.", - category=DeprecationWarning, - ) + in_keys_inv = [] self.dtype_in = dtype_in self.dtype_out = dtype_out diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 087cabe4186..38d8d1dfd02 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -481,9 +481,10 @@ def root_dist(self): @property def mode(self): warnings.warn( - "This computation of the mode is based on the first-order Taylor expansion " - "of the transform around the normal mean value, which can be inaccurate. " + "This computation of the mode is based on an inaccurate estimation of the mode " + "given the base_dist mode. " "To use a more stable implementation of the mode, use dist.get_mode() method instead. " + "To silence this warning, consider using the DETERMINISTIC exploration_type." "This implementation will be removed in v0.6.", category=DeprecationWarning, ) diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index fb0cc0135b8..62ccf53c30a 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -6,6 +6,8 @@ from torchrl.modules.tensordict_module.common import DistributionalDQNnet +from .batchrenorm import BatchRenorm1d + from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import ( diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py new file mode 100644 index 00000000000..26a2f9d50d2 --- /dev/null +++ b/torchrl/modules/models/batchrenorm.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn + + +class BatchRenorm1d(nn.Module): + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). + + The code is adapted from https://github.com/google-research/corenet + + BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm, + it utilizes running statistics to normalize batches after an initial warmup phase. + This approach reduces the impact of "outlier" batches that may occur during + extended training periods, making BatchRenorm more robust for long training runs. + + During the warmup phase, BatchRenorm functions identically to a BatchNorm layer. + + Args: + num_features (int): Number of features in the input tensor. + + Keyword Args: + momentum (float, optional): Momentum factor for computing the running mean and variance. + Defaults to ``0.01``. + eps (float, optional): Small value added to the variance to avoid division by zero. + Defaults to ``1e-5``. + max_r (float, optional): Maximum value for the scaling factor r. + Defaults to ``3.0``. + max_d (float, optional): Maximum value for the bias factor d. + Defaults to ``5.0``. + warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. + Defaults to ``10000``. + smooth (bool, optional): if ``True``, the behaviour smoothly transitions from regular + batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``). + Otherwise, the behaviour will transition from batch-norm to batch-renorm when + ``iter=warmup_steps``. Defaults to ``False``. + """ + + def __init__( + self, + num_features: int, + *, + momentum: float = 0.01, + eps: float = 1e-5, + max_r: float = 3.0, + max_d: float = 5.0, + warmup_steps: int = 10000, + smooth: bool = False, + ): + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.max_r = max_r + self.max_d = max_d + self.warmup_steps = warmup_steps + self.smooth = smooth + + self.register_buffer( + "running_mean", torch.zeros(num_features, dtype=torch.float32) + ) + self.register_buffer( + "running_var", torch.ones(num_features, dtype=torch.float32) + ) + self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64)) + self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.dim() >= 2: + raise ValueError( + f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}." + ) + + view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2) + + def _v(v): + return v.view(view_dims) + + running_std = (self.running_var + self.eps).sqrt_() + + if self.training: + reduce_dims = [i for i in range(x.dim()) if i != 1] + b_mean = x.mean(reduce_dims) + b_var = x.var(reduce_dims, unbiased=False) + b_std = (b_var + self.eps).sqrt_() + + r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r) + d = torch.clamp( + (b_mean.detach() - self.running_mean) / running_std, + -self.max_d, + self.max_d, + ) + + # Compute warmup factor (0 during warmup, 1 after warmup) + if self.warmup_steps > 0: + if self.smooth: + warmup_factor = self.num_batches_tracked / self.warmup_steps + else: + warmup_factor = self.num_batches_tracked // self.warmup_steps + r = 1.0 + (r - 1.0) * warmup_factor + d = d * warmup_factor + + x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d) + + unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1) + self.running_var += self.momentum * (unbiased_var - self.running_var) + self.running_mean += self.momentum * (b_mean.detach() - self.running_mean) + self.num_batches_tracked += 1 + self.num_batches_tracked.clamp_max(self.warmup_steps) + else: + x = (x - _v(self.running_mean)) / _v(running_std) + + x = _v(self.weight) * x + _v(self.bias) + return x diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 17b1ea77ee4..83b6a8d1fb3 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union import torch @@ -922,10 +921,9 @@ def __init__( out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, + raise RuntimeError( + "Using specs in action_space is deprecated. " + "Please use the 'spec' argument if you want to provide an action spec" ) action_space, _ = _process_action_space_spec(action_space, None) @@ -1136,10 +1134,9 @@ def __init__( action_mask_key: Optional[NestedKey] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, + raise RuntimeError( + "Using specs in action_space is deprecated." + "Please use the 'spec' argument if you want to provide an action spec" ) action_space, spec = _process_action_space_spec(action_space, spec) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8d2bd1d977..aa13a88c7e9 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,6 +6,7 @@ from .a2c import A2CLoss from .common import LossModule from .cql import CQLLoss, DiscreteCQLLoss +from .crossq import CrossQLoss from .ddpg import DDPGLoss from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss @@ -17,6 +18,7 @@ from .reinforce import ReinforceLoss from .sac import DiscreteSACLoss, SACLoss from .td3 import TD3Loss +from .td3_bc import TD3BCLoss from .utils import ( default_value_kwargs, distance_loss, diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py new file mode 100644 index 00000000000..22d35bd5799 --- /dev/null +++ b/torchrl/objectives/crossq.py @@ -0,0 +1,662 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import wraps +from typing import Dict, Tuple, Union + +import torch +from tensordict import TensorDict, TensorDictBase, TensorDictParams + +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey +from torch import Tensor +from torchrl.data.tensor_specs import CompositeSpec +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ProbabilisticActor +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _reduce, + _vmap_func, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +def _delezify(func): + @wraps(func) + def new_func(self, *args, **kwargs): + self.target_entropy + return func(self, *args, **kwargs) + + return new_func + + +class CrossQLoss(LossModule): + """TorchRL implementation of the CrossQ loss. + + Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING + FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX + + This class has three loss functions that will be called sequentially by the `forward` method: + :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can + be called by the user that order. + + Args: + actor_network (ProbabilisticActor): stochastic actor + qvalue_network (TensorDictModule): Q(s, a) parametric model. + This module typically outputs a ``"state_action_value"`` entry. + + Keyword Args: + num_qvalue_nets (integer, optional): number of Q-Value networks used. + Defaults to ``2``. + loss_function (str, optional): loss function to be used with + the value function loss. Default is `"smooth_l1"`. + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is None (no minimum value). + max_alpha (float, optional): max value of alpha. + Default is None (no maximum value). + action_spec (TensorSpec, optional): the action tensor spec. If not provided + and the target entropy is ``"auto"``, it will be retrieved from + the actor. + fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its + initial value. Otherwise, alpha will be optimized to + match the 'target_entropy' value. + Default is ``False``. + target_entropy (float or str, optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-prod(n_actions)`. + priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + Tensordict key where to write the + priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.crossq import CrossQLoss + >>> from tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = CrossQLoss(actor, qvalue) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network. + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives import CrossQLoss + >>> _ = torch.manual_seed(42) + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = CrossQLoss(actor, qvalue) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue, _, _, _ = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + + The output keys can also be filtered using the :meth:`CrossQLoss.select_out_keys` + method. + + Examples: + >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"advantage"``. + state_action_value (NestedKey): The input tensordict key where the + state action value is expected. Defaults to ``"state_action_value"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. + """ + + action: NestedKey = "action" + state_action_value: NestedKey = "state_action_value" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + log_prob: NestedKey = "_log_prob" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + + actor_network: ProbabilisticActor + actor_network_params: TensorDictParams + qvalue_network: TensorDictModule + qvalue_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + target_qvalue_network_params: TensorDictParams + + def __init__( + self, + actor_network: ProbabilisticActor, + qvalue_network: TensorDictModule, + *, + num_qvalue_nets: int = 2, + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = None, + max_alpha: float = None, + action_spec=None, + fixed_alpha: bool = False, + target_entropy: Union[str, float] = "auto", + priority_key: str = None, + separate_losses: bool = False, + reduction: str = None, + ) -> None: + self._in_keys = None + self._out_keys = None + if reduction is None: + reduction = "mean" + super().__init__() + self._set_deprecated_ctor_keys(priority_key=priority_key) + + # Actor + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + ) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor_network.parameters()) + else: + policy_params = None + q_value_policy_params = None + + # Q value + self.num_qvalue_nets = num_qvalue_nets + + q_value_policy_params = policy_params + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=False, + compare_against=q_value_policy_params, + ) + + self.loss_function = loss_function + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + if bool(min_alpha) ^ bool(max_alpha): + min_alpha = min_alpha if min_alpha else 0.0 + if max_alpha == 0: + raise ValueError("max_alpha must be either None or greater than 0.") + max_alpha = max_alpha if max_alpha else 1e9 + if min_alpha: + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + else: + self.min_log_alpha = None + if max_alpha: + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + else: + self.max_log_alpha = None + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + self._target_entropy = target_entropy + self._action_spec = action_spec + self._vmap_qnetworkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness=self.vmap_randomness + ) + self.reduction = reduction + + @property + def target_entropy_buffer(self): + """The target entropy. + + This value can be controlled via the `target_entropy` kwarg in the constructor. + """ + return self.target_entropy + + @property + def target_entropy(self): + target_entropy = self._buffers.get("_target_entropy", None) + if target_entropy is not None: + return target_entropy + target_entropy = self._target_entropy + action_spec = self._action_spec + actor_network = self.actor_network + device = next(self.parameters()).device + if target_entropy == "auto": + action_spec = ( + action_spec + if action_spec is not None + else getattr(actor_network, "spec", None) + ) + if action_spec is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + if not isinstance(action_spec, CompositeSpec): + action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape + else: + action_container_shape = action_spec.shape + target_entropy = -float( + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() + ) + delattr(self, "_target_entropy") + self.register_buffer( + "_target_entropy", torch.tensor(target_entropy, device=device) + ) + return self._target_entropy + + state_dict = _delezify(LossModule.state_dict) + load_state_dict = _delezify(LossModule.load_state_dict) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value=self.tensor_keys.value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, + ) + self._set_in_keys() + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + + value_net = None + hp = dict(default_value_kwargs(value_type)) + hp.update(hyperparams) + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + **hp, + value_network=value_net, + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys) + + @property + def device(self) -> torch.device: + for p in self.parameters(): + return p.device + raise RuntimeError( + "At least one of the networks of SACLoss must have trainable " "parameters." + ) + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns + a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached). + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + shape = None + if tensordict.ndimension() > 1: + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict + + loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape) + loss_actor, metadata_actor = self.actor_loss(tensordict_reshape) + loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"]) + tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) + if loss_actor.shape != loss_qvalue.shape: + raise RuntimeError( + f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}" + ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"] + out = { + "loss_actor": loss_actor, + "loss_qvalue": loss_qvalue, + "loss_alpha": loss_alpha, + "alpha": self._alpha, + "entropy": entropy.detach().mean(), + **metadata_actor, + **value_metadata, + } + td_out = TensorDict(out, []) + # td_out = td_out.named_apply( + # lambda name, value: ( + # _reduce(value, reduction=self.reduction) + # if name.startswith("loss_") + # else value + # ), + # batch_size=[], + # ) + return td_out + + @property + @_cache_values + def _cached_detached_qvalue_params(self): + return self.qvalue_network_params.detach() + + def actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Compute the actor loss. + + The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which + requires the `log_prob` field of the `metadata` returned by this method. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. + """ + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm) + + td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + self.qvalue_network.eval() + td_q.set(self.tensor_keys.action, a_reparm) + td_q = self._vmap_qnetworkN0( + td_q, + self._cached_detached_qvalue_params, + ) + + min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) + self.qvalue_network.train() + + if log_prob.shape != min_q.shape: + raise RuntimeError( + f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}" + ) + actor_loss = self._alpha * log_prob - min_q + return _reduce(actor_loss, reduction=self.reduction), { + "log_prob": log_prob.detach() + } + + def qvalue_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling. + """ + # # compute next action + with torch.no_grad(): + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + next_tensordict = tensordict.get("next").clone(False) + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.sample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = next_dist.log_prob(next_action) + + combined = torch.cat( + [ + tensordict.select(*self.qvalue_network.in_keys, strict=False), + next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + ] + ) + pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( + self.tensor_keys.state_action_value + ) + (current_state_action_value, next_state_action_value) = pred_qs.split( + tensordict.batch_size[0], dim=1 + ) + + # compute target value + if ( + next_state_action_value.shape[-len(next_sample_log_prob.shape) :] + != next_sample_log_prob.shape + ): + next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + next_state_action_value = next_state_action_value.min(0)[0] + next_state_action_value = ( + next_state_action_value - self._alpha * next_sample_log_prob + ).detach() + + target_value = self.value_estimator.value_estimate( + tensordict, next_value=next_state_action_value + ).squeeze(-1) + + # get current q-values + pred_val = current_state_action_value.squeeze(-1) + + # compute loss + td_error = abs(pred_val - target_value) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).sum(0) + metadata = {"td_error": td_error.detach().max(0)[0]} + return _reduce(loss_qval, reduction=self.reduction), metadata + + def alpha_loss(self, log_prob: Tensor) -> Tensor: + """Compute the entropy loss. + + The entropy loss should be computed last. + + Args: + log_prob (torch.Tensor): a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`. + + Returns: a differentiable tensor with the entropy loss. + """ + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_prob) + return _reduce(alpha_loss, reduction=self.reduction) + + @property + def _alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py new file mode 100644 index 00000000000..93845bb00bd --- /dev/null +++ b/torchrl/objectives/td3_bc.py @@ -0,0 +1,571 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey +from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec + +from torchrl.envs.utils import step_mdp +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _reduce, + _vmap_func, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +class TD3BCLoss(LossModule): + r"""TD3+BC Loss Module. + + Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to + Offline Reinforcement Learning" `. + + This class incorporates two loss functions, executed sequentially within the `forward` method: + + 1. :meth:`~.qvalue_loss` + 2. :meth:`~.actor_loss` + + Users also have the option to call these functions directly in the same order if preferred. + + Args: + actor_network (TensorDictModule): the actor to be trained + qvalue_network (TensorDictModule): a single Q-value network that will + be multiplicated as many times as needed. + + Keyword Args: + bounds (tuple of float, optional): the bounds of the action space. + Exclusive with ``action_spec``. Either this or ``action_spec`` must + be provided. + action_spec (TensorSpec, optional): the action spec. + Exclusive with ``bounds``. Either this or ``bounds`` must be provided. + num_qvalue_nets (int, optional): Number of Q-value networks to be + trained. Default is ``2``. + policy_noise (float, optional): Standard deviation for the target + policy action noise. Default is ``0.2``. + noise_clip (float, optional): Clipping range value for the sampled + target policy action noise. Default is ``0.5``. + alpha (float, optional): Weight for the behavioral cloning loss. + Defaults to ``2.5``. + priority_key (str, optional): Key where to write the priority value + for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. + Can be one of ``"smooth_l1"``, ``"l2"``, + ``"l1"``, Default is ``"smooth_l1"``. + delay_actor (bool, optional): whether to separate the target actor + networks from the actor networks used for + data collection. Default is ``True``. + delay_qvalue (bool, optional): Whether to separate the target Q value + networks from the Q value networks used + for data collection. Default is ``True``. + spec (TensorSpec, optional): the action tensor spec. If not provided + and the target entropy is ``"auto"``, it will be retrieved from + the actor. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.td3_bc import TD3BCLoss + >>> from tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), + state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), + target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator + >>> from torchrl.objectives.td3_bc import TD3BCLoss + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_reward=torch.randn(*batch, 1), + ... next_observation=torch.randn(*batch, n_obs)) + >>> loss_actor.backward() + + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + state_action_value (NestedKey): The input tensordict key where the state action value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_action_value"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + action: NestedKey = "action" + state_action_value: NestedKey = "state_action_value" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_qvalue", + "bc_loss", + "lmbd", + "pred_value", + "state_action_value_actor", + "next_state_value", + "target_value", + ] + + actor_network: TensorDictModule + qvalue_network: TensorDictModule + actor_network_params: TensorDictParams + qvalue_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + target_qvalue_network_params: TensorDictParams + + def __init__( + self, + actor_network: TensorDictModule, + qvalue_network: TensorDictModule, + *, + action_spec: TensorSpec = None, + bounds: Optional[Tuple[float]] = None, + num_qvalue_nets: int = 2, + policy_noise: float = 0.2, + noise_clip: float = 0.5, + alpha: float = 2.5, + loss_function: str = "smooth_l1", + delay_actor: bool = True, + delay_qvalue: bool = True, + priority_key: str = None, + separate_losses: bool = False, + reduction: str = None, + ) -> None: + if reduction is None: + reduction = "mean" + super().__init__() + self._in_keys = None + self._set_deprecated_ctor_keys(priority=priority_key) + + self.delay_actor = delay_actor + self.delay_qvalue = delay_qvalue + + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor_network.parameters()) + else: + policy_params = None + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=policy_params, + ) + + for p in self.parameters(): + device = p.device + break + else: + device = None + self.num_qvalue_nets = num_qvalue_nets + self.loss_function = loss_function + self.policy_noise = policy_noise + self.noise_clip = noise_clip + self.alpha = alpha + if not ((action_spec is not None) ^ (bounds is not None)): + raise ValueError( + "One of 'bounds' and 'action_spec' must be provided, " + f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." + ) + elif action_spec is not None: + if isinstance(action_spec, CompositeSpec): + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape + action_spec = action_spec[self.tensor_keys.action][ + (0,) * len(action_container_shape) + ] + if not isinstance(action_spec, BoundedTensorSpec): + raise ValueError( + f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + ) + low = action_spec.space.low + high = action_spec.space.high + else: + low, high = bounds + if not isinstance(low, torch.Tensor): + low = torch.tensor(low) + if not isinstance(high, torch.Tensor): + high = torch.tensor(high, device=low.device, dtype=low.dtype) + if (low > high).any(): + raise ValueError("Got a low bound higher than a high bound.") + if device is not None: + low = low.to(device) + high = high.to(device) + self.register_buffer("max_action", high) + self.register_buffer("min_action", low) + self._vmap_qvalue_network00 = _vmap_func( + self.qvalue_network, randomness=self.vmap_randomness + ) + self._vmap_actor_network00 = _vmap_func( + self.actor_network, randomness=self.vmap_randomness + ) + self.reduction = reduction + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value=self._tensor_keys.state_action_value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, + ) + self._set_in_keys() + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + @_cache_values + def _cached_detach_qvalue_network_params(self): + return self.qvalue_network_params.detach() + + @property + @_cache_values + def _cached_stack_actor_params(self): + return torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + def actor_loss(self, tensordict): + """Compute the actor loss. + + The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"` + used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda + value, and the lambda value `"lmbd"` itself. + """ + tensordict_actor_grad = tensordict.select( + *self.actor_network.in_keys, strict=False + ) + with self.actor_network_params.to_module(self.actor_network): + tensordict_actor_grad = self.actor_network(tensordict_actor_grad) + actor_loss_td = tensordict_actor_grad.select( + *self.qvalue_network.in_keys, strict=False + ).expand( + self.num_qvalue_nets, *tensordict_actor_grad.batch_size + ) # for actor loss + state_action_value_actor = ( + self._vmap_qvalue_network00( + actor_loss_td, + self._cached_detach_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + bc_loss = torch.nn.functional.mse_loss( + tensordict_actor_grad.get(self.tensor_keys.action), + tensordict.get(self.tensor_keys.action), + ) + lmbd = self.alpha / state_action_value_actor[0].abs().mean().detach() + + loss_actor = -lmbd * state_action_value_actor[0] + bc_loss + + metadata = { + "state_action_value_actor": state_action_value_actor[0].detach(), + "bc_loss": bc_loss.detach(), + "lmbd": lmbd, + } + loss_actor = _reduce(loss_actor, reduction=self.reduction) + return loss_actor, metadata + + def qvalue_loss(self, tensordict): + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`. + """ + tensordict = tensordict.clone(False) + + act = tensordict.get(self.tensor_keys.action) + + # computing early for reprod + noise = (torch.randn_like(act) * self.policy_noise).clamp( + -self.noise_clip, self.noise_clip + ) + + with torch.no_grad(): + next_td_actor = step_mdp(tensordict).select( + *self.actor_network.in_keys, strict=False + ) # next_observation -> + with self.target_actor_network_params.to_module(self.actor_network): + next_td_actor = self.actor_network(next_td_actor) + next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( + self.min_action, self.max_action + ) + next_td_actor.set( + self.tensor_keys.action, + next_action, + ) + next_val_td = next_td_actor.select( + *self.qvalue_network.in_keys, strict=False + ).expand( + self.num_qvalue_nets, *next_td_actor.batch_size + ) # for next value estimation + next_target_q1q2 = ( + self._vmap_qvalue_network00( + next_val_td, + self.target_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + # min over the next target qvalues + next_target_qvalue = next_target_q1q2.min(0)[0] + + # set next target qvalues + tensordict.set( + ("next", self.tensor_keys.state_action_value), + next_target_qvalue.unsqueeze(-1), + ) + + qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand( + self.num_qvalue_nets, + *tensordict.batch_size, + ) + # preditcted current qvalues + current_qvalue = ( + self._vmap_qvalue_network00( + qval_td, + self.qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done)) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + + td_error = (current_qvalue - target_value).pow(2) + loss_qval = distance_loss( + current_qvalue, + target_value.expand_as(current_qvalue), + loss_function=self.loss_function, + ).sum(0) + metadata = { + "td_error": td_error, + "next_state_value": next_target_qvalue.detach(), + "pred_value": current_qvalue.detach(), + "target_value": target_value.detach(), + } + loss_qval = _reduce(loss_qval, reduction=self.reduction) + return loss_qval, metadata + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns + a tensordict with these values. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + tensordict_save = tensordict + loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.qvalue_loss(tensordict_save) + tensordict_save.set( + self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] + ) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + source={ + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, + **metadata_actor, + **metadata_value, + }, + batch_size=[], + ) + return td_out + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + # we do not need a value network bc the next state value is already passed + if value_type == ValueEstimators.TD1: + self._value_estimator = TD1Estimator(value_network=None, **hp) + elif value_type == ValueEstimators.TD0: + self._value_estimator = TD0Estimator(value_network=None, **hp) + elif value_type == ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type == ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator(value_network=None, **hp) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "value": self.tensor_keys.state_action_value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 2c2f3fb21ac..b7fb8ab4ed2 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -221,11 +221,11 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation_trsf = make_grid( obs_flat, nrow=int(math.ceil(math.sqrt(obs_flat.shape[0]))) ) - self.obs.append(observation_trsf.to(torch.uint8)) + self.obs.append(observation_trsf.to("cpu", torch.uint8)) elif observation_trsf.ndimension() >= 4: - self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4)) + self.obs.extend(observation_trsf.to("cpu", torch.uint8).flatten(0, -4)) else: - self.obs.append(observation_trsf.to(torch.uint8)) + self.obs.append(observation_trsf.to("cpu", torch.uint8)) return observation def forward(self, tensordict: TensorDictBase) -> TensorDictBase: diff --git a/version.txt b/version.txt index 1d0ba9ea182..8f0916f768f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.0 +0.5.0