diff --git a/.github/scripts/m1_script.sh b/.github/scripts/td_script.sh
similarity index 71%
rename from .github/scripts/m1_script.sh
rename to .github/scripts/td_script.sh
index 6552d8e4622..6da1cad5d79 100644
--- a/.github/scripts/m1_script.sh
+++ b/.github/scripts/td_script.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-export TORCHRL_BUILD_VERSION=0.4.0
+export TORCHRL_BUILD_VERSION=0.5.0
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh
index 075489b208d..f8b700c0410 100755
--- a/.github/unittest/linux_examples/scripts/run_test.sh
+++ b/.github/unittest/linux_examples/scripts/run_test.sh
@@ -9,9 +9,19 @@
#
#
-set -e
+#set -e
set -v
+# Initialize an error flag
+error_occurred=0
+# Function to handle errors
+error_handler() {
+ echo "Error on line $1"
+ error_occurred=1
+}
+# Trap ERR to call the error_handler function with the failing line number
+trap 'error_handler $LINENO' ERR
+
export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
@@ -24,6 +34,7 @@ lib_dir="${env_dir}/lib"
# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
+export CUDA_LAUNCH_BLOCKING=1
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
#python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200
@@ -51,10 +62,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
logger.backend=
-
# ==================================================================================== #
# ================================ Gymnasium ========================================= #
+python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \
+ optim.gradient_steps=55 \
+ logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
@@ -149,18 +162,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
-python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
- collector.total_frames=200 \
+python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \
+ collector.total_frames=48 \
collector.init_random_frames=10 \
- collector.frames_per_batch=200 \
- env.n_parallel_envs=4 \
- optimization.optim_steps_per_batch=1 \
- logger.video=True \
- logger.backend=csv \
- replay_buffer.buffer_size=120 \
- replay_buffer.batch_size=24 \
- replay_buffer.batch_length=12 \
- networks.rssm_hidden_dim=17
+ collector.frames_per_batch=16 \
+ collector.env_per_collector=2 \
+ collector.device= \
+ optim.batch_size=10 \
+ optim.utd_ratio=1 \
+ replay_buffer.size=120 \
+ env.name=Pendulum-v1 \
+ network.device= \
+ logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
@@ -200,8 +213,8 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
- logger.backend=csv \
logger.video=True \
+ logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
@@ -298,3 +311,11 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ba
coverage combine
coverage xml -i
+
+# Check if any errors occurred during the script execution
+if [ "$error_occurred" -ne 0 ]; then
+ echo "Errors occurred during script execution"
+ exit 1
+else
+ echo "Script executed successfully"
+fi
diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml
index 5171a7c3e2a..f51c5ed79b6 100644
--- a/.github/workflows/build-wheels-linux.yml
+++ b/.github/workflows/build-wheels-linux.yml
@@ -45,3 +45,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
+ env-var-script: .github/scripts/td_script.sh
diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml
index 84fe79d09d2..73a365a79f2 100644
--- a/.github/workflows/build-wheels-m1.yml
+++ b/.github/workflows/build-wheels-m1.yml
@@ -46,4 +46,4 @@ jobs:
runner-type: macos-m1-stable
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
- env-var-script: .github/scripts/m1_script.sh
+ env-var-script: .github/scripts/td_script.sh
diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml
index 683f2a93f69..1beef7318f4 100644
--- a/.github/workflows/build-wheels-windows.yml
+++ b/.github/workflows/build-wheels-windows.yml
@@ -46,3 +46,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
+ env-var-script: .github/scripts/td_script.sh
diff --git a/README.md b/README.md
index 5ac72ff052e..2b250dac540 100644
--- a/README.md
+++ b/README.md
@@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr
- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py)
- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py)
- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py)
+- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py)
- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py)
- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py)
- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py)
diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst
index ccd6cb23ed0..b46d789ed15 100644
--- a/docs/source/reference/modules.rst
+++ b/docs/source/reference/modules.rst
@@ -317,6 +317,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
+ BatchRenorm
Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst
index c2f43d8e9b6..96a887196aa 100644
--- a/docs/source/reference/objectives.rst
+++ b/docs/source/reference/objectives.rst
@@ -121,6 +121,15 @@ REDQ
REDQLoss
+CrossQ
+----
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template_noinherit.rst
+
+ CrossQ
+
IQL
----
@@ -160,6 +169,15 @@ TD3
TD3Loss
+TD3+BC
+----
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template_noinherit.rst
+
+ TD3BCLoss
+
PPO
---
diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst
index 2f0982257eb..11384bda0e6 100644
--- a/docs/source/reference/trainers.rst
+++ b/docs/source/reference/trainers.rst
@@ -124,26 +124,26 @@ Checkpointing
-------------
The trainer class and hooks support checkpointing, which can be achieved either
-using the ``torchsnapshot ``_ backend or
+using the `torchsnapshot `_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:
.. code-block::
- $ CKPT_BACKEND=torch python script.py
+ $ CKPT_BACKEND=torchsnapshot python script.py
-which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
+``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.
-When building a trainer, one can provide a file path where the checkpoints are to
+When building a trainer, one can provide a path where the checkpoints are to
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).
.. code-block::
- >>> filepath = "path/to/dir/"
+ >>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
diff --git a/setup.py b/setup.py
index a439829db17..73541790e8f 100644
--- a/setup.py
+++ b/setup.py
@@ -172,7 +172,7 @@ def _main(argv):
if is_nightly:
tensordict_dep = "tensordict-nightly"
else:
- tensordict_dep = "tensordict>=0.4.0"
+ tensordict_dep = "tensordict>=0.5.0"
if is_nightly:
version = get_nightly_version()
diff --git a/sota-check/run_crossq.sh b/sota-check/run_crossq.sh
new file mode 100644
index 00000000000..2ae4ea51c49
--- /dev/null
+++ b/sota-check/run_crossq.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+#SBATCH --job-name=crossq
+#SBATCH --ntasks=32
+#SBATCH --cpus-per-task=1
+#SBATCH --gres=gpu:1
+#SBATCH --output=slurm_logs/crossq_%j.txt
+#SBATCH --error=slurm_errors/crossq_%j.txt
+
+current_commit=$(git rev-parse --short HEAD)
+project_name="torchrl-example-check-$current_commit"
+group_name="crossq"
+export PYTHONPATH=$(dirname $(dirname $PWD))
+python $PYTHONPATH/sota-implementations/crossq/crossq.py \
+ logger.backend=wandb \
+ logger.project_name="$project_name" \
+ logger.group_name="$group_name"
+
+# Capture the exit status of the Python command
+exit_status=$?
+# Write the exit status to a file
+if [ $exit_status -eq 0 ]; then
+ echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log
+else
+ echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
+fi
diff --git a/sota-check/run_td3bc.sh b/sota-check/run_td3bc.sh
new file mode 100644
index 00000000000..0fefb3ecd6f
--- /dev/null
+++ b/sota-check/run_td3bc.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+#SBATCH --job-name=td3bc_offline
+#SBATCH --ntasks=32
+#SBATCH --cpus-per-task=1
+#SBATCH --gres=gpu:1
+#SBATCH --output=slurm_logs/td3bc_offline_%j.txt
+#SBATCH --error=slurm_errors/td3bc_offline_%j.txt
+
+current_commit=$(git rev-parse --short HEAD)
+project_name="torchrl-example-check-$current_commit"
+group_name="td3bc_offline"
+export PYTHONPATH=$(dirname $(dirname $PWD))
+python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \
+ logger.backend=wandb \
+ logger.project_name="$project_name" \
+ logger.group_name="$group_name"
+
+# Capture the exit status of the Python command
+exit_status=$?
+# Write the exit status to a file
+if [ $exit_status -eq 0 ]; then
+ echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log
+else
+ echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log
+fi
diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh
index cad2783c653..515ac06a50b 100755
--- a/sota-check/submitit-release-check.sh
+++ b/sota-check/submitit-release-check.sh
@@ -65,6 +65,7 @@ scripts=(
run_ppo_mujoco.sh
run_sac.sh
run_td3.sh
+ run_td3bc.sh
run_dt.sh
run_dt_online.sh
)
diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py
index 775dcfe206d..f8c18147306 100644
--- a/sota-implementations/a2c/a2c_atari.py
+++ b/sota-implementations/a2c/a2c_atari.py
@@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py
index 0276039058f..d115174eb9c 100644
--- a/sota-implementations/a2c/a2c_mujoco.py
+++ b/sota-implementations/a2c/a2c_mujoco.py
@@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py
index d8185c8091c..5ca70f83b53 100644
--- a/sota-implementations/cql/cql_offline.py
+++ b/sota-implementations/cql/cql_offline.py
@@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# evaluation
if i % evaluation_interval == 0:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py
index 5f8f81357c8..cf629ed0733 100644
--- a/sota-implementations/cql/cql_online.py
+++ b/sota-implementations/cql/cql_online.py
@@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py
index 4b6f14cd058..d0d6693eb97 100644
--- a/sota-implementations/cql/discrete_cql_online.py
+++ b/sota-implementations/cql/discrete_cql_online.py
@@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml
new file mode 100644
index 00000000000..1dcbd3db92d
--- /dev/null
+++ b/sota-implementations/crossq/config.yaml
@@ -0,0 +1,58 @@
+# environment and task
+env:
+ name: HalfCheetah-v4
+ task: ""
+ library: gym
+ max_episode_steps: 1000
+ seed: 42
+
+# collector
+collector:
+ total_frames: 1_000_000
+ init_random_frames: 25000
+ frames_per_batch: 1000
+ init_env_steps: 1000
+ device: cpu
+ env_per_collector: 1
+ reset_at_each_iter: False
+
+# replay buffer
+replay_buffer:
+ size: 1000000
+ prb: 0 # use prioritized experience replay
+ scratch_dir: null
+
+# optim
+optim:
+ utd_ratio: 1.0
+ policy_update_delay: 3
+ gamma: 0.99
+ loss_function: l2
+ lr: 1.0e-3
+ weight_decay: 0.0
+ batch_size: 256
+ alpha_init: 1.0
+ adam_eps: 1.0e-8
+ beta1: 0.5
+ beta2: 0.999
+
+# network
+network:
+ batch_norm_momentum: 0.01
+ warmup_steps: 100000
+ critic_hidden_sizes: [2048, 2048]
+ actor_hidden_sizes: [256, 256]
+ critic_activation: relu
+ actor_activation: relu
+ default_policy_scale: 1.0
+ scale_lb: 0.1
+ device: "cuda:0"
+
+# logging
+logger:
+ backend: wandb
+ project_name: torchrl_example_crossQ
+ group_name: null
+ exp_name: ${env.name}_CrossQ
+ mode: online
+ eval_iter: 25000
diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py
new file mode 100644
index 00000000000..df34d4ae68d
--- /dev/null
+++ b/sota-implementations/crossq/crossq.py
@@ -0,0 +1,229 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""CrossQ Example.
+
+This is a simple self-contained example of a CrossQ training script.
+
+It supports state environments like MuJoCo.
+
+The helper functions are coded in the utils.py associated with this script.
+"""
+import time
+
+import hydra
+
+import numpy as np
+import torch
+import torch.cuda
+import tqdm
+from torchrl._utils import logger as torchrl_logger
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+
+from torchrl.record.loggers import generate_exp_name, get_logger
+from utils import (
+ log_metrics,
+ make_collector,
+ make_crossQ_agent,
+ make_crossQ_optimizer,
+ make_environment,
+ make_loss_module,
+ make_replay_buffer,
+)
+
+
+@hydra.main(version_base="1.1", config_path=".", config_name="config")
+def main(cfg: "DictConfig"): # noqa: F821
+ device = cfg.network.device
+ if device in ("", None):
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+ device = torch.device(device)
+
+ # Create logger
+ exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
+ logger = None
+ if cfg.logger.backend:
+ logger = get_logger(
+ logger_type=cfg.logger.backend,
+ logger_name="crossq_logging",
+ experiment_name=exp_name,
+ wandb_kwargs={
+ "mode": cfg.logger.mode,
+ "config": dict(cfg),
+ "project": cfg.logger.project_name,
+ "group": cfg.logger.group_name,
+ },
+ )
+
+ torch.manual_seed(cfg.env.seed)
+ np.random.seed(cfg.env.seed)
+
+ # Create environments
+ train_env, eval_env = make_environment(cfg)
+
+ # Create agent
+ model, exploration_policy = make_crossQ_agent(cfg, train_env, device)
+
+ # Create CrossQ loss
+ loss_module = make_loss_module(cfg, model)
+
+ # Create off-policy collector
+ collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device)
+
+ # Create replay buffer
+ replay_buffer = make_replay_buffer(
+ batch_size=cfg.optim.batch_size,
+ prb=cfg.replay_buffer.prb,
+ buffer_size=cfg.replay_buffer.size,
+ scratch_dir=cfg.replay_buffer.scratch_dir,
+ device="cpu",
+ )
+
+ # Create optimizers
+ (
+ optimizer_actor,
+ optimizer_critic,
+ optimizer_alpha,
+ ) = make_crossQ_optimizer(cfg, loss_module)
+
+ # Main loop
+ start_time = time.time()
+ collected_frames = 0
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
+
+ init_random_frames = cfg.collector.init_random_frames
+ num_updates = int(
+ cfg.collector.env_per_collector
+ * cfg.collector.frames_per_batch
+ * cfg.optim.utd_ratio
+ )
+ prb = cfg.replay_buffer.prb
+ eval_iter = cfg.logger.eval_iter
+ frames_per_batch = cfg.collector.frames_per_batch
+ eval_rollout_steps = cfg.env.max_episode_steps
+
+ sampling_start = time.time()
+ update_counter = 0
+ delayed_updates = cfg.optim.policy_update_delay
+ for _, tensordict in enumerate(collector):
+ sampling_time = time.time() - sampling_start
+
+ # Update weights of the inference policy
+ collector.update_policy_weights_()
+
+ pbar.update(tensordict.numel())
+
+ tensordict = tensordict.reshape(-1)
+ current_frames = tensordict.numel()
+ # Add to replay buffer
+ replay_buffer.extend(tensordict.cpu())
+ collected_frames += current_frames
+
+ # Optimization steps
+ training_start = time.time()
+ if collected_frames >= init_random_frames:
+ (
+ actor_losses,
+ alpha_losses,
+ q_losses,
+ ) = ([], [], [])
+ for _ in range(num_updates):
+
+ # Update actor every delayed_updates
+ update_counter += 1
+ update_actor = update_counter % delayed_updates == 0
+ # Sample from replay buffer
+ sampled_tensordict = replay_buffer.sample()
+ if sampled_tensordict.device != device:
+ sampled_tensordict = sampled_tensordict.to(device)
+ else:
+ sampled_tensordict = sampled_tensordict.clone()
+
+ # Compute loss
+ q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
+ q_loss = q_loss.mean()
+ # Update critic
+ optimizer_critic.zero_grad()
+ q_loss.backward()
+ optimizer_critic.step()
+ q_losses.append(q_loss.detach().item())
+
+ if update_actor:
+ actor_loss, metadata_actor = loss_module.actor_loss(
+ sampled_tensordict
+ )
+ actor_loss = actor_loss.mean()
+ alpha_loss = loss_module.alpha_loss(
+ log_prob=metadata_actor["log_prob"]
+ ).mean()
+
+ # Update actor
+ optimizer_actor.zero_grad()
+ actor_loss.backward()
+ optimizer_actor.step()
+
+ # Update alpha
+ optimizer_alpha.zero_grad()
+ alpha_loss.backward()
+ optimizer_alpha.step()
+
+ actor_losses.append(actor_loss.detach().item())
+ alpha_losses.append(alpha_loss.detach().item())
+
+ # Update priority
+ if prb:
+ replay_buffer.update_priority(sampled_tensordict)
+
+ training_time = time.time() - training_start
+ episode_end = (
+ tensordict["next", "done"]
+ if tensordict["next", "done"].any()
+ else tensordict["next", "truncated"]
+ )
+ episode_rewards = tensordict["next", "episode_reward"][episode_end]
+
+ # Logging
+ metrics_to_log = {}
+ if len(episode_rewards) > 0:
+ episode_length = tensordict["next", "step_count"][episode_end]
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
+ episode_length
+ )
+ if collected_frames >= init_random_frames:
+ metrics_to_log["train/q_loss"] = np.mean(q_losses).item()
+ metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item()
+ metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item()
+ metrics_to_log["train/sampling_time"] = sampling_time
+ metrics_to_log["train/training_time"] = training_time
+
+ # Evaluation
+ if abs(collected_frames % eval_iter) < frames_per_batch:
+ with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ eval_start = time.time()
+ eval_rollout = eval_env.rollout(
+ eval_rollout_steps,
+ model[0],
+ auto_cast_to_device=True,
+ break_when_any_done=True,
+ )
+ eval_time = time.time() - eval_start
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
+ metrics_to_log["eval/reward"] = eval_reward
+ metrics_to_log["eval/time"] = eval_time
+ if logger is not None:
+ log_metrics(logger, metrics_to_log, collected_frames)
+ sampling_start = time.time()
+
+ collector.shutdown()
+ end_time = time.time()
+ execution_time = end_time - start_time
+ torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py
new file mode 100644
index 00000000000..9883bc50b17
--- /dev/null
+++ b/sota-implementations/crossq/utils.py
@@ -0,0 +1,310 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from tensordict.nn import InteractionType, TensorDictModule
+from tensordict.nn.distributions import NormalParamExtractor
+from torch import nn, optim
+from torchrl.collectors import SyncDataCollector
+from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
+from torchrl.data.replay_buffers.storages import LazyMemmapStorage
+from torchrl.envs import (
+ CatTensors,
+ Compose,
+ DMControlEnv,
+ DoubleToFloat,
+ EnvCreator,
+ ParallelEnv,
+ TransformedEnv,
+)
+from torchrl.envs.libs.gym import GymEnv, set_gym_backend
+from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
+from torchrl.modules.distributions import TanhNormal
+
+from torchrl.modules.models.batchrenorm import BatchRenorm1d
+from torchrl.objectives import CrossQLoss
+
+# ====================================================================
+# Environment utils
+# -----------------
+
+
+def env_maker(cfg, device="cpu"):
+ lib = cfg.env.library
+ if lib in ("gym", "gymnasium"):
+ with set_gym_backend(lib):
+ return GymEnv(
+ cfg.env.name,
+ device=device,
+ )
+ elif lib == "dm_control":
+ env = DMControlEnv(cfg.env.name, cfg.env.task)
+ return TransformedEnv(
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
+ )
+ else:
+ raise NotImplementedError(f"Unknown lib {lib}.")
+
+
+def apply_env_transforms(env, max_episode_steps=1000):
+ transformed_env = TransformedEnv(
+ env,
+ Compose(
+ InitTracker(),
+ StepCounter(max_episode_steps),
+ DoubleToFloat(),
+ RewardSum(),
+ ),
+ )
+ return transformed_env
+
+
+def make_environment(cfg):
+ """Make environments for training and evaluation."""
+ parallel_env = ParallelEnv(
+ cfg.collector.env_per_collector,
+ EnvCreator(lambda cfg=cfg: env_maker(cfg)),
+ serial_for_single=True,
+ )
+ parallel_env.set_seed(cfg.env.seed)
+
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
+
+ eval_env = TransformedEnv(
+ ParallelEnv(
+ cfg.collector.env_per_collector,
+ EnvCreator(lambda cfg=cfg: env_maker(cfg)),
+ serial_for_single=True,
+ ),
+ train_env.transform.clone(),
+ )
+ return train_env, eval_env
+
+
+# ====================================================================
+# Collector and replay buffer
+# ---------------------------
+
+
+def make_collector(cfg, train_env, actor_model_explore, device):
+ """Make collector."""
+ collector = SyncDataCollector(
+ train_env,
+ actor_model_explore,
+ init_random_frames=cfg.collector.init_random_frames,
+ frames_per_batch=cfg.collector.frames_per_batch,
+ total_frames=cfg.collector.total_frames,
+ device=device,
+ )
+ collector.set_seed(cfg.env.seed)
+ return collector
+
+
+def make_replay_buffer(
+ batch_size,
+ prb=False,
+ buffer_size=1000000,
+ scratch_dir=None,
+ device="cpu",
+ prefetch=3,
+):
+ if prb:
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
+ alpha=0.7,
+ beta=0.5,
+ pin_memory=False,
+ prefetch=prefetch,
+ storage=LazyMemmapStorage(
+ buffer_size,
+ scratch_dir=scratch_dir,
+ ),
+ batch_size=batch_size,
+ )
+ else:
+ replay_buffer = TensorDictReplayBuffer(
+ pin_memory=False,
+ prefetch=prefetch,
+ storage=LazyMemmapStorage(
+ buffer_size,
+ scratch_dir=scratch_dir,
+ ),
+ batch_size=batch_size,
+ )
+ replay_buffer.append_transform(lambda x: x.to(device, non_blocking=True))
+ return replay_buffer
+
+
+# ====================================================================
+# Model
+# -----
+
+
+def make_crossQ_agent(cfg, train_env, device):
+ """Make CrossQ agent."""
+ # Define Actor Network
+ in_keys = ["observation"]
+ action_spec = train_env.action_spec
+ if train_env.batch_size:
+ action_spec = action_spec[(0,) * len(train_env.batch_size)]
+ actor_net_kwargs = {
+ "num_cells": cfg.network.actor_hidden_sizes,
+ "out_features": 2 * action_spec.shape[-1],
+ "activation_class": get_activation(cfg.network.actor_activation),
+ "norm_class": BatchRenorm1d,
+ "norm_kwargs": {
+ "momentum": cfg.network.batch_norm_momentum,
+ "num_features": cfg.network.actor_hidden_sizes[-1],
+ "warmup_steps": cfg.network.warmup_steps,
+ },
+ }
+
+ actor_net = MLP(**actor_net_kwargs)
+
+ dist_class = TanhNormal
+ dist_kwargs = {
+ "low": action_spec.space.low,
+ "high": action_spec.space.high,
+ "tanh_loc": False,
+ }
+
+ actor_extractor = NormalParamExtractor(
+ scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}",
+ scale_lb=cfg.network.scale_lb,
+ )
+ actor_net = nn.Sequential(actor_net, actor_extractor)
+
+ in_keys_actor = in_keys
+ actor_module = TensorDictModule(
+ actor_net,
+ in_keys=in_keys_actor,
+ out_keys=[
+ "loc",
+ "scale",
+ ],
+ )
+ actor = ProbabilisticActor(
+ spec=action_spec,
+ in_keys=["loc", "scale"],
+ module=actor_module,
+ distribution_class=dist_class,
+ distribution_kwargs=dist_kwargs,
+ default_interaction_type=InteractionType.RANDOM,
+ return_log_prob=False,
+ )
+
+ # Define Critic Network
+ qvalue_net_kwargs = {
+ "num_cells": cfg.network.critic_hidden_sizes,
+ "out_features": 1,
+ "activation_class": get_activation(cfg.network.critic_activation),
+ "norm_class": BatchRenorm1d,
+ "norm_kwargs": {
+ "momentum": cfg.network.batch_norm_momentum,
+ "num_features": cfg.network.critic_hidden_sizes[-1],
+ "warmup_steps": cfg.network.warmup_steps,
+ },
+ }
+
+ qvalue_net = MLP(
+ **qvalue_net_kwargs,
+ )
+
+ qvalue = ValueOperator(
+ in_keys=["action"] + in_keys,
+ module=qvalue_net,
+ )
+
+ model = nn.ModuleList([actor, qvalue]).to(device)
+
+ # init nets
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
+ td = train_env.fake_tensordict()
+ td = td.to(device)
+ for net in model:
+ net.eval()
+ net(td)
+ net.train()
+ del td
+
+ return model, model[0]
+
+
+# ====================================================================
+# CrossQ Loss
+# ---------
+
+
+def make_loss_module(cfg, model):
+ """Make loss module and target network updater."""
+ # Create CrossQ loss
+ loss_module = CrossQLoss(
+ actor_network=model[0],
+ qvalue_network=model[1],
+ num_qvalue_nets=2,
+ loss_function=cfg.optim.loss_function,
+ alpha_init=cfg.optim.alpha_init,
+ )
+ loss_module.make_value_estimator(gamma=cfg.optim.gamma)
+
+ return loss_module
+
+
+def split_critic_params(critic_params):
+ critic1_params = []
+ critic2_params = []
+
+ for param in critic_params:
+ data1, data2 = param.data.chunk(2, dim=0)
+ critic1_params.append(nn.Parameter(data1))
+ critic2_params.append(nn.Parameter(data2))
+ return critic1_params, critic2_params
+
+
+def make_crossQ_optimizer(cfg, loss_module):
+ critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
+
+ optimizer_actor = optim.Adam(
+ actor_params,
+ lr=cfg.optim.lr,
+ weight_decay=cfg.optim.weight_decay,
+ eps=cfg.optim.adam_eps,
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
+ )
+ optimizer_critic = optim.Adam(
+ critic_params,
+ lr=cfg.optim.lr,
+ weight_decay=cfg.optim.weight_decay,
+ eps=cfg.optim.adam_eps,
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
+ )
+ optimizer_alpha = optim.Adam(
+ [loss_module.log_alpha],
+ lr=cfg.optim.lr,
+ )
+ return optimizer_actor, optimizer_critic, optimizer_alpha
+
+
+# ====================================================================
+# General utils
+# ---------
+
+
+def log_metrics(logger, metrics, step):
+ for metric_name, metric_value in metrics.items():
+ logger.log_scalar(metric_name, metric_value, step)
+
+
+def get_activation(activation: str):
+ if activation == "relu":
+ return nn.ReLU
+ elif activation == "tanh":
+ return nn.Tanh
+ elif activation == "leaky_relu":
+ return nn.LeakyReLU
+ else:
+ raise NotImplementedError
diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py
index eb0b88c26f7..a92ee6185c3 100644
--- a/sota-implementations/ddpg/ddpg.py
+++ b/sota-implementations/ddpg/ddpg.py
@@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py
index 59dbcafd8c9..9cca9fd8af5 100644
--- a/sota-implementations/decision_transformer/dt.py
+++ b/sota-implementations/decision_transformer/dt.py
@@ -56,7 +56,9 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Create test environment
- test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video)
+ test_env = make_env(
+ cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device
+ )
if cfg.logger.video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
@@ -114,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821
to_log = {"train/loss": loss_vals["loss"]}
# Evaluation
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py
index 5cb297e5c0b..da2241ce9fa 100644
--- a/sota-implementations/decision_transformer/online_dt.py
+++ b/sota-implementations/decision_transformer/online_dt.py
@@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821
}
# Evaluation
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py
index 7c9500aa4e7..409833c75fa 100644
--- a/sota-implementations/decision_transformer/utils.py
+++ b/sota-implementations/decision_transformer/utils.py
@@ -57,7 +57,7 @@
# -----------------
-def make_base_env(env_cfg, from_pixels=False):
+def make_base_env(env_cfg, from_pixels=False, device=None):
set_gym_backend(env_cfg.backend).set()
env_library = LIBS[env_cfg.library]
@@ -73,7 +73,7 @@ def make_base_env(env_cfg, from_pixels=False):
if env_library is DMControlEnv:
env_task = env_cfg.task
env_kwargs.update({"task_name": env_task})
- env = env_library(**env_kwargs)
+ env = env_library(**env_kwargs, device=device)
return env
@@ -134,7 +134,9 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
return transformed_env
-def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False):
+def make_parallel_env(
+ env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None
+):
if train:
num_envs = env_cfg.num_train_envs
else:
@@ -142,10 +144,12 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False)
def make_env():
with set_gym_backend(env_cfg.backend):
- return make_base_env(env_cfg, from_pixels=from_pixels)
+ return make_base_env(env_cfg, from_pixels=from_pixels, device="cpu")
env = make_transformed_env(
- ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True),
+ ParallelEnv(
+ num_envs, EnvCreator(make_env), serial_for_single=True, device=device
+ ),
env_cfg,
obs_loc,
obs_std,
@@ -154,11 +158,15 @@ def make_env():
return env
-def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False):
- env = make_parallel_env(
- env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels
+def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None):
+ return make_parallel_env(
+ env_cfg,
+ obs_loc,
+ obs_std,
+ train=train,
+ from_pixels=from_pixels,
+ device=device,
)
- return env
# ====================================================================
diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py
index 6e100f92dc3..386f743c7d3 100644
--- a/sota-implementations/discrete_sac/discrete_sac.py
+++ b/sota-implementations/discrete_sac/discrete_sac.py
@@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // eval_iter
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py
index 90f93551d4d..906273ee2f5 100644
--- a/sota-implementations/dqn/dqn_atari.py
+++ b/sota-implementations/dqn/dqn_atari.py
@@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get and log evaluation rewards and eval time
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py
index ac3f17a9203..173f88f7028 100644
--- a/sota-implementations/dqn/dqn_cartpole.py
+++ b/sota-implementations/dqn/dqn_cartpole.py
@@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get and log evaluation rewards and eval time
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml
index ab101e8486a..604e1ac546a 100644
--- a/sota-implementations/dreamer/config.yaml
+++ b/sota-implementations/dreamer/config.yaml
@@ -9,17 +9,13 @@ env:
image_size : 64
horizon: 500
n_parallel_envs: 8
- device:
- _target_: dreamer_utils._default_device
- device: null
+ device: cpu
collector:
total_frames: 5_000_000
init_random_frames: 3000
frames_per_batch: 1000
device:
- _target_: dreamer_utils._default_device
- device: null
optimization:
train_every: 1000
@@ -41,8 +37,6 @@ optimization:
networks:
exploration_noise: 0.3
device:
- _target_: dreamer_utils._default_device
- device: null
state_dim: 30
rssm_hidden_dim: 200
hidden_dim: 400
diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py
index e7b346b2b22..e521b9df386 100644
--- a/sota-implementations/dreamer/dreamer.py
+++ b/sota-implementations/dreamer/dreamer.py
@@ -10,6 +10,7 @@
import torch.cuda
import tqdm
from dreamer_utils import (
+ _default_device,
dump_video,
log_metrics,
make_collector,
@@ -17,7 +18,6 @@
make_environments,
make_replay_buffer,
)
-from hydra.utils import instantiate
# mixed precision training
from torch.cuda.amp import GradScaler
@@ -38,7 +38,7 @@
def main(cfg: "DictConfig"): # noqa: F821
# cfg = correct_for_frame_skip(cfg)
- device = torch.device(instantiate(cfg.networks.device))
+ device = _default_device(cfg.networks.device)
# Create logger
exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name)
@@ -284,7 +284,7 @@ def compile_rssms(module):
# Evaluation
if (i % eval_iter) == 0:
# Real env
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_rollout = test_env.rollout(
eval_rollout_steps,
policy,
@@ -298,7 +298,9 @@ def compile_rssms(module):
log_metrics(logger, eval_metrics, collected_frames)
# Simulated env
if model_based_env_eval is not None:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(
+ ExplorationType.DETERMINISTIC
+ ), torch.no_grad():
eval_rollout = model_based_env_eval.rollout(
eval_rollout_steps,
policy,
diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py
index ff14871b011..73baa310821 100644
--- a/sota-implementations/dreamer/dreamer_utils.py
+++ b/sota-implementations/dreamer/dreamer_utils.py
@@ -9,7 +9,6 @@
import torch
import torch.nn as nn
-from hydra.utils import instantiate
from tensordict import NestedKey
from tensordict.nn import (
InteractionType,
@@ -88,6 +87,7 @@ def _make_env(cfg, device, from_pixels=False):
cfg.env.task,
from_pixels=cfg.env.from_pixels or from_pixels,
pixels_only=cfg.env.from_pixels,
+ device=device,
)
else:
raise NotImplementedError(f"Unknown lib {lib}.")
@@ -98,7 +98,6 @@ def _make_env(cfg, device, from_pixels=False):
env = env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
)
- assert env is not None
return env
@@ -129,7 +128,7 @@ def transform_env(cfg, env):
def make_environments(cfg, parallel_envs=1, logger=None):
"""Make environments for training and evaluation."""
- func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device)
+ func = functools.partial(_make_env, cfg=cfg, device=_default_device(cfg.env.device))
train_env = ParallelEnv(
parallel_envs,
EnvCreator(func),
@@ -138,7 +137,10 @@ def make_environments(cfg, parallel_envs=1, logger=None):
train_env = transform_env(cfg, train_env)
train_env.set_seed(cfg.env.seed)
func = functools.partial(
- _make_env, cfg=cfg, device=cfg.env.device, from_pixels=cfg.logger.video
+ _make_env,
+ cfg=cfg,
+ device=_default_device(cfg.env.device),
+ from_pixels=cfg.logger.video,
)
eval_env = ParallelEnv(
1,
@@ -332,7 +334,7 @@ def make_collector(cfg, train_env, actor_model_explore):
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
- policy_device=instantiate(cfg.collector.device),
+ policy_device=_default_device(cfg.collector.device),
env_device=train_env.device,
storing_device="cpu",
)
@@ -535,7 +537,7 @@ def _dreamer_make_actor_real(
SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys=[action_key],
- default_interaction_type=InteractionType.MODE,
+ default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=CompositeSpec(
diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py
index 0482a595ffa..1998c044305 100644
--- a/sota-implementations/impala/impala_multi_node_ray.py
+++ b/sota-implementations/impala/impala_multi_node_ray.py
@@ -247,7 +247,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py
index ce96cf06ce8..fdee4256c42 100644
--- a/sota-implementations/impala/impala_multi_node_submitit.py
+++ b/sota-implementations/impala/impala_multi_node_submitit.py
@@ -239,7 +239,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py
index bb0f314197a..cf583909620 100644
--- a/sota-implementations/impala/impala_single_node.py
+++ b/sota-implementations/impala/impala_single_node.py
@@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py
index 33513dd3973..ae1894379fd 100644
--- a/sota-implementations/iql/discrete_iql.py
+++ b/sota-implementations/iql/discrete_iql.py
@@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py
index d98724e1371..d1a16fd8192 100644
--- a/sota-implementations/iql/iql_offline.py
+++ b/sota-implementations/iql/iql_offline.py
@@ -130,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# evaluation
if i % evaluation_interval == 0:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py
index b66c6f9dcf2..d50ff806294 100644
--- a/sota-implementations/iql/iql_online.py
+++ b/sota-implementations/iql/iql_online.py
@@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py
index 81551ebefb7..a4d2b88a9d0 100644
--- a/sota-implementations/multiagent/iql.py
+++ b/sota-implementations/multiagent/iql.py
@@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py
index 9d14ff04b04..bd44bb0a043 100644
--- a/sota-implementations/multiagent/maddpg_iddpg.py
+++ b/sota-implementations/multiagent/maddpg_iddpg.py
@@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py
index e752c4d73f2..fa006a7d4a2 100644
--- a/sota-implementations/multiagent/mappo_ippo.py
+++ b/sota-implementations/multiagent/mappo_ippo.py
@@ -236,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py
index d294a9c783e..4e6a962c556 100644
--- a/sota-implementations/multiagent/qmix_vdn.py
+++ b/sota-implementations/multiagent/qmix_vdn.py
@@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py
index 30b7e7e98bc..f7b2523010b 100644
--- a/sota-implementations/multiagent/sac.py
+++ b/sota-implementations/multiagent/sac.py
@@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py
index 908cb7924a3..2b02254032a 100644
--- a/sota-implementations/ppo/ppo_atari.py
+++ b/sota-implementations/ppo/ppo_atari.py
@@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py
index e3e74971a49..219ae1b59b6 100644
--- a/sota-implementations/ppo/ppo_mujoco.py
+++ b/sota-implementations/ppo/ppo_mujoco.py
@@ -210,7 +210,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Get test rewards
- with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg_logger_test_interval:
diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py
index f7a399cda72..9904fe072ab 100644
--- a/sota-implementations/sac/sac.py
+++ b/sota-implementations/sac/sac.py
@@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py
index 97fd039c238..5fbc9b032d7 100644
--- a/sota-implementations/td3/td3.py
+++ b/sota-implementations/td3/td3.py
@@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
diff --git a/sota-implementations/td3_bc/config.yaml b/sota-implementations/td3_bc/config.yaml
new file mode 100644
index 00000000000..54275a94bc2
--- /dev/null
+++ b/sota-implementations/td3_bc/config.yaml
@@ -0,0 +1,45 @@
+# task and env
+env:
+ name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency
+ task: ""
+ library: gymnasium
+ seed: 42
+ max_episode_steps: 1000
+
+# replay buffer
+replay_buffer:
+ dataset: halfcheetah-medium-v2
+ batch_size: 256
+
+# optim
+optim:
+ gradient_steps: 100000
+ gamma: 0.99
+ loss_function: l2
+ lr: 3.0e-4
+ weight_decay: 0.0
+ adam_eps: 1e-4
+ batch_size: 256
+ target_update_polyak: 0.995
+ policy_update_delay: 2
+ policy_noise: 0.2
+ noise_clip: 0.5
+ alpha: 2.5
+
+# network
+network:
+ hidden_sizes: [256, 256]
+ activation: relu
+ device: null
+
+# logging
+logger:
+ backend: wandb
+ project_name: td3+bc_${replay_buffer.dataset}
+ group_name: null
+ exp_name: TD3+BC_${replay_buffer.dataset}
+ mode: online
+ eval_iter: 5000
+ eval_steps: 1000
+ eval_envs: 1
+ video: False
diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py
new file mode 100644
index 00000000000..7c43fdc1a12
--- /dev/null
+++ b/sota-implementations/td3_bc/td3_bc.py
@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""TD3+BC Example.
+
+This is a self-contained example of an offline RL TD3+BC training script.
+
+The helper functions are coded in the utils.py associated with this script.
+
+"""
+import time
+
+import hydra
+import numpy as np
+import torch
+import tqdm
+from torchrl._utils import logger as torchrl_logger
+
+from torchrl.envs import set_gym_backend
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.record.loggers import generate_exp_name, get_logger
+
+from utils import (
+ dump_video,
+ log_metrics,
+ make_environment,
+ make_loss_module,
+ make_offline_replay_buffer,
+ make_optimizer,
+ make_td3_agent,
+)
+
+
+@hydra.main(config_path="", config_name="config")
+def main(cfg: "DictConfig"): # noqa: F821
+ set_gym_backend(cfg.env.library).set()
+
+ # Create logger
+ exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name)
+ logger = None
+ if cfg.logger.backend:
+ logger = get_logger(
+ logger_type=cfg.logger.backend,
+ logger_name="td3bc_logging",
+ experiment_name=exp_name,
+ wandb_kwargs={
+ "mode": cfg.logger.mode,
+ "config": dict(cfg),
+ "project": cfg.logger.project_name,
+ "group": cfg.logger.group_name,
+ },
+ )
+
+ # Set seeds
+ torch.manual_seed(cfg.env.seed)
+ np.random.seed(cfg.env.seed)
+ device = cfg.network.device
+ if device in ("", None):
+ if torch.cuda.is_available():
+ device = "cuda:0"
+ else:
+ device = "cpu"
+ device = torch.device(device)
+
+ # Creante env
+ eval_env = make_environment(
+ cfg,
+ logger=logger,
+ )
+
+ # Create replay buffer
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
+
+ # Create agent
+ model, _ = make_td3_agent(cfg, eval_env, device)
+
+ # Create loss
+ loss_module, target_net_updater = make_loss_module(cfg.optim, model)
+
+ # Create optimizer
+ optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)
+
+ gradient_steps = cfg.optim.gradient_steps
+ evaluation_interval = cfg.logger.eval_iter
+ eval_steps = cfg.logger.eval_steps
+ delayed_updates = cfg.optim.policy_update_delay
+ update_counter = 0
+ pbar = tqdm.tqdm(range(gradient_steps))
+ # Training loop
+ start_time = time.time()
+ for i in pbar:
+ pbar.update(1)
+ # Update actor every delayed_updates
+ update_counter += 1
+ update_actor = update_counter % delayed_updates == 0
+
+ # Sample from replay buffer
+ sampled_tensordict = replay_buffer.sample()
+ if sampled_tensordict.device != device:
+ sampled_tensordict = sampled_tensordict.to(device)
+ else:
+ sampled_tensordict = sampled_tensordict.clone()
+
+ # Compute loss
+ q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
+
+ # Update critic
+ optimizer_critic.zero_grad()
+ q_loss.backward()
+ optimizer_critic.step()
+ q_loss.item()
+
+ to_log = {"q_loss": q_loss.item()}
+
+ # Update actor
+ if update_actor:
+ actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
+ optimizer_actor.zero_grad()
+ actor_loss.backward()
+ optimizer_actor.step()
+
+ # Update target params
+ target_net_updater.step()
+
+ to_log["actor_loss"] = actor_loss.item()
+ to_log.update(actorloss_metadata)
+
+ # evaluation
+ if i % evaluation_interval == 0:
+ with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ eval_td = eval_env.rollout(
+ max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
+ )
+ eval_env.apply(dump_video)
+ eval_reward = eval_td["next", "reward"].sum(1).mean().item()
+ to_log["evaluation_reward"] = eval_reward
+ if logger is not None:
+ log_metrics(logger, to_log, i)
+
+ pbar.close()
+ torchrl_logger.info(f"Training time: {time.time() - start_time}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py
new file mode 100644
index 00000000000..3772eefccde
--- /dev/null
+++ b/sota-implementations/td3_bc/utils.py
@@ -0,0 +1,257 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import functools
+
+import torch
+
+from torch import nn, optim
+from torchrl.data.datasets.d4rl import D4RLExperienceReplay
+from torchrl.data.replay_buffers import SamplerWithoutReplacement
+from torchrl.envs import (
+ CatTensors,
+ Compose,
+ DMControlEnv,
+ DoubleToFloat,
+ EnvCreator,
+ InitTracker,
+ ParallelEnv,
+ RewardSum,
+ StepCounter,
+ TransformedEnv,
+)
+from torchrl.envs.libs.gym import GymEnv, set_gym_backend
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.modules import (
+ AdditiveGaussianWrapper,
+ MLP,
+ SafeModule,
+ SafeSequential,
+ TanhModule,
+ ValueOperator,
+)
+
+from torchrl.objectives import SoftUpdate
+from torchrl.objectives.td3_bc import TD3BCLoss
+from torchrl.record import VideoRecorder
+
+
+# ====================================================================
+# Environment utils
+# -----------------
+
+
+def env_maker(cfg, device="cpu", from_pixels=False):
+ lib = cfg.env.library
+ if lib in ("gym", "gymnasium"):
+ with set_gym_backend(lib):
+ return GymEnv(
+ cfg.env.name,
+ device=device,
+ from_pixels=from_pixels,
+ pixels_only=False,
+ )
+ elif lib == "dm_control":
+ env = DMControlEnv(
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
+ )
+ return TransformedEnv(
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
+ )
+ else:
+ raise NotImplementedError(f"Unknown lib {lib}.")
+
+
+def apply_env_transforms(env, max_episode_steps):
+ transformed_env = TransformedEnv(
+ env,
+ Compose(
+ StepCounter(max_steps=max_episode_steps),
+ InitTracker(),
+ DoubleToFloat(),
+ RewardSum(),
+ ),
+ )
+ return transformed_env
+
+
+def make_environment(cfg, logger=None):
+ """Make environments for training and evaluation."""
+ partial = functools.partial(env_maker, cfg=cfg)
+ parallel_env = ParallelEnv(
+ cfg.logger.eval_envs,
+ EnvCreator(partial),
+ serial_for_single=True,
+ )
+ parallel_env.set_seed(cfg.env.seed)
+
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
+ return train_env
+
+
+# ====================================================================
+# Replay buffer
+# ---------------------------
+
+
+def make_offline_replay_buffer(rb_cfg):
+ data = D4RLExperienceReplay(
+ dataset_id=rb_cfg.dataset,
+ split_trajs=False,
+ batch_size=rb_cfg.batch_size,
+ sampler=SamplerWithoutReplacement(drop_last=False),
+ prefetch=4,
+ direct_download=True,
+ )
+
+ data.append_transform(DoubleToFloat())
+
+ return data
+
+
+# ====================================================================
+# Model
+# -----
+
+
+def make_td3_agent(cfg, train_env, device):
+ """Make TD3 agent."""
+ # Define Actor Network
+ in_keys = ["observation"]
+ action_spec = train_env.action_spec
+ if train_env.batch_size:
+ action_spec = action_spec[(0,) * len(train_env.batch_size)]
+ actor_net_kwargs = {
+ "num_cells": cfg.network.hidden_sizes,
+ "out_features": action_spec.shape[-1],
+ "activation_class": get_activation(cfg),
+ }
+
+ actor_net = MLP(**actor_net_kwargs)
+
+ in_keys_actor = in_keys
+ actor_module = SafeModule(
+ actor_net,
+ in_keys=in_keys_actor,
+ out_keys=[
+ "param",
+ ],
+ )
+ actor = SafeSequential(
+ actor_module,
+ TanhModule(
+ in_keys=["param"],
+ out_keys=["action"],
+ spec=action_spec,
+ ),
+ )
+
+ # Define Critic Network
+ qvalue_net_kwargs = {
+ "num_cells": cfg.network.hidden_sizes,
+ "out_features": 1,
+ "activation_class": get_activation(cfg),
+ }
+
+ qvalue_net = MLP(
+ **qvalue_net_kwargs,
+ )
+
+ qvalue = ValueOperator(
+ in_keys=["action"] + in_keys,
+ module=qvalue_net,
+ )
+
+ model = nn.ModuleList([actor, qvalue]).to(device)
+
+ # init nets
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
+ td = train_env.fake_tensordict()
+ td = td.to(device)
+ for net in model:
+ net(td)
+ del td
+
+ # Exploration wrappers:
+ actor_model_explore = AdditiveGaussianWrapper(
+ model[0],
+ sigma_init=1,
+ sigma_end=1,
+ mean=0,
+ std=0.1,
+ spec=action_spec,
+ ).to(device)
+ return model, actor_model_explore
+
+
+# ====================================================================
+# TD3 Loss
+# ---------
+
+
+def make_loss_module(cfg, model):
+ """Make loss module and target network updater."""
+ # Create TD3 loss
+ loss_module = TD3BCLoss(
+ actor_network=model[0],
+ qvalue_network=model[1],
+ num_qvalue_nets=2,
+ loss_function=cfg.loss_function,
+ delay_actor=True,
+ delay_qvalue=True,
+ action_spec=model[0][1].spec,
+ policy_noise=cfg.policy_noise,
+ noise_clip=cfg.noise_clip,
+ alpha=cfg.alpha,
+ )
+ loss_module.make_value_estimator(gamma=cfg.gamma)
+
+ # Define Target Network Updater
+ target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak)
+ return loss_module, target_net_updater
+
+
+def make_optimizer(cfg, loss_module):
+ critic_params = list(loss_module.qvalue_network_params.values(True, True))
+ actor_params = list(loss_module.actor_network_params.values(True, True))
+
+ optimizer_actor = optim.Adam(
+ actor_params,
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ eps=cfg.adam_eps,
+ )
+ optimizer_critic = optim.Adam(
+ critic_params,
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ eps=cfg.adam_eps,
+ )
+ return optimizer_actor, optimizer_critic
+
+
+# ====================================================================
+# General utils
+# ---------
+
+
+def log_metrics(logger, metrics, step):
+ for metric_name, metric_value in metrics.items():
+ logger.log_scalar(metric_name, metric_value, step)
+
+
+def get_activation(cfg):
+ if cfg.network.activation == "relu":
+ return nn.ReLU
+ elif cfg.network.activation == "tanh":
+ return nn.Tanh
+ elif cfg.network.activation == "leaky_relu":
+ return nn.LeakyReLU
+ else:
+ raise NotImplementedError
+
+
+def dump_video(module):
+ if isinstance(module, VideoRecorder):
+ module.dump()
diff --git a/test/test_cost.py b/test/test_cost.py
index 76fc4e651f4..a318f5694cd 100644
--- a/test/test_cost.py
+++ b/test/test_cost.py
@@ -12,7 +12,7 @@
from dataclasses import asdict, dataclass
from packaging import version as pack_version
-from tensordict._tensordict import unravel_keys
+from tensordict._C import unravel_keys
from tensordict.nn import (
InteractionType,
@@ -98,6 +98,7 @@
A2CLoss,
ClipPPOLoss,
CQLLoss,
+ CrossQLoss,
DDPGLoss,
DiscreteCQLLoss,
DiscreteIQLLoss,
@@ -114,6 +115,7 @@
PPOLoss,
QMixerLoss,
SACLoss,
+ TD3BCLoss,
TD3Loss,
)
from torchrl.objectives.common import LossModule
@@ -261,9 +263,9 @@ def __init__(self):
self.vmap_model = _vmap_func(
self.model,
(None, 0),
- randomness="error"
- if vmap_randomness == "error"
- else self.vmap_randomness,
+ randomness=(
+ "error" if vmap_randomness == "error" else self.vmap_randomness
+ ),
)
def forward(self, td):
@@ -319,9 +321,9 @@ def _create_mock_actor(
spec=CompositeSpec(
{
"action": action_spec,
- "action_value"
- if action_value_key is None
- else action_value_key: None,
+ (
+ "action_value" if action_value_key is None else action_value_key
+ ): None,
"chosen_action_value": None,
},
shape=[],
@@ -2714,11 +2716,7 @@ def test_td3_reduction(self, reduction):
assert loss[key].shape == torch.Size([])
-@pytest.mark.skipif(
- not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
-)
-@pytest.mark.parametrize("version", [1, 2])
-class TestSAC(LossModuleTestBase):
+class TestTD3BC(LossModuleTestBase):
seed = 0
def _create_mock_actor(
@@ -2727,36 +2725,35 @@ def _create_mock_actor(
obs_dim=3,
action_dim=4,
device="cpu",
- observation_key="observation",
- action_key="action",
+ in_keys=None,
+ out_keys=None,
+ dropout=0.0,
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
- net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
- module = TensorDictModule(
- net, in_keys=[observation_key], out_keys=["loc", "scale"]
+ module = nn.Sequential(
+ nn.Linear(obs_dim, obs_dim),
+ nn.Dropout(dropout),
+ nn.Linear(obs_dim, action_dim),
)
- actor = ProbabilisticActor(
- module=module,
- in_keys=["loc", "scale"],
- spec=action_spec,
- distribution_class=TanhNormal,
- out_keys=[action_key],
+ actor = Actor(
+ spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys
)
return actor.to(device)
- def _create_mock_qvalue(
+ def _create_mock_value(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
- observation_key="observation",
- action_key="action",
out_keys=None,
+ action_key="action",
+ observation_key="observation",
):
+ # Actor
class ValueClass(nn.Module):
def __init__(self):
super().__init__()
@@ -2766,29 +2763,17 @@ def forward(self, obs, act):
return self.linear(torch.cat([obs, act], -1))
module = ValueClass()
- qvalue = ValueOperator(
+ value = ValueOperator(
module=module,
in_keys=[observation_key, action_key],
out_keys=out_keys,
)
- return qvalue.to(device)
+ return value.to(device)
- def _create_mock_value(
- self,
- batch=2,
- obs_dim=3,
- action_dim=4,
- device="cpu",
- observation_key="observation",
- out_keys=None,
+ def _create_mock_distributional_actor(
+ self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5
):
- module = nn.Linear(obs_dim, 1)
- value = ValueOperator(
- module=module,
- in_keys=[observation_key],
- out_keys=out_keys,
- )
- return value.to(device)
+ raise NotImplementedError
def _create_mock_common_layer_setup(
self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2
@@ -2805,7 +2790,7 @@ def _create_mock_common_layer_setup(
depth=1,
out_features=2 * n_act,
)
- qvalue = MLP(
+ value = MLP(
in_features=n_hidden + n_act,
num_cells=ncells,
depth=1,
@@ -2836,31 +2821,27 @@ def _create_mock_common_layer_setup(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
+ return_log_prob=True,
),
)
- qvalue_head = Mod(
- qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"]
+ value_head = Mod(
+ value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
- qvalue = Seq(common, qvalue_head)
- return actor, qvalue, common, td
-
- def _create_mock_distributional_actor(
- self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5
- ):
- raise NotImplementedError
+ value = Seq(common, value_head)
+ return actor, value, common, td
- def _create_mock_data_sac(
+ def _create_mock_data_td3bc(
self,
- batch=16,
+ batch=8,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
- observation_key="observation",
action_key="action",
+ observation_key="observation",
+ reward_key="reward",
done_key="done",
terminated_key="terminated",
- reward_key="reward",
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
@@ -2888,7 +2869,7 @@ def _create_mock_data_sac(
)
return td
- def _create_seq_mock_data_sac(
+ def _create_seq_mock_data_td3bc(
self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
):
# create a tensordict
@@ -2904,269 +2885,225 @@ def _create_seq_mock_data_sac(
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
- mask = torch.ones(batch, T, dtype=torch.bool, device=device)
+ mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
- "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "observation": obs * mask.to(obs.dtype),
"next": {
- "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "observation": next_obs * mask.to(obs.dtype),
+ "reward": reward * mask.to(obs.dtype),
"done": done,
"terminated": terminated,
- "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
- "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "action": action * mask.to(obs.dtype),
},
names=[None, "time"],
device=device,
)
return td
- @pytest.mark.parametrize("delay_value", (True, False))
- @pytest.mark.parametrize("delay_actor", (True, False))
- @pytest.mark.parametrize("delay_qvalue", (True, False))
- @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
+ @pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize(
+ "delay_actor, delay_qvalue", [(False, False), (True, True)]
+ )
+ @pytest.mark.parametrize("policy_noise", [0.1, 1.0])
+ @pytest.mark.parametrize("noise_clip", [0.1, 1.0])
+ @pytest.mark.parametrize("alpha", [0.1, 6.0])
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
- def test_sac(
+ @pytest.mark.parametrize("use_action_spec", [True, False])
+ @pytest.mark.parametrize("dropout", [0.0, 0.1])
+ def test_td3bc(
self,
- delay_value,
delay_actor,
delay_qvalue,
- num_qvalue,
device,
- version,
+ policy_noise,
+ noise_clip,
+ alpha,
td_est,
+ use_action_spec,
+ dropout,
):
- if (delay_actor or delay_qvalue) and not delay_value:
- pytest.skip("incompatible config")
-
torch.manual_seed(self.seed)
- td = self._create_mock_data_sac(device=device)
-
- actor = self._create_mock_actor(device=device)
- qvalue = self._create_mock_qvalue(device=device)
- if version == 1:
- value = self._create_mock_value(device=device)
+ actor = self._create_mock_actor(device=device, dropout=dropout)
+ value = self._create_mock_value(device=device)
+ td = self._create_mock_data_td3bc(device=device)
+ if use_action_spec:
+ action_spec = actor.spec
+ bounds = None
else:
- value = None
-
- kwargs = {}
- if delay_actor:
- kwargs["delay_actor"] = True
- if delay_qvalue:
- kwargs["delay_qvalue"] = True
- if delay_value:
- kwargs["delay_value"] = True
-
- loss_fn = SACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- value_network=value,
- num_qvalue_nets=num_qvalue,
+ bounds = (-1, 1)
+ action_spec = None
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=action_spec,
+ bounds=bounds,
loss_function="l2",
- **kwargs,
+ policy_noise=policy_noise,
+ noise_clip=noise_clip,
+ alpha=alpha,
+ delay_actor=delay_actor,
+ delay_qvalue=delay_qvalue,
)
-
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
if td_est is not None:
loss_fn.make_value_estimator(td_est)
-
- with _check_td_steady(td), pytest.warns(
- UserWarning, match="No target network updater"
+ with (
+ pytest.warns(
+ UserWarning,
+ match="No target network updater has been associated with this loss module",
+ )
+ if (delay_actor or delay_qvalue)
+ else contextlib.nullcontext()
):
- loss = loss_fn(td)
-
- assert loss_fn.tensor_keys.priority in td.keys()
+ with _check_td_steady(td):
+ loss = loss_fn(td)
- # check that losses are independent
- for k in loss.keys():
- if not k.startswith("loss"):
- continue
- loss[k].sum().backward(retain_graph=True)
- if k == "loss_actor":
- if version == 1:
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(True, True)
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(True, True)
+ )
+ # check that losses are independent
+ for k in loss.keys():
+ if not k.startswith("loss"):
+ continue
+ loss[k].sum().backward(retain_graph=True)
+ if k == "loss_actor":
assert all(
(p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.value_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- assert not any(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- elif k == "loss_value" and version == 1:
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- assert not any(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.value_network_params.values(
- include_nested=True, leaves_only=True
+ for p in loss_fn.qvalue_network_params.values(True, True)
)
- )
- elif k == "loss_qvalue":
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(True, True)
)
- )
- if version == 1:
+ elif k == "loss_qvalue":
assert all(
(p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.value_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- assert not any(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- elif k == "loss_alpha":
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
+ for p in loss_fn.actor_network_params.values(True, True)
)
- )
- if version == 1:
- assert all(
+ assert not any(
(p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.value_network_params.values(
- include_nested=True, leaves_only=True
- )
+ for p in loss_fn.qvalue_network_params.values(True, True)
)
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- else:
- raise NotImplementedError(k)
- loss_fn.zero_grad()
+ else:
+ raise NotImplementedError(k)
+ loss_fn.zero_grad()
- sum(
- [item for name, item in loss.items() if name.startswith("loss_")]
- ).backward()
- named_parameters = list(loss_fn.named_parameters())
- named_buffers = list(loss_fn.named_buffers())
+ sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = list(loss_fn.named_parameters())
+ named_buffers = list(loss_fn.named_buffers())
- assert len({p for n, p in named_parameters}) == len(list(named_parameters))
- assert len({p for n, p in named_buffers}) == len(list(named_buffers))
+ assert len({p for n, p in named_parameters}) == len(list(named_parameters))
+ assert len({p for n, p in named_buffers}) == len(list(named_buffers))
- for name, p in named_parameters:
- if not name.startswith("target_"):
- assert (
- p.grad is not None and p.grad.norm() > 0.0
- ), f"parameter {name} (shape: {p.shape}) has a null gradient"
- else:
- assert (
- p.grad is None or p.grad.norm() == 0.0
- ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
- @pytest.mark.parametrize("delay_value", (True, False))
- @pytest.mark.parametrize("delay_actor", (True, False))
- @pytest.mark.parametrize("delay_qvalue", (True, False))
- @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
- def test_sac_state_dict(
+ @pytest.mark.parametrize(
+ "delay_actor, delay_qvalue", [(False, False), (True, True)]
+ )
+ @pytest.mark.parametrize("policy_noise", [0.1])
+ @pytest.mark.parametrize("noise_clip", [0.1])
+ @pytest.mark.parametrize("alpha", [0.1])
+ @pytest.mark.parametrize("use_action_spec", [True, False])
+ def test_td3bc_state_dict(
self,
- delay_value,
delay_actor,
delay_qvalue,
- num_qvalue,
device,
- version,
+ policy_noise,
+ noise_clip,
+ alpha,
+ use_action_spec,
):
- if (delay_actor or delay_qvalue) and not delay_value:
- pytest.skip("incompatible config")
-
torch.manual_seed(self.seed)
-
actor = self._create_mock_actor(device=device)
- qvalue = self._create_mock_qvalue(device=device)
- if version == 1:
- value = self._create_mock_value(device=device)
+ value = self._create_mock_value(device=device)
+ if use_action_spec:
+ action_spec = actor.spec
+ bounds = None
else:
- value = None
-
- kwargs = {}
- if delay_actor:
- kwargs["delay_actor"] = True
- if delay_qvalue:
- kwargs["delay_qvalue"] = True
- if delay_value:
- kwargs["delay_value"] = True
-
- loss_fn = SACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- value_network=value,
- num_qvalue_nets=num_qvalue,
+ bounds = (-1, 1)
+ action_spec = None
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=action_spec,
+ bounds=bounds,
loss_function="l2",
- **kwargs,
+ policy_noise=policy_noise,
+ noise_clip=noise_clip,
+ alpha=alpha,
+ delay_actor=delay_actor,
+ delay_qvalue=delay_qvalue,
)
sd = loss_fn.state_dict()
- loss_fn2 = SACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- value_network=value,
- num_qvalue_nets=num_qvalue,
+ loss_fn2 = TD3BCLoss(
+ actor,
+ value,
+ action_spec=action_spec,
+ bounds=bounds,
loss_function="l2",
- **kwargs,
+ policy_noise=policy_noise,
+ noise_clip=noise_clip,
+ alpha=alpha,
+ delay_actor=delay_actor,
+ delay_qvalue=delay_qvalue,
)
loss_fn2.load_state_dict(sd)
+ @pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [False, True])
- def test_sac_separate_losses(
+ def test_td3bc_separate_losses(
self,
device,
separate_losses,
- version,
n_act=4,
):
torch.manual_seed(self.seed)
- actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act)
-
- loss_fn = SACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)),
- num_qvalue_nets=1,
+ actor, value, common, td = self._create_mock_common_layer_setup(n_act=n_act)
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
+ loss_function="l2",
separate_losses=separate_losses,
)
with pytest.warns(UserWarning, match="No target network updater has been"):
loss = loss_fn(td)
- assert loss_fn.tensor_keys.priority in td.keys()
-
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(True, True)
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(True, True)
+ )
# check that losses are independent
for k in loss.keys():
if not k.startswith("loss"):
@@ -3175,25 +3112,19 @@ def test_sac_separate_losses(
if k == "loss_actor":
assert all(
(p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
+ for p in loss_fn.qvalue_network_params.values(True, True)
)
assert not any(
(p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
- )
+ for p in loss_fn.actor_network_params.values(True, True)
)
elif k == "loss_qvalue":
- common_layers_no = len(list(common.parameters()))
assert all(
(p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
- )
+ for p in loss_fn.actor_network_params.values(True, True)
)
if separate_losses:
+ common_layers_no = len(list(common.parameters()))
common_layers = itertools.islice(
loss_fn.qvalue_network_params.values(True, True),
common_layers_no,
@@ -3216,235 +3147,1686 @@ def test_sac_separate_losses(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.qvalue_network_params.values(True, True)
)
- elif k == "loss_alpha":
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.actor_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
- assert all(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- )
+
else:
raise NotImplementedError(k)
loss_fn.zero_grad()
+ @pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("n", range(1, 4))
- @pytest.mark.parametrize("delay_value", (True, False))
- @pytest.mark.parametrize("delay_actor", (True, False))
- @pytest.mark.parametrize("delay_qvalue", (True, False))
- @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
- def test_sac_batcher(
+ @pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)])
+ @pytest.mark.parametrize("policy_noise", [0.1, 1.0])
+ @pytest.mark.parametrize("noise_clip", [0.1, 1.0])
+ @pytest.mark.parametrize("alpha", [0.1, 6.0])
+ def test_td3bc_batcher(
self,
n,
- delay_value,
delay_actor,
delay_qvalue,
- num_qvalue,
device,
- version,
+ policy_noise,
+ noise_clip,
+ alpha,
+ gamma=0.9,
):
- if (delay_actor or delay_qvalue) and not delay_value:
- pytest.skip("incompatible config")
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_sac(device=device)
-
actor = self._create_mock_actor(device=device)
- qvalue = self._create_mock_qvalue(device=device)
- if version == 1:
- value = self._create_mock_value(device=device)
- else:
- value = None
-
- kwargs = {}
- if delay_actor:
- kwargs["delay_actor"] = True
- if delay_qvalue:
- kwargs["delay_qvalue"] = True
- if delay_value:
- kwargs["delay_value"] = True
-
- loss_fn = SACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- value_network=value,
- num_qvalue_nets=num_qvalue,
- loss_function="l2",
- **kwargs,
+ value = self._create_mock_value(device=device)
+ td = self._create_seq_mock_data_td3bc(device=device)
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=actor.spec,
+ policy_noise=policy_noise,
+ noise_clip=noise_clip,
+ alpha=alpha,
+ delay_qvalue=delay_qvalue,
+ delay_actor=delay_actor,
)
- ms = MultiStep(gamma=0.9, n_steps=n).to(device)
+ ms = MultiStep(gamma=gamma, n_steps=n).to(device)
td_clone = td.clone()
ms_td = ms(td_clone)
torch.manual_seed(0)
np.random.seed(0)
- with pytest.warns(
- UserWarning,
- match="No target network updater has been associated with this loss module",
- ):
- with _check_td_steady(ms_td):
- loss_ms = loss_fn(ms_td)
- assert loss_fn.tensor_keys.priority in ms_td.keys()
-
- with torch.no_grad():
- torch.manual_seed(0) # log-prob is computed with a random action
- np.random.seed(0)
- loss = loss_fn(td)
- if n == 1:
- assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
- _loss = sum(
- [item for name, item in loss.items() if name.startswith("loss_")]
- )
- _loss_ms = sum(
- [item for name, item in loss_ms.items() if name.startswith("loss_")]
- )
- assert (
- abs(_loss - _loss_ms) < 1e-3
- ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0"
- else:
- with pytest.raises(AssertionError):
- assert_allclose_td(loss, loss_ms)
- sum(
- [item for name, item in loss_ms.items() if name.startswith("loss_")]
- ).backward()
- named_parameters = loss_fn.named_parameters()
- for name, p in named_parameters:
- if not name.startswith("target_"):
- assert (
- p.grad is not None and p.grad.norm() > 0.0
- ), f"parameter {name} (shape: {p.shape}) has a null gradient"
- else:
- assert (
- p.grad is None or p.grad.norm() == 0.0
- ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
- # Check param update effect on targets
- target_actor = [
- p.clone()
- for p in loss_fn.target_actor_network_params.values(
- include_nested=True, leaves_only=True
- )
- ]
- target_qvalue = [
- p.clone()
- for p in loss_fn.target_qvalue_network_params.values(
- include_nested=True, leaves_only=True
+ with (
+ pytest.warns(UserWarning, match="No target network updater has been")
+ if (delay_qvalue or delay_actor)
+ else contextlib.nullcontext()
+ ), _check_td_steady(ms_td):
+ loss_ms = loss_fn(ms_td)
+ assert loss_fn.tensor_keys.priority in ms_td.keys()
+
+ if delay_qvalue or delay_actor:
+ SoftUpdate(loss_fn, eps=0.5)
+
+ with torch.no_grad():
+ torch.manual_seed(0) # log-prob is computed with a random action
+ np.random.seed(0)
+ loss = loss_fn(td)
+
+ if n == 1:
+ assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
+ _loss = sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ )
+ _loss_ms = sum(
+ [item for name, item in loss_ms.items() if name.startswith("loss_")]
+ )
+ assert (
+ abs(_loss - _loss_ms) < 1e-3
+ ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0"
+ else:
+ with pytest.raises(AssertionError):
+ assert_allclose_td(loss, loss_ms)
+
+ sum(
+ [item for name, item in loss_ms.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = loss_fn.named_parameters()
+
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
+
+ # Check param update effect on targets
+ target_actor = loss_fn.target_actor_network_params.clone().values(
+ include_nested=True, leaves_only=True
+ )
+ target_qvalue = loss_fn.target_qvalue_network_params.clone().values(
+ include_nested=True, leaves_only=True
+ )
+ for p in loss_fn.parameters():
+ if p.requires_grad:
+ p.data += torch.randn_like(p)
+ target_actor2 = loss_fn.target_actor_network_params.clone().values(
+ include_nested=True, leaves_only=True
+ )
+ target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values(
+ include_nested=True, leaves_only=True
+ )
+ if loss_fn.delay_actor:
+ assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2))
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
+ )
+ if loss_fn.delay_qvalue:
+ assert all(
+ (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2)
+ )
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
+ )
+
+ # check that policy is updated after parameter update
+ actorp_set = set(actor.parameters())
+ loss_fnp_set = set(loss_fn.parameters())
+ assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set)
+ parameters = [p.clone() for p in actor.parameters()]
+ for p in loss_fn.parameters():
+ if p.requires_grad:
+ p.data += torch.randn_like(p)
+ assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))
+
+ @pytest.mark.parametrize(
+ "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
+ )
+ def test_td3bc_tensordict_keys(self, td_est):
+ actor = self._create_mock_actor()
+ value = self._create_mock_value()
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=actor.spec,
+ )
+
+ default_keys = {
+ "priority": "td_error",
+ "state_action_value": "state_action_value",
+ "action": "action",
+ "reward": "reward",
+ "done": "done",
+ "terminated": "terminated",
+ }
+
+ self.tensordict_keys_test(
+ loss_fn,
+ default_keys=default_keys,
+ td_est=td_est,
+ )
+
+ value = self._create_mock_value(out_keys=["state_action_value_test"])
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=actor.spec,
+ )
+ key_mapping = {
+ "state_action_value": ("value", "state_action_value_test"),
+ "reward": ("reward", "reward_test"),
+ "done": ("done", ("done", "test")),
+ "terminated": ("terminated", ("terminated", "test")),
+ }
+ self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)
+
+ @pytest.mark.parametrize("spec", [True, False])
+ @pytest.mark.parametrize("bounds", [True, False])
+ def test_constructor(self, spec, bounds):
+ actor = self._create_mock_actor()
+ value = self._create_mock_value()
+ action_spec = actor.spec if spec else None
+ bounds = (-1, 1) if bounds else None
+ if (bounds is not None and action_spec is not None) or (
+ bounds is None and action_spec is None
+ ):
+ with pytest.raises(ValueError, match="but not both"):
+ TD3BCLoss(
+ actor,
+ value,
+ action_spec=action_spec,
+ bounds=bounds,
)
- ]
- if version == 1:
- target_value = [
- p.clone()
- for p in loss_fn.target_value_network_params.values(
+ return
+ TD3BCLoss(
+ actor,
+ value,
+ action_spec=action_spec,
+ bounds=bounds,
+ )
+
+ # TODO: test for action_key, atm the action key of the TD3+BC loss is not configurable,
+ # since it is used in it's constructor
+ @pytest.mark.parametrize("observation_key", ["observation", "observation2"])
+ @pytest.mark.parametrize("reward_key", ["reward", "reward2"])
+ @pytest.mark.parametrize("done_key", ["done", "done2"])
+ @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
+ def test_td3bc_notensordict(
+ self, observation_key, reward_key, done_key, terminated_key
+ ):
+ torch.manual_seed(self.seed)
+ actor = self._create_mock_actor(in_keys=[observation_key])
+ qvalue = self._create_mock_value(
+ observation_key=observation_key, out_keys=["state_action_value"]
+ )
+ td = self._create_mock_data_td3bc(
+ observation_key=observation_key,
+ reward_key=reward_key,
+ done_key=done_key,
+ terminated_key=terminated_key,
+ )
+ loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
+ loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key)
+
+ kwargs = {
+ observation_key: td.get(observation_key),
+ f"next_{reward_key}": td.get(("next", reward_key)),
+ f"next_{done_key}": td.get(("next", done_key)),
+ f"next_{terminated_key}": td.get(("next", terminated_key)),
+ f"next_{observation_key}": td.get(("next", observation_key)),
+ "action": td.get("action"),
+ }
+ td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
+
+ with pytest.warns(UserWarning, match="No target network updater has been"):
+ torch.manual_seed(0)
+ loss_val_td = loss(td)
+ torch.manual_seed(0)
+ loss_val = loss(**kwargs)
+ loss_val_reconstruct = TensorDict(dict(zip(loss.out_keys, loss_val)), [])
+ assert_allclose_td(loss_val_reconstruct, loss_val_td)
+
+ # test select
+ loss.select_out_keys("loss_actor", "loss_qvalue")
+ torch.manual_seed(0)
+ if torch.__version__ >= "2.0.0":
+ loss_actor, loss_qvalue = loss(**kwargs)
+ else:
+ with pytest.raises(
+ RuntimeError,
+ match="You are likely using tensordict.nn.dispatch with keyword arguments",
+ ):
+ loss_actor, loss_qvalue = loss(**kwargs)
+ return
+
+ assert loss_actor == loss_val_td["loss_actor"]
+ assert loss_qvalue == loss_val_td["loss_qvalue"]
+
+ @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
+ def test_td3bc_reduction(self, reduction):
+ torch.manual_seed(self.seed)
+ device = (
+ torch.device("cpu")
+ if torch.cuda.device_count() == 0
+ else torch.device("cuda")
+ )
+ actor = self._create_mock_actor(device=device)
+ value = self._create_mock_value(device=device)
+ td = self._create_mock_data_td3bc(device=device)
+ action_spec = actor.spec
+ bounds = None
+ loss_fn = TD3BCLoss(
+ actor,
+ value,
+ action_spec=action_spec,
+ bounds=bounds,
+ loss_function="l2",
+ delay_qvalue=False,
+ delay_actor=False,
+ reduction=reduction,
+ )
+ loss_fn.make_value_estimator()
+ loss = loss_fn(td)
+ if reduction == "none":
+ for key in loss.keys():
+ if key.startswith("loss"):
+ assert loss[key].shape == td.shape
+ else:
+ for key in loss.keys():
+ if not key.startswith("loss"):
+ continue
+ assert loss[key].shape == torch.Size([])
+
+
+@pytest.mark.skipif(
+ not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
+)
+@pytest.mark.parametrize("version", [1, 2])
+class TestSAC(LossModuleTestBase):
+ seed = 0
+
+ def _create_mock_actor(
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ observation_key="observation",
+ action_key="action",
+ ):
+ # Actor
+ action_spec = BoundedTensorSpec(
+ -torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
+ )
+ net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
+ module = TensorDictModule(
+ net, in_keys=[observation_key], out_keys=["loc", "scale"]
+ )
+ actor = ProbabilisticActor(
+ module=module,
+ in_keys=["loc", "scale"],
+ spec=action_spec,
+ distribution_class=TanhNormal,
+ out_keys=[action_key],
+ )
+ return actor.to(device)
+
+ def _create_mock_qvalue(
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ observation_key="observation",
+ action_key="action",
+ out_keys=None,
+ ):
+ class ValueClass(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(obs_dim + action_dim, 1)
+
+ def forward(self, obs, act):
+ return self.linear(torch.cat([obs, act], -1))
+
+ module = ValueClass()
+ qvalue = ValueOperator(
+ module=module,
+ in_keys=[observation_key, action_key],
+ out_keys=out_keys,
+ )
+ return qvalue.to(device)
+
+ def _create_mock_value(
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ observation_key="observation",
+ out_keys=None,
+ ):
+ module = nn.Linear(obs_dim, 1)
+ value = ValueOperator(
+ module=module,
+ in_keys=[observation_key],
+ out_keys=out_keys,
+ )
+ return value.to(device)
+
+ def _create_mock_common_layer_setup(
+ self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2
+ ):
+ common = MLP(
+ num_cells=ncells,
+ in_features=n_obs,
+ depth=3,
+ out_features=n_hidden,
+ )
+ actor_net = MLP(
+ num_cells=ncells,
+ in_features=n_hidden,
+ depth=1,
+ out_features=2 * n_act,
+ )
+ qvalue = MLP(
+ in_features=n_hidden + n_act,
+ num_cells=ncells,
+ depth=1,
+ out_features=1,
+ )
+ batch = [batch]
+ td = TensorDict(
+ {
+ "obs": torch.randn(*batch, n_obs),
+ "action": torch.randn(*batch, n_act),
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
+ "terminated": torch.zeros(*batch, 1, dtype=torch.bool),
+ "next": {
+ "obs": torch.randn(*batch, n_obs),
+ "reward": torch.randn(*batch, 1),
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
+ "terminated": torch.zeros(*batch, 1, dtype=torch.bool),
+ },
+ },
+ batch,
+ )
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
+ actor = ProbSeq(
+ common,
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
+ ProbMod(
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ ),
+ )
+ qvalue_head = Mod(
+ qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"]
+ )
+ qvalue = Seq(common, qvalue_head)
+ return actor, qvalue, common, td
+
+ def _create_mock_distributional_actor(
+ self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5
+ ):
+ raise NotImplementedError
+
+ def _create_mock_data_sac(
+ self,
+ batch=16,
+ obs_dim=3,
+ action_dim=4,
+ atoms=None,
+ device="cpu",
+ observation_key="observation",
+ action_key="action",
+ done_key="done",
+ terminated_key="terminated",
+ reward_key="reward",
+ ):
+ # create a tensordict
+ obs = torch.randn(batch, obs_dim, device=device)
+ next_obs = torch.randn(batch, obs_dim, device=device)
+ if atoms:
+ raise NotImplementedError
+ else:
+ action = torch.randn(batch, action_dim, device=device).clamp(-1, 1)
+ reward = torch.randn(batch, 1, device=device)
+ done = torch.zeros(batch, 1, dtype=torch.bool, device=device)
+ terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device)
+ td = TensorDict(
+ batch_size=(batch,),
+ source={
+ observation_key: obs,
+ "next": {
+ observation_key: next_obs,
+ done_key: done,
+ terminated_key: terminated,
+ reward_key: reward,
+ },
+ action_key: action,
+ },
+ device=device,
+ )
+ return td
+
+ def _create_seq_mock_data_sac(
+ self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
+ ):
+ # create a tensordict
+ total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
+ obs = total_obs[:, :T]
+ next_obs = total_obs[:, 1:]
+ if atoms:
+ action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
+ -1, 1
+ )
+ else:
+ action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
+ reward = torch.randn(batch, T, 1, device=device)
+ done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
+ terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
+ mask = torch.ones(batch, T, dtype=torch.bool, device=device)
+ td = TensorDict(
+ batch_size=(batch, T),
+ source={
+ "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "next": {
+ "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "done": done,
+ "terminated": terminated,
+ "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ },
+ "collector": {"mask": mask},
+ "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ },
+ names=[None, "time"],
+ device=device,
+ )
+ return td
+
+ @pytest.mark.parametrize("delay_value", (True, False))
+ @pytest.mark.parametrize("delay_actor", (True, False))
+ @pytest.mark.parametrize("delay_qvalue", (True, False))
+ @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
+ def test_sac(
+ self,
+ delay_value,
+ delay_actor,
+ delay_qvalue,
+ num_qvalue,
+ device,
+ version,
+ td_est,
+ ):
+ if (delay_actor or delay_qvalue) and not delay_value:
+ pytest.skip("incompatible config")
+
+ torch.manual_seed(self.seed)
+ td = self._create_mock_data_sac(device=device)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+ if version == 1:
+ value = self._create_mock_value(device=device)
+ else:
+ value = None
+
+ kwargs = {}
+ if delay_actor:
+ kwargs["delay_actor"] = True
+ if delay_qvalue:
+ kwargs["delay_qvalue"] = True
+ if delay_value:
+ kwargs["delay_value"] = True
+
+ loss_fn = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ **kwargs,
+ )
+
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
+ with pytest.raises(NotImplementedError):
+ loss_fn.make_value_estimator(td_est)
+ return
+ if td_est is not None:
+ loss_fn.make_value_estimator(td_est)
+
+ with _check_td_steady(td), pytest.warns(
+ UserWarning, match="No target network updater"
+ ):
+ loss = loss_fn(td)
+
+ assert loss_fn.tensor_keys.priority in td.keys()
+
+ # check that losses are independent
+ for k in loss.keys():
+ if not k.startswith("loss"):
+ continue
+ loss[k].sum().backward(retain_graph=True)
+ if k == "loss_actor":
+ if version == 1:
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.value_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ elif k == "loss_value" and version == 1:
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.value_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ elif k == "loss_qvalue":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ if version == 1:
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.value_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ elif k == "loss_alpha":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ if version == 1:
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.value_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ else:
+ raise NotImplementedError(k)
+ loss_fn.zero_grad()
+
+ sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = list(loss_fn.named_parameters())
+ named_buffers = list(loss_fn.named_buffers())
+
+ assert len({p for n, p in named_parameters}) == len(list(named_parameters))
+ assert len({p for n, p in named_buffers}) == len(list(named_buffers))
+
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
+
+ @pytest.mark.parametrize("delay_value", (True, False))
+ @pytest.mark.parametrize("delay_actor", (True, False))
+ @pytest.mark.parametrize("delay_qvalue", (True, False))
+ @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("device", get_default_devices())
+ def test_sac_state_dict(
+ self,
+ delay_value,
+ delay_actor,
+ delay_qvalue,
+ num_qvalue,
+ device,
+ version,
+ ):
+ if (delay_actor or delay_qvalue) and not delay_value:
+ pytest.skip("incompatible config")
+
+ torch.manual_seed(self.seed)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+ if version == 1:
+ value = self._create_mock_value(device=device)
+ else:
+ value = None
+
+ kwargs = {}
+ if delay_actor:
+ kwargs["delay_actor"] = True
+ if delay_qvalue:
+ kwargs["delay_qvalue"] = True
+ if delay_value:
+ kwargs["delay_value"] = True
+
+ loss_fn = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ **kwargs,
+ )
+ sd = loss_fn.state_dict()
+ loss_fn2 = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ **kwargs,
+ )
+ loss_fn2.load_state_dict(sd)
+
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("separate_losses", [False, True])
+ def test_sac_separate_losses(
+ self,
+ device,
+ separate_losses,
+ version,
+ n_act=4,
+ ):
+ torch.manual_seed(self.seed)
+ actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act)
+
+ loss_fn = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)),
+ num_qvalue_nets=1,
+ separate_losses=separate_losses,
+ )
+ with pytest.warns(UserWarning, match="No target network updater has been"):
+ loss = loss_fn(td)
+
+ assert loss_fn.tensor_keys.priority in td.keys()
+
+ # check that losses are independent
+ for k in loss.keys():
+ if not k.startswith("loss"):
+ continue
+ loss[k].sum().backward(retain_graph=True)
+ if k == "loss_actor":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ elif k == "loss_qvalue":
+ common_layers_no = len(list(common.parameters()))
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ if separate_losses:
+ common_layers = itertools.islice(
+ loss_fn.qvalue_network_params.values(True, True),
+ common_layers_no,
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in common_layers
+ )
+ qvalue_layers = itertools.islice(
+ loss_fn.qvalue_network_params.values(True, True),
+ common_layers_no,
+ None,
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in qvalue_layers
+ )
+ else:
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(True, True)
+ )
+ elif k == "loss_alpha":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ else:
+ raise NotImplementedError(k)
+ loss_fn.zero_grad()
+
+ @pytest.mark.parametrize("n", range(1, 4))
+ @pytest.mark.parametrize("delay_value", (True, False))
+ @pytest.mark.parametrize("delay_actor", (True, False))
+ @pytest.mark.parametrize("delay_qvalue", (True, False))
+ @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
+ @pytest.mark.parametrize("device", get_default_devices())
+ def test_sac_batcher(
+ self,
+ n,
+ delay_value,
+ delay_actor,
+ delay_qvalue,
+ num_qvalue,
+ device,
+ version,
+ ):
+ if (delay_actor or delay_qvalue) and not delay_value:
+ pytest.skip("incompatible config")
+ torch.manual_seed(self.seed)
+ td = self._create_seq_mock_data_sac(device=device)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+ if version == 1:
+ value = self._create_mock_value(device=device)
+ else:
+ value = None
+
+ kwargs = {}
+ if delay_actor:
+ kwargs["delay_actor"] = True
+ if delay_qvalue:
+ kwargs["delay_qvalue"] = True
+ if delay_value:
+ kwargs["delay_value"] = True
+
+ loss_fn = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ **kwargs,
+ )
+
+ ms = MultiStep(gamma=0.9, n_steps=n).to(device)
+
+ td_clone = td.clone()
+ ms_td = ms(td_clone)
+
+ torch.manual_seed(0)
+ np.random.seed(0)
+ with pytest.warns(
+ UserWarning,
+ match="No target network updater has been associated with this loss module",
+ ):
+ with _check_td_steady(ms_td):
+ loss_ms = loss_fn(ms_td)
+ assert loss_fn.tensor_keys.priority in ms_td.keys()
+
+ with torch.no_grad():
+ torch.manual_seed(0) # log-prob is computed with a random action
+ np.random.seed(0)
+ loss = loss_fn(td)
+ if n == 1:
+ assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
+ _loss = sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ )
+ _loss_ms = sum(
+ [item for name, item in loss_ms.items() if name.startswith("loss_")]
+ )
+ assert (
+ abs(_loss - _loss_ms) < 1e-3
+ ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0"
+ else:
+ with pytest.raises(AssertionError):
+ assert_allclose_td(loss, loss_ms)
+ sum(
+ [item for name, item in loss_ms.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = loss_fn.named_parameters()
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
+
+ # Check param update effect on targets
+ target_actor = [
+ p.clone()
+ for p in loss_fn.target_actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ target_qvalue = [
+ p.clone()
+ for p in loss_fn.target_qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ if version == 1:
+ target_value = [
+ p.clone()
+ for p in loss_fn.target_value_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ for p in loss_fn.parameters():
+ if p.requires_grad:
+ p.data += torch.randn_like(p)
+ target_actor2 = [
+ p.clone()
+ for p in loss_fn.target_actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ target_qvalue2 = [
+ p.clone()
+ for p in loss_fn.target_qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ if version == 1:
+ target_value2 = [
+ p.clone()
+ for p in loss_fn.target_value_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ if loss_fn.delay_actor:
+ assert all(
+ (p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)
+ )
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
+ )
+ if loss_fn.delay_qvalue:
+ assert all(
+ (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2)
+ )
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
+ )
+ if version == 1:
+ if loss_fn.delay_value:
+ assert all(
+ (p1 == p2).all() for p1, p2 in zip(target_value, target_value2)
+ )
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_value, target_value2)
+ )
+
+ # check that policy is updated after parameter update
+ parameters = [p.clone() for p in actor.parameters()]
+ for p in loss_fn.parameters():
+ if p.requires_grad:
+ p.data += torch.randn_like(p)
+ assert all(
+ (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())
+ )
+
+ @pytest.mark.parametrize(
+ "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
+ )
+ def test_sac_tensordict_keys(self, td_est, version):
+ td = self._create_mock_data_sac()
+
+ actor = self._create_mock_actor()
+ qvalue = self._create_mock_qvalue()
+ if version == 1:
+ value = self._create_mock_value()
+ else:
+ value = None
+
+ loss_fn = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ num_qvalue_nets=2,
+ loss_function="l2",
+ )
+
+ default_keys = {
+ "priority": "td_error",
+ "value": "state_value",
+ "state_action_value": "state_action_value",
+ "action": "action",
+ "log_prob": "_log_prob",
+ "reward": "reward",
+ "done": "done",
+ "terminated": "terminated",
+ }
+
+ self.tensordict_keys_test(
+ loss_fn,
+ default_keys=default_keys,
+ td_est=td_est,
+ )
+
+ value = self._create_mock_value()
+ loss_fn = SACLoss(
+ actor,
+ value,
+ loss_function="l2",
+ )
+
+ key_mapping = {
+ "value": ("value", "state_value_test"),
+ "reward": ("reward", "reward_test"),
+ "done": ("done", ("done", "test")),
+ "terminated": ("terminated", ("terminated", "test")),
+ }
+ self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)
+
+ @pytest.mark.parametrize("action_key", ["action", "action2"])
+ @pytest.mark.parametrize("observation_key", ["observation", "observation2"])
+ @pytest.mark.parametrize("reward_key", ["reward", "reward2"])
+ @pytest.mark.parametrize("done_key", ["done", "done2"])
+ @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
+ def test_sac_notensordict(
+ self, action_key, observation_key, reward_key, done_key, terminated_key, version
+ ):
+ torch.manual_seed(self.seed)
+ td = self._create_mock_data_sac(
+ action_key=action_key,
+ observation_key=observation_key,
+ reward_key=reward_key,
+ done_key=done_key,
+ terminated_key=terminated_key,
+ )
+
+ actor = self._create_mock_actor(
+ observation_key=observation_key, action_key=action_key
+ )
+ qvalue = self._create_mock_qvalue(
+ observation_key=observation_key,
+ action_key=action_key,
+ out_keys=["state_action_value"],
+ )
+ if version == 1:
+ value = self._create_mock_value(observation_key=observation_key)
+ else:
+ value = None
+
+ loss = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ )
+ loss.set_keys(
+ action=action_key,
+ reward=reward_key,
+ done=done_key,
+ terminated=terminated_key,
+ )
+
+ kwargs = {
+ action_key: td.get(action_key),
+ observation_key: td.get(observation_key),
+ f"next_{reward_key}": td.get(("next", reward_key)),
+ f"next_{done_key}": td.get(("next", done_key)),
+ f"next_{terminated_key}": td.get(("next", terminated_key)),
+ f"next_{observation_key}": td.get(("next", observation_key)),
+ }
+ td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
+
+ # setting the seed for each loss so that drawing the random samples from value network
+ # leads to same numbers for both runs
+ torch.manual_seed(self.seed)
+ with pytest.warns(UserWarning, match="No target network updater"):
+ loss_val = loss(**kwargs)
+
+ torch.manual_seed(self.seed)
+
+ SoftUpdate(loss, eps=0.5)
+
+ loss_val_td = loss(td)
+
+ if version == 1:
+ assert len(loss_val) == 6
+ elif version == 2:
+ assert len(loss_val) == 5
+
+ torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
+ torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
+ torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
+ torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
+ torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
+ if version == 1:
+ torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[5])
+ # test select
+ torch.manual_seed(self.seed)
+ loss.select_out_keys("loss_actor", "loss_alpha")
+ if torch.__version__ >= "2.0.0":
+ loss_actor, loss_alpha = loss(**kwargs)
+ else:
+ with pytest.raises(
+ RuntimeError,
+ match="You are likely using tensordict.nn.dispatch with keyword arguments",
+ ):
+ loss_actor, loss_alpha = loss(**kwargs)
+ return
+ assert loss_actor == loss_val_td["loss_actor"]
+ assert loss_alpha == loss_val_td["loss_alpha"]
+
+ def test_state_dict(self, version):
+ if version == 1:
+ pytest.skip("Test not implemented for version 1.")
+ model = torch.nn.Linear(3, 4)
+ actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"])
+ policy = ProbabilisticActor(
+ module=actor_module,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=TanhDelta,
+ )
+ value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")
+
+ loss = SACLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ state = loss.state_dict()
+
+ loss = SACLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ loss.load_state_dict(state)
+
+ # with an access in between
+ loss = SACLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ loss.target_entropy
+ state = loss.state_dict()
+
+ loss = SACLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ loss.load_state_dict(state)
+
+ @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
+ def test_sac_reduction(self, reduction, version):
+ torch.manual_seed(self.seed)
+ device = (
+ torch.device("cpu")
+ if torch.cuda.device_count() == 0
+ else torch.device("cuda")
+ )
+ td = self._create_mock_data_sac(device=device)
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+ if version == 1:
+ value = self._create_mock_value(device=device)
+ else:
+ value = None
+ loss_fn = SACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ value_network=value,
+ loss_function="l2",
+ delay_qvalue=False,
+ delay_actor=False,
+ delay_value=False,
+ reduction=reduction,
+ )
+ loss_fn.make_value_estimator()
+ loss = loss_fn(td)
+ if reduction == "none":
+ for key in loss.keys():
+ if key.startswith("loss"):
+ assert loss[key].shape == td.shape
+ else:
+ for key in loss.keys():
+ if not key.startswith("loss"):
+ continue
+ assert loss[key].shape == torch.Size([])
+
+
+@pytest.mark.skipif(
+ not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
+)
+class TestDiscreteSAC(LossModuleTestBase):
+ seed = 0
+
+ def _create_mock_actor(
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ observation_key="observation",
+ action_key="action",
+ ):
+ # Actor
+ action_spec = OneHotDiscreteTensorSpec(action_dim)
+ net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
+ module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
+ actor = ProbabilisticActor(
+ spec=action_spec,
+ module=module,
+ in_keys=["logits"],
+ out_keys=[action_key],
+ distribution_class=OneHotCategorical,
+ return_log_prob=False,
+ )
+ return actor.to(device)
+
+ def _create_mock_qvalue(
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ observation_key="observation",
+ ):
+ class ValueClass(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(obs_dim, action_dim)
+
+ def forward(self, obs):
+ return self.linear(obs)
+
+ module = ValueClass()
+ qvalue = ValueOperator(
+ module=module, in_keys=[observation_key], out_keys=["action_value"]
+ )
+ return qvalue.to(device)
+
+ def _create_mock_distributional_actor(
+ self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5
+ ):
+ raise NotImplementedError
+
+ def _create_mock_data_sac(
+ self,
+ batch=16,
+ obs_dim=3,
+ action_dim=4,
+ atoms=None,
+ device="cpu",
+ observation_key="observation",
+ action_key="action",
+ done_key="done",
+ terminated_key="terminated",
+ reward_key="reward",
+ ):
+ # create a tensordict
+ obs = torch.randn(batch, obs_dim, device=device)
+ next_obs = torch.randn(batch, obs_dim, device=device)
+ if atoms:
+ action_value = torch.randn(batch, atoms, action_dim).softmax(-2)
+ action = (
+ (action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0])
+ .to(torch.long)
+ .to(device)
+ )
+ else:
+ action_value = torch.randn(batch, action_dim, device=device)
+ action = (action_value == action_value.max(-1, True)[0]).to(torch.long)
+ reward = torch.randn(batch, 1, device=device)
+ done = torch.zeros(batch, 1, dtype=torch.bool, device=device)
+ terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device)
+ td = TensorDict(
+ batch_size=(batch,),
+ source={
+ observation_key: obs,
+ "next": {
+ observation_key: next_obs,
+ done_key: done,
+ terminated_key: terminated,
+ reward_key: reward,
+ },
+ action_key: action,
+ },
+ device=device,
+ )
+ return td
+
+ def _create_seq_mock_data_sac(
+ self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
+ ):
+ # create a tensordict
+ total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
+ obs = total_obs[:, :T]
+ next_obs = total_obs[:, 1:]
+ if atoms:
+ action_value = torch.randn(
+ batch, T, atoms, action_dim, device=device
+ ).softmax(-2)
+ action = (
+ action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0]
+ ).to(torch.long)
+ else:
+ action_value = torch.randn(batch, T, action_dim, device=device)
+ action = (action_value == action_value.max(-1, True)[0]).to(torch.long)
+
+ reward = torch.randn(batch, T, 1, device=device)
+ done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
+ terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
+ mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
+ td = TensorDict(
+ batch_size=(batch, T),
+ source={
+ "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "next": {
+ "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "done": done,
+ "terminated": terminated,
+ "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ },
+ "collector": {"mask": mask},
+ "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ },
+ names=[None, "time"],
+ )
+ return td
+
+ @pytest.mark.parametrize("delay_qvalue", (True, False))
+ @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99])
+ @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0])
+ @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
+ def test_discrete_sac(
+ self,
+ delay_qvalue,
+ num_qvalue,
+ device,
+ target_entropy_weight,
+ target_entropy,
+ td_est,
+ ):
+ torch.manual_seed(self.seed)
+ td = self._create_mock_data_sac(device=device)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+
+ kwargs = {}
+ if delay_qvalue:
+ kwargs["delay_qvalue"] = True
+
+ loss_fn = DiscreteSACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_actions=actor.spec["action"].space.n,
+ num_qvalue_nets=num_qvalue,
+ target_entropy_weight=target_entropy_weight,
+ target_entropy=target_entropy,
+ loss_function="l2",
+ action_space="one-hot",
+ **kwargs,
+ )
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
+ with pytest.raises(NotImplementedError):
+ loss_fn.make_value_estimator(td_est)
+ return
+ if td_est is not None:
+ loss_fn.make_value_estimator(td_est)
+
+ with _check_td_steady(td), pytest.warns(
+ UserWarning, match="No target network updater"
+ ):
+ loss = loss_fn(td)
+
+ assert loss_fn.tensor_keys.priority in td.keys()
+
+ # check that losses are independent
+ for k in loss.keys():
+ if not k.startswith("loss"):
+ continue
+ loss[k].sum().backward(retain_graph=True)
+ if k == "loss_actor":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
include_nested=True, leaves_only=True
)
- ]
- for p in loss_fn.parameters():
- if p.requires_grad:
- p.data += torch.randn_like(p)
- target_actor2 = [
- p.clone()
- for p in loss_fn.target_actor_network_params.values(
- include_nested=True, leaves_only=True
- )
- ]
- target_qvalue2 = [
- p.clone()
- for p in loss_fn.target_qvalue_network_params.values(
- include_nested=True, leaves_only=True
)
- ]
- if version == 1:
- target_value2 = [
- p.clone()
- for p in loss_fn.target_value_network_params.values(
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
include_nested=True, leaves_only=True
)
- ]
- if loss_fn.delay_actor:
+ )
+ elif k == "loss_qvalue":
assert all(
- (p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
)
- else:
assert not any(
- (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
)
- if loss_fn.delay_qvalue:
+ elif k == "loss_alpha":
assert all(
- (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2)
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
)
else:
- assert not any(
- (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
- )
- if version == 1:
- if loss_fn.delay_value:
- assert all(
- (p1 == p2).all() for p1, p2 in zip(target_value, target_value2)
- )
- else:
- assert not any(
- (p1 == p2).any() for p1, p2 in zip(target_value, target_value2)
- )
+ raise NotImplementedError(k)
+ loss_fn.zero_grad()
+
+ sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = list(loss_fn.named_parameters())
+ named_buffers = list(loss_fn.named_buffers())
+
+ assert len({p for n, p in named_parameters}) == len(list(named_parameters))
+ assert len({p for n, p in named_buffers}) == len(list(named_buffers))
+
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
+
+ @pytest.mark.parametrize("delay_qvalue", (True, False))
+ @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("target_entropy_weight", [0.5])
+ @pytest.mark.parametrize("target_entropy", ["auto"])
+ def test_discrete_sac_state_dict(
+ self,
+ delay_qvalue,
+ num_qvalue,
+ device,
+ target_entropy_weight,
+ target_entropy,
+ ):
+ torch.manual_seed(self.seed)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+
+ kwargs = {}
+ if delay_qvalue:
+ kwargs["delay_qvalue"] = True
+
+ loss_fn = DiscreteSACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_actions=actor.spec["action"].space.n,
+ num_qvalue_nets=num_qvalue,
+ target_entropy_weight=target_entropy_weight,
+ target_entropy=target_entropy,
+ loss_function="l2",
+ action_space="one-hot",
+ **kwargs,
+ )
+ sd = loss_fn.state_dict()
+ loss_fn2 = DiscreteSACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_actions=actor.spec["action"].space.n,
+ num_qvalue_nets=num_qvalue,
+ target_entropy_weight=target_entropy_weight,
+ target_entropy=target_entropy,
+ loss_function="l2",
+ action_space="one-hot",
+ **kwargs,
+ )
+ loss_fn2.load_state_dict(sd)
+
+ @pytest.mark.parametrize("n", range(1, 4))
+ @pytest.mark.parametrize("delay_qvalue", (True, False))
+ @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99])
+ @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0])
+ def test_discrete_sac_batcher(
+ self,
+ n,
+ delay_qvalue,
+ num_qvalue,
+ device,
+ target_entropy_weight,
+ target_entropy,
+ gamma=0.9,
+ ):
+ torch.manual_seed(self.seed)
+ td = self._create_seq_mock_data_sac(device=device)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+
+ kwargs = {}
+ if delay_qvalue:
+ kwargs["delay_qvalue"] = True
+ loss_fn = DiscreteSACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_actions=actor.spec["action"].space.n,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ target_entropy_weight=target_entropy_weight,
+ target_entropy=target_entropy,
+ action_space="one-hot",
+ **kwargs,
+ )
+
+ ms = MultiStep(gamma=gamma, n_steps=n).to(device)
+
+ td_clone = td.clone()
+ ms_td = ms(td_clone)
+
+ torch.manual_seed(0)
+ np.random.seed(0)
+ with _check_td_steady(ms_td), pytest.warns(
+ UserWarning, match="No target network updater"
+ ):
+ loss_ms = loss_fn(ms_td)
+ assert loss_fn.tensor_keys.priority in ms_td.keys()
+
+ SoftUpdate(loss_fn, eps=0.5)
+
+ with torch.no_grad():
+ torch.manual_seed(0) # log-prob is computed with a random action
+ np.random.seed(0)
+ loss = loss_fn(td)
+ if n == 1:
+ assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
+ _loss = sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ )
+ _loss_ms = sum(
+ [item for name, item in loss_ms.items() if name.startswith("loss_")]
+ )
+ assert (
+ abs(_loss - _loss_ms) < 1e-3
+ ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0"
+ else:
+ with pytest.raises(AssertionError):
+ assert_allclose_td(loss, loss_ms)
+ sum(
+ [item for name, item in loss_ms.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = loss_fn.named_parameters()
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
- # check that policy is updated after parameter update
- parameters = [p.clone() for p in actor.parameters()]
- for p in loss_fn.parameters():
- if p.requires_grad:
- p.data += torch.randn_like(p)
+ # Check param update effect on targets
+ target_actor = [
+ p.clone()
+ for p in loss_fn.target_actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ target_qvalue = [
+ p.clone()
+ for p in loss_fn.target_qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ for p in loss_fn.parameters():
+ if p.requires_grad:
+ p.data += torch.randn_like(p)
+ target_actor2 = [
+ p.clone()
+ for p in loss_fn.target_actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ target_qvalue2 = [
+ p.clone()
+ for p in loss_fn.target_qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ ]
+ if loss_fn.delay_actor:
+ assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2))
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
+ )
+ if loss_fn.delay_qvalue:
assert all(
- (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())
+ (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2)
+ )
+ else:
+ assert not any(
+ (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
)
+ # check that policy is updated after parameter update
+ parameters = [p.clone() for p in actor.parameters()]
+ for p in loss_fn.parameters():
+ if p.requires_grad:
+ p.data += torch.randn_like(p)
+ assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))
+
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
- def test_sac_tensordict_keys(self, td_est, version):
- td = self._create_mock_data_sac()
-
+ def test_discrete_sac_tensordict_keys(self, td_est):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
- if version == 1:
- value = self._create_mock_value()
- else:
- value = None
- loss_fn = SACLoss(
+ loss_fn = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
- value_network=value,
- num_qvalue_nets=2,
+ num_actions=actor.spec["action"].space.n,
loss_function="l2",
+ action_space="one-hot",
)
default_keys = {
"priority": "td_error",
"value": "state_value",
- "state_action_value": "state_action_value",
"action": "action",
- "log_prob": "_log_prob",
"reward": "reward",
"done": "done",
"terminated": "terminated",
}
-
self.tensordict_keys_test(
loss_fn,
default_keys=default_keys,
td_est=td_est,
)
- value = self._create_mock_value()
- loss_fn = SACLoss(
- actor,
- value,
+ qvalue = self._create_mock_qvalue()
+ loss_fn = DiscreteSACLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_actions=actor.spec["action"].space.n,
loss_function="l2",
+ action_space="one-hot",
)
key_mapping = {
@@ -3460,8 +4842,8 @@ def test_sac_tensordict_keys(self, td_est, version):
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
- def test_sac_notensordict(
- self, action_key, observation_key, reward_key, done_key, terminated_key, version
+ def test_discrete_sac_notensordict(
+ self, action_key, observation_key, reward_key, done_key, terminated_key
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
@@ -3477,18 +4859,13 @@ def test_sac_notensordict(
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
- action_key=action_key,
- out_keys=["state_action_value"],
)
- if version == 1:
- value = self._create_mock_value(observation_key=observation_key)
- else:
- value = None
- loss = SACLoss(
+ loss = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
- value_network=value,
+ num_actions=actor.spec[action_key].space.n,
+ action_space="one-hot",
)
loss.set_keys(
action=action_key,
@@ -3507,90 +4884,32 @@ def test_sac_notensordict(
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
- # setting the seed for each loss so that drawing the random samples from value network
- # leads to same numbers for both runs
- torch.manual_seed(self.seed)
- with pytest.warns(UserWarning, match="No target network updater"):
+ with pytest.warns(UserWarning, match="No target network updater has been"):
loss_val = loss(**kwargs)
+ loss_val_td = loss(td)
- torch.manual_seed(self.seed)
-
- SoftUpdate(loss, eps=0.5)
-
- loss_val_td = loss(td)
-
- if version == 1:
- assert len(loss_val) == 6
- elif version == 2:
- assert len(loss_val) == 5
-
- torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
- torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
- torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
- torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
- torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
- if version == 1:
- torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[5])
- # test select
- torch.manual_seed(self.seed)
- loss.select_out_keys("loss_actor", "loss_alpha")
- if torch.__version__ >= "2.0.0":
- loss_actor, loss_alpha = loss(**kwargs)
- else:
- with pytest.raises(
- RuntimeError,
- match="You are likely using tensordict.nn.dispatch with keyword arguments",
- ):
+ torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
+ torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
+ torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
+ torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
+ torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
+ # test select
+ torch.manual_seed(self.seed)
+ loss.select_out_keys("loss_actor", "loss_alpha")
+ if torch.__version__ >= "2.0.0":
loss_actor, loss_alpha = loss(**kwargs)
- return
- assert loss_actor == loss_val_td["loss_actor"]
- assert loss_alpha == loss_val_td["loss_alpha"]
-
- def test_state_dict(self, version):
- if version == 1:
- pytest.skip("Test not implemented for version 1.")
- model = torch.nn.Linear(3, 4)
- actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"])
- policy = ProbabilisticActor(
- module=actor_module,
- in_keys=["logits"],
- out_keys=["action"],
- distribution_class=TanhDelta,
- )
- value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")
-
- loss = SACLoss(
- actor_network=policy,
- qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
- )
- state = loss.state_dict()
-
- loss = SACLoss(
- actor_network=policy,
- qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
- )
- loss.load_state_dict(state)
-
- # with an access in between
- loss = SACLoss(
- actor_network=policy,
- qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
- )
- loss.target_entropy
- state = loss.state_dict()
-
- loss = SACLoss(
- actor_network=policy,
- qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
- )
- loss.load_state_dict(state)
+ else:
+ with pytest.raises(
+ RuntimeError,
+ match="You are likely using tensordict.nn.dispatch with keyword arguments",
+ ):
+ loss_actor, loss_alpha = loss(**kwargs)
+ return
+ assert loss_actor == loss_val_td["loss_actor"]
+ assert loss_alpha == loss_val_td["loss_alpha"]
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
- def test_sac_reduction(self, reduction, version):
+ def test_discrete_sac_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
@@ -3600,18 +4919,13 @@ def test_sac_reduction(self, reduction, version):
td = self._create_mock_data_sac(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
- if version == 1:
- value = self._create_mock_value(device=device)
- else:
- value = None
- loss_fn = SACLoss(
+ loss_fn = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
- value_network=value,
+ num_actions=actor.spec["action"].space.n,
loss_function="l2",
+ action_space="one-hot",
delay_qvalue=False,
- delay_actor=False,
- delay_value=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
@@ -3627,10 +4941,7 @@ def test_sac_reduction(self, reduction, version):
assert loss[key].shape == torch.Size([])
-@pytest.mark.skipif(
- not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
-)
-class TestDiscreteSAC(LossModuleTestBase):
+class TestCrossQ(LossModuleTestBase):
seed = 0
def _create_mock_actor(
@@ -3643,16 +4954,19 @@ def _create_mock_actor(
action_key="action",
):
# Actor
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = BoundedTensorSpec(
+ -torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
+ )
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
- module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
+ module = TensorDictModule(
+ net, in_keys=[observation_key], out_keys=["loc", "scale"]
+ )
actor = ProbabilisticActor(
- spec=action_spec,
module=module,
- in_keys=["logits"],
+ in_keys=["loc", "scale"],
+ spec=action_spec,
+ distribution_class=TanhNormal,
out_keys=[action_key],
- distribution_class=OneHotCategorical,
- return_log_prob=False,
)
return actor.to(device)
@@ -3663,27 +4977,85 @@ def _create_mock_qvalue(
action_dim=4,
device="cpu",
observation_key="observation",
+ action_key="action",
+ out_keys=None,
):
class ValueClass(nn.Module):
def __init__(self):
super().__init__()
- self.linear = nn.Linear(obs_dim, action_dim)
+ self.linear = nn.Linear(obs_dim + action_dim, 1)
- def forward(self, obs):
- return self.linear(obs)
+ def forward(self, obs, act):
+ return self.linear(torch.cat([obs, act], -1))
module = ValueClass()
qvalue = ValueOperator(
- module=module, in_keys=[observation_key], out_keys=["action_value"]
+ module=module,
+ in_keys=[observation_key, action_key],
+ out_keys=out_keys,
)
return qvalue.to(device)
+ def _create_mock_common_layer_setup(
+ self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2
+ ):
+ common = MLP(
+ num_cells=ncells,
+ in_features=n_obs,
+ depth=3,
+ out_features=n_hidden,
+ )
+ actor_net = MLP(
+ num_cells=ncells,
+ in_features=n_hidden,
+ depth=1,
+ out_features=2 * n_act,
+ )
+ qvalue = MLP(
+ in_features=n_hidden + n_act,
+ num_cells=ncells,
+ depth=1,
+ out_features=1,
+ )
+ batch = [batch]
+ td = TensorDict(
+ {
+ "obs": torch.randn(*batch, n_obs),
+ "action": torch.randn(*batch, n_act),
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
+ "terminated": torch.zeros(*batch, 1, dtype=torch.bool),
+ "next": {
+ "obs": torch.randn(*batch, n_obs),
+ "reward": torch.randn(*batch, 1),
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
+ "terminated": torch.zeros(*batch, 1, dtype=torch.bool),
+ },
+ },
+ batch,
+ )
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
+ actor = ProbSeq(
+ common,
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
+ ProbMod(
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ ),
+ )
+ qvalue_head = Mod(
+ qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"]
+ )
+ qvalue = Seq(common, qvalue_head)
+ return actor, qvalue, common, td
+
def _create_mock_distributional_actor(
self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5
):
raise NotImplementedError
- def _create_mock_data_sac(
+ def _create_mock_data_crossq(
self,
batch=16,
obs_dim=3,
@@ -3700,15 +5072,9 @@ def _create_mock_data_sac(
obs = torch.randn(batch, obs_dim, device=device)
next_obs = torch.randn(batch, obs_dim, device=device)
if atoms:
- action_value = torch.randn(batch, atoms, action_dim).softmax(-2)
- action = (
- (action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0])
- .to(torch.long)
- .to(device)
- )
+ raise NotImplementedError
else:
- action_value = torch.randn(batch, action_dim, device=device)
- action = (action_value == action_value.max(-1, True)[0]).to(torch.long)
+ action = torch.randn(batch, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, 1, device=device)
done = torch.zeros(batch, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device)
@@ -3728,7 +5094,7 @@ def _create_mock_data_sac(
)
return td
- def _create_seq_mock_data_sac(
+ def _create_seq_mock_data_crossq(
self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
):
# create a tensordict
@@ -3736,20 +5102,15 @@ def _create_seq_mock_data_sac(
obs = total_obs[:, :T]
next_obs = total_obs[:, 1:]
if atoms:
- action_value = torch.randn(
- batch, T, atoms, action_dim, device=device
- ).softmax(-2)
- action = (
- action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0]
- ).to(torch.long)
+ action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
+ -1, 1
+ )
else:
- action_value = torch.randn(batch, T, action_dim, device=device)
- action = (action_value == action_value.max(-1, True)[0]).to(torch.long)
-
+ action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
- mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
+ mask = torch.ones(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
@@ -3762,48 +5123,33 @@ def _create_seq_mock_data_sac(
},
"collector": {"mask": mask},
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
- "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
names=[None, "time"],
+ device=device,
)
return td
- @pytest.mark.parametrize("delay_qvalue", (True, False))
- @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
- @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99])
- @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0])
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
- def test_discrete_sac(
+ def test_crossq(
self,
- delay_qvalue,
num_qvalue,
device,
- target_entropy_weight,
- target_entropy,
td_est,
):
torch.manual_seed(self.seed)
- td = self._create_mock_data_sac(device=device)
-
+ td = self._create_mock_data_crossq(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
- kwargs = {}
- if delay_qvalue:
- kwargs["delay_qvalue"] = True
-
- loss_fn = DiscreteSACLoss(
+ loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
num_qvalue_nets=num_qvalue,
- target_entropy_weight=target_entropy_weight,
- target_entropy=target_entropy,
loss_function="l2",
- action_space="one-hot",
- **kwargs,
)
+
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
@@ -3811,9 +5157,7 @@ def test_discrete_sac(
if td_est is not None:
loss_fn.make_value_estimator(td_est)
- with _check_td_steady(td), pytest.warns(
- UserWarning, match="No target network updater"
- ):
+ with _check_td_steady(td):
loss = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys()
@@ -3842,13 +5186,145 @@ def test_discrete_sac(
for p in loss_fn.actor_network_params.values(
include_nested=True, leaves_only=True
)
- )
- assert not any(
- (p.grad is None) or (p.grad == 0).all()
- for p in loss_fn.qvalue_network_params.values(
- include_nested=True, leaves_only=True
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ elif k == "loss_alpha":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ else:
+ raise NotImplementedError(k)
+ loss_fn.zero_grad()
+
+ sum(
+ [item for name, item in loss.items() if name.startswith("loss_")]
+ ).backward()
+ named_parameters = list(loss_fn.named_parameters())
+ named_buffers = list(loss_fn.named_buffers())
+
+ assert len({p for n, p in named_parameters}) == len(list(named_parameters))
+ assert len({p for n, p in named_buffers}) == len(list(named_buffers))
+
+ for name, p in named_parameters:
+ if not name.startswith("target_"):
+ assert (
+ p.grad is not None and p.grad.norm() > 0.0
+ ), f"parameter {name} (shape: {p.shape}) has a null gradient"
+ else:
+ assert (
+ p.grad is None or p.grad.norm() == 0.0
+ ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
+
+ @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("device", get_default_devices())
+ def test_crossq_state_dict(
+ self,
+ num_qvalue,
+ device,
+ ):
+ torch.manual_seed(self.seed)
+
+ actor = self._create_mock_actor(device=device)
+ qvalue = self._create_mock_qvalue(device=device)
+
+ loss_fn = CrossQLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ )
+ sd = loss_fn.state_dict()
+ loss_fn2 = CrossQLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ num_qvalue_nets=num_qvalue,
+ loss_function="l2",
+ )
+ loss_fn2.load_state_dict(sd)
+
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("separate_losses", [False, True])
+ def test_crossq_separate_losses(
+ self,
+ separate_losses,
+ device,
+ ):
+ n_act = 4
+ torch.manual_seed(self.seed)
+ actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act)
+
+ loss_fn = CrossQLoss(
+ actor_network=actor,
+ qvalue_network=qvalue,
+ action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)),
+ num_qvalue_nets=1,
+ separate_losses=separate_losses,
+ )
+ loss = loss_fn(td)
+
+ assert loss_fn.tensor_keys.priority in td.keys()
+
+ # check that losses are independent
+ for k in loss.keys():
+ if not k.startswith("loss"):
+ continue
+ loss[k].sum().backward(retain_graph=True)
+ if k == "loss_actor":
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ elif k == "loss_qvalue":
+ common_layers_no = len(list(common.parameters()))
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ if separate_losses:
+ common_layers = itertools.islice(
+ loss_fn.qvalue_network_params.values(True, True),
+ common_layers_no,
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all() for p in common_layers
+ )
+ qvalue_layers = itertools.islice(
+ loss_fn.qvalue_network_params.values(True, True),
+ common_layers_no,
+ None,
+ )
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all() for p in qvalue_layers
+ )
+ else:
+ assert not any(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(True, True)
)
- )
elif k == "loss_alpha":
assert all(
(p.grad is None) or (p.grad == 0).all()
@@ -3866,124 +5342,40 @@ def test_discrete_sac(
raise NotImplementedError(k)
loss_fn.zero_grad()
- sum(
- [item for name, item in loss.items() if name.startswith("loss_")]
- ).backward()
- named_parameters = list(loss_fn.named_parameters())
- named_buffers = list(loss_fn.named_buffers())
-
- assert len({p for n, p in named_parameters}) == len(list(named_parameters))
- assert len({p for n, p in named_buffers}) == len(list(named_buffers))
-
- for name, p in named_parameters:
- if not name.startswith("target_"):
- assert (
- p.grad is not None and p.grad.norm() > 0.0
- ), f"parameter {name} (shape: {p.shape}) has a null gradient"
- else:
- assert (
- p.grad is None or p.grad.norm() == 0.0
- ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"
-
- @pytest.mark.parametrize("delay_qvalue", (True, False))
- @pytest.mark.parametrize("num_qvalue", [2])
- @pytest.mark.parametrize("device", get_default_devices())
- @pytest.mark.parametrize("target_entropy_weight", [0.5])
- @pytest.mark.parametrize("target_entropy", ["auto"])
- def test_discrete_sac_state_dict(
- self,
- delay_qvalue,
- num_qvalue,
- device,
- target_entropy_weight,
- target_entropy,
- ):
- torch.manual_seed(self.seed)
-
- actor = self._create_mock_actor(device=device)
- qvalue = self._create_mock_qvalue(device=device)
-
- kwargs = {}
- if delay_qvalue:
- kwargs["delay_qvalue"] = True
-
- loss_fn = DiscreteSACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
- num_qvalue_nets=num_qvalue,
- target_entropy_weight=target_entropy_weight,
- target_entropy=target_entropy,
- loss_function="l2",
- action_space="one-hot",
- **kwargs,
- )
- sd = loss_fn.state_dict()
- loss_fn2 = DiscreteSACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
- num_qvalue_nets=num_qvalue,
- target_entropy_weight=target_entropy_weight,
- target_entropy=target_entropy,
- loss_function="l2",
- action_space="one-hot",
- **kwargs,
- )
- loss_fn2.load_state_dict(sd)
-
@pytest.mark.parametrize("n", range(1, 4))
- @pytest.mark.parametrize("delay_qvalue", (True, False))
- @pytest.mark.parametrize("num_qvalue", [2])
+ @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
- @pytest.mark.parametrize("target_entropy_weight", [0.01, 0.5, 0.99])
- @pytest.mark.parametrize("target_entropy", ["auto", 1.0, 0.1, 0.0])
- def test_discrete_sac_batcher(
+ def test_crossq_batcher(
self,
n,
- delay_qvalue,
num_qvalue,
device,
- target_entropy_weight,
- target_entropy,
- gamma=0.9,
):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_sac(device=device)
+ td = self._create_seq_mock_data_crossq(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
- kwargs = {}
- if delay_qvalue:
- kwargs["delay_qvalue"] = True
- loss_fn = DiscreteSACLoss(
+ loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
num_qvalue_nets=num_qvalue,
loss_function="l2",
- target_entropy_weight=target_entropy_weight,
- target_entropy=target_entropy,
- action_space="one-hot",
- **kwargs,
)
- ms = MultiStep(gamma=gamma, n_steps=n).to(device)
+ ms = MultiStep(gamma=0.9, n_steps=n).to(device)
td_clone = td.clone()
ms_td = ms(td_clone)
torch.manual_seed(0)
np.random.seed(0)
- with _check_td_steady(ms_td), pytest.warns(
- UserWarning, match="No target network updater"
- ):
+
+ with _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()
- SoftUpdate(loss_fn, eps=0.5)
-
with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
@@ -4023,12 +5415,6 @@ def test_discrete_sac_batcher(
include_nested=True, leaves_only=True
)
]
- target_qvalue = [
- p.clone()
- for p in loss_fn.target_qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- ]
for p in loss_fn.parameters():
if p.requires_grad:
p.data += torch.randn_like(p)
@@ -4038,26 +5424,8 @@ def test_discrete_sac_batcher(
include_nested=True, leaves_only=True
)
]
- target_qvalue2 = [
- p.clone()
- for p in loss_fn.target_qvalue_network_params.values(
- include_nested=True, leaves_only=True
- )
- ]
- if loss_fn.delay_actor:
- assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2))
- else:
- assert not any(
- (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
- )
- if loss_fn.delay_qvalue:
- assert all(
- (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2)
- )
- else:
- assert not any(
- (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
- )
+
+ assert not any((p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2))
# check that policy is updated after parameter update
parameters = [p.clone() for p in actor.parameters()]
@@ -4069,26 +5437,29 @@ def test_discrete_sac_batcher(
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
- def test_discrete_sac_tensordict_keys(self, td_est):
+ def test_crossq_tensordict_keys(self, td_est):
+
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
+ value = None
- loss_fn = DiscreteSACLoss(
+ loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
+ num_qvalue_nets=2,
loss_function="l2",
- action_space="one-hot",
)
default_keys = {
"priority": "td_error",
- "value": "state_value",
+ "state_action_value": "state_action_value",
"action": "action",
+ "log_prob": "_log_prob",
"reward": "reward",
"done": "done",
"terminated": "terminated",
}
+
self.tensordict_keys_test(
loss_fn,
default_keys=default_keys,
@@ -4096,16 +5467,13 @@ def test_discrete_sac_tensordict_keys(self, td_est):
)
qvalue = self._create_mock_qvalue()
- loss_fn = DiscreteSACLoss(
- actor_network=actor,
- qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
+ loss_fn = CrossQLoss(
+ actor,
+ qvalue,
loss_function="l2",
- action_space="one-hot",
)
key_mapping = {
- "value": ("value", "state_value_test"),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
"terminated": ("terminated", ("terminated", "test")),
@@ -4117,11 +5485,11 @@ def test_discrete_sac_tensordict_keys(self, td_est):
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
- def test_discrete_sac_notensordict(
+ def test_crossq_notensordict(
self, action_key, observation_key, reward_key, done_key, terminated_key
):
torch.manual_seed(self.seed)
- td = self._create_mock_data_sac(
+ td = self._create_mock_data_crossq(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
@@ -4134,13 +5502,13 @@ def test_discrete_sac_notensordict(
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
+ action_key=action_key,
+ out_keys=["state_action_value"],
)
- loss = DiscreteSACLoss(
+ loss = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
- num_actions=actor.spec[action_key].space.n,
- action_space="one-hot",
)
loss.set_keys(
action=action_key,
@@ -4159,48 +5527,97 @@ def test_discrete_sac_notensordict(
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
- with pytest.warns(UserWarning, match="No target network updater has been"):
- loss_val = loss(**kwargs)
- loss_val_td = loss(td)
+ # setting the seed for each loss so that drawing the random samples from value network
+ # leads to same numbers for both runs
+ torch.manual_seed(self.seed)
+ loss_val = loss(**kwargs)
- torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
- torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
- torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
- torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
- torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
- # test select
- torch.manual_seed(self.seed)
- loss.select_out_keys("loss_actor", "loss_alpha")
- if torch.__version__ >= "2.0.0":
+ torch.manual_seed(self.seed)
+
+ loss_val_td = loss(td)
+ assert len(loss_val) == 5
+
+ torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
+ torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
+ torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
+ torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
+ torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
+
+ # test select
+ torch.manual_seed(self.seed)
+ loss.select_out_keys("loss_actor", "loss_alpha")
+ if torch.__version__ >= "2.0.0":
+ loss_actor, loss_alpha = loss(**kwargs)
+ else:
+ with pytest.raises(
+ RuntimeError,
+ match="You are likely using tensordict.nn.dispatch with keyword arguments",
+ ):
loss_actor, loss_alpha = loss(**kwargs)
- else:
- with pytest.raises(
- RuntimeError,
- match="You are likely using tensordict.nn.dispatch with keyword arguments",
- ):
- loss_actor, loss_alpha = loss(**kwargs)
- return
- assert loss_actor == loss_val_td["loss_actor"]
- assert loss_alpha == loss_val_td["loss_alpha"]
+ return
+ assert loss_actor == loss_val_td["loss_actor"]
+ assert loss_alpha == loss_val_td["loss_alpha"]
+
+ def test_state_dict(
+ self,
+ ):
+
+ model = torch.nn.Linear(3, 4)
+ actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"])
+ policy = ProbabilisticActor(
+ module=actor_module,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=TanhDelta,
+ )
+ value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")
+
+ loss = CrossQLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ state = loss.state_dict()
+
+ loss = CrossQLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ loss.load_state_dict(state)
+
+ # with an access in between
+ loss = CrossQLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ loss.target_entropy
+ state = loss.state_dict()
+
+ loss = CrossQLoss(
+ actor_network=policy,
+ qvalue_network=value,
+ action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ )
+ loss.load_state_dict(state)
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
- def test_discrete_sac_reduction(self, reduction):
+ def test_crossq_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
- td = self._create_mock_data_sac(device=device)
+ td = self._create_mock_data_crossq(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
- loss_fn = DiscreteSACLoss(
+
+ loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
- num_actions=actor.spec["action"].space.n,
loss_function="l2",
- action_space="one-hot",
- delay_qvalue=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
@@ -5686,9 +7103,9 @@ def _create_mock_actor(
spec=CompositeSpec(
{
"action": action_spec,
- "action_value"
- if action_value_key is None
- else action_value_key: None,
+ (
+ "action_value" if action_value_key is None else action_value_key
+ ): None,
"chosen_action_value": None,
},
shape=[],
diff --git a/test/test_env.py b/test/test_env.py
index e6ca38b729c..f8f242f3955 100644
--- a/test/test_env.py
+++ b/test/test_env.py
@@ -2061,6 +2061,7 @@ def main_collector(j, q=None):
total_frames=N * n_workers * 100,
storing_device=device,
device=device,
+ cat_results=-1,
)
single_collectors = [
SyncDataCollector(
diff --git a/test/test_modules.py b/test/test_modules.py
index 59adbea653d..592464f0a96 100644
--- a/test/test_modules.py
+++ b/test/test_modules.py
@@ -34,7 +34,14 @@
VDNMixer,
)
from torchrl.modules.distributions.utils import safeatanh, safetanh
-from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear
+from torchrl.modules.models import (
+ BatchRenorm1d,
+ Conv3dNet,
+ ConvNet,
+ MLP,
+ NoisyLazyLinear,
+ NoisyLinear,
+)
from torchrl.modules.models.decision_transformer import (
_has_transformers,
DecisionTransformer,
@@ -1438,6 +1445,40 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers):
torch.testing.assert_close(h1, h2)
+class TestBatchRenorm:
+ @pytest.mark.parametrize("num_steps", [0, 5])
+ @pytest.mark.parametrize("smooth", [False, True])
+ def test_batchrenorm(self, num_steps, smooth):
+ torch.manual_seed(0)
+ bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5)
+ brn = BatchRenorm1d(
+ 5,
+ momentum=0.1,
+ eps=1e-5,
+ warmup_steps=num_steps,
+ max_d=10000,
+ max_r=10000,
+ smooth=smooth,
+ )
+ bn.train()
+ brn.train()
+ data_train = torch.randn(100, 5).split(25)
+ data_test = torch.randn(100, 5)
+ for i, d in enumerate(data_train):
+ b = bn(d)
+ a = brn(d)
+ if num_steps > 0 and (
+ (i < num_steps and not smooth) or (i == 0 and smooth)
+ ):
+ torch.testing.assert_close(a, b)
+ else:
+ assert not torch.isclose(a, b).all(), i
+
+ bn.eval()
+ brn.eval()
+ torch.testing.assert_close(bn(data_test), brn(data_test))
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index 50e3dd5cc49..32294a25edd 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -2065,18 +2065,18 @@ def _queue_len(self) -> int:
def iterator(self) -> Iterator[TensorDictBase]:
cat_results = self.cat_results
if cat_results is None:
- cat_results = 0
+ cat_results = "stack"
warnings.warn(
f"`cat_results` was not specified in the constructor of {type(self).__name__}. "
f"For MultiSyncDataCollector, `cat_results` indicates how the data should "
- f"be packed: the preferred option is `cat_results='stack'` which provides "
- f"the best interoperability across torchrl components. "
+ f"be packed: the preferred option and current default is `cat_results='stack'` "
+ f"which provides the best interoperability across torchrl components. "
f"Other accepted values are `cat_results=0` (previous behaviour) and "
f"`cat_results=-1` (cat along time dimension). Among these two, the latter "
f"should be preferred for consistency across environment configurations. "
- f"Currently, the default value is `0` (using torch.cat along first dimension)."
- f"From v0.5 onward, this will default to `'stack'`. "
- f"To suppress this warning, set stack_results to the desired value.",
+ f"Currently, the default value is `'stack'`."
+ f"From v0.6 onward, this warning will be removed. "
+ f"To suppress this warning, set `cat_results` to the desired value.",
category=DeprecationWarning,
)
diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py
index 04c24cb8d57..0006213cd27 100644
--- a/torchrl/data/tensor_specs.py
+++ b/torchrl/data/tensor_specs.py
@@ -1143,6 +1143,7 @@ def __eq__(self, other):
if not isinstance(other, LazyStackedTensorSpec):
return False
if self.device != other.device:
+ raise RuntimeError((self, other))
return False
if len(self._specs) != len(other._specs):
return False
diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py
index 7f462782757..4996e527527 100644
--- a/torchrl/envs/batched_envs.py
+++ b/torchrl/envs/batched_envs.py
@@ -26,8 +26,8 @@
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
+ unravel_key,
)
-from tensordict._tensordict import unravel_key
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
@@ -406,17 +406,16 @@ def _find_sync_values(self):
return _do_nothing, _do_nothing
if worker_device is None:
- worker_not_main = [False]
+ worker_not_main = False
- def find_all_worker_devices(item, worker_not_main=worker_not_main):
+ def find_all_worker_devices(item):
+ nonlocal worker_not_main
if hasattr(item, "device"):
- worker_not_main[0] = worker_not_main[0] or (
- item.device != self_device
- )
+ worker_not_main = worker_not_main or (item.device != self_device)
for td in self.shared_tensordicts:
td.apply(find_all_worker_devices, filter_empty=True)
- if worker_not_main[0]:
+ if worker_not_main:
if torch.cuda.is_available():
worker_device = (
torch.device("cuda")
@@ -431,6 +430,8 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main):
)
else:
raise RuntimeError("Did not find a valid worker device")
+ else:
+ worker_device = self_device
if (
worker_device is not None
@@ -460,6 +461,7 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main):
and self_device.type == "mps"
):
return _mps_sync(self_device), _mps_sync(self_device)
+ return _do_nothing, _do_nothing
def __getstate__(self):
out = copy(self.__dict__)
diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py
index c965e7dedf3..e30de3534d9 100644
--- a/torchrl/envs/common.py
+++ b/torchrl/envs/common.py
@@ -15,7 +15,6 @@
import torch
import torch.nn as nn
from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key
-from tensordict.base import NO_DEFAULT
from tensordict.utils import NestedKey
from torchrl._utils import (
_ends_with,
@@ -3020,21 +3019,11 @@ class _EnvWrapper(EnvBase):
def __init__(
self,
*args,
- device: DEVICE_TYPING = NO_DEFAULT,
+ device: DEVICE_TYPING = None,
batch_size: Optional[torch.Size] = None,
allow_done_after_reset: bool = False,
**kwargs,
):
- if device is NO_DEFAULT:
- warnings.warn(
- "Your wrapper was not given a device. Currently, this "
- "value will default to 'cpu'. From v0.5 it will "
- "default to `None`. With a device of None, no device casting "
- "is performed and the resulting tensordicts are deviceless. "
- "Please set your device accordingly.",
- category=DeprecationWarning,
- )
- device = torch.device("cpu")
super().__init__(
device=device,
batch_size=batch_size,
diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py
index 47f93f09779..c7935272c91 100644
--- a/torchrl/envs/gym_like.py
+++ b/torchrl/envs/gym_like.py
@@ -348,8 +348,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
batch_size=tensordict.batch_size,
)
if self.device is not None:
- tensordict_out = tensordict_out.to(self.device, non_blocking=True)
- self._sync_device()
+ tensordict_out = tensordict_out.to(self.device)
if self.info_dict_reader and (info_dict is not None):
if not isinstance(info_dict, dict):
@@ -393,8 +392,7 @@ def _reset(
if key not in tensordict_out.keys(True, True):
tensordict_out[key] = item.zero()
if self.device is not None:
- tensordict_out = tensordict_out.to(self.device, non_blocking=True)
- self._sync_device()
+ tensordict_out = tensordict_out.to(self.device)
return tensordict_out
@abc.abstractmethod
diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py
index 07c48587c14..9195929e31d 100644
--- a/torchrl/envs/libs/gym.py
+++ b/torchrl/envs/libs/gym.py
@@ -27,7 +27,6 @@
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
- LazyStackedTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
@@ -246,8 +245,8 @@ def _gym_to_torchrl_spec_transform(
).expand(batch_size)
gym_spaces = gym_backend("spaces")
if isinstance(spec, gym_spaces.tuple.Tuple):
- result = LazyStackedTensorSpec(
- *[
+ result = torch.stack(
+ [
_gym_to_torchrl_spec_transform(
s,
device=device,
diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py
index 8e30fdb2a7e..9751e84a3ac 100644
--- a/torchrl/envs/libs/vmas.py
+++ b/torchrl/envs/libs/vmas.py
@@ -795,7 +795,9 @@ def _build_env(
env=vmas.make_env(
scenario=scenario,
num_envs=num_envs,
- device=self.device,
+ device=self.device
+ if self.device is not None
+ else torch.get_default_device(),
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index bec76c603e6..70aef03e041 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -39,7 +39,7 @@
unravel_key,
unravel_key_list,
)
-from tensordict._tensordict import _unravel_key_to_tuple
+from tensordict._C import _unravel_key_to_tuple
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import expand_as_right, expand_right, NestedKey
from torch import nn, Tensor
@@ -3411,14 +3411,7 @@ def __init__(
out_keys_inv: Sequence[NestedKey] | None = None,
):
if in_keys is not None and in_keys_inv is None:
- warnings.warn(
- "in_keys have been provided but not in_keys_inv. From v0.5, "
- "this will result in in_keys_inv being an empty list whereas "
- "now the input keys are retrieved automatically. "
- "To silence this warning, pass the (possibly empty) "
- "list of in_keys_inv.",
- category=DeprecationWarning,
- )
+ in_keys_inv = []
self.dtype_in = dtype_in
self.dtype_out = dtype_out
diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py
index 087cabe4186..38d8d1dfd02 100644
--- a/torchrl/modules/distributions/continuous.py
+++ b/torchrl/modules/distributions/continuous.py
@@ -481,9 +481,10 @@ def root_dist(self):
@property
def mode(self):
warnings.warn(
- "This computation of the mode is based on the first-order Taylor expansion "
- "of the transform around the normal mean value, which can be inaccurate. "
+ "This computation of the mode is based on an inaccurate estimation of the mode "
+ "given the base_dist mode. "
"To use a more stable implementation of the mode, use dist.get_mode() method instead. "
+ "To silence this warning, consider using the DETERMINISTIC exploration_type."
"This implementation will be removed in v0.6.",
category=DeprecationWarning,
)
diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py
index fb0cc0135b8..62ccf53c30a 100644
--- a/torchrl/modules/models/__init__.py
+++ b/torchrl/modules/models/__init__.py
@@ -6,6 +6,8 @@
from torchrl.modules.tensordict_module.common import DistributionalDQNnet
+from .batchrenorm import BatchRenorm1d
+
from .decision_transformer import DecisionTransformer
from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
from .model_based import (
diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py
new file mode 100644
index 00000000000..26a2f9d50d2
--- /dev/null
+++ b/torchrl/modules/models/batchrenorm.py
@@ -0,0 +1,117 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+import torch.nn as nn
+
+
+class BatchRenorm1d(nn.Module):
+ """BatchRenorm Module (https://arxiv.org/abs/1702.03275).
+
+ The code is adapted from https://github.com/google-research/corenet
+
+ BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm,
+ it utilizes running statistics to normalize batches after an initial warmup phase.
+ This approach reduces the impact of "outlier" batches that may occur during
+ extended training periods, making BatchRenorm more robust for long training runs.
+
+ During the warmup phase, BatchRenorm functions identically to a BatchNorm layer.
+
+ Args:
+ num_features (int): Number of features in the input tensor.
+
+ Keyword Args:
+ momentum (float, optional): Momentum factor for computing the running mean and variance.
+ Defaults to ``0.01``.
+ eps (float, optional): Small value added to the variance to avoid division by zero.
+ Defaults to ``1e-5``.
+ max_r (float, optional): Maximum value for the scaling factor r.
+ Defaults to ``3.0``.
+ max_d (float, optional): Maximum value for the bias factor d.
+ Defaults to ``5.0``.
+ warmup_steps (int, optional): Number of warm-up steps for the running mean and variance.
+ Defaults to ``10000``.
+ smooth (bool, optional): if ``True``, the behaviour smoothly transitions from regular
+ batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``).
+ Otherwise, the behaviour will transition from batch-norm to batch-renorm when
+ ``iter=warmup_steps``. Defaults to ``False``.
+ """
+
+ def __init__(
+ self,
+ num_features: int,
+ *,
+ momentum: float = 0.01,
+ eps: float = 1e-5,
+ max_r: float = 3.0,
+ max_d: float = 5.0,
+ warmup_steps: int = 10000,
+ smooth: bool = False,
+ ):
+ super().__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.max_r = max_r
+ self.max_d = max_d
+ self.warmup_steps = warmup_steps
+ self.smooth = smooth
+
+ self.register_buffer(
+ "running_mean", torch.zeros(num_features, dtype=torch.float32)
+ )
+ self.register_buffer(
+ "running_var", torch.ones(num_features, dtype=torch.float32)
+ )
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64))
+ self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32))
+ self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if not x.dim() >= 2:
+ raise ValueError(
+ f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}."
+ )
+
+ view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2)
+
+ def _v(v):
+ return v.view(view_dims)
+
+ running_std = (self.running_var + self.eps).sqrt_()
+
+ if self.training:
+ reduce_dims = [i for i in range(x.dim()) if i != 1]
+ b_mean = x.mean(reduce_dims)
+ b_var = x.var(reduce_dims, unbiased=False)
+ b_std = (b_var + self.eps).sqrt_()
+
+ r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r)
+ d = torch.clamp(
+ (b_mean.detach() - self.running_mean) / running_std,
+ -self.max_d,
+ self.max_d,
+ )
+
+ # Compute warmup factor (0 during warmup, 1 after warmup)
+ if self.warmup_steps > 0:
+ if self.smooth:
+ warmup_factor = self.num_batches_tracked / self.warmup_steps
+ else:
+ warmup_factor = self.num_batches_tracked // self.warmup_steps
+ r = 1.0 + (r - 1.0) * warmup_factor
+ d = d * warmup_factor
+
+ x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d)
+
+ unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1)
+ self.running_var += self.momentum * (unbiased_var - self.running_var)
+ self.running_mean += self.momentum * (b_mean.detach() - self.running_mean)
+ self.num_batches_tracked += 1
+ self.num_batches_tracked.clamp_max(self.warmup_steps)
+ else:
+ x = (x - _v(self.running_mean)) / _v(running_std)
+
+ x = _v(self.weight) * x + _v(self.bias)
+ return x
diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py
index 17b1ea77ee4..83b6a8d1fb3 100644
--- a/torchrl/modules/tensordict_module/actors.py
+++ b/torchrl/modules/tensordict_module/actors.py
@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
-import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
@@ -922,10 +921,9 @@ def __init__(
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
- warnings.warn(
- "Using specs in action_space will be deprecated in v0.4.0,"
- " please use the 'spec' argument if you want to provide an action spec",
- category=DeprecationWarning,
+ raise RuntimeError(
+ "Using specs in action_space is deprecated. "
+ "Please use the 'spec' argument if you want to provide an action spec"
)
action_space, _ = _process_action_space_spec(action_space, None)
@@ -1136,10 +1134,9 @@ def __init__(
action_mask_key: Optional[NestedKey] = None,
):
if isinstance(action_space, TensorSpec):
- warnings.warn(
- "Using specs in action_space will be deprecated v0.4.0,"
- " please use the 'spec' argument if you want to provide an action spec",
- category=DeprecationWarning,
+ raise RuntimeError(
+ "Using specs in action_space is deprecated."
+ "Please use the 'spec' argument if you want to provide an action spec"
)
action_space, spec = _process_action_space_spec(action_space, spec)
diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py
index f8d2bd1d977..aa13a88c7e9 100644
--- a/torchrl/objectives/__init__.py
+++ b/torchrl/objectives/__init__.py
@@ -6,6 +6,7 @@
from .a2c import A2CLoss
from .common import LossModule
from .cql import CQLLoss, DiscreteCQLLoss
+from .crossq import CrossQLoss
from .ddpg import DDPGLoss
from .decision_transformer import DTLoss, OnlineDTLoss
from .dqn import DistributionalDQNLoss, DQNLoss
@@ -17,6 +18,7 @@
from .reinforce import ReinforceLoss
from .sac import DiscreteSACLoss, SACLoss
from .td3 import TD3Loss
+from .td3_bc import TD3BCLoss
from .utils import (
default_value_kwargs,
distance_loss,
diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py
new file mode 100644
index 00000000000..22d35bd5799
--- /dev/null
+++ b/torchrl/objectives/crossq.py
@@ -0,0 +1,662 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import math
+from dataclasses import dataclass
+from functools import wraps
+from typing import Dict, Tuple, Union
+
+import torch
+from tensordict import TensorDict, TensorDictBase, TensorDictParams
+
+from tensordict.nn import dispatch, TensorDictModule
+from tensordict.utils import NestedKey
+from torch import Tensor
+from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.modules import ProbabilisticActor
+from torchrl.objectives.common import LossModule
+
+from torchrl.objectives.utils import (
+ _cache_values,
+ _reduce,
+ _vmap_func,
+ default_value_kwargs,
+ distance_loss,
+ ValueEstimators,
+)
+from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
+
+
+def _delezify(func):
+ @wraps(func)
+ def new_func(self, *args, **kwargs):
+ self.target_entropy
+ return func(self, *args, **kwargs)
+
+ return new_func
+
+
+class CrossQLoss(LossModule):
+ """TorchRL implementation of the CrossQ loss.
+
+ Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING
+ FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX
+
+ This class has three loss functions that will be called sequentially by the `forward` method:
+ :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can
+ be called by the user that order.
+
+ Args:
+ actor_network (ProbabilisticActor): stochastic actor
+ qvalue_network (TensorDictModule): Q(s, a) parametric model.
+ This module typically outputs a ``"state_action_value"`` entry.
+
+ Keyword Args:
+ num_qvalue_nets (integer, optional): number of Q-Value networks used.
+ Defaults to ``2``.
+ loss_function (str, optional): loss function to be used with
+ the value function loss. Default is `"smooth_l1"`.
+ alpha_init (float, optional): initial entropy multiplier.
+ Default is 1.0.
+ min_alpha (float, optional): min value of alpha.
+ Default is None (no minimum value).
+ max_alpha (float, optional): max value of alpha.
+ Default is None (no maximum value).
+ action_spec (TensorSpec, optional): the action tensor spec. If not provided
+ and the target entropy is ``"auto"``, it will be retrieved from
+ the actor.
+ fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
+ initial value. Otherwise, alpha will be optimized to
+ match the 'target_entropy' value.
+ Default is ``False``.
+ target_entropy (float or str, optional): Target entropy for the
+ stochastic policy. Default is "auto", where target entropy is
+ computed as :obj:`-prod(n_actions)`.
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
+ Tensordict key where to write the
+ priority (for prioritized replay buffer usage). Defaults to ``"td_error"``.
+ separate_losses (bool, optional): if ``True``, shared parameters between
+ policy and critic will only be trained on the policy loss.
+ Defaults to ``False``, ie. gradients are propagated to shared
+ parameters for both policy and critic losses.
+ reduction (str, optional): Specifies the reduction to apply to the output:
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
+ ``"mean"``: the sum of the output will be divided by the number of
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
+
+ Examples:
+ >>> import torch
+ >>> from torch import nn
+ >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
+ >>> from torchrl.objectives.crossq import CrossQLoss
+ >>> from tensordict import TensorDict
+ >>> n_act, n_obs = 4, 3
+ >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
+ >>> actor = ProbabilisticActor(
+ ... module=module,
+ ... in_keys=["loc", "scale"],
+ ... spec=spec,
+ ... distribution_class=TanhNormal)
+ >>> class ValueClass(nn.Module):
+ ... def __init__(self):
+ ... super().__init__()
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
+ ... def forward(self, obs, act):
+ ... return self.linear(torch.cat([obs, act], -1))
+ >>> module = ValueClass()
+ >>> qvalue = ValueOperator(
+ ... module=module,
+ ... in_keys=['observation', 'action'])
+ >>> loss = CrossQLoss(actor, qvalue)
+ >>> batch = [2, ]
+ >>> action = spec.rand(batch)
+ >>> data = TensorDict({
+ ... "observation": torch.randn(*batch, n_obs),
+ ... "action": action,
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
+ ... ("next", "reward"): torch.randn(*batch, 1),
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
+ ... }, batch)
+ >>> loss(data)
+ TensorDict(
+ fields={
+ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)
+
+ This class is compatible with non-tensordict based modules too and can be
+ used without recurring to any tensordict-related primitive. In this case,
+ the expected keyword arguments are:
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network.
+ The return value is a tuple of tensors in the following order:
+ ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]``
+
+ Examples:
+ >>> import torch
+ >>> from torch import nn
+ >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
+ >>> from torchrl.objectives import CrossQLoss
+ >>> _ = torch.manual_seed(42)
+ >>> n_act, n_obs = 4, 3
+ >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
+ >>> actor = ProbabilisticActor(
+ ... module=module,
+ ... in_keys=["loc", "scale"],
+ ... spec=spec,
+ ... distribution_class=TanhNormal)
+ >>> class ValueClass(nn.Module):
+ ... def __init__(self):
+ ... super().__init__()
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
+ ... def forward(self, obs, act):
+ ... return self.linear(torch.cat([obs, act], -1))
+ >>> module = ValueClass()
+ >>> qvalue = ValueOperator(
+ ... module=module,
+ ... in_keys=['observation', 'action'])
+ >>> loss = CrossQLoss(actor, qvalue)
+ >>> batch = [2, ]
+ >>> action = spec.rand(batch)
+ >>> loss_actor, loss_qvalue, _, _, _ = loss(
+ ... observation=torch.randn(*batch, n_obs),
+ ... action=action,
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
+ ... next_observation=torch.zeros(*batch, n_obs),
+ ... next_reward=torch.randn(*batch, 1))
+ >>> loss_actor.backward()
+
+ The output keys can also be filtered using the :meth:`CrossQLoss.select_out_keys`
+ method.
+
+ Examples:
+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
+ >>> loss_actor, loss_qvalue = loss(
+ ... observation=torch.randn(*batch, n_obs),
+ ... action=action,
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
+ ... next_observation=torch.zeros(*batch, n_obs),
+ ... next_reward=torch.randn(*batch, 1))
+ >>> loss_actor.backward()
+ """
+
+ @dataclass
+ class _AcceptedKeys:
+ """Maintains default values for all configurable tensordict keys.
+
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
+ default values.
+
+ Attributes:
+ action (NestedKey): The input tensordict key where the action is expected.
+ Defaults to ``"advantage"``.
+ state_action_value (NestedKey): The input tensordict key where the
+ state action value is expected. Defaults to ``"state_action_value"``.
+ priority (NestedKey): The input tensordict key where the target priority is written to.
+ Defaults to ``"td_error"``.
+ reward (NestedKey): The input tensordict key where the reward is expected.
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
+ done (NestedKey): The key in the input TensorDict that indicates
+ whether a trajectory is done. Will be used for the underlying value estimator.
+ Defaults to ``"done"``.
+ terminated (NestedKey): The key in the input TensorDict that indicates
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
+ Defaults to ``"terminated"``.
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
+ Defaults to ``"_log_prob"``.
+ """
+
+ action: NestedKey = "action"
+ state_action_value: NestedKey = "state_action_value"
+ priority: NestedKey = "td_error"
+ reward: NestedKey = "reward"
+ done: NestedKey = "done"
+ terminated: NestedKey = "terminated"
+ log_prob: NestedKey = "_log_prob"
+
+ default_keys = _AcceptedKeys()
+ default_value_estimator = ValueEstimators.TD0
+
+ actor_network: ProbabilisticActor
+ actor_network_params: TensorDictParams
+ qvalue_network: TensorDictModule
+ qvalue_network_params: TensorDictParams
+ target_actor_network_params: TensorDictParams
+ target_qvalue_network_params: TensorDictParams
+
+ def __init__(
+ self,
+ actor_network: ProbabilisticActor,
+ qvalue_network: TensorDictModule,
+ *,
+ num_qvalue_nets: int = 2,
+ loss_function: str = "smooth_l1",
+ alpha_init: float = 1.0,
+ min_alpha: float = None,
+ max_alpha: float = None,
+ action_spec=None,
+ fixed_alpha: bool = False,
+ target_entropy: Union[str, float] = "auto",
+ priority_key: str = None,
+ separate_losses: bool = False,
+ reduction: str = None,
+ ) -> None:
+ self._in_keys = None
+ self._out_keys = None
+ if reduction is None:
+ reduction = "mean"
+ super().__init__()
+ self._set_deprecated_ctor_keys(priority_key=priority_key)
+
+ # Actor
+ self.convert_to_functional(
+ actor_network,
+ "actor_network",
+ create_target_params=False,
+ )
+ if separate_losses:
+ # we want to make sure there are no duplicates in the params: the
+ # params of critic must be refs to actor if they're shared
+ policy_params = list(actor_network.parameters())
+ else:
+ policy_params = None
+ q_value_policy_params = None
+
+ # Q value
+ self.num_qvalue_nets = num_qvalue_nets
+
+ q_value_policy_params = policy_params
+ self.convert_to_functional(
+ qvalue_network,
+ "qvalue_network",
+ num_qvalue_nets,
+ create_target_params=False,
+ compare_against=q_value_policy_params,
+ )
+
+ self.loss_function = loss_function
+ try:
+ device = next(self.parameters()).device
+ except AttributeError:
+ device = torch.device("cpu")
+ self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
+ if bool(min_alpha) ^ bool(max_alpha):
+ min_alpha = min_alpha if min_alpha else 0.0
+ if max_alpha == 0:
+ raise ValueError("max_alpha must be either None or greater than 0.")
+ max_alpha = max_alpha if max_alpha else 1e9
+ if min_alpha:
+ self.register_buffer(
+ "min_log_alpha", torch.tensor(min_alpha, device=device).log()
+ )
+ else:
+ self.min_log_alpha = None
+ if max_alpha:
+ self.register_buffer(
+ "max_log_alpha", torch.tensor(max_alpha, device=device).log()
+ )
+ else:
+ self.max_log_alpha = None
+ self.fixed_alpha = fixed_alpha
+ if fixed_alpha:
+ self.register_buffer(
+ "log_alpha", torch.tensor(math.log(alpha_init), device=device)
+ )
+ else:
+ self.register_parameter(
+ "log_alpha",
+ torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
+ )
+
+ self._target_entropy = target_entropy
+ self._action_spec = action_spec
+ self._vmap_qnetworkN0 = _vmap_func(
+ self.qvalue_network, (None, 0), randomness=self.vmap_randomness
+ )
+ self.reduction = reduction
+
+ @property
+ def target_entropy_buffer(self):
+ """The target entropy.
+
+ This value can be controlled via the `target_entropy` kwarg in the constructor.
+ """
+ return self.target_entropy
+
+ @property
+ def target_entropy(self):
+ target_entropy = self._buffers.get("_target_entropy", None)
+ if target_entropy is not None:
+ return target_entropy
+ target_entropy = self._target_entropy
+ action_spec = self._action_spec
+ actor_network = self.actor_network
+ device = next(self.parameters()).device
+ if target_entropy == "auto":
+ action_spec = (
+ action_spec
+ if action_spec is not None
+ else getattr(actor_network, "spec", None)
+ )
+ if action_spec is None:
+ raise RuntimeError(
+ "Cannot infer the dimensionality of the action. Consider providing "
+ "the target entropy explicitely or provide the spec of the "
+ "action tensor in the actor network."
+ )
+ if not isinstance(action_spec, CompositeSpec):
+ action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
+ if (
+ isinstance(self.tensor_keys.action, tuple)
+ and len(self.tensor_keys.action) > 1
+ ):
+ action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
+ else:
+ action_container_shape = action_spec.shape
+ target_entropy = -float(
+ action_spec[self.tensor_keys.action]
+ .shape[len(action_container_shape) :]
+ .numel()
+ )
+ delattr(self, "_target_entropy")
+ self.register_buffer(
+ "_target_entropy", torch.tensor(target_entropy, device=device)
+ )
+ return self._target_entropy
+
+ state_dict = _delezify(LossModule.state_dict)
+ load_state_dict = _delezify(LossModule.load_state_dict)
+
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
+ if self._value_estimator is not None:
+ self._value_estimator.set_keys(
+ value=self.tensor_keys.value,
+ reward=self.tensor_keys.reward,
+ done=self.tensor_keys.done,
+ terminated=self.tensor_keys.terminated,
+ )
+ self._set_in_keys()
+
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
+ if value_type is None:
+ value_type = self.default_value_estimator
+ self.value_type = value_type
+
+ value_net = None
+ hp = dict(default_value_kwargs(value_type))
+ hp.update(hyperparams)
+ if value_type is ValueEstimators.TD1:
+ self._value_estimator = TD1Estimator(
+ **hp,
+ value_network=value_net,
+ )
+ elif value_type is ValueEstimators.TD0:
+ self._value_estimator = TD0Estimator(
+ **hp,
+ value_network=value_net,
+ )
+ elif value_type is ValueEstimators.GAE:
+ raise NotImplementedError(
+ f"Value type {value_type} it not implemented for loss {type(self)}."
+ )
+ elif value_type is ValueEstimators.TDLambda:
+ self._value_estimator = TDLambdaEstimator(
+ **hp,
+ value_network=value_net,
+ )
+ else:
+ raise NotImplementedError(f"Unknown value type {value_type}")
+
+ tensor_keys = {
+ "reward": self.tensor_keys.reward,
+ "done": self.tensor_keys.done,
+ "terminated": self.tensor_keys.terminated,
+ }
+ self._value_estimator.set_keys(**tensor_keys)
+
+ @property
+ def device(self) -> torch.device:
+ for p in self.parameters():
+ return p.device
+ raise RuntimeError(
+ "At least one of the networks of SACLoss must have trainable " "parameters."
+ )
+
+ def _set_in_keys(self):
+ keys = [
+ self.tensor_keys.action,
+ ("next", self.tensor_keys.reward),
+ ("next", self.tensor_keys.done),
+ ("next", self.tensor_keys.terminated),
+ *self.actor_network.in_keys,
+ *[("next", key) for key in self.actor_network.in_keys],
+ *self.qvalue_network.in_keys,
+ ]
+ self._in_keys = list(set(keys))
+
+ @property
+ def in_keys(self):
+ if self._in_keys is None:
+ self._set_in_keys()
+ return self._in_keys
+
+ @in_keys.setter
+ def in_keys(self, values):
+ self._in_keys = values
+
+ @property
+ def out_keys(self):
+ if self._out_keys is None:
+ keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
+ self._out_keys = keys
+ return self._out_keys
+
+ @out_keys.setter
+ def out_keys(self, values):
+ self._out_keys = values
+
+ @dispatch
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
+ """The forward method.
+
+ Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns
+ a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached).
+ To see what keys are expected in the input tensordict and what keys are expected as output, check the
+ class's `"in_keys"` and `"out_keys"` attributes.
+ """
+ shape = None
+ if tensordict.ndimension() > 1:
+ shape = tensordict.shape
+ tensordict_reshape = tensordict.reshape(-1)
+ else:
+ tensordict_reshape = tensordict
+
+ loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape)
+ loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
+ loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"])
+ tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
+ if loss_actor.shape != loss_qvalue.shape:
+ raise RuntimeError(
+ f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}"
+ )
+ if shape:
+ tensordict.update(tensordict_reshape.view(shape))
+ entropy = -metadata_actor["log_prob"]
+ out = {
+ "loss_actor": loss_actor,
+ "loss_qvalue": loss_qvalue,
+ "loss_alpha": loss_alpha,
+ "alpha": self._alpha,
+ "entropy": entropy.detach().mean(),
+ **metadata_actor,
+ **value_metadata,
+ }
+ td_out = TensorDict(out, [])
+ # td_out = td_out.named_apply(
+ # lambda name, value: (
+ # _reduce(value, reduction=self.reduction)
+ # if name.startswith("loss_")
+ # else value
+ # ),
+ # batch_size=[],
+ # )
+ return td_out
+
+ @property
+ @_cache_values
+ def _cached_detached_qvalue_params(self):
+ return self.qvalue_network_params.detach()
+
+ def actor_loss(
+ self, tensordict: TensorDictBase
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
+ """Compute the actor loss.
+
+ The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which
+ requires the `log_prob` field of the `metadata` returned by this method.
+
+ Args:
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
+ are required for this to be computed.
+
+ Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action.
+ """
+ with set_exploration_type(
+ ExplorationType.RANDOM
+ ), self.actor_network_params.to_module(self.actor_network):
+ dist = self.actor_network.get_dist(tensordict)
+ a_reparm = dist.rsample()
+ log_prob = dist.log_prob(a_reparm)
+
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
+ self.qvalue_network.eval()
+ td_q.set(self.tensor_keys.action, a_reparm)
+ td_q = self._vmap_qnetworkN0(
+ td_q,
+ self._cached_detached_qvalue_params,
+ )
+
+ min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
+ self.qvalue_network.train()
+
+ if log_prob.shape != min_q.shape:
+ raise RuntimeError(
+ f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}"
+ )
+ actor_loss = self._alpha * log_prob - min_q
+ return _reduce(actor_loss, reduction=self.reduction), {
+ "log_prob": log_prob.detach()
+ }
+
+ def qvalue_loss(
+ self, tensordict: TensorDictBase
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
+ """Compute the q-value loss.
+
+ The q-value loss should be computed before the :meth:`~.actor_loss`.
+
+ Args:
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
+ are required for this to be computed.
+
+ Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
+ the detached `"td_error"` to be used for prioritized sampling.
+ """
+ # # compute next action
+ with torch.no_grad():
+ with set_exploration_type(
+ ExplorationType.RANDOM
+ ), self.actor_network_params.to_module(self.actor_network):
+ next_tensordict = tensordict.get("next").clone(False)
+ next_dist = self.actor_network.get_dist(next_tensordict)
+ next_action = next_dist.sample()
+ next_tensordict.set(self.tensor_keys.action, next_action)
+ next_sample_log_prob = next_dist.log_prob(next_action)
+
+ combined = torch.cat(
+ [
+ tensordict.select(*self.qvalue_network.in_keys, strict=False),
+ next_tensordict.select(*self.qvalue_network.in_keys, strict=False),
+ ]
+ )
+ pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get(
+ self.tensor_keys.state_action_value
+ )
+ (current_state_action_value, next_state_action_value) = pred_qs.split(
+ tensordict.batch_size[0], dim=1
+ )
+
+ # compute target value
+ if (
+ next_state_action_value.shape[-len(next_sample_log_prob.shape) :]
+ != next_sample_log_prob.shape
+ ):
+ next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
+ next_state_action_value = next_state_action_value.min(0)[0]
+ next_state_action_value = (
+ next_state_action_value - self._alpha * next_sample_log_prob
+ ).detach()
+
+ target_value = self.value_estimator.value_estimate(
+ tensordict, next_value=next_state_action_value
+ ).squeeze(-1)
+
+ # get current q-values
+ pred_val = current_state_action_value.squeeze(-1)
+
+ # compute loss
+ td_error = abs(pred_val - target_value)
+ loss_qval = distance_loss(
+ pred_val,
+ target_value.expand_as(pred_val),
+ loss_function=self.loss_function,
+ ).sum(0)
+ metadata = {"td_error": td_error.detach().max(0)[0]}
+ return _reduce(loss_qval, reduction=self.reduction), metadata
+
+ def alpha_loss(self, log_prob: Tensor) -> Tensor:
+ """Compute the entropy loss.
+
+ The entropy loss should be computed last.
+
+ Args:
+ log_prob (torch.Tensor): a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`.
+
+ Returns: a differentiable tensor with the entropy loss.
+ """
+ if self.target_entropy is not None:
+ # we can compute this loss even if log_alpha is not a parameter
+ alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
+ else:
+ # placeholder
+ alpha_loss = torch.zeros_like(log_prob)
+ return _reduce(alpha_loss, reduction=self.reduction)
+
+ @property
+ def _alpha(self):
+ if self.min_log_alpha is not None:
+ self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
+ with torch.no_grad():
+ alpha = self.log_alpha.exp()
+ return alpha
diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py
new file mode 100644
index 00000000000..93845bb00bd
--- /dev/null
+++ b/torchrl/objectives/td3_bc.py
@@ -0,0 +1,571 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+
+from tensordict import TensorDict, TensorDictBase, TensorDictParams
+from tensordict.nn import dispatch, TensorDictModule
+from tensordict.utils import NestedKey
+from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec
+
+from torchrl.envs.utils import step_mdp
+from torchrl.objectives.common import LossModule
+
+from torchrl.objectives.utils import (
+ _cache_values,
+ _reduce,
+ _vmap_func,
+ default_value_kwargs,
+ distance_loss,
+ ValueEstimators,
+)
+from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
+
+
+class TD3BCLoss(LossModule):
+ r"""TD3+BC Loss Module.
+
+ Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to
+ Offline Reinforcement Learning" `.
+
+ This class incorporates two loss functions, executed sequentially within the `forward` method:
+
+ 1. :meth:`~.qvalue_loss`
+ 2. :meth:`~.actor_loss`
+
+ Users also have the option to call these functions directly in the same order if preferred.
+
+ Args:
+ actor_network (TensorDictModule): the actor to be trained
+ qvalue_network (TensorDictModule): a single Q-value network that will
+ be multiplicated as many times as needed.
+
+ Keyword Args:
+ bounds (tuple of float, optional): the bounds of the action space.
+ Exclusive with ``action_spec``. Either this or ``action_spec`` must
+ be provided.
+ action_spec (TensorSpec, optional): the action spec.
+ Exclusive with ``bounds``. Either this or ``bounds`` must be provided.
+ num_qvalue_nets (int, optional): Number of Q-value networks to be
+ trained. Default is ``2``.
+ policy_noise (float, optional): Standard deviation for the target
+ policy action noise. Default is ``0.2``.
+ noise_clip (float, optional): Clipping range value for the sampled
+ target policy action noise. Default is ``0.5``.
+ alpha (float, optional): Weight for the behavioral cloning loss.
+ Defaults to ``2.5``.
+ priority_key (str, optional): Key where to write the priority value
+ for prioritized replay buffers. Default is
+ `"td_error"`.
+ loss_function (str, optional): loss function to be used for the Q-value.
+ Can be one of ``"smooth_l1"``, ``"l2"``,
+ ``"l1"``, Default is ``"smooth_l1"``.
+ delay_actor (bool, optional): whether to separate the target actor
+ networks from the actor networks used for
+ data collection. Default is ``True``.
+ delay_qvalue (bool, optional): Whether to separate the target Q value
+ networks from the Q value networks used
+ for data collection. Default is ``True``.
+ spec (TensorSpec, optional): the action tensor spec. If not provided
+ and the target entropy is ``"auto"``, it will be retrieved from
+ the actor.
+ separate_losses (bool, optional): if ``True``, shared parameters between
+ policy and critic will only be trained on the policy loss.
+ Defaults to ``False``, ie. gradients are propagated to shared
+ parameters for both policy and critic losses.
+ reduction (str, optional): Specifies the reduction to apply to the output:
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
+ ``"mean"``: the sum of the output will be divided by the number of
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
+
+ Examples:
+ >>> import torch
+ >>> from torch import nn
+ >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
+ >>> from torchrl.objectives.td3_bc import TD3BCLoss
+ >>> from tensordict import TensorDict
+ >>> n_act, n_obs = 4, 3
+ >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> module = nn.Linear(n_obs, n_act)
+ >>> actor = Actor(
+ ... module=module,
+ ... spec=spec)
+ >>> class ValueClass(nn.Module):
+ ... def __init__(self):
+ ... super().__init__()
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
+ ... def forward(self, obs, act):
+ ... return self.linear(torch.cat([obs, act], -1))
+ >>> module = ValueClass()
+ >>> qvalue = ValueOperator(
+ ... module=module,
+ ... in_keys=['observation', 'action'])
+ >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
+ >>> batch = [2, ]
+ >>> action = spec.rand(batch)
+ >>> data = TensorDict({
+ ... "observation": torch.randn(*batch, n_obs),
+ ... "action": action,
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
+ ... ("next", "reward"): torch.randn(*batch, 1),
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
+ ... }, batch)
+ >>> loss(data)
+ TensorDict(
+ fields={
+ bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
+ next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
+ state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
+ target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)
+
+ This class is compatible with non-tensordict based modules too and can be
+ used without recurring to any tensordict-related primitive. In this case,
+ the expected keyword arguments are:
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
+ The return value is a tuple of tensors in the following order:
+ ``["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.
+
+ Examples:
+ >>> import torch
+ >>> from torch import nn
+ >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
+ >>> from torchrl.objectives.td3_bc import TD3BCLoss
+ >>> n_act, n_obs = 4, 3
+ >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> module = nn.Linear(n_obs, n_act)
+ >>> actor = Actor(
+ ... module=module,
+ ... spec=spec)
+ >>> class ValueClass(nn.Module):
+ ... def __init__(self):
+ ... super().__init__()
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
+ ... def forward(self, obs, act):
+ ... return self.linear(torch.cat([obs, act], -1))
+ >>> module = ValueClass()
+ >>> qvalue = ValueOperator(
+ ... module=module,
+ ... in_keys=['observation', 'action'])
+ >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
+ >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
+ >>> batch = [2, ]
+ >>> action = spec.rand(batch)
+ >>> loss_actor, loss_qvalue = loss(
+ ... observation=torch.randn(*batch, n_obs),
+ ... action=action,
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
+ ... next_reward=torch.randn(*batch, 1),
+ ... next_observation=torch.randn(*batch, n_obs))
+ >>> loss_actor.backward()
+
+ """
+
+ @dataclass
+ class _AcceptedKeys:
+ """Maintains default values for all configurable tensordict keys.
+
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
+ default values.
+
+ Attributes:
+ action (NestedKey): The input tensordict key where the action is expected.
+ Defaults to ``"action"``.
+ state_action_value (NestedKey): The input tensordict key where the state action value is expected.
+ Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
+ priority (NestedKey): The input tensordict key where the target priority is written to.
+ Defaults to ``"td_error"``.
+ reward (NestedKey): The input tensordict key where the reward is expected.
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
+ done (NestedKey): The key in the input TensorDict that indicates
+ whether a trajectory is done. Will be used for the underlying value estimator.
+ Defaults to ``"done"``.
+ terminated (NestedKey): The key in the input TensorDict that indicates
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
+ Defaults to ``"terminated"``.
+ """
+
+ action: NestedKey = "action"
+ state_action_value: NestedKey = "state_action_value"
+ priority: NestedKey = "td_error"
+ reward: NestedKey = "reward"
+ done: NestedKey = "done"
+ terminated: NestedKey = "terminated"
+
+ default_keys = _AcceptedKeys()
+ default_value_estimator = ValueEstimators.TD0
+ out_keys = [
+ "loss_actor",
+ "loss_qvalue",
+ "bc_loss",
+ "lmbd",
+ "pred_value",
+ "state_action_value_actor",
+ "next_state_value",
+ "target_value",
+ ]
+
+ actor_network: TensorDictModule
+ qvalue_network: TensorDictModule
+ actor_network_params: TensorDictParams
+ qvalue_network_params: TensorDictParams
+ target_actor_network_params: TensorDictParams
+ target_qvalue_network_params: TensorDictParams
+
+ def __init__(
+ self,
+ actor_network: TensorDictModule,
+ qvalue_network: TensorDictModule,
+ *,
+ action_spec: TensorSpec = None,
+ bounds: Optional[Tuple[float]] = None,
+ num_qvalue_nets: int = 2,
+ policy_noise: float = 0.2,
+ noise_clip: float = 0.5,
+ alpha: float = 2.5,
+ loss_function: str = "smooth_l1",
+ delay_actor: bool = True,
+ delay_qvalue: bool = True,
+ priority_key: str = None,
+ separate_losses: bool = False,
+ reduction: str = None,
+ ) -> None:
+ if reduction is None:
+ reduction = "mean"
+ super().__init__()
+ self._in_keys = None
+ self._set_deprecated_ctor_keys(priority=priority_key)
+
+ self.delay_actor = delay_actor
+ self.delay_qvalue = delay_qvalue
+
+ self.convert_to_functional(
+ actor_network,
+ "actor_network",
+ create_target_params=self.delay_actor,
+ )
+ if separate_losses:
+ # we want to make sure there are no duplicates in the params: the
+ # params of critic must be refs to actor if they're shared
+ policy_params = list(actor_network.parameters())
+ else:
+ policy_params = None
+ self.convert_to_functional(
+ qvalue_network,
+ "qvalue_network",
+ num_qvalue_nets,
+ create_target_params=self.delay_qvalue,
+ compare_against=policy_params,
+ )
+
+ for p in self.parameters():
+ device = p.device
+ break
+ else:
+ device = None
+ self.num_qvalue_nets = num_qvalue_nets
+ self.loss_function = loss_function
+ self.policy_noise = policy_noise
+ self.noise_clip = noise_clip
+ self.alpha = alpha
+ if not ((action_spec is not None) ^ (bounds is not None)):
+ raise ValueError(
+ "One of 'bounds' and 'action_spec' must be provided, "
+ f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
+ )
+ elif action_spec is not None:
+ if isinstance(action_spec, CompositeSpec):
+ if (
+ isinstance(self.tensor_keys.action, tuple)
+ and len(self.tensor_keys.action) > 1
+ ):
+ action_container_shape = action_spec[
+ self.tensor_keys.action[:-1]
+ ].shape
+ else:
+ action_container_shape = action_spec.shape
+ action_spec = action_spec[self.tensor_keys.action][
+ (0,) * len(action_container_shape)
+ ]
+ if not isinstance(action_spec, BoundedTensorSpec):
+ raise ValueError(
+ f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}."
+ )
+ low = action_spec.space.low
+ high = action_spec.space.high
+ else:
+ low, high = bounds
+ if not isinstance(low, torch.Tensor):
+ low = torch.tensor(low)
+ if not isinstance(high, torch.Tensor):
+ high = torch.tensor(high, device=low.device, dtype=low.dtype)
+ if (low > high).any():
+ raise ValueError("Got a low bound higher than a high bound.")
+ if device is not None:
+ low = low.to(device)
+ high = high.to(device)
+ self.register_buffer("max_action", high)
+ self.register_buffer("min_action", low)
+ self._vmap_qvalue_network00 = _vmap_func(
+ self.qvalue_network, randomness=self.vmap_randomness
+ )
+ self._vmap_actor_network00 = _vmap_func(
+ self.actor_network, randomness=self.vmap_randomness
+ )
+ self.reduction = reduction
+
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
+ if self._value_estimator is not None:
+ self._value_estimator.set_keys(
+ value=self._tensor_keys.state_action_value,
+ reward=self.tensor_keys.reward,
+ done=self.tensor_keys.done,
+ terminated=self.tensor_keys.terminated,
+ )
+ self._set_in_keys()
+
+ def _set_in_keys(self):
+ keys = [
+ self.tensor_keys.action,
+ ("next", self.tensor_keys.reward),
+ ("next", self.tensor_keys.done),
+ ("next", self.tensor_keys.terminated),
+ *self.actor_network.in_keys,
+ *[("next", key) for key in self.actor_network.in_keys],
+ *self.qvalue_network.in_keys,
+ ]
+ self._in_keys = list(set(keys))
+
+ @property
+ def in_keys(self):
+ if self._in_keys is None:
+ self._set_in_keys()
+ return self._in_keys
+
+ @in_keys.setter
+ def in_keys(self, values):
+ self._in_keys = values
+
+ @property
+ @_cache_values
+ def _cached_detach_qvalue_network_params(self):
+ return self.qvalue_network_params.detach()
+
+ @property
+ @_cache_values
+ def _cached_stack_actor_params(self):
+ return torch.stack(
+ [self.actor_network_params, self.target_actor_network_params], 0
+ )
+
+ def actor_loss(self, tensordict):
+ """Compute the actor loss.
+
+ The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates.
+
+ Args:
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
+ are required for this to be computed.
+ Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"`
+ used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda
+ value, and the lambda value `"lmbd"` itself.
+ """
+ tensordict_actor_grad = tensordict.select(
+ *self.actor_network.in_keys, strict=False
+ )
+ with self.actor_network_params.to_module(self.actor_network):
+ tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
+ actor_loss_td = tensordict_actor_grad.select(
+ *self.qvalue_network.in_keys, strict=False
+ ).expand(
+ self.num_qvalue_nets, *tensordict_actor_grad.batch_size
+ ) # for actor loss
+ state_action_value_actor = (
+ self._vmap_qvalue_network00(
+ actor_loss_td,
+ self._cached_detach_qvalue_network_params,
+ )
+ .get(self.tensor_keys.state_action_value)
+ .squeeze(-1)
+ )
+
+ bc_loss = torch.nn.functional.mse_loss(
+ tensordict_actor_grad.get(self.tensor_keys.action),
+ tensordict.get(self.tensor_keys.action),
+ )
+ lmbd = self.alpha / state_action_value_actor[0].abs().mean().detach()
+
+ loss_actor = -lmbd * state_action_value_actor[0] + bc_loss
+
+ metadata = {
+ "state_action_value_actor": state_action_value_actor[0].detach(),
+ "bc_loss": bc_loss.detach(),
+ "lmbd": lmbd,
+ }
+ loss_actor = _reduce(loss_actor, reduction=self.reduction)
+ return loss_actor, metadata
+
+ def qvalue_loss(self, tensordict):
+ """Compute the q-value loss.
+
+ The q-value loss should be computed before the :meth:`~.actor_loss`.
+
+ Args:
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
+ are required for this to be computed.
+ Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
+ the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`.
+ """
+ tensordict = tensordict.clone(False)
+
+ act = tensordict.get(self.tensor_keys.action)
+
+ # computing early for reprod
+ noise = (torch.randn_like(act) * self.policy_noise).clamp(
+ -self.noise_clip, self.noise_clip
+ )
+
+ with torch.no_grad():
+ next_td_actor = step_mdp(tensordict).select(
+ *self.actor_network.in_keys, strict=False
+ ) # next_observation ->
+ with self.target_actor_network_params.to_module(self.actor_network):
+ next_td_actor = self.actor_network(next_td_actor)
+ next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
+ self.min_action, self.max_action
+ )
+ next_td_actor.set(
+ self.tensor_keys.action,
+ next_action,
+ )
+ next_val_td = next_td_actor.select(
+ *self.qvalue_network.in_keys, strict=False
+ ).expand(
+ self.num_qvalue_nets, *next_td_actor.batch_size
+ ) # for next value estimation
+ next_target_q1q2 = (
+ self._vmap_qvalue_network00(
+ next_val_td,
+ self.target_qvalue_network_params,
+ )
+ .get(self.tensor_keys.state_action_value)
+ .squeeze(-1)
+ )
+ # min over the next target qvalues
+ next_target_qvalue = next_target_q1q2.min(0)[0]
+
+ # set next target qvalues
+ tensordict.set(
+ ("next", self.tensor_keys.state_action_value),
+ next_target_qvalue.unsqueeze(-1),
+ )
+
+ qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
+ self.num_qvalue_nets,
+ *tensordict.batch_size,
+ )
+ # preditcted current qvalues
+ current_qvalue = (
+ self._vmap_qvalue_network00(
+ qval_td,
+ self.qvalue_network_params,
+ )
+ .get(self.tensor_keys.state_action_value)
+ .squeeze(-1)
+ )
+
+ # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done))
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
+
+ td_error = (current_qvalue - target_value).pow(2)
+ loss_qval = distance_loss(
+ current_qvalue,
+ target_value.expand_as(current_qvalue),
+ loss_function=self.loss_function,
+ ).sum(0)
+ metadata = {
+ "td_error": td_error,
+ "next_state_value": next_target_qvalue.detach(),
+ "pred_value": current_qvalue.detach(),
+ "target_value": target_value.detach(),
+ }
+ loss_qval = _reduce(loss_qval, reduction=self.reduction)
+ return loss_qval, metadata
+
+ @dispatch
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
+ """The forward method.
+
+ Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns
+ a tensordict with these values.
+ To see what keys are expected in the input tensordict and what keys are expected as output, check the
+ class's `"in_keys"` and `"out_keys"` attributes.
+ """
+ tensordict_save = tensordict
+ loss_actor, metadata_actor = self.actor_loss(tensordict)
+ loss_qval, metadata_value = self.qvalue_loss(tensordict_save)
+ tensordict_save.set(
+ self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
+ )
+ if not loss_qval.shape == loss_actor.shape:
+ raise RuntimeError(
+ f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
+ )
+ td_out = TensorDict(
+ source={
+ "loss_actor": loss_actor,
+ "loss_qvalue": loss_qval,
+ **metadata_actor,
+ **metadata_value,
+ },
+ batch_size=[],
+ )
+ return td_out
+
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
+ if value_type is None:
+ value_type = self.default_value_estimator
+ self.value_type = value_type
+ hp = dict(default_value_kwargs(value_type))
+ if hasattr(self, "gamma"):
+ hp["gamma"] = self.gamma
+ hp.update(hyperparams)
+ # we do not need a value network bc the next state value is already passed
+ if value_type == ValueEstimators.TD1:
+ self._value_estimator = TD1Estimator(value_network=None, **hp)
+ elif value_type == ValueEstimators.TD0:
+ self._value_estimator = TD0Estimator(value_network=None, **hp)
+ elif value_type == ValueEstimators.GAE:
+ raise NotImplementedError(
+ f"Value type {value_type} it not implemented for loss {type(self)}."
+ )
+ elif value_type == ValueEstimators.TDLambda:
+ self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
+ else:
+ raise NotImplementedError(f"Unknown value type {value_type}")
+
+ tensor_keys = {
+ "value": self.tensor_keys.state_action_value,
+ "reward": self.tensor_keys.reward,
+ "done": self.tensor_keys.done,
+ "terminated": self.tensor_keys.terminated,
+ }
+ self._value_estimator.set_keys(**tensor_keys)
diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py
index 2c2f3fb21ac..b7fb8ab4ed2 100644
--- a/torchrl/record/recorder.py
+++ b/torchrl/record/recorder.py
@@ -221,11 +221,11 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation_trsf = make_grid(
obs_flat, nrow=int(math.ceil(math.sqrt(obs_flat.shape[0])))
)
- self.obs.append(observation_trsf.to(torch.uint8))
+ self.obs.append(observation_trsf.to("cpu", torch.uint8))
elif observation_trsf.ndimension() >= 4:
- self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4))
+ self.obs.extend(observation_trsf.to("cpu", torch.uint8).flatten(0, -4))
else:
- self.obs.append(observation_trsf.to(torch.uint8))
+ self.obs.append(observation_trsf.to("cpu", torch.uint8))
return observation
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
diff --git a/version.txt b/version.txt
index 1d0ba9ea182..8f0916f768f 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.4.0
+0.5.0