Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into pettingzoo_tuto
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 17, 2024
2 parents c744307 + 730dd45 commit 56219b0
Show file tree
Hide file tree
Showing 23 changed files with 1,085 additions and 176 deletions.
8 changes: 5 additions & 3 deletions .github/unittest/linux_libs/scripts_habitat/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ if [ ! -d "${env_dir}" ]; then
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
fi
conda activate "${env_dir}"
#pip3 uninstall cython -y
#pip uninstall cython -y
#conda uninstall cython -y

# set debug variables
conda env config vars set MAGNUM_LOG=debug HABITAT_SIM_LOG=debug
conda deactivate && conda activate "${env_dir}"

pip3 install "cython<3"
conda install -c anaconda cython="<3.0.0" -y

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- protobuf
- pip:
# Initial version is required to install Atari ROMS in setup_env.sh
- gym==0.13
- gymnasium
- hypothesis
- future
- cloudpickle
Expand Down
7 changes: 6 additions & 1 deletion .github/unittest/linux_libs/scripts_robohive/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,9 @@ conda env update --file "${this_dir}/environment.yml" --prune

conda install conda-forge::ffmpeg -y

pip install git+https://github.com/vikashplus/robohive@main
pip install robohive

python3 -m robohive_init

# make sure only gymnasium is available
# pip uninstall gym -y
6 changes: 3 additions & 3 deletions .github/workflows/test-linux-habitat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ jobs:
tests:
strategy:
matrix:
python_version: ["3.9"] # "3.8", "3.9", "3.10", "3.11"
cuda_arch_version: ["11.6"] # "11.6", "11.7"
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.g5.4xlarge.nvidia.gpu
repository: pytorch/rl
docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
docker-image: "nvidia/cuda:12.1.1-devel-ubuntu22.04"
gpu-arch-type: cuda
gpu-arch-version: ${{ matrix.cuda_arch_version }}
timeout: 90
Expand Down
16 changes: 13 additions & 3 deletions .github/workflows/test-linux-libs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ jobs:
unittests-brax:
strategy:
matrix:
python_version: ["3.9"]
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "11.7"
docker-image: "pytorch/manylinux-cuda124"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
Expand All @@ -73,7 +75,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export PYTHON_VERSION="3.11"
export CU_VERSION="12.1"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
Expand Down Expand Up @@ -123,7 +125,7 @@ jobs:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
Expand Down Expand Up @@ -224,12 +226,14 @@ jobs:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "11.7"
docker-image: "pytorch/manylinux-cuda124"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
Expand Down Expand Up @@ -324,12 +328,14 @@ jobs:
bash .github/unittest/linux_libs/scripts_openx/post_process.sh
unittests-pettingzoo:
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "11.7"
docker-image: "pytorch/manylinux-cuda124"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
Expand Down Expand Up @@ -360,6 +366,7 @@ jobs:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
Expand Down Expand Up @@ -468,6 +475,7 @@ jobs:
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "11.7"
docker-image: "pytorch/manylinux-cuda124"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
Expand Down Expand Up @@ -532,12 +540,14 @@ jobs:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "11.7"
docker-image: "pytorch/manylinux-cuda124"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
Expand Down
78 changes: 78 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,82 @@ single agent standards.
MarlGroupMapType
check_marl_grouping

Auto-resetting Envs
-------------------

Auto-resetting environments are environments where calls to :meth:`~torchrl.envs.EnvBase.reset` are not expected when
the environment reaches a ``"done"`` state during a rollout, as the reset happens automatically.
Usually, in such cases the observations delivered with the done and reward (which effectively result from performing the
action in the environment) are actually the first observations of a new episode, and not the last observations of the
current episode.

To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations
that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both
:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations).
This transform class also provides a fine-grained control over the behaviour to be adopted for the invalid observations,
which can be masked with `"nan"` or any other values, or not masked at all.

To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument
during construction. If provided, an ``auto_reset_replace`` argument can also control whether the values of the last
observation of an episode should be replaced with some placeholder or not.

>>> from torchrl.envs import GymEnv
>>> from torchrl.envs import set_gym_backend
>>> import torch
>>> torch.manual_seed(0)
>>>
>>> class AutoResettingGymEnv(GymEnv):
... def _step(self, tensordict):
... tensordict = super()._step(tensordict)
... if tensordict["done"].any():
... td_reset = super().reset()
... tensordict.update(td_reset.exclude(*self.done_keys))
... return tensordict
...
... def _reset(self, tensordict=None):
... if tensordict is not None and "_reset" in tensordict:
... return tensordict.copy()
... return super()._reset(tensordict)
>>>
>>> with set_gym_backend("gym"):
... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True)
... env.set_seed(0)
... r = env.rollout(30, break_when_any_done=False)
>>> print(r["next", "done"].squeeze())
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, True, False, False, False, False, False, False,
False, False, False, False, False, True, False, False, False, False])
>>> print("observation after reset are set as nan", r["next", "observation"])
observation after reset are set as nan tensor([[-4.3633e-02, -1.4877e-01, 1.2849e-02, 2.7584e-01],
[-4.6609e-02, 4.6166e-02, 1.8366e-02, -1.2761e-02],
[-4.5685e-02, 2.4102e-01, 1.8111e-02, -2.9959e-01],
[-4.0865e-02, 4.5644e-02, 1.2119e-02, -1.2542e-03],
[-3.9952e-02, 2.4059e-01, 1.2094e-02, -2.9009e-01],
[-3.5140e-02, 4.3554e-01, 6.2920e-03, -5.7893e-01],
[-2.6429e-02, 6.3057e-01, -5.2867e-03, -8.6963e-01],
[-1.3818e-02, 8.2576e-01, -2.2679e-02, -1.1640e+00],
[ 2.6972e-03, 1.0212e+00, -4.5959e-02, -1.4637e+00],
[ 2.3121e-02, 1.2168e+00, -7.5232e-02, -1.7704e+00],
[ 4.7457e-02, 1.4127e+00, -1.1064e-01, -2.0854e+00],
[ 7.5712e-02, 1.2189e+00, -1.5235e-01, -1.8289e+00],
[ 1.0009e-01, 1.0257e+00, -1.8893e-01, -1.5872e+00],
[ nan, nan, nan, nan],
[-3.9405e-02, -1.7766e-01, -1.0403e-02, 3.0626e-01],
[-4.2959e-02, -3.7263e-01, -4.2775e-03, 5.9564e-01],
[-5.0411e-02, -5.6769e-01, 7.6354e-03, 8.8698e-01],
[-6.1765e-02, -7.6292e-01, 2.5375e-02, 1.1820e+00],
[-7.7023e-02, -9.5836e-01, 4.9016e-02, 1.4826e+00],
[-9.6191e-02, -7.6387e-01, 7.8667e-02, 1.2056e+00],
[-1.1147e-01, -9.5991e-01, 1.0278e-01, 1.5219e+00],
[-1.3067e-01, -7.6617e-01, 1.3322e-01, 1.2629e+00],
[-1.4599e-01, -5.7298e-01, 1.5848e-01, 1.0148e+00],
[-1.5745e-01, -7.6982e-01, 1.7877e-01, 1.3527e+00],
[-1.7285e-01, -9.6668e-01, 2.0583e-01, 1.6956e+00],
[ nan, nan, nan, nan],
[-4.3962e-02, 1.9845e-01, -4.5015e-02, -2.5903e-01],
[-3.9993e-02, 3.9418e-01, -5.0196e-02, -5.6557e-01],
[-3.2109e-02, 5.8997e-01, -6.1507e-02, -8.7363e-01],
[-2.0310e-02, 3.9574e-01, -7.8980e-02, -6.0090e-01]])


Transforms
Expand Down Expand Up @@ -580,6 +656,8 @@ to be able to create this other composition:
Transform
TransformedEnv
ActionMask
AutoResetEnv
AutoResetTransform
BatchSizeTransform
BinarizeReward
BurnInTransform
Expand Down
48 changes: 48 additions & 0 deletions examples/video/video-from-dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Video from dataset example.
This example shows how to save a video from a dataset.
To run it, you will need to install the openx requirements as well as torchvision.
"""

from torchrl.data.datasets import OpenXExperienceReplay
from torchrl.record import CSVLogger, VideoRecorder

# Create a logger that saves videos as mp4
logger = CSVLogger("./dump", video_format="mp4")


# We use the VideoRecorder transform to save register the images coming from the batch.
t = VideoRecorder(
logger=logger, tag="pixels", in_keys=[("next", "observation", "image")]
)
# Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
dataset = OpenXExperienceReplay(
"cmu_stretch",
batch_size=2000,
slice_len=200,
download=True,
strict_length=False,
transform=t,
)

# Get a batch of data and visualize it
for _ in dataset:
# The transform has seen the data since it's in the replay buffer
t.dump()
break

# Alternatively, we can build the dataset without the VideoRecorder and call it manually:
dataset = OpenXExperienceReplay(
"cmu_stretch",
batch_size=2000,
slice_len=200,
download=True,
strict_length=False,
)

# Get a batch of data and visualize it
for data in dataset:
t(data)
t.dump()
break
Loading

0 comments on commit 56219b0

Please sign in to comment.