diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh index 8e929443ef6..ff1a899b06b 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/m1_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export BUILD_VERSION=0.3.0 +export BUILD_VERSION=0.3.1 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 2340dbb2e54..137255a786c 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -124,15 +124,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" @@ -150,7 +150,11 @@ else fi # install tensordict -pip3 install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi printf "* Installing torchrl\n" python setup.py develop diff --git a/.github/unittest/linux_distributed/scripts/install.sh b/.github/unittest/linux_distributed/scripts/install.sh index 95eda22aecb..a05bfd6be5e 100755 --- a/.github/unittest/linux_distributed/scripts/install.sh +++ b/.github/unittest/linux_distributed/scripts/install.sh @@ -26,11 +26,21 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U + else + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U + fi else - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + printf "Failed to install pytorch" + exit 1 fi # smoke test @@ -40,7 +50,11 @@ python -c "import functorch" pip install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi printf "* Installing torchrl\n" python setup.py develop diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 74fdc043f0a..d6c3fe47f97 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -148,7 +148,7 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION +pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION # smoke test python -c "import functorch" @@ -157,7 +157,11 @@ python -c "import functorch" pip install git+https://github.com/pytorch/torchsnapshot # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi printf "* Installing torchrl\n" python setup.py develop diff --git a/.github/unittest/linux_libs/scripts_ataridqn/install.sh b/.github/unittest/linux_libs/scripts_ataridqn/install.sh index 1be476425a6..15d4fa389e2 100755 --- a/.github/unittest/linux_libs/scripts_ataridqn/install.sh +++ b/.github/unittest/linux_libs/scripts_ataridqn/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_brax/install.sh b/.github/unittest/linux_libs/scripts_brax/install.sh index 93c1f113b52..80efdc536ab 100755 --- a/.github/unittest/linux_libs/scripts_brax/install.sh +++ b/.github/unittest/linux_libs/scripts_brax/install.sh @@ -25,14 +25,22 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall --progress-bar off +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall --progress-bar off + printf "Failed to install pytorch" + exit 1 fi # install tensordict diff --git a/.github/unittest/linux_libs/scripts_d4rl/install.sh b/.github/unittest/linux_libs/scripts_d4rl/install.sh index 2eb52b8f65e..b53dae7ae9d 100755 --- a/.github/unittest/linux_libs/scripts_d4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_envpool/install.sh b/.github/unittest/linux_libs/scripts_envpool/install.sh index c62a2de25fb..0f60dd233ce 100755 --- a/.github/unittest/linux_libs/scripts_envpool/install.sh +++ b/.github/unittest/linux_libs/scripts_envpool/install.sh @@ -26,11 +26,11 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" +printf "Installing PyTorch with cu121" if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi # smoke test diff --git a/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh b/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh index 1be476425a6..15d4fa389e2 100755 --- a/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh +++ b/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_gym/install.sh b/.github/unittest/linux_libs/scripts_gym/install.sh index 718e4f37e3a..79d5a369d14 100755 --- a/.github/unittest/linux_libs/scripts_gym/install.sh +++ b/.github/unittest/linux_libs/scripts_gym/install.sh @@ -37,16 +37,20 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then - conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 cpuonly -c pytorch + conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch else - conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y + conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has -pip install -U --force-reinstall charset-normalizer +pip install -U charset-normalizer # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" @@ -54,3 +58,11 @@ python -c "import tensordict" printf "* Installing torchrl\n" python setup.py develop python -c "import torchrl" + +## Reinstalling pytorch with specific version +#printf "Re-installing PyTorch with %s\n" "${CU_VERSION}" +#if [ "${CU_VERSION:-}" == cpu ] ; then +# conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch +#else +# conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y +#fi diff --git a/.github/unittest/linux_libs/scripts_habitat/install.sh b/.github/unittest/linux_libs/scripts_habitat/install.sh index 071af690448..47f6fdc974c 100755 --- a/.github/unittest/linux_libs/scripts_habitat/install.sh +++ b/.github/unittest/linux_libs/scripts_habitat/install.sh @@ -20,10 +20,19 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U +elif [[ "$TORCH_VERSION" == "stable" ]]; then + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 +fi # install tensordict -pip3 install git+https://github.com/pytorch/tensordict.git +# install tensordict +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python3 -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh index d287f8a5977..56064fce082 100755 --- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -65,7 +65,8 @@ pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune -conda install habitat-sim withbullet headless -c conda-forge -c aihabitat-nightly -y -conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-lab +#conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y +conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y +conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git@stable#subdirectory=habitat-lab #conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-baselines conda run python -m pip install "gym[atari,accept-rom-license]" pygame diff --git a/.github/unittest/linux_libs/scripts_jumanji/install.sh b/.github/unittest/linux_libs/scripts_jumanji/install.sh index 3d6ad9ed450..27062a8cdad 100755 --- a/.github/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/install.sh @@ -25,18 +25,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_minari/install.sh b/.github/unittest/linux_libs/scripts_minari/install.sh index 2eb52b8f65e..b53dae7ae9d 100755 --- a/.github/unittest/linux_libs/scripts_minari/install.sh +++ b/.github/unittest/linux_libs/scripts_minari/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh index 1be476425a6..15d4fa389e2 100755 --- a/.github/unittest/linux_libs/scripts_openx/install.sh +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/install.sh b/.github/unittest/linux_libs/scripts_pettingzoo/install.sh index fb82bcb4ea8..278e975506e 100755 --- a/.github/unittest/linux_libs/scripts_pettingzoo/install.sh +++ b/.github/unittest/linux_libs/scripts_pettingzoo/install.sh @@ -25,18 +25,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh index 31a6b2b56d4..5607f04b2d8 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh b/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh index 873962164d6..ded09c0c27b 100755 --- a/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh +++ b/.github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh @@ -36,18 +36,30 @@ esac # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_libs/scripts_roboset/install.sh b/.github/unittest/linux_libs/scripts_roboset/install.sh index 2eb52b8f65e..b53dae7ae9d 100755 --- a/.github/unittest/linux_libs/scripts_roboset/install.sh +++ b/.github/unittest/linux_libs/scripts_roboset/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_sklearn/install.sh b/.github/unittest/linux_libs/scripts_sklearn/install.sh index 2eb52b8f65e..b53dae7ae9d 100755 --- a/.github/unittest/linux_libs/scripts_sklearn/install.sh +++ b/.github/unittest/linux_libs/scripts_sklearn/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_smacv2/install.sh b/.github/unittest/linux_libs/scripts_smacv2/install.sh index fb82bcb4ea8..278e975506e 100755 --- a/.github/unittest/linux_libs/scripts_smacv2/install.sh +++ b/.github/unittest/linux_libs/scripts_smacv2/install.sh @@ -25,18 +25,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index 1be476425a6..15d4fa389e2 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -28,18 +28,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch;import tensordict" diff --git a/.github/unittest/linux_libs/scripts_vmas/install.sh b/.github/unittest/linux_libs/scripts_vmas/install.sh index fb82bcb4ea8..278e975506e 100755 --- a/.github/unittest/linux_libs/scripts_vmas/install.sh +++ b/.github/unittest/linux_libs/scripts_vmas/install.sh @@ -25,18 +25,30 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - # conda install -y pytorch torchvision cpuonly -c pytorch-nightly - # use pip to install pytorch as conda can frequently pick older release -# conda install -y pytorch cpuonly -c pytorch-nightly - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall + printf "Failed to install pytorch" + exit 1 fi # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index f55daf8e8ce..fd259251dc9 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -37,16 +37,20 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then - conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 cpuonly -c pytorch + conda install pytorch==1.13.1 torchvision==0.14.1 cpuonly -c pytorch else - conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y + conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has -pip install -U --force-reinstall charset-normalizer +pip install -U charset-normalizer # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import tensordict" diff --git a/.github/unittest/linux_optdeps/scripts/install.sh b/.github/unittest/linux_optdeps/scripts/install.sh index e7d48b4cb9b..c5d7ed062f1 100755 --- a/.github/unittest/linux_optdeps/scripts/install.sh +++ b/.github/unittest/linux_optdeps/scripts/install.sh @@ -23,7 +23,11 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION # install tensordict -pip install git+https://github.com/pytorch/tensordict.git +if [[ "$TORCH_VERSION" == "nightly" ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi # smoke test python -c "import functorch" diff --git a/.github/unittest/windows_optdepts/scripts/install.sh b/.github/unittest/windows_optdepts/scripts/install.sh index 506a4a0b035..a8e02795326 100644 --- a/.github/unittest/windows_optdepts/scripts/install.sh +++ b/.github/unittest/windows_optdepts/scripts/install.sh @@ -40,9 +40,9 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${cudatoolkit}" if $torch_cuda ; then - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 + python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 else - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U fi torch_cuda=$(python -c "import torch; print(torch.cuda.is_available())") diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 01d880708f4..8eaed2fb825 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -30,7 +30,7 @@ jobs: python-version: 3.8 - name: Setup Environment run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U python -m pip install git+https://github.com/pytorch/tensordict python setup.py develop python -m pip install pytest pytest-benchmark @@ -91,7 +91,7 @@ jobs: echo /usr/local/bin >> $GITHUB_PATH - name: Setup Environment run: | - python3 -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 + python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U python3 -m pip install git+https://github.com/pytorch/tensordict python3 setup.py develop python3 -m pip install pytest pytest-benchmark diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 0f0ad3e5723..a8a1bc4c8dc 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -29,7 +29,7 @@ jobs: python-version: 3.8 - name: Setup Environment run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U python -m pip install git+https://github.com/pytorch/tensordict python setup.py develop python -m pip install pytest pytest-benchmark @@ -102,7 +102,7 @@ jobs: echo /usr/local/bin >> $GITHUB_PATH - name: Setup Environment run: | - python3 -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 + python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U python3 -m pip install git+https://github.com/pytorch/tensordict python3 setup.py develop python3 -m pip install pytest pytest-benchmark diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3f7ce76885a..bcf6b5066c3 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -59,7 +59,7 @@ jobs: git version # 5. Install PyTorch - python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu --quiet --root-user-action=ignore + python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --quiet --root-user-action=ignore # 6. Install tensordict python3 -m pip install git+https://github.com/pytorch/tensordict.git --quiet --root-user-action=ignore diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 42ff3c77251..82c6d84c231 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -84,7 +84,7 @@ jobs: uses: actions/checkout@v2 - name: Install PyTorch nightly run: | - python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U - name: Build TorchRL Nightly run: | export CC=clang CXX=clang++ @@ -116,7 +116,7 @@ jobs: uses: actions/checkout@v2 - name: Install PyTorch Nightly run: | - python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U - name: Upgrade pip run: | python3 -mpip install --upgrade pip @@ -289,7 +289,7 @@ jobs: - name: Install PyTorch nightly shell: bash run: | - python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U - name: Build TorchRL nightly shell: bash run: | @@ -322,7 +322,7 @@ jobs: - name: Install PyTorch Nightly shell: bash run: | - python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U - name: Upgrade pip shell: bash run: | diff --git a/.github/workflows/test-linux-examples.yml b/.github/workflows/test-linux-examples.yml index bfc6884bc7a..0981edc5684 100644 --- a/.github/workflows/test-linux-examples.yml +++ b/.github/workflows/test-linux-examples.yml @@ -39,6 +39,14 @@ jobs: export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" diff --git a/.github/workflows/test-linux-habitat.yml b/.github/workflows/test-linux-habitat.yml index 734052241d6..1459bbfd79c 100644 --- a/.github/workflows/test-linux-habitat.yml +++ b/.github/workflows/test-linux-habitat.yml @@ -31,6 +31,14 @@ jobs: gpu-arch-version: ${{ matrix.cuda_arch_version }} timeout: 90 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} # Commenting these out for now because the GPU test are not working inside docker diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index abf78e5e19c..9ef201bd103 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -30,6 +30,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -55,6 +63,14 @@ jobs: gpu-arch-version: "11.7" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" @@ -81,6 +97,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -107,6 +131,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -133,6 +165,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -159,6 +199,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euxo pipefail export PYTHON_VERSION="3.9" # export CU_VERSION="${{ inputs.gpu-arch-version }}" @@ -184,6 +232,14 @@ jobs: gpu-arch-version: "11.7" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="12.1" @@ -212,6 +268,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -238,6 +302,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -260,6 +332,14 @@ jobs: gpu-arch-version: "11.7" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="12.1" @@ -287,6 +367,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -312,6 +400,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -339,6 +435,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -366,6 +470,14 @@ jobs: gpu-arch-version: "11.7" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="12.1" @@ -394,6 +506,14 @@ jobs: docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu117" @@ -420,6 +540,14 @@ jobs: gpu-arch-version: "11.7" timeout: 120 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="12.1" diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 508e9525dea..6473634af80 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -31,10 +31,16 @@ jobs: docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04" timeout: 90 script: | + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} export CU_VERSION="cpu" - export TORCH_VERSION=nightly echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" @@ -62,7 +68,13 @@ jobs: # Commenting these out for now because the GPU test are not working inside docker export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" - export TORCH_VERSION=nightly + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" @@ -90,7 +102,13 @@ jobs: export PYTHON_VERSION="3.9" export CU_VERSION="cu116" export TAR_OPTIONS="--no-same-owner" - export UPLOAD_CHANNEL="nightly" + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi export TF_CPP_MIN_LOG_LEVEL=0 @@ -121,6 +139,14 @@ jobs: # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" @@ -147,7 +173,13 @@ jobs: # Commenting these out for now because the GPU test are not working inside docker export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" - export TORCH_VERSION=stable + if [[ "${{ github.ref == 'release/*' }}" ]]; then + export RELEASE=1 + export TORCH_VERSION=nightly + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 47c1b0c6fec..e910ba4201b 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -4,7 +4,7 @@ on: types: [opened, synchronize, reopened] push: branches: - - release/0.3.0 + - release/0.4.0 concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.3.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -72,7 +72,7 @@ jobs: run: | export CC=clang CXX=clang++ python3 -mpip install wheel - BUILD_VERSION=0.3.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: @@ -104,7 +104,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - BUILD_VERSION=0.3.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/README.md b/README.md index 2e1d08a0757..6adbc2decfe 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,11 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library. +## Getting started + +Check our [Getting Started tutorials](https://pytorch.org/rl/index.html#getting-started) for quickly ramp up with the basic +features of the library! + ## Documentation and knowledge base The TorchRL documentation can be found [here](https://pytorch.org/rl). diff --git a/docs/source/conf.py b/docs/source/conf.py index f0821ede0bf..060103b48b4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -189,3 +189,8 @@ generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial") # generate_tutorial_references("../../tutorials/src/", "src") generate_tutorial_references("../../tutorials/media/", "media") + +# We do this to indicate that the script is run by sphinx +import builtins + +builtins.__sphinx_build__ = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 91906abb857..ab1cee681db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,7 +11,14 @@ TorchRL TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. -It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. +You can install TorchRL directly from PyPI (see more about installation +instructions in the dedicated section below): + +.. code-block:: + + $ pip install torchrl + +TorchRL provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort. This repo attempts to align with the existing pytorch ecosystem libraries in that it has a "dataset pillar" @@ -30,6 +37,49 @@ TorchRL aims at a high modularity and good runtime performance. To read more about TorchRL philosophy and capabilities beyond this API reference, check the `TorchRL paper `__. +Installation +============ + +TorchRL releases are synced with PyTorch, so make sure you always enjoy the latest +features of the library with the `most recent version of PyTorch `__ (although core features +are guaranteed to be backward compatible with pytorch>=1.13). +Nightly releases can be installed via + +.. code-block:: + + $ pip install tensordict-nightly + $ pip install torchrl-nightly + +or via a ``git clone`` if you're willing to contribute to the library: + +.. code-block:: + + $ cd path/to/root + $ git clone https://github.com/pytorch/tensordict + $ git clone https://github.com/pytorch/rl + $ cd tensordict + $ python setup.py develop + $ cd ../rl + $ python setup.py develop + +Getting started +=============== + +A series of quick tutorials to get ramped up with the basic features of the +library. If you're in a hurry, you can start by +:ref:`the last item of the series ` +and navigate to the previous ones whenever you want to learn more! + +.. toctree:: + :maxdepth: 1 + + tutorials/getting-started-0 + tutorials/getting-started-1 + tutorials/getting-started-2 + tutorials/getting-started-3 + tutorials/getting-started-4 + tutorials/getting-started-5 + Tutorials ========= diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index aa8de179f20..982b8664862 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -3,6 +3,8 @@ torchrl.collectors package ========================== +.. _ref_collectors: + Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they collect data over non-static data sources and (2) the data is collected using a model (likely a version of the model that is being trained). diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 6ed32ebe921..3fff3e5fdc1 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -3,6 +3,8 @@ torchrl.data package ==================== +.. _ref_data: + Replay Buffers -------------- @@ -699,6 +701,9 @@ efficient sampling. TokenizedDatasetLoader create_infinite_iterator get_dataloader + ConstantKLController + AdaptiveKLController + Utils ----- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index cce34e14b14..4dbb5a5da57 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -475,6 +475,9 @@ single agent standards. Transforms ---------- + +.. _transforms: + .. currentmodule:: torchrl.envs.transforms In most cases, the raw output of an environment must be treated before being passed to another object (such as a diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index d859140bb70..bcd234a7ff9 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -3,9 +3,13 @@ torchrl.modules package ======================= +.. _ref_modules: + TensorDict modules: Actors, exploration, value models and generative models --------------------------------------------------------------------------- +.. _tdmodules: + TorchRL offers a series of module wrappers aimed at making it easy to build RL models from the ground up. These wrappers are exclusively based on :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`. diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 1aec88f2d11..c2f43d8e9b6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -3,6 +3,8 @@ torchrl.objectives package ========================== +.. _ref_objectives: + TorchRL provides a series of losses to use in your training scripts. The aim is to have losses that are easily reusable/swappable and that have a simple signature. diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index eb857f15a0f..04d4386c631 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -218,6 +218,8 @@ Utils Loggers ------- +.. _ref_loggers: + .. currentmodule:: torchrl.record.loggers .. autosummary:: diff --git a/examples/distributed/collectors/multi_nodes/lol.py b/examples/distributed/collectors/multi_nodes/lol.py deleted file mode 100644 index 89d5e66b487..00000000000 --- a/examples/distributed/collectors/multi_nodes/lol.py +++ /dev/null @@ -1,3 +0,0 @@ -from torchrl.envs.libs.gym import GymEnv - -env = GymEnv("ALE/Pong-v5") diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 1408e47e915..bb374b99941 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index d4ed03ad3c6..bed6240d244 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index 8f4a2356c35..6b7206511ca 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -17,6 +17,7 @@ from torchrl.data.replay_buffers.storages import LazyTensorStorage from torchrl.envs import RewardSum, TransformedEnv from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import ClipPPOLoss, ValueEstimators @@ -175,7 +176,7 @@ def train(cfg: "DictConfig"): # noqa: F821 loss_module.value_estimator( tensordict_data, params=loss_module.critic_network_params, - target_params=loss_module.target_critic_params, + target_params=loss_module.target_critic_network_params, ) current_frames = tensordict_data.numel() total_frames += current_frames @@ -235,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index e814ce8f79f..008e01b28b9 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 28317dba728..d76ddd1f913 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index a921e58bad6..94d9234db2a 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -100,9 +100,7 @@ def main(cfg): # using a Gym-like API (querying steps etc) introduces some # extra code that we can spare. # - kl_scheduler = AdaptiveKLController( - model, init_kl_coef=0.1, target=6, horizon=10000 - ) + kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000) rollout_from_model = RolloutFromModel( model, ref_model, diff --git a/setup.py b/setup.py index f31a2ed9f5c..c2e2b89cf20 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,7 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.3.0" + tensordict_dep = "tensordict>=0.3.1" if is_nightly: version = get_nightly_version() diff --git a/test/conftest.py b/test/conftest.py index 5ce980a4080..2dcd369003a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -53,7 +53,7 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -65,6 +65,11 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"Skipping device Apple Paravirtual device", + ) warnings.filterwarnings( "ignore", category=DeprecationWarning, diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 7a32c9a38ef..d68c7f30aa3 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1072,7 +1072,7 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(torch.int).to(self.device) + self.count += action.to(dtype=torch.int, device=self.device) tensordict = TensorDict( source={ "observation": self.count.clone(), @@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): 3, ) ), + device=self.device, ) self.unbatched_action_spec = CompositeSpec( lazy=action_specs, + device=self.device, ) self.unbatched_reward_spec = CompositeSpec( { @@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): }, shape=(self.n_nested_dim,), ) - } + }, + device=self.device, ) self.unbatched_done_spec = CompositeSpec( { @@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): }, shape=(self.n_nested_dim,), ) - } + }, + device=self.device, ) self.action_spec = self.unbatched_action_spec.expand( @@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i): "lidar": lidar, "vector": vector_3d, "tensor_0": tensor_0, - } + }, + device=self.device, ) elif i == 1: return CompositeSpec( @@ -1497,7 +1502,8 @@ def get_agent_obs_spec(self, i): "lidar": lidar, "vector": vector_2d, "tensor_1": tensor_1, - } + }, + device=self.device, ) elif i == 2: return CompositeSpec( @@ -1505,7 +1511,8 @@ def get_agent_obs_spec(self, i): "camera": camera, "vector": vector_2d, "tensor_2": tensor_2, - } + }, + device=self.device, ) else: raise ValueError(f"Index {i} undefined for index 3") diff --git a/test/smoke_test.py b/test/smoke_test.py index 313c786088c..c6500deb5e8 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -14,3 +14,5 @@ def test_imports(): from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 from torchrl.modules import SafeModule # noqa: F401 from torchrl.objectives.common import LossModule # noqa: F401 + + PrioritizedReplayBuffer(alpha=1.1, beta=1.1) diff --git a/test/test_collector.py b/test/test_collector.py index b5afe7f35d7..09c6ee293c3 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1675,8 +1675,12 @@ def test_maxframes_error(): @pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) @pytest.mark.parametrize("env_device", [None, *get_available_devices()]) @pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("parallel", [False, True]) def test_reset_heterogeneous_envs( - policy_device: torch.device, env_device: torch.device, storing_device: torch.device + policy_device: torch.device, + env_device: torch.device, + storing_device: torch.device, + parallel, ): if ( policy_device is not None @@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs( env_device = torch.device("cpu") # explicit mapping elif env_device is not None and env_device.type == "cuda" and policy_device is None: policy_device = torch.device("cpu") - env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2)) - env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3)) - env = SerialEnv(2, [env1, env2], device=env_device) + env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2)) + env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3)) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + env = cls(2, [env1, env2], device=env_device) collector = SyncDataCollector( env, RandomPolicy(env.action_spec), @@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs( assert ( data[0]["next", "truncated"].squeeze() == torch.tensor([False, True], device=data_device).repeat(25)[:50] - ).all(), data[0]["next", "truncated"][:10] + ).all(), data[0]["next", "truncated"] assert ( data[1]["next", "truncated"].squeeze() == torch.tensor([False, False, True], device=data_device).repeat(17)[:50] diff --git a/test/test_cost.py b/test/test_cost.py index c6eb27172ee..2917b816367 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -25,6 +25,7 @@ TensorDictSequential, TensorDictSequential as Seq, ) +from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type from torchrl.modules.models import QMixer @@ -155,6 +156,10 @@ ) +# Capture all warnings +pytestmark = pytest.mark.filterwarnings("error") + + class _check_td_steady: def __init__(self, td): self.td_clone = td.clone() @@ -500,6 +505,11 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): else contextlib.nullcontext() ), _check_td_steady(td): loss = loss_fn(td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in td.keys() sum([item for _, item in loss.items()]).backward() @@ -561,6 +571,10 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -600,7 +614,7 @@ def test_dqn_tensordict_keys(self, td_est): torch.manual_seed(self.seed) action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) default_keys = { "advantage": "advantage", @@ -616,7 +630,7 @@ def test_dqn_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -629,7 +643,7 @@ def test_dqn_tensordict_keys(self, td_est): actor = self._create_mock_actor( action_spec_type=action_spec_type, action_value_key="chosen_action_value_2" ) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) key_mapping = { "value": ("value", "chosen_action_value_2"), } @@ -656,11 +670,14 @@ def test_dqn_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = DQNLoss(actor, loss_function="l2") + loss_fn = DQNLoss(actor, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) if td_est is not None: loss_fn.make_value_estimator(td_est) + + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -706,6 +723,10 @@ def test_distributional_dqn( sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): @@ -743,7 +764,7 @@ def test_dqn_notensordict( module=module, in_keys=[observation_key], ) - dqn_loss = DQNLoss(actor) + dqn_loss = DQNLoss(actor, delay_value=True) dqn_loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) @@ -761,6 +782,8 @@ def test_dqn_notensordict( "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") + # Disable warning + SoftUpdate(dqn_loss, eps=0.5) loss_val = dqn_loss(**kwargs) loss_val_td = dqn_loss(td) torch.testing.assert_close(loss_val_td.get("loss"), loss_val) @@ -774,7 +797,7 @@ def test_distributional_dqn_tensordict_keys(self): action_spec_type=action_spec_type, atoms=atoms ) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma) + loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=True) default_keys = { "priority": "td_error", @@ -809,11 +832,14 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): action_key=tensor_keys["action"], action_value_key=tensor_keys["action_value"], ) - loss_fn = DistributionalDQNLoss(actor, gamma=0.9) + loss_fn = DistributionalDQNLoss(actor, gamma=0.9, delay_value=True) loss_fn.set_keys(**tensor_keys) loss_fn.make_value_estimator(td_est) + # remove warnings + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -983,6 +1009,10 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): @@ -1050,6 +1080,11 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in ms_td.keys() with torch.no_grad(): @@ -1104,7 +1139,7 @@ def test_qmix_tensordict_keys(self, td_est): action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) mixer = self._create_mock_mixer() - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) default_keys = { "advantage": "advantage", @@ -1121,7 +1156,7 @@ def test_qmix_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -1137,7 +1172,7 @@ def test_qmix_tensordict_keys(self, td_est): mixer = self._create_mock_mixer( global_chosen_action_value_key=("some", "nested") ) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "global_value": ("value", ("some", "nested")), } @@ -1172,9 +1207,9 @@ def test_qmix_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = QMixerLoss(actor, mixer, loss_function="l2") + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) - + SoftUpdate(loss_fn, eps=0.5) if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -1230,7 +1265,9 @@ def test_mixer_keys( ) td = actor(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) # Wthout etting the keys if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1244,7 +1281,10 @@ def test_mixer_keys( else: loss(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) + # When setting the key loss.set_keys(global_value=mixer_global_chosen_action_value_key) if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1465,6 +1505,10 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): ): loss = loss_fn(td) + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1581,6 +1625,9 @@ def test_ddpg_separate_losses( with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1701,6 +1748,11 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -2303,10 +2355,14 @@ def test_td3_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) + if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum([item for _, item in loss.items()]) @@ -3290,6 +3346,9 @@ def test_sac_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + loss_val_td = loss(td) if version == 1: @@ -3537,6 +3596,7 @@ def test_discrete_sac( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -3647,6 +3707,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) sd = loss_fn.state_dict() @@ -3658,6 +3719,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) loss_fn2.load_state_dict(sd) @@ -3695,6 +3757,7 @@ def test_discrete_sac_batcher( loss_function="l2", target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, + action_space="one-hot", **kwargs, ) @@ -3711,6 +3774,8 @@ def test_discrete_sac_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -3799,6 +3864,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -3821,6 +3887,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -3859,6 +3926,7 @@ def test_discrete_sac_notensordict( actor_network=actor, qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -4369,6 +4437,8 @@ def test_redq_deprecated_separate_losses(self, separate_losses): ): loss = loss_fn(td) + SoftUpdate(loss_fn, eps=0.5) + # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -5407,6 +5477,9 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys(True) + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + sum([item for key, item in loss.items() if key.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 @@ -5466,6 +5539,9 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -5509,7 +5585,7 @@ def test_dcql_tensordict_keys(self, td_est): torch.manual_seed(self.seed) action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) default_keys = { "value_target": "value_target", @@ -5565,6 +5641,8 @@ def test_dcql_tensordict_run(self, action_spec_type, td_est): loss_fn = DiscreteCQLLoss(actor, loss_function="l2") loss_fn.set_keys(**tensor_keys) + SoftUpdate(loss_fn, eps=0.5) + if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -5589,6 +5667,9 @@ def test_dcql_notensordict( in_keys=[observation_key], ) loss = DiscreteCQLLoss(actor) + + SoftUpdate(loss, eps=0.5) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) @@ -5744,6 +5825,10 @@ def _create_mock_data_ppo( reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + loc_key = "loc" + scale_key = "scale" + loc = torch.randn(batch, 4, device=device) + scale = torch.rand(batch, 4, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -5756,6 +5841,8 @@ def _create_mock_data_ppo( }, action_key: action, sample_log_prob_key: torch.randn_like(action[..., 1]) / 10, + loc_key: loc, + scale_key: scale, }, device=device, ) @@ -5857,6 +5944,10 @@ def test_ppo( loss_fn.make_value_estimator(td_est) loss = loss_fn(td) + if isinstance(loss_fn, KLPENPPOLoss): + kl = loss.pop("kl") + assert (kl != 0).any() + loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -6345,20 +6436,32 @@ def test_ppo_notensordict( f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } + if loss_class is KLPENPPOLoss: + kwargs.update({"loc": td.get("loc"), "scale": td.get("scale")}) + td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") # setting the seed for each loss so that drawing the random samples from # value network leads to same numbers for both runs torch.manual_seed(self.seed) + beta = getattr(loss, "beta", None) + if beta is not None: + beta = beta.clone() loss_val = loss(**kwargs) torch.manual_seed(self.seed) + if beta is not None: + loss.beta = beta.clone() loss_val_td = loss(td) for i, out_key in enumerate(loss.out_keys): - torch.testing.assert_close(loss_val_td.get(out_key), loss_val[i]) + torch.testing.assert_close( + loss_val_td.get(out_key), loss_val[i], msg=out_key + ) # test select torch.manual_seed(self.seed) + if beta is not None: + loss.beta = beta.clone() loss.select_out_keys("loss_objective", "loss_critic") if torch.__version__ >= "2.0.0": loss_obj, loss_crit = loss(**kwargs) @@ -8782,6 +8885,9 @@ def test_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + # Remove warnings + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9205,6 +9311,7 @@ def test_discrete_iql( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -9327,6 +9434,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) sd = loss_fn.state_dict() loss_fn2 = DiscreteIQLLoss( @@ -9337,6 +9445,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) loss_fn2.load_state_dict(sd) @@ -9350,6 +9459,7 @@ def test_discrete_iql_separate_losses(self, separate_losses): value_network=value, loss_function="l2", separate_losses=separate_losses, + action_space="one-hot", ) with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) @@ -9528,6 +9638,7 @@ def test_discrete_iql_batcher( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) ms = MultiStep(gamma=gamma, n_steps=n).to(device) @@ -9543,6 +9654,8 @@ def test_discrete_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9614,6 +9727,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -9639,6 +9753,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -9674,7 +9789,10 @@ def test_discrete_iql_notensordict( value = self._create_mock_value(observation_key=observation_key) loss = DiscreteIQLLoss( - actor_network=actor, qvalue_network=qvalue, value_network=value + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -9743,6 +9861,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: out_keys=["action"], ) loss = MyLoss(actor_module) + + if create_target_params: + SoftUpdate(loss, eps=0.5) + if cast is not None: loss.to(cast) for name in ("weight", "bias"): @@ -9872,11 +9994,13 @@ def __init__(self, delay_module=True): self.convert_to_functional( module1, "module1", create_target_params=delay_module ) + module2 = torch.nn.BatchNorm2d(10).eval() self.module2 = module2 - iterator_params = self.target_module1_params.values( - include_nested=True, leaves_only=True - ) + tparam = self._modules.get("target_module1_params", None) + if tparam is None: + tparam = self._modules.get("module1_params").data + iterator_params = tparam.values(include_nested=True, leaves_only=True) for target in iterator_params: if target.dtype is not torch.int64: target.data.normal_() @@ -12284,10 +12408,14 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) def test_instantiate_with_different_keys(): - loss_1 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_1 = DQNLoss( + value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True + ) loss_1.set_keys(reward="a") assert loss_1.tensor_keys.reward == "a" - loss_2 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_2 = DQNLoss( + value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True + ) loss_2.set_keys(reward="b") assert loss_1.tensor_keys.reward == "a" @@ -12391,6 +12519,22 @@ def __init__(self): assert p.device == dest +def test_loss_exploration(): + class DummyLoss(LossModule): + def forward(self, td): + assert exploration_type() == InteractionType.MODE + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + assert exploration_type() == ExplorationType.MODE + return td + + loss_fn = DummyLoss() + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + loss_fn(None) + assert exploration_type() == ExplorationType.RANDOM + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_distributions.py b/test/test_distributions.py index e6f228628a4..55f0f28cf18 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -140,22 +140,30 @@ def test_tanhnormal(self, min, max, vecs, upscale, shape, device): class TestTruncatedNormal: def test_truncnormal(self, min, max, vecs, upscale, shape, device): torch.manual_seed(0) - min, max, vecs, upscale, shape = _map_all( - min, max, vecs, upscale, shape, device=device + *vecs, min, max, vecs, upscale = torch.utils._pytree.tree_map( + lambda t: torch.as_tensor(t, device=device), + (*vecs, min, max, vecs, upscale), ) + assert all(t.device == device for t in vecs) d = TruncatedNormal( *vecs, upscale=upscale, min=min, max=max, ) + assert d.device == device for _ in range(100): a = d.rsample(shape) + assert a.device == device assert a.shape[: len(shape)] == shape assert (a >= d.min).all() assert (a <= d.max).all() lp = d.log_prob(a) assert torch.isfinite(lp).all() + oob_min = d.min.expand((*d.batch_shape, *d.event_shape)) - 1e-2 + assert not torch.isfinite(d.log_prob(oob_min)).any() + oob_max = d.max.expand((*d.batch_shape, *d.event_shape)) + 1e-2 + assert not torch.isfinite(d.log_prob(oob_max)).any() def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device): torch.manual_seed(0) diff --git a/test/test_env.py b/test/test_env.py index 22918c390df..d8136ff382b 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import functools import gc import os.path import re @@ -65,7 +66,14 @@ DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvBase, + EnvCreator, + ParallelEnv, + SerialEnv, +) from torchrl.envs.gym_like import default_info_dict_reader from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper @@ -209,6 +217,35 @@ def test_rollout(env_name, frame_skip, seed=0): env.close() +@pytest.mark.parametrize("max_steps", [1, 5]) +def test_rollouts_chaining(max_steps, batch_size=(4,), epochs=4): + # CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1 + env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size) + policy = CountingEnvCountPolicy( + action_spec=env.action_spec, action_key=env.action_key + ) + + input_td = env.reset() + for _ in range(epochs): + rollout_td = env.rollout( + max_steps=max_steps, + policy=policy, + auto_reset=False, + break_when_any_done=False, + tensordict=input_td, + ) + assert (env.count == max_steps).all() + input_td = step_mdp( + rollout_td[..., -1], + keep_other=True, + exclude_action=False, + exclude_reward=True, + reward_keys=env.reward_keys, + action_keys=env.action_keys, + done_keys=env.done_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) def test_rollout_predictability(device): env = MockSerialEnv(device=device) @@ -338,7 +375,8 @@ def test_mb_env_batch_lock(self, device, seed=0): mb_env.step(td) with pytest.raises( - RuntimeError, match=re.escape("Expected a tensordict with shape==env.shape") + RuntimeError, + match=re.escape("Expected a tensordict with shape==env.batch_size"), ): mb_env.step(td_expanded) @@ -624,6 +662,29 @@ def test_parallel_env_with_policy( # env_serial.close() env0.close() + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.parametrize("heterogeneous", [False, True]) + def test_transform_env_transform_no_device(self, heterogeneous): + # Tests non-regression on 1865 + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(), StepCounter(max_steps=3) + ) + + if heterogeneous: + make_envs = [EnvCreator(make_env), EnvCreator(make_env)] + else: + make_envs = make_env + penv = ParallelEnv(2, make_envs) + r = penv.rollout(6, break_when_any_done=False) + assert r.shape == (2, 6) + try: + env = TransformedEnv(penv) + r = env.rollout(6, break_when_any_done=False) + assert r.shape == (2, 6) + finally: + penv.close() + @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.parametrize( "env_name", @@ -1584,7 +1645,7 @@ def test_batch_locked(device): _ = env.step(td) with pytest.raises( - RuntimeError, match="Expected a tensordict with shape==env.shape, " + RuntimeError, match="Expected a tensordict with shape==env.batch_size, " ): env.step(td_expanded) @@ -2095,7 +2156,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count): @pytest.mark.parametrize("batch_size", [(1, 2)]) @pytest.mark.parametrize("env_type", ["serial", "parallel"]) - def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): + @pytest.mark.parametrize("break_when_any_done", [False, True]) + def test_vec_env( + self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2 + ): env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size) if env_type == "serial": vec_env = SerialEnv(n_workers, env_fun) @@ -2109,7 +2173,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): rollout_steps, policy=policy, return_contiguous=False, - break_when_any_done=False, + break_when_any_done=break_when_any_done, ) td = dense_stack_tds(td) for i in range(env_fun().n_nested_dim): @@ -2447,6 +2511,87 @@ def test_auto_cast_to_device(break_when_any_done): assert_allclose_td(rollout0, rollout1) +@pytest.mark.parametrize("device", get_default_devices()) +def test_backprop(device): + # Tests that backprop through a series of single envs and through a serial env are identical + # Also tests that no backprop can be achieved with parallel env. + class DifferentiableEnv(EnvBase): + def __init__(self, device): + super().__init__(device=device) + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec(3, device=device), + device=device, + ) + self.action_spec = CompositeSpec( + action=UnboundedContinuousTensorSpec(3, device=device), device=device + ) + self.reward_spec = CompositeSpec( + reward=UnboundedContinuousTensorSpec(1, device=device), device=device + ) + self.seed = 0 + + def _set_seed(self, seed): + self.seed = seed + return seed + + def _reset(self, tensordict): + td = self.observation_spec.zero().update(self.done_spec.zero()) + td["observation"] = ( + td["observation"].clone() + self.seed % 10 + ).requires_grad_() + return td + + def _step(self, tensordict): + action = tensordict.get("action") + obs = (tensordict.get("observation") + action) / action.norm() + return TensorDict( + { + "reward": action.sum().unsqueeze(0), + **self.full_done_spec.zero(), + "observation": obs, + } + ) + + torch.manual_seed(0) + policy = Actor(torch.nn.Linear(3, 3, device=device)) + env0 = DifferentiableEnv(device=device) + seed = env0.set_seed(0) + env1 = DifferentiableEnv(device=device) + env1.set_seed(seed) + r0 = env0.rollout(10, policy) + r1 = env1.rollout(10, policy) + r = torch.stack([r0, r1]) + g = torch.autograd.grad(r["next", "reward"].sum(), policy.parameters()) + + def make_env(seed, device=device): + env = DifferentiableEnv(device=device) + env.set_seed(seed) + return env + + serial_env = SerialEnv( + 2, + [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], + device=device, + ) + r_serial = serial_env.rollout(10, policy) + + g_serial = torch.autograd.grad( + r_serial["next", "reward"].sum(), policy.parameters() + ) + torch.testing.assert_close(g, g_serial) + + p_env = ParallelEnv( + 2, + [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], + device=device, + ) + try: + r_parallel = p_env.rollout(10, policy) + assert not r_parallel.exclude("action").requires_grad + finally: + p_env.close() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_exploration.py b/test/test_exploration.py index d0735a53ae8..e6493bd1804 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -54,6 +54,7 @@ class TestEGreedy: @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1]) @pytest.mark.parametrize("module", [True, False]) + @set_exploration_type(InteractionType.RANDOM) def test_egreedy(self, eps_init, module): torch.manual_seed(0) spec = BoundedTensorSpec(1, 1, torch.Size([4])) diff --git a/test/test_rb.py b/test/test_rb.py index 4b9b1a5dc9f..45b970da34e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -155,44 +155,20 @@ def _get_datum(self, datatype): def _get_data(self, datatype, size): if datatype is None: - data = torch.randint( - 100, - ( - size, - 1, - ), - ) + data = torch.randint(100, (size, 1)) elif datatype == "tensor": - data = torch.randint( - 100, - ( - size, - 1, - ), - ) + data = torch.randint(100, (size, 1)) elif datatype == "tensordict": data = TensorDict( { - "a": torch.randint( - 100, - ( - size, - 1, - ), - ), + "a": torch.randint(100, (size, 1)), "next": {"reward": torch.randn(size, 1)}, }, [size], ) elif datatype == "pytree": data = { - "a": torch.randint( - 100, - ( - size, - 1, - ), - ), + "a": torch.randint(100, (size, 1)), "b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]}, 30: torch.zeros(size, 2), } @@ -671,6 +647,8 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): def test_storage_dumps_loads( self, device_data, storage_type, data_type, isinit, tmpdir ): + torch.manual_seed(0) + dir_rb = tmpdir / "rb" dir_save = tmpdir / "save" dir_rb.mkdir() @@ -715,15 +693,18 @@ class TC: ) else: raise NotImplementedError + if storage_type in (LazyMemmapStorage,): storage = storage_type(max_size=10, scratch_dir=dir_rb) else: storage = storage_type(max_size=10) + # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index if data_type == "pytree": storage.set(range(3), tree_map(lambda x: x.cpu(), data)) else: storage.set(range(3), data.cpu()) + storage.dumps(dir_save) # check we can dump twice storage.dumps(dir_save) @@ -731,9 +712,11 @@ class TC: storage_recover = storage_type(max_size=10) if isinit: if data_type == "pytree": - storage_recover.set(range(3), tree_map(lambda x: x.cpu().zero_(), data)) + storage_recover.set( + range(3), tree_map(lambda x: x.cpu().clone().zero_(), data) + ) else: - storage_recover.set(range(3), data.cpu().zero_()) + storage_recover.set(range(3), data.cpu().clone().zero_()) if data_type in ("tensor", "pytree") and not isinit: with pytest.raises( @@ -830,11 +813,39 @@ def test_set_tensorclass(self, max_size, shape, storage): tc_sample = mystorage.get(idx) assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]]) + def test_extend_list_pytree(self, max_size, shape, storage): + memory = ReplayBuffer( + storage=storage(max_size=max_size), + sampler=SamplerWithoutReplacement(), + ) + data = [ + ( + torch.full(shape, i), + {"a": torch.full(shape, i), "b": (torch.full(shape, i))}, + [torch.full(shape, i)], + ) + for i in range(10) + ] + memory.extend(data) + sample = memory.sample(10) + for leaf in torch.utils._pytree.tree_leaves(sample): + assert (leaf.unique(sorted=True) == torch.arange(10)).all() + memory = ReplayBuffer( + storage=storage(max_size=max_size), + sampler=SamplerWithoutReplacement(), + ) + t1x4 = torch.Tensor([0.1, 0.2, 0.3, 0.4]) + t1x1 = torch.Tensor([0.01]) + with pytest.raises( + RuntimeError, match="Stacking the elements of the list resulted in an error" + ): + memory.extend([t1x4, t1x1, t1x4 + 0.4, t1x1 + 0.01]) + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) -def test_prototype_prb(priority_key, contiguous, device): +def test_ptdrb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) rb = TensorDictReplayBuffer( diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 83a283e4e56..7e0fef99786 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -9,18 +9,14 @@ import torch from mocking_classes import DiscreteActionVecMockEnv from tensordict import pad, TensorDict, unravel_key_list -from tensordict.nn import ( - InteractionType, - make_functional, - TensorDictModule, - TensorDictSequential, -) +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec, ) +from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( AdditiveGaussianWrapper, @@ -149,900 +145,28 @@ def test_stateful(self, safe, spec_type, lazy): RuntimeError, match="is not a valid configuration as the tensor specs are not " "specified", - ): - tensordict_module = SafeModule( - module=net, - spec=spec, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tensordict_module = SafeModule( - module=net, - spec=spec, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) - @pytest.mark.parametrize("lazy", [True, False]) - @pytest.mark.parametrize( - "exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net = nn.LazyLinear(4 * param_multiplier) - else: - net = nn.Linear(3, 4 * param_multiplier) - - in_keys = ["in"] - net = SafeModule( - module=NormalParamWrapper(net), - spec=None, - in_keys=in_keys, - out_keys=out_keys, - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - if out_keys == ["loc", "scale"]: - dist_in_keys = ["loc", "scale"] - elif out_keys == ["loc_1", "scale_1"]: - dist_in_keys = {"loc": "loc_1", "scale": "scale_1"} - else: - raise NotImplementedError - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=dist_in_keys, - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=dist_in_keys, - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tensordict_module = SafeProbabilisticTensorDictSequential(net, prob_module) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_exploration_type(exp_mode): - tensordict_module(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - params = make_functional(net) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tensordict_module = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=TensorDict({"module": params}, [])) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - tdnet = SafeModule( - module=NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tensordict_module = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tensordict_module) - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.BatchNorm1d(32 * param_multiplier) - params = make_functional(net) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=TensorDict({"module": params}, [])) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - tdnet = SafeModule( - module=NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - return - else: - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) - tdnet = SafeModule( - module=net, in_keys=["in"], out_keys=["loc", "scale"], spec=None - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - -class TestTDSequence: - # Temporarily disabling this test until 473 is merged in tensordict - # def test_in_key_warning(self): - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] - # ) - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] - # ) - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 1 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - kwargs = {} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - spec=spec, - module=net2, - in_keys=["hidden"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful_probabilistic(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - in_keys=["in"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - spec=spec, - in_keys=["loc", "scale"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") + ): + tensordict_module = SafeModule( + module=net, + spec=spec, + in_keys=["in"], + out_keys=["out"], + safe=safe, + ) + return else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, + tensordict_module = SafeModule( + module=net, spec=spec, - in_keys=["hidden"], + in_keys=["in"], out_keys=["out"], safe=safe, ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tensordict_module(td) assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) + assert td.get("out").shape == torch.Size([3, 4]) # test bounds if not safe and spec_type == "bounded": @@ -1052,91 +176,73 @@ def test_functional_with_buffer(self, safe, spec_type): @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) + @pytest.mark.parametrize("lazy", [True, False]) + @pytest.mark.parametrize( + "exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None] + ) + def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys): torch.manual_seed(0) param_multiplier = 2 + if lazy: + net = nn.LazyLinear(4 * param_multiplier) + else: + net = nn.Linear(3, 4 * param_multiplier) - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) + in_keys = ["in"] + net = SafeModule( + module=NormalParamWrapper(net), + spec=None, + in_keys=in_keys, + out_keys=out_keys, ) - net2 = NormalParamWrapper(net2) if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError kwargs = {"distribution_class": TanhNormal} + if out_keys == ["loc", "scale"]: + dist_in_keys = ["loc", "scale"] + elif out_keys == ["loc_1", "scale_1"]: + dist_in_keys = {"loc": "loc_1", "scale": "scale_1"} + else: + raise NotImplementedError if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + prob_module = SafeProbabilisticModule( + in_keys=dist_in_keys, + out_keys=["out"], + spec=spec, + safe=safe, + **kwargs, + ) + return else: - tdmodule1 = SafeModule( - net1, in_keys=["in"], out_keys=["hidden"], spec=None, safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], + in_keys=dist_in_keys, out_keys=["out"], spec=spec, safe=safe, **kwargs, ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, ["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape + tensordict_module = SafeProbabilisticTensorDictSequential(net, prob_module) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + with set_exploration_type(exp_mode): + tensordict_module(td) assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) + assert td.get("out").shape == torch.Size([3, 4]) # test bounds if not safe and spec_type == "bounded": @@ -1144,18 +250,33 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) + +class TestTDSequence: + # Temporarily disabling this test until 473 is merged in tensordict + # def test_in_key_warning(self): + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] + # ) + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] + # ) + @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) if spec_type is None: spec = None @@ -1164,6 +285,8 @@ def test_vmap(self, safe, spec_type): elif spec_type == "unbounded": spec = UnboundedContinuousTensorSpec(4) + kwargs = {} + if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: @@ -1182,80 +305,54 @@ def test_vmap(self, safe, spec_type): safe=False, ) tdmodule2 = SafeModule( - net2, spec=spec, + module=net2, in_keys=["hidden"], out_keys=["out"], - safe=safe, + safe=False, + **kwargs, ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - params = make_functional(tdmodule) - assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] - with params.unlock_(): - del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") assert tdmodule[0] is tdmodule1 assert tdmodule[1] is tdmodule2 - # vmap = True - params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + tdmodule(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) # test bounds if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful_probabilistic(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) if spec_type is None: @@ -1274,60 +371,68 @@ def test_vmap_probabilistic(self, safe, spec_type): else: tdmodule1 = SafeModule( net1, - spec=None, in_keys=["in"], out_keys=["hidden"], + spec=None, + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + in_keys=["hidden"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeModule( + module=net2, + in_keys=["hidden"], + out_keys=["loc", "scale"], + spec=None, safe=False, ) - tdmodule2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) + prob_module = SafeProbabilisticModule( + spec=spec, in_keys=["loc", "scale"], out_keys=["out"], - spec=spec, - safe=safe, + safe=False, **kwargs, ) tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, tdmodule2, prob_module + tdmodule1, dummy_tdmodule, tdmodule2, prob_module ) - params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 4 + tdmodule[1] = tdmodule2 + tdmodule[2] = prob_module + assert len(tdmodule) == 4 - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 4 + del tdmodule[3] + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 + assert tdmodule[2] is prob_module - # vmap = (0, 0) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) + tdmodule(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + dist = tdmodule.get_dist(td) + assert dist.rsample().shape[: td.ndimension()] == td.shape + # test bounds if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - @pytest.mark.parametrize("functional", [True, False]) - def test_submodule_sequence(self, functional): + def test_submodule_sequence(self): td_module_1 = SafeModule( nn.Linear(3, 2), in_keys=["in"], @@ -1340,34 +445,19 @@ def test_submodule_sequence(self, functional): ) td_module = SafeSequential(td_module_1, td_module_2) - if functional: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - params = make_functional(sub_seq_1) - sub_seq_1(td_1, params=params) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - params = make_functional(sub_seq_2) - sub_seq_2(td_2, params=params) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - else: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) @pytest.mark.parametrize("stack", [True, False]) - @pytest.mark.parametrize("functional", [True, False]) - def test_sequential_partial(self, stack, functional): + def test_sequential_partial(self, stack): torch.manual_seed(0) param_multiplier = 2 @@ -1416,11 +506,6 @@ def test_sequential_partial(self, stack, functional): tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) - if functional: - params = make_functional(tdmodule) - else: - params = None - if stack: td = torch.stack( [ @@ -1429,10 +514,7 @@ def test_sequential_partial(self, stack, functional): ], 0, ) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1443,10 +525,7 @@ def test_sequential_partial(self, stack, functional): assert "b" in td[0].keys() else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1782,9 +861,12 @@ def test_multi_consecutive(self, shape, python_based): ) @pytest.mark.parametrize("python_based", [True, False]) - def test_lstm_parallel_env(self, python_based): + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + torch.manual_seed(0) device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( @@ -1796,6 +878,10 @@ def test_lstm_parallel_env(self, python_based): device=device, python_based=python_based, ) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv def create_transformed_env(): primer = lstm_module.make_tensordict_primer() @@ -1807,7 +893,12 @@ def create_transformed_env(): env.append_transform(primer) return env - env = ParallelEnv( + if heterogeneous: + create_transformed_env = [ + EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env), + ] + env = cls( create_env_fn=create_transformed_env, num_workers=2, ) @@ -2109,9 +1200,13 @@ def test_multi_consecutive(self, shape, python_based): ) @pytest.mark.parametrize("python_based", [True, False]) - def test_gru_parallel_env(self, python_based): + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_gru_parallel_env(self, python_based, parallel, heterogeneous): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + torch.manual_seed(0) + device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs gru_module = GRUModule( @@ -2134,7 +1229,17 @@ def create_transformed_env(): env.append_transform(primer) return env - env = ParallelEnv( + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if heterogeneous: + create_transformed_env = [ + EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env), + ] + + env = cls( create_env_fn=create_transformed_env, num_workers=2, ) diff --git a/test/test_transforms.py b/test/test_transforms.py index 725945ef113..c2fb7fca41c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -762,16 +762,16 @@ def test_transform_env_clone(self): ).all() assert cloned is not env.transform - @pytest.mark.parametrize("dim", [-2, -1]) + @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) + @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) def test_transform_model(self, dim, N, padding): # test equivalence between transforms within an env and within a rb key1 = "observation" keys = [key1] out_keys = ["out_" + key1] cat_frames = CatFrames( - N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding + N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding ) cat_frames2 = CatFrames( N=N, @@ -781,23 +781,22 @@ def test_transform_model(self, dim, N, padding): padding=padding, ) envbase = ContinuousActionVecMockEnv() - env = TransformedEnv( - envbase, - Compose( - UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames - ), - ) + env = TransformedEnv(envbase, cat_frames) + torch.manual_seed(10) env.set_seed(10) td = env.rollout(10) + torch.manual_seed(10) envbase.set_seed(10) tdbase = envbase.rollout(10) + tdbase0 = tdbase.clone() model = nn.Sequential(cat_frames2, nn.Identity()) model(tdbase) - assert (td == tdbase).all() + assert assert_allclose_td(td, tdbase) + with pytest.warns(UserWarning): tdbase0.names = None model(tdbase0) @@ -816,7 +815,7 @@ def test_transform_model(self, dim, N, padding): # check that swapping dims and names leads to same result assert_allclose_td(v1, v2.transpose(0, 1)) - @pytest.mark.parametrize("dim", [-2, -1]) + @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) @@ -826,7 +825,7 @@ def test_transform_rb(self, dim, N, padding, rbclass): keys = [key1] out_keys = ["out_" + key1] cat_frames = CatFrames( - N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding + N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding ) cat_frames2 = CatFrames( N=N, @@ -836,12 +835,7 @@ def test_transform_rb(self, dim, N, padding, rbclass): padding=padding, ) - env = TransformedEnv( - ContinuousActionVecMockEnv(), - Compose( - UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames - ), - ) + env = TransformedEnv(ContinuousActionVecMockEnv(), cat_frames) td = env.rollout(10) rb = rbclass(storage=LazyTensorStorage(20)) @@ -875,8 +869,8 @@ def test_transform_as_inverse(self, dim, N, padding): td = env1.rollout(rollout_length) transformed_td = cat_frames._inv_call(td) - assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim, N) - assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim, N) + assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim * N) + assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim * N) with pytest.raises( Exception, match="CatFrames as inverse is not supported as a transform for environments, only for replay buffers.", @@ -971,14 +965,50 @@ def test_transform_no_env(self, device, d, batch_size, dim, N): # we don't want the same tensor to be returned twice, but they're all copies of the same buffer assert v1 is not v2 + @pytest.mark.skipif(not _has_gym, reason="gym required for this test") + @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) + @pytest.mark.parametrize("envtype", ["gym", "conv"]) + def test_tranform_offline_against_online(self, padding, envtype): + torch.manual_seed(0) + key = "observation" if envtype == "gym" else "pixels" + env = SerialEnv( + 3, + lambda: TransformedEnv( + GymEnv("CartPole-v1") + if envtype == "gym" + else DiscreteActionConvMockEnv(), + CatFrames( + dim=-3 if envtype == "conv" else -1, + N=5, + in_keys=[key], + out_keys=[f"{key}_cat"], + padding=padding, + ), + ), + ) + env.set_seed(0) + + r = env.rollout(100, break_when_any_done=False) + + c = CatFrames( + dim=-3 if envtype == "conv" else -1, + N=5, + in_keys=[key, ("next", key)], + out_keys=[f"{key}_cat2", ("next", f"{key}_cat2")], + padding=padding, + ) + + r2 = c(r) + + torch.testing.assert_close(r2[f"{key}_cat2"], r2[f"{key}_cat"]) + torch.testing.assert_close(r2["next", f"{key}_cat2"], r2["next", f"{key}_cat"]) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)]) @pytest.mark.parametrize("d", range(2, 3)) @pytest.mark.parametrize( "dim", - [ - -3, - ], + [-3], ) @pytest.mark.parametrize("N", [2, 4]) def test_transform_compose(self, device, d, batch_size, dim, N): @@ -8231,7 +8261,7 @@ def test_batch_locked_transformed(device): env.step(td) with pytest.raises( - RuntimeError, match="Expected a tensordict with shape==env.shape, " + RuntimeError, match="Expected a tensordict with shape==env.batch_size, " ): env.step(td_expanded) @@ -8275,7 +8305,7 @@ def test_batch_unlocked_with_batch_size_transformed(device): td_expanded = td.expand(2, 2).reshape(-1).to_tensordict() with pytest.raises( - RuntimeError, match="Expected a tensordict with shape==env.shape, " + RuntimeError, match="Expected a tensordict with shape==env.batch_size, " ): env.step(td_expanded) @@ -10113,17 +10143,15 @@ def test_trans_parallel_env_check(self): def test_transform_no_env(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) - assert not td.is_empty() t = RemoveEmptySpecs() t._call(td) - assert td.is_empty() + assert len(td.keys()) == 0 def test_transform_compose(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) - assert not td.is_empty() t = Compose(RemoveEmptySpecs()) t._call(td) - assert td.is_empty() + assert len(td.keys()) == 0 def test_transform_env(self): base_env = self.DummyEnv() @@ -10138,7 +10166,7 @@ def test_transform_model(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) t = nn.Sequential(Compose(RemoveEmptySpecs())) td = t(td) - assert td.is_empty(), td + assert len(td.keys()) == 0 @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, rbclass): @@ -10154,14 +10182,13 @@ def test_transform_rb(self, rbclass): td = rb.sample(1) if "index" in td.keys(): del td["index"] - assert td.is_empty() + assert len(td.keys()) == 0 def test_transform_inverse(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) - assert not td.is_empty() t = RemoveEmptySpecs() t.inv(td) - assert not td.is_empty() + assert len(td.keys()) != 0 env = TransformedEnv(self.DummyEnv(), RemoveEmptySpecs()) td2 = env.transform.inv(TensorDict({}, [])) assert ("state", "sub") in td2.keys(True) diff --git a/test/test_utils.py b/test/test_utils.py index 620149daeb6..c2ce2eae6b9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,10 @@ import _utils_internal import pytest -from torchrl._utils import get_binary_env_var, implement_for +import torch + +from _utils_internal import get_default_devices +from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -358,6 +361,21 @@ class MockGym: ) # would break with gymnasium +@pytest.mark.parametrize("device", get_default_devices()) +def test_rng_decorator(device): + with torch.device(device): + torch.manual_seed(10) + s0a = torch.randn(3) + with _rng_decorator(0): + torch.randn(3) + s0b = torch.randn(3) + torch.manual_seed(10) + s1a = torch.randn(3) + s1b = torch.randn(3) + torch.testing.assert_close(s0a, s1a) + torch.testing.assert_close(s0b, s1b) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_extension.py b/torchrl/_extension.py index 5eb820cb86f..a9e52dbf9a4 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -3,18 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import importlib +import importlib.util import warnings def is_module_available(*modules: str) -> bool: - r"""Returns if a top-level module with :attr:`name` exists *without** importing it. + """Returns if a top-level module with :attr:`name` exists *without** importing it. This is generally safer than try-catch block around a `import X`. It avoids third party libraries breaking assumptions of some of our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544). - """ return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9538cecb026..ae01556f0e6 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return new_ending else: return key[:-1] + (new_ending,) + + +class _rng_decorator(_DecoratorContextManager): + """Temporarily sets the seed and sets back the rng state when exiting.""" + + def __init__(self, seed, device=None): + self.seed = seed + self.device = device + self.has_cuda = torch.cuda.is_available() + + def __enter__(self): + self._get_state() + torch.manual_seed(self.seed) + + def _get_state(self): + if self.has_cuda: + if self.device is None: + self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state()) + else: + self._state = ( + torch.random.get_rng_state(), + torch.cuda.get_rng_state(self.device), + ) + + else: + self._state = torch.random.get_rng_state() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.has_cuda: + torch.random.set_rng_state(self._state[0]) + if self.device is not None: + torch.cuda.set_rng_state(self._state[1], device=self.device) + else: + torch.cuda.set_rng_state(self._state[1]) + + else: + torch.random.set_rng_state(self._state) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index eff2434d487..4e816f657ed 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -575,7 +575,7 @@ def __init__( reset_when_done: bool = True, interruptor=None, ): - from torchrl.envs.batched_envs import _BatchedEnv + from torchrl.envs.batched_envs import BatchedEnvBase self.closed = True @@ -589,7 +589,7 @@ def __init__( else: env = create_env_fn if create_env_kwargs: - if not isinstance(env, _BatchedEnv): + if not isinstance(env, BatchedEnvBase): raise RuntimeError( "kwargs were passed to SyncDataCollector but they can't be set " f"on environment of type {type(create_env_fn)}." @@ -718,7 +718,8 @@ def __init__( self.return_same_td = return_same_td # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env - self._shuttle = self.env.reset() + with torch.no_grad(): + self._shuttle = self.env.reset() if self.policy_device != self.env_device or self.env_device is None: self._shuttle_has_no_device = True self._shuttle.clear_device_() @@ -1077,7 +1078,7 @@ def rollout(self) -> TensorDictBase: if self.storing_device is not None: tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=False) + self._shuttle.to(self.storing_device, non_blocking=True) ) else: tensordicts.append(self._shuttle) @@ -1135,6 +1136,7 @@ def _update_device_wise(tensor0, tensor1): return tensor1 return tensor1.to(tensor0.device, non_blocking=True) + @torch.no_grad() def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata @@ -1191,11 +1193,11 @@ def state_dict(self) -> OrderedDict: `"env_state_dict"`. """ - from torchrl.envs.batched_envs import _BatchedEnv + from torchrl.envs.batched_envs import BatchedEnvBase if isinstance(self.env, TransformedEnv): env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, _BatchedEnv): + elif isinstance(self.env, BatchedEnvBase): env_state_dict = self.env.state_dict() else: env_state_dict = OrderedDict() diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index a467c763fa5..faf4d4a6cce 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -472,7 +472,7 @@ def check_list_length_consistency(*lists): pending_samples = [ e.print_remote_collector_info.remote() for e in self.remote_collectors() ] - ray.wait(object_refs=pending_samples) + ray.wait(pending_samples) @property def num_workers(self): @@ -602,7 +602,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: samples_ready = [] while len(samples_ready) < self.num_collectors: samples_ready, samples_not_ready = ray.wait( - object_refs=pending_tasks, num_returns=len(pending_tasks) + pending_tasks, num_returns=len(pending_tasks) ) # Retrieve and concatenate Tensordicts @@ -645,7 +645,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: raise RuntimeError("Missing pending tasks, something went wrong") # Wait for first worker to finish - wait_results = ray.wait(object_refs=list(pending_tasks.keys())) + wait_results = ray.wait(list(pending_tasks.keys())) future = wait_results[0][0] collector_index = pending_tasks.pop(future) collector = self.remote_collectors()[collector_index] @@ -678,7 +678,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: # Wait for the in-process collections tasks to finish. refs = list(pending_tasks.keys()) - ray.wait(object_refs=refs, num_returns=len(refs)) + ray.wait(refs, num_returns=len(refs)) # Cancel the in-process collections tasks # for ref in refs: diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 6fb7d6fdbaf..cb84ce32586 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import datasets from .postprocs import MultiStep from .replay_buffers import ( ImmutableDatasetWriter, diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index c3999806aaf..749bf0888ae 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -925,7 +925,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: if data.ndim: priority = self._get_priority_vector(data) else: - priority = self._get_priority_item(data) + priority = torch.as_tensor(self._get_priority_item(data)) index = data.get("index") while index.shape != priority.shape: # reduce index diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 15e46ae1038..0022fe41569 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,6 @@ from torchrl._utils import _replace_last from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES try: from torchrl._torchrl import ( @@ -250,11 +249,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class PrioritizedSampler(Sampler): """Prioritized sampler for replay buffer. - Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. - Prioritized experience replay." - (https://arxiv.org/abs/1511.05952) + Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952) Args: + max_capacity (int): maximum capacity of the buffer. alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. @@ -264,6 +262,51 @@ class PrioritizedSampler(Sampler): tensordicts (ie stored trajectory). Can be one of "max", "min", "median" or "mean". + Examples: + >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler + >>> from tensordict import TensorDict + >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)) + >>> priority = torch.tensor([0, 1000]) + >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) + >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) + >>> rb.add(data_0) + >>> rb.add(data_1) + >>> rb.update_priority(torch.tensor([0, 1]), priority=priority) + >>> sample, info = rb.sample(10, return_info=True) + >>> print(sample) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), + obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), + priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + >>> print(info) + {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, + 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])} + + .. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the + process of updating the priorities: + + >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler + >>> from tensordict import TensorDict + >>> rb = TDRB( + ... storage=LazyTensorStorage(10), + ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0), + ... priority_key="priority", # This kwarg isn't present in regular RBs + ... ) + >>> priority = torch.tensor([0, 1000]) + >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) + >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) + >>> data = torch.stack([data_0, data_1]) + >>> rb.extend(data) + >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor + >>> sample, info = rb.sample(10, return_info=True) + >>> print(sample['index']) # The index is packed with the tensordict + tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + """ def __init__( @@ -327,15 +370,17 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: raise RuntimeError("negative p_sum") if p_min <= 0: raise RuntimeError("negative p_min") + # For some undefined reason, only np.random works here. + # All PT attempts fail, even when subsequently transformed into numpy mass = np.random.uniform(0.0, p_sum, size=batch_size) + # mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum) + # mass = torch.rand(batch_size).mul_(p_sum) index = self._sum_tree.scan_lower_bound(mass) - if not isinstance(index, np.ndarray): - index = np.array([index]) - if isinstance(index, torch.Tensor): - index.clamp_max_(len(storage) - 1) - else: - index = np.clip(index, None, len(storage) - 1) - weight = self._sum_tree[index] + index = torch.as_tensor(index) + if not index.ndim: + index = index.unsqueeze(0) + index.clamp_max_(len(storage) - 1) + weight = torch.as_tensor(self._sum_tree[index]) # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta) @@ -345,9 +390,10 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) # weight_i = (p_i / min(p)) ^ (-beta) # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) + weight = torch.pow(weight / p_min, -self._beta) return index, {"_weight": weight} + @torch.no_grad() def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: priority = self.default_priority @@ -360,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: "priority should be a scalar or an iterable of the same " "length as index" ) + # make sure everything is cast to cpu + if isinstance(index, torch.Tensor) and not index.is_cpu: + index = index.cpu() + if isinstance(priority, torch.Tensor) and not priority.is_cpu: + priority = priority.cpu() self._sum_tree[index] = priority self._min_tree[index] = priority @@ -377,6 +428,7 @@ def extend(self, index: torch.Tensor) -> None: index = index.cpu() self._add_or_extend(index) + @torch.no_grad() def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> None: @@ -389,28 +441,26 @@ def update_priority( indexed elements. """ - if isinstance(index, INT_CLASSES): - if not isinstance(priority, float): - if len(priority) != 1: - raise RuntimeError( - f"priority length should be 1, got {len(priority)}" - ) - priority = priority.item() - else: - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(index) == len(priority) - ): + priority = torch.as_tensor(priority, device=torch.device("cpu")).detach() + index = torch.as_tensor( + index, dtype=torch.long, device=torch.device("cpu") + ).detach() + # we need to reshape priority if it has more than one elements or if it has + # a different shape than index + if priority.numel() > 1 and priority.shape != index.shape: + try: + priority = priority.reshape(index.shape[:1]) + except Exception as err: raise RuntimeError( "priority should be a number or an iterable of the same " - "length as index" - ) - index = _to_numpy(index) - priority = _to_numpy(priority) - - self._max_priority = max(self._max_priority, np.max(priority)) - priority = np.power(priority + self._eps, self._alpha) + f"length as index. Got priority of shape {priority.shape} and index " + f"{index.shape}." + ) from err + elif priority.numel() <= 1: + priority = priority.squeeze() + + self._max_priority = priority.max().clamp_min(self._max_priority).item() + priority = torch.pow(priority + self._eps, self._alpha) self._sum_tree[index] = priority self._min_tree[index] = priority @@ -668,7 +718,7 @@ def __init__( if end_key is None: end_key = ("next", "done") if traj_key is None: - traj_key = "run" + traj_key = "episode" self.end_key = end_key self.traj_key = traj_key @@ -790,7 +840,7 @@ def _get_stop_and_length(self, storage, fallback=True): raise RuntimeError( "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) - vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] + vals = self._find_start_stop_traj(end=done.squeeze()[: len(storage)]) if self.cache_values: self._cache["stop-and-length"] = vals return vals @@ -867,7 +917,7 @@ def _sample_slices( truncated[seq_length.cumsum(0) - 1] = 1 traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 terminated = torch.zeros_like(truncated) - if terminated.any(): + if traj_terminated.any(): if isinstance(seq_length, int): truncated.view(num_slices, -1)[traj_terminated] = 1 else: @@ -1233,7 +1283,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index d4d81f10bc1..bb86b18bad8 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -226,7 +226,9 @@ def load_state_dict(self, state_dict): if isinstance(elt, torch.Tensor): self._storage.append(elt) elif isinstance(elt, (dict, OrderedDict)): - self._storage.append(TensorDict({}, []).load_state_dict(elt)) + self._storage.append( + TensorDict({}, []).load_state_dict(elt, strict=False) + ) else: raise TypeError( f"Objects of type {type(elt)} are not supported by ListStorage.load_state_dict" @@ -497,9 +499,11 @@ def load_state_dict(self, state_dict): ) elif isinstance(_storage, (dict, OrderedDict)): if is_tensor_collection(self._storage): - self._storage.load_state_dict(_storage) + self._storage.load_state_dict(_storage, strict=False) elif self._storage is None: - self._storage = TensorDict({}, []).load_state_dict(_storage) + self._storage = TensorDict({}, []).load_state_dict( + _storage, strict=False + ) else: raise RuntimeError( f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}. If your storage is pytree-based, use the dumps/load API instead." @@ -538,6 +542,21 @@ def set( else: self._len = max(self._len, max(cursor) + 1) + if isinstance(data, list): + # flip list + try: + data = _flip_list(data) + except Exception: + raise RuntimeError( + "Stacking the elements of the list resulted in " + "an error. " + f"Storages of type {type(self)} expect all elements of the list " + f"to have the same tree structure. If the list is compact (each " + f"leaf is itself a batch with the appropriate number of elements) " + f"consider using a tuple instead, as lists are used within `extend` " + f"for per-item addition." + ) + if not self.initialized: if not isinstance(cursor, INT_CLASSES): if is_tensor_collection(data): @@ -832,7 +851,7 @@ def load_state_dict(self, state_dict): ) elif isinstance(_storage, (dict, OrderedDict)): if is_tensor_collection(self._storage): - self._storage.load_state_dict(_storage) + self._storage.load_state_dict(_storage, strict=False) self._storage.memmap_() elif self._storage is None: warnings.warn( @@ -840,7 +859,9 @@ def load_state_dict(self, state_dict): "It is preferable to load a storage onto a" "pre-allocated one whenever possible." ) - self._storage = TensorDict({}, []).load_state_dict(_storage) + self._storage = TensorDict({}, []).load_state_dict( + _storage, strict=False + ) self._storage.memmap_() else: raise RuntimeError( @@ -888,7 +909,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: # to be deprecated in v0.4 def map_device(tensor): if tensor.device != self.device: - return tensor.to(self.device, non_blocking=False) + return tensor.to(self.device, non_blocking=True) return tensor if is_tensor_collection(result): @@ -1313,3 +1334,10 @@ def save_tensor(tensor_path: str, tensor: torch.Tensor): out.append(save_tensor(tensor_path, tensor)) return tree_unflatten(out, data_specs) + + +def _flip_list(data): + flat_data, flat_specs = zip(*[tree_flatten(item) for item in data]) + flat_data = zip(*flat_data) + stacks = [torch.stack(item) for item in flat_data] + return tree_unflatten(stacks, flat_specs[0]) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 156d32f9539..6517d915a0b 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -18,7 +18,16 @@ from tensordict import is_tensor_collection, MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp -from torch.utils._pytree import tree_flatten + +try: + from torch.utils._pytree import tree_leaves +except ImportError: + from torch.utils._pytree import tree_flatten + + def tree_leaves(data): # noqa: D103 + tree_flat, _ = tree_flatten(data) + return tree_flat + from torchrl.data.replay_buffers.storages import Storage from torchrl.data.replay_buffers.utils import _reduce @@ -125,7 +134,7 @@ def extend(self, data: Sequence) -> torch.Tensor: elif isinstance(data, list): batch_size = len(data) else: - batch_size = len(tree_flatten(data)[0][0]) + batch_size = len(tree_leaves(data)[0]) if batch_size == 0: raise RuntimeError("Expected at least one element in extend.") device = data.device if hasattr(data, "device") else None diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 19090d3f4c5..8f039b317fc 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -394,7 +394,7 @@ def get_dataloader( ) out = TensorDictReplayBuffer( storage=TensorStorage(data), - collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False), + collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True), sampler=SamplerWithoutReplacement(drop_last=True), batch_size=batch_size, prefetch=prefetch, diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 311b2584aa5..a4ccbfd8a1b 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -7,13 +7,13 @@ import abc import collections import importlib -from typing import Sequence, Tuple +from typing import List, Tuple import numpy as np import torch from tensordict import TensorDict -from torch import Tensor +from torch import nn, Tensor from torch.nn import functional as F from torchrl.data.rlhf.prompt import PromptData @@ -30,8 +30,8 @@ class KLControllerBase(abc.ABC): """ @abc.abstractmethod - def update(self, kl_values: float): - pass + def update(self, kl_values: List[float]) -> float: + ... class ConstantKLController(KLControllerBase): @@ -40,30 +40,39 @@ class ConstantKLController(KLControllerBase): This controller maintains a fixed coefficient no matter what values it is updated with. - Arguments: - model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + Keyword Arguments: kl_coef (float): The coefficient to multiply KL with when calculating the reward. + model (nn.Module, optional): wrapped model that needs to be controlled. + Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will + be updated in-place. """ - def __init__(self, model, kl_coef): + def __init__( + self, + *, + kl_coef: float = None, + model: nn.Module | None = None, + ): self.model = model - if not hasattr(model, "kl_coef"): + if model is not None and not hasattr(model, "kl_coef"): raise AttributeError( "Model input to ConstantKLController doesn't have attribute 'kl_coef'" ) self.coef = kl_coef - self.model.kl_coef = self.coef + if model is not None: + self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float] = None): - self.model.kl_coef = self.coef + def update(self, kl_values: List[float] = None) -> float: + if self.model is not None: + self.model.kl_coef = self.coef + return self.coef class AdaptiveKLController(KLControllerBase): """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". - Arguments: - model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + Keyword Arguments: init_kl_coef (float): The starting value of the coefficient. target (float): The target KL value. When the observed KL is smaller, the coefficient is decreased, thereby relaxing the KL penalty in the training @@ -72,19 +81,30 @@ class AdaptiveKLController(KLControllerBase): increased, thereby pulling the model back towards the reference model. horizon (int): Scaling factor to control how aggressively we update the coefficient. + model (nn.Module, optional): wrapped model that needs to be controlled. + Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will + be updated in-place. Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py """ - def __init__(self, model, init_kl_coef: float, target: float, horizon: int): + def __init__( + self, + *, + init_kl_coef: float, + target: float, + horizon: int, + model: nn.Module | None = None, + ): self.model = model self.coef = init_kl_coef self.target = target self.horizon = horizon - self.model.kl_coef = self.coef + if model is not None: + self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float]): + def update(self, kl_values: List[float]): """Update ``self.coef`` adaptively. Arguments: @@ -104,6 +124,9 @@ def update(self, kl_values: Sequence[float]): proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ mult = 1 + proportional_error * n_steps / self.horizon self.coef *= mult # βₜ₊₁ + if self.model is not None: + self.model.kl_coef = self.coef + return self.coef class RolloutFromModel: @@ -233,8 +256,6 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the ``generate`` method. - kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting - from the reward. Defaults to 0.1. Returns: A :class:`~tensordict.TensorDict` with the following keys: @@ -514,7 +535,7 @@ def generate(self, batch: PromptData, generation_config=None): def step_scheduler(self): # recover true kl - self.kl_scheduler.update(self._kl_queue) + self.kl_coef = self.kl_scheduler.update(self._kl_queue) if isinstance(self._kl_queue, (list, collections.deque)): # remove all values while len(self._kl_queue): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5e88cf4e86d..f472533b9f3 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -8,6 +8,7 @@ import gc import os +import weakref from collections import OrderedDict from copy import deepcopy from functools import wraps @@ -19,7 +20,7 @@ import torch from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict._tensordict import _unravel_key_to_tuple, unravel_key +from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, @@ -40,7 +41,6 @@ from torchrl.envs.utils import ( _aggregate_end_of_traj, - _set_single_key, _sort_keys, _update_during_reset, clear_mpi_env_vars, @@ -48,7 +48,7 @@ def _check_start(fun): - def decorated_fun(self: _BatchedEnv, *args, **kwargs): + def decorated_fun(self: BatchedEnvBase, *args, **kwargs): if self.is_closed: self._create_td() self._start_workers() @@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs): return super().__call__(*args, **kwargs) -class _BatchedEnv(EnvBase): +class BatchedEnvBase(EnvBase): """Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely. Those queries will return a list of length equal to the number of workers containing the @@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase): serial_for_single (bool, optional): if ``True``, creating a parallel environment with a single worker will return a :class:`~SerialEnv` instead. This option has no effect with :class:`~SerialEnv`. Defaults to ``False``. + non_blocking (bool, optional): if ``True``, device moves will be done using the + ``non_blocking=True`` option. Defaults to ``True`` for batched environments + on cuda devices, and ``False`` otherwise. Examples: >>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator @@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase): >>> env = ParallelEnv(2, [ ... lambda: DMControlEnv("humanoid", "stand"), ... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands - >>> r = env.rollout(10) # executes 10 random steps in the environment - >>> r[0] # data for Humanoid stand + >>> rollout = env.rollout(10) # executes 10 random steps in the environment + >>> rollout[0] # data for Humanoid stand TensorDict( fields={ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), @@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase): batch_size=torch.Size([10]), device=cpu, is_shared=False) - >>> r[1] # data for Humanoid walk + >>> rollout[1] # data for Humanoid walk TensorDict( fields={ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), @@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase): batch_size=torch.Size([10]), device=cpu, is_shared=False) + >>> # serial_for_single to avoid creating parallel envs if not necessary >>> env = ParallelEnv(1, make_env, serial_for_single=True) >>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary """ @@ -270,6 +274,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, serial_for_single: bool = False, + non_blocking: bool = False, ): super().__init__(device=device) self.serial_for_single = serial_for_single @@ -327,6 +332,15 @@ def __init__( # self._prepare_dummy_env(create_env_fn, create_env_kwargs) self._properties_set = False self._get_metadata(create_env_fn, create_env_kwargs) + self._non_blocking = non_blocking + + @property + def non_blocking(self): + nb = self._non_blocking + if nb is None: + nb = self.device is not None and self.device.type == "cuda" + self._non_blocking = nb + return nb def _get_metadata( self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] @@ -419,7 +433,10 @@ def _check_for_empty_spec(specs: CompositeSpec): def map_device(key, value, device_map=device_map): return value.to(device_map[key]) - self._env_tensordict.named_apply(map_device, nested_keys=True) + self._env_tensordict.named_apply( + map_device, + nested_keys=True, + ) self._batch_locked = meta_data.batch_locked else: @@ -535,22 +552,17 @@ def _create_td(self) -> None: self._selected_keys = self._selected_keys.union(reset_keys) # input keys - self._selected_input_keys = { - _unravel_key_to_tuple(key) for key in self._env_input_keys - } + self._selected_input_keys = {unravel_key(key) for key in self._env_input_keys} # output keys after reset self._selected_reset_keys = { - _unravel_key_to_tuple(key) - for key in self._env_obs_keys + self.done_keys + reset_keys + unravel_key(key) for key in self._env_obs_keys + self.done_keys + reset_keys } # output keys after reset, filtered self._selected_reset_keys_filt = { unravel_key(key) for key in self._env_obs_keys + self.done_keys } # output keys after step - self._selected_step_keys = { - _unravel_key_to_tuple(key) for key in self._env_output_keys - } + self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys} if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( @@ -657,6 +669,7 @@ def start(self) -> None: self._start_workers() def to(self, device: DEVICE_TYPING): + self._non_blocking = None device = torch.device(device) if device == self.device: return self @@ -678,10 +691,10 @@ def to(self, device: DEVICE_TYPING): return self -class SerialEnv(_BatchedEnv): +class SerialEnv(BatchedEnvBase): """Creates a series of environments in the same process.""" - __doc__ += _BatchedEnv.__doc__ + __doc__ += BatchedEnvBase.__doc__ _share_memory = False @@ -689,11 +702,27 @@ def _start_workers(self) -> None: _num_workers = self.num_workers self._envs = [] - + weakref_set = set() for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - if self.device is not None: - env = env.to(self.device) + # We want to avoid having the same env multiple times + # so we try to deepcopy it if needed. If we can't, we make + # the user aware that this isn't a very good idea + wr = weakref.ref(env) + if wr in weakref_set: + try: + env = deepcopy(env) + except Exception: + warn( + "Deepcopying the env failed within SerialEnv " + "but more than one copy of the same env was found. " + "This is a dangerous situation if your env keeps track " + "of some variables (e.g., state) in-place. " + "We'll use the same copy of the environment be beaware that " + "this may have important, unwanted issues for stateful " + "environments!" + ) + weakref_set.add(wr) self._envs.append(env) self.is_closed = False @@ -755,8 +784,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = None else: env_device = _env.device - if env_device != self.device: - tensordict_ = tensordict_.to(env_device) + if env_device != self.device and env_device is not None: + tensordict_ = tensordict_.to( + env_device, non_blocking=self.non_blocking + ) else: tensordict_ = tensordict_.clone(False) else: @@ -764,30 +795,28 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( - _td.select(*self._selected_reset_keys_filt, strict=False) + _td, + keys_to_update=list(self._selected_reset_keys_filt), ) selected_output_keys = self._selected_reset_keys_filt device = self.device - if self._single_task: - # select + clone creates 2 tds, but we can create one only - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in selected_output_keys: - _set_single_key( - self.shared_tensordict_parent, out, key, clone=True, device=device - ) - else: - out = self.shared_tensordict_parent.select( - *selected_output_keys, - strict=False, - ) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clone().clear_device_() + + # select + clone creates 2 tds, but we can create one only + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() + + out = self.shared_tensordict_parent.named_apply( + select_and_clone, + nested_keys=True, + ) + del out["next"] + + if out.device != device: + if device is None: + out = out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=self.non_blocking) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -807,30 +836,31 @@ def _step( # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device - if env_device != self.device: - data_in = tensordict_in[i].to(env_device, non_blocking=False) + if env_device != self.device and env_device is not None: + data_in = tensordict_in[i].to( + env_device, non_blocking=self.non_blocking + ) else: data_in = tensordict_in[i] out_td = self._envs[i]._step(data_in) - next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) + next_td[i].update_(out_td, keys_to_update=list(self._env_output_keys)) + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps device = self.device - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True, device=device) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clone().clear_device_() - else: - out = out.to(device, non_blocking=False) + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() + + # out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) + out = next_td.named_apply(select_and_clone, nested_keys=True) + + if out.device != device: + if device is None: + out = out.clear_device_() + elif out.device != device: + out = out.to(device, non_blocking=self.non_blocking) return out def __getattr__(self, attr: str) -> Any: @@ -876,14 +906,14 @@ def to(self, device: DEVICE_TYPING): return self -class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta): +class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """Creates one environment per process. TensorDicts are passed via shared memory or memory map. """ - __doc__ += _BatchedEnv.__doc__ + __doc__ += BatchedEnvBase.__doc__ __doc__ += """ .. warning:: @@ -1013,6 +1043,12 @@ class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta): >>> # If no cuda device is available >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) + .. warning:: + ParallelEnv disable gradients in all operations (:meth:`~.step`, + :meth:`~.reset` and :meth:`~.step_and_maybe_reset`) because gradients + cannot be passed through :class:`multiprocessing.Pipe` objects. + Only :class:`~torchrl.envs.SerialEnv` will support backpropagation. + """ def _start_workers(self) -> None: @@ -1040,6 +1076,7 @@ def _start_workers(self) -> None: def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda + # self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) self.shared_tensordict_parent.apply(look_for_cuda) has_cuda = has_cuda[0] if has_cuda: @@ -1115,36 +1152,34 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: event.wait() event.clear() + @torch.no_grad() @_check_start def step_and_maybe_reset( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - if self._single_task and not self.has_lazy_inputs: - # We must use the in_keys and nothing else for the following reasons: - # - efficiency: copying all the keys will in practice mean doing a lot - # of writing operations since the input tensordict may (and often will) - # contain all the previous output data. - # - value mismatch: if the batched env is placed within a transform - # and this transform overrides an observation key (eg, CatFrames) - # the shape, dtype or device may not necessarily match and writing - # the value in-place will fail. - for key in self._env_input_keys: - self.shared_tensordict_parent.set_(key, tensordict.get(key)) - next_td = tensordict.get("next", None) - if next_td is not None: - # we copy the input keys as well as the keys in the 'next' td, if any - # as this mechanism can be used by a policy to set anticipatively the - # keys of the next call (eg, with recurrent nets) - for key in next_td.keys(True, True): - key = unravel_key(("next", key)) - if key in self.shared_tensordict_parent.keys(True, True): - self.shared_tensordict_parent.set_(key, next_td.get(key[1:])) + # We must use the in_keys and nothing else for the following reasons: + # - efficiency: copying all the keys will in practice mean doing a lot + # of writing operations since the input tensordict may (and often will) + # contain all the previous output data. + # - value mismatch: if the batched env is placed within a transform + # and this transform overrides an observation key (eg, CatFrames) + # the shape, dtype or device may not necessarily match and writing + # the value in-place will fail. + self.shared_tensordict_parent.update_( + tensordict, keys_to_update=self._env_input_keys + ) + next_td_passthrough = tensordict.get("next", None) + if next_td_passthrough is not None: + # if we have input "next" data (eg, RNNs which pass the next state) + # the sub-envs will need to process them through step_and_maybe_reset. + # We keep track of which keys are present to let the worker know what + # should be passd to the env (we don't want to pass done states for instance) + next_td_keys = list(next_td_passthrough.keys(True, True)) + self.shared_tensordict_parent.get("next").update_(next_td_passthrough) else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, "next", strict=False) - ) + next_td_keys = None for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", None)) + self.parent_channels[i].send(("step_and_maybe_reset", next_td_keys)) for i in range(self.num_workers): event = self._events[i] @@ -1160,45 +1195,54 @@ def step_and_maybe_reset( next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: - next_td = next_td.to(device, non_blocking=False) - tensordict_ = tensordict_.to(device, non_blocking=False) + next_td = next_td._fast_apply( + lambda x: x.to(device, non_blocking=self.non_blocking) + if x.device != device + else x.clone(), + device=device, + ) + tensordict_ = tensordict_._fast_apply( + lambda x: x.to(device, non_blocking=self.non_blocking) + if x.device != device + else x.clone(), + device=device, + ) else: next_td = next_td.clone().clear_device_() tensordict_ = tensordict_.clone().clear_device_() tensordict.set("next", next_td) return tensordict, tensordict_ + @torch.no_grad() @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - if self._single_task and not self.has_lazy_inputs: - # We must use the in_keys and nothing else for the following reasons: - # - efficiency: copying all the keys will in practice mean doing a lot - # of writing operations since the input tensordict may (and often will) - # contain all the previous output data. - # - value mismatch: if the batched env is placed within a transform - # and this transform overrides an observation key (eg, CatFrames) - # the shape, dtype or device may not necessarily match and writing - # the value in-place will fail. - for key in tensordict.keys(True, True): - # we copy the input keys as well as the keys in the 'next' td, if any - # as this mechanism can be used by a policy to set anticipatively the - # keys of the next call (eg, with recurrent nets) - if key in self._env_input_keys or ( - isinstance(key, tuple) - and key[0] == "next" - and key in self.shared_tensordict_parent.keys(True, True) - ): - val = tensordict.get(key) - self.shared_tensordict_parent.set_(key, val) + # We must use the in_keys and nothing else for the following reasons: + # - efficiency: copying all the keys will in practice mean doing a lot + # of writing operations since the input tensordict may (and often will) + # contain all the previous output data. + # - value mismatch: if the batched env is placed within a transform + # and this transform overrides an observation key (eg, CatFrames) + # the shape, dtype or device may not necessarily match and writing + # the value in-place will fail. + self.shared_tensordict_parent.update_( + tensordict, keys_to_update=list(self._env_input_keys) + ) + next_td_passthrough = tensordict.get("next", None) + if next_td_passthrough is not None: + # if we have input "next" data (eg, RNNs which pass the next state) + # the sub-envs will need to process them through step_and_maybe_reset. + # We keep track of which keys are present to let the worker know what + # should be passd to the env (we don't want to pass done states for instance) + next_td_keys = list(next_td_passthrough.keys(True, True)) + self.shared_tensordict_parent.get("next").update_(next_td_passthrough) else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, "next", strict=False) - ) + next_td_keys = None + if self.event is not None: self.event.record() self.event.synchronize() for i in range(self.num_workers): - self.parent_channels[i].send(("step", None)) + self.parent_channels[i].send(("step", next_td_keys)) for i in range(self.num_workers): event = self._events[i] @@ -1209,21 +1253,23 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") device = self.device - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True, device=device) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False) - if out.device == device: - out = out.clone() + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() + + out = next_td.named_apply( + select_and_clone, + nested_keys=True, + ) + if out.device != device: + if device is None: + out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=self.non_blocking) return out + @torch.no_grad() @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict is not None: @@ -1258,16 +1304,30 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i] - .get("next") - .select(*self._selected_reset_keys, strict=False) + self.shared_tensordicts[i].get("next"), + keys_to_update=list(self._selected_reset_keys), ) if tensordict_ is not None: self.shared_tensordicts[i].update_( - tensordict_.select(*self._selected_reset_keys, strict=False) + tensordict_, keys_to_update=list(self._selected_reset_keys) ) continue - out = ("reset", tensordict_) + if tensordict_ is not None: + tdkeys = list(tensordict_.keys(True, True)) + + # This way we can avoid calling select over all the keys in the shared tensordict + def tentative_update(val, other): + if other is not None: + val.copy_(other) + return val + + self.shared_tensordicts[i].apply_( + tentative_update, tensordict_, default=None + ) + out = ("reset", tdkeys) + else: + out = ("reset", False) + channel.send(out) workers.append(i) @@ -1278,26 +1338,22 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: selected_output_keys = self._selected_reset_keys_filt device = self.device - if self._single_task: - # select + clone creates 2 tds, but we can create one only - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in selected_output_keys: - _set_single_key( - self.shared_tensordict_parent, out, key, clone=True, device=device - ) - else: - out = self.shared_tensordict_parent.select( - *selected_output_keys, - strict=False, - ) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clear_device_().clone() + + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() + + out = self.shared_tensordict_parent.named_apply( + select_and_clone, + nested_keys=True, + ) + del out["next"] + + if out.device != device: + if device is None: + out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=self.non_blocking) return out @_check_start @@ -1496,9 +1552,16 @@ def look_for_cuda(tensor, has_cuda=has_cuda): torchrl_logger.info(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - cur_td = env.reset(tensordict=data) + # we use 'data' to pass the keys that we need to pass to reset, + # because passing the entire buffer may have unwanted consequences + cur_td = env.reset( + tensordict=root_shared_tensordict.select(*data, strict=False) + if data + else None + ) shared_tensordict.update_( - cur_td.select(*_selected_reset_keys, strict=False) + cur_td, + keys_to_update=list(_selected_reset_keys), ) if event is not None: event.record() @@ -1510,7 +1573,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + # No need to copy here since we don't write in-place + if data: + next_td_passthrough_keys = data + input = root_shared_tensordict.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + else: + input = root_shared_tensordict + next_td = env._step(input) next_shared_tensordict.update_(next_td) if event is not None: event.record() @@ -1522,9 +1593,25 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict) - next_shared_tensordict.update_(td.get("next")) + # We must copy the root shared td here, or at least get rid of done: + # if we don't `td is root_shared_tensordict` + # which means that root_shared_tensordict will carry the content of next + # in the next iteration. When using StepCounter, it will look for an + # existing done state, find it and consider the env as done by input (not + # by output) of the step! + # Caveat: for RNN we may need some keys of the "next" TD so we pass the list + # through data + if data: + next_td_passthrough_keys = data + input = root_shared_tensordict.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + else: + input = root_shared_tensordict + td, root_next_td = env.step_and_maybe_reset(input) + next_shared_tensordict.update_(td.pop("next")) root_shared_tensordict.update_(root_next_td) + if event is not None: event.record() event.synchronize() @@ -1585,8 +1672,5 @@ def look_for_cuda(tensor, has_cuda=has_cuda): child_pipe.send(("_".join([cmd, "done"]), None)) -def _update_cuda(t_dest, t_source): - if t_source is None: - return - t_dest.copy_(t_source.pin_memory(), non_blocking=False) - return +# Create an alias for possible imports +_BatchedEnv = BatchedEnvBase diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b2b201922e1..44d5a554043 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -419,6 +419,13 @@ def run_type_checks(self, run_type_checks: bool) -> None: @property def batch_size(self) -> torch.Size: + """Number of envs batched in this environment instance organised in a `torch.Size()` object. + + Environment may be similar or different but it is assumed that they have little if + not no interactions between them (e.g., multi-task or batched execution + in parallel). + + """ _batch_size = self.__dict__["_batch_size"] if _batch_size is None: _batch_size = self._batch_size = torch.Size([]) @@ -439,6 +446,11 @@ def batch_size(self, value: torch.Size) -> None: self.input_spec.shape = value self.input_spec.lock_() + @property + def shape(self): + """Equivalent to :attr:`~.batch_size`.""" + return self.batch_size + @property def device(self) -> torch.device: device = self.__dict__.get("_device", None) @@ -2055,8 +2067,8 @@ def reset( tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly - # if tensordict_reset.device != self.device: - # tensordict_reset = tensordict_reset.to(self.device, non_blocking=False) + # if reset.device != self.device: + # reset = reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -2162,7 +2174,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: self.batch_locked or self.batch_size != () ) and tensordict.batch_size != self.batch_size: raise RuntimeError( - f"Expected a tensordict with shape==env.shape, " + f"Expected a tensordict with shape==env.batch_size, " f"got {tensordict.batch_size} and {self.batch_size}" ) @@ -2261,7 +2273,9 @@ def rollout( called on the sub-envs that are done. Default is True. return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. tensordict (TensorDict, optional): if auto_reset is False, an initial - tensordict must be provided. + tensordict must be provided. Rollout will check if this tensordict has done flags and reset the + environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the + output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout. Returns: TensorDict object containing the resulting trajectory. @@ -2357,6 +2371,26 @@ def rollout( >>> print(rollout.names) [None, 'time'] + Rollouts can be used in a loop to emulate data collection. + To do so, you need to pass as input the last tensordict coming from the previous rollout after calling + :func:`~torchrl.envs.utils.step_mdp` on it. + + Examples: + >>> from torchrl.envs import GymEnv, step_mdp + >>> env = GymEnv("CartPole-v1") + >>> epochs = 10 + >>> input_td = env.reset() + >>> for i in range(epochs): + ... rollout_td = env.rollout( + ... max_steps=100, + ... break_when_any_done=False, + ... auto_reset=False, + ... tensordict=input_td, + ... ) + ... input_td = step_mdp( + ... rollout_td[..., -1], + ... ) + """ if auto_cast_to_device: try: @@ -2376,6 +2410,9 @@ def rollout( tensordict = self.reset() elif tensordict is None: raise RuntimeError("tensordict must be provided when auto_reset is False") + else: + tensordict = self.maybe_reset(tensordict) + if policy is None: policy = self.rand_action @@ -2418,13 +2455,13 @@ def _rollout_stop_early( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict = tensordict.to(policy_device, non_blocking=False) + tensordict = tensordict.to(policy_device, non_blocking=True) else: tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: - tensordict = tensordict.to(env_device, non_blocking=False) + tensordict = tensordict.to(env_device, non_blocking=True) else: tensordict.clear_device_() tensordict = self.step(tensordict) @@ -2472,16 +2509,19 @@ def _rollout_nonstop( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict_ = tensordict_.to(policy_device, non_blocking=False) + tensordict_ = tensordict_.to(policy_device, non_blocking=True) else: tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: if env_device is not None: - tensordict_ = tensordict_.to(env_device, non_blocking=False) + tensordict_ = tensordict_.to(env_device, non_blocking=True) else: tensordict_.clear_device_() - tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) + if i == max_steps - 1: + tensordict = self.step(tensordict_) + else: + tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) tensordicts.append(tensordict) if i == max_steps - 1: # we don't truncated as one could potentially continue the run @@ -2545,14 +2585,28 @@ def step_and_maybe_reset( action_keys=self.action_keys, done_keys=self.done_keys, ) + tensordict_ = self.maybe_reset(tensordict_) + return tensordict, tensordict_ + + def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: + """Checks the done keys of the input tensordict and, if needed, resets the environment where it is done. + + Args: + tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`. + + Returns: + A tensordict that is identical to the input where the environment was + not reset and contains the new reset data where the environment was reset. + + """ any_done = _terminated_or_truncated( - tensordict_, + tensordict, full_done_spec=self.output_spec["full_done_spec"], key="_reset", ) if any_done: - tensordict_ = self.reset(tensordict_) - return tensordict, tensordict_ + tensordict = self.reset(tensordict) + return tensordict def empty_cache(self): """Erases all the cached values. diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 38995a07a6b..d3b3dfd659c 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -322,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=False) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -366,7 +366,7 @@ def _reset( for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() - tensordict_out = tensordict_out.to(self.device, non_blocking=False) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index a9b43fed62b..894d56ef5b6 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -102,6 +102,7 @@ def __init__(self, env_name, **kwargs): device_num = torch.device(kwargs.pop("device", 0)).index kwargs["override_options"] = [ f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}", + "habitat.simulator.concur_render=False", ] super().__init__(env_name=env_name, **kwargs) diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 0aa5aa99313..7ac318e03cb 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -6,7 +6,6 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.data.tensor_specs import ( @@ -94,6 +93,8 @@ def available_envs(cls): ] def __init__(self, dataset_name, device="cpu", batch_size=None): + from torchrl.data.datasets.openml import OpenMLExperienceReplay + if batch_size is None: batch_size = torch.Size([]) else: diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 1c12cf9be15..546321d5815 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util from typing import List, Optional, Union import torch @@ -28,21 +29,7 @@ ) from torchrl.envs.transforms.utils import _set_missing_tolerance -try: - from torchvision import models - - _has_tv = True -except ImportError: - _has_tv = False - -try: - from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights - from torchvision.models._api import WeightsEnum -except ImportError: - - class WeightsEnum: # noqa: D101 - # placeholder - pass +_has_tv = importlib.util.find_spec("torchvision", None) is not None R3M_MODEL_MAP = { @@ -62,6 +49,8 @@ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True): "Tried to instantiate R3M without torchvision. Make sure you have " "torchvision installed in your environment." ) + from torchvision import models + self.model_name = model_name if model_name == "resnet18": # self.model_name = "r3m_18" @@ -152,6 +141,13 @@ def _load_weights(model_name, r3m_instance, dir_prefix): r3m_instance.convnet.load_state_dict(state_dict) def load_weights(self, dir_prefix=None, tv_weights=None): + from torchvision import models + from torchvision.models import ( + ResNet18_Weights, + ResNet34_Weights, + ResNet50_Weights, + ) + if dir_prefix is not None and tv_weights is not None: raise RuntimeError( "torchvision weights API does not allow for custom download path." @@ -244,7 +240,7 @@ def __init__( out_keys: List[str] = None, size: int = 244, stack_images: bool = True, - download: Union[bool, WeightsEnum, str] = False, + download: Union[bool, "WeightsEnum", str] = False, # noqa: F821 download_path: Optional[str] = None, tensor_pixels_keys: List[str] = None, ): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a661b152d39..a449a2395e0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6,6 +6,7 @@ from __future__ import annotations import collections +import importlib.util import multiprocessing as mp import warnings from copy import copy @@ -28,7 +29,7 @@ ) from tensordict._tensordict import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase -from tensordict.utils import expand_as_right, NestedKey +from tensordict.utils import expand_as_right, expand_right, NestedKey from torch import nn, Tensor from torch.utils._pytree import tree_map from torchrl._utils import _replace_last @@ -55,25 +56,7 @@ from torchrl.envs.utils import _sort_keys, _update_during_reset, step_mdp from torchrl.objectives.value.functional import reward2go -try: - from torchvision.transforms.functional import center_crop - - try: - from torchvision.transforms.functional import InterpolationMode, resize - - def interpolation_fn(interpolation): # noqa: D103 - return InterpolationMode(interpolation) - - except ImportError: - - def interpolation_fn(interpolation): # noqa: D103 - return interpolation - - from torchvision.transforms.functional_tensor import resize - - _has_tv = True -except ImportError: - _has_tv = False +_has_tv = importlib.util.find_spec("torchvision", None) is not None IMAGE_KEYS = ["pixels"] _MAX_NOOPS_TRIALS = 10 @@ -791,7 +774,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): return tensordict_reset def _reset_proc_data(self, tensordict, tensordict_reset): - # self._complete_done(self.full_done_spec, tensordict_reset) + # self._complete_done(self.full_done_spec, reset) self._reset_check_done(tensordict, tensordict_reset) if tensordict is not None: tensordict_reset = _update_during_reset( @@ -802,7 +785,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset): # # doesn't do anything special # mt_mode = self.transform.missing_tolerance # self.set_missing_tolerance(True) - # tensordict_reset = self.transform._call(tensordict_reset) + # reset = self.transform._call(reset) # self.set_missing_tolerance(mt_mode) return tensordict_reset @@ -1748,6 +1731,18 @@ def __init__( super().__init__(in_keys=in_keys, out_keys=out_keys) self.w = int(w) self.h = int(h) + + try: + from torchvision.transforms.functional import InterpolationMode + + def interpolation_fn(interpolation): # noqa: D103 + return InterpolationMode(interpolation) + + except ImportError: + + def interpolation_fn(interpolation): # noqa: D103 + return interpolation + self.interpolation = interpolation_fn(interpolation) def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: @@ -1758,6 +1753,10 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: if ndim > 4: sizes = observation.shape[:-3] observation = torch.flatten(observation, 0, ndim - 4) + try: + from torchvision.transforms.functional import resize + except ImportError: + from torchvision.transforms.functional_tensor import resize observation = resize( observation, [self.w, self.h], @@ -1827,6 +1826,8 @@ def __init__( self.h = h if h else w def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + from torchvision.transforms.functional import center_crop + observation = center_crop(observation, [self.w, self.h]) return observation @@ -2619,6 +2620,8 @@ class CatFrames(ObservationTransform): reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise. + done_key (NestedKey, optional): the done key to be used as partial + done indicator. Must be unique. If not provided, defaults to ``"done"``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -2699,6 +2702,7 @@ def __init__( padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, + done_key: NestedKey | None = None, ): if in_keys is None: in_keys = IMAGE_KEYS @@ -2732,6 +2736,19 @@ def __init__( # keeps track of calls to _reset since it's only _call that will populate the buffer self.as_inverse = as_inverse self.reset_key = reset_key + self.done_key = done_key + + @property + def done_key(self): + done_key = self.__dict__.get("_done_key", None) + if done_key is None: + done_key = "done" + self._done_key = done_key + return done_key + + @done_key.setter + def done_key(self, value): + self._done_key = value @property def reset_key(self): @@ -2828,15 +2845,6 @@ def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase: # make linter happy. An exception has already been raised raise NotImplementedError - # # this duplicates the code below, but only for _reset values - # if _all: - # buffer.copy_(torch.roll(buffer_reset, shifts=-d, dims=dim)) - # buffer_reset = buffer - # else: - # buffer_reset = buffer[_reset] = torch.roll( - # buffer_reset, shifts=-d, dims=dim - # ) - # add new obs if self.dim < 0: n = buffer_reset.ndimension() + self.dim else: @@ -2905,69 +2913,145 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: if i != tensordict.ndim - 1: tensordict = tensordict.transpose(tensordict.ndim - 1, i) # first sort the in_keys with strings and non-strings - in_keys = list( - zip( - (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) - if isinstance(in_key, str) or len(in_key) == 1 - ) - ) - in_keys += list( - zip( - (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) - if not isinstance(in_key, str) and not len(in_key) == 1 + keys = [ + (in_key, out_key) + for in_key, out_key in zip(self.in_keys, self.out_keys) + if isinstance(in_key, str) + ] + keys += [ + (in_key, out_key) + for in_key, out_key in zip(self.in_keys, self.out_keys) + if not isinstance(in_key, str) + ] + + def unfold_done(done, N): + prefix = (slice(None),) * (tensordict.ndim - 1) + reset = torch.cat( + [ + torch.zeros_like(done[prefix + (slice(self.N - 1),)]), + torch.ones_like(done[prefix + (slice(1),)]), + done[prefix + (slice(None, -1),)], + ], + tensordict.ndim - 1, ) - ) - for in_key, out_key in zip(self.in_keys, self.out_keys): + reset_unfold = reset.unfold(tensordict.ndim - 1, self.N, 1) + reset_unfold_slice = reset_unfold[..., -1] + reset_unfold_list = [torch.zeros_like(reset_unfold_slice)] + for r in reversed(reset_unfold.unbind(-1)): + reset_unfold_list.append(r | reset_unfold_list[-1]) + reset_unfold_slice = reset_unfold_list[-1] + reset_unfold = torch.stack(list(reversed(reset_unfold_list))[1:], -1) + reset = reset[prefix + (slice(self.N - 1, None),)] + reset[prefix + (0,)] = 1 + return reset_unfold, reset + + done = tensordict.get(("next", self.done_key)) + done_mask, reset = unfold_done(done, self.N) + + for in_key, out_key in keys: # check if we have an obs in "next" that has already been processed. # If so, we must add an offset - data = tensordict.get(in_key) + data_orig = data = tensordict.get(in_key) + n_feat = data_orig.shape[data.ndim + self.dim] + first_val = None if isinstance(in_key, tuple) and in_key[0] == "next": # let's get the out_key we have already processed - prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]] - prev_val = tensordict.get(prev_out_key) - # the first item is located along `dim+1` at the last index of the - # first time index - idx = ( - [slice(None)] * (tensordict.ndim - 1) - + [0] - + [..., -1] - + [slice(None)] * (abs(self.dim) - 1) + prev_out_key = dict(zip(self.in_keys, self.out_keys)).get( + in_key[1], None ) - first_val = prev_val[tuple(idx)].unsqueeze(tensordict.ndim - 1) - data0 = [first_val] * (self.N - 1) - if self.padding == "constant": - data0 = [ - torch.full_like(elt, self.padding_value) for elt in data0[:-1] - ] + data0[-1:] - elif self.padding == "same": - pass - else: - # make linter happy. An exception has already been raised - raise NotImplementedError - elif self.padding == "same": - idx = [slice(None)] * (tensordict.ndim - 1) + [0] - data0 = [data[tuple(idx)].unsqueeze(tensordict.ndim - 1)] * (self.N - 1) - elif self.padding == "constant": - idx = [slice(None)] * (tensordict.ndim - 1) + [0] - data0 = [ - torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze( - tensordict.ndim - 1 + if prev_out_key is not None: + prev_val = tensordict.get(prev_out_key) + # n_feat = prev_val.shape[data.ndim + self.dim] // self.N + first_val = prev_val.unflatten( + data.ndim + self.dim, (self.N, n_feat) ) - ] * (self.N - 1) - else: - # make linter happy. An exception has already been raised - raise NotImplementedError + + idx = [slice(None)] * (tensordict.ndim - 1) + [0] + data0 = [ + torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze( + tensordict.ndim - 1 + ) + ] * (self.N - 1) data = torch.cat(data0 + [data], tensordict.ndim - 1) data = data.unfold(tensordict.ndim - 1, self.N, 1) + + # Place -1 dim at self.dim place before squashing + done_mask_expand = done_mask.view( + *done_mask.shape[: tensordict.ndim], + *(1,) * (data.ndim - 1 - tensordict.ndim), + done_mask.shape[-1], + ) + done_mask_expand = expand_as_right(done_mask_expand, data) data = data.permute( - *range(0, data.ndim + self.dim), + *range(0, data.ndim + self.dim - 1), -1, - *range(data.ndim + self.dim, data.ndim - 1), + *range(data.ndim + self.dim - 1, data.ndim - 1), ) + done_mask_expand = done_mask_expand.permute( + *range(0, done_mask_expand.ndim + self.dim - 1), + -1, + *range(done_mask_expand.ndim + self.dim - 1, done_mask_expand.ndim - 1), + ) + if self.padding != "same": + data = torch.where(done_mask_expand, self.padding_value, data) + else: + # TODO: This is a pretty bad implementation, could be + # made more efficient but it works! + reset_any = reset.any(-1, False) + reset_vals = list(data_orig[reset_any].unbind(0)) + j_ = float("inf") + reps = [] + d = data.ndim + self.dim - 1 + n_feat = data.shape[data.ndim + self.dim :].numel() + for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat: + if j > j_: + reset_vals = reset_vals[1:] + reps.extend([reset_vals[0]] * int(j)) + j_ = j + reps = torch.stack(reps) + data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1)) + + if first_val is not None: + # Aggregate reset along last dim + reset_any = reset.any(-1, False) + rexp = expand_right( + reset_any, (*reset_any.shape, *data.shape[data.ndim + self.dim :]) + ) + rexp = torch.cat( + [ + torch.zeros_like( + data0[0].repeat_interleave( + len(data0), dim=tensordict.ndim - 1 + ), + dtype=torch.bool, + ), + rexp, + ], + tensordict.ndim - 1, + ) + rexp = rexp.unfold(tensordict.ndim - 1, self.N, 1) + rexp_orig = rexp + rexp = torch.cat([rexp[..., 1:], torch.zeros_like(rexp[..., -1:])], -1) + if self.padding == "same": + rexp_orig = rexp_orig.flip(-1).cumsum(-1).flip(-1).bool() + rexp = rexp.flip(-1).cumsum(-1).flip(-1).bool() + rexp_orig = torch.cat( + [torch.zeros_like(rexp_orig[..., -1:]), rexp_orig[..., 1:]], -1 + ) + rexp = rexp.permute( + *range(0, rexp.ndim + self.dim - 1), + -1, + *range(rexp.ndim + self.dim - 1, rexp.ndim - 1), + ) + rexp_orig = rexp_orig.permute( + *range(0, rexp_orig.ndim + self.dim - 1), + -1, + *range(rexp_orig.ndim + self.dim - 1, rexp_orig.ndim - 1), + ) + data[rexp] = first_val[rexp_orig] + data = data.flatten(data.ndim + self.dim - 1, data.ndim + self.dim) tensordict.set(out_key, data) if tensordict_orig is not tensordict: tensordict_orig = tensordict.transpose(tensordict.ndim - 1, i) @@ -3612,10 +3696,10 @@ def set_container(self, container: Union[Transform, EnvBase]) -> None: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=False) + return tensordict.to(self.device, non_blocking=True) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=False) + return tensordict.to(self.device, non_blocking=True) def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -3628,8 +3712,8 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if parent is None: if self.orig_device is None: return tensordict - return tensordict.to(self.orig_device, non_blocking=False) - return tensordict.to(parent.device, non_blocking=False) + return tensordict.to(self.orig_device, non_blocking=True) + return tensordict.to(parent.device, non_blocking=True) def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec.to(self.device) @@ -5146,7 +5230,7 @@ def _reset( if step_count is None: step_count = self.container.observation_spec[step_count_key].zero() if step_count.device != reset.device: - step_count = step_count.to(reset.device, non_blocking=False) + step_count = step_count.to(reset.device, non_blocking=True) # zero the step count if reset is needed step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) @@ -5163,11 +5247,10 @@ def _reset( def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - for step_count_key, truncated_key, done_key, terminated_key in zip( + for step_count_key, truncated_key, done_key in zip( self.step_count_keys, self.truncated_keys, self.done_keys, - self.terminated_keys, ): step_count = tensordict.get(step_count_key) next_step_count = step_count + 1 @@ -5178,9 +5261,12 @@ def _step( truncated = truncated | next_tensordict.get(truncated_key, False) if self.update_done: done = next_tensordict.get(done_key, None) - terminated = next_tensordict.get(terminated_key, None) - if terminated is not None: - truncated = truncated & ~terminated + + # we can have terminated and truncated + # terminated = next_tensordict.get(terminated_key, None) + # if terminated is not None: + # truncated = truncated & ~terminated + done = truncated | done # we assume no done after reset next_tensordict.set(done_key, done) next_tensordict.set(truncated_key, truncated) diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 9c272b42b89..e814f5da476 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util from typing import List, Optional, Union import torch @@ -26,21 +27,7 @@ ) from torchrl.envs.transforms.utils import _set_missing_tolerance -try: - from torchvision import models - - _has_tv = True -except ImportError: - _has_tv = False - -try: - from torchvision.models import ResNet50_Weights - from torchvision.models._api import WeightsEnum -except ImportError: - - class WeightsEnum: # noqa: D101 - # placeholder - pass +_has_tv = importlib.util.find_spec("torchvision", None) is not None VIP_MODEL_MAP = { @@ -58,6 +45,9 @@ def __init__(self, in_keys, out_keys, model_name="resnet50", del_keys: bool = Tr "Tried to instantiate VIP without torchvision. Make sure you have " "torchvision installed in your environment." ) + + from torchvision import models + self.model_name = model_name if model_name == "resnet50": self.outdim = 2048 @@ -138,6 +128,9 @@ def _load_weights(model_name, vip_instance, dir_prefix): vip_instance.convnet.load_state_dict(state_dict) def load_weights(self, dir_prefix=None, tv_weights=None): + from torchvision import models + from torchvision.models import ResNet50_Weights + if dir_prefix is not None and tv_weights is not None: raise RuntimeError( "torchvision weights API does not allow for custom download path." @@ -218,7 +211,7 @@ def __init__( out_keys: List[str] = None, size: int = 244, stack_images: bool = True, - download: Union[bool, WeightsEnum, str] = False, + download: Union[bool, "WeightsEnum", str] = False, # noqa: F821 download_path: Optional[str] = None, tensor_pixels_keys: List[str] = None, ): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ebb9100655c..1ef0c05a9ae 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,13 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import abc + import contextlib +import functools import importlib.util +import inspect import os import re from enum import Enum -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import torch @@ -20,6 +24,7 @@ TensorDictBase, unravel_key, ) +from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -31,7 +36,9 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey -from torchrl._utils import _replace_last, logger as torchrl_logger +from torch import nn +from torch.utils._pytree import tree_map +from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -268,7 +275,7 @@ def _set_single_key( dest = new_val else: if device is not None and val.device != device: - val = val.to(device, non_blocking=False) + val = val.to(device, non_blocking=True) elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) @@ -419,7 +426,9 @@ def _per_level_env_check(data0, data1, check_dtype): ) -def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): +def check_env_specs( + env, return_contiguous=True, check_dtype=True, seed: int | None = None +): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with @@ -436,7 +445,12 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of inputs/outputs). Defaults to True. check_dtype (bool, optional): if False, dtype checks will be skipped. Defaults to True. - seed (int, optional): for reproducibility, a seed is set. + seed (int, optional): for reproducibility, a seed can be set. + The seed will be set in pytorch temporarily, then the RNG state will + be reverted to what it was before. For the env, we set the seed but since + setting the rng state back to what is was isn't a feature of most environment, + we leave it to the user to accomplish that. + Defaults to ``None``. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -444,8 +458,14 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """ if seed is not None: - torch.manual_seed(seed) - env.set_seed(seed) + device = ( + env.device if env.device is not None and env.device.type == "cuda" else None + ) + with _rng_decorator(seed, device=device): + env.set_seed(seed) + return check_env_specs( + env, return_contiguous=return_contiguous, check_dtype=check_dtype + ) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous) @@ -1102,3 +1122,207 @@ def _repr_by_depth(key): return (0, key) else: return (len(key) - 1, ".".join(key)) + + +def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False): + if policy is None: + if env is None: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + policy = RandomPolicy(env.input_spec["full_action_spec"]) + # make sure policy is an nn.Module + policy = _NonParametricPolicyWrapper(policy) + if not _policy_is_tensordict_compatible(policy): + # policy is a nn.Module that doesn't operate on tensordicts directly + # so we attempt to auto-wrap policy with TensorDictModule + if observation_spec is None: + raise ValueError( + "Unable to read observation_spec from the environment. This is " + "required to check compatibility of the environment and policy " + "since the policy is a nn.Module that operates on tensors " + "rather than a TensorDictModule or a nn.Module that accepts a " + "TensorDict as input and defines in_keys and out_keys." + ) + + try: + sig = policy.forward.__signature__ + except AttributeError: + sig = inspect.signature(policy.forward) + # we check if all the mandatory params are there + params = list(sig.parameters.keys()) + if ( + set(sig.parameters) == {"tensordict"} + or set(sig.parameters) == {"td"} + or ( + len(params) == 1 + and is_tensor_collection(sig.parameters[params[0]].annotation) + ) + ): + return policy + if fast_wrap: + in_keys = list(observation_spec.keys()) + out_keys = list(env.action_keys) + return TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + + required_kwargs = { + str(k) for k, p in sig.parameters.items() if p.default is inspect._empty + } + next_observation = { + key: value for key, value in observation_spec.rand().items() + } + if not required_kwargs.difference(set(next_observation)): + in_keys = [str(k) for k in sig.parameters if k in next_observation] + if env is None: + out_keys = ["action"] + else: + out_keys = list(env.action_keys) + for p in policy.parameters(): + policy_device = p.device + break + else: + policy_device = None + if policy_device: + next_observation = tree_map( + lambda x: x.to(policy_device), next_observation + ) + + output = policy(**next_observation) + + if isinstance(output, tuple): + out_keys.extend(f"output{i + 1}" for i in range(len(output) - 1)) + + policy = TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + else: + raise TypeError( + f"""Arguments to policy.forward are incompatible with entries in + env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). + If you want TorchRL to automatically wrap your policy with a TensorDictModule + then the arguments to policy.forward must correspond one-to-one with entries + in env.observation_spec. + For more complex behaviour and more control you can consider writing your + own TensorDictModule. + Check the collector documentation to know more about accepted policies. + """ + ) + return policy + + +def _policy_is_tensordict_compatible(policy: nn.Module): + if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( + policy.policy, RandomPolicy + ): + return True + + if isinstance(policy, TensorDictModuleBase): + return True + + sig = inspect.signature(policy.forward) + + if ( + len(sig.parameters) == 1 + and hasattr(policy, "in_keys") + and hasattr(policy, "out_keys") + ): + raise RuntimeError( + "Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys " + "is deprecated. Users should inherit from this class (which " + "has very few restrictions) to make the experience smoother. " + "Simply change your policy from `class Policy(nn.Module)` to `Policy(tensordict.nn.TensorDictModuleBase)` " + "and this error should disappear.", + ) + elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): + # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # we assume no TensorDict compatibility and will try to wrap it. + return False + + # if in_keys or out_keys were defined but policy is not a TensorDictModule or + # accepts multiple arguments then it's likely the user is trying to do something + # that will have undetermined behaviour, we raise an error + raise TypeError( + "Received a policy that defines in_keys or out_keys and also expects multiple " + "arguments to policy.forward. If the policy is compatible with TensorDict, it " + "should take a single argument of type TensorDict to policy.forward and define " + "both in_keys and out_keys. Alternatively, policy.forward can accept " + "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " + "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." + ) + + +class RandomPolicy: + """A random policy for data collectors. + + This is a wrapper around the action_spec.rand method. + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] + """ + + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() + self.action_spec = action_spec.clone() + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + if isinstance(self.action_spec, CompositeSpec): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) + + +class _PolicyMetaClass(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + # no kwargs + if isinstance(args[0], nn.Module): + return args[0] + return super().__call__(*args) + + +class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): + """A wrapper for non-parametric policies.""" + + def __init__(self, policy): + super().__init__() + self.policy = policy + + @property + def forward(self): + forward = self.__dict__.get("_forward", None) + if forward is None: + + @functools.wraps(self.policy) + def forward(*input, **kwargs): + return self.policy.__call__(*input, **kwargs) + + self.__dict__["_forward"] = forward + return forward + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + + elif attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + + elif "policy" in self.__dir__(): + policy = self.__getattribute__("policy") + return getattr(policy, attr) + try: + super().__getattr__(attr) + except Exception: + raise AttributeError( + f"policy not set in {self.__class__.__name__}, cannot access {attr}." + ) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index eb5f2a38944..51a6ba783b4 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -194,6 +194,8 @@ class TruncatedNormal(D.Independent): num_params: int = 2 + base_dist: _TruncatedNormal + arg_constraints = { "loc": constraints.real, "scale": constraints.greater_than(1e-6), @@ -231,20 +233,10 @@ def __init__( self.tanh_loc = tanh_loc self.device = loc.device - self.upscale = ( - upscale - if not isinstance(upscale, torch.Tensor) - else upscale.to(self.device) - ) + self.upscale = torch.as_tensor(upscale, device=self.device) - if isinstance(max, torch.Tensor): - max = max.to(self.device) - else: - max = torch.as_tensor(max, device=self.device) - if isinstance(min, torch.Tensor): - min = min.to(self.device) - else: - min = torch.as_tensor(min, device=self.device) + max = torch.as_tensor(max, device=self.device) + min = torch.as_tensor(min, device=self.device) self.min = min self.max = max self.update(loc, scale) @@ -258,7 +250,11 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: self.scale = scale base_dist = _TruncatedNormal( - loc, scale, self.min.expand_as(loc), self.max.expand_as(scale) + loc, + scale, + self.min.expand_as(loc), + self.max.expand_as(scale), + device=self.device, ) super().__init__(base_dist, 1, validate_args=False) @@ -271,13 +267,25 @@ def mode(self): return torch.max(torch.stack([m, a], -1), dim=-1)[0] def log_prob(self, value, **kwargs): + above_or_below = (self.min > value) | (self.max < value) a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 a = a.expand_as(value) b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 b = b.expand_as(value) value = torch.min(torch.stack([value, b], -1), dim=-1)[0] value = torch.max(torch.stack([value, a], -1), dim=-1)[0] - return super().log_prob(value, **kwargs) + lp = super().log_prob(value, **kwargs) + if above_or_below.any(): + if self.event_shape: + above_or_below = above_or_below.flatten(-len(self.event_shape), -1).any( + -1 + ) + lp = torch.masked_fill( + lp, + above_or_below.expand_as(lp), + torch.tensor(-float("inf"), device=lp.device, dtype=lp.dtype), + ) + return lp class TanhNormal(FasterTransformedDistribution): diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index 59b95658ea5..ea8ee0b9fe2 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -33,8 +33,10 @@ class TruncatedStandardNormal(Distribution): has_rsample = True eps = 1e-6 - def __init__(self, a, b, validate_args=None): + def __init__(self, a, b, validate_args=None, device=None): self.a, self.b = broadcast_all(a, b) + self.a = self.a.to(device) + self.b = self.b.to(device) if isinstance(a, Number) and isinstance(b, Number): batch_shape = torch.Size() else: @@ -139,9 +141,11 @@ class TruncatedNormal(TruncatedStandardNormal): has_rsample = True - def __init__(self, loc, scale, a, b, validate_args=None): + def __init__(self, loc, scale, a, b, validate_args=None, device=None): scale = scale.clamp_min(self.eps) self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + a = a.to(device) + b = b.to(device) self._non_std_a = a self._non_std_b = b a = (a - self.loc) / self.scale diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index b7a044cae7d..8d9855283f5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -33,12 +33,13 @@ class Actor(SafeModule): """General class for deterministic actors in RL. - The Actor class comes with default values for the out_keys (["action"]) - and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into :obj:`spec = CompositeSpec(action=spec)` + The Actor class comes with default values for the out_keys (``["action"]``) + and if the spec is provided but not as a + :class:`~torchrl.data.CompositeSpec` object, it will be + automatically translated into ``spec = CompositeSpec(action=spec)``. Args: - module (nn.Module): a :class:`torch.nn.Module` used to map the input to + module (nn.Module): a :class:`~torch.nn.Module` used to map the input to the output parameter space. in_keys (iterable of str, optional): keys to be read from input tensordict and passed to the module. If it @@ -47,9 +48,11 @@ class Actor(SafeModule): Defaults to ``["observation"]``. out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the - number of tensors returned by the embedded module. Using "_" as a + number of tensors returned by the embedded module. Using ``"_"`` as a key avoid writing tensor to output. Defaults to ``["action"]``. + + Keyword Args: spec (TensorSpec, optional): Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, @@ -59,7 +62,7 @@ class Actor(SafeModule): input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :meth:`~torchrl.data.TensorSpec.project` method. Default is ``False``. Examples: @@ -148,17 +151,23 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. - default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument. + default_interaction_type (str, optional): keyword-only argument. Default method to be used to retrieve - the output value. Should be one of: 'mode', 'median', 'mean' or 'random' - (in which case the value is sampled randomly from the distribution). Default - is 'mode'. - Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will - first look for the interaction mode dictated by the `interaction_typ()` - global function. If this returns `None` (its default value), then the - `default_interaction_type` of the `ProbabilisticTDModule` instance will be - used. Note that DataCollector instances will use `set_interaction_type` to - :class:`tensordict.nn.InteractionType.RANDOM` by default. + the output value. Should be one of: 'InteractionType.MODE', + 'InteractionType.MEDIAN', 'InteractionType.MEAN' or + 'InteractionType.RANDOM' (in which case the value is sampled + randomly from the distribution). Defaults to is 'InteractionType.RANDOM'. + + .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will + first look for the interaction mode dictated by the + :func:`~tensordict.nn.probabilistic.interaction_type` + global function. If this returns `None` (its default value), then the + `default_interaction_type` of the `ProbabilisticTDModule` + instance will be used. Note that + :class:`~torchrl.collectors.collectors.DataCollectorBase` + instances will use `set_interaction_type` to + :class:`tensordict.nn.InteractionType.RANDOM` by default. + distribution_class (Type, optional): keyword-only argument. A :class:`torch.distributions.Distribution` class to be used for sampling. @@ -197,9 +206,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... ) - >>> params = TensorDict.from_module(td_module) - >>> with params.to_module(td_module): - ... td = td_module(td) + >>> td = td_module(td) >>> td TensorDict( fields={ @@ -315,7 +322,8 @@ class ValueOperator(TensorDictModule): The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. - Defaults to ``["action"]``. + Defaults to ``["state_value"]`` or + ``["state_action_value"]`` if ``"action"`` is part of the ``in_keys``. Examples: >>> import torch @@ -334,9 +342,7 @@ class ValueOperator(TensorDictModule): >>> td_module = ValueOperator( ... in_keys=["observation", "action"], module=module ... ) - >>> params = TensorDict.from_module(td_module) - >>> with params.to_module(td_module): - ... td = td_module(td) + >>> td = td_module(td) >>> print(td) TensorDict( fields={ diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 9a7f88844cc..763b50eaa60 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -149,10 +149,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: out = action_tensordict.get(action_key) eps = self.eps.item() - cond = ( - torch.rand(action_tensordict.shape, device=action_tensordict.device) - < eps - ).to(out.dtype) + cond = torch.rand(action_tensordict.shape, device=out.device) < eps cond = expand_as_right(cond, out) spec = self.spec if spec is not None: @@ -177,7 +174,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"Action mask key {self.action_mask_key} not found in {tensordict}." ) spec.update_mask(action_mask) - out = cond * spec.rand().to(out.device) + (1 - cond) * out + out = torch.where(cond, spec.rand().to(out.device), out) else: raise RuntimeError("spec must be provided to the exploration wrapper.") action_tensordict.set(action_key, out) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6edcda5c800..8fcbd5a6699 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -255,7 +255,8 @@ def __init__( if functional: self.convert_to_functional( - actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] + actor_network, + "actor_network", ) else: self.actor_network = actor_network @@ -350,7 +351,7 @@ def in_keys(self): *[("next", key) for key in self.actor_network.in_keys], ] if self.critic_coef: - keys.extend(self.critic.in_keys) + keys.extend(self.critic_network.in_keys) return list(set(keys)) @property @@ -414,11 +415,11 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # TODO: if the advantage is gathered by forward, this introduces an # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic( + state_value = self.critic_network( tensordict_select, ).get(self.tensor_keys.value) loss_value = distance_loss( @@ -477,13 +478,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 1f5edcf26ed..0499e110398 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -5,13 +5,14 @@ from __future__ import annotations +import abc import warnings from copy import deepcopy from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from torch import nn @@ -31,7 +32,13 @@ def _updater_check_forward_prehook(module, *args, **kwargs): ) -class LossModule(TensorDictModuleBase): +class _LossMeta(abc.ABCMeta): + def __init__(cls, name, bases, attr_dict): + super().__init__(name, bases, attr_dict) + cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) + + +class LossModule(TensorDictModuleBase, metaclass=_LossMeta): """A parent class for RL losses. LossModule inherits from nn.Module. It is designed to read an input @@ -97,7 +104,6 @@ def tensor_keys(self) -> _AcceptedKeys: return self._tensor_keys def __new__(cls, *args, **kwargs): - cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) self = super().__new__(cls) return self @@ -110,7 +116,6 @@ def __init__(self): self.value_type = self.default_value_estimator self._tensor_keys = self._AcceptedKeys() self.register_forward_pre_hook(_updater_check_forward_prehook) - # self.register_forward_pre_hook(_parameters_to_tensordict) def _set_deprecated_ctor_keys(self, **kwargs) -> None: for key, value in kwargs.items(): @@ -243,6 +248,12 @@ def convert_to_functional( # For buffers, a cloned expansion (or equivalently a repeat) is returned. def _compare_and_expand(param): + if is_tensor_collection(param): + return param._apply_nest( + _compare_and_expand, + batch_size=[expand_dim, *param.shape], + call_on_nested=True, + ) if not isinstance(param, nn.Parameter): buffer = param.expand(expand_dim, *param.shape).clone() return buffer @@ -252,7 +263,7 @@ def _compare_and_expand(param): # is called: return expanded_param else: - p_out = param.repeat(expand_dim, *[1 for _ in param.shape]) + p_out = param.expand(expand_dim, *param.shape).clone() p_out = nn.Parameter( p_out.uniform_( p_out.min().item(), p_out.max().item() @@ -262,7 +273,9 @@ def _compare_and_expand(param): params = TensorDictParams( params.apply( - _compare_and_expand, batch_size=[expand_dim, *params.shape] + _compare_and_expand, + batch_size=[expand_dim, *params.shape], + call_on_nested=True, ), no_convert=True, ) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f963f0e0b52..87ccc75b863 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -45,6 +45,7 @@ class CQLLoss(LossModule): actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + Keyword args: loss_function (str, optional): loss function to be used with the value function loss. Default is `"smooth_l1"`. @@ -127,8 +128,9 @@ class CQLLoss(LossModule): alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_alpha_prime: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, @@ -169,10 +171,10 @@ class CQLLoss(LossModule): >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) - >>> loss = CQLLoss(actor, qvalue, value) + >>> loss = CQLLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> loss_actor, loss_qvalue, _, _, _, _ = loss( + >>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), @@ -185,7 +187,7 @@ class CQLLoss(LossModule): method. Examples: - >>> loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, @@ -471,10 +473,11 @@ def out_keys(self): "loss_qvalue", "loss_cql", "loss_alpha", - "loss_alpha_prime", "alpha", "entropy", ] + if self.with_lagrange: + keys.append("loss_alpha_prime") self._out_keys = keys return self._out_keys @@ -876,8 +879,9 @@ class DiscreteCQLLoss(LossModule): Examples: - >>> from torchrl.modules import MLP + >>> from torchrl.modules import MLP, QValueActor >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.objectives import DiscreteCQLLoss >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) >>> spec = OneHotDiscreteTensorSpec(n_act) @@ -895,8 +899,11 @@ class DiscreteCQLLoss(LossModule): >>> loss(data) TensorDict( fields={ - loss: Tensor(shape=torch.Size([]), device=cuda:0, dtype=torch.float32, is_shared=True), - loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a24aa4a1271..954bd0b9a42 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -275,7 +275,6 @@ def __init__( actor_network, "actor_network", create_target_params=False, - funs_to_decorate=["forward"], ) self.loss_function = loss_function diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 37fd1cbdaea..2298c262368 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -213,7 +213,10 @@ def __init__( try: action_space = value_network.action_space except AttributeError: - raise ValueError(self.ACTION_SPEC_ERROR) + raise ValueError( + "The action space could not be retrieved from the value_network. " + "Make sure it is available to the DQN loss module." + ) if action_space is None: warnings.warn( "action_space was not specified. DQNLoss will default to 'one-hot'." diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 62d2a628af4..9de05e09d9d 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -73,20 +73,22 @@ class IQLLoss(LossModule): ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) - >>> class ValueClass(nn.Module): + >>> class QValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) - >>> module = ValueClass() - >>> qvalue = ValueOperator( - ... module=module, - ... in_keys=['observation', 'action']) - >>> module = nn.Linear(n_obs, 1) - >>> value = ValueOperator( - ... module=module, - ... in_keys=["observation"]) + >>> qvalue = SafeModule( + ... QValueClass(), + ... in_keys=["observation", "action"], + ... out_keys=["state_action_value"], + ... ) + >>> value = SafeModule( + ... nn.Linear(n_obs, 1), + ... in_keys=["observation"], + ... out_keys=["state_value"], + ... ) >>> loss = IQLLoss(actor, qvalue, value) >>> batch = [2, ] >>> action = spec.rand(batch) @@ -134,20 +136,22 @@ class IQLLoss(LossModule): ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) - >>> class ValueClass(nn.Module): + >>> class QValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) - >>> module = ValueClass() - >>> qvalue = ValueOperator( - ... module=module, - ... in_keys=['observation', 'action']) - >>> module = nn.Linear(n_obs, 1) - >>> value = ValueOperator( - ... module=module, - ... in_keys=["observation"]) + >>> qvalue = SafeModule( + ... QValueClass(), + ... in_keys=["observation", "action"], + ... out_keys=["state_action_value"], + ... ) + >>> value = SafeModule( + ... nn.Linear(n_obs, 1), + ... in_keys=["observation"], + ... out_keys=["state_value"], + ... ) >>> loss = IQLLoss(actor, qvalue, value) >>> batch = [2, ] >>> action = spec.rand(batch) @@ -165,7 +169,7 @@ class IQLLoss(LossModule): method. Examples: - >>> loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, @@ -495,7 +499,7 @@ class DiscreteIQLLoss(IQLLoss): Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (TensorDictModule): Q(s) parametric model + qvalue_network (TensorDictModule): Q(s, a) parametric model. value_network (TensorDictModule, optional): V(s) parametric model. Keyword Args: @@ -526,34 +530,33 @@ class DiscreteIQLLoss(IQLLoss): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper >>> from torchrl.modules.distributions.discrete import OneHotCategorical - >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) - >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) + >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["logits"], ... out_keys=["action"], ... spec=spec, ... distribution_class=OneHotCategorical) - >>> qvalue = TensorDictModule( - ... nn.Linear(n_obs), + >>> qvalue = SafeModule( + ... nn.Linear(n_obs, n_act), ... in_keys=["observation"], ... out_keys=["state_action_value"], ... ) - >>> value = TensorDictModule( - ... nn.Linear(n_obs), + >>> value = SafeModule( + ... nn.Linear(n_obs, 1), ... in_keys=["observation"], ... out_keys=["state_value"], ... ) >>> loss = DiscreteIQLLoss(actor, qvalue, value) >>> batch = [2, ] - >>> action = spec.rand(batch) + >>> action = spec.rand(batch).long() >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": action, @@ -585,40 +588,33 @@ class DiscreteIQLLoss(IQLLoss): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper >>> from torchrl.modules.distributions.discrete import OneHotCategorical - >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss - >>> from tensordict import TensorDict >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) - >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) + >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["logits"], ... out_keys=["action"], ... spec=spec, ... distribution_class=OneHotCategorical) - >>> class ValueClass(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.linear = nn.Linear(n_obs, n_act) - ... def forward(self, obs): - ... return self.linear(obs) - >>> module = ValueClass() - >>> qvalue = ValueOperator( - ... module=module, - ... in_keys=['observation']) - >>> module = nn.Linear(n_obs, 1) - >>> value = ValueOperator( - ... module=module, - ... in_keys=["observation"]) + >>> qvalue = SafeModule( + ... nn.Linear(n_obs, n_act), + ... in_keys=["observation"], + ... out_keys=["state_action_value"], + ... ) + >>> value = SafeModule( + ... nn.Linear(n_obs, 1), + ... in_keys=["observation"], + ... out_keys=["state_value"], + ... ) >>> loss = DiscreteIQLLoss(actor, qvalue, value) >>> batch = [2, ] - >>> action = spec.rand(batch) + >>> action = spec.rand(batch).long() >>> loss_actor, loss_qvalue, loss_value, entropy = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, @@ -633,7 +629,7 @@ class DiscreteIQLLoss(IQLLoss): method. Examples: - >>> loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value') + >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value') >>> loss_actor, loss_qvalue, loss_value = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index ac2244b9a23..7a3e2e33ebc 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -14,10 +14,17 @@ import torch from tensordict import TensorDict, TensorDictBase -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModule, +) from tensordict.utils import NestedKey from torch import distributions as d +from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -25,9 +32,13 @@ distance_loss, ValueEstimators, ) - -from .common import LossModule -from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class PPOLoss(LossModule): @@ -927,6 +938,35 @@ def __init__( self.decrement = decrement self.samples_mc_kl = samples_mc_kl + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + self.tensor_keys.sample_log_prob, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.critic_network.in_keys, + ] + # Get the parameter keys from the actor dist + actor_dist_module = None + for module in self.actor_network.modules(): + # Ideally we should combine them if there is more than one + if isinstance(module, ProbabilisticTensorDictModule): + if actor_dist_module is not None: + raise RuntimeError( + "Actors with one and only one distribution are currently supported " + f"in {type(self).__name__}. If you need to use more than one " + f"distribtuion over the action space please submit an issue " + f"on github." + ) + actor_dist_module = module + if actor_dist_module is None: + raise RuntimeError("Could not find the probabilistic module in the actor.") + keys += list(actor_dist_module.in_keys) + self._in_keys = list(set(keys)) + @property def out_keys(self): if self._out_keys is None: @@ -944,27 +984,33 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDict: - tensordict = tensordict.clone(False) - advantage = tensordict.get(self.tensor_keys.advantage, None) + tensordict_copy = tensordict.copy() + try: + previous_dist = self.actor_network.build_dist_from_params(tensordict) + except KeyError as err: + raise KeyError( + "The parameters of the distribution were not found. " + f"Make sure they are provided to {type(self).__name__}." + ) from err + advantage = tensordict_copy.get(self.tensor_keys.advantage, None) if advantage is None: self.value_estimator( - tensordict, + tensordict_copy, params=self._cached_critic_network_params_detached, target_params=self.target_critic_network_params, ) - advantage = tensordict.get(self.tensor_keys.advantage) + advantage = tensordict_copy.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: loc = advantage.mean() scale = advantage.std().clamp_min(1e-6) advantage = (advantage - loc) / scale - log_weight, dist = self._log_weight(tensordict) + log_weight, dist = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage - previous_dist = self.actor_network.build_dist_from_params(tensordict) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): - current_dist = self.actor_network.get_dist(tensordict) + current_dist = self.actor_network.get_dist(tensordict_copy) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 4613810d0d3..9738b922c5d 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -351,7 +351,7 @@ def _set_in_keys(self): ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], - *self.critic.in_keys, + *self.critic_network.in_keys, ] self._in_keys = list(set(keys)) @@ -398,11 +398,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic(tensordict_select).get(self.tensor_keys.value) + state_value = self.critic_network(tensordict_select).get( + self.tensor_keys.value + ) loss_value = distance_loss( target_return, state_value, @@ -427,13 +429,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 053da9e53d2..5b722fd05f3 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -292,7 +292,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -980,7 +979,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -1036,7 +1034,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behaviour will be deprecated soon and a space will have to be passed. " "Check the DiscreteSACLoss documentation to see how to pass the action space. " ) action_space = "one-hot" diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 43dfa65c0c4..9afbf8095f0 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -300,8 +300,7 @@ def __init__( ): if eps is None and tau is None: raise RuntimeError( - "Neither eps nor tau was provided. " "This behaviour is deprecated.", - category=DeprecationWarning, + "Neither eps nor tau was provided. This behaviour is deprecated.", ) eps = 0.999 if (eps is None) ^ (tau is None): @@ -460,7 +459,7 @@ def _cache_values(fun): def new_fun(self, netname=None): __dict__ = self.__dict__ - _cache = __dict__["_cache"] + _cache = __dict__.setdefault("_cache", {}) attr_name = name if netname is not None: attr_name += "_" + netname diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index dfa56e5c672..1d28897105a 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -27,7 +27,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import _vmap_func, hold_out_net +from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -78,6 +78,7 @@ def _call_value_nets( single_call: bool, value_key: NestedKey, detach_next: bool, + vmap_randomness: str = "error", ): in_keys = value_net.in_keys if single_call: @@ -141,9 +142,11 @@ def _call_value_nets( ) elif params is not None: params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack) + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) else: - data_out = vmap(value_net, (0,))(data_in) + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) value_est = data_out.get(value_key) value, value_ = value_est[0], value_est[1] data.set(value_key, value) @@ -214,6 +217,7 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] + _vmap_randomness = None @property def advantage_key(self): @@ -428,6 +432,28 @@ def _next_value(self, tensordict, target_params, kwargs): next_value = step_td.get(self.tensor_keys.value) return next_value + @property + def vmap_randomness(self): + if self._vmap_randomness is None: + do_break = False + for val in self.__dict__.values(): + if isinstance(val, torch.nn.Module): + for module in val.modules(): + if isinstance(module, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" + do_break = True + break + if do_break: + # double break + break + else: + self._vmap_randomness = "error" + + return self._vmap_randomness + + def set_vmap_randomness(self, value): + self._vmap_randomness = value + class TD0Estimator(ValueEstimatorBase): """Temporal Difference (TD(0)) estimate of advantage function. @@ -589,6 +615,7 @@ def forward( single_call=self.shifted, value_key=self.tensor_keys.value, detach_next=True, + vmap_randomness=self.vmap_randomness, ) else: value = tensordict.get(self.tensor_keys.value) @@ -790,6 +817,7 @@ def forward( single_call=self.shifted, value_key=self.tensor_keys.value, detach_next=True, + vmap_randomness=self.vmap_randomness, ) else: value = tensordict.get(self.tensor_keys.value) @@ -1001,6 +1029,7 @@ def forward( single_call=self.shifted, value_key=self.tensor_keys.value, detach_next=True, + vmap_randomness=self.vmap_randomness, ) else: value = tensordict.get(self.tensor_keys.value) @@ -1247,6 +1276,7 @@ def forward( single_call=self.shifted, value_key=self.tensor_keys.value, detach_next=True, + vmap_randomness=self.vmap_randomness, ) else: value = tensordict.get(self.tensor_keys.value) @@ -1329,6 +1359,7 @@ def value_estimate( single_call=self.shifted, value_key=self.tensor_keys.value, detach_next=True, + vmap_randomness=self.vmap_randomness, ) else: value = tensordict.get(self.tensor_keys.value) @@ -1575,6 +1606,7 @@ def forward( single_call=self.shifted, value_key=self.tensor_keys.value, detach_next=True, + vmap_randomness=self.vmap_randomness, ) else: value = tensordict.get(self.tensor_keys.value) diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 256d0a2e840..6bcd3f50c86 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os from collections import defaultdict from pathlib import Path @@ -126,7 +128,7 @@ class CSVLogger(Logger): def __init__( self, exp_name: str, - log_dir: Optional[str] = None, + log_dir: str | None = None, video_format: str = "pt", video_fps: int = 30, ) -> None: diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index a6181145311..fbfe4f6e205 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util from copy import copy from typing import Optional, Sequence @@ -15,11 +16,7 @@ from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.record.loggers import Logger -try: - from torchvision.transforms.functional import center_crop as center_crop_fn - from torchvision.utils import make_grid -except ImportError: - center_crop_fn = None +_has_tv = importlib.util.find_spec("torchvision", None) is not None class VideoRecorder(ObservationTransform): @@ -95,7 +92,7 @@ def __init__( self.count = 0 self.center_crop = center_crop self.make_grid = make_grid - if center_crop and not center_crop_fn: + if center_crop and not _has_tv: raise ImportError( "Could not load center_crop from torchvision. Make sure torchvision is installed." ) @@ -118,20 +115,26 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: trailing_dim = range(observation_trsf.ndimension() - 3) observation_trsf = observation_trsf.permute(*trailing_dim, -1, -3, -2) if self.center_crop: - if center_crop_fn is None: + if not _has_tv: raise ImportError( "Could not import torchvision, `center_crop` not available." "Make sure torchvision is installed in your environment." ) + from torchvision.transforms.functional import ( + center_crop as center_crop_fn, + ) + observation_trsf = center_crop_fn( observation_trsf, [self.center_crop, self.center_crop] ) if self.make_grid and observation_trsf.ndimension() == 4: - if make_grid is None: + if not _has_tv: raise ImportError( "Could not import torchvision, `make_grid` not available." "Make sure torchvision is installed in your environment." ) + from torchvision.utils import make_grid + observation_trsf = make_grid(observation_trsf) self.obs.append(observation_trsf.to(torch.uint8)) return observation diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index f844613432c..03a7be37573 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -708,7 +708,7 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: def sample(self, batch: TensorDictBase) -> TensorDictBase: sample = self.replay_buffer.sample(batch_size=self.batch_size) - return sample.to(self.device, non_blocking=False) + return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch) diff --git a/tutorials/sphinx-tutorials/README.rst b/tutorials/sphinx-tutorials/README.rst index a7e41cccf45..7995a1fbb2e 100644 --- a/tutorials/sphinx-tutorials/README.rst +++ b/tutorials/sphinx-tutorials/README.rst @@ -1,2 +1,4 @@ README Tutos ============ + +Check the tutorials on torchrl documentation: https://pytorch.org/rl diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 85590c545fa..4a818474985 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -4,9 +4,14 @@ ====================================== **Author**: `Vincent Moens `_ +.. _coding_ddpg: + """ ############################################################################## +# Overview +# -------- +# # TorchRL separates the training of RL algorithms in various pieces that will be # assembled in your training script: the environment, the data collection and # storage, the model and finally the loss function. @@ -14,29 +19,33 @@ # TorchRL losses (or "objectives") are stateful objects that contain the # trainable parameters (policy and value models). # This tutorial will guide you through the steps to code a loss from the ground up -# using torchrl. +# using TorchRL. # # To this aim, we will be focusing on DDPG, which is a relatively straightforward # algorithm to code. -# DDPG (`Deep Deterministic Policy Gradient `_) +# `Deep Deterministic Policy Gradient `_ (DDPG) # is a simple continuous control algorithm. It consists in learning a # parametric value function for an action-observation pair, and -# then learning a policy that outputs actions that maximise this value +# then learning a policy that outputs actions that maximize this value # function given a certain observation. # -# Key learnings: +# What you will learn: # # - how to write a loss module and customize its value estimator; -# - how to build an environment in torchrl, including transforms -# (e.g. data normalization) and parallel execution; +# - how to build an environment in TorchRL, including transforms +# (for example, data normalization) and parallel execution; # - how to design a policy and value network; # - how to collect data from your environment efficiently and store them # in a replay buffer; # - how to store trajectories (and not transitions) in your replay buffer); -# - and finally how to evaluate your model. +# - how to evaluate your model. +# +# Prerequisites +# ~~~~~~~~~~~~~ # -# This tutorial assumes that you have completed the PPO tutorial which gives -# an overview of the torchrl components and dependencies, such as +# This tutorial assumes that you have completed the +# `PPO tutorial `_ which gives +# an overview of the TorchRL components and dependencies, such as # :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`, # although it should be # sufficiently transparent to be understood without a deep understanding of @@ -44,17 +53,20 @@ # # .. note:: # We do not aim at giving a SOTA implementation of the algorithm, but rather -# to provide a high-level illustration of torchrl's loss implementations +# to provide a high-level illustration of TorchRL's loss implementations # and the library features that are to be used in the context of # this algorithm. # # Imports and setup # ----------------- # +# .. code-block:: bash +# +# %%bash +# pip3 install torchrl mujoco glfw # sphinx_gallery_start_ignore import warnings -from typing import Tuple warnings.filterwarnings("ignore") from torch import multiprocessing @@ -63,24 +75,34 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore -import torch.cuda + +import torch import tqdm ############################################################################### -# We will execute the policy on cuda if available +# We will execute the policy on CUDA if available +is_fork = multiprocessing.get_start_method() == "fork" device = ( - torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") ) +collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA ############################################################################### -# torchrl :class:`~torchrl.objectives.LossModule` +# TorchRL :class:`~torchrl.objectives.LossModule` # ----------------------------------------------- # # TorchRL provides a series of losses to use in your training scripts. @@ -89,11 +111,11 @@ # # The main characteristics of TorchRL losses are: # -# - they are stateful objects: they contain a copy of the trainable parameters +# - They are stateful objects: they contain a copy of the trainable parameters # such that ``loss_module.parameters()`` gives whatever is needed to train the # algorithm. -# - They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward` -# method will receive a tensordict as input that contains all the necessary +# - They follow the ``TensorDict`` convention: the :meth:`torch.nn.Module.forward` +# method will receive a TensorDict as input that contains all the necessary # information to return a loss value. # # >>> data = replay_buffer.sample() @@ -101,8 +123,9 @@ # # - They output a :class:`tensordict.TensorDict` instance with the loss values # written under a ``"loss_"`` where ``smth`` is a string describing the -# loss. Additional keys in the tensordict may be useful metrics to log during +# loss. Additional keys in the ``TensorDict`` may be useful metrics to log during # training time. +# # .. note:: # The reason we return independent losses is to let the user use a different # optimizer for different sets of parameters for instance. Summing the losses @@ -129,14 +152,14 @@ # # Let us start with the :meth:`~torchrl.objectives.LossModule.__init__` # method. DDPG aims at solving a control task with a simple strategy: -# training a policy to output actions that maximise the value predicted by +# training a policy to output actions that maximize the value predicted by # a value network. Hence, our loss module needs to receive two networks in its # constructor: an actor and a value networks. We expect both of these to be -# tensordict-compatible objects, such as +# TensorDict-compatible objects, such as # :class:`tensordict.nn.TensorDictModule`. # Our loss function will need to compute a target value and fit the value # network to this, and generate an action and fit the policy such that its -# value estimate is maximised. +# value estimate is maximized. # # The crucial step of the :meth:`LossModule.__init__` method is the call to # :meth:`~torchrl.LossModule.convert_to_functional`. This method will extract @@ -149,7 +172,7 @@ # model with different sets of parameters, called "trainable" and "target" # parameters. # The "trainable" parameters are those that the optimizer needs to fit. The -# "target" parameters are usually a copy of the formers with some time lag +# "target" parameters are usually a copy of the former's with some time lag # (absolute or diluted through a moving average). # These target parameters are used to compute the value associated with the # next observation. One the advantages of using a set of target parameters @@ -163,7 +186,7 @@ # accessible but this will just return a **detached** version of the # actor parameters. # -# Later, we will see how the target parameters should be updated in torchrl. +# Later, we will see how the target parameters should be updated in TorchRL. # from tensordict.nn import TensorDictModule @@ -235,27 +258,22 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): hp.update(hyperparams) value_key = "state_action_value" if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator( - value_network=self.actor_critic, value_key=value_key, **hp - ) + self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator( - value_network=self.actor_critic, value_key=value_key, **hp - ) + self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.GAE: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator( - value_network=self.actor_critic, value_key=value_key, **hp - ) + self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") + self._value_estimator.set_keys(value=value_key) ############################################################################### -# The ``make_value_estimator`` method can but does not need to be called: if +# The ``make_value_estimator`` method can but does not need to be called: ifgg # not, the :class:`~torchrl.objectives.LossModule` will query this method with # its default estimator. # @@ -265,7 +283,7 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): # The central piece of an RL algorithm is the training loss for the actor. # In the case of DDPG, this function is quite simple: we just need to compute # the value associated with an action computed using the policy and optimize -# the actor weights to maximise this value. +# the actor weights to maximize this value. # # When computing this value, we must make sure to take the value parameters out # of the graph, otherwise the actor and value loss will be mixed up. @@ -279,12 +297,11 @@ def _loss_actor( ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys) # Get an action from the actor network: since we made it functional, we need to pass the params - td_copy = self.actor_network(td_copy, params=self.actor_network_params) + with self.actor_network_params.to_module(self.actor_network): + td_copy = self.actor_network(td_copy) # get the value associated with that action - td_copy = self.value_network( - td_copy, - params=self.value_network_params.detach(), - ) + with self.value_network_params.detach().to_module(self.value_network): + td_copy = self.value_network(td_copy) return -td_copy.get("state_action_value") @@ -302,11 +319,12 @@ def _loss_actor( def _loss_value( self, tensordict, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +): td_copy = tensordict.clone() # V(s, a) - self.value_network(td_copy, params=self.value_network_params) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get("state_action_value").squeeze(-1) # we manually reconstruct the parameters of the actor-critic, where the first @@ -321,11 +339,10 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - target_value = self.value_estimator.value_estimate( - tensordict, target_params=target_params - ).squeeze(-1) + with target_params.to_module(self.value_estimator): + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - # Computes the value loss: L2, L1 or smooth L1 depending on self.loss_funtion + # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) td_error = (pred_val - target_value).pow(2) @@ -337,7 +354,7 @@ def _loss_value( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The only missing piece is the forward method, which will glue together the -# value and actor loss, collect the cost values and write them in a tensordict +# value and actor loss, collect the cost values and write them in a ``TensorDict`` # delivered to the user. from tensordict import TensorDict, TensorDictBase @@ -397,7 +414,7 @@ class DDPGLoss(LossModule): # For this example, we will be using the ``"cheetah"`` task. The goal is to make # a half-cheetah run as fast as possible. # -# In TorchRL, one can create such a task by relying on dm_control or gym: +# In TorchRL, one can create such a task by relying on ``dm_control`` or ``gym``: # # .. code-block:: python # @@ -411,7 +428,7 @@ class DDPGLoss(LossModule): # # By default, these environment disable rendering. Training from states is # usually easier than training from images. To keep things simple, we focus -# on learning from states only. To pass the pixels to the tensordicts that +# on learning from states only. To pass the pixels to the ``tensordicts`` that # are collected by :func:`env.step()`, simply pass the ``from_pixels=True`` # argument to the constructor: # @@ -420,7 +437,7 @@ class DDPGLoss(LossModule): # env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True) # # We write a :func:`make_env` helper function that will create an environment -# with either one of the two backends considered above (dm-control or gym). +# with either one of the two backends considered above (``dm-control`` or ``gym``). # from torchrl.envs.libs.dm_control import DMControlEnv @@ -431,7 +448,7 @@ class DDPGLoss(LossModule): def make_env(from_pixels=False): - """Create a base env.""" + """Create a base ``env``.""" global env_library global env_name @@ -502,7 +519,7 @@ def make_env(from_pixels=False): def make_transformed_env( env, ): - """Apply transforms to the env (such as reward scaling and state normalization).""" + """Apply transforms to the ``env`` (such as reward scaling and state normalization).""" env = TransformedEnv(env) @@ -511,16 +528,6 @@ def make_transformed_env( # syntax. env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) - double_to_float_list = [] - double_to_float_inv_list = [] - if env_library is DMControlEnv: - # DMControl requires double-precision - double_to_float_list += [ - "reward", - "action", - ] - double_to_float_inv_list += ["action"] - # We concatenate all states into a single "observation_vector" # even if there is a single tensor, it'll be renamed in "observation_vector". # This facilitates the downstream operations as we know the name of the @@ -536,16 +543,12 @@ def make_transformed_env( # version of the transform env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True)) - double_to_float_list.append(out_key) - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) + env.append_transform(DoubleToFloat()) env.append_transform(StepCounter(max_frames_per_traj)) - # We need a marker for the start of trajectories for our OU exploration: + # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU) + # exploration: env.append_transform(InitTracker()) return env @@ -608,15 +611,16 @@ def make_t_env(): return env -# The backend can be gym or dm_control +# The backend can be ``gym`` or ``dm_control`` backend = "gym" ############################################################################### # .. note:: +# # ``frame_skip`` batches multiple step together with a single action -# If > 1, the other frame counts (e.g. frames_per_batch, total_frames) need to -# be adjusted to have a consistent total number of frames collected across -# experiments. This is important as raising the frame-skip but keeping the +# If > 1, the other frame counts (for example, frames_per_batch, total_frames) +# need to be adjusted to have a consistent total number of frames collected +# across experiments. This is important as raising the frame-skip but keeping the # total number of frames unchanged may seem like cheating: all things compared, # a dataset of 10M elements collected with a frame-skip of 2 and another with # a frame-skip of 1 actually have a ratio of interactions with the environment @@ -630,7 +634,7 @@ def make_t_env(): ############################################################################### # We also define when a trajectory will be truncated. A thousand steps (500 if -# frame-skip = 2) is a good number to use for cheetah: +# frame-skip = 2) is a good number to use for the cheetah task: max_frames_per_traj = 500 @@ -660,7 +664,7 @@ def get_env_stats(): ############################################################################### # Normalization stats # ~~~~~~~~~~~~~~~~~~~ -# Number of random steps used as for stats computation using ObservationNorm +# Number of random steps used as for stats computation using ``ObservationNorm`` init_env_steps = 5000 @@ -764,8 +768,8 @@ def make_ddpg_actor( module=q_net, ).to(device) - # init lazy moduless - qnet(actor(proof_environment.reset())) + # initialize lazy modules + qnet(actor(proof_environment.reset().to(device))) return actor, qnet @@ -779,7 +783,7 @@ def make_ddpg_actor( # ~~~~~~~~~~~ # # The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` -# exploration module, as suggesed in the original paper. +# exploration module, as suggested in the original paper. # Let's define the number of frames before OU noise reaches its minimum value annealing_frames = 1_000_000 @@ -801,24 +805,27 @@ def make_ddpg_actor( # environment and reset it when required. # Data collectors are designed to help developers have a tight control # on the number of frames per batch of data, on the (a)sync nature of this -# collection and on the resources allocated to the data collection (e.g. GPU, -# number of workers etc). +# collection and on the resources allocated to the data collection (for example +# GPU, number of workers, and so on). # # Here we will use -# :class:`~torchrl.collectors.MultiaSyncDataCollector`, a data collector that -# will be executed in an async manner (i.e. data will be collected while -# the policy is being optimized). With the :class:`MultiaSyncDataCollector`, -# multiple workers are running rollouts separately. When a batch is asked, it -# is gathered from the first worker that can provide it. +# :class:`~torchrl.collectors.SyncDataCollector`, a simple, single-process +# data collector. TorchRL offers other collectors, such as +# :class:`~torchrl.collectors.MultiaSyncDataCollector`, which executed the +# rollouts in an asynchronous manner (for example, data will be collected while +# the policy is being optimized, thereby decoupling the training and +# data collection). # # The parameters to specify are: # -# - the list of environment creation functions, +# - an environment factory or an environment, # - the policy, # - the total number of frames before the collector is considered empty, # - the maximum number of frames per trajectory (useful for non-terminating -# environments, like dm_control ones). +# environments, like ``dm_control`` ones). +# # .. note:: +# # The ``max_frames_per_traj`` passed to the collector will have the effect # of registering a new :class:`~torchrl.envs.StepCounter` transform # with the environment used for inference. We can achieve the same result @@ -837,8 +844,8 @@ def make_ddpg_actor( ############################################################################### # The number of frames returned by the collector at each iteration of the outer -# loop is equal to the length of each sub-trajectories times the number of envs -# run in parallel in each collector. +# loop is equal to the length of each sub-trajectories times the number of +# environments run in parallel in each collector. # # In other words, we expect batches from the collector to have a shape # ``[env_per_collector, traj_len]`` where @@ -849,26 +856,18 @@ def make_ddpg_actor( init_random_frames = 5000 num_collectors = 2 -from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.envs import ExplorationType -collector = MultiaSyncDataCollector( - create_env_fn=[ - parallel_env, - ] - * num_collectors, +collector = SyncDataCollector( + parallel_env, policy=actor_model_explore, total_frames=total_frames, - # max_frames_per_traj=max_frames_per_traj, # this is achieved by the env constructor frames_per_batch=frames_per_batch, init_random_frames=init_random_frames, reset_at_each_iter=False, split_trajs=False, - device=device, - # device for execution - storing_device=device, - # device where data will be stored and passed - update_at_each_batch=False, + device=collector_device, exploration_type=ExplorationType.RANDOM, ) @@ -961,7 +960,7 @@ def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb ############################################################################### -# We'll store the replay buffer in a temporary dirrectory on disk +# We'll store the replay buffer in a temporary directory on disk import tempfile @@ -977,17 +976,17 @@ def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb # size by dividing it by the length of the sub-trajectories yielded by our # data collector. # Regarding the batch-size, our sampling strategy will consist in sampling -# trajectories of length ``traj_len=200`` before selecting sub-trajecotries +# trajectories of length ``traj_len=200`` before selecting sub-trajectories # or length ``random_crop_len=25`` on which the loss will be computed. # This strategy balances the choice of storing whole trajectories of a certain -# length with the need for providing sampels with a sufficient heterogeneity +# length with the need for providing samples with a sufficient heterogeneity # to our loss. The following figure shows the dataflow from a collector # that gets 8 frames in each batch with 2 environments run in parallel, # feeds them to a replay buffer that contains 1000 trajectories and # samples sub-trajectories of 2 time steps each. # # .. figure:: /_static/img/replaybuffer_traj.png -# :alt: Storign trajectories in the replay buffer +# :alt: Storing trajectories in the replay buffer # # Let's start with the number of frames stored in the buffer @@ -1005,7 +1004,7 @@ def ceil_div(x, y): ############################################################################### # We also need to define how many updates we'll be doing per batch of data -# collected. This is known as the update-to-data or UTD ratio: +# collected. This is known as the update-to-data or ``UTD`` ratio: update_to_data = 64 ############################################################################### @@ -1032,7 +1031,7 @@ def ceil_div(x, y): # Loss module construction # ------------------------ # -# We build our loss module with the actor and qnet we've just created. +# We build our loss module with the actor and ``qnet`` we've just created. # Because we have target parameters to update, we _must_ create a target network # updater. # @@ -1189,7 +1188,7 @@ def ceil_div(x, y): # # .. note:: # As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 1M. +# use a greater value for ``total_frames`` for example, 1M. from matplotlib import pyplot as plt @@ -1205,7 +1204,7 @@ def ceil_div(x, y): # Conclusion # ---------- # -# In this tutorial, we have learnt how to code a loss module in TorchRL given +# In this tutorial, we have learned how to code a loss module in TorchRL given # the concrete example of DDPG. # # The key takeaways are: @@ -1215,3 +1214,11 @@ def ceil_div(x, y): # - How to use (or not) a target network, and how to update its parameters; # - How to create an optimizer associated with a loss module. # +# Next Steps +# ---------- +# +# To iterate further on this loss module we might consider: +# +# - Using `@dispatch` (see `[Feature] Distpatch IQL loss module `_.) +# - Allowing flexible TensorDict keys. +# diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index fcddd699b3a..eb476dfcc15 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -4,6 +4,8 @@ ============================== **Author**: `Vincent Moens `_ +.. _coding_dqn: + """ ############################################################################## @@ -86,6 +88,8 @@ import tempfile import warnings +from tensordict.nn import TensorDictSequential + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -94,13 +98,17 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore - import os import uuid @@ -125,7 +133,7 @@ ToTensorImage, TransformedEnv, ) -from torchrl.modules import DuelingCnnDQNet, EGreedyWrapper, QValueActor +from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger @@ -270,6 +278,7 @@ def get_norm_stats(): # let's check that normalizing constants have a size of ``[C, 1, 1]`` where # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`). print("state dict of the observation norm:", obs_norm_sd) + test_env.close() return obs_norm_sd @@ -328,13 +337,14 @@ def make_model(dummy_env): tensordict = dummy_env.fake_tensordict() actor(tensordict) - # we wrap our actor in an EGreedyWrapper for data collection - actor_explore = EGreedyWrapper( - actor, + # we join our actor with an EGreedyModule for data collection + exploration_module = EGreedyModule( + spec=dummy_env.action_spec, annealing_num_steps=total_frames, eps_init=eps_greedy_val, eps_end=eps_greedy_val_env, ) + actor_explore = TensorDictSequential(actor, exploration_module) return actor, actor_explore @@ -381,6 +391,13 @@ def get_replay_buffer(buffer_size, n_optim, batch_size): # We choose the following configuration: we will be running a series of # parallel environments synchronously in parallel in different collectors, # themselves running in parallel but asynchronously. +# +# .. note:: +# This feature is only available when running the code within the "spawn" +# start method of python multiprocessing library. If this tutorial is run +# directly as a script (thereby using the "fork" method) we will be using +# a regular :class:`~torchrl.collectors.SyncDataCollector`. +# # The advantage of this configuration is that we can balance the amount of # compute that is executed in batch with what we want to be executed # asynchronously. We encourage the reader to experiment how the collection @@ -389,9 +406,9 @@ def get_replay_buffer(buffer_size, n_optim, batch_size): # environment executed in parallel in each collector (controlled by the # ``num_workers`` hyperparameter). # -# When building the collector, we can choose on which device we want the -# environment and policy to execute the operations through the ``device`` -# keyword argument. The ``storing_devices`` argument will modify the +# Collector's devices are fully parametrizable through the ``device`` (general), +# ``policy_device``, ``env_device`` and ``storing_device`` arguments. +# The ``storing_device`` argument will modify the # location of the data being collected: if the batches that we are gathering # have a considerable size, we may want to store them on a different location # than the device where the computation is happening. For asynchronous data @@ -409,11 +426,10 @@ def get_collector( total_frames, device, ): - data_collector = MultiaSyncDataCollector( - [ - make_env(parallel=True, obs_norm_sd=stats), - ] - * num_collectors, + cls = MultiaSyncDataCollector + env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors + data_collector = cls( + env_arg, policy=actor_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -464,7 +480,12 @@ def get_loss_module(actor, gamma): # in practice, and the performance of the algorithm should hopefully not be # too sensitive to slight variations of these. -device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ############################################################################### # Optimizer @@ -642,6 +663,12 @@ def get_loss_module(actor, gamma): ) recorder.register(trainer) +############################################################################### +# The exploration module epsilon factor is also annealed: +# + +trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch) + ############################################################################### # - Any callable (including :class:`~torchrl.trainers.TrainerHookBase` # subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`. diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 679d625220c..6f31a0aed1a 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -4,6 +4,8 @@ ================================================== **Author**: `Vincent Moens `_ +.. _coding_ppo: + This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium control library `__. @@ -15,8 +17,8 @@ Key learnings: -- How to create an environment in TorchRL, transform its outputs, and collect data from this env; -- How to make your classes talk to each other using :class:`tensordict.TensorDict`; +- How to create an environment in TorchRL, transform its outputs, and collect data from this environment; +- How to make your classes talk to each other using :class:`~tensordict.TensorDict`; - The basics of building your training loop with TorchRL: - How to compute the advantage signal for policy gradient methods; @@ -56,7 +58,7 @@ # problem rather than re-inventing the wheel every time you want to train a policy. # # For completeness, here is a brief overview of what the loss computes, even though -# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows: +# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows: # 1. we will sample a batch of data by playing the # policy in the environment for a given number of steps. # 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using @@ -99,7 +101,7 @@ # 5. Finally, we will run our training loop and analyze the results. # # Throughout this tutorial, we'll be using the :mod:`tensordict` library. -# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract +# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract # what a module reads and writes and care less about the specific data # description and more about the algorithm itself. # @@ -114,9 +116,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -159,7 +166,12 @@ # actually return ``frame_skip`` frames). # -device = "cpu" if not torch.has_cuda else "cuda:0" +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) num_cells = 256 # number of cells in each layer i.e. output dim. lr = 3e-4 max_grad_norm = 1.0 @@ -174,22 +186,10 @@ # use. In general, the goal of an RL algorithm is to learn to solve the task # as fast as it can in terms of environment interactions: the lower the ``total_frames`` # the better. -# We also define a ``frame_skip``: in some contexts, repeating the same action -# multiple times over the course of a trajectory may be beneficial as it makes -# the behavior more consistent and less erratic. However, "skipping" -# too many frames will hamper training by reducing the reactivity of the actor -# to observation changes. -# -# When using ``frame_skip`` it is good practice to -# correct the other frame counts by the number of frames we are grouping -# together. If we configure a total count of X frames for training but -# use a ``frame_skip`` of Y, we will be actually collecting XY frames in total -# which exceeds our predefined budget. -# -frame_skip = 1 -frames_per_batch = 1000 // frame_skip +# +frames_per_batch = 1000 # For a complete training, bring the number of frames up to 1M -total_frames = 10_000 // frame_skip +total_frames = 10_000 ###################################################################### # PPO parameters @@ -220,23 +220,23 @@ # control system. Various libraries provide simulation environments for reinforcement # learning, including Gymnasium (previously OpenAI Gym), DeepMind control suite, and # many others. -# As a generalistic library, TorchRL's goal is to provide an interchangeable interface +# As a general library, TorchRL's goal is to provide an interchangeable interface # to a large panel of RL simulators, allowing you to easily swap one environment # with another. For example, creating a wrapped gym environment can be achieved with few characters: # -base_env = GymEnv("InvertedDoublePendulum-v4", device=device, frame_skip=frame_skip) +base_env = GymEnv("InvertedDoublePendulum-v4", device=device) ###################################################################### # There are a few things to notice in this code: first, we created # the environment by calling the ``GymEnv`` wrapper. If extra keyword arguments # are passed, they will be transmitted to the ``gym.make`` method, hence covering -# the most common env construction commands. +# the most common environment construction commands. # Alternatively, one could also directly create a gym environment using ``gym.make(env_name, **kwargs)`` # and wrap it in a `GymWrapper` class. # # Also the ``device`` argument: for gym, this only controls the device where -# input action and observered states will be stored, but the execution will always +# input action and observed states will be stored, but the execution will always # be done on CPU. The reason for this is simply that gym does not support on-device # execution, unless specified otherwise. For other libraries, we have control over # the execution device and, as much as we can, we try to stay consistent in terms of @@ -248,9 +248,9 @@ # We will append some transforms to our environments to prepare the data for # the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different # approach, more similar to other pytorch domain libraries, through the use of transforms. -# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv` -# instance and append the sequence of transforms to it. The transformed env will inherit -# the device and meta-data of the wrapped env, and transform these depending on the sequence +# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv` +# instance and append the sequence of transforms to it. The transformed environment will inherit +# the device and meta-data of the wrapped environment, and transform these depending on the sequence # of transforms it contains. # # Normalization @@ -262,17 +262,17 @@ # run a certain number of random steps in the environment and compute # the summary statistics of these observations. # -# We'll append two other transforms: the :class:`DoubleToFloat` transform will +# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will # convert double entries to single-precision numbers, ready to be read by the -# policy. The :class:`StepCounter` transform will be used to count the steps before +# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before # the environment is terminated. We will use this measure as a supplementary measure # of performance. # -# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict` +# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict` # to communicate. You could think of it as a python dictionary with some extra # tensor features. In practice, this means that many modules we will be working # with need to be told what key to read (``in_keys``) and what key to write -# (``out_keys``) in the tensordict they will receive. Usually, if ``out_keys`` +# (``out_keys``) in the ``tensordict`` they will receive. Usually, if ``out_keys`` # is omitted, it is assumed that the ``in_keys`` entries will be updated # in-place. For our transforms, the only entry we are interested in is referred # to as ``"observation"`` and our transform layers will be told to modify this @@ -284,22 +284,20 @@ Compose( # normalize observations ObservationNorm(in_keys=["observation"]), - DoubleToFloat( - in_keys=["observation"], - ), + DoubleToFloat(), StepCounter(), ), ) ###################################################################### # As you may have noticed, we have created a normalization layer but we did not -# set its normalization parameters. To do this, :class:`ObservationNorm` can +# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can # automatically gather the summary statistics of our environment: # env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0) ###################################################################### -# The :class:`ObservationNorm` transform has now been populated with a +# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a # location and a scale that will be used to normalize the data. # # Let us do a little sanity check for the shape of our summary stats: @@ -313,25 +311,23 @@ # For efficiency purposes, TorchRL is quite stringent when it comes to # environment specs, but you can easily check that your environment specs are # adequate. -# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits -# from it already take care of setting the proper specs for your env so +# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and +# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits +# from it already take care of setting the proper specs for your environment so # you should not have to care about this. # # Nevertheless, let's see a concrete example using our transformed # environment by looking at its specs. -# There are five specs to look at: ``observation_spec`` which defines what +# There are three specs to look at: ``observation_spec`` which defines what # is to be expected when executing an action in the environment, -# ``reward_spec`` which indicates the reward domain, -# ``done_spec`` which indicates the done state of an environment, -# the ``action_spec`` which defines the action space, dtype and device and -# the ``state_spec`` which groups together the specs of all the other inputs -# (if any) to the environment. +# ``reward_spec`` which indicates the reward domain and finally the +# ``input_spec`` (which contains the ``action_spec``) and which represents +# everything an environment requires to execute a single step. # print("observation_spec:", env.observation_spec) print("reward_spec:", env.reward_spec) -print("done_spec:", env.done_spec) -print("action_spec:", env.action_spec) -print("state_spec:", env.state_spec) +print("input_spec:", env.input_spec) +print("action_spec (as defined by input_spec):", env.action_spec) ###################################################################### # the :func:`check_env_specs` function runs a small rollout and compares its output against the environment @@ -349,9 +345,9 @@ # action as input, and outputs an observation, a reward and a done state. The # observation may be composite, meaning that it could be composed of more than one # tensor. This is not a problem for TorchRL, since the whole set of observations -# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout -# (ie a sequence of environment steps and random action generations) over a given -# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape +# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout +# (for example, a sequence of environment steps and random action generations) over a given +# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape # that matches this trajectory length: # rollout = env.rollout(3) @@ -361,8 +357,8 @@ ###################################################################### # Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps # we ran it for. The ``"next"`` entry points to the data coming after the current step. -# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this -# may not be the case if we are using some specific transformations (e.g. multi-step). +# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this +# may not be the case if we are using some specific transformations (for example, multi-step). # # Policy # ------ @@ -388,10 +384,9 @@ # # 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``. # -# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (ie splits the input in two equal parts -# and applies a positive transformation to the scale parameter). +# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter). # -# 3. Create a probabilistic :class:`TensorDictModule` that can generate this distribution and sample from it. +# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it. # actor_net = nn.Sequential( @@ -406,8 +401,8 @@ ) ###################################################################### -# To enable the policy to "talk" with the environment through the tensordict -# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This +# To enable the policy to "talk" with the environment through the ``tensordict`` +# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This # class will simply ready the ``in_keys`` it is provided with and write the # outputs in-place at the registered ``out_keys``. # @@ -417,18 +412,19 @@ ###################################################################### # We now need to build a distribution out of the location and scale of our -# normal distribution. To do so, we instruct the :class:`ProbabilisticActor` -# class to build a :class:`TanhNormal` out of the location and scale +# normal distribution. To do so, we instruct the +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` +# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale # parameters. We also provide the minimum and maximum values of this # distribution, which we gather from the environment specs. # # The name of the ``in_keys`` (and hence the name of the ``out_keys`` from -# the :class:`TensorDictModule` above) cannot be set to any value one may -# like, as the :class:`TanhNormal` distribution constructor will expect the +# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may +# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the # ``loc`` and ``scale`` keyword arguments. That being said, -# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys`` -# where the key-value pair indicates what ``in_key`` string should be used for -# every keyword argument that is to be used. +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts +# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates +# what ``in_key`` string should be used for every keyword argument that is to be used. # policy_module = ProbabilisticActor( module=policy_module, @@ -436,8 +432,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.minimum, - "max": env.action_spec.space.maximum, + "min": env.action_spec.space.low, + "max": env.action_spec.space.high, }, return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights @@ -451,7 +447,7 @@ # won't be used at inference time. This module will read the observations and # return an estimation of the discounted return for the following trajectory. # This allows us to amortize learning by relying on the some utility estimation -# that is learnt on-the-fly during training. Our value network share the same +# that is learned on-the-fly during training. Our value network share the same # structure as the policy, but for simplicity we assign it its own set of # parameters. # @@ -472,7 +468,7 @@ ###################################################################### # let's try our policy and value modules. As we said earlier, the usage of -# :class:`TensorDictModule` makes it possible to directly read the output +# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output # of the environment to run these modules, as they know what information to read # and where to write it: # @@ -483,11 +479,11 @@ # Data collector # -------------- # -# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these -# classes execute three operations: reset an environment, compute an action -# given the latest observation, execute a step in the environment, and repeat -# the last two steps until the environment signals a stop (or reaches a done -# state). +# TorchRL provides a set of `DataCollector classes `__. +# Briefly, these classes execute three operations: reset an environment, +# compute an action given the latest observation, execute a step in the environment, +# and repeat the last two steps until the environment signals a stop (or reaches +# a done state). # # They allow you to control how many frames to collect at each iteration # (through the ``frames_per_batch`` parameter), @@ -495,18 +491,19 @@ # on which ``device`` the policy should be executed, etc. They are also # designed to work efficiently with batched and multiprocessed environments. # -# The simplest data collector is the :class:`SyncDataCollector`: it is an -# iterator that you can use to get batches of data of a given length, and +# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`: +# it is an iterator that you can use to get batches of data of a given length, and # that will stop once a total number of frames (``total_frames``) have been # collected. -# Other data collectors (``MultiSyncDataCollector`` and -# ``MultiaSyncDataCollector``) will execute the same operations in synchronous -# and asynchronous manner over a set of multiprocessed workers. +# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute +# the same operations in synchronous and asynchronous manner over a +# set of multiprocessed workers. # # As for the policy and environment before, the data collector will return -# :class:`tensordict.TensorDict` instances with a total number of elements that will -# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the -# training loop allows you to write dataloading pipelines +# :class:`~tensordict.TensorDict` instances with a total number of elements that will +# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the +# training loop allows you to write data loading pipelines # that are 100% oblivious to the actual specificities of the rollout content. # collector = SyncDataCollector( @@ -528,10 +525,10 @@ # of epochs. # # TorchRL's replay buffers are built using a common container -# :class:`ReplayBuffer` which takes as argument the components of the buffer: -# a storage, a writer, a sampler and possibly some transforms. Only the -# storage (which indicates the replay buffer capacity) is mandatory. We -# also specify a sampler without repetition to avoid sampling multiple times +# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components +# of the buffer: a storage, a writer, a sampler and possibly some transforms. +# Only the storage (which indicates the replay buffer capacity) is mandatory. +# We also specify a sampler without repetition to avoid sampling multiple times # the same item in one epoch. # Using a replay buffer for PPO is not mandatory and we could simply # sample the sub-batches from the collected batch, but using these classes @@ -539,7 +536,7 @@ # replay_buffer = ReplayBuffer( - storage=LazyTensorStorage(frames_per_batch), + storage=LazyTensorStorage(max_size=frames_per_batch), sampler=SamplerWithoutReplacement(), ) @@ -547,8 +544,8 @@ # Loss function # ------------- # -# The PPO loss can be directly imported from torchrl for convenience using the -# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO: +# The PPO loss can be directly imported from TorchRL for convenience using the +# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO: # it hides away the mathematical operations of PPO and the control flow that # goes with it. # @@ -558,11 +555,11 @@ # To compute the advantage, one just needs to (1) build the advantage module, which # utilizes our value operator, and (2) pass each batch of data through it before each # epoch. -# The GAE module will update the input :class:`TensorDict` with new ``"advantage"`` and +# The GAE module will update the input ``tensordict`` with new ``"advantage"`` and # ``"value_target"`` entries. # The ``"value_target"`` is a gradient-free tensor that represents the empirical # value that the value network should represent with the input observation. -# Both of these will be used by :class:`ClipPPOLoss` to +# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to # return the policy and value losses. # @@ -577,9 +574,7 @@ entropy_bonus=bool(entropy_eps), entropy_coef=entropy_eps, # these keys match by default but we set this for completeness - value_target_key=advantage_module.value_target_key, critic_coef=1.0, - gamma=0.99, loss_critic_type="smooth_l1", ) @@ -610,7 +605,7 @@ logs = defaultdict(list) -pbar = tqdm(total=total_frames * frame_skip) +pbar = tqdm(total=total_frames) eval_str = "" # We iterate over the collector until it reaches the total number of frames it was @@ -621,8 +616,7 @@ # We'll need an "advantage" signal to make PPO work. # We re-compute it at each epoch as its value depends on the value # network which is updated in the inner loop. - with torch.no_grad(): - advantage_module(tensordict_data) + advantage_module(tensordict_data) data_view = tensordict_data.reshape(-1) replay_buffer.extend(data_view.cpu()) for _ in range(frames_per_batch // sub_batch_size): @@ -634,7 +628,7 @@ + loss_vals["loss_entropy"] ) - # Optimization: backward, grad clipping and optim step + # Optimization: backward, grad clipping and optimization step loss_value.backward() # this is not strictly mandatory but it's good practice to keep # your gradient norm bounded @@ -643,7 +637,7 @@ optim.zero_grad() logs["reward"].append(tensordict_data["next", "reward"].mean().item()) - pbar.update(tensordict_data.numel() * frame_skip) + pbar.update(tensordict_data.numel()) cum_reward_str = ( f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})" ) @@ -655,8 +649,8 @@ # We evaluate the policy once every 10 batches of data. # Evaluation is rather simple: execute the policy without exploration # (take the expected value of the action distribution) for a given - # number of steps (1000, which is our env horizon). - # The ``rollout`` method of the env can take a policy as argument: + # number of steps (1000, which is our ``env`` horizon). + # The ``rollout`` method of the ``env`` can take a policy as argument: # it will then execute this policy at each step. with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): # execute a rollout with the trained policy @@ -717,7 +711,7 @@ # we could run several simulations in parallel to speed up data collection. # Check :class:`~torchrl.envs.ParallelEnv` for further information. # -# * From a logging perspective, one could add a :class:`~torchrl.record.VideoRecorder` transform to +# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to # the environment after asking for rendering to get a visual rendering of the # inverted pendulum in action. Check :py:mod:`torchrl.record` to # know more. diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index a1c82d5c429..a2b2b12b562 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -6,6 +6,8 @@ **Author**: `Vincent Moens `_ +.. _RNN_tuto: + .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn @@ -78,15 +80,24 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore import torch import tqdm -from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictSequential, + TensorDictSequential as Seq, +) from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -104,10 +115,15 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv -from torchrl.modules import ConvNet, EGreedyWrapper, LSTMModule, MLP, QValueModule +from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate -device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu") +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ###################################################################### # Environment @@ -309,11 +325,15 @@ # DQN being a deterministic algorithm, exploration is a crucial part of it. # We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying # progressively to 0. -# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step` +# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step` # (see training loop below). # -stoch_policy = EGreedyWrapper( - stoch_policy, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 +exploration_module = EGreedyModule( + annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 +) +stoch_policy = TensorDictSequential( + stoch_policy, + exploration_module, ) ###################################################################### @@ -419,7 +439,7 @@ pbar.set_description( f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}" ) - stoch_policy.step(data.numel()) + exploration_module.step(data.numel()) updater.step() with set_exploration_type(ExplorationType.MODE), torch.no_grad(): diff --git a/tutorials/sphinx-tutorials/getting-started-0.py b/tutorials/sphinx-tutorials/getting-started-0.py new file mode 100644 index 00000000000..752e116a5c3 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-0.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +""" + +Get started with Environments, TED and transforms +================================================= + +**Author**: `Vincent Moens `_ + +.. _gs_env_ted: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" + +################################ +# Welcome to the getting started tutorials! +# +# Below is the list of the topics we will be covering. +# +# - :ref:`Environments, TED and transforms `; +# - :ref:`TorchRL's modules `; +# - :ref:`Losses and optimization `; +# - :ref:`Data collection and storage `; +# - :ref:`TorchRL's logging API `. +# +# If you are in a hurry, you can jump straight away to the last tutorial, +# :ref:`Your own first training loop `, from where you can +# backtrack every other "Getting Started" tutorial if things are not clear or +# if you want to learn more about a specific topic! +# +# Environments in RL +# ------------------ +# +# The standard RL (Reinforcement Learning) training loop involves a model, +# also known as a policy, which is trained to accomplish a task within a +# specific environment. Often, this environment is a simulator that accepts +# actions as input and produces an observation along with some metadata as +# output. +# +# In this document, we will explore the environment API of TorchRL: we will +# learn how to create an environment, interact with it, and understand the +# data format it uses. +# +# Creating an environment +# ----------------------- +# +# In essence, TorchRL does not directly provide environments, but instead +# offers wrappers for other libraries that encapsulate the simulators. The +# :mod:`~torchrl.envs` module can be viewed as a provider for a generic +# environment API, as well as a central hub for simulation backends like +# `gym `_ (:class:`~torchrl.envs.GymEnv`), +# `Brax `_ (:class:`~torchrl.envs.BraxEnv`) +# or `DeepMind Control Suite `_ +# (:class:`~torchrl.envs.DMControlEnv`). +# +# Creating your environment is typically as straightforward as the underlying +# backend API allows. Here's an example using gym: + +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") + +################################ +# +# Running an environment +# ---------------------- +# +# Environments in TorchRL have two crucial methods: +# :meth:`~torchrl.envs.EnvBase.reset`, which initiates +# an episode, and :meth:`~torchrl.envs.EnvBase.step`, which executes an +# action selected by the actor. +# In TorchRL, environment methods read and write +# :class:`~tensordict.TensorDict` instances. +# Essentially, :class:`~tensordict.TensorDict` is a generic key-based data +# carrier for tensors. +# The benefit of using TensorDict over plain tensors is that it enables us to +# handle simple and complex data structures interchangeably. As our function +# signatures are very generic, it eliminates the challenge of accommodating +# different data formats. In simpler terms, after this brief tutorial, +# you will be capable of operating on both simple and highly complex +# environments, as their user-facing API is identical and simple! +# +# Let's put the environment into action and see what a tensordict instance +# looks like: + +reset = env.reset() +print(reset) + +################################ +# Now let's take a random action in the action space. First, sample the action: +reset_with_action = env.rand_action(reset) +print(reset_with_action) + +################################ +# This tensordict has the same structure as the one obtained from +# :meth:`~torchrl.envs.EnvBase` with an additional ``"action"`` entry. +# You can access the action easily, like you would do with a regular +# dictionary: +# + +print(reset_with_action["action"]) + +################################ +# We now need to pass this action tp the environment. +# We'll be passing the entire tensordict to the ``step`` method, since there +# might be more than one tensor to be read in more advanced cases like +# Multi-Agent RL or stateless environments: + +stepped_data = env.step(reset_with_action) +print(stepped_data) + +################################ +# Again, this new tensordict is identical to the previous one except for the +# fact that it has a ``"next"`` entry (itself a tensordict!) containing the +# observation, reward and done state resulting from +# our action. +# +# We call this format TED, for +# :ref:`TorchRL Episode Data format `. It is +# the ubiquitous way of representing data in the library, both dynamically like +# here, or statically with offline datasets. +# +# The last bit of information you need to run a rollout in the environment is +# how to bring that ``"next"`` entry at the root to perform the next step. +# TorchRL provides a dedicated :func:`~torchrl.envs.utils.step_mdp` function +# that does just that: it filters out the information you won't need and +# delivers a data structure corresponding to your observation after a step in +# the Markov Decision Process, or MDP. + +from torchrl.envs import step_mdp + +data = step_mdp(stepped_data) +print(data) + +################################ +# Environment rollouts +# -------------------- +# +# .. _gs_env_ted_rollout: +# +# Writing down those three steps (computing an action, making a step, +# moving in the MDP) can be a bit tedious and repetitive. Fortunately, +# TorchRL provides a nice :meth:`~torchrl.envs.EnvBase.rollout` function that +# allows you to run them in a closed loop at will: +# + +rollout = env.rollout(max_steps=10) +print(rollout) + +################################ +# This data looks pretty much like the ``stepped_data`` above with the +# exception of its batch-size, which now equates the number of steps we +# provided through the ``max_steps`` argument. The magic of tensordict +# doesn't end there: if you're interested in a single transition of this +# environment, you can index the tensordict like you would index a tensor: + +transition = rollout[3] +print(transition) + +################################ +# :class:`~tensordict.TensorDict` will automatically check if the index you +# provided is a key (in which case we index along the key-dimension) or a +# spatial index like here. +# +# Executed as such (without a policy), the ``rollout`` method may seem rather +# useless: it just runs random actions. If a policy is available, it can +# be passed to the method and used to collect data. +# +# Nevertheless, it can useful to run a naive, policyless rollout at first to +# check what is to be expected from an environment at a glance. +# +# To appreciate the versatility of TorchRL's API, consider the fact that the +# rollout method is universally applicable. It functions across **all** use +# cases, whether you're working with a single environment like this one, +# multiple copies across various processes, a multi-agent environment, or even +# a stateless version of it! +# +# +# Transforming an environment +# --------------------------- +# +# Most of the time, you'll want to modify the output of the environment to +# better suit your requirements. For example, you might want to monitor the +# number of steps executed since the last reset, resize images, or stack +# consecutive observations together. +# +# In this section, we'll examine a simple transform, the +# :class:`~torchrl.envs.transforms.StepCounter` transform. +# The complete list of transforms can be found +# :ref:`here `. +# +# The transform is integrated with the environment through a +# :class:`~torchrl.envs.transforms.TransformedEnv`: +# + +from torchrl.envs import StepCounter, TransformedEnv + +transformed_env = TransformedEnv(env, StepCounter(max_steps=10)) +rollout = transformed_env.rollout(max_steps=100) +print(rollout) + +################################ +# As you can see, our environment now has one more entry, ``"step_count"`` that +# tracks the number of steps since the last reset. +# Given that we passed the optional +# argument ``max_steps=10`` to the transform constructor, we also truncated the +# trajectory after 10 steps (not completing a full rollout of 100 steps like +# we asked with the ``rollout`` call). We can see that the trajectory was +# truncated by looking at the truncated entry: + +print(rollout["next", "truncated"]) + +################################ +# +# This is all for this short introduction to TorchRL's environment API! +# +# Next steps +# ---------- +# +# To explore further what TorchRL's environments can do, go and check: +# +# - The :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` method that packs +# together :meth:`~torchrl.envs.EnvBase.step`, +# :func:`~torchrl.envs.utils.step_mdp` and +# :meth:`~torchrl.envs.EnvBase.reset`. +# - Some environments like :class:`~torchrl.envs.GymEnv` support rendering +# through the ``from_pixels`` argument. Check the class docstrings to know +# more! +# - The batched environments, in particular :class:`~torchrl.envs.ParallelEnv` +# which allows you to run multiple copies of one same (or different!) +# environments on multiple processes. +# - Design your own environment with the +# :ref:`Pendulum tutorial ` and learn about specs and +# stateless environments. +# - See the more in-depth tutorial about environments +# :ref:`in the dedicated tutorial `; +# - Check the +# :ref:`multi-agent environment API ` +# if you're interested in MARL; +# - TorchRL has many tools to interact with the Gym API such as +# a way to register TorchRL envs in the Gym register through +# :meth:`~torchrl.envs.EnvBase.register_gym`, an API to read +# the info dictionaries through +# :meth:`~torchrl.envs.EnvBase.set_info_dict_reader` or a way +# to control the gym backend thanks to +# :func:`~torchrl.envs.set_gym_backend`. +# diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py new file mode 100644 index 00000000000..75ccf7cf8e7 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +""" +Get started with TorchRL's modules +================================== + +**Author**: `Vincent Moens `_ + +.. _gs_modules: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" +################################### +# Reinforcement Learning is designed to create policies that can effectively +# tackle specific tasks. Policies can take various forms, from a differentiable +# map transitioning from the observation space to the action space, to a more +# ad-hoc method like an argmax over a list of values computed for each possible +# action. Policies can be deterministic or stochastic, and may incorporate +# complex elements such as Recurrent Neural Networks (RNNs) or transformers. +# +# Accommodating all these scenarios can be quite intricate. In this succinct +# tutorial, we will delve into the core functionality of TorchRL in terms of +# policy construction. We will primarily focus on stochastic and Q-Value +# policies in two common scenarios: using a Multi-Layer Perceptron (MLP) or +# a Convolutional Neural Network (CNN) as backbones. +# +# TensorDictModules +# ----------------- +# +# Similar to how environments interact with instances of +# :class:`~tensordict.TensorDict`, the modules used to represent policies and +# value functions also do the same. The core idea is simple: encapsulate a +# standard :class:`~torch.nn.Module` (or any other function) within a class +# that knows which entries need to be read and passed to the module, and then +# records the results with the assigned entries. To illustrate this, we will +# use the simplest policy possible: a deterministic map from the observation +# space to the action space. For maximum generality, we will use a +# :class:`~torch.nn.LazyLinear` module with the Pendulum environment we +# instantiated in the previous tutorial. +# + +import torch + +from tensordict.nn import TensorDictModule +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") +module = torch.nn.LazyLinear(out_features=env.action_spec.shape[-1]) +policy = TensorDictModule( + module, + in_keys=["observation"], + out_keys=["action"], +) + +################################### +# This is all that's required to execute our policy! The use of a lazy module +# allows us to bypass the need to fetch the shape of the observation space, as +# the module will automatically determine it. This policy is now ready to be +# run in the environment: + +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# Specialized wrappers +# -------------------- +# +# To simplify the incorporation of :class:`~torch.nn.Module`s into your +# codebase, TorchRL offers a range of specialized wrappers designed to be +# used as actors, including :class:`~torchrl.modules.tensordict_module.Actor`, +# # :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`, +# # :class:`~torchrl.modules.tensordict_module.ActorValueOperator` or +# # :class:`~torchrl.modules.tensordict_module.ActorCriticOperator`. +# For example, :class:`~torchrl.modules.tensordict_module.Actor` provides +# default values for the ``in_keys`` and ``out_keys``, making integration +# with many common environments straightforward: +# + +from torchrl.modules import Actor + +policy = Actor(module) +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# The list of available specialized TensorDictModules is available in the +# :ref:`API reference `. +# +# Networks +# -------- +# +# TorchRL also provides regular modules that can be used without recurring to +# tensordict features. The two most common networks you will encounter are +# the :class:`~torchrl.modules.MLP` and the :class:`~torchrl.modules.ConvNet` +# (CNN) modules. We can substitute our policy module with one of these: +# + +from torchrl.modules import MLP + +module = MLP( + out_features=env.action_spec.shape[-1], + num_cells=[32, 64], + activation_class=torch.nn.Tanh, +) +policy = Actor(module) +rollout = env.rollout(max_steps=10, policy=policy) + +################################### +# TorchRL also supports RNN-based policies. Since this is a more technical +# topic, it is treated in :ref:`a separate tutorial `. +# +# Probabilistic policies +# ---------------------- +# +# Policy-optimization algorithms like +# `PPO `_ require the policy to be +# stochastic: unlike in the examples above, the module now encodes a map from +# the observation space to a parameter space encoding a distribution over the +# possible actions. TorchRL facilitates the design of such modules by grouping +# under a single class the various operations such as building the distribution +# from the parameters, sampling from that distribution and retrieving the +# log-probability. Here, we'll be building an actor that relies on a regular +# normal distribution using three components: +# +# - An :class:`~torchrl.modules.MLP` backbone reading observations of size +# ``[3]`` and outputting a single tensor of size ``[2]``; +# - A :class:`~tensordict.nn.distributions.NormalParamExtractor` module that +# will split this output on two chunks, a mean and a standard deviation of +# size ``[1]``; +# - A :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` that will +# read those parameters as ``in_keys``, create a distribution with them and +# populate our tensordict with samples and log-probabilities. +# + +from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Normal +from torchrl.modules import ProbabilisticActor + +backbone = MLP(in_features=3, out_features=2) +extractor = NormalParamExtractor() +module = torch.nn.Sequential(backbone, extractor) +td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) +policy = ProbabilisticActor( + td_module, + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=Normal, + return_log_prob=True, +) + +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# There are a few things to note about this rollout: +# +# - Since we asked for it during the construction of the actor, the +# log-probability of the actions given the distribution at that time is +# also written. This is necessary for algorithms like PPO. +# - The parameters of the distribution are returned within the output +# tensordict too under the ``"loc"`` and ``"scale"`` entries. +# +# You can control the sampling of the action to use the expected value or +# other properties of the distribution instead of using random samples if +# your application requires it. This can be controlled via the +# :func:`~torchrl.envs.utils.set_exploration_type` function: + +from torchrl.envs.utils import ExplorationType, set_exploration_type + +with set_exploration_type(ExplorationType.MEAN): + # takes the mean as action + rollout = env.rollout(max_steps=10, policy=policy) +with set_exploration_type(ExplorationType.RANDOM): + # Samples actions according to the dist + rollout = env.rollout(max_steps=10, policy=policy) + +################################### +# Check the ``default_interaction_type`` keyword argument in +# the docstrings to know more. +# +# Exploration +# ----------- +# +# Stochastic policies like this somewhat naturally trade off exploration and +# exploitation, but deterministic policies won't. Fortunately, TorchRL can +# also palliate to this with its exploration modules. +# We will take the example of the :class:`~torchrl.modules.EGreedyModule` +# exploration module (check also +# :class:`~torchrl.modules.AdditiveGaussianWrapper` and +# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`). +# To see this module in action, let's revert to a deterministic policy: + +from tensordict.nn import TensorDictSequential +from torchrl.modules import EGreedyModule + +policy = Actor(MLP(3, 1, num_cells=[32, 64])) + +################################### +# Our :math:`\epsilon`-greedy exploration module will usually be customized +# with a number of annealing frames and an initial value for the +# :math:`\epsilon` parameter. A value of :math:`\epsilon = 1` means that every +# action taken is random, while :math:`\epsilon=0` means that there is no +# exploration at all. To anneal (i.e., decrease) the exploration factor, a call +# to :meth:`~torchrl.modules.EGreedyModule.step` is required (see the last +# :ref:`tutorial ` for an example). +# +exploration_module = EGreedyModule( + spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5 +) + +################################### +# To build our explorative policy, we only had to concatenate the +# deterministic policy module with the exploration module within a +# :class:`~tensordict.nn.TensorDictSequential` module (which is the analogous +# to :class:`~torch.nn.Sequential` in the tensordict realm). + +exploration_policy = TensorDictSequential(policy, exploration_module) + +with set_exploration_type(ExplorationType.MEAN): + # Turns off exploration + rollout = env.rollout(max_steps=10, policy=exploration_policy) +with set_exploration_type(ExplorationType.RANDOM): + # Turns on exploration + rollout = env.rollout(max_steps=10, policy=exploration_policy) + +################################### +# Because it must be able to sample random actions in the action space, the +# :class:`~torchrl.modules.EGreedyModule` must be equipped with the +# ``action_space`` from the environment to know what strategy to use to +# sample actions randomly. +# +# Q-Value actors +# -------------- +# +# In some settings, the policy isn't a standalone module but is constructed on +# top of another module. This is the case with **Q-Value actors**. In short, these +# actors require an estimate of the action value (most of the time discrete) +# and will greedily pick up the action with the highest value. In some +# settings (finite discrete action space and finite discrete state space), +# one can just store a 2D table of state-action pairs and pick up the +# action with the highest value. The innovation brought by +# `DQN `_ was to scale this up to continuous +# state spaces by utilizing a neural network to encode for the ``Q(s, a)`` +# value map. Let's consider another environment with a discrete action space +# for a clearer understanding: + +env = GymEnv("CartPole-v1") +print(env.action_spec) + +################################### +# We build a value network that produces one value per action when it reads a +# state from the environment: + +num_actions = 2 +value_net = TensorDictModule( + MLP(out_features=num_actions, num_cells=[32, 32]), + in_keys=["observation"], + out_keys=["action_value"], +) + +################################### +# We can easily build our Q-Value actor by adding a +# :class:`~torchrl.modules.tensordict_module.QValueModule` after our value +# network: + +from torchrl.modules import QValueModule + +policy = TensorDictSequential( + value_net, # writes action values in our tensordict + QValueModule( + action_space=env.action_spec + ), # Reads the "action_value" entry by default +) + +################################### +# Let's check it out! We run the policy for a couple of steps and look at the +# output. We should find an ``"action_value"`` as well as a +# ``"chosen_action_value"`` entries in the rollout that we obtain: +# + +rollout = env.rollout(max_steps=3, policy=policy) +print(rollout) + +################################### +# Because it relies on the ``argmax`` operator, this policy is deterministic. +# During data collection, we will need to explore the environment. For that, +# we are using the :class:`~torchrl.modules.tensordict_module.EGreedyModule` +# once again: + +policy_explore = TensorDictSequential(policy, EGreedyModule(env.action_spec)) + +with set_exploration_type(ExplorationType.RANDOM): + rollout_explore = env.rollout(max_steps=3, policy=policy_explore) + +################################### +# This is it for our short tutorial on building a policy with TorchRL! +# +# There are many more things you can do with the library. A good place to start +# is to look at the :ref:`API reference for modules `. +# +# Next steps: +# +# - Check how to use compound distributions with +# :class:`~tensordict.nn.distributions.CompositeDistribution` when the +# action is composite (e.g., a discrete and a continuous action are +# required by the env); +# - Have a look at how you can use an RNN within the policy (a +# :ref:`tutorial `); +# - Compare this to the usage of transformers with the Decision Transformers +# examples (see the ``example`` directory on GitHub). +# diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py new file mode 100644 index 00000000000..0a16071bed2 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +""" +Getting started with model optimization +======================================= + +**Author**: `Vincent Moens `_ + +.. _gs_optim: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" + +################################### +# In TorchRL, we try to treat optimization as it is custom to do in PyTorch, +# using dedicated loss modules which are designed with the sole purpose of +# optimizing the model. This approach efficiently decouples the execution of +# the policy from its training and allows us to design training loops that are +# similar to what can be found in traditional supervised learning examples. +# +# The typical training loop therefore looks like this: +# +# >>> for i in range(n_collections): +# ... data = get_next_batch(env, policy) +# ... for j in range(n_optim): +# ... loss = loss_fn(data) +# ... loss.backward() +# ... optim.step() +# +# In this concise tutorial, you will receive a brief overview of the loss modules. Due to the typically +# straightforward nature of the API for basic usage, this tutorial will be kept brief. +# +# RL objective functions +# ---------------------- +# +# In RL, innovation typically involves the exploration of novel methods +# for optimizing a policy (i.e., new algorithms), rather than focusing +# on new architectures, as seen in other domains. Within TorchRL, +# these algorithms are encapsulated within loss modules. A loss +# module orchestrates the various components of your algorithm and +# yields a set of loss values that can be backpropagated +# through to train the corresponding components. +# +# In this tutorial, we will take a popular +# off-policy algorithm as an example, +# `DDPG `_. +# +# To build a loss module, the only thing one needs is a set of networks +# defined as :class:`~tensordict.nn.TensorDictModule`s. Most of the time, one +# of these modules will be the policy. Other auxiliary networks such as +# Q-Value networks or critics of some kind may be needed as well. Let's see +# what this looks like in practice: DDPG requires a deterministic +# map from the observation space to the action space as well as a value +# network that predicts the value of a state-action pair. The DDPG loss will +# attempt to find the policy parameters that output actions that maximize the +# value for a given state. +# +# To build the loss, we need both the actor and value networks. +# If they are built according to DDPG's expectations, it is all +# we need to get a trainable loss module: + +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") + +from torchrl.modules import Actor, MLP, ValueOperator +from torchrl.objectives import DDPGLoss + +n_obs = env.observation_spec["observation"].shape[-1] +n_act = env.action_spec.shape[-1] +actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32])) +value_net = ValueOperator( + MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]), + in_keys=["observation", "action"], +) + +ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net) + +################################### +# And that is it! Our loss module can now be run with data coming from the +# environment (we omit exploration, storage and other features to focus on +# the loss functionality): +# + +rollout = env.rollout(max_steps=100, policy=actor) +loss_vals = ddpg_loss(rollout) +print(loss_vals) + +################################### +# LossModule's output +# ------------------- +# +# As you can see, the value we received from the loss isn't a single scalar +# but a dictionary containing multiple losses. +# +# The reason is simple: because more than one network may be trained at a time, +# and since some users may wish to separate the optimization of each module +# in distinct steps, TorchRL's objectives will return dictionaries containing +# the various loss components. +# +# This format also allows us to pass metadata along with the loss values. In +# general, we make sure that only the loss values are differentiable such that +# you can simply sum over the values of the dictionary to obtain the total +# loss. If you want to make sure you're fully in control of what is happening, +# you can sum over only the entries which keys start with the ``"loss_"`` prefix: +# + +total_loss = 0 +for key, val in loss_vals.items(): + if key.startswith("loss_"): + total_loss += val + +################################### +# Training a LossModule +# --------------------- +# +# Given all this, training the modules is not so different from what would be +# done in any other training loop. Because it wraps the modules, +# the easiest way to get the list of trainable parameters is to query +# the :meth:`~torchrl.objectives.LossModule.parameters` method. +# +# We'll need an optimizer (or one optimizer +# per module if that is your choice). +# + +from torch.optim import Adam + +optim = Adam(ddpg_loss.parameters()) +total_loss.backward() + +################################### +# The following items will typically be +# found in your training loop: + +optim.step() +optim.zero_grad() + +################################### +# Further considerations: Target parameters +# ----------------------------------------- +# +# Another important aspect to consider is the presence of target parameters +# in off-policy algorithms like DDPG. Target parameters typically represent +# a delayed or smoothed version of the parameters over time, and they play +# a crucial role in value estimation during policy training. Utilizing target +# parameters for policy training often proves to be significantly more +# efficient compared to using the current configuration of value network +# parameters. Generally, managing target parameters is handled by the loss +# module, relieving users of direct concern. However, it remains the user's +# responsibility to update these values as necessary based on specific +# requirements. TorchRL offers a couple of updaters, namely +# :class:`~torchrl.objectives.HardUpdate` and +# :class:`~torchrl.objectives.SoftUpdate`, +# which can be easily instantiated without requiring in-depth +# knowledge of the underlying mechanisms of the loss module. +# +from torchrl.objectives import SoftUpdate + +updater = SoftUpdate(ddpg_loss, eps=0.99) + +################################### +# In your training loop, you will need to update the target parameters at each +# optimization step or each collection step: + +updater.step() + +################################### +# This is all you need to know about loss modules to get started! +# +# To further explore the topic, have a look at: +# +# - The :ref:`loss module reference page `; +# - The :ref:`Coding a DDPG loss tutorial `; +# - Losses in action in :ref:`PPO ` or :ref:`DQN `. +# diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py new file mode 100644 index 00000000000..829b22cf061 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +""" +Get started with data collection and storage +============================================ + +**Author**: `Vincent Moens `_ + +.. _gs_storage: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" + +################################# +# +# There is no learning without data. In supervised learning, users are +# accustomed to using :class:`~torch.utils.data.DataLoader` and the like +# to integrate data in their training loop. +# Dataloaders are iterable objects that provide you with the data that you will +# be using to train your model. +# +# TorchRL approaches the problem of dataloading in a similar manner, although +# it is surprisingly unique in the ecosystem of RL libraries. TorchRL's +# dataloaders are referred to as ``DataCollectors``. Most of the time, +# data collection does not stop at the collection of raw data, +# as the data needs to be stored temporarily in a buffer +# (or equivalent structure for on-policy algorithms) before being consumed +# by the :ref:`loss module `. This tutorial will explore +# these two classes. +# +# Data collectors +# --------------- +# +# .. _gs_storage_collector: +# +# +# The primary data collector discussed here is the +# :class:`~torchrl.collectors.SyncDataCollector`, which is the focus of this +# documentation. At a fundamental level, a collector is a straightforward +# class responsible for executing your policy within the environment, +# resetting the environment when necessary, and providing batches of a +# predefined size. Unlike the :meth:`~torchrl.envs.EnvBase.rollout` method +# demonstrated in :ref:`the env tutorial `, collectors do not +# reset between consecutive batches of data. Consequently, two successive +# batches of data may contain elements from the same trajectory. +# +# The basic arguments you need to pass to your collector are the size of the +# batches you want to collect (``frames_per_batch``), the length (possibly +# infinite) of the iterator, the policy and the environment. For simplicity, +# we will use a dummy, random policy in this example. + +import torch + +torch.manual_seed(0) + +from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.envs import GymEnv + +env = GymEnv("CartPole-v1") +env.set_seed(0) + +policy = RandomPolicy(env.action_spec) +collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1) + +################################# +# We now expect that our collector will deliver batches of size ``200`` no +# matter what happens during collection. In other words, we may have multiple +# trajectories in this batch! The ``total_frames`` indicates how long the +# collector should be. A value of ``-1`` will produce a never +# ending collector. +# +# Let's iterate over the collector to get a sense +# of what this data looks like: + +for data in collector: + print(data) + break + +################################# +# As you can see, our data is augmented with some collector-specific metadata +# grouped in a ``"collector"`` sub-tensordict that we did not see during +# :ref:`environment rollouts `. This is useful to keep track of +# the trajectory ids. In the following list, each item marks the trajectory +# number the corresponding transition belongs to: + +print(data["collector", "traj_ids"]) + +################################# +# Data collectors are very useful when it comes to coding state-of-the-art +# algorithms, as performance is usually measured by the capability of a +# specific technique to solve a problem in a given number of interactions with +# the environment (the ``total_frames`` argument in the collector). +# For this reason, most training loops in our examples look like this: +# +# >>> for data in collector: +# ... # your algorithm here +# +# +# Replay Buffers +# -------------- +# +# .. _gs_storage_rb: +# +# Now that we have explored how to collect data, we would like to know how to +# store it. In RL, the typical setting is that the data is collected, stored +# temporarily and cleared after a little while given some heuristic: +# first-in first-out or other. A typical pseudo-code would look like this: +# +# >>> for data in collector: +# ... storage.store(data) +# ... for i in range(n_optim): +# ... sample = storage.sample() +# ... loss_val = loss_fn(sample) +# ... loss_val.backward() +# ... optim.step() # etc +# +# The parent class that stores the data in TorchRL +# is referred to as :class:`~torchrl.data.ReplayBuffer`. TorchRL's replay +# buffers are composable: you can edit the storage type, their sampling +# technique, the writing heuristic or the transforms applied to them. We will +# leave the fancy stuff for a dedicated in-depth tutorial. The generic replay +# buffer only needs to know what storage it has to use. In general, we +# recommend a :class:`~torchrl.data.TensorStorage` subclass, which will work +# fine in most cases. We'll be using +# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# in this tutorial, which enjoys two nice properties: first, being "lazy", +# you don't need to explicitly tell it what your data looks like in advance. +# Second, it uses :class:`~tensordict.MemoryMappedTensor` as a backend to save +# your data on disk in an efficient way. The only thing you need to know is +# how big you want your buffer to be. + +from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer + +buffer = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) + +################################# +# Populating the buffer can be done via the +# :meth:`~torchrl.data.ReplayBuffer.add` (single element) or +# :meth:`~torchrl.data.ReplayBuffer.extend` (multiple elements) methods. Using +# the data we just collected, we initialize and populate the buffer in one go: + +indices = buffer.extend(data) + +################################# +# We can check that the buffer now has the same number of elements than what +# we got from the collector: + +assert len(buffer) == collector.frames_per_batch + +################################# +# The only thing left to know is how to gather data from the buffer. +# Naturally, this relies on the :meth:`~torchrl.data.ReplayBuffer.sample` +# method. Because we did not specify that sampling had to be done without +# repetitions, it is not guaranteed that the samples gathered from our buffer +# will be unique: + +sample = buffer.sample(batch_size=30) +print(sample) + +################################# +# Again, our sample looks exactly the same as the data we gathered from the +# collector! +# +# Next steps +# ---------- +# +# - You can have look at other multirpocessed +# collectors such as :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` or +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`. +# - TorchRL also offers distributed collectors if you have multiple nodes to +# use for inference. Check them out in the +# :ref:`API reference `. +# - Check the dedicated :ref:`Replay Buffer tutorial ` to know +# more about the options you have when building a buffer, or the +# :ref:`API reference ` which covers all the features in +# details. Replay buffers have countless features such as multithreaded +# sampling, prioritized experience replay, and many more... +# - We left out the capacity of replay buffers to be iterated over for +# simplicity. Try it out for yourself: build a buffer and indicate its +# batch-size in the constructor, then try to iterate over it. This is +# equivalent to calling ``rb.sample()`` within a loop! +# diff --git a/tutorials/sphinx-tutorials/getting-started-4.py b/tutorials/sphinx-tutorials/getting-started-4.py new file mode 100644 index 00000000000..a7c6462375d --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-4.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +""" +Get started with logging +======================== + +**Author**: `Vincent Moens `_ + +.. _gs_logging: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" + +##################################### +# The final chapter of this series before we orchestrate everything in a +# training script is to learn about logging. +# +# Loggers +# ------- +# +# Logging is crucial for reporting your results to the outside world and for +# you to check that your algorithm is learning properly. TorchRL has several +# loggers that interface with custom backends such as +# wandb (:class:`~torchrl.record.loggers.wandb.WandbLogger`), +# tensorboard (:class:`~torchrl.record.loggers.tensorboard.TensorBoardLogger`) or a lightweight and +# portable CSV logger (:class:`~torchrl.record.loggers.csv.CSVLogger`) that you can use +# pretty much everywhere. +# +# Loggers are located in the ``torchrl.record`` module and the various classes +# can be found in the :ref:`API reference `. +# +# We tried to keep the loggers APIs as similar as we could, given the +# differences in the underlying backends. While execution of the loggers will +# mostly be interchangeable, their instantiation can differ. +# +# Usually, building a logger requires +# at least an experiment name and possibly a logging directory and other +# hyperapameters. +# + +from torchrl.record import CSVLogger + +logger = CSVLogger(exp_name="my_exp") + +##################################### +# Once the logger is instantiated, the only thing left to do is call the +# logging methods! For example, :meth:`~torchrl.record.CSVLogger.log_scalar` +# is used in several places across the training examples to log values such as +# reward, loss value or time elapsed for executing a piece of code. + +logger.log_scalar("my_scalar", 0.4) + +##################################### +# Recording videos +# ---------------- +# +# Finally, it can come in handy to record videos of a simulator. Some +# environments (e.g., Atari games) are already rendered as images whereas +# others require you to create them as such. Fortunately, in most common cases, +# rendering and recording videos isn't too difficult. +# +# Let's first see how we can create a Gym environment that outputs images +# alongside its observations. :class:`~torchrl.envs.GymEnv` accept two keywords +# for this purpose: ``from_pixels=True`` will make the env ``step`` function +# write a ``"pixels"`` entry containing the images corresponding to your +# observations, and the ``pixels_only=False`` will indicate that you want the +# observations to be returned as well. +# + +from torchrl.envs import GymEnv + +env = GymEnv("CartPole-v1", from_pixels=True, pixels_only=False) + +print(env.rollout(max_steps=3)) + +from torchrl.envs import TransformedEnv + +##################################### +# We now have built an environment that renders images with its observations. +# To record videos, we will need to combine that environment with a recorder +# and the logger (the logger providing the backend to save the video). +# This will happen within a transformed environment, like the one we saw in +# the :ref:`first tutorial `. + +from torchrl.record import VideoRecorder + +recorder = VideoRecorder(logger, tag="my_video") +record_env = TransformedEnv(env, recorder) + +##################################### +# When running this environment, all the ``"pixels"`` entries will be saved in +# a local buffer and dumped in a video on demand (it is important that you +# call this method when appropriate): + +rollout = record_env.rollout(max_steps=3) +# Uncomment this line to save the video on disk: +# recorder.dump() + +##################################### +# In this specific case, the video format can be chosen when instantiating +# the CSVLogger. +# +# This is all we wanted to cover in the getting started tutorial. +# You should now be ready to code your +# :ref:`first training loop with TorchRL `! +# diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py new file mode 100644 index 00000000000..039e15fa035 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +""" +Get started with your own first training loop +============================================= + +**Author**: `Vincent Moens `_ + +.. _gs_first_training: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" + +################################# +# Time to wrap up everything we've learned so far in this Getting Started +# series! +# +# In this tutorial, we will be writing the most basic training loop there is +# using only components we have presented in the previous lessons. +# +# We'll be using DQN with a CartPole environment as a prototypical example. +# +# We will be voluntarily keeping the verbosity to its minimum, only linking +# each section to the related tutorial. +# +# Building the environment +# ------------------------ +# +# We'll be using a gym environment with a :class:`~torchrl.envs.transforms.StepCounter` +# transform. If you need a refresher, check our these features are presented in +# :ref:`the environment tutorial `. +# + +import torch + +torch.manual_seed(0) + +import time + +from torchrl.envs import GymEnv, StepCounter, TransformedEnv + +env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) +env.set_seed(0) + +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + +################################# +# Designing a policy +# ------------------ +# +# The next step is to build our policy. +# We'll be making a regular, deterministic +# version of the actor to be used within the +# :ref:`loss module ` and during +# :ref:`evaluation `. +# Next, we will augment it with an exploration module +# for :ref:`inference `. + +from torchrl.modules import EGreedyModule, MLP, QValueModule + +value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64]) +value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"]) +policy = Seq(value_net, QValueModule(env.action_spec)) +exploration_module = EGreedyModule( + env.action_spec, annealing_num_steps=100_000, eps_init=0.5 +) +policy_explore = Seq(policy, exploration_module) + + +################################# +# Data Collector and replay buffer +# -------------------------------- +# +# Here comes the data part: we need a +# :ref:`data collector ` to easily get batches of data +# and a :ref:`replay buffer ` to store that data for training. +# + +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer + +init_rand_steps = 5000 +frames_per_batch = 100 +optim_steps = 10 +collector = SyncDataCollector( + env, + policy, + frames_per_batch=frames_per_batch, + total_frames=-1, + init_random_frames=init_rand_steps, +) +rb = ReplayBuffer(storage=LazyTensorStorage(100_000)) + +from torch.optim import Adam + +################################# +# Loss module and optimizer +# ------------------------- +# +# We build our loss as indicated in the :ref:`dedicated tutorial `, with +# its optimizer and target parameter updater: + +from torchrl.objectives import DQNLoss, SoftUpdate + +loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) +optim = Adam(loss.parameters(), lr=0.02) +updater = SoftUpdate(loss, eps=0.99) + +################################# +# Logger +# ------ +# +# We'll be using a CSV logger to log our results, and save rendered videos. +# + +from torchrl._utils import logger as torchrl_logger +from torchrl.record import CSVLogger, VideoRecorder + +path = "./training_loop" +logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4") +video_recorder = VideoRecorder(logger, tag="video") +record_env = TransformedEnv( + GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder +) + +################################# +# Training loop +# ------------- +# +# Instead of fixing a specific number of iterations to run, we will keep on +# training the network until it reaches a certain performance (arbitrarily +# defined as 200 steps in the environment -- with CartPole, success is defined +# as having longer trajectories). +# + +total_count = 0 +total_episodes = 0 +t0 = time.time() +for i, data in enumerate(collector): + # Write data in replay buffer + rb.extend(data) + max_length = rb[:]["next", "step_count"].max() + if len(rb) > init_rand_steps: + # Optim loop (we do several optim steps + # per batch collected for efficiency) + for _ in range(optim_steps): + sample = rb.sample(128) + loss_vals = loss(sample) + loss_vals["loss"].backward() + optim.step() + optim.zero_grad() + # Update exploration factor + exploration_module.step(data.numel()) + # Update target params + updater.step() + if i % 10: + torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}") + total_count += data.numel() + total_episodes += data["next", "done"].sum() + if max_length > 200: + break + +t1 = time.time() + +torchrl_logger.info( + f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s." +) + +################################# +# Rendering +# --------- +# +# Finally, we run the environment for as many steps as we can and save the +# video locally (notice that we are not exploring). + +record_env.rollout(max_steps=1000, policy=policy) +video_recorder.dump() + +################################# +# +# This is what your rendered CartPole video will look like after a full +# training loop: +# +# .. figure:: /_static/img/cartpole.gif +# +# This concludes our series of "Getting started with TorchRL" tutorials! +# Feel free to share feedback about it on GitHub. +# diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index a12c2b05ff8..68cb995a1a3 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -20,9 +20,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 90fd82dab3c..7451d6b39e7 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -122,10 +122,12 @@ # Torch import torch -# Tensordict modules from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +# Tensordict modules +from torch import multiprocessing + # Data collection from torchrl.collectors import SyncDataCollector from torchrl.data.replay_buffers import ReplayBuffer @@ -161,7 +163,12 @@ # # Devices -device = "cpu" if not torch.has_cuda else "cuda:0" # The divice where learning is run +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) vmas_device = device # The device where the simulator is run (VMAS can run on GPU) # Sampling diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 12c8bdc3193..8e7817978e4 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -6,6 +6,8 @@ **Author**: `Vincent Moens `_ +.. _pendulum_tuto: + Creating an environment (a simulator or an interface to a physical control system) is an integrative part of reinforcement learning and control engineering. @@ -84,9 +86,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index e8abf33cef8..03265c50d2b 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -24,9 +24,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -37,7 +42,12 @@ from torchrl.envs.libs.gym import GymEnv from torchrl.modules import Actor -device = "cuda:0" if torch.cuda.device_count() else "cpu" +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ############################################################################## # Let us first create an environment. For the sake of simplicity, we will be using diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 6106e3cf65a..2c5cd95e780 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -5,6 +5,8 @@ **Author**: `Vincent Moens `_ +.. _rb_tuto: + """ ###################################################################### # Replay buffers are a central piece of any RL or control algorithm. @@ -30,17 +32,24 @@ # # # In this tutorial, you will learn: -# - How to build a Replay Buffer (RB) and use it with any datatype; -# - How to use RBs with TensorDict; -# - How to sample from or iterate over a replay buffer, and how to define the sampling strategy; -# - How to use prioritized replay buffers; -# - How to transform data coming in and out from the buffer; -# - How to store trajectories in the buffer. +# +# - How to build a :ref:`Replay Buffer (RB) ` and use it with +# any datatype; +# - How to customize the :ref:`buffer's storage `; +# - How to use :ref:`RBs with TensorDict `; +# - How to :ref:`sample from or iterate over a replay buffer `, +# and how to define the sampling strategy; +# - How to use :ref:`prioritized replay buffers `; +# - How to :ref:`transform data ` coming in and out from +# the buffer; +# - How to store :ref:`trajectories ` in the buffer. # # # Basics: building a vanilla replay buffer # ---------------------------------------- # +# .. _tuto_rb_vanilla: +# # TorchRL's replay buffers are designed to prioritize modularity, # composability, efficiency, and simplicity. For instance, creating a basic # replay buffer is a straightforward process, as shown in the following @@ -57,9 +66,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -72,7 +86,7 @@ ###################################################################### # By default, this replay buffer will have a size of 1000. Let's check this -# by populating our buffer using the :meth:`torchrl.data.ReplayBuffer.extend` +# by populating our buffer using the :meth:`~torchrl.data.ReplayBuffer.extend` # method: # @@ -82,24 +96,24 @@ print("length after adding elements:", len(buffer)) -import torch -from tensordict import TensorDict - ###################################################################### -# We have used the :meth:`torchrl.data.ReplayBuffer.extend` method which is +# We have used the :meth:`~torchrl.data.ReplayBuffer.extend` method which is # designed to add multiple items all at once. If the object that is passed # to ``extend`` has more than one dimension, its first dimension is # considered to be the one to be split in separate elements in the buffer. +# # This essentially means that when adding multidimensional tensors or # tensordicts to the buffer, the buffer will only look at the first dimension # when counting the elements it holds in memory. # If the object passed it not iterable, an exception will be thrown. # -# To add items one at a time, the :meth:`torchrl.data.ReplayBuffer.add` method +# To add items one at a time, the :meth:`~torchrl.data.ReplayBuffer.add` method # should be used instead. # # Customizing the storage -# ~~~~~~~~~~~~~~~~~~~~~~~ +# ----------------------- +# +# .. _tuto_rb_storage: # # We see that the buffer has been capped to the first 1000 elements that we # passed to it. @@ -107,25 +121,27 @@ # # TorchRL proposes three types of storages: # -# - The :class:`torchrl.dataListStorage` stores elements independently in a +# - The :class:`~torchrl.data.ListStorage` stores elements independently in a # list. It supports any data type, but this flexibility comes at the cost # of efficiency; -# - The :class:`torchrl.dataLazyTensorStorage` stores tensors or -# :class:`tensordidct.TensorDict` (or :class:`torchrl.data.tensorclass`) +# - The :class:`~torchrl.data.LazyTensorStorage` stores tensors data +# structures contiguously. +# It works naturally with :class:`~tensordidct.TensorDict` +# (or :class:`~torchrl.data.tensorclass`) # objects. The storage is contiguous on a per-tensor basis, meaning that # sampling will be more efficient than when using a list, but the # implicit restriction is that any data passed to it must have the same -# basic properties as the -# first batch of data that was used to instantiate the buffer. +# basic properties (such as shape and dtype) as the first batch of data that +# was used to instantiate the buffer. # Passing data that does not match this requirement will either raise an # exception or lead to some undefined behaviours. -# - The :class:`torchrl.dataLazyMemmapStorage` works as the -# :class:`torchrl.data.LazyTensorStorage` in that it is lazy (ie. it +# - The :class:`~torchrl.data.LazyMemmapStorage` works as the +# :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it # expects the first batch of data to be instantiated), and it requires data # that match in shape and dtype for each batch stored. What makes this -# storage unique is that it points to disk files, meaning that it can -# support very large datasets while still accessing data in a contiguous -# manner. +# storage unique is that it points to disk files (or uses the filesystem +# storage), meaning that it can support very large datasets while still +# accessing data in a contiguous manner. # # Let us see how we can use each of these storages: @@ -133,7 +149,7 @@ from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage # We define the maximum size of the buffer -size = 10_000 +size = 100 ###################################################################### # A buffer with a list storage buffer can store any kind of data (but we must @@ -144,9 +160,9 @@ ###################################################################### # Because it is the one with the lowest amount of assumption, the -# :class:`torchrl.data.ListStorage` is the default storage in TorchRL. +# :class:`~torchrl.data.ListStorage` is the default storage in TorchRL. # -# A :class:`torchrl.data.LazyTensorStorage` can store data contiguously. +# A :class:`~torchrl.data.LazyTensorStorage` can store data contiguously. # This should be the preferred option when dealing with complicated but # unchanging data structures of medium size: @@ -156,6 +172,10 @@ # Let us create a batch of data of size ``torch.Size([3])` with 2 tensors # stored in it: # + +import torch +from tensordict import TensorDict + data = TensorDict( { "a": torch.arange(12).view(3, 4), @@ -166,7 +186,7 @@ print(data) ###################################################################### -# The first call to :meth:`torchrl.data.ReplayBuffer.extend` will +# The first call to :meth:`~torchrl.data.ReplayBuffer.extend` will # instantiate the storage. The first dimension of the data is unbound into # separate datapoints: @@ -181,7 +201,7 @@ print("samples", sample["a"], sample["b", "c"]) ###################################################################### -# A :class:`torchrl.data.LazyMemmapStorage` is created in the same manner: +# A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner: # buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size)) @@ -208,16 +228,20 @@ # Integration with TensorDict # --------------------------- # +# .. _tuto_rb_td: +# # The tensor location follows the same structure as the TensorDict that # contains them: this makes it easy to save and load buffers during training. # -# To use :class:`tensordict.TensorDict` as a data carrier at its fullest -# potential, the :class:`torchrl.data.TensorDictReplayBuffer` class should +# To use :class:`~tensordict.TensorDict` as a data carrier at its fullest +# potential, the :class:`~torchrl.data.TensorDictReplayBuffer` class can # be used. # One of its key benefits is its ability to handle the organization of sampled # data, along with any additional information that may be required # (such as sample indices). -# It can be built in the same manner as a standard :class:`torchrl.data.ReplayBuffer` and can +# +# It can be built in the same manner as a standard +# :class:`~torchrl.data.ReplayBuffer` and can # generally be used interchangeably. # @@ -245,7 +269,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The ReplayBuffer class and associated subclasses also work natively with -# :class:`tensordict.tensorclass` classes, which can conviniently be used to +# :class:`~tensordict.tensorclass` classes, which can conveniently be used to # encode datasets in a more explicit manner: from tensordict import tensorclass @@ -260,10 +284,10 @@ class MyData: data = MyData( images=torch.randint( 255, - (1000, 64, 64, 3), + (10, 64, 64, 3), ), - labels=torch.randint(100, (1000,)), - batch_size=[1000], + labels=torch.randint(100, (10,)), + batch_size=[10], ) tempdir = tempfile.TemporaryDirectory() @@ -279,31 +303,28 @@ class MyData: ###################################################################### # As expected. the data has the proper class and shape! # -# Integration with PyTree -# ~~~~~~~~~~~~~~~~~~~~~~~ +# Integration with other tensor structures (PyTrees) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # TorchRL's replay buffers also work with any pytree data structure. # A PyTree is a nested structure of arbitrary depth made of dicts, lists and/or # tuples where the leaves are tensors. # This means that one can store in contiguous memory any such tree structure! # Various storages can be used: -# :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` -# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this kind of data. +# :class:`~torchrl.data.replay_buffers.TensorStorage`, +# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this +# kind of data. # -# Here is a bried demonstration of what this feature looks like: +# Here is a brief demonstration of what this feature looks like: # from torch.utils._pytree import tree_map -# With pytrees, any callable can be used as a transform: -def transform(x): - # Zeros all the data in the pytree - return tree_map(lambda y: y * 0, x) - - +###################################################################### # Let's build our replay buffer on disk: -rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform) +rb = ReplayBuffer(storage=LazyMemmapStorage(size)) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, @@ -315,6 +336,20 @@ def transform(x): sample = rb.sample(10) +###################################################################### +# With pytrees, any callable can be used as a transform: + + +def transform(x): + # Zeros all the data in the pytree + return tree_map(lambda y: y * 0, x) + + +rb.append_transform(transform) +sample = rb.sample(batch_size=12) + + +###################################################################### # let's check that our transform did its job: def assert0(x): assert (x == 0).all() @@ -323,9 +358,12 @@ def assert0(x): tree_map(assert0, sample) +###################################################################### # Sampling and iterating over buffers # ----------------------------------- # +# .. _tuto_rb_sampling: +# # Replay Buffers support multiple sampling strategies: # # - If the batch-size is fixed and can be defined at construction time, it can @@ -333,7 +371,7 @@ def assert0(x): # - With a fixed batch-size, the replay buffer can be iterated over to gather # samples; # - If the batch-size is dynamic, it can be passed to the -# :class:`torchrl.data.ReplayBuffer.sample` method +# :class:`~torchrl.data.ReplayBuffer.sample` method # on-the-fly. # # Sampling can be done using multithreading, but this is incompatible with the @@ -344,9 +382,19 @@ def assert0(x): # # Fixed batch-size # ~~~~~~~~~~~~~~~~ -# If the batch-size is passed during construction, it should be ommited when +# +# If the batch-size is passed during construction, it should be omitted when # sampling: +data = MyData( + images=torch.randint( + 255, + (200, 64, 64, 3), + ), + labels=torch.randint(100, (200,)), + batch_size=[200], +) + buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) buffer_lazymemmap.extend(data) buffer_lazymemmap.sample() @@ -357,7 +405,8 @@ def assert0(x): # # To enable multithreaded sampling, just pass a positive integer to the # ``prefetch`` keyword argument during construction. This should speed up -# sampling considerably: +# sampling considerably whenever sampling is time consuming (e.g., when +# using prioritized samplers): buffer_lazymemmap = ReplayBuffer( @@ -368,8 +417,8 @@ def assert0(x): ###################################################################### -# Fixed batch-size, iterating over the buffer -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Iterating over the buffer with a fixed batch-size +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can also iterate over the buffer like we would do with a regular # dataloader, as long as the batch-size is predefined: @@ -384,7 +433,8 @@ def assert0(x): ###################################################################### # Due to the fact that our sampling technique is entirely random and does not # prevent replacement, the iterator in question is infinite. However, we can -# make use of the :class:`torchrl.data.replay_buffers.SamplerWithoutReplacement` +# make use of the +# :class:`~torchrl.data.replay_buffers.SamplerWithoutReplacement` # instead, which will transform our buffer into a finite iterator: # @@ -397,10 +447,10 @@ def assert0(x): # we create a data that is big enough to get a couple of samples data = TensorDict( { - "a": torch.arange(512).view(128, 4), - ("b", "c"): torch.arange(1024).view(128, 8), + "a": torch.arange(64).view(16, 4), + ("b", "c"): torch.arange(128).view(16, 8), }, - batch_size=[128], + batch_size=[16], ) buffer_lazymemmap.extend(data) @@ -414,7 +464,7 @@ def assert0(x): # ~~~~~~~~~~~~~~~~~~ # # In contrast to what we have seen earlier, the ``batch_size`` keyword -# argument can be omitted and passed directly to the `sample` method: +# argument can be omitted and passed directly to the ``sample`` method: buffer_lazymemmap = ReplayBuffer( @@ -428,7 +478,10 @@ def assert0(x): # Prioritized Replay buffers # -------------------------- # -# TorchRL also provides an interface for prioritized replay buffers. +# .. _tuto_rb_prb: +# +# TorchRL also provides an interface for +# `prioritized replay buffers `_. # This buffer class samples data according to a priority signal that is passed # through the data. # @@ -443,7 +496,7 @@ def assert0(x): from torchrl.data.replay_buffers.samplers import PrioritizedSampler -size = 1000 +size = 100 rb = ReplayBuffer( storage=ListStorage(size), @@ -462,8 +515,8 @@ def assert0(x): # buffer, the priority is set to a default value of 1. Once the priority has # been computed (usually through the loss), it must be updated in the buffer. # -# This is done via the `update_priority` method, which requires the indices -# as well as the priority. +# This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority` +# method, which requires the indices as well as the priority. # We assign an artificially high priority to the second sample in the dataset # to observe its effect on sampling: # @@ -485,6 +538,7 @@ def assert0(x): ###################################################################### # We see that using a prioritized replay buffer requires a series of extra # steps in the training loop compared with a regular buffer: +# # - After collecting data and extending the buffer, the priority of the # items must be updated; # - After computing the loss and getting a "priority signal" from it, we must @@ -497,10 +551,10 @@ def assert0(x): # that the appropriate methods are called at the appropriate place, if and # only if a prioritized buffer is being used. # -# Let us see how we can improve this with TensorDict. We saw that the -# :class:`torchrl.data.TensorDictReplayBuffer` returns data augmented with -# their relative storage indices. One feature we did not mention is that -# this class also ensures that the priority +# Let us see how we can improve this with :class:`~tensordict.TensorDict`. +# We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data +# augmented with their relative storage indices. One feature we did not mention +# is that this class also ensures that the priority # signal is automatically parsed to the prioritized sampler if present during # extension. # @@ -568,6 +622,8 @@ def assert0(x): # Using transforms # ---------------- # +# .. _tuto_rb_transform: +# # The data stored in a replay buffer may not be ready to be presented to a # loss module. # In some cases, the data produced by a collector can be too heavy to be @@ -591,8 +647,14 @@ def assert0(x): from torchrl.collectors import RandomPolicy, SyncDataCollector -from torchrl.envs import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import ( + Compose, + GrayScale, + Resize, + ToTensorImage, + TransformedEnv, +) env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True), @@ -616,7 +678,7 @@ def assert0(x): # To do this, we will append a transform to the collector to select the keys # we want to see appearing: -from torchrl.envs import ExcludeTransform +from torchrl.envs.transforms import ExcludeTransform collector = SyncDataCollector( env, @@ -671,7 +733,7 @@ def assert0(x): # A more complex examples: using CatFrames # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The :class:`torchrl.envs.CatFrames` transform unfolds the observations +# The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations # through time, creating a n-back memory of past events that allow the model # to take the past events into account (in the case of POMDPs or with # recurrent policies such as Decision Transformers). Storing these concatenated @@ -718,7 +780,7 @@ def assert0(x): GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) -rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) +rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) rb.add(data_exclude) @@ -738,6 +800,56 @@ def assert0(x): assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all() +###################################################################### +# Storing trajectories +# -------------------- +# +# .. _tuto_rb_traj: +# +# In many cases, it is desirable to access trajectories from the buffer rather +# than simple transitions. TorchRL offers multiple ways of achieving this. +# +# The preferred way is currently to store trajectories along the first +# dimension of the buffer and use a :class:`~torchrl.data.SliceSampler` to +# sample these batches of data. This class only needs a couple of information +# about your data structure to do its job (not that as of now it is only +# compatible with tensordict-structured data): the number of slices or their +# length and some information about where the separation between the +# episodes can be found (e.g. :ref:`recall that ` with a +# :ref:`DataCollector `, the trajectory id is stored in +# ``("collector", "traj_ids")``). In this simple example, we construct a data +# with 4 consecutive short trajectories and sample 4 slices out of it, each of +# length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps). +# We mark the steps as well. + +from torchrl.data import SliceSampler + +rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(size), + sampler=SliceSampler(traj_key="episode", num_slices=4), + batch_size=8, +) +episode = torch.zeros(10, dtype=torch.int) +episode[:3] = 1 +episode[3:5] = 2 +episode[5:7] = 3 +episode[7:] = 4 +steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)]) +data = TensorDict( + { + "episode": episode, + "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5), + "act": torch.randn((20,)).expand(10, 20), + "other": torch.randn((20, 50)).expand(10, 20, 50), + "steps": steps, + }, + [10], +) +rb.extend(data) +sample = rb.sample() +print("episode are grouped", sample["episode"]) +print("steps are successive", sample["steps"]) + ###################################################################### # Conclusion # ---------- @@ -751,3 +863,13 @@ def assert0(x): # - Choose the best storage type for your problem (list, memory or disk-based); # - Minimize the memory footprint of your buffer. # +# Next steps +# ---------- +# +# - Check the data API reference to learn about offline datasets in TorchRL, +# which are based on our Replay Buffer API; +# - Check other samplers such as +# :class:`~torchrl.data.SamplerWithoutReplacement`, +# :class:`~torchrl.data.PrioritizedSliceSampler` and +# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers +# such as :class:`~torchrl.data.TensorDictMaxValueWriter`. diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index d1a261e63f5..25213503e19 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ Introduction to TorchRL -============================ +======================= This demo was presented at ICML 2022 on the industry demo day. """ ############################################################################## @@ -32,75 +32,138 @@ # **Content**: # .. aafig:: # -# "torchrl" -# │ -# ├── "collectors" -# │ └── "collectors.py" -# ├── "data" -# │ ├── "tensor_specs.py" -# │ ├── "postprocs" -# │ │ └── "postprocs.py" -# │ └── "replay_buffers" -# │ ├── "replay_buffers.py" -# │ └── "storages.py" -# ├── "envs" -# │ ├── "common.py" -# │ ├── "env_creator.py" -# │ ├── "gym_like.py" -# │ ├── "vec_env.py" -# │ ├── "libs" -# │ │ ├── "dm_control.py" -# │ │ └── "gym.py" -# │ └── "transforms" -# │ ├── "functional.py" -# │ └── "transforms.py" -# ├── "modules" -# │ ├── "distributions" -# │ │ ├── "continuous.py" -# │ │ └── "discrete.py" -# │ ├── "models" -# │ │ ├── "models.py" -# │ │ └── "exploration.py" -# │ └── "tensordict_module" -# │ ├── "actors.py" -# │ ├── "common.py" -# │ ├── "exploration.py" -# │ ├── "probabilistic.py" -# │ └── "sequence.py" -# ├── "objectives" -# │ ├── "common.py" -# │ ├── "ddpg.py" -# │ ├── "dqn.py" -# │ ├── "functional.py" -# │ ├── "ppo.py" -# │ ├── "redq.py" -# │ ├── "reinforce.py" -# │ ├── "sac.py" -# │ ├── "utils.py" -# │ └── "value" -# │ ├── "advantages.py" -# │ ├── "functional.py" -# │ ├── "pg.py" -# │ ├── "utils.py" -# │ └── "vtrace.py" -# ├── "record" -# │ └── "recorder.py" -# └── "trainers" -# ├── "loggers" -# │ ├── "common.py" -# │ ├── "csv.py" -# │ ├── "mlflow.py" -# │ ├── "tensorboard.py" -# │ └── "wandb.py" -# ├── "trainers.py" -# └── "helpers" -# ├── "collectors.py" -# ├── "envs.py" -# ├── "loggers.py" -# ├── "losses.py" -# ├── "models.py" -# ├── "replay_buffer.py" -# └── "trainers.py" +# "torchrl" +# │ +# ├── "collectors" +# │ └── "collectors.py" +# │ │ +# │ └── "distributed" +# │ └── "default_configs.py" +# │ └── "generic.py" +# │ └── "ray.py" +# │ └── "rpc.py" +# │ └── "sync.py" +# ├── "data" +# │ │ +# │ ├── "datasets" +# │ │ └── "atari_dqn.py" +# │ │ └── "d4rl.py" +# │ │ └── "d4rl_infos.py" +# │ │ └── "gen_dgrl.py" +# │ │ └── "minari_data.py" +# │ │ └── "openml.py" +# │ │ └── "openx.py" +# │ │ └── "roboset.py" +# │ │ └── "vd4rl.py" +# │ ├── "postprocs" +# │ │ └── "postprocs.py" +# │ ├── "replay_buffers" +# │ │ └── "replay_buffers.py" +# │ │ └── "samplers.py" +# │ │ └── "storages.py" +# │ │ └── "transforms.py" +# │ │ └── "writers.py" +# │ ├── "rlhf" +# │ │ └── "dataset.py" +# │ │ └── "prompt.py" +# │ │ └── "reward.py" +# │ └── "tensor_specs.py" +# ├── "envs" +# │ └── "batched_envs.py" +# │ └── "common.py" +# │ └── "env_creator.py" +# │ └── "gym_like.py" +# │ ├── "libs" +# │ │ └── "brax.py" +# │ │ └── "dm_control.py" +# │ │ └── "envpool.py" +# │ │ └── "gym.py" +# │ │ └── "habitat.py" +# │ │ └── "isaacgym.py" +# │ │ └── "jumanji.py" +# │ │ └── "openml.py" +# │ │ └── "pettingzoo.py" +# │ │ └── "robohive.py" +# │ │ └── "smacv2.py" +# │ │ └── "vmas.py" +# │ ├── "model_based" +# │ │ └── "common.py" +# │ │ └── "dreamer.py" +# │ ├── "transforms" +# │ │ └── "functional.py" +# │ │ └── "gym_transforms.py" +# │ │ └── "r3m.py" +# │ │ └── "rlhf.py" +# │ │ └── "transforms.py" +# │ │ └── "vc1.py" +# │ │ └── "vip.py" +# │ └── "vec_envs.py" +# ├── "modules" +# │ ├── "distributions" +# │ │ └── "continuous.py" +# │ │ └── "discrete.py" +# │ │ └── "truncated_normal.py" +# │ ├── "models" +# │ │ └── "decision_transformer.py" +# │ │ └── "exploration.py" +# │ │ └── "model_based.py" +# │ │ └── "models.py" +# │ │ └── "multiagent.py" +# │ │ └── "rlhf.py" +# │ ├── "planners" +# │ │ └── "cem.py" +# │ │ └── "common.py" +# │ │ └── "mppi.py" +# │ └── "tensordict_module" +# │ └── "actors.py" +# │ └── "common.py" +# │ └── "exploration.py" +# │ └── "probabilistic.py" +# │ └── "rnn.py" +# │ └── "sequence.py" +# │ └── "world_models.py" +# ├── "objectives" +# │ └── "a2c.py" +# │ └── "common.py" +# │ └── "cql.py" +# │ └── "ddpg.py" +# │ └── "decision_transformer.py" +# │ └── "deprecated.py" +# │ └── "dqn.py" +# │ └── "dreamer.py" +# │ └── "functional.py" +# │ └── "iql.py" +# │ ├── "multiagent" +# │ │ └── "qmixer.py" +# │ └── "ppo.py" +# │ └── "redq.py" +# │ └── "reinforce.py" +# │ └── "sac.py" +# │ └── "td3.py" +# │ ├── "value" +# │ └── "advantages.py" +# │ └── "functional.py" +# │ └── "pg.py" +# ├── "record" +# │ ├── "loggers" +# │ │ └── "common.py" +# │ │ └── "csv.py" +# │ │ └── "mlflow.py" +# │ │ └── "tensorboard.py" +# │ │ └── "wandb.py" +# │ └── "recorder.py" +# ├── "trainers" +# │ │ +# │ ├── "helpers" +# │ │ └── "collectors.py" +# │ │ └── "envs.py" +# │ │ └── "logger.py" +# │ │ └── "losses.py" +# │ │ └── "models.py" +# │ │ └── "replay_buffer.py" +# │ │ └── "trainers.py" +# │ └── "trainers.py" +# └── "version.py" # # Unlike other domains, RL is less about media than *algorithms*. As such, it # is harder to make truly independent components. @@ -135,9 +198,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -475,21 +543,30 @@ # Functional Programming (Ensembling / Meta-RL) # ---------------------------------------------- -from tensordict.nn import make_functional +from tensordict import TensorDict -params = make_functional(sequence) -len(list(sequence.parameters())) # functional modules have no parameters +params = TensorDict.from_module(sequence) +print("extracted params", params) ############################################################################### +# functional call using tensordict: -sequence(tensordict, params) +with params.to_module(sequence): + sequence(tensordict) ############################################################################### - +# Using vectorized map for model ensembling from torch import vmap params_expand = params.expand(4) -tensordict_exp = vmap(sequence, (None, 0))(tensordict, params_expand) + + +def exec_sequence(params, data): + with params.to_module(sequence): + return sequence(data) + + +tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict) print(tensordict_exp) ############################################################################### @@ -678,8 +755,7 @@ total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) - storing_devices=devices, # len must match len of env created - devices=devices, + device=devices, ) ############################################################################### @@ -701,8 +777,7 @@ total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) - storing_devices=devices, # len must match len of env created - devices=devices, + device=devices, ) for i, d in enumerate(collector): @@ -737,7 +812,8 @@ def forward(self, obs, action): value_module, in_keys=["observation", "action"], out_keys=["state_action_value"] ) -loss_fn = DDPGLoss(actor, value, gamma=0.99) +loss_fn = DDPGLoss(actor, value) +loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99) ############################################################################### diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index dc836b43150..4c792d44b80 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -1,9 +1,15 @@ # -*- coding: utf-8 -*- """ TorchRL envs -============================ +============ + +**Author**: `Vincent Moens `_ + +.. _envs_tuto: + """ ############################################################################## +# # Environments play a crucial role in RL settings, often somewhat similar to # datasets in supervised and unsupervised settings. The RL community has # become quite familiar with OpenAI gym API which offers a flexible way of @@ -19,7 +25,10 @@ # To run this part of the tutorial, you will need to have a recent version of # the gym library installed, as well as the atari suite. You can get this # installed by installing the following packages: -# $ pip install gym atari-py ale-py gym[accept-rom-license] pygame +# +# .. code-block:: +# $ pip install gym atari-py ale-py gym[accept-rom-license] pygame +# # To unify all frameworks, torchrl environments are built inside the # ``__init__`` method with a private method called ``_build_env`` that # will pass the arguments and keyword arguments to the root library builder. @@ -37,9 +46,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -575,7 +589,7 @@ def env_make(env_name): parallel_env = ParallelEnv( 2, [env_make, env_make], - [{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], + create_env_kwargs=[{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], ) tensordict = parallel_env.reset() @@ -619,7 +633,7 @@ def env_make(env_name): parallel_env = ParallelEnv( 2, [env_make, env_make], - [{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], + create_env_kwargs=[{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], ) parallel_env = TransformedEnv(parallel_env, GrayScale()) # transforms on main process tensordict = parallel_env.reset() diff --git a/version.txt b/version.txt index 0d91a54c7d4..9e11b32fcaa 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.0 +0.3.1