From 35663887d92419ec7eb326ec5ee4a1894b37d1d8 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 22:00:15 +0200 Subject: [PATCH] [Algorithm] TD3+BC (#2249) Co-authored-by: Vincent Moens Co-authored-by: Vincent Moens --- .../linux_examples/scripts/run_test.sh | 4 +- README.md | 1 + docs/source/reference/objectives.rst | 9 + sota-check/run_td3bc.sh | 26 + sota-check/submitit-release-check.sh | 1 + sota-implementations/td3_bc/config.yaml | 45 ++ sota-implementations/td3_bc/td3_bc.py | 146 ++++ sota-implementations/td3_bc/utils.py | 257 ++++++ test/test_cost.py | 742 +++++++++++++++++- torchrl/objectives/__init__.py | 1 + torchrl/objectives/td3_bc.py | 571 ++++++++++++++ 11 files changed, 1793 insertions(+), 10 deletions(-) create mode 100644 sota-check/run_td3bc.sh create mode 100644 sota-implementations/td3_bc/config.yaml create mode 100644 sota-implementations/td3_bc/td3_bc.py create mode 100644 sota-implementations/td3_bc/utils.py create mode 100644 torchrl/objectives/td3_bc.py diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 075489b208d..a984b37faa9 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -51,10 +51,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ logger.backend= - # ==================================================================================== # # ================================ Gymnasium ========================================= # +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \ + optim.gradient_steps=55 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \ collector.total_frames=80 \ collector.frames_per_batch=20 \ diff --git a/README.md b/README.md index 5ac72ff052e..2b250dac540 100644 --- a/README.md +++ b/README.md @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr - [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py) - [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py) - [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py) +- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py) - [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py) - [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py) - [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index c2f43d8e9b6..ef9bc1ee907 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -160,6 +160,15 @@ TD3 TD3Loss +TD3+BC +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + TD3BCLoss + PPO --- diff --git a/sota-check/run_td3bc.sh b/sota-check/run_td3bc.sh new file mode 100644 index 00000000000..0fefb3ecd6f --- /dev/null +++ b/sota-check/run_td3bc.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=td3bc_offline +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/td3bc_offline_%j.txt +#SBATCH --error=slurm_errors/td3bc_offline_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="td3bc_offline" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh index cad2783c653..515ac06a50b 100755 --- a/sota-check/submitit-release-check.sh +++ b/sota-check/submitit-release-check.sh @@ -65,6 +65,7 @@ scripts=( run_ppo_mujoco.sh run_sac.sh run_td3.sh + run_td3bc.sh run_dt.sh run_dt_online.sh ) diff --git a/sota-implementations/td3_bc/config.yaml b/sota-implementations/td3_bc/config.yaml new file mode 100644 index 00000000000..54275a94bc2 --- /dev/null +++ b/sota-implementations/td3_bc/config.yaml @@ -0,0 +1,45 @@ +# task and env +env: + name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency + task: "" + library: gymnasium + seed: 42 + max_episode_steps: 1000 + +# replay buffer +replay_buffer: + dataset: halfcheetah-medium-v2 + batch_size: 256 + +# optim +optim: + gradient_steps: 100000 + gamma: 0.99 + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 + adam_eps: 1e-4 + batch_size: 256 + target_update_polyak: 0.995 + policy_update_delay: 2 + policy_noise: 0.2 + noise_clip: 0.5 + alpha: 2.5 + +# network +network: + hidden_sizes: [256, 256] + activation: relu + device: null + +# logging +logger: + backend: wandb + project_name: td3+bc_${replay_buffer.dataset} + group_name: null + exp_name: TD3+BC_${replay_buffer.dataset} + mode: online + eval_iter: 5000 + eval_steps: 1000 + eval_envs: 1 + video: False diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py new file mode 100644 index 00000000000..7c43fdc1a12 --- /dev/null +++ b/sota-implementations/td3_bc/td3_bc.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""TD3+BC Example. + +This is a self-contained example of an offline RL TD3+BC training script. + +The helper functions are coded in the utils.py associated with this script. + +""" +import time + +import hydra +import numpy as np +import torch +import tqdm +from torchrl._utils import logger as torchrl_logger + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.record.loggers import generate_exp_name, get_logger + +from utils import ( + dump_video, + log_metrics, + make_environment, + make_loss_module, + make_offline_replay_buffer, + make_optimizer, + make_td3_agent, +) + + +@hydra.main(config_path="", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.library).set() + + # Create logger + exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="td3bc_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + + # Creante env + eval_env = make_environment( + cfg, + logger=logger, + ) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create agent + model, _ = make_td3_agent(cfg, eval_env, device) + + # Create loss + loss_module, target_net_updater = make_loss_module(cfg.optim, model) + + # Create optimizer + optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module) + + gradient_steps = cfg.optim.gradient_steps + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + delayed_updates = cfg.optim.policy_update_delay + update_counter = 0 + pbar = tqdm.tqdm(range(gradient_steps)) + # Training loop + start_time = time.time() + for i in pbar: + pbar.update(1) + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to(device) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) + + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_loss.item() + + to_log = {"q_loss": q_loss.item()} + + # Update actor + if update_actor: + actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict) + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Update target params + target_net_updater.step() + + to_log["actor_loss"] = actor_loss.item() + to_log.update(actorloss_metadata) + + # evaluation + if i % evaluation_interval == 0: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_env.apply(dump_video) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + if logger is not None: + log_metrics(logger, to_log, i) + + pbar.close() + torchrl_logger.info(f"Training time: {time.time() - start_time}") + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py new file mode 100644 index 00000000000..3772eefccde --- /dev/null +++ b/sota-implementations/td3_bc/utils.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import functools + +import torch + +from torch import nn, optim +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + InitTracker, + ParallelEnv, + RewardSum, + StepCounter, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ( + AdditiveGaussianWrapper, + MLP, + SafeModule, + SafeSequential, + TanhModule, + ValueOperator, +) + +from torchrl.objectives import SoftUpdate +from torchrl.objectives.td3_bc import TD3BCLoss +from torchrl.record import VideoRecorder + + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu", from_pixels=False): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + ) + elif lib == "dm_control": + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps): + transformed_env = TransformedEnv( + env, + Compose( + StepCounter(max_steps=max_episode_steps), + InitTracker(), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg, logger=None): + """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) + parallel_env = ParallelEnv( + cfg.logger.eval_envs, + EnvCreator(partial), + serial_for_single=True, + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + return train_env + + +# ==================================================================== +# Replay buffer +# --------------------------- + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + dataset_id=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +# ==================================================================== +# Model +# ----- + + +def make_td3_agent(cfg, train_env, device): + """Make TD3 agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": get_activation(cfg), + } + + actor_net = MLP(**actor_net_kwargs) + + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "param", + ], + ) + actor = SafeSequential( + actor_module, + TanhModule( + in_keys=["param"], + out_keys=["action"], + spec=action_spec, + ), + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": 1, + "activation_class": get_activation(cfg), + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = train_env.fake_tensordict() + td = td.to(device) + for net in model: + net(td) + del td + + # Exploration wrappers: + actor_model_explore = AdditiveGaussianWrapper( + model[0], + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device) + return model, actor_model_explore + + +# ==================================================================== +# TD3 Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create TD3 loss + loss_module = TD3BCLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + loss_function=cfg.loss_function, + delay_actor=True, + delay_qvalue=True, + action_spec=model[0][1].spec, + policy_noise=cfg.policy_noise, + noise_clip=cfg.noise_clip, + alpha=cfg.alpha, + ) + loss_module.make_value_estimator(gamma=cfg.gamma) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) + return loss_module, target_net_updater + + +def make_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.values(True, True)) + actor_params = list(loss_module.actor_network_params.values(True, True)) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + eps=cfg.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + eps=cfg.adam_eps, + ) + return optimizer_actor, optimizer_critic + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/test/test_cost.py b/test/test_cost.py index 76fc4e651f4..2f187c8e3ba 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -114,6 +114,7 @@ PPOLoss, QMixerLoss, SACLoss, + TD3BCLoss, TD3Loss, ) from torchrl.objectives.common import LossModule @@ -261,9 +262,9 @@ def __init__(self): self.vmap_model = _vmap_func( self.model, (None, 0), - randomness="error" - if vmap_randomness == "error" - else self.vmap_randomness, + randomness=( + "error" if vmap_randomness == "error" else self.vmap_randomness + ), ) def forward(self, td): @@ -319,9 +320,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], @@ -2714,6 +2715,729 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == torch.Size([]) +class TestTD3BC(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + in_keys=None, + out_keys=None, + dropout=0.0, + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = nn.Sequential( + nn.Linear(obs_dim, obs_dim), + nn.Dropout(dropout), + nn.Linear(obs_dim, action_dim), + ) + actor = Actor( + spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys + ) + return actor.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + action_key="action", + observation_key="observation", + ): + # Actor + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + value = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + value = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + return_log_prob=True, + ), + ) + value_head = Mod( + value, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + value = Seq(common, value_head) + return actor, value, common, td + + def _create_mock_data_td3bc( + self, + batch=8, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", + terminated_key="terminated", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_td3bc( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next": { + "observation": next_obs * mask.to(obs.dtype), + "reward": reward * mask.to(obs.dtype), + "done": done, + "terminated": terminated, + }, + "collector": {"mask": mask}, + "action": action * mask.to(obs.dtype), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) + @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) + @pytest.mark.parametrize("alpha", [0.1, 6.0]) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + @pytest.mark.parametrize("dropout", [0.0, 0.1]) + def test_td3bc( + self, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + alpha, + td_est, + use_action_spec, + dropout, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device, dropout=dropout) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3bc(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None + else: + bounds = (-1, 1) + action_spec = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with ( + pytest.warns( + UserWarning, + match="No target network updater has been associated with this loss module", + ) + if (delay_actor or delay_qvalue) + else contextlib.nullcontext() + ): + with _check_td_steady(td): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1]) + @pytest.mark.parametrize("noise_clip", [0.1]) + @pytest.mark.parametrize("alpha", [0.1]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_td3bc_state_dict( + self, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + alpha, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None + else: + bounds = (-1, 1) + action_spec = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + sd = loss_fn.state_dict() + loss_fn2 = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_td3bc_separate_losses( + self, + device, + separate_losses, + n_act=4, + ): + torch.manual_seed(self.seed) + actor, value, common, td = self._create_mock_common_layer_setup(n_act=n_act) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + loss_function="l2", + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + if separate_losses: + common_layers_no = len(list(common.parameters())) + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("n", range(1, 4)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)]) + @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) + @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) + @pytest.mark.parametrize("alpha", [0.1, 6.0]) + def test_td3bc_batcher( + self, + n, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + alpha, + gamma=0.9, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_seq_mock_data_td3bc(device=device) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_qvalue=delay_qvalue, + delay_actor=delay_actor, + ) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if (delay_qvalue or delay_actor) + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + actorp_set = set(actor.parameters()) + loss_fnp_set = set(loss_fn.parameters()) + assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_td3bc_tensordict_keys(self, td_est): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value(out_keys=["state_action_value_test"]) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + ) + key_mapping = { + "state_action_value": ("value", "state_action_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("spec", [True, False]) + @pytest.mark.parametrize("bounds", [True, False]) + def test_constructor(self, spec, bounds): + actor = self._create_mock_actor() + value = self._create_mock_value() + action_spec = actor.spec if spec else None + bounds = (-1, 1) if bounds else None + if (bounds is not None and action_spec is not None) or ( + bounds is None and action_spec is None + ): + with pytest.raises(ValueError, match="but not both"): + TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + ) + return + TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + ) + + # TODO: test for action_key, atm the action key of the TD3+BC loss is not configurable, + # since it is used in it's constructor + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_td3bc_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(in_keys=[observation_key]) + qvalue = self._create_mock_value( + observation_key=observation_key, out_keys=["state_action_value"] + ) + td = self._create_mock_data_td3bc( + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + "action": td.get("action"), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + with pytest.warns(UserWarning, match="No target network updater has been"): + torch.manual_seed(0) + loss_val_td = loss(td) + torch.manual_seed(0) + loss_val = loss(**kwargs) + loss_val_reconstruct = TensorDict(dict(zip(loss.out_keys, loss_val)), []) + assert_allclose_td(loss_val_reconstruct, loss_val_td) + + # test select + loss.select_out_keys("loss_actor", "loss_qvalue") + torch.manual_seed(0) + if torch.__version__ >= "2.0.0": + loss_actor, loss_qvalue = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_qvalue = loss(**kwargs) + return + + assert loss_actor == loss_val_td["loss_actor"] + assert loss_qvalue == loss_val_td["loss_qvalue"] + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_td3bc_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3bc(device=device) + action_spec = actor.spec + bounds = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + delay_qvalue=False, + delay_actor=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) @@ -5686,9 +6410,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8d2bd1d977..674c06123ad 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -17,6 +17,7 @@ from .reinforce import ReinforceLoss from .sac import DiscreteSACLoss, SACLoss from .td3 import TD3Loss +from .td3_bc import TD3BCLoss from .utils import ( default_value_kwargs, distance_loss, diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py new file mode 100644 index 00000000000..93845bb00bd --- /dev/null +++ b/torchrl/objectives/td3_bc.py @@ -0,0 +1,571 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey +from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec + +from torchrl.envs.utils import step_mdp +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _reduce, + _vmap_func, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +class TD3BCLoss(LossModule): + r"""TD3+BC Loss Module. + + Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to + Offline Reinforcement Learning" `. + + This class incorporates two loss functions, executed sequentially within the `forward` method: + + 1. :meth:`~.qvalue_loss` + 2. :meth:`~.actor_loss` + + Users also have the option to call these functions directly in the same order if preferred. + + Args: + actor_network (TensorDictModule): the actor to be trained + qvalue_network (TensorDictModule): a single Q-value network that will + be multiplicated as many times as needed. + + Keyword Args: + bounds (tuple of float, optional): the bounds of the action space. + Exclusive with ``action_spec``. Either this or ``action_spec`` must + be provided. + action_spec (TensorSpec, optional): the action spec. + Exclusive with ``bounds``. Either this or ``bounds`` must be provided. + num_qvalue_nets (int, optional): Number of Q-value networks to be + trained. Default is ``2``. + policy_noise (float, optional): Standard deviation for the target + policy action noise. Default is ``0.2``. + noise_clip (float, optional): Clipping range value for the sampled + target policy action noise. Default is ``0.5``. + alpha (float, optional): Weight for the behavioral cloning loss. + Defaults to ``2.5``. + priority_key (str, optional): Key where to write the priority value + for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. + Can be one of ``"smooth_l1"``, ``"l2"``, + ``"l1"``, Default is ``"smooth_l1"``. + delay_actor (bool, optional): whether to separate the target actor + networks from the actor networks used for + data collection. Default is ``True``. + delay_qvalue (bool, optional): Whether to separate the target Q value + networks from the Q value networks used + for data collection. Default is ``True``. + spec (TensorSpec, optional): the action tensor spec. If not provided + and the target entropy is ``"auto"``, it will be retrieved from + the actor. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.td3_bc import TD3BCLoss + >>> from tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), + state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), + target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator + >>> from torchrl.objectives.td3_bc import TD3BCLoss + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_reward=torch.randn(*batch, 1), + ... next_observation=torch.randn(*batch, n_obs)) + >>> loss_actor.backward() + + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + state_action_value (NestedKey): The input tensordict key where the state action value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_action_value"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + action: NestedKey = "action" + state_action_value: NestedKey = "state_action_value" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_qvalue", + "bc_loss", + "lmbd", + "pred_value", + "state_action_value_actor", + "next_state_value", + "target_value", + ] + + actor_network: TensorDictModule + qvalue_network: TensorDictModule + actor_network_params: TensorDictParams + qvalue_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + target_qvalue_network_params: TensorDictParams + + def __init__( + self, + actor_network: TensorDictModule, + qvalue_network: TensorDictModule, + *, + action_spec: TensorSpec = None, + bounds: Optional[Tuple[float]] = None, + num_qvalue_nets: int = 2, + policy_noise: float = 0.2, + noise_clip: float = 0.5, + alpha: float = 2.5, + loss_function: str = "smooth_l1", + delay_actor: bool = True, + delay_qvalue: bool = True, + priority_key: str = None, + separate_losses: bool = False, + reduction: str = None, + ) -> None: + if reduction is None: + reduction = "mean" + super().__init__() + self._in_keys = None + self._set_deprecated_ctor_keys(priority=priority_key) + + self.delay_actor = delay_actor + self.delay_qvalue = delay_qvalue + + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor_network.parameters()) + else: + policy_params = None + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=policy_params, + ) + + for p in self.parameters(): + device = p.device + break + else: + device = None + self.num_qvalue_nets = num_qvalue_nets + self.loss_function = loss_function + self.policy_noise = policy_noise + self.noise_clip = noise_clip + self.alpha = alpha + if not ((action_spec is not None) ^ (bounds is not None)): + raise ValueError( + "One of 'bounds' and 'action_spec' must be provided, " + f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." + ) + elif action_spec is not None: + if isinstance(action_spec, CompositeSpec): + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape + action_spec = action_spec[self.tensor_keys.action][ + (0,) * len(action_container_shape) + ] + if not isinstance(action_spec, BoundedTensorSpec): + raise ValueError( + f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + ) + low = action_spec.space.low + high = action_spec.space.high + else: + low, high = bounds + if not isinstance(low, torch.Tensor): + low = torch.tensor(low) + if not isinstance(high, torch.Tensor): + high = torch.tensor(high, device=low.device, dtype=low.dtype) + if (low > high).any(): + raise ValueError("Got a low bound higher than a high bound.") + if device is not None: + low = low.to(device) + high = high.to(device) + self.register_buffer("max_action", high) + self.register_buffer("min_action", low) + self._vmap_qvalue_network00 = _vmap_func( + self.qvalue_network, randomness=self.vmap_randomness + ) + self._vmap_actor_network00 = _vmap_func( + self.actor_network, randomness=self.vmap_randomness + ) + self.reduction = reduction + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value=self._tensor_keys.state_action_value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, + ) + self._set_in_keys() + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + @_cache_values + def _cached_detach_qvalue_network_params(self): + return self.qvalue_network_params.detach() + + @property + @_cache_values + def _cached_stack_actor_params(self): + return torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + def actor_loss(self, tensordict): + """Compute the actor loss. + + The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"` + used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda + value, and the lambda value `"lmbd"` itself. + """ + tensordict_actor_grad = tensordict.select( + *self.actor_network.in_keys, strict=False + ) + with self.actor_network_params.to_module(self.actor_network): + tensordict_actor_grad = self.actor_network(tensordict_actor_grad) + actor_loss_td = tensordict_actor_grad.select( + *self.qvalue_network.in_keys, strict=False + ).expand( + self.num_qvalue_nets, *tensordict_actor_grad.batch_size + ) # for actor loss + state_action_value_actor = ( + self._vmap_qvalue_network00( + actor_loss_td, + self._cached_detach_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + bc_loss = torch.nn.functional.mse_loss( + tensordict_actor_grad.get(self.tensor_keys.action), + tensordict.get(self.tensor_keys.action), + ) + lmbd = self.alpha / state_action_value_actor[0].abs().mean().detach() + + loss_actor = -lmbd * state_action_value_actor[0] + bc_loss + + metadata = { + "state_action_value_actor": state_action_value_actor[0].detach(), + "bc_loss": bc_loss.detach(), + "lmbd": lmbd, + } + loss_actor = _reduce(loss_actor, reduction=self.reduction) + return loss_actor, metadata + + def qvalue_loss(self, tensordict): + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`. + """ + tensordict = tensordict.clone(False) + + act = tensordict.get(self.tensor_keys.action) + + # computing early for reprod + noise = (torch.randn_like(act) * self.policy_noise).clamp( + -self.noise_clip, self.noise_clip + ) + + with torch.no_grad(): + next_td_actor = step_mdp(tensordict).select( + *self.actor_network.in_keys, strict=False + ) # next_observation -> + with self.target_actor_network_params.to_module(self.actor_network): + next_td_actor = self.actor_network(next_td_actor) + next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( + self.min_action, self.max_action + ) + next_td_actor.set( + self.tensor_keys.action, + next_action, + ) + next_val_td = next_td_actor.select( + *self.qvalue_network.in_keys, strict=False + ).expand( + self.num_qvalue_nets, *next_td_actor.batch_size + ) # for next value estimation + next_target_q1q2 = ( + self._vmap_qvalue_network00( + next_val_td, + self.target_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + # min over the next target qvalues + next_target_qvalue = next_target_q1q2.min(0)[0] + + # set next target qvalues + tensordict.set( + ("next", self.tensor_keys.state_action_value), + next_target_qvalue.unsqueeze(-1), + ) + + qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand( + self.num_qvalue_nets, + *tensordict.batch_size, + ) + # preditcted current qvalues + current_qvalue = ( + self._vmap_qvalue_network00( + qval_td, + self.qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done)) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + + td_error = (current_qvalue - target_value).pow(2) + loss_qval = distance_loss( + current_qvalue, + target_value.expand_as(current_qvalue), + loss_function=self.loss_function, + ).sum(0) + metadata = { + "td_error": td_error, + "next_state_value": next_target_qvalue.detach(), + "pred_value": current_qvalue.detach(), + "target_value": target_value.detach(), + } + loss_qval = _reduce(loss_qval, reduction=self.reduction) + return loss_qval, metadata + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns + a tensordict with these values. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + tensordict_save = tensordict + loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.qvalue_loss(tensordict_save) + tensordict_save.set( + self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] + ) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + source={ + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, + **metadata_actor, + **metadata_value, + }, + batch_size=[], + ) + return td_out + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + # we do not need a value network bc the next state value is already passed + if value_type == ValueEstimators.TD1: + self._value_estimator = TD1Estimator(value_network=None, **hp) + elif value_type == ValueEstimators.TD0: + self._value_estimator = TD0Estimator(value_network=None, **hp) + elif value_type == ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type == ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator(value_network=None, **hp) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "value": self.tensor_keys.state_action_value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys)