Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into add-python3.12-setup
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 10, 2024
2 parents 456eaab + ea79350 commit d4e5f30
Show file tree
Hide file tree
Showing 69 changed files with 4,819 additions and 875 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash

export TORCHRL_BUILD_VERSION=0.4.0
export TORCHRL_BUILD_VERSION=0.5.0

${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
49 changes: 35 additions & 14 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@
#
#

set -e
#set -e
set -v

# Initialize an error flag
error_occurred=0
# Function to handle errors
error_handler() {
echo "Error on line $1"
error_occurred=1
}
# Trap ERR to call the error_handler function with the failing line number
trap 'error_handler $LINENO' ERR

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
Expand All @@ -24,6 +34,7 @@ lib_dir="${env_dir}/lib"
# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export CUDA_LAUNCH_BLOCKING=1

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
#python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200
Expand Down Expand Up @@ -51,10 +62,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \
optim.gradient_steps=55 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
Expand Down Expand Up @@ -149,18 +162,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=4 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device= \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device= \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -200,8 +213,8 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.backend=csv \
logger.video=True \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
Expand Down Expand Up @@ -298,3 +311,11 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ba

coverage combine
coverage xml -i

# Check if any errors occurred during the script execution
if [ "$error_occurred" -ne 0 ]; then
echo "Errors occurred during script execution"
exit 1
else
echo "Script executed successfully"
fi
1 change: 1 addition & 0 deletions .github/workflows/build-wheels-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/td_script.sh
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
runner-type: macos-m1-stable
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/m1_script.sh
env-var-script: .github/scripts/td_script.sh
1 change: 1 addition & 0 deletions .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/td_script.sh
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr
- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py)
- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py)
- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py)
- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py)
- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py)
- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py)
- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py)
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
18 changes: 18 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ REDQ

REDQLoss

CrossQ
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ

IQL
----

Expand Down Expand Up @@ -160,6 +169,15 @@ TD3

TD3Loss

TD3+BC
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TD3BCLoss

PPO
---

Expand Down
10 changes: 5 additions & 5 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,26 @@ Checkpointing
-------------

The trainer class and hooks support checkpointing, which can be achieved either
using the ``torchsnapshot <https://github.com/pytorch/torchsnapshot/>``_ backend or
using the `torchsnapshot <https://github.com/pytorch/torchsnapshot/>`_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:

.. code-block::
$ CKPT_BACKEND=torch python script.py
$ CKPT_BACKEND=torchsnapshot python script.py
which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.

When building a trainer, one can provide a file path where the checkpoints are to
When building a trainer, one can provide a path where the checkpoints are to
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).

.. code-block::
>>> filepath = "path/to/dir/"
>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _main(argv):
if is_nightly:
tensordict_dep = "tensordict-nightly"
else:
tensordict_dep = "tensordict>=0.4.0"
tensordict_dep = "tensordict>=0.5.0"

if is_nightly:
version = get_nightly_version()
Expand Down
26 changes: 26 additions & 0 deletions sota-check/run_crossq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=crossq
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/crossq_%j.txt
#SBATCH --error=slurm_errors/crossq_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="crossq"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/crossq/crossq.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
fi
26 changes: 26 additions & 0 deletions sota-check/run_td3bc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=td3bc_offline
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/td3bc_offline_%j.txt
#SBATCH --error=slurm_errors/td3bc_offline_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="td3bc_offline"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log
fi
1 change: 1 addition & 0 deletions sota-check/submitit-release-check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ scripts=(
run_ppo_mujoco.sh
run_sac.sh
run_td3.sh
run_td3bc.sh
run_dt.sh
run_dt_online.sh
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
58 changes: 58 additions & 0 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# environment and task
env:
name: HalfCheetah-v4
task: ""
library: gym
max_episode_steps: 1000
seed: 42

# collector
collector:
total_frames: 1_000_000
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
env_per_collector: 1
reset_at_each_iter: False

# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: null

# optim
optim:
utd_ratio: 1.0
policy_update_delay: 3
gamma: 0.99
loss_function: l2
lr: 1.0e-3
weight_decay: 0.0
batch_size: 256
alpha_init: 1.0
adam_eps: 1.0e-8
beta1: 0.5
beta2: 0.999

# network
network:
batch_norm_momentum: 0.01
warmup_steps: 100000
critic_hidden_sizes: [2048, 2048]
actor_hidden_sizes: [256, 256]
critic_activation: relu
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"

# logging
logger:
backend: wandb
project_name: torchrl_example_crossQ
group_name: null
exp_name: ${env.name}_CrossQ
mode: online
eval_iter: 25000
Loading

0 comments on commit d4e5f30

Please sign in to comment.