Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into version-0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 10, 2024
2 parents d9a7178 + 28acf61 commit 971c2ed
Show file tree
Hide file tree
Showing 23 changed files with 4,695 additions and 751 deletions.
16 changes: 15 additions & 1 deletion .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \
optim.gradient_steps=55 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
Expand Down Expand Up @@ -160,6 +162,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device= \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device= \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.init_random_frames=10 \
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr
- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py)
- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py)
- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py)
- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py)
- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py)
- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py)
- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py)
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
18 changes: 18 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ REDQ

REDQLoss

CrossQ
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ

IQL
----

Expand Down Expand Up @@ -160,6 +169,15 @@ TD3

TD3Loss

TD3+BC
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TD3BCLoss

PPO
---

Expand Down
10 changes: 5 additions & 5 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,26 @@ Checkpointing
-------------

The trainer class and hooks support checkpointing, which can be achieved either
using the ``torchsnapshot <https://github.com/pytorch/torchsnapshot/>``_ backend or
using the `torchsnapshot <https://github.com/pytorch/torchsnapshot/>`_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:

.. code-block::
$ CKPT_BACKEND=torch python script.py
$ CKPT_BACKEND=torchsnapshot python script.py
which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.

When building a trainer, one can provide a file path where the checkpoints are to
When building a trainer, one can provide a path where the checkpoints are to
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).

.. code-block::
>>> filepath = "path/to/dir/"
>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
Expand Down
26 changes: 26 additions & 0 deletions sota-check/run_crossq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=crossq
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/crossq_%j.txt
#SBATCH --error=slurm_errors/crossq_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="crossq"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/crossq/crossq.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
fi
26 changes: 26 additions & 0 deletions sota-check/run_td3bc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=td3bc_offline
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/td3bc_offline_%j.txt
#SBATCH --error=slurm_errors/td3bc_offline_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="td3bc_offline"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log
fi
1 change: 1 addition & 0 deletions sota-check/submitit-release-check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ scripts=(
run_ppo_mujoco.sh
run_sac.sh
run_td3.sh
run_td3bc.sh
run_dt.sh
run_dt_online.sh
)
Expand Down
58 changes: 58 additions & 0 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# environment and task
env:
name: HalfCheetah-v4
task: ""
library: gym
max_episode_steps: 1000
seed: 42

# collector
collector:
total_frames: 1_000_000
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
env_per_collector: 1
reset_at_each_iter: False

# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: null

# optim
optim:
utd_ratio: 1.0
policy_update_delay: 3
gamma: 0.99
loss_function: l2
lr: 1.0e-3
weight_decay: 0.0
batch_size: 256
alpha_init: 1.0
adam_eps: 1.0e-8
beta1: 0.5
beta2: 0.999

# network
network:
batch_norm_momentum: 0.01
warmup_steps: 100000
critic_hidden_sizes: [2048, 2048]
actor_hidden_sizes: [256, 256]
critic_activation: relu
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"

# logging
logger:
backend: wandb
project_name: torchrl_example_crossQ
group_name: null
exp_name: ${env.name}_CrossQ
mode: online
eval_iter: 25000
Loading

0 comments on commit 971c2ed

Please sign in to comment.