From 3f0413187353a01cba62ee8ff1565431c17432d3 Mon Sep 17 00:00:00 2001 From: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Date: Mon, 22 Jan 2024 09:17:31 +0100 Subject: [PATCH 01/35] [Docs] Fix doc of ToTensorImage transforms.py (#1824) --- torchrl/envs/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 47374a43a6b..ab95cd8352b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1163,7 +1163,7 @@ class ToTensorImage(ObservationTransform): from_int (bool, optional): if ``True``, the tensor will be scaled from the range [0, 255] to the range [0.0, 1.0]. if `False``, the tensor will not be scaled. if `None`, the tensor will be scaled if - it's a floating-point tensor. default=None. + it's not a floating-point tensor. default=None. unsqueeze (bool): if ``True``, the observation tensor is unsqueezed along the first dimension. default=False. dtype (torch.dtype, optional): dtype to use for the resulting From 55ec016128c5b1de885491e5b0eb1c2b55e9aafc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Jan 2024 17:24:50 +0000 Subject: [PATCH 02/35] [BugFix] Fix device of container generated values in transforms (#1827) --- torchrl/envs/transforms/transforms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ab95cd8352b..ed8be751474 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5151,6 +5151,8 @@ def _reset( step_count = tensordict.get(step_count_key, default=None) if step_count is None: step_count = self.container.observation_spec[step_count_key].zero() + if step_count.device != reset.device: + step_count = step_count.to(reset.device, non_blocking=True) # zero the step count if reset is needed step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) @@ -6413,7 +6415,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: raise ValueError( self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec)) ) - action_spec.update_mask(mask) + action_spec.update_mask(mask.to(action_spec.device)) return tensordict def _reset( @@ -6424,7 +6426,10 @@ def _reset( raise ValueError( self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec)) ) - action_spec.update_mask(tensordict.get(self.in_keys[1], None)) + mask = tensordict.get(self.in_keys[1], None) + if mask is not None: + mask = mask.to(action_spec.device) + action_spec.update_mask(mask) # TODO: Check that this makes sense with _set_missing_tolerance(self, True): From 6769fee3cc28ec5940a9b065a81e3f587b2742c0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Jan 2024 17:58:34 +0000 Subject: [PATCH 03/35] [Feature] Atari DQN dataset (#1815) --- .../scripts_ataridqn/environment.yml | 25 + .../linux_libs/scripts_ataridqn/install.sh | 51 ++ .../scripts_ataridqn/post_process.sh | 6 + .../scripts_ataridqn/run-clang-format.py | 356 ++++++++ .../linux_libs/scripts_ataridqn/run_test.sh | 24 + .../linux_libs/scripts_ataridqn/setup_env.sh | 50 ++ .github/workflows/test-linux-libs.yml | 26 + docs/source/reference/data.rst | 2 +- test/test_libs.py | 55 ++ test/test_rb.py | 3 +- torchrl/data/datasets/__init__.py | 1 + torchrl/data/datasets/atari_dqn.py | 760 ++++++++++++++++++ torchrl/data/datasets/d4rl.py | 4 +- torchrl/data/datasets/minari_data.py | 8 +- torchrl/data/datasets/openml.py | 4 +- torchrl/data/datasets/openx.py | 12 +- torchrl/data/datasets/roboset.py | 8 +- torchrl/data/datasets/vd4rl.py | 4 +- torchrl/data/replay_buffers/replay_buffers.py | 45 +- torchrl/data/replay_buffers/samplers.py | 133 ++- 20 files changed, 1511 insertions(+), 66 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_ataridqn/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/install.sh create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh create mode 100644 torchrl/data/datasets/atari_dqn.py diff --git a/.github/unittest/linux_libs/scripts_ataridqn/environment.yml b/.github/unittest/linux_libs/scripts_ataridqn/environment.yml new file mode 100644 index 00000000000..b88860dddde --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/environment.yml @@ -0,0 +1,25 @@ +channels: + - pytorch + - defaults + - conda-forge +dependencies: + - pip + - gsutil + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - tqdm + - h5py + - datasets + - pillow diff --git a/.github/unittest/linux_libs/scripts_ataridqn/install.sh b/.github/unittest/linux_libs/scripts_ataridqn/install.sh new file mode 100755 index 00000000000..1be476425a6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +else + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_ataridqn/post_process.sh b/.github/unittest/linux_libs/scripts_ataridqn/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py b/.github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh b/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh new file mode 100755 index 00000000000..ee7bf9b46b1 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestAtariDQN --error-for-skips --runslow +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_ataridqn/setup_env.sh b/.github/unittest/linux_libs/scripts_ataridqn/setup_env.sh new file mode 100755 index 00000000000..5b415112814 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_ataridqn/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip curl + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 3b090582e4f..abf78e5e19c 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -16,6 +16,32 @@ concurrency: cancel-in-progress: true jobs: + + unittests-atari-dqn: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh + bash .github/unittest/linux_libs/scripts_ataridqn/install.sh + bash .github/unittest/linux_libs/scripts_ataridqn/run_test.sh + bash .github/unittest/linux_libs/scripts_ataridqn/post_process.sh + unittests-brax: strategy: matrix: diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 90dbe4f3d4e..2def1b4bfa8 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -280,7 +280,7 @@ Here's an example: :toctree: generated/ :template: rl_template.rst - + AtariDQNExperienceReplay D4RLExperienceReplay GenDGRLExperienceReplay MinariExperienceReplay diff --git a/test/test_libs.py b/test/test_libs.py index e034cea84c7..13891331b05 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -61,9 +61,11 @@ MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, + ReplayBufferEnsemble, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) +from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay @@ -2489,6 +2491,59 @@ def test_load(self, image_size): break +@pytest.mark.slow +class TestAtariDQN: + @pytest.fixture(scope="class") + def limit_max_runs(self): + prev_val = AtariDQNExperienceReplay._max_runs + AtariDQNExperienceReplay._max_runs = 3 + yield + AtariDQNExperienceReplay._max_runs = prev_val + + @pytest.mark.parametrize("dataset_id", ["Asterix/1", "Pong/4"]) + @pytest.mark.parametrize( + "num_slices,slice_len", [[None, None], [None, 8], [2, None]] + ) + def test_single_dataset(self, dataset_id, slice_len, num_slices, limit_max_runs): + dataset = AtariDQNExperienceReplay( + dataset_id, slice_len=slice_len, num_slices=num_slices + ) + sample = dataset.sample(64) + for key in ( + ("next", "observation"), + ("next", "truncated"), + ("next", "terminated"), + ("next", "done"), + ("next", "reward"), + "observation", + "action", + "done", + "truncated", + "terminated", + ): + assert key in sample.keys(True) + assert sample.shape == (64,) + assert sample.get_non_tensor("metadata")["dataset_id"] == dataset_id + + @pytest.mark.parametrize( + "num_slices,slice_len", [[None, None], [None, 8], [2, None]] + ) + def test_double_dataset(self, slice_len, num_slices, limit_max_runs): + dataset_pong = AtariDQNExperienceReplay( + "Pong/4", slice_len=slice_len, num_slices=num_slices + ) + dataset_asterix = AtariDQNExperienceReplay( + "Asterix/1", slice_len=slice_len, num_slices=num_slices + ) + dataset = ReplayBufferEnsemble( + dataset_pong, dataset_asterix, sample_from_all=True, batch_size=128 + ) + sample = dataset.sample() + assert sample.shape == (2, 64) + assert sample[0].get_non_tensor("metadata")["dataset_id"] == "Pong/4" + assert sample[1].get_non_tensor("metadata")["dataset_id"] == "Asterix/1" + + @pytest.mark.slow class TestOpenX: @pytest.mark.parametrize( diff --git a/test/test_rb.py b/test/test_rb.py index cf9deabb956..5d184c365e2 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1808,7 +1808,8 @@ def test_slice_sampler_errors(self): storage.set(range(100), data) sampler = SliceSampler(num_slices=num_slices) with pytest.raises( - RuntimeError, match="can only sample from TensorStorage subclasses" + RuntimeError, + match="Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories.", ): index, _ = sampler.sample(storage, batch_size=batch_size) diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 4822ac35c54..092b80083a1 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1,3 +1,4 @@ +from .atari_dqn import AtariDQNExperienceReplay from .d4rl import D4RLExperienceReplay from .gen_dgrl import GenDGRLExperienceReplay from .minari_data import MinariExperienceReplay diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py new file mode 100644 index 00000000000..93950532026 --- /dev/null +++ b/torchrl/data/datasets/atari_dqn.py @@ -0,0 +1,760 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import gzip +import io +import json +import logging + +import os +import shutil +import subprocess +import tempfile +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +from tensordict import MemoryMappedTensor, TensorDict +from torch import multiprocessing as mp + +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import ( + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import Storage +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter +from torchrl.envs.utils import _classproperty + + +class AtariDQNExperienceReplay(TensorDictReplayBuffer): + """Atari DQN Experience replay class. + + The Atari DQN dataset (https://offline-rl.github.io/) is a collection of 5 training + iterations of DQN over each of the Arari 2600 games for a total of 200 million frames. + The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset + has 50 million steps in total. + + The data format follows the TED convention. Since the dataset is quite heavy, + the data formatting is done on-line, at sampling time. + + To make training more modular, we split the dataset in each of the Atari games + and separate each training round. Consequently, each dataset is presented as + a Storage of length 50x10^6 elements. Under the hood, this dataset is split + in 50 memory-mapped tensordicts of length 1 million each. + + Args: + dataset_id (str): The dataset to be downloaded. + Must be part of ``AtariDQNExperienceReplay.available_datasets``. + batch_size (int): Batch-size used during sampling. + Can be overridden by `data.sample(batch_size)` if necessary. + + Keyword Args: + root (Path or str, optional): The AtariDQN dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/atari`. + num_procs (int, optional): number of processes to launch for preprocessing. + Has no effect whenever the data is already downloaded. Defaults to 0 + (no multiprocessing used). + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. + num_slices (int, optional): the number of slices to be sampled. The batch-size + must be greater or equal to the ``num_slices`` argument. Exclusive + with ``slice_len``. Defaults to ``None`` (no slice sampling). + The ``sampler`` arg will override this value. + slice_len (int, optional): the length of the slices to be sampled. The batch-size + must be greater or equal to the ``slice_len`` argument and divisible + by it. Exclusive with ``num_slices``. Defaults to ``None`` (no slice sampling). + The ``sampler`` arg will override this value. + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + The ``sampler`` arg will override this value. + replacement (bool, optional): if ``False``, sampling will occur without replacement. + The ``sampler`` arg will override this value. + + Attributes: + available_datasets: list of available datasets, formatted as `/`. Example: + `"Pong/5"`, `"Krull/2"`, ... + dataset_id (str): the name of the dataset. + episodes (torch.Tensor): a 1d tensor indicating to what run each of the + 1M frames belongs. To be used with :class:`~torchrl.data.replay_buffers.SliceSampler` + to cheaply sample slices of episodes. + + Examples: + >>> from torchrl.data.datasets import AtariDQNExperienceReplay + >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128) + >>> for data in dataset: + ... print(data) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False) + + .. warning:: + Atari-DQN does not provide the next observation after a termination signal. + In other words, there is no way to obtain the ``("next", "observation")`` state + when ``("next", "done")`` is ``True``. This value is filled with 0s but should + not be used in practice. If TorchRL's value estimators (:class:`~torchrl.objectives.values.ValueEstimator`) + are used, this should not be an issue. + + .. note:: + Because the construction of the sampler for episode sampling is slightly + convoluted, we made it easy for users to pass the arguments of the + :class:`~torchrl.data.replay_buffers.SliceSampler` directly to the + ``AtariDQNExperienceReplay`` dataset: any of the ``num_slices`` or + ``slice_len`` arguments will make the sampler an instance of + :class:`~torchrl.data.replay_buffers.SliceSampler`. The ``strict_length`` + can also be passed. + + >>> from torchrl.data.datasets import AtariDQNExperienceReplay + >>> from torchrl.data.replay_buffers import SliceSampler + >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64) + >>> for data in dataset: + ... print(data) + ... print(data.get("index")) # indices are in 4 groups of consecutive values + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([128]), + device=None, + is_shared=False) + tensor([2657628, 2657629, 2657630, 2657631, 2657632, 2657633, 2657634, 2657635, + 2657636, 2657637, 2657638, 2657639, 2657640, 2657641, 2657642, 2657643, + 2657644, 2657645, 2657646, 2657647, 2657648, 2657649, 2657650, 2657651, + 2657652, 2657653, 2657654, 2657655, 2657656, 2657657, 2657658, 2657659, + 2657660, 2657661, 2657662, 2657663, 2657664, 2657665, 2657666, 2657667, + 2657668, 2657669, 2657670, 2657671, 2657672, 2657673, 2657674, 2657675, + 2657676, 2657677, 2657678, 2657679, 2657680, 2657681, 2657682, 2657683, + 2657684, 2657685, 2657686, 2657687, 2657688, 2657689, 2657690, 2657691, + 1995687, 1995688, 1995689, 1995690, 1995691, 1995692, 1995693, 1995694, + 1995695, 1995696, 1995697, 1995698, 1995699, 1995700, 1995701, 1995702, + 1995703, 1995704, 1995705, 1995706, 1995707, 1995708, 1995709, 1995710, + 1995711, 1995712, 1995713, 1995714, 1995715, 1995716, 1995717, 1995718, + 1995719, 1995720, 1995721, 1995722, 1995723, 1995724, 1995725, 1995726, + 1995727, 1995728, 1995729, 1995730, 1995731, 1995732, 1995733, 1995734, + 1995735, 1995736, 1995737, 1995738, 1995739, 1995740, 1995741, 1995742, + 1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750]) + + .. note:: + As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`: + + >>> from torchrl.data.datasets import AtariDQNExperienceReplay + >>> from torchrl.data.replay_buffers import ReplayBufferEnsemble + >>> # we change this parameter for quick experimentation, in practice it should be left untouched + >>> AtariDQNExperienceReplay._max_runs = 2 + >>> dataset_asterix = AtariDQNExperienceReplay("Asterix/5", batch_size=128, slice_len=64, num_procs=4) + >>> dataset_pong = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64, num_procs=4) + >>> dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True) + >>> sample = dataset.sample() + >>> print("first sample, Asterix", sample[0]) + first sample, Asterix TensorDict( + fields={ + action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + index: TensorDict( + fields={ + buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False) + >>> print("second sample, Pong", sample[1]) + second sample, Pong TensorDict( + fields={ + action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + index: TensorDict( + fields={ + buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + metadata: NonTensorData( + data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Asterix/5'}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([64]), + device=None, + is_shared=False) + >>> print("Aggregate (metadata hidden)", sample) + Aggregate (metadata hidden) LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + index: LazyStackedTensorDict( + fields={ + buffer_ids: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False), + index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0), + metadata: LazyStackedTensorDict( + fields={ + }, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0), + observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2, 64]), + device=None, + is_shared=False, + stack_dim=0) + + """ + + @_classproperty + def available_datasets(cls): + games = [ + "AirRaid", + "Alien", + "Amidar", + "Assault", + "Asterix", + "Asteroids", + "Atlantis", + "BankHeist", + "BattleZone", + "BeamRider", + "Berzerk", + "Bowling", + "Boxing", + "Breakout", + "Carnival", + "Centipede", + "ChopperCommand", + "CrazyClimber", + "DemonAttack", + "DoubleDunk", + "ElevatorAction", + "Enduro", + "FishingDerby", + "Freeway", + "Frostbite", + "Gopher", + "Gravitar", + "Hero", + "IceHockey", + "Jamesbond", + "JourneyEscape", + "Kangaroo", + "Krull", + "KungFuMaster", + "MontezumaRevenge", + "MsPacman", + "NameThisGame", + "Phoenix", + "Pitfall", + "Pong", + "Pooyan", + "PrivateEye", + "Qbert", + "Riverraid", + "RoadRunner", + "Robotank", + "Seaquest", + "Skiing", + "Solaris", + "SpaceInvaders", + ] + return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)] + + # If we want to keep track of the original atari files + tmpdir = None + # use _max_runs for debugging, avoids downloading the entire dataset + _max_runs = None + + def __init__( + self, + dataset_id: str, + batch_size: int | None = None, + *, + root: str | Path | None = None, + download: bool | str = True, + sampler=None, + writer=None, + transform: "Transform" | None = None, # noqa: F821 + num_procs: int = 0, + num_slices: int | None = None, + slice_len: int | None = None, + strict_len: bool = True, + replacement: bool = True, + **kwargs, + ): + if dataset_id not in self.available_datasets: + raise ValueError( + "The dataseet_id is not part of the available datasets. The dataset should be named / " + "where is one of the Atari 2600 games and the run is a number betweeen 1 and 5. " + "The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets." + ) + self.dataset_id = dataset_id + from torchrl.data.datasets.utils import _get_root_dir + + if root is None: + root = _get_root_dir("atari") + self.root = root + self.num_procs = num_procs + if download == "force" or (download and not self._is_downloaded): + try: + self._download_and_preproc() + except Exception: + # remove temporary data + if os.path.exists(self.dataset_path): + shutil.rmtree(self.dataset_path) + raise + storage = _AtariStorage(self.dataset_path) + if writer is None: + writer = ImmutableDatasetWriter() + if sampler is None: + if num_slices is not None or slice_len is not None: + if not replacement: + sampler = SliceSamplerWithoutReplacement( + num_slices=num_slices, + slice_len=slice_len, + trajectories=storage.episodes, + ) + else: + sampler = SliceSampler( + num_slices=num_slices, + slice_len=slice_len, + trajectories=storage.episodes, + cache_values=True, + ) + elif not replacement: + sampler = SamplerWithoutReplacement() + + super().__init__( + storage=storage, + batch_size=batch_size, + writer=writer, + sampler=sampler, + collate_fn=lambda x: x, + transform=transform, + **kwargs, + ) + + @property + def episodes(self): + return self._storage.episodes + + @property + def root(self) -> Path: + return self._root + + @root.setter + def root(self, value): + self._root = Path(value) + + @property + def dataset_path(self) -> Path: + return self._root / self.dataset_id + + @property + def _is_downloaded(self): + if os.path.exists(self.dataset_path / "processed.json"): + with open(self.dataset_path / "processed.json", "r") as jsonfile: + return json.load(jsonfile).get("processed", False) == self._max_runs + return False + + def _download_and_preproc(self): + logging.info( + f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..." + ) + if os.path.exists(self.dataset_path): + shutil.rmtree(self.dataset_path) + with tempfile.TemporaryDirectory() as tempdir: + if self.tmpdir is not None: + tempdir = self.tmpdir + if not os.listdir(tempdir): + os.makedirs(tempdir, exist_ok=True) + # get the list of runs + command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs" + output = subprocess.run( + command, shell=True, capture_output=True + ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + files = [ + file.decode("utf-8").replace("$", "\$") # noqa: W605 + for file in output.stdout.splitlines() + if file.endswith(b".gz") + ] + self.remote_gz_files = self._list_runs(None, files) + total_runs = list(self.remote_gz_files)[-1] + if self.num_procs == 0: + for run, run_files in self.remote_gz_files.items(): + self._download_and_proc_split( + run, + run_files, + tempdir=tempdir, + dataset_path=self.dataset_path, + total_episodes=total_runs, + max_runs=self._max_runs, + ) + else: + func = functools.partial( + self._download_and_proc_split, + tempdir=tempdir, + dataset_path=self.dataset_path, + total_episodes=total_runs, + max_runs=self._max_runs, + ) + args = [ + (run, run_files) + for (run, run_files) in self.remote_gz_files.items() + ] + with mp.Pool(self.num_procs) as pool: + pool.starmap(func, args) + with open(self.dataset_path / "processed.json", "w") as file: + # we save self._max_runs such that changing the number of runs to process + # forces the data to be re-downloaded + json.dump({"processed": self._max_runs}, file) + + @classmethod + def _download_and_proc_split( + cls, run, run_files, *, tempdir, dataset_path, total_episodes, max_runs + ): + if (max_runs is not None) and (run >= max_runs): + return + tempdir = Path(tempdir) + os.makedirs(tempdir / str(run)) + files_str = " ".join(run_files) # .decode("utf-8") + logging.info("downloading", files_str) + command = f"gsutil -m cp {files_str} {tempdir}/{run}" + subprocess.run( + command, shell=True + ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + local_gz_files = cls._list_runs(tempdir / str(run)) + # we iterate over the dict but this one has length 1 + for run in local_gz_files: + path = dataset_path / str(run) + try: + cls._preproc_run(path, local_gz_files, run) + except Exception: + shutil.rmtree(path) + raise + shutil.rmtree(tempdir / str(run)) + logging.info(f"Concluded run {run} out of {total_episodes}") + + @classmethod + def _preproc_run(cls, path, gz_files, run): + files = gz_files[run] + td = TensorDict({}, []) + path = Path(path) + for file in files: + name = str(Path(file).parts[-1]).split(".")[0] + with gzip.GzipFile(file, mode="rb") as f: + file_content = f.read() + file_content = io.BytesIO(file_content) + file_content = np.load(file_content) + t = torch.as_tensor(file_content) + # Create the memmap file + key = cls._process_name(name) + if key == ("data", "observation"): + shape = t.shape + shape = [shape[0] + 1] + list(shape[1:]) + filename = path / "data" / "observation.memmap" + os.makedirs(filename.parent, exist_ok=True) + mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename) + mmap[:-1].copy_(t) + td[key] = mmap + # td["data", "next", key[1:]] = mmap[1:] + else: + if key in ( + ("data", "reward"), + ("data", "done"), + ("data", "terminated"), + ): + filename = path / "data" / "next" / (key[-1] + ".memmap") + os.makedirs(filename.parent, exist_ok=True) + mmap = MemoryMappedTensor.from_tensor(t, filename=filename) + td["data", "next", key[1:]] = mmap + else: + filename = path + for i, _key in enumerate(key): + if i == len(key) - 1: + _key = _key + ".memmap" + filename = filename / _key + os.makedirs(filename.parent, exist_ok=True) + mmap = MemoryMappedTensor.from_tensor(t, filename=filename) + td[key] = mmap + td.set_non_tensor("dataset_id", "/".join(path.parts[-3:-1])) + td.memmap_(path, copy_existing=False) + + @staticmethod + def _process_name(name): + if name.endswith("_ckpt"): + name = name[:-5] + if "store" in name: + key = ("data", name.split("_")[1]) + else: + key = (name,) + if key[-1] == "terminal": + key = (*key[:-1], "terminated") + return key + + @classmethod + def _list_runs(cls, download_path, gz_files=None): + path = download_path + if gz_files is None: + gz_files = [] + for root, _, files in os.walk(path): + for file in files: + if file.endswith(".gz"): + gz_files.append(os.path.join(root, file)) + runs = defaultdict(list) + for file in gz_files: + filename = Path(file).parts[-1] + name, episode, extension = str(filename).split(".") + episode = int(episode) + runs[episode].append(file) + return dict(sorted(runs.items(), key=lambda x: x[0])) + + +class _AtariStorage(Storage): + def __init__(self, path): + self.path = Path(path) + + def get_folders(path): + return [ + name + for name in os.listdir(path) + if os.path.isdir(os.path.join(path, name)) + ] + + # Usage + self.splits = [] + folders = get_folders(path) + for folder in folders: + self.splits.append(int(Path(folder).parts[-1])) + self.splits = sorted(self.splits) + self._split_tds = [] + frames_per_split = {} + for split in self.splits: + path = self.path / str(split) + self._split_tds.append(self._load_split(path)) + # take away 1 because we padded with 1 empty val + frames_per_split[split] = ( + self._split_tds[-1].get(("data", "observation")).shape[0] - 1 + ) + + frames_per_split = torch.tensor( + [[split, length] for (split, length) in frames_per_split.items()] + ) + frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0) + self.frames_per_split = torch.cat( + [torch.tensor([[-1, 0]]), frames_per_split], 0 + ) + + # retrieve episodes + self.episodes = torch.cumsum( + torch.cat( + [td.get(("data", "next", "terminated")) for td in self._split_tds], 0 + ), + 0, + ) + + def __len__(self): + return self.frames_per_split[-1, 1].item() + + def _read_from_splits(self, item: int | torch.Tensor): + # We need to allocate each item to its storage. + # We don't assume each storage has the same size (too expensive to test) + # so we keep a map of each storage cumulative length and retrieve the + # storages one after the other. + split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & ( + item >= self.frames_per_split[:-1, 1].unsqueeze(1) + ) + split_tmp, idx = split.squeeze().nonzero().unbind(-1) + split = torch.zeros_like(split_tmp) + split[idx] = split_tmp + split = self.frames_per_split[split + 1, 0] + item = item - self.frames_per_split[split, 1] + assert (item >= 0).all() + if isinstance(item, int): + unique_splits = (split,) + split_inverse = None + else: + unique_splits, split_inverse = torch.unique(split, return_inverse=True) + unique_splits = unique_splits.tolist() + out = [] + for i, split in enumerate(unique_splits): + _item = item[split_inverse == i] if split_inverse is not None else item + out.append(self._proc_td(self._split_tds[split], _item)) + return torch.cat(out, 0) + + def _load_split(self, path): + return TensorDict.load_memmap(path) + + def _proc_td(self, td, index): + td_data = td.get("data") + obs_ = td_data.get(("observation"))[index + 1] + done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool() + if done.ndim and done.any(): + obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0) + td_idx = td.empty() + td_idx.set(("next", "observation"), obs_) + non_tensor = td.exclude("data").to_dict() + td_idx.update(td_data.apply(lambda x: x[index])) + if isinstance(index, torch.Tensor): + td_idx.batch_size = [len(index)] + td_idx.set_non_tensor("metadata", non_tensor) + + terminated = td_idx.get(("next", "terminated")) + zterminated = torch.zeros_like(terminated) + td_idx.set(("next", "done"), terminated.clone()) + td_idx.set(("next", "truncated"), zterminated) + td_idx.set("terminated", zterminated) + td_idx.set("done", zterminated) + td_idx.set("truncated", zterminated) + + return td_idx + + def get(self, index): + if isinstance(index, int): + return self._read_from_splits(index) + if isinstance(index, tuple): + if len(index) == 1: + return self.get(index[0]) + return self.get(index[0])[(Ellipsis, *index[1:])] + if isinstance(index, torch.Tensor): + if index.ndim <= 1: + return self._read_from_splits(index) + else: + raise RuntimeError("Only 1d tensors are accepted") + # with ThreadPoolExecutor(16) as pool: + # results = map(self.__getitem__, index.tolist()) + # return torch.stack(list(results)) + if isinstance(index, (range, list)): + return self[torch.tensor(index)] + if isinstance(index, slice): + start = index.start if index.start is not None else 0 + stop = index.stop if index.stop is not None else len(self) + step = index.step if index.step is not None else 1 + return self.get(torch.arange(start, stop, step)) + return self[torch.arange(len(self))[index]] diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index c2646f8366b..2d91da82367 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -52,7 +52,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -61,7 +61,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 8e20ebc12da..866888ae925 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -75,7 +75,7 @@ class MinariExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -84,15 +84,13 @@ class MinariExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which is recovered via ``done = truncated | terminated``. In other words, it is assumed that any ``truncated`` or ``terminated`` signal is - equivalent to the end of a trajectory. For some datasets from - ``D4RL``, this may not be true. It is up to the user to make - accurate choices regarding this usage of ``split_trajs``. + equivalent to the end of a trajectory. Defaults to ``False``. Attributes: diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index 07e5dfc8cdc..fadcc0e7f96 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -42,7 +42,7 @@ class OpenMLExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -51,7 +51,7 @@ class OpenMLExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. """ diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 5237386c200..4beb18b00a1 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -114,10 +114,10 @@ class for more information on how to interact with non-tensor data 0s. If another value is provided, it will be used for padding. If ``False`` or ``None`` (default) any encounter with a trajectory of insufficient length will raise an exception. - root (Path or str, optional): The Minari dataset root directory. + root (Path or str, optional): The OpenX dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/minari`. + ``~/.cache/torchrl/openx`. streaming (bool, optional): if ``True``, the data won't be downloaded but read from a stream instead. @@ -139,7 +139,7 @@ class for more information on how to interact with non-tensor data sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -148,15 +148,13 @@ class for more information on how to interact with non-tensor data prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which is recovered via ``done = truncated | terminated``. In other words, it is assumed that any ``truncated`` or ``terminated`` signal is - equivalent to the end of a trajectory. For some datasets from - ``D4RL``, this may not be true. It is up to the user to make - accurate choices regarding this usage of ``split_trajs``. + equivalent to the end of a trajectory. Defaults to ``False``. strict_length (bool, optional): if ``False``, trajectories of length shorter than `slice_len` (or `batch_size // num_slices`) will be diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index bcbb12a4891..825b937e8ac 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -59,7 +59,7 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -68,15 +68,13 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which is recovered via ``done = truncated | terminated``. In other words, it is assumed that any ``truncated`` or ``terminated`` signal is - equivalent to the end of a trajectory. For some datasets from - ``D4RL``, this may not be true. It is up to the user to make - accurate choices regarding this usage of ``split_trajs``. + equivalent to the end of a trajectory. Defaults to ``False``. Attributes: diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index a6e79f9b266..417c025ae59 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -67,7 +67,7 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer): sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -76,7 +76,7 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer): prefetch (int, optional): number of next batches to be prefetched using multithreading. transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class. split_trajs (bool, optional): if ``True``, the trajectories will be split along the first dimension and padded to have a matching shape. To split the trajectories, the ``"done"`` signal will be used, which diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index de9b13b8129..1e1ce31bf96 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -157,13 +157,7 @@ def __init__( self._writer = writer if writer is not None else RoundRobinWriter() self._writer.register_storage(self._storage) - self._collate_fn = ( - collate_fn - if collate_fn is not None - else _get_default_collate( - self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer) - ) - ) + self._get_collate_fn(collate_fn) self._pin_memory = pin_memory self._prefetch = bool(prefetch) @@ -201,6 +195,43 @@ def __init__( ) self._batch_size = batch_size + def _get_collate_fn(self, collate_fn): + self._collate_fn = ( + collate_fn + if collate_fn is not None + else _get_default_collate( + self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer) + ) + ) + + def set_storage(self, storage: Storage, collate_fn: Callable | None = None): + """Sets a new storage in the replay buffer and returns the previous storage. + + Args: + storage (Storage): the new storage for the buffer. + collate_fn (callable, optional): if provided, the collate_fn is set to this + value. Otherwise it is reset to a default value. + + """ + prev_storage = self._storage + self._storage = storage + self._get_collate_fn(collate_fn) + + return prev_storage + + def set_writer(self, writer: Writer): + """Sets a new writer in the replay buffer and returns the previous writer.""" + prev_writer = self._writer + self._writer = writer + self._writer.register_storage(self._storage) + return prev_writer + + def set_sampler(self, sampler: Sampler): + """Sets a new sampler in the replay buffer and returns the previous sampler.""" + prev_sampler = self._sampler + self._sampler = sampler + return prev_sampler + def __len__(self) -> int: with self._replay_lock: return len(self._storage) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 05baa2eaee1..3460f6ed51c 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -524,6 +524,14 @@ class SliceSampler(Sampler): trajectory (or episode). Defaults to ``("next", "done")``. traj_key (NestedKey, optional): the key indicating the trajectories. Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. cache_values (bool, optional): to be used with static datasets. Will cache the start and end signal of the trajectory. truncated_key (NestedKey, optional): If not ``None``, this argument @@ -612,19 +620,12 @@ def __init__( slice_len: int = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, + ends: torch.Tensor | None = None, + trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, ) -> object: - if end_key is None: - end_key = ("next", "done") - if traj_key is None: - traj_key = "episode" - if not ((num_slices is None) ^ (slice_len is None)): - raise TypeError( - "Either num_slices or slice_len must be not None, and not both. " - f"Got num_slices={num_slices} and slice_len={slice_len}." - ) self.num_slices = num_slices self.slice_len = slice_len self.end_key = end_key @@ -635,6 +636,47 @@ def __init__( self._uses_data_prefix = False self.strict_length = strict_length self._cache = {} + if trajectories is not None: + if traj_key is not None or end_key: + raise RuntimeError( + "`trajectories` and `end_key` or `traj_key` are exclusive arguments." + ) + if ends is not None: + raise RuntimeError("trajectories and ends are exclusive arguments.") + if not cache_values: + raise RuntimeError( + "To be used, trajectories requires `cache_values` to be set to `True`." + ) + vals = self._find_start_stop_traj(trajectory=trajectories) + self._cache["stop-and-length"] = vals + + elif ends is not None: + if traj_key is not None or end_key: + raise RuntimeError( + "`ends` and `end_key` or `traj_key` are exclusive arguments." + ) + if trajectories is not None: + raise RuntimeError("trajectories and ends are exclusive arguments.") + if not cache_values: + raise RuntimeError( + "To be used, ends requires `cache_values` to be set to `True`." + ) + vals = self._find_start_stop_traj(end=ends) + self._cache["stop-and-length"] = vals + + else: + if end_key is None: + end_key = ("next", "done") + if traj_key is None: + traj_key = "run" + self.end_key = end_key + self.traj_key = traj_key + + if not ((num_slices is None) ^ (slice_len is None)): + raise TypeError( + "Either num_slices or slice_len must be not None, and not both. " + f"Got num_slices={num_slices} and slice_len={slice_len}." + ) @staticmethod def _find_start_stop_traj(*, trajectory=None, end=None): @@ -696,16 +738,24 @@ def _get_stop_and_length(self, storage, fallback=True): # In the future, this may be deprecated, and we don't want to mess # with the keys provided by the user so we fall back on a proxy to # the traj key. - try: - trajectory = storage._storage.get(self._used_traj_key) - except KeyError: - trajectory = storage._storage.get(("_data", self.traj_key)) - # cache that value for future use - self._used_traj_key = ("_data", self.traj_key) - self._uses_data_prefix = ( - isinstance(self._used_traj_key, tuple) - and self._used_traj_key[0] == "_data" - ) + if isinstance(storage, TensorStorage): + try: + trajectory = storage._storage.get(self._used_traj_key) + except KeyError: + trajectory = storage._storage.get(("_data", self.traj_key)) + # cache that value for future use + self._used_traj_key = ("_data", self.traj_key) + self._uses_data_prefix = ( + isinstance(self._used_traj_key, tuple) + and self._used_traj_key[0] == "_data" + ) + else: + try: + trajectory = storage[:].get(self.traj_key) + except Exception: + raise RuntimeError( + "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." + ) vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) if self.cache_values: self._cache["stop-and-length"] = vals @@ -722,16 +772,24 @@ def _get_stop_and_length(self, storage, fallback=True): # In the future, this may be deprecated, and we don't want to mess # with the keys provided by the user so we fall back on a proxy to # the traj key. - try: - done = storage._storage.get(self._used_end_key) - except KeyError: - done = storage._storage.get(("_data", self.end_key)) - # cache that value for future use - self._used_end_key = ("_data", self.end_key) - self._uses_data_prefix = ( - isinstance(self._used_end_key, tuple) - and self._used_end_key[0] == "_data" - ) + if isinstance(storage, TensorStorage): + try: + done = storage._storage.get(self._used_end_key) + except KeyError: + done = storage._storage.get(("_data", self.end_key)) + # cache that value for future use + self._used_end_key = ("_data", self.end_key) + self._uses_data_prefix = ( + isinstance(self._used_end_key, tuple) + and self._used_end_key[0] == "_data" + ) + else: + try: + done = storage[:].get(self.end_key) + except Exception: + raise RuntimeError( + "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." + ) vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] if self.cache_values: self._cache["stop-and-length"] = vals @@ -760,11 +818,6 @@ def _adjusted_batch_size(self, batch_size): return seq_length, num_slices def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: - if not isinstance(storage, TensorStorage): - raise RuntimeError( - f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead." - ) - # pick up as many trajs as we need start_idx, stop_idx, lengths = self._get_stop_and_length(storage) seq_length, num_slices = self._adjusted_batch_size(batch_size) @@ -889,6 +942,14 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): trajectory (or episode). Defaults to ``("next", "done")``. traj_key (NestedKey, optional): the key indicating the trajectories. Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. truncated_key (NestedKey, optional): If not ``None``, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided @@ -973,6 +1034,8 @@ def __init__( drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, + ends: torch.Tensor | None = None, + trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, shuffle: bool = True, @@ -986,6 +1049,8 @@ def __init__( cache_values=True, truncated_key=truncated_key, strict_length=strict_length, + ends=ends, + trajectories=trajectories, ) SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle) From 5b67dd3c770dab9760e1f08376077d4f853ea138 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Jan 2024 07:58:35 +0000 Subject: [PATCH 04/35] [Feature] Non-functional objectives (PPO, A2C, Reinforce) (#1804) --- benchmarks/test_objectives_benchmarks.py | 6 +- examples/a2c/a2c_atari.py | 4 +- examples/a2c/a2c_mujoco.py | 4 +- .../collectors/multi_nodes/ray_train.py | 2 +- examples/impala/impala_multi_node_ray.py | 4 +- examples/impala/impala_multi_node_submitit.py | 4 +- examples/impala/impala_single_node.py | 4 +- examples/multiagent/mappo_ippo.py | 6 +- examples/ppo/ppo_atari.py | 4 +- examples/ppo/ppo_mujoco.py | 4 +- test/test_cost.py | 49 +++-- torchrl/objectives/a2c.py | 137 +++++++++--- torchrl/objectives/ppo.py | 203 +++++++++++++----- torchrl/objectives/reinforce.py | 151 ++++++++++--- tutorials/sphinx-tutorials/coding_ppo.py | 4 +- tutorials/sphinx-tutorials/multiagent_ppo.py | 4 +- 16 files changed, 433 insertions(+), 157 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07e8f5da90..4cfc8470a15 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -548,7 +548,7 @@ def test_a2c_speed( actor(td.clone()) critic(td.clone()) - loss = A2CLoss(actor=actor, critic=critic) + loss = A2CLoss(actor_network=actor, critic_network=critic) advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) @@ -605,7 +605,7 @@ def test_ppo_speed( actor(td.clone()) critic(td.clone()) - loss = ClipPPOLoss(actor=actor, critic=critic) + loss = ClipPPOLoss(actor_network=actor, critic_network=critic) advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) @@ -662,7 +662,7 @@ def test_reinforce_speed( actor(td.clone()) critic(td.clone()) - loss = ReinforceLoss(actor=actor, critic=critic) + loss = ReinforceLoss(actor_network=actor, critic_network=critic) advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index 8d19080f223..0452d7d600f 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -69,8 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_gae=True, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 4076631f1ef..2628a6f388c 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -63,8 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_gae=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 955d97113fe..2db86b9f917 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -145,7 +145,7 @@ ) loss_module = ClipPPOLoss( actor=policy_module, - critic=value_module, + critic_network=value_module, advantage_key="advantage", clip_epsilon=clip_epsilon, entropy_bonus=bool(entropy_eps), diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 46941529c00..49b3dd4bd4d 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -114,8 +114,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_adv=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 7eef42ec98f..2b89ef046a1 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -106,8 +106,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_adv=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 9a853e9bc76..f5b64e4718a 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -84,8 +84,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_adv=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index 95d340046fa..b00bb18a2a0 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -137,8 +137,8 @@ def train(cfg: "DictConfig"): # noqa: F821 # Loss loss_module = ClipPPOLoss( - actor=policy, - critic=value_module, + actor_network=policy, + critic_network=value_module, clip_epsilon=cfg.loss.clip_epsilon, entropy_coef=cfg.loss.entropy_eps, normalize_advantage=False, @@ -174,7 +174,7 @@ def train(cfg: "DictConfig"): # noqa: F821 with torch.no_grad(): loss_module.value_estimator( tensordict_data, - params=loss_module.critic_params, + params=loss_module.critic_network_params, target_params=loss_module.target_critic_params, ) current_frames = tensordict_data.numel() diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index 86685fa2642..1e69dd7678d 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_gae=False, ) loss_module = ClipPPOLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, clip_epsilon=cfg.loss.clip_epsilon, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index eca985c2069..90fe74650f5 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821 ) loss_module = ClipPPOLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, clip_epsilon=cfg.loss.clip_epsilon, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, diff --git a/test/test_cost.py b/test/test_cost.py index 8d704566c39..b8b5e265f8b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5820,7 +5820,10 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): + @pytest.mark.parametrize("functional", [True, False]) + def test_ppo( + self, loss_class, device, gradient_mode, advantage, td_est, functional + ): torch.manual_seed(self.seed) td = self._create_seq_mock_data_ppo(device=device) @@ -5850,7 +5853,7 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): else: raise NotImplementedError - loss_fn = loss_class(actor, value, loss_critic_type="l2") + loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional) if advantage is not None: advantage(td) else: @@ -6328,7 +6331,7 @@ def test_ppo_notensordict( ) value = self._create_mock_value(observation_key=observation_key) - loss = loss_class(actor=actor, critic=value) + loss = loss_class(actor_network=actor, critic_network=value) loss.set_keys( action=action_key, reward=reward_key, @@ -6537,7 +6540,8 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_a2c(self, device, gradient_mode, advantage, td_est): + @pytest.mark.parametrize("functional", (True, False)) + def test_a2c(self, device, gradient_mode, advantage, td_est, functional): torch.manual_seed(self.seed) td = self._create_seq_mock_data_a2c(device=device) @@ -6567,7 +6571,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): else: raise NotImplementedError - loss_fn = A2CLoss(actor, value, loss_critic_type="l2") + loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional) # Check error is raised when actions require grads td["action"].requires_grad = True @@ -6629,7 +6633,9 @@ def test_a2c_state_dict(self, device, gradient_mode): def test_a2c_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() - loss_fn = A2CLoss(actor=actor, critic=critic, separate_losses=separate_losses) + loss_fn = A2CLoss( + actor_network=actor, critic_network=critic, separate_losses=separate_losses + ) # Check error is raised when actions require grads td["action"].requires_grad = True @@ -6966,7 +6972,6 @@ def test_a2c_notensordict( class TestReinforce(LossModuleTestBase): seed = 0 - @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) @pytest.mark.parametrize( @@ -6979,7 +6984,12 @@ class TestReinforce(LossModuleTestBase): None, ], ) - def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): + @pytest.mark.parametrize( + "delay_value,functional", [[False, True], [False, False], [True, True]] + ) + def test_reinforce_value_net( + self, advantage, gradient_mode, delay_value, td_est, functional + ): n_obs = 3 n_act = 5 batch = 4 @@ -7023,8 +7033,9 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est loss_fn = ReinforceLoss( actor_net, - critic=value_net, + critic_network=value_net, delay_value=delay_value, + functional=functional, ) td = TensorDict( @@ -7049,7 +7060,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est if advantage is not None: params = TensorDict.from_module(value_net) if delay_value: - target_params = loss_fn.target_critic_params + target_params = loss_fn.target_critic_network_params else: target_params = None advantage(td, params=params, target_params=target_params) @@ -7108,7 +7119,7 @@ def test_reinforce_tensordict_keys(self, td_est): loss_fn = ReinforceLoss( actor_net, - critic=value_net, + critic_network=value_net, ) default_keys = { @@ -7133,7 +7144,7 @@ def test_reinforce_tensordict_keys(self, td_est): loss_fn = ReinforceLoss( actor_net, - critic=value_net, + critic_network=value_net, ) key_mapping = { @@ -7207,14 +7218,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() loss_fn = ReinforceLoss( - actor=actor, critic=critic, separate_losses=separate_losses + actor_network=actor, critic_network=critic, separate_losses=separate_losses ) loss = loss_fn(td) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.critic_params.values(True, True) + for p in loss_fn.critic_network_params.values(True, True) ) assert all( (p.grad is None) or (p.grad == 0).all() @@ -7234,14 +7245,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): for p in loss_fn.actor_network_params.values(True, True) ) common_layers = itertools.islice( - loss_fn.critic_params.values(True, True), + loss_fn.critic_network_params.values(True, True), common_layers_no, ) assert all( (p.grad is None) or (p.grad == 0).all() for p in common_layers ) critic_layers = itertools.islice( - loss_fn.critic_params.values(True, True), + loss_fn.critic_network_params.values(True, True), common_layers_no, None, ) @@ -7250,7 +7261,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): ) else: common_layers = itertools.islice( - loss_fn.critic_params.values(True, True), + loss_fn.critic_network_params.values(True, True), common_layers_no, ) assert not any( @@ -7266,7 +7277,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.critic_params.values(True, True) + for p in loss_fn.critic_network_params.values(True, True) ) else: @@ -7297,7 +7308,7 @@ def test_reinforce_notensordict( in_keys=["loc", "scale"], spec=UnboundedContinuousTensorSpec(n_act), ) - loss = ReinforceLoss(actor=actor_net, critic=value_net) + loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net) loss.set_keys( reward=reward_key, done=done_key, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 4384ccef282..397b9de4e23 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -2,18 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib +import logging import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import ( - dispatch, - ProbabilisticTensorDictSequential, - repopulate_module, - TensorDictModule, -) +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -46,8 +43,8 @@ class A2CLoss(LossModule): https://arxiv.org/abs/1602.01783v2 Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. entropy_bonus (bool): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int): if the distribution retrieved from the policy @@ -68,6 +65,10 @@ class A2CLoss(LossModule): The input tensordict key where the advantage is expected to be written. default: "advantage" value_target_key (str): [Deprecated, use set_keys() instead] the input tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -221,8 +222,8 @@ class _AcceptedKeys: def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential = None, + critic_network: TensorDictModule = None, *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, @@ -233,23 +234,49 @@ def __init__( separate_losses: bool = False, advantage_key: str = None, value_target_key: str = None, + functional: bool = True, + actor: ProbabilisticTensorDictSequential = None, + critic: ProbabilisticTensorDictSequential = None, ): + if actor is not None: + actor_network = actor + del actor + if critic is not None: + critic_network = critic + del critic + if actor_network is None or critic_network is None: + raise TypeError( + "Missing positional arguments actor_network or critic_network." + ) + + self._functional = functional self._out_keys = None super().__init__() self._set_deprecated_ctor_keys( advantage=advantage_key, value_target=value_target_key ) - self.convert_to_functional( - actor, "actor", funs_to_decorate=["forward", "get_dist"] - ) + if functional: + self.convert_to_functional( + actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] + ) + else: + self.actor_network = actor_network + if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared - policy_params = list(actor.parameters()) + policy_params = list(actor_network.parameters()) else: policy_params = None - self.convert_to_functional(critic, "critic", compare_against=policy_params) + if functional: + self.convert_to_functional( + critic_network, "critic_network", compare_against=policy_params + ) + else: + self.critic_network = critic_network + self.target_critic_network_params = None + self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus and entropy_coef @@ -265,6 +292,50 @@ def __init__( self.gamma = gamma self.loss_critic_type = loss_critic_type + @property + def functional(self): + return self._functional + + @property + def actor(self): + logging.warning( + f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " + "link will be removed in v0.4." + ) + return self.actor_network + + @property + def critic(self): + logging.warning( + f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " + "link will be removed in v0.4." + ) + return self.critic_network + + @property + def actor_params(self): + logging.warning( + f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " + "link will be removed in v0.4." + ) + return self.actor_network_params + + @property + def critic_params(self): + logging.warning( + f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.critic_network_params + + @property + def target_critic_params(self): + logging.warning( + f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.target_critic_network_params + @property def in_keys(self): keys = [ @@ -272,8 +343,8 @@ def in_keys(self): ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), ("next", self.tensor_keys.terminated), - *self.actor.in_keys, - *[("next", key) for key in self.actor.in_keys], + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], ] if self.critic_coef: keys.extend(self.critic.in_keys) @@ -326,9 +397,11 @@ def _log_probs( raise RuntimeError( f"tensordict stored {self.tensor_keys.action} require grad." ) - tensordict_clone = tensordict.select(*self.actor.in_keys).clone() - with self.actor_params.to_module(self.actor): - dist = self.actor.get_dist(tensordict_clone) + tensordict_clone = tensordict.select(*self.actor_network.in_keys).clone() + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + dist = self.actor_network.get_dist(tensordict_clone) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) return log_prob, dist @@ -339,7 +412,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - with self.critic_params.to_module(self.critic): + with self.critic_network_params.to_module( + self.critic + ) if self.functional else contextlib.nullcontext(): state_value = self.critic( tensordict_select, ).get(self.tensor_keys.value) @@ -360,8 +435,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: @property @_cache_values - def _cached_detach_critic_params(self): - return self.critic_params.detach() + def _cached_detach_critic_network_params(self): + if not self.functional: + return None + return self.critic_network_params.detach() @dispatch() def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -370,8 +447,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self._cached_detach_critic_params, - target_params=self.target_critic_params, + params=self._cached_detach_critic_network_params, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) assert not advantage.requires_grad @@ -406,11 +483,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor - actor_with_params = repopulate_module( - deepcopy(self.actor), self.actor_params - ) + if self.functional: + actor_with_params = deepcopy(self.actor_network) + self.actor_network_params.to_module(actor_with_params) + else: + actor_with_params = self.actor_network self._value_estimator = VTrace( - value_network=self.critic, actor_network=actor_with_params, **hp + value_network=self.critic_network, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index beefa2467fa..0d4714e6a2d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -2,6 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import contextlib +import logging + import math import warnings from copy import deepcopy @@ -9,12 +14,7 @@ from typing import Tuple import torch -from tensordict.nn import ( - dispatch, - ProbabilisticTensorDictSequential, - repopulate_module, - TensorDictModule, -) +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -49,8 +49,8 @@ class PPOLoss(LossModule): https://arxiv.org/abs/1707.06347 Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -82,6 +82,10 @@ class PPOLoss(LossModule): value_key (str, optional): [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to ``"state_value"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note:: The advantage (typically GAE) can be computed by the loss function or @@ -259,8 +263,8 @@ class _AcceptedKeys: def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, @@ -273,18 +277,47 @@ def __init__( advantage_key: str = None, value_target_key: str = None, value_key: str = None, + functional: bool = True, + actor: ProbabilisticTensorDictSequential = None, + critic: ProbabilisticTensorDictSequential = None, ): + if actor is not None: + actor_network = actor + del actor + if critic is not None: + critic_network = critic + del critic + if actor_network is None or critic_network is None: + raise TypeError( + "Missing positional arguments actor_network or critic_network." + ) + + self._functional = functional self._in_keys = None self._out_keys = None super().__init__() - self.convert_to_functional(actor, "actor") + if functional: + self.convert_to_functional(actor_network, "actor_network") + else: + self.actor_network = actor_network + self.actor_network_params = None + self.target_actor_network_params = None + if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared - policy_params = list(actor.parameters()) + policy_params = list(actor_network.parameters()) else: policy_params = None - self.convert_to_functional(critic, "critic", compare_against=policy_params) + if functional: + self.convert_to_functional( + critic_network, "critic_network", compare_against=policy_params + ) + else: + self.critic_network = critic_network + self.critic_network_params = None + self.target_critic_network_params = None + self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus self.separate_losses = separate_losses @@ -307,6 +340,50 @@ def __init__( value=value_key, ) + @property + def functional(self): + return self._functional + + @property + def actor(self): + logging.warning( + f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " + "link will be removed in v0.4." + ) + return self.actor_network + + @property + def critic(self): + logging.warning( + f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " + "link will be removed in v0.4." + ) + return self.critic_network + + @property + def actor_params(self): + logging.warning( + f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " + "link will be removed in v0.4." + ) + return self.actor_network_params + + @property + def critic_params(self): + logging.warning( + f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.critic_network_params + + @property + def target_critic_params(self): + logging.warning( + f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.target_critic_network_params + def _set_in_keys(self): keys = [ self.tensor_keys.action, @@ -314,9 +391,9 @@ def _set_in_keys(self): ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), ("next", self.tensor_keys.terminated), - *self.actor.in_keys, - *[("next", key) for key in self.actor.in_keys], - *self.critic.in_keys, + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.critic_network.in_keys, ] self._in_keys = list(set(keys)) @@ -378,8 +455,10 @@ def _log_weight( f"tensordict stored {self.tensor_keys.action} requires grad." ) - with self.actor_params.to_module(self.actor): - dist = self.actor.get_dist(tensordict) + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + dist = self.actor_network.get_dist(tensordict) log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) @@ -405,8 +484,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"can be used for the value loss." ) - with self.critic_params.to_module(self.critic): - state_value_td = self.critic(tensordict) + with self.critic_network_params.to_module( + self.critic_network + ) if self.functional else contextlib.nullcontext(): + state_value_td = self.critic_network(tensordict) try: state_value = state_value_td.get(self.tensor_keys.value) @@ -425,8 +506,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: @property @_cache_values - def _cached_critic_params_detached(self): - return self.critic_params.detach() + def _cached_critic_network_params_detached(self): + if not self.functional: + return None + return self.critic_network_params.detach() @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -435,8 +518,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self._cached_critic_params_detached, - target_params=self.target_critic_params, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: @@ -465,20 +548,28 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor - actor_with_params = repopulate_module( - deepcopy(self.actor), self.actor_params - ) + if self.functional: + actor_with_params = deepcopy(self.actor_network) + self.actor_network_params.to_module(actor_with_params) + else: + actor_with_params = self.actor_network self._value_estimator = VTrace( - value_network=self.critic, actor_network=actor_with_params, **hp + value_network=self.critic_network, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -502,8 +593,8 @@ class ClipPPOLoss(PPOLoss): loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage) Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. Keyword Args: clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation. @@ -537,6 +628,10 @@ class ClipPPOLoss(PPOLoss): value_key (str, optional): [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to ``"state_value"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -583,8 +678,8 @@ class ClipPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, *, clip_epsilon: float = 0.2, entropy_bonus: bool = True, @@ -598,8 +693,8 @@ def __init__( **kwargs, ): super(ClipPPOLoss, self).__init__( - actor, - critic, + actor_network, + critic_network, entropy_bonus=entropy_bonus, samples_mc_entropy=samples_mc_entropy, entropy_coef=entropy_coef, @@ -642,8 +737,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self._cached_critic_params_detached, - target_params=self.target_critic_params, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: @@ -689,8 +784,8 @@ class KLPENPPOLoss(PPOLoss): favouring a certain level of distancing between the two while still preventing them to be too much apart. Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. Keyword Args: dtarg (scalar, optional): target KL divergence. Defaults to ``0.01``. @@ -731,6 +826,10 @@ class KLPENPPOLoss(PPOLoss): value_key (str, optional): [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to ``"state_value"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: @@ -778,8 +877,8 @@ class KLPENPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, *, dtarg: float = 0.01, beta: float = 1.0, @@ -797,8 +896,8 @@ def __init__( **kwargs, ): super(KLPENPPOLoss, self).__init__( - actor, - critic, + actor_network, + critic_network, entropy_bonus=entropy_bonus, samples_mc_entropy=samples_mc_entropy, entropy_coef=entropy_coef, @@ -848,8 +947,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: if advantage is None: self.value_estimator( tensordict, - params=self._cached_critic_params_detached, - target_params=self.target_critic_params, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: @@ -859,9 +958,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: log_weight, dist = self._log_weight(tensordict) neg_loss = log_weight.exp() * advantage - previous_dist = self.actor.build_dist_from_params(tensordict) - with self.actor_params.to_module(self.actor): - current_dist = self.actor.get_dist(tensordict) + previous_dist = self.actor_network.build_dist_from_params(tensordict) + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + current_dist = self.actor_network.get_dist(tensordict) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 832af829c64..98c4d4d14d3 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -2,19 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import contextlib +import logging import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Optional import torch -from tensordict.nn import ( - dispatch, - ProbabilisticTensorDictSequential, - repopulate_module, - TensorDictModule, -) +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule @@ -41,10 +39,12 @@ class ReinforceLoss(LossModule): Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. + + Keyword Args: delay_value (bool, optional): if ``True``, a target network is needed - for the critic. Defaults to ``False``. + for the critic. Defaults to ``False``. Incompatible with ``functional=False``. loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. advantage_key (str): [Deprecated, use .set_keys(advantage_key=advantage_key) instead] @@ -57,6 +57,10 @@ class ReinforceLoss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -208,8 +212,8 @@ def __new__(cls, *args, **kwargs): def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: Optional[TensorDictModule] = None, + actor_network: ProbabilisticTensorDictSequential, + critic_network: TensorDictModule | None = None, *, delay_value: bool = False, loss_critic_type: str = "smooth_l1", @@ -217,7 +221,27 @@ def __init__( advantage_key: str = None, value_target_key: str = None, separate_losses: bool = False, + functional: bool = True, + actor: ProbabilisticTensorDictSequential = None, + critic: ProbabilisticTensorDictSequential = None, ) -> None: + if actor is not None: + actor_network = actor + del actor + if critic is not None: + critic_network = critic + del critic + if actor_network is None or critic_network is None: + raise TypeError( + "Missing positional arguments actor_network or critic_network." + ) + if not functional and delay_value: + raise RuntimeError( + "delay_value and ~functional are incompatible, as delayed value currently relies on functional calls." + ) + + self._functional = functional + super().__init__() self.in_keys = None self._set_deprecated_ctor_keys( @@ -228,29 +252,82 @@ def __init__( self.loss_critic_type = loss_critic_type # Actor - self.convert_to_functional( - actor, - "actor_network", - create_target_params=False, - ) + if self.functional: + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + ) + else: + self.actor_network = actor_network + if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared - policy_params = list(actor.parameters()) + policy_params = list(actor_network.parameters()) else: policy_params = None # Value - if critic is not None: - self.convert_to_functional( - critic, - "critic", - create_target_params=self.delay_value, - compare_against=policy_params, - ) + if critic_network is not None: + if self.functional: + self.convert_to_functional( + critic_network, + "critic_network", + create_target_params=self.delay_value, + compare_against=policy_params, + ) + else: + self.critic_network = critic_network + self.target_critic_network_params = None + if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma + @property + def functional(self): + return self._functional + + @property + def actor(self): + logging.warning( + f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " + "link will be removed in v0.4." + ) + return self.actor_network + + @property + def critic(self): + logging.warning( + f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " + "link will be removed in v0.4." + ) + return self.critic_network + + @property + def actor_params(self): + logging.warning( + f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " + "link will be removed in v0.4." + ) + return self.actor_network_params + + @property + def critic_params(self): + logging.warning( + f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.critic_network_params + + @property + def target_critic_params(self): + logging.warning( + f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.target_critic_network_params + def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( @@ -291,13 +368,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self.critic_params.detach(), - target_params=self.target_critic_params, + params=self.critic_network_params.detach() if self.functional else None, + target_params=self.target_critic_network_params + if self.functional + else None, ) advantage = tensordict.get(self.tensor_keys.advantage) # compute log-prob - with self.actor_network_params.to_module(self.actor_network): + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): tensordict = self.actor_network(tensordict) log_prob = tensordict.get(self.tensor_keys.sample_log_prob) @@ -315,7 +396,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - with self.critic_params.to_module(self.critic): + with self.critic_network_params.to_module( + self.critic + ) if self.functional else contextlib.nullcontext(): state_value = self.critic(tensordict_select).get(self.tensor_keys.value) loss_value = distance_loss( target_return, @@ -350,11 +433,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor - actor_with_params = repopulate_module( - deepcopy(self.actor), self.actor_params - ) + if self.functional: + actor_with_params = deepcopy(self.actor_network) + self.actor_network_params.to_module(actor_with_params) + else: + actor_with_params = self.actor_network self._value_estimator = VTrace( - value_network=self.critic, actor_network=actor_with_params, **hp + value_network=self.critic_network, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 51228e66da1..56f96221a40 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -555,8 +555,8 @@ ) loss_module = ClipPPOLoss( - actor=policy_module, - critic=value_module, + actor_network=policy_module, + critic_network=value_module, clip_epsilon=clip_epsilon, entropy_bonus=bool(entropy_eps), entropy_coef=entropy_eps, diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index d8726e804f4..f32d2d93b2f 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -595,8 +595,8 @@ # loss_module = ClipPPOLoss( - actor=policy, - critic=critic, + actor_network=policy, + critic_network=critic, clip_epsilon=clip_epsilon, entropy_coef=entropy_eps, normalize_advantage=False, # Important to avoid normalizing across the agent dimension From 069975374bae166666af6b61d169ee23da88c2c1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Jan 2024 11:25:06 +0000 Subject: [PATCH 05/35] [Refactor] change default CKPT_BACKEND to torch (#1830) --- torchrl/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9f04d3f87c1..7de6453d10a 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -178,7 +178,7 @@ class _Dynamic_CKPT_BACKEND: backends = ["torch", "torchsnapshot"] def _get_backend(self): - backend = os.environ.get("CKPT_BACKEND", "torchsnapshot") + backend = os.environ.get("CKPT_BACKEND", "torch") if backend == "torchsnapshot": try: import torchsnapshot # noqa: F401 From c390cf602fc79cb37d5f7bda6e44b5e9546ecda0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=A9tan=20Lepage?= <33058747+GaetanLepage@users.noreply.github.com> Date: Tue, 23 Jan 2024 12:26:56 +0100 Subject: [PATCH 06/35] pyproject.toml: remove unknown properties (#1828) --- pyproject.toml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05461e4a8c7..5b1d99711fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,12 +3,3 @@ first_party_detection = false [build-system] requires = ["setuptools", "wheel", "torch", "ninja"] - -first_party_detection = false - -target-version = ["py38"] - -excludes = [ - "gallery", - "tutorials", -] From 24d14ad4ec405f1c827284e360a02c02006663eb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Jan 2024 13:20:03 +0000 Subject: [PATCH 07/35] [Doc, Feature] Doc improvements for video recording and CSV video formats (#1829) --- .../scripts_gym_0_13/environment.yml | 1 + test/test_loggers.py | 42 ++++++++++-- torchrl/objectives/a2c.py | 4 +- torchrl/objectives/ppo.py | 3 +- torchrl/objectives/redq.py | 3 +- torchrl/record/loggers/csv.py | 64 +++++++++++++++++-- torchrl/record/loggers/tensorboard.py | 2 +- torchrl/record/loggers/wandb.py | 16 +++++ torchrl/record/recorder.py | 26 +++++++- 9 files changed, 139 insertions(+), 22 deletions(-) diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index be549ec2a5f..9efcbbfa640 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -25,3 +25,4 @@ dependencies: - patchelf - pyopengl==3.1.4 - ray<2.8.0 + - av diff --git a/test/test_loggers.py b/test/test_loggers.py index a19c8251b28..98a330d0daf 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -12,6 +12,8 @@ import pytest import torch + +from tensordict import MemoryMappedTensor from torchrl.record.loggers.csv import CSVLogger from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger @@ -150,16 +152,22 @@ def test_log_scalar(self, steps, tmpdir): assert row == f"{step},{values[i].item()}\n" @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) - def test_log_video(self, steps, tmpdir): + @pytest.mark.parametrize( + "video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else [] + ) + def test_log_video(self, steps, video_format, tmpdir): torch.manual_seed(0) exp_name = "ramala" - logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name) + logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name, video_format=video_format) # creating a sample video (T, C, H, W), where T - number of frames, # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. # the first 64 frames are black and the next 64 are white video = torch.cat( - (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + ( + torch.zeros(64, 1, 32, 32, dtype=torch.uint8), + torch.full((64, 1, 32, 32), 255, dtype=torch.uint8), + ) ) video = video[None, :] for i in range(3): @@ -171,11 +179,31 @@ def test_log_video(self, steps, tmpdir): sleep(0.01) # wait until events are registered # check that the logged videos are the same as the initial video - video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + ".pt" - logged_video = torch.load( - os.path.join(tmpdir, exp_name, "videos", video_file_name) + extention = ( + ".pt" + if video_format == "pt" + else ".memmap" + if video_format == "memmap" + else ".mp4" ) - assert torch.equal(video, logged_video), logged_video + video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + extention + path = os.path.join(tmpdir, exp_name, "videos", video_file_name) + if video_format == "pt": + logged_video = torch.load(path) + assert torch.equal(video, logged_video), logged_video + elif video_format == "memmap": + logged_video = MemoryMappedTensor.from_filename( + path, dtype=torch.uint8, shape=(1, 128, 1, 32, 32) + ) + assert torch.equal(video, logged_video), logged_video + elif video_format == "mp4": + import torchvision + + logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][ + :, :1 + ] + logged_video = logged_video.unsqueeze(0) + torch.testing.assert_close(video, logged_video) # check that we catch the error in case the format of the tensor is wrong video_wrong_format = torch.zeros(64, 2, 32, 32) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 397b9de4e23..2cdb7af2553 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -130,9 +130,7 @@ class A2CLoss(LossModule): the expected keyword arguments are: ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic. The return value is a tuple of tensors in the following order: - ``["loss_objective"]`` - + ``["loss_critic"]`` if critic_coef is not None - + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None + ``["loss_objective"]`` + ``["loss_critic"]`` if critic_coef is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None Examples: >>> import torch diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 0d4714e6a2d..5533dfd74d9 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -178,8 +178,7 @@ class PPOLoss(LossModule): the expected keyword arguments are: ``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network. The return value is a tuple of tensors in the following order: - ``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set - + ``"loss_critic"`` if critic_coef is not None. + ``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coef is not ``None``. The output keys can also be filtered using :meth:`PPOLoss.select_out_keys` method. Examples: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index d76f76ddc41..cac829964fc 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -138,8 +138,7 @@ class REDQLoss(LossModule): the expected keyword arguments are: ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: - ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", - "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``. + ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``. Examples: >>> import torch diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 6db921f3201..d9b5f45c25f 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -8,6 +8,8 @@ from typing import Dict, Optional, Sequence, Union import torch + +from tensordict import MemoryMappedTensor from torch import Tensor from .common import Logger @@ -16,11 +18,13 @@ class CSVExperiment: """A CSV logger experiment class.""" - def __init__(self, log_dir: str): + def __init__(self, log_dir: str, *, video_format="pt", video_fps=30): self.scalars = defaultdict(lambda: []) self.videos_counter = defaultdict(lambda: 0) self.text_counter = defaultdict(lambda: 0) self.log_dir = log_dir + self.video_format = video_format + self.video_fps = video_fps os.makedirs(self.log_dir, exist_ok=True) os.makedirs(os.path.join(self.log_dir, "scalars"), exist_ok=True) os.makedirs(os.path.join(self.log_dir, "videos"), exist_ok=True) @@ -44,12 +48,43 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs if global_step is None: global_step = self.videos_counter[tag] self.videos_counter[tag] += 1 + if self.video_format == "pt": + extension = ".pt" + elif self.video_format == "memmap": + extension = ".memmap" + elif self.video_format == "mp4": + extension = ".mp4" + else: + raise ValueError( + f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'." + ) + filepath = os.path.join( - self.log_dir, "videos", "_".join([tag, str(global_step)]) + ".pt" + self.log_dir, "videos", "_".join([tag, str(global_step)]) + extension ) path_to_create = Path(str(filepath)).parent os.makedirs(path_to_create, exist_ok=True) - torch.save(vid_tensor, filepath) + if self.video_format == "pt": + torch.save(vid_tensor, filepath) + elif self.video_format == "memmap": + MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath) + elif self.video_format == "mp4": + import torchvision + + if vid_tensor.shape[-3] not in (3, 1): + raise RuntimeError( + "expected the video tensor to be of format [T, C, H, W] but the third channel " + f"starting from the end isn't in (1, 3) but is {vid_tensor.shape[-3]}." + ) + if vid_tensor.ndim > 4: + vid_tensor = vid_tensor.flatten(0, vid_tensor.ndim - 4) + vid_tensor = vid_tensor.permute((0, 2, 3, 1)) + vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3) + torchvision.io.write_video(filepath, vid_tensor, fps=self.video_fps) + else: + raise ValueError( + f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'." + ) def add_text(self, tag, text, global_step: Optional[int] = None): if global_step is None: @@ -77,20 +112,37 @@ class CSVLogger(Logger): Args: exp_name (str): The name of the experiment. + log_dir (str or Path, optional): where the experiment should be saved. + Defaults to ``/csv_logs``. + video_format (str, optional): how videos should be saved. Must be one of + ``"pt"`` (video saved as a `video__.pt` file with torch.save), + ``"memmap"`` (video saved as a `video__.memmap` file with :class:`~tensordict.MemoryMappedTensor`), + ``"mp4"`` (video saved as a `video__.mp4` file, requires torchvision to be installed). + Defaults to ``"pt"``. + video_fps (int, optional): the video frames-per-seconds if `video_format="mp4"`. Defaults to 30. """ - def __init__(self, exp_name: str, log_dir: Optional[str] = None) -> None: + def __init__( + self, + exp_name: str, + log_dir: Optional[str] = None, + video_format: str = "pt", + video_fps: int = 30, + ) -> None: if log_dir is None: log_dir = "csv_logs" + self.video_format = video_format + self.video_fps = video_fps super().__init__(exp_name=exp_name, log_dir=log_dir) - self._has_imported_moviepy = False def _create_experiment(self) -> "CSVExperiment": """Creates a CSV experiment.""" log_dir = str(os.path.join(self.log_dir, self.exp_name)) - return CSVExperiment(log_dir) + return CSVExperiment( + log_dir, video_format=self.video_format, video_fps=self.video_fps + ) def log_scalar(self, name: str, value: float, step: int = None) -> None: """Logs a scalar value to the tensorboard. diff --git a/torchrl/record/loggers/tensorboard.py b/torchrl/record/loggers/tensorboard.py index 12e52a91a64..7327e4aff33 100644 --- a/torchrl/record/loggers/tensorboard.py +++ b/torchrl/record/loggers/tensorboard.py @@ -20,7 +20,7 @@ class TensorboardLogger(Logger): Args: exp_name (str): The name of the experiment. - log_dir (str): the tensorboard log_dir. + log_dir (str): the tensorboard log_dir. Defaults to ``td_logs``. """ diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index 9a818753956..441360a48bb 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -19,8 +19,24 @@ class WandbLogger(Logger): """Wrapper for the wandb logger. + The keyword arguments are mainly based on the :func:`wandb.init` kwargs. + See the doc `here `__. + Args: exp_name (str): The name of the experiment. + offline (bool, optional): if ``True``, the logs will be stored locally + only. Defaults to ``False``. + save_dir (path, optional): the directory where to save data. Exclusive with + ``log_dir``. + log_dir (path, optional): the directory where to save data. Exclusive with + ``save_dir``. + id (str, optional): A unique ID for this run, used for resuming. + It must be unique in the project, and if you delete a run you can't reuse the ID. + project (str, optional): The name of the project where you're sending + the new run. If the project is not specified, the run is put in + an ``"Uncategorized"`` project. + **kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for + more info. """ diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 7883cef26ee..1910c920a41 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -30,7 +30,8 @@ class VideoRecorder(ObservationTransform): Args: logger (Logger): a Logger instance where the video - should be written. + should be written. To save the video under a memmap tensor or an mp4 file, use + the :class:`~torchrl.record.loggers.CSVLogger` class. tag (str): the video tag in the logger. in_keys (Sequence of NestedKey, optional): keys to be read to produce the video. Default is :obj:`"pixels"`. @@ -43,6 +44,29 @@ class VideoRecorder(ObservationTransform): out_keys (sequence of NestedKey, optional): destination keys. Defaults to ``in_keys`` if not provided. + Examples: + The following example shows how to save a rollout under a video. First a few imports: + >>> from torchrl.record import VideoRecorder + >>> from torchrl.record.loggers.csv import CSVLogger + >>> from torchrl.envs import TransformedEnv, DMControlEnv + + The video format is chosen in the logger. Wandb and tensorboard will take care of that + on their own, CSV accepts various video formats. + >>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4") + + Some envs (eg, Atari games) natively return images, some require the user to ask for them. + Check :class:`~torchrl.env.GymEnv` or :class:`~torchrl.envs.DMControlEnv` to see how to render images + in these contexts. + >>> base_env = DMControlEnv("cheetah", "run", from_pixels=True) + >>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video")) + >>> env.rollout(100) + + All transforms have a dump function, mostly a no-op except for ``VideoRecorder``, and :class:`~torchrl.envs.transforms.Composite` + which will dispatch the `dumps` to all its members. + >>> env.transform.dump() + + Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``! + """ def __init__( From da7904eb1992cb2ce6937caa82594d3ec7d2a006 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jan 2024 18:15:18 +0000 Subject: [PATCH 08/35] [Feature] PyTrees in replay buffers (#1831) --- docs/source/reference/data.rst | 114 +++- test/test_cost.py | 16 +- test/test_rb.py | 575 ++++++++++++------ test/test_transforms.py | 4 +- torchrl/_utils.py | 7 +- torchrl/data/replay_buffers/replay_buffers.py | 109 +++- torchrl/data/replay_buffers/storages.py | 395 +++++++++--- torchrl/data/replay_buffers/transforms.py | 22 + torchrl/data/replay_buffers/writers.py | 11 +- torchrl/envs/transforms/transforms.py | 12 +- tutorials/sphinx-tutorials/rb_tutorial.py | 44 ++ 11 files changed, 985 insertions(+), 324 deletions(-) create mode 100644 torchrl/data/replay_buffers/transforms.py diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 2def1b4bfa8..91391c6af36 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -22,7 +22,108 @@ widely used replay buffers: Composable Replay Buffers ------------------------- -We also give users the ability to compose a replay buffer using the following components: +We also give users the ability to compose a replay buffer. +We provide a wide panel of solutions for replay buffer usage, including support for +almost any data type; storage in memory, on device or on physical memory; +several sampling strategies; usage of transforms etc. + +Supported data types and choosing a storage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In theory, replay buffers support any data type but we can't guarantee that each +component will support any data type. The most crude replay buffer implementation +is made of a :class:`~torchrl.data.replay_buffers.ReplayBuffer` base with a +:class:`~torchrl.data.replay_buffers.ListStorage` storage. This is very inefficient +but it will allow you to store complex data structures with non-tensor data. +Storages in contiguous memory include :class:`~torchrl.data.replay_buffers.TensorStorage`, +:class:`~torchrl.data.replay_buffers.LazyTensorStorage` and +:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`. +These classes support :class:`~tensordict.TensorDict` data as first-class citizens, but also +any PyTree data structure (eg, tuples, lists, dictionaries and nested versions +of these). The :class:`~torchrl.data.replay_buffers.TensorStorage` storage requires +you to provide the storage at construction time, whereas :class:`~torchrl.data.replay_buffers.TensorStorage` +(RAM, CUDA) and :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` (physical memory) +will preallocate the storage for you after they've been extended the first time. + +Here are a few examples, starting with the generic :class:`~torchrl.data.replay_buffers.ListStorage`: + + >>> from torchrl.data.replay_buffers import ReplayBuffer, ListStorage + >>> rb = ReplayBuffer(storage=ListStorage(10)) + >>> rb.add("a string!") # first element will be a string + >>> rb.extend([30, None]) # element [1] is an int, [2] is None + +Using a :class:`~torchrl.data.replay_buffers.TensorStorage` we tell our RB that +we want the storage to be contiguous, which is by far more efficient but also +more restrictive: + + >>> import torch + >>> from torchrl.data.replay_buffers import ReplayBuffer, TensorStorage + >>> container = torch.empty(10, 3, 64, 64, dtype=torch.unit8) + >>> rb = ReplayBuffer(storage=TensorStorage(container)) + >>> img = torch.randint(255, (3, 64, 64), dtype=torch.uint8) + >>> rb.add(img) + +Next we can avoid creating the container and ask the storage to do it automatically. +This is very useful when using PyTrees and tensordicts! For PyTrees as other data +structures, :meth:`~torchrl.data.replay_buffers.ReplayBuffer.add` considers the sampled +passed to it as a single instance of the type. :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` +on the other hand will consider that the data is an iterable. For tensors, tensordicts +and lists (see below), the iterable is looked for at the root level. For PyTrees, +we assume that the leading dimension of all the leaves (tensors) in the tree +match. If they don't, ``extend`` will throw an exception. + + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyMemmapStorage + >>> rb_td = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=1) # max 10 elements stored + >>> rb_td.add(TensorDict({"img": torch.randint(255, (3, 64, 64), dtype=torch.unit8), + ... "labels": torch.randint(100, ())}, batch_size=[])) + >>> rb_pytree = ReplayBuffer(storage=LazyMemmapStorage(10)) # max 10 elements stored + >>> # extend with a PyTree where all tensors have the same leading dim (3) + >>> rb_pytree.extend({"a": {"b": torch.randn(3), "c": [torch.zeros(3, 2), (torch.ones(3, 10),)]}}) + >>> assert len(rb_pytree) == 3 # the replay buffer has 3 elements! + +.. note:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an + ambiguous signature when dealing with lists of values, which should be interpreted + either as PyTree (in which case all elements in the list will be put in a slice + in the stored PyTree in the storage) or a list of values to add one at a time. + To solve this, TorchRL makes the clear-cut distinction between list and tuple: + a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted + as a stack of values to add one at a time to the buffer. + +Sampling and indexing +~~~~~~~~~~~~~~~~~~~~~ + +Replay buffers can be indexed and sampled. +Indexing and sampling collect data at given indices in the storage and then process them +through a series of transforms and ``collate_fn`` that can be passed to the `__init__` +function of the replay buffer. ``collate_fn`` comes with default values that should +match user expectations in the majority of cases, such that you should not have +to worry about it most of the time. Transforms are usually instances of :class:`~torchrl.envs.transforms.Transform` +even though regular functions will work too (in the latter case, the :meth:`~torchrl.envs.transforms.Transform.inv` +method will obviously be ignored, whereas in the first case it can be used to +preprocess the data before it is passed to the buffer). +Finally, sampling can be achieved using multithreading by passing the number of threads +to the constructor through the ``prefetch`` keyword argument. We advise users to +benchmark this technique in real life settings before adopting it, as there is +no guarantee that it will lead to a faster throughput in practice depending on +the machine and setting where it is used. + +When sampling, the ``batch_size`` can be either passed during construction +(e.g., if it's constant throughout training) or +to the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method. + +To further refine the sampling strategy, we advise you to look into our samplers! + +Here are a couple of examples of how to get data out of a replay buffer: + + >>> first_elt = rb_td[0] + >>> storage = rb_td[:] # returns all valid elements from the buffer + >>> sample = rb_td.sample(128) + >>> for data in rb_td: # iterate over the buffer using the sampler -- batch-size was set in the constructor to 1 + ... print(data) + +using the following components: .. currentmodule:: torchrl.data.replay_buffers @@ -48,9 +149,14 @@ We also give users the ability to compose a replay buffer using the following co TensorDictRoundRobinWriter TensorDictMaxValueWriter -Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes. -:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery. -The following mean sampling latency improvements over using ListStorage were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage. +Storage choice is very influential on replay buffer sampling latency, especially +in distributed reinforcement learning settings with larger data volumes. +:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly +advised in distributed settings with shared storage due to the lower serialisation +cost of MemoryMappedTensors as well as the ability to specify file storage locations +for improved node failure recovery. +The following mean sampling latency improvements over using :class:`~torchrl.data.replay_buffers.ListStorage` +were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/benchmarks/storage. +-------------------------------+-----------+ | Storage Type | Speed up | diff --git a/test/test_cost.py b/test/test_cost.py index b8b5e265f8b..9561f6063e4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8,7 +8,6 @@ import functools import itertools import operator -import re import warnings from copy import deepcopy from dataclasses import asdict, dataclass @@ -270,14 +269,11 @@ def forward(self, td): loss_module.set_vmap_randomness(vmap_randomness) # Fail case elif vmap_randomness == "error" and dropout > 0.0: - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises( + RuntimeError, + match="vmap: called random operation while in randomness error mode", + ): loss_module(td)["loss"] - - # Accessing cause of the caught exception - cause = exc_info.value.__cause__ - assert re.match( - r"vmap: called random operation while in randomness error mode", str(cause) - ) return loss_module(td)["loss"] @@ -1238,7 +1234,7 @@ def test_mixer_keys( # Wthout etting the keys if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): loss(td) elif unravel_key(mixer_global_chosen_action_value_key) != "chosen_action_value": with pytest.raises( @@ -1253,7 +1249,7 @@ def test_mixer_keys( loss.set_keys(global_value=mixer_global_chosen_action_value_key) if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): with pytest.raises( - RuntimeError + KeyError ): # The mixer in key still does not match the actor out_key loss(td) else: diff --git a/test/test_rb.py b/test/test_rb.py index 5d184c365e2..96f392d5a22 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -15,11 +15,14 @@ import numpy as np import pytest import torch + from _utils_internal import get_default_devices, make_tc +from packaging import version from packaging.version import parse from tensordict import is_tensor_collection, is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase from torch import multiprocessing as mp +from torch.utils._pytree import tree_flatten, tree_map from torchrl.data import ( PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, @@ -80,22 +83,35 @@ _has_snapshot = importlib.util.find_spec("torchsnapshot") is not None _os_is_windows = sys.platform == "win32" +torch_2_3 = version.parse( + ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) +) >= version.parse("2.3.0") + -@pytest.mark.parametrize( - "rb_type", - [ - ReplayBuffer, - TensorDictReplayBuffer, - RemoteTensorDictReplayBuffer, - ], -) @pytest.mark.parametrize( "sampler", [samplers.RandomSampler, samplers.PrioritizedSampler] ) @pytest.mark.parametrize( "writer", [writers.RoundRobinWriter, writers.TensorDictMaxValueWriter] ) -@pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage]) +@pytest.mark.parametrize( + "rb_type,storage,datatype", + [ + [ReplayBuffer, ListStorage, None], + [TensorDictReplayBuffer, ListStorage, "tensordict"], + [RemoteTensorDictReplayBuffer, ListStorage, "tensordict"], + [ReplayBuffer, LazyTensorStorage, "tensor"], + [ReplayBuffer, LazyTensorStorage, "tensordict"], + [ReplayBuffer, LazyTensorStorage, "pytree"], + [TensorDictReplayBuffer, LazyTensorStorage, "tensordict"], + [RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"], + [ReplayBuffer, LazyMemmapStorage, "tensor"], + [ReplayBuffer, LazyMemmapStorage, "tensordict"], + [ReplayBuffer, LazyMemmapStorage, "pytree"], + [TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], + [RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], + ], +) @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage): @@ -112,38 +128,73 @@ def _get_rb(self, rb_type, size, sampler, writer, storage): rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3) return rb - def _get_datum(self, rb_type): - if rb_type is ReplayBuffer: + def _get_datum(self, datatype): + if datatype is None: data = torch.randint(100, (1,)) - elif ( - rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer - ): + elif datatype == "tensor": + data = torch.randint(100, (1,)) + elif datatype == "tensordict": data = TensorDict( {"a": torch.randint(100, (1,)), "next": {"reward": torch.randn(1)}}, [] ) + elif datatype == "pytree": + data = { + "a": torch.randint(100, (1,)), + "b": {"c": [torch.zeros(3), (torch.ones(2),)]}, + 30: torch.zeros(2), + } else: - raise NotImplementedError(rb_type) + raise NotImplementedError(datatype) return data - def _get_data(self, rb_type, size): - if rb_type is ReplayBuffer: - data = torch.randint(100, (size, 1)) - elif ( - rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer - ): + def _get_data(self, datatype, size): + if datatype is None: + data = torch.randint( + 100, + ( + size, + 1, + ), + ) + elif datatype == "tensor": + data = torch.randint( + 100, + ( + size, + 1, + ), + ) + elif datatype == "tensordict": data = TensorDict( { - "a": torch.randint(100, (size,)), - "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), + "a": torch.randint( + 100, + ( + size, + 1, + ), + ), "next": {"reward": torch.randn(size, 1)}, }, [size], ) + elif datatype == "pytree": + data = { + "a": torch.randint( + 100, + ( + size, + 1, + ), + ), + "b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]}, + 30: torch.zeros(size, 2), + } else: - raise NotImplementedError(rb_type) + raise NotImplementedError(datatype) return data - def test_add(self, rb_type, sampler, writer, storage, size): + def test_add(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( "Distributed package support on Windows is a prototype feature and is subject to changes." @@ -152,8 +203,8 @@ def test_add(self, rb_type, sampler, writer, storage, size): rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) - data = self._get_datum(rb_type) - if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + data = self._get_datum(datatype) + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( RuntimeError, match="expects data to be a tensor collection" ): @@ -161,21 +212,31 @@ def test_add(self, rb_type, sampler, writer, storage, size): return rb.add(data) s = rb.sample(1) - assert s.ndim, s - s = s[0] + if isinstance(s, (torch.Tensor, TensorDictBase)): + assert s.ndim, s + s = s[0] + else: + + def assert_ndim(tensor): + assert tensor.shape[0] == 1 + + tree_map(assert_ndim, s) + s = tree_map(lambda s: s[0], s) if isinstance(s, TensorDictBase): s = s.select(*data.keys(True), strict=False) data = data.select(*s.keys(True), strict=False) assert (s == data).all() assert list(s.keys(True, True)) else: - assert (s == data).all() + flat_s = tree_flatten(s)[0] + flat_data = tree_flatten(data)[0] + assert all((_s == _data).all() for (_s, _data) in zip(flat_s, flat_data)) - def test_cursor_position(self, rb_type, sampler, writer, storage, size): + def test_cursor_position(self, rb_type, sampler, writer, storage, size, datatype): storage = storage(size) writer = writer() writer.register_storage(storage) - batch1 = self._get_data(rb_type, size=5) + batch1 = self._get_data(datatype, size=5) cond = ( OLD_TORCH and not isinstance(writer, TensorDictMaxValueWriter) @@ -183,7 +244,7 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size): and isinstance(storage, TensorStorage) ) - if isinstance(batch1, torch.Tensor) and isinstance( + if not is_tensor_collection(batch1) and isinstance( writer, TensorDictMaxValueWriter ): with pytest.raises( @@ -213,11 +274,11 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size): else: assert writer._cursor == 0 if not isinstance(writer, TensorDictMaxValueWriter): - batch2 = self._get_data(rb_type, size=size - 1) + batch2 = self._get_data(datatype, size=size - 1) writer.extend(batch2) assert writer._cursor == size - 1 - def test_extend(self, rb_type, sampler, writer, storage, size): + def test_extend(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( "Distributed package support on Windows is a prototype feature and is subject to changes." @@ -226,20 +287,21 @@ def test_extend(self, rb_type, sampler, writer, storage, size): rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) - data = self._get_data(rb_type, size=5) + data_shape = 5 + data = self._get_data(datatype, size=data_shape) cond = ( OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data) and isinstance(rb._storage, TensorStorage) ) - if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( RuntimeError, match="expects data to be a tensor collection" ): rb.extend(data) return - length = min(rb._storage.max_size, len(rb) + data.shape[0]) + length = min(rb._storage.max_size, len(rb) + data_shape) if writer is TensorDictMaxValueWriter: data["next", "reward"][-length:] = 1_000_000 with pytest.warns( @@ -248,22 +310,35 @@ def test_extend(self, rb_type, sampler, writer, storage, size): ) if cond else contextlib.nullcontext(): rb.extend(data) length = len(rb) - for d in data[-length:]: + if is_tensor_collection(data): + data_iter = data[-length:] + else: + + def data_iter(): + for t in range(-length, -1): + yield tree_map(lambda x, t=t: x[t], data) + + data_iter = data_iter() + for d in data_iter: for b in rb._storage: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) b = b.exclude("index").select(*keys, strict=False) keys = set(d.keys()).intersection(b.keys()) d = d.select(*keys, strict=False) - - value = b == d - if isinstance(value, (torch.Tensor, TensorDictBase)): + if isinstance(b, (torch.Tensor, TensorDictBase)): + value = b == d value = value.all() + else: + d_flat = tree_flatten(d)[0] + b_flat = tree_flatten(b)[0] + value = all((_b == _d).all() for (_b, _d) in zip(b_flat, d_flat)) if value: break else: raise RuntimeError("did not find match") - data2 = self._get_data(rb_type, size=2 * size + 2) + + data2 = self._get_data(datatype, size=2 * size + 2) cond = ( OLD_TORCH and writer is not TensorDictMaxValueWriter @@ -276,7 +351,7 @@ def test_extend(self, rb_type, sampler, writer, storage, size): ) if cond else contextlib.nullcontext(): rb.extend(data2) - def test_sample(self, rb_type, sampler, writer, storage, size): + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( "Distributed package support on Windows is a prototype feature and is subject to changes." @@ -285,14 +360,14 @@ def test_sample(self, rb_type, sampler, writer, storage, size): rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) - data = self._get_data(rb_type, size=5) + data = self._get_data(datatype, size=5) cond = ( OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data) and isinstance(rb._storage, TensorStorage) ) - if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( RuntimeError, match="expects data to be a tensor collection" ): @@ -303,27 +378,50 @@ def test_sample(self, rb_type, sampler, writer, storage, size): match="A cursor of length superior to the storage capacity was provided", ) if cond else contextlib.nullcontext(): rb.extend(data) - new_data = rb.sample() - if not isinstance(new_data, (torch.Tensor, TensorDictBase)): - new_data = new_data[0] + rb_sample = rb.sample() + # if not isinstance(new_data, (torch.Tensor, TensorDictBase)): + # new_data = new_data[0] - for d in new_data: - for b in data: - if isinstance(b, TensorDictBase): - keys = set(d.keys()).intersection(b.keys()) - b = b.exclude("index").select(*keys, strict=False) - keys = set(d.keys()).intersection(b.keys()) - d = d.select(*keys, strict=False) + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + rb_sample_iter = rb_sample + else: - value = b == d - if isinstance(value, (torch.Tensor, TensorDictBase)): + def data_iter_func(maxval, data=data): + for t in range(maxval): + yield tree_map(lambda x, t=t: x[t], data) + + rb_sample_iter = data_iter_func(rb._batch_size, rb_sample) + + for single_sample in rb_sample_iter: + + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + data_iter = data + else: + data_iter = data_iter_func(5, data) + + for data_sample in data_iter: + if isinstance(data_sample, TensorDictBase): + keys = set(single_sample.keys()).intersection(data_sample.keys()) + data_sample = data_sample.exclude("index").select( + *keys, strict=False + ) + keys = set(single_sample.keys()).intersection(data_sample.keys()) + single_sample = single_sample.select(*keys, strict=False) + + if isinstance(data_sample, (torch.Tensor, TensorDictBase)): + value = data_sample == single_sample value = value.all() + else: + d_flat = tree_flatten(single_sample)[0] + b_flat = tree_flatten(data_sample)[0] + value = all((_b == _d).all() for (_b, _d) in zip(b_flat, d_flat)) + if value: break else: raise RuntimeError("did not find match") - def test_index(self, rb_type, sampler, writer, storage, size): + def test_index(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( "Distributed package support on Windows is a prototype feature and is subject to changes." @@ -332,14 +430,14 @@ def test_index(self, rb_type, sampler, writer, storage, size): rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) - data = self._get_data(rb_type, size=5) + data = self._get_data(datatype, size=5) cond = ( OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data) and isinstance(rb._storage, TensorStorage) ) - if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( RuntimeError, match="expects data to be a tensor collection" ): @@ -354,12 +452,17 @@ def test_index(self, rb_type, sampler, writer, storage, size): d2 = rb._storage[2] if type(d1) is not type(d2): d1 = d1[0] - b = d1 == d2 - if not isinstance(b, bool): - b = b.all() + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + b = d1 == d2 + if not isinstance(b, bool): + b = b.all() + else: + d1_flat = tree_flatten(d1)[0] + d2_flat = tree_flatten(d2)[0] + b = all((_d1 == _d2).all() for (_d1, _d2) in zip(d1_flat, d2_flat)) assert b - def test_pickable(self, rb_type, sampler, writer, storage, size): + def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size @@ -381,6 +484,13 @@ def _get_tensordict(self): [10, 11], ) + def _get_pytree(self): + return { + "a": torch.randint(100, (10, 11, 1)), + "b": {"c": [torch.zeros(10, 11), (torch.ones(10, 11),)]}, + 30: torch.zeros(10, 11), + } + def _get_tensorclass(self): data = self._get_tensordict() return make_tc(data)(**data, batch_size=data.shape) @@ -395,7 +505,9 @@ def test_errors(self, storage_type): ): storage_type(data, max_size=4) - @pytest.mark.parametrize("data_type", ["tensor", "tensordict", "tensorclass"]) + @pytest.mark.parametrize( + "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] + ) @pytest.mark.parametrize("storage_type", [TensorStorage]) def test_get_set(self, storage_type, data_type): if data_type == "tensor": @@ -404,13 +516,25 @@ def test_get_set(self, storage_type, data_type): data = self._get_tensorclass() elif data_type == "tensordict": data = self._get_tensordict() + elif data_type == "pytree": + data = self._get_pytree() else: raise NotImplementedError storage = storage_type(data) - storage.set(range(10), torch.zeros_like(data)) - assert (storage.get(range(10)) == 0).all() + if data_type == "pytree": + storage.set(range(10), tree_map(torch.zeros_like, data)) + + def check(x): + assert (x == 0).all() + + tree_map(check, storage.get(range(10))) + else: + storage.set(range(10), torch.zeros_like(data)) + assert (storage.get(range(10)) == 0).all() - @pytest.mark.parametrize("data_type", ["tensor", "tensordict", "tensorclass"]) + @pytest.mark.parametrize( + "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] + ) @pytest.mark.parametrize("storage_type", [TensorStorage]) def test_state_dict(self, storage_type, data_type): if data_type == "tensor": @@ -419,9 +543,15 @@ def test_state_dict(self, storage_type, data_type): data = self._get_tensorclass() elif data_type == "tensordict": data = self._get_tensordict() + elif data_type == "pytree": + data = self._get_pytree() else: raise NotImplementedError storage = storage_type(data) + if data_type == "pytree": + with pytest.raises(TypeError, match="are not supported by"): + storage.state_dict() + return sd = storage.state_dict() storage2 = storage_type(torch.zeros_like(data)) storage2.load_state_dict(sd) @@ -530,7 +660,7 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): @pytest.mark.parametrize("device_data", get_default_devices()) @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - @pytest.mark.parametrize("data_type", ["tensor", "tc", "td"]) + @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) @pytest.mark.parametrize("isinit", [True, False]) def test_storage_dumps_loads( self, device_data, storage_type, data_type, isinit, tmpdir @@ -549,6 +679,12 @@ class TC: if data_type == "tensor": data = torch.randint(10, (3,), device=device_data) + elif data_type == "pytree": + data = { + "a": torch.randint(10, (3,), device=device_data), + "b": {"c": [torch.ones(3), (-torch.ones(3, 2),)]}, + 30: -torch.ones(3, 1), + } elif data_type == "td": data = TensorDict( { @@ -578,18 +714,39 @@ class TC: else: storage = storage_type(max_size=10) # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index - storage.set(range(3), data.cpu()) + if data_type == "pytree": + storage.set(range(3), tree_map(lambda x: x.cpu(), data)) + else: + storage.set(range(3), data.cpu()) storage.dumps(dir_save) # check we can dump twice storage.dumps(dir_save) + storage_recover = storage_type(max_size=10) if isinit: - storage_recover.set(range(3), data.cpu().zero_()) + if data_type == "pytree": + storage_recover.set(range(3), tree_map(lambda x: x.cpu().zero_(), data)) + else: + storage_recover.set(range(3), data.cpu().zero_()) + + if data_type in ("tensor", "pytree") and not isinit: + with pytest.raises( + RuntimeError, + match="Cannot fill a non-initialized pytree-based TensorStorage", + ): + storage_recover.loads(dir_save) + return storage_recover.loads(dir_save) - if data_type == "tensor": - torch.testing.assert_close(storage._storage, storage_recover._storage) - else: - assert_allclose_td(storage._storage, storage_recover._storage) + # tree_map with more than one pytree is only available in torch >= 2.3 + if torch_2_3: + if data_type in ("tensor", "pytree"): + tree_map( + torch.testing.assert_close, + tree_flatten(storage._storage)[0], + tree_flatten(storage_recover._storage)[0], + ) + else: + assert_allclose_td(storage._storage, storage_recover._storage) if data == "tc": assert storage._storage.text == storage_recover._storage.text @@ -872,7 +1029,7 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): ) if cond else contextlib.nullcontext(): rb.extend(batch1) - # Added less data than storage max size + # Added fewer data than storage max size if size > 5 or storage is None: assert rb._writer._cursor == 5 # Added more data than storage max size @@ -1180,142 +1337,166 @@ def test_shared_storage_prioritized_sampler(): assert rb1._sampler._sum_tree.query(0, 70) == 50 -def test_append_transform(): - rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), batch_size=1) - td = TensorDict( - { - "observation": torch.randn(2, 4, 3, 16), - "observation2": torch.randn(2, 4, 3, 16), - }, - [], - ) - rb.add(td) - flatten = CatTensors( - in_keys=["observation", "observation2"], out_key="observation_cat" - ) - - rb.append_transform(flatten) - - sampled = rb.sample() - assert sampled.get("observation_cat").shape[-1] == 32 - - -def test_init_transform(): - flatten = FlattenObservation( - -2, -1, in_keys=["observation"], out_keys=["flattened"] - ) - - rb = ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 - ) - - td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) - rb.add(td) - sampled = rb.sample() - assert sampled.get("flattened").shape[-1] == 48 +class TestTransforms: + def test_append_transform(self): + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), batch_size=1) + td = TensorDict( + { + "observation": torch.randn(2, 4, 3, 16), + "observation2": torch.randn(2, 4, 3, 16), + }, + [], + ) + rb.add(td) + flatten = CatTensors( + in_keys=["observation", "observation2"], out_key="observation_cat" + ) + rb.append_transform(flatten) -def test_insert_transform(): - flatten = FlattenObservation( - -2, -1, in_keys=["observation"], out_keys=["flattened"] - ) - rb = ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 - ) - td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) - rb.add(td) + sampled = rb.sample() + assert sampled.get("observation_cat").shape[-1] == 32 - rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) + def test_init_transform(self): + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) - sampled = rb.sample() - assert sampled.get("flattened").shape[-1] == 48 + rb = ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 + ) - with pytest.raises(ValueError): - rb.insert_transform(10, SqueezeTransform(-1, in_keys=["observation"])) + td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) + rb.add(td) + sampled = rb.sample() + assert sampled.get("flattened").shape[-1] == 48 + def test_insert_transform(self): + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) + rb = ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 + ) + td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) + rb.add(td) -transforms = [ - ToTensorImage, - pytest.param( - partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping" - ), - BinarizeReward, - pytest.param( - partial(Resize, w=2, h=2), - id="Resize", - marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), - ), - pytest.param( - partial(CenterCrop, w=1), - id="CenterCrop", - marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), - ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), - GrayScale, - pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), - pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), - pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), - DoubleToFloat, - VecNorm, -] + rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) + sampled = rb.sample() + assert sampled.get("flattened").shape[-1] == 48 -@pytest.mark.parametrize("transform", transforms) -def test_smoke_replay_buffer_transform(transform): - rb = TensorDictReplayBuffer( - transform=transform(in_keys=["observation"]), batch_size=1 - ) + with pytest.raises(ValueError): + rb.insert_transform(10, SqueezeTransform(-1, in_keys=["observation"])) - # td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, []) - td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 3)}, []) - rb.add(td) + transforms = [ + ToTensorImage, + pytest.param( + partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping" + ), + BinarizeReward, + pytest.param( + partial(Resize, w=2, h=2), + id="Resize", + marks=pytest.mark.skipif( + not _has_tv, reason="needs torchvision dependency" + ), + ), + pytest.param( + partial(CenterCrop, w=1), + id="CenterCrop", + marks=pytest.mark.skipif( + not _has_tv, reason="needs torchvision dependency" + ), + ), + pytest.param( + partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" + ), + pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + GrayScale, + pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), + pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), + pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), + DoubleToFloat, + VecNorm, + ] + + @pytest.mark.parametrize("transform", transforms) + def test_smoke_replay_buffer_transform(self, transform): + rb = TensorDictReplayBuffer( + transform=transform(in_keys=["observation"]), batch_size=1 + ) - m = mock.Mock() - m.side_effect = [td.unsqueeze(0)] - rb._transform.forward = m - # rb._transform.__len__ = lambda *args: 3 - rb.sample() - assert rb._transform.forward.called + # td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, []) + td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 3)}, []) + rb.add(td) - # was_called = [False] - # forward = rb._transform.forward - # def new_forward(*args, **kwargs): - # was_called[0] = True - # return forward(*args, **kwargs) - # rb._transform.forward = new_forward - # rb.sample() - # assert was_called[0] + m = mock.Mock() + m.side_effect = [td.unsqueeze(0)] + rb._transform.forward = m + # rb._transform.__len__ = lambda *args: 3 + rb.sample() + assert rb._transform.forward.called + + # was_called = [False] + # forward = rb._transform.forward + # def new_forward(*args, **kwargs): + # was_called[0] = True + # return forward(*args, **kwargs) + # rb._transform.forward = new_forward + # rb.sample() + # assert was_called[0] + + transforms2 = [ + partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3), + FiniteTensorDictCheck, + gSDENoise, + PinMemoryTransform, + ] + + @pytest.mark.parametrize("transform", transforms2) + def test_smoke_replay_buffer_transform_no_inkeys(self, transform): + if transform == PinMemoryTransform and not torch.cuda.is_available(): + raise pytest.skip("No CUDA device detected, skipping PinMemory") + rb = ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1 + ) + action = torch.zeros(3) + action[..., 0] = 1 + td = TensorDict( + {"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, [] + ) + rb.add(td) + rb.sample() -transforms = [ - partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3), - FiniteTensorDictCheck, - gSDENoise, - PinMemoryTransform, -] + rb._transform = mock.MagicMock() + rb._transform.__len__ = lambda *args: 3 + rb.sample() + assert rb._transform.called + @pytest.mark.parametrize("at_init", [True, False]) + def test_transform_nontensor(self, at_init): + def t(x): + return tree_map(lambda y: y * 0, x) -@pytest.mark.parametrize("transform", transforms) -def test_smoke_replay_buffer_transform_no_inkeys(transform): - if transform == PinMemoryTransform and not torch.cuda.is_available(): - raise pytest.skip("No CUDA device detected, skipping PinMemory") - rb = ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1 - ) + if at_init: + rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=t) + else: + rb = ReplayBuffer(storage=LazyMemmapStorage(100)) + rb.append_transform(t) + data = { + "a": torch.randn(3), + "b": {"c": (torch.zeros(2), [torch.ones(1)])}, + 30: -torch.ones(()), + } + rb.add(data) - action = torch.zeros(3) - action[..., 0] = 1 - td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, []) - rb.add(td) - rb.sample() + def assert0(x): + assert (x == 0).all() - rb._transform = mock.MagicMock() - rb._transform.__len__ = lambda *args: 3 - rb.sample() - assert rb._transform.called + s = rb.sample(10) + tree_map(assert0, s) @pytest.mark.parametrize("size", [10, 15, 20]) @@ -1530,7 +1711,7 @@ def worker(rb, q0, q1): extended = q1.get(timeout=5) assert extended == "extended" assert len(rb) == 21, len(rb) - assert (rb["_data", "a"][:9] == 2).all() + assert (rb["a"][:9] == 2).all() q0.put("finish") def exec_multiproc_rb( @@ -1556,7 +1737,7 @@ def exec_multiproc_rb( extended = q0.get(timeout=100) assert extended == "extended" assert len(rb) == 20 - assert (rb["_data", "a"][10:20] == 1).all() + assert (rb["a"][10:20] == 1).all() td = TensorDict({"a": torch.zeros(10) + 2}, [10]) rb.extend(td) q1.put("extended") diff --git a/test/test_transforms.py b/test/test_transforms.py index d0d892f1e2f..708bdaa715c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -45,7 +45,7 @@ from tensordict import unravel_key from tensordict.nn import TensorDictSequential from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import _unravel_key_to_tuple +from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod from torchrl.data import ( @@ -797,7 +797,7 @@ def test_transform_model(self, dim, N, padding): v1 = model(tdbase0) v2 = model(tdbase0_copy) # check that swapping dims and names leads to same result - assert (v1 == v2.transpose(0, 1)).all() + assert_allclose_td(v1, v2.transpose(0, 1)) @pytest.mark.parametrize("dim", [-2, -1]) @pytest.mark.parametrize("N", [3, 4]) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 7de6453d10a..93b72483ce8 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -271,9 +271,10 @@ def __init__( implement_for._setters.append(self) @staticmethod - def check_version(version, from_version, to_version): - return (from_version is None or parse(version) >= parse(from_version)) and ( - to_version is None or parse(version) < parse(to_version) + def check_version(version: str, from_version: str | None, to_version: str | None): + version = parse(".".join([str(v) for v in parse(version).release])) + return (from_version is None or version >= parse(from_version)) and ( + to_version is None or version < parse(to_version) ) @staticmethod diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 1e1ce31bf96..1381c6d2383 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -17,7 +17,8 @@ import torch -from tensordict import is_tensorclass +from tensordict import is_tensorclass, unravel_key +from tensordict.nn.utils import _set_dispatch_td_nn_modules from tensordict.tensordict import ( is_tensor_collection, LazyStackedTensorDict, @@ -78,12 +79,11 @@ class ReplayBuffer: prefetch (int, optional): number of next batches to be prefetched using multithreading. Defaults to None (no prefetching). transform (Transform, optional): Transform to be executed when - sample() is called. + :meth:`~.sample` is called. To chain transforms use the :class:`~torchrl.envs.Compose` class. Transforms should be used with :class:`tensordict.TensorDict` - content. If used with other structures, the transforms should be - encoded with a ``"data"`` leading key that will be used to - construct a tensordict from the non-tensordict content. + content. A generic callable can also be passed if the replay buffer + is used with PyTree structures (see example below). batch_size (int, optional): the batch size to be used when sample() is called. .. note:: @@ -129,7 +129,7 @@ class ReplayBuffer: Replay buffers accept *any* kind of data. Not all storage types will work, as some expect numerical data only, but the default - :class:`torchrl.data.ListStorage` will: + :class:`~torchrl.data.ListStorage` will: Examples: >>> torch.manual_seed(0) @@ -137,6 +137,32 @@ class ReplayBuffer: >>> indices = buffer.extend(["a", 1, None]) >>> buffer.sample(3) [None, 'a', None] + + The :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` + and :class:`~torchrl.data.replay_buffers.LazyTensorStorage` also work + with any PyTree structure (a PyTree is a nested structure of arbitrary depth made of dicts, + lists or tuples where the leaves are tensors) provided that it only contains + tensor data. + + Examples: + >>> from torch.utils._pytree import tree_map + >>> def transform(x): + ... # Zeros all the data in the pytree + ... return tree_map(lambda y: y * 0, x) + >>> rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform) + >>> data = { + ... "a": torch.randn(3), + ... "b": {"c": (torch.zeros(2), [torch.ones(1)])}, + ... 30: -torch.ones(()), + ... } + >>> rb.add(data) + >>> # The sample has a similar structure to the data (with a leading dimension of 10 for each tensor) + >>> s = rb.sample(10) + >>> # let's check that our transform did its job: + >>> def assert0(x): + >>> assert (x == 0).all() + >>> tree_map(assert0, s) + """ def __init__( @@ -168,11 +194,18 @@ def __init__( self._replay_lock = threading.RLock() self._futures_lock = threading.RLock() - from torchrl.envs.transforms.transforms import Compose + from torchrl.data.replay_buffers.transforms import _CallableTransform + from torchrl.envs.transforms.transforms import Compose, Transform if transform is None: transform = Compose() elif not isinstance(transform, Compose): + if not isinstance(transform, Transform) and callable(transform): + transform = _CallableTransform(transform) + elif not isinstance(transform, Transform): + raise RuntimeError( + "transform must be either a Transform instance or a callable." + ) transform = Compose(transform) transform.eval() self._transform = transform @@ -247,6 +280,8 @@ def __repr__(self) -> str: @pin_memory_output def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: + if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)): + return self[:][index] index = _to_numpy(index) with self._replay_lock: data = self._storage[index] @@ -362,10 +397,9 @@ def add(self, data: Any) -> int: Returns: index where the data lives in the replay buffer. """ - if self._transform is not None and ( - is_tensor_collection(data) or len(self._transform) - ): - data = self._transform.inv(data) + if self._transform is not None and len(self._transform): + with _set_dispatch_td_nn_modules(is_tensor_collection(data)): + data = self._transform.inv(data) return self._add(data) def _add(self, data): @@ -391,11 +425,21 @@ def extend(self, data: Sequence) -> torch.Tensor: Returns: Indices of the data added to the replay buffer. + + .. warning:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an + ambiguous signature when dealing with lists of values, which should be interpreted + either as PyTree (in which case all elements in the list will be put in a slice + in the stored PyTree in the storage) or a list of values to add one at a time. + To solve this, TorchRL makes the clear-cut distinction between list and tuple: + a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted + as a stack of values to add one at a time to the buffer. + For :class:`~torchrl.data.replay_buffers.ListStorage` instances, only + unbound elements can be provided (no PyTrees). + """ - if self._transform is not None and ( - is_tensor_collection(data) or len(self._transform) - ): - data = self._transform.inv(data) + if self._transform is not None and len(self._transform): + with _set_dispatch_td_nn_modules(is_tensor_collection(data)): + data = self._transform.inv(data) return self._extend(data) def update_priority( @@ -415,18 +459,16 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]: if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) if self._transform is not None and len(self._transform): - is_td = True - if not is_tensor_collection(data): - data = TensorDict({"data": data}, []) - is_td = False - is_locked = data.is_locked - if is_locked: - data.unlock_() - data = self._transform(data) - if is_locked: - data.lock_() - if not is_td: - data = data["data"] + is_td = is_tensor_collection(data) + if is_td: + is_locked = data.is_locked + if is_locked: + data.unlock_() + with _set_dispatch_td_nn_modules(is_td): + data = self._transform(data) + if is_td: + if is_locked: + data.lock_() return data, info @@ -497,6 +539,11 @@ def append_transform(self, transform: "Transform") -> None: # noqa-F821 Args: transform (Transform): The transform to be appended """ + from torchrl.data.replay_buffers.transforms import _CallableTransform + from torchrl.envs.transforms.transforms import Transform + + if not isinstance(transform, Transform) and callable(transform): + transform = _CallableTransform(transform) transform.eval() self._transform.append(transform) @@ -807,7 +854,8 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: def add(self, data: TensorDictBase) -> int: if self._transform is not None: - data = self._transform.inv(data) + with _set_dispatch_td_nn_modules(is_tensor_collection(data)): + data = self._transform.inv(data) if is_tensor_collection(data): data_add = TensorDict( @@ -913,7 +961,8 @@ def sample( ) data, info = super().sample(batch_size, return_info=True) - if not is_tensorclass(data) and include_info in (True, None): + is_tc = is_tensor_collection(data) + if is_tc and not is_tensorclass(data) and include_info in (True, None): is_locked = data.is_locked if is_locked: data.unlock_() @@ -924,6 +973,8 @@ def sample( data.set(k, v) if is_locked: data.lock_() + elif not is_tc and include_info in (True, None): + raise RuntimeError("Cannot include info in non-tensordict data") if return_info: return data, info return data diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index eef391de58c..5357b9a835f 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Sequence, Union import numpy as np +import tensordict import torch from tensordict import is_tensorclass from tensordict.memmap import MemmapTensor, MemoryMappedTensor @@ -23,6 +24,8 @@ from tensordict.utils import _STRDTYPE2DTYPE, expand_right from torch import multiprocessing as mp +from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten + from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES @@ -33,6 +36,10 @@ except ImportError: _has_ts = False +SINGLE_TENSOR_BUFFER_NAME = os.environ.get( + "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_" +) + class Storage: """A Storage is the container of a replay buffer. @@ -120,6 +127,10 @@ def _empty(self): class ListStorage(Storage): """A storage stored in a list. + This class cannot be extended with PyTrees, the data provided during calls to + :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` should be iterables + (like lists, tuples, tensors or tensordicts with non-empty batch-size). + Args: max_size (int): the maximum number of elements stored in the storage. @@ -149,8 +160,26 @@ def set(self, cursor: Union[int, Sequence[int], slice], data: Any): if isinstance(cursor, slice): self._storage[cursor] = data return - for _cursor, _data in zip(cursor, data): - self.set(_cursor, _data) + if isinstance( + data, + ( + list, + tuple, + torch.Tensor, + TensorDictBase, + *tensordict.base._ACCEPTED_CLASSES, + range, + set, + np.ndarray, + ), + ): + for _cursor, _data in zip(cursor, data): + self.set(_cursor, _data) + else: + raise TypeError( + f"Cannot extend a {type(self)} with data of type {type(data)}. " + f"Provide a list, tuple, set, range, np.ndarray, tensor or tensordict subclass instead." + ) return else: if cursor > len(self._storage): @@ -288,7 +317,10 @@ def __init__(self, storage, max_size=None, device="cpu"): f"max_size={max_size} for a storage of shape {storage.shape}." ) elif storage is not None: - max_size = storage.shape[0] + if is_tensor_collection(storage): + max_size = storage.shape[0] + else: + max_size = tree_flatten(storage)[0][0].shape[0] super().__init__(max_size) self.initialized = storage is not None if self.initialized: @@ -310,35 +342,22 @@ def dumps(self, path): if not self.initialized: raise RuntimeError("Cannot save a non-initialized storage.") - if isinstance(self._storage, torch.Tensor): - try: - MemoryMappedTensor.from_filename( - shape=self._storage.shape, - filename=path / "storage.memmap", - dtype=self._storage.dtype, - ).copy_(self._storage) - except FileNotFoundError: - MemoryMappedTensor.from_tensor( - self._storage, filename=path / "storage.memmap", copy_existing=True - ) - is_tensor = True - dtype = str(self._storage.dtype) - shape = list(self._storage.shape) - else: + metadata = {} + if is_tensor_collection(self._storage): # try to load the path and overwrite. self._storage.memmap( path, copy_existing=True, num_threads=torch.get_num_threads() ) - is_tensor = False - dtype = None - shape = None + is_pytree = False + else: + _save_pytree(self._storage, metadata, path) + is_pytree = True with open(path / "storage_metadata.json", "w") as file: json.dump( { - "is_tensor": is_tensor, - "dtype": dtype, - "shape": shape, + "metadata": metadata, + "is_pytree": is_pytree, "len": self._len, }, file, @@ -347,24 +366,50 @@ def dumps(self, path): def loads(self, path): with open(path / "storage_metadata.json", "r") as file: metadata = json.load(file) - is_tensor = metadata["is_tensor"] - shape = metadata["shape"] - dtype = metadata["dtype"] + is_pytree = metadata["is_pytree"] _len = metadata["len"] - if dtype is not None: - shape = torch.Size(shape) - dtype = _STRDTYPE2DTYPE[dtype] - if is_tensor: - _storage = MemoryMappedTensor.from_filename( - path / "storage.memmap", shape=shape, dtype=dtype - ).clone() + if is_pytree: + path = Path(path) + for local_path, md in metadata["metadata"].items(): + # load tensor + local_path_dot = local_path.replace(".", "/") + total_tensor_path = path / (local_path_dot + ".memmap") + shape = torch.Size(md["shape"]) + dtype = _STRDTYPE2DTYPE[md["dtype"]] + tensor = MemoryMappedTensor.from_filename( + filename=total_tensor_path, shape=shape, dtype=dtype + ) + # split path + local_path = local_path.split(".") + # replace potential dots + local_path = [_path.replace("__", ".") for _path in local_path] + if self.initialized: + # copy in-place + _storage_tensor = self._storage + # in this case there is a single tensor, so we skip + if local_path != ["_-single-tensor-_"]: + for _path in local_path: + if _path.isdigit(): + _path_attempt = int(_path) + try: + _storage_tensor = _storage_tensor[_path_attempt] + continue + except IndexError: + pass + _storage_tensor = _storage_tensor[_path] + _storage_tensor.copy_(tensor) + else: + raise RuntimeError( + "Cannot fill a non-initialized pytree-based TensorStorage." + ) else: _storage = TensorDict.load_memmap(path) - if not self.initialized: - self._storage = _storage - self.initialized = True - else: - self._storage.copy_(_storage) + if not self.initialized: + # this should not be reached if is_pytree=True + self._storage = _storage + self.initialized = True + else: + self._storage.copy_(_storage) self._len = _len @property @@ -454,7 +499,7 @@ def load_state_dict(self, state_dict): self._storage = TensorDict({}, []).load_state_dict(_storage) else: raise RuntimeError( - f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}" + f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}. If your storage is pytree-based, use the dumps/load API instead." ) else: raise TypeError( @@ -463,6 +508,22 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] + @implement_for("torch", "2.3") + def _set_tree_map(self, cursor, data, storage): + def set_tensor(datum, store): + store[cursor] = datum + + # this won't be available until v2.3 + tree_map(set_tensor, data, storage) + + @implement_for("torch", "2.0", "2.3") + def _set_tree_map(self, cursor, data, storage): # noqa: 534 + # flatten data and cursor + data_flat = tree_flatten(data)[0] + storage_flat = tree_flatten(storage)[0] + for datum, store in zip(data_flat, storage_flat): + store[cursor] = datum + @implement_for("torch", "2.0", None) def set( self, @@ -476,10 +537,16 @@ def set( if not self.initialized: if not isinstance(cursor, INT_CLASSES): - self._init(data[0]) + if is_tensor_collection(data): + self._init(data[0]) + else: + self._init(tree_map(lambda x: x[0], data)) else: self._init(data) - self._storage[cursor] = data + if is_tensor_collection(data): + self._storage[cursor] = data + else: + self._set_tree_map(cursor, data, self._storage) @implement_for("torch", None, "2.0") def set( # noqa: F811 @@ -492,6 +559,11 @@ def set( # noqa: F811 else: self._len = max(self._len, max(cursor) + 1) + if not is_tensor_collection(data) and not isinstance(data, torch.Tensor): + raise NotImplementedError( + "storage extension with pytrees is only available with torch >= 2.0. If you need this " + "feature, please open an issue on TorchRL's github repository." + ) if not self.initialized: if not isinstance(cursor, INT_CLASSES): self._init(data[0]) @@ -514,19 +586,24 @@ def set( # noqa: F811 self._storage[cursor] = data def get(self, index: Union[int, Sequence[int], slice]) -> Any: + _storage = self._storage + is_tc = is_tensor_collection(_storage) if self._len < self.max_size: - storage = self._storage[: self._len] + if is_tc: + storage = self._storage[: self._len] + else: + storage = tree_map(lambda x: x[: self._len], self._storage) else: storage = self._storage if not self.initialized: raise RuntimeError( "Cannot get an item from an unitialized LazyMemmapStorage" ) - out = storage[index] - if is_tensor_collection(out): - out = _reset_batch_size(out) - return out # .unlock_() - return out + if is_tc: + out = storage[index] + return _reset_batch_size(out) + else: + return tree_map(lambda x: x[index], storage) def __len__(self): return self._len @@ -607,24 +684,19 @@ class LazyTensorStorage(TensorStorage): def __init__(self, max_size, device="cpu"): super().__init__(storage=None, max_size=max_size, device=device) - def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: + def _init( + self, + data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 + ) -> None: if VERBOSE: logging.info("Creating a TensorStorage...") if self.device == "auto": self.device = data.device - if isinstance(data, torch.Tensor): - # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype - out = torch.empty( - self.max_size, - *data.shape, - device=self.device, - dtype=data.dtype, - ) - elif is_tensorclass(data): + if is_tensorclass(data): out = ( data.expand(self.max_size, *data.shape).clone().zero_().to(self.device) ) - else: + elif is_tensor_collection(data): out = ( data.expand(self.max_size, *data.shape) .to_tensordict() @@ -632,6 +704,17 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: .clone() .to(self.device) ) + else: + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + out = tree_map( + lambda data: torch.empty( + self.max_size, + *data.shape, + device=self.device, + dtype=data.dtype, + ), + data, + ) self._storage = out self.initialized = True @@ -720,7 +803,7 @@ def state_dict(self) -> Dict[str, Any]: _storage = {} else: raise TypeError( - f"Objects of type {type(_storage)} are not supported by LazyTensorStorage.state_dict" + f"Objects of type {type(_storage)} are not supported by LazyTensorStorage.state_dict. If you are trying to serialize a PyTree, the storage.dumps/loads is preferred." ) return { "_storage": _storage, @@ -792,29 +875,23 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." ) else: - # If not a tensorclass/tensordict, it must be a tensor(-like) - # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype - out = _make_empty_memmap( - (self.max_size, *data.shape), - dtype=data.dtype, - path=self.scratch_dir + "/tensor.memmap" - if self.scratch_dir is not None - else None, - ) - if VERBOSE: - filesize = os.path.getsize(out.filename) / 1024 / 1024 - logging.info( - f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." - ) + out = _init_pytree(self.scratch_dir, self.max_size, data) self._storage = out self.initialized = True def get(self, index: Union[int, Sequence[int], slice]) -> Any: result = super().get(index) + # to be deprecated in v0.4 - if result.device != self.device: - return result.to(self.device, non_blocking=True) - return result + def map_device(tensor): + if tensor.device != self.device: + return tensor.to(self.device, non_blocking=True) + return tensor + + if is_tensor_collection(result): + return map_device(result) + else: + return tree_map(map_device, result) class StorageEnsemble(Storage): @@ -1069,3 +1146,167 @@ def _make_memmap(tensor, path): def _make_empty_memmap(shape, dtype, path): return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path) + + +@implement_for("torch", "2.3", None) +def _path2str(path, default_name=None): + # Uses the Keys defined in pytree to build a path + from torch.utils._pytree import MappingKey, SequenceKey + + if default_name is None: + default_name = SINGLE_TENSOR_BUFFER_NAME + if not path: + return default_name + if isinstance(path, tuple): + return "/".join([_path2str(_sub, default_name=default_name) for _sub in path]) + if isinstance(path, MappingKey): + if not isinstance(path.key, (int, str, bytes)): + raise ValueError("Values must be of type int, str or bytes in PyTree maps.") + result = str(path.key) + if result == default_name: + raise RuntimeError( + "A tensor had the same identifier as the default name used when the buffer contains " + f"a single tensor (name={default_name}). This behaviour is not allowed. Please rename your " + f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME." + ) + return result + if isinstance(path, SequenceKey): + return str(path.idx) + + +@implement_for("torch", None, "2.3") +def _path2str(path, default_name=None): # noqa: F811 + raise RuntimeError + + +def _get_paths(spec, cumulpath=""): + # alternative way to build a path without the keys + if isinstance(spec, LeafSpec): + yield cumulpath if cumulpath else SINGLE_TENSOR_BUFFER_NAME + + contexts = spec.context + children_specs = spec.children_specs + if contexts is None: + contexts = range(len(children_specs)) + + for context, spec in zip(contexts, children_specs): + cpath = "/".join((cumulpath, str(context))) if cumulpath else str(context) + yield from _get_paths(spec, cpath) + + +def _save_pytree_common(tensor_path, path, tensor, metadata): + if "." in tensor_path: + tensor_path.replace(".", "__") + total_tensor_path = path / (tensor_path + ".memmap") + if os.path.exists(total_tensor_path): + MemoryMappedTensor.from_filename( + shape=tensor.shape, + filename=total_tensor_path, + dtype=tensor.dtype, + ).copy_(tensor) + else: + os.makedirs(total_tensor_path.parent, exist_ok=True) + MemoryMappedTensor.from_tensor( + tensor, + filename=total_tensor_path, + copy_existing=True, + copy_data=True, + ) + key = tensor_path.replace("/", ".") + if key in metadata: + raise KeyError( + "At least two values have conflicting representations in " + f"the data structure to be serialized: {key}." + ) + metadata[key] = { + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + } + + +@implement_for("torch", "2.3", None) +def _save_pytree(_storage, metadata, path): + from torch.utils._pytree import tree_map_with_path + + def save_tensor( + tensor_path: tuple, tensor: torch.Tensor, metadata=metadata, path=path + ): + tensor_path = _path2str(tensor_path) + _save_pytree_common(tensor_path, path, tensor, metadata) + + tree_map_with_path(save_tensor, _storage) + + +@implement_for("torch", None, "2.3") +def _save_pytree(_storage, metadata, path): # noqa: F811 + + flat_storage, storage_specs = tree_flatten(_storage) + storage_paths = _get_paths(storage_specs) + + def save_tensor( + tensor_path: str, tensor: torch.Tensor, metadata=metadata, path=path + ): + _save_pytree_common(tensor_path, path, tensor, metadata) + + for tensor, tensor_path in zip(flat_storage, storage_paths): + save_tensor(tensor_path, tensor) + + +def _init_pytree_common(tensor_path, scratch_dir, max_size, tensor): + if "." in tensor_path: + tensor_path.replace(".", "__") + if scratch_dir is not None: + total_tensor_path = Path(scratch_dir) / (tensor_path + ".memmap") + if os.path.exists(total_tensor_path): + raise RuntimeError( + f"The storage of tensor {total_tensor_path} already exists. " + f"To load an existing replay buffer, use storage.loads. " + f"Choose a different path to store your buffer or delete the existing files." + ) + os.makedirs(total_tensor_path.parent, exist_ok=True) + else: + total_tensor_path = None + out = MemoryMappedTensor.empty( + shape=(max_size, *tensor.shape), + filename=total_tensor_path, + dtype=tensor.dtype, + ) + if VERBOSE: + filesize = os.path.getsize(out.filename) / 1024 / 1024 + logging.info( + f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." + ) + return out + + +@implement_for("torch", "2.3", None) +def _init_pytree(scratch_dir, max_size, data): + from torch.utils._pytree import tree_map_with_path + + # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + def save_tensor(tensor_path: tuple, tensor: torch.Tensor): + tensor_path = _path2str(tensor_path) + return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor) + + out = tree_map_with_path(save_tensor, data) + return out + + +@implement_for("torch", None, "2.3") +def _init_pytree(scratch_dir, max_size, data): # noqa: F811 + + flat_data, data_specs = tree_flatten(data) + data_paths = _get_paths(data_specs) + data_paths = list(data_paths) + + # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + def save_tensor(tensor_path: str, tensor: torch.Tensor): + return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor) + + out = [] + for tensor, tensor_path in zip(flat_data, data_paths): + out.append(save_tensor(tensor_path, tensor)) + + return tree_unflatten(out, data_specs) diff --git a/torchrl/data/replay_buffers/transforms.py b/torchrl/data/replay_buffers/transforms.py new file mode 100644 index 00000000000..828bfbe549b --- /dev/null +++ b/torchrl/data/replay_buffers/transforms.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from tensordict import TensorDictBase +from torchrl.envs.transforms.transforms import Transform + + +class _CallableTransform(Transform): + # A wrapper around a custom callable to make it possible to transform any data type + def __init__(self, func): + super().__init__() + self.func = func + + def forward(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def _call(self, tensordict): + raise RuntimeError + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 243df9a8011..41d551535ac 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -14,9 +14,11 @@ import numpy as np import torch + from tensordict import is_tensor_collection, MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp +from torch.utils._pytree import tree_flatten from torchrl.data.replay_buffers.storages import Storage from torchrl.data.replay_buffers.utils import _reduce @@ -118,7 +120,14 @@ def add(self, data: Any) -> int: def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor - batch_size = len(data) + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + batch_size = len(data) + elif isinstance(data, list): + batch_size = len(data) + else: + batch_size = len(tree_flatten(data)[0][0]) + if batch_size == 0: + raise RuntimeError("Expected at least one element in extend.") device = data.device if hasattr(data, "device") else None index = ( torch.arange( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ed8be751474..e2454c6d35e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -28,6 +28,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, NestedKey from torch import nn, Tensor +from torch.utils._pytree import tree_map from torchrl._utils import _replace_last from torchrl.data.tensor_specs import ( @@ -346,7 +347,16 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: @dispatch(source="in_keys_inv", dest="out_keys_inv") def inv(self, tensordict: TensorDictBase) -> TensorDictBase: - out = self._inv_call(tensordict.clone(False)) + def clone(data): + try: + # we priviledge speed for tensordicts + return data.clone(recurse=False) + except AttributeError: + return tree_map(lambda x: x, data) + except TypeError: + return tree_map(lambda x: x, data) + + out = self._inv_call(clone(tensordict)) return out def transform_env_device(self, device: torch.device): diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index be6e607c1b5..ab3a99cdece 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -262,6 +262,50 @@ class MyData: ###################################################################### # As expected. the data has the proper class and shape! # +# Integration with PyTree +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# TorchRL's replay buffers also work with any pytree data structure. +# A PyTree is a nested structure of arbitrary depth made of dicts, lists and/or +# tuples where the leaves are tensors. +# This means that one can store in contiguous memory any such tree structure! +# Various storages can be used: +# :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this kind of data. +# +# Here is a bried demonstration of what this feature looks like: +# + +from torch.utils._pytree import tree_map + + +# With pytrees, any callable can be used as a transform: +def transform(x): + # Zeros all the data in the pytree + return tree_map(lambda y: y * 0, x) + + +# Let's build our replay buffer on disk: +rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform) +data = { + "a": torch.randn(3), + "b": {"c": (torch.zeros(2), [torch.ones(1)])}, + 30: -torch.ones(()), # non-string keys also work +} +rb.add(data) + +# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor) +sample = rb.sample(10) + + +# let's check that our transform did its job: +def assert0(x): + assert (x == 0).all() + + +tree_map(assert0, sample) + + # Sampling and iterating over buffers # ----------------------------------- # From e679e7189ac94cf774315a00219ea5b7cbcb7d74 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jan 2024 16:23:53 +0000 Subject: [PATCH 09/35] [BugFix] Fix sequential step counts (#1838) --- test/test_transforms.py | 16 ++++++++++++++++ torchrl/envs/transforms/transforms.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 708bdaa715c..94ea338365e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1477,6 +1477,22 @@ def test_step_count_gym(self): env.rollout(1000) check_env_specs(env) + @pytest.mark.skipif(not _has_gym, reason="no gym detected") + def test_step_count_gym_doublecount(self): + # tests that 2 truncations can be used together + env = TransformedEnv( + GymEnv(PENDULUM_VERSIONED), + Compose( + StepCounter(max_steps=2), + StepCounter(max_steps=3), # this one will be ignored + ), + ) + r = env.rollout(10, break_when_any_done=False) + assert ( + r.get(("next", "truncated")).squeeze().nonzero().squeeze(-1) + == torch.arange(1, 10, 2) + ).all() + @pytest.mark.skipif(not _has_dm_control, reason="no dm_control detected") def test_step_count_dmc(self): env = TransformedEnv(DMControlEnv("cheetah", "run"), StepCounter(max_steps=30)) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e2454c6d35e..41ce74293bb 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5169,6 +5169,7 @@ def _reset( tensordict_reset.set(step_count_key, step_count) if self.max_steps is not None: truncated = step_count >= self.max_steps + truncated = truncated | tensordict_reset.get(truncated_key, False) if self.update_done: # we assume no done after reset tensordict_reset.set(done_key, truncated) @@ -5187,8 +5188,10 @@ def _step( step_count = tensordict.get(step_count_key) next_step_count = step_count + 1 next_tensordict.set(step_count_key, next_step_count) + if self.max_steps is not None: truncated = next_step_count >= self.max_steps + truncated = truncated | next_tensordict.get(truncated_key, False) if self.update_done: done = next_tensordict.get(done_key, None) terminated = next_tensordict.get(terminated_key, None) From 3fd637f6bd19d842bf1095cbb258d73688833519 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jan 2024 17:20:53 +0000 Subject: [PATCH 10/35] [Doc] TED format (#1836) --- docs/source/reference/data.rst | 169 +++++++++++++++++++++++ docs/source/reference/envs.rst | 4 + torchrl/modules/tensordict_module/rnn.py | 4 +- 3 files changed, 175 insertions(+), 2 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 91391c6af36..859ca9c389c 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -298,6 +298,175 @@ before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback of this method is that it will struggle to save big data structures, which is a common setting when using replay buffers. +TorchRL Episode Data Format (TED) +--------------------------------- + +In TorchRL, sequential data is consistently presented in a specific format, known +as the TorchRL Episode Data Format (TED). This format is crucial for the seamless +integration and functioning of various components within TorchRL. + +Some components, such as replay buffers, are somewhat indifferent to the data +format. However, others, particularly environments, heavily depend on it for smooth operation. + +Therefore, it's essential to understand the TED, its purpose, and how to interact +with it. This guide will provide a clear explanation of the TED, why it's used, +and how to effectively work with it. + +The Rationale Behind TED +~~~~~~~~~~~~~~~~~~~~~~~~ + +Formatting sequential data can be a complex task, especially in the realm of +Reinforcement Learning (RL). As practitioners, we often encounter situations +where data is delivered at the reset time (though not always), and sometimes data +is provided or discarded at the final step of the trajectory. + +This variability means that we can observe data of different lengths in a dataset, +and it's not always immediately clear how to match each time step across the +various elements of this dataset. Consider the following ambiguous dataset structure: + + >>> observation.shape + [200, 3] + >>> action.shape + [199, 4] + >>> info.shape + [200, 3] + +At first glance, it seems that the info and observation were delivered +together (one of each at reset + one of each at each step call), as suggested by +the action having one less element. However, if info has one less element, we +must assume that it was either omitted at reset time or not delivered or recorded +for the last step of the trajectory. Without proper documentation of the data +structure, it's impossible to determine which info corresponds to which time step. + +Complicating matters further, some datasets provide inconsistent data formats, +where ``observations`` or ``infos`` are missing at the start or end of the +rollout, and this behavior is often not documented. +The primary aim of TED is to eliminate these ambiguities by providing a clear +and consistent data representation. + +The structure of TED +~~~~~~~~~~~~~~~~~~~~ + +TED is built upon the canonical definition of a Markov Decision Process (MDP) in RL contexts. +At each step, an observation conditions an action that results in (1) a new +observation, (2) an indicator of task completion (terminated, truncated, done), +and (3) a reward signal. + +Some elements may be missing (for example, the reward is optional in imitation +learning contexts), or additional information may be passed through a state or +info container. In some cases, additional information is required to get the +observation during a call to ``step`` (for instance, in stateless environment simulators). Furthermore, +in certain scenarios, an "action" (or any other data) cannot be represented as a +single tensor and needs to be organized differently. For example, in Multi-Agent RL +settings, actions, observations, rewards, and completion signals may be composite. + +TED accommodates all these scenarios with a single, uniform, and unambiguous +format. We distinguish what happens at time step ``t`` and ``t+1`` by setting a +limit at the time the action is executed. In other words, everything that was +present before ``env.step`` was called belongs to ``t``, and everything that +comes after belongs to ``t+1``. + +The general rule is that everything that belongs to time step ``t`` is stored +at the root of the tensordict, while everything that belongs to ``t+1`` is stored +in the ``"next"`` entry of the tensordict. Here's an example: + + >>> data = env.reset() + >>> data = policy(data) + >>> print(env.step(data)) + TensorDict( + fields={ + action: Tensor(...), # The action taken at time t + done: Tensor(...), # The done state when the action was taken (at reset) + next: TensorDict( # all of this content comes from the call to `step` + fields={ + done: Tensor(...), # The done state after the action has been taken + observation: Tensor(...), # The observation resulting from the action + reward: Tensor(...), # The reward resulting from the action + terminated: Tensor(...), # The terminated state after the action has been taken + truncated: Tensor(...), # The truncated state after the action has been taken + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(...), # the observation at reset + terminated: Tensor(...), # the terminated at reset + truncated: Tensor(...), # the truncated at reset + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + +During a rollout (either using :class:`~torchrl.envs.EnvBase` or +:class:`~torchrl.collectors.SyncDataCollector`), the content of the ``"next"`` +tensordict is brought to the root through the :func:`~torchrl.envs.utils.step_mdp` +function when the agent resets its step count: ``t <- t+1``. You can read more +about the environment API :ref:`here `. + +In most cases, there is no `True`-valued ``"done"`` state at the root since any +done state will trigger a (partial) reset which will turn the ``"done"`` to ``False``. +However, this is only true as long as resets are automatically performed. In some +cases, partial resets will not trigger a reset, so we retain these data, which +should have a considerably lower memory footprint than observations, for instance. + +This format eliminates any ambiguity regarding the matching of an observation with +its action, info, or done state. + +Dimensionality of the Tensordict +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +During a rollout, all collected tensordicts will be stacked along a new dimension +positioned at the end. Both collectors and environments will label this dimension +with the ``"time"`` name. Here's an example: + + >>> rollout = env.rollout(10, policy) + >>> assert rollout.shape[-1] == 10 + >>> assert rollout.names[-1] == "time" + +This ensures that the time dimension is clearly marked and easily identifiable +in the data structure. + +Special cases and footnotes +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Multi-Agent data presentation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The multi-agent data formatting documentation can be accessed in the :ref:`MARL environment API ` section. + +Memory-based policies (RNNs and Transformers) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the examples provided above, only ``env.step(data)`` generates data that +needs to be read in the next step. However, in some cases, the policy also +outputs information that will be required in the next step. This is typically +the case for RNN-based policies, which output an action as well as a recurrent +state that needs to be used in the next step. +To accommodate this, we recommend users to adjust their RNN policy to write this +data under the ``"next"`` entry of the tensordict. This ensures that this content +will be brought to the root in the next step. More information can be found in +:class:`~torchrl.modules.GRUModule` and :class:`~torchrl.modules.LSTMModule`. + +Multi-step +^^^^^^^^^^ + +Collectors allow users to skip steps when reading the data, accumulating reward +for the upcoming n steps. This technique is popular in DQN-like algorithms like Rainbow. +The :class:`~torchrl.data.postprocs.MultiStep` class performs this data transformation +on batches coming out of collectors. In these cases, a check like the following +will fail since the next observation is shifted by n steps: + + >>> assert (data[..., 1:]["observation"] == data[..., :-1]["next", "observation"]).all() + +What about memory requirements? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Implemented naively, this data format consumes approximately twice the memory +that a flat representation would. In some memory-intensive settings +(for example, in the :class:`~torchrl.data.datasets.AtariDQNExperienceReplay` dataset), +we store only the ``T+1`` observation on disk and perform the formatting online at get time. +In other cases, we assume that the 2x memory cost is a small price to pay for a +clearer representation. However, generalizing the lazy representation for offline +datasets would certainly be a beneficial feature to have, and we welcome +contributions in this direction! + Datasets -------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 948bf896a45..cce34e14b14 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -3,6 +3,8 @@ torchrl.envs package ==================== +.. _Environment-API: + TorchRL offers an API to handle environments of different backends, such as gym, dm-control, dm-lab, model-based environments as well as custom environments. The goal is to be able to swap environments in an experiment with little or no effort, @@ -333,6 +335,8 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w Multi-agent environments ------------------------ +.. _MARL-environment-API: + .. currentmodule:: torchrl.envs TorchRL supports multi-agent learning out-of-the-box. diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 75c6110c413..e392070c517 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -828,7 +828,8 @@ class GRU(GRUBase): """A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like :class:`torch.nn.GRU`, but this implementation is exclusively coded in Python. .. note:: - This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`. + This class is implemented without relying on CuDNN, which makes it + compatible with :func:`torch.vmap` and :func:`torch.compile`. Examples: >>> import torch @@ -1031,7 +1032,6 @@ class GRUModule(ModuleBase): dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 - proj_size: If ``> 0``, will use GRU with projections of corresponding size. Default: 0 python_based: If ``True``, will use a full Python implementation of the GRU cell. Default: ``False`` Keyword Args: From 6f90397dc8ffe441ca6efb345c2c3e9346fda6c7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jan 2024 17:25:43 +0000 Subject: [PATCH 11/35] [Doc] References to TED (#1839) --- docs/source/reference/data.rst | 2 ++ torchrl/data/datasets/atari_dqn.py | 2 +- torchrl/data/datasets/d4rl.py | 1 + torchrl/data/datasets/gen_dgrl.py | 2 ++ torchrl/data/datasets/minari_data.py | 4 ++++ torchrl/data/datasets/openml.py | 2 ++ torchrl/data/datasets/openx.py | 2 ++ torchrl/data/datasets/roboset.py | 2 ++ torchrl/data/datasets/vd4rl.py | 2 ++ 9 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 859ca9c389c..6ed32ebe921 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -301,6 +301,8 @@ common setting when using replay buffers. TorchRL Episode Data Format (TED) --------------------------------- +.. _TED-format: + In TorchRL, sequential data is consistently presented in a specific format, known as the TorchRL Episode Data Format (TED). This format is crucial for the seamless integration and functioning of various components within TorchRL. diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 93950532026..28fddb79fca 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -41,7 +41,7 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset has 50 million steps in total. - The data format follows the TED convention. Since the dataset is quite heavy, + The data format follows the :ref:`TED convention `. Since the dataset is quite heavy, the data formatting is done on-line, at sampling time. To make training more modular, we split the dataset in each of the Atari games diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 2d91da82367..adf6317e679 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -38,6 +38,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): To install D4RL, follow the instructions on the `official repo `__. + The data format follows the :ref:`TED convention `. The replay buffer contains the env specs under D4RLExperienceReplay.specs. If present, metadata will be written in ``D4RLExperienceReplay.metadata`` diff --git a/torchrl/data/datasets/gen_dgrl.py b/torchrl/data/datasets/gen_dgrl.py index 47f1c56f58e..d1ca0b15fb8 100644 --- a/torchrl/data/datasets/gen_dgrl.py +++ b/torchrl/data/datasets/gen_dgrl.py @@ -34,6 +34,8 @@ class GenDGRLExperienceReplay(TensorDictReplayBuffer): GitHub: https://github.com/facebookresearch/gen_dgrl + The data format follows the :ref:`TED convention `. + This class gives you access to the ProcGen dataset. Each `dataset_id` registered in `GenDGRLExperienceReplay.available_datasets` consists in a particular task (`"bigfish"`, `"bossfight"`, ...) separated from a category (`"1M_E"`, `"1M_S"`, ...) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 866888ae925..5deeccd3253 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -59,6 +59,10 @@ class MinariExperienceReplay(TensorDictReplayBuffer): """Minari Experience replay dataset. + Learn more about Minari on their website: https://minari.farama.org/ + + The data format follows the :ref:`TED convention `. + Args: dataset_id (str): The dataset to be downloaded. Must be part of MinariExperienceReplay.available_datasets batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index fadcc0e7f96..0070c86d534 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -27,6 +27,8 @@ class OpenMLExperienceReplay(TensorDictReplayBuffer): This class provides an easy entry point for public datasets. See "Dua, D. and Graff, C. (2017) UCI Machine Learning Repository. http://archive.ics.uci.edu/ml" + The data format follows the :ref:`TED convention `. + The data is accessed via scikit-learn. Make sure sklearn and pandas are installed before retrieving the data: diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 4beb18b00a1..0b825188a5b 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -42,6 +42,8 @@ class OpenXExperienceReplay(TensorDictReplayBuffer): Paper: https://arxiv.org/abs/2310.08864 + The data format follows the :ref:`TED convention `. + .. note:: Non-tensor data will be written in the tensordict data using the :class:`~tensordict.tensorclass.NonTensorData` primitive. diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 825b937e8ac..8d8b84fb7a9 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -43,6 +43,8 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): Learn more about roboset here: https://sites.google.com/view/robohive/roboset + The data format follows the :ref:`TED convention `. + Args: dataset_id (str): the dataset to be downloaded. Must be part of RobosetExperienceReplay.available_datasets. batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 417c025ae59..f107804ae84 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -50,6 +50,8 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer): that is not reward, done-state, action or pixels is moved under a `"state"` node. + The data format follows the :ref:`TED convention `. + Args: dataset_id (str): the dataset to be downloaded. Must be part of VD4RLExperienceReplay.available_datasets. From 6a42116c839e48f0f679b10dec8dea07619689f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 26 Jan 2024 10:40:08 +0000 Subject: [PATCH 12/35] [BugFix] Temporarily set lazy legacy to True (#1840) --- torchrl/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 7d807244f70..9b460db5216 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -7,8 +7,12 @@ import torch +from tensordict import set_lazy_legacy + from torch import multiprocessing as mp +set_lazy_legacy(True).set() + if torch.cuda.device_count() > 1: n = torch.cuda.device_count() - 1 os.environ["MUJOCO_EGL_DEVICE_ID"] = str(1 + (os.getpid() % n)) From c2fae321abe574fd8e63bf2198e4015829344fa0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 26 Jan 2024 14:16:49 +0000 Subject: [PATCH 13/35] [BugFix] Fix gym info scalar infos (#1842) --- torchrl/envs/gym_like.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index e6d6ca93a68..6d0b3dc0213 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -84,7 +84,7 @@ def __init__( _info_spec = None elif spec is None: _info_spec = CompositeSpec( - {key: UnboundedContinuousTensorSpec() for key in keys}, shape=[] + {key: UnboundedContinuousTensorSpec(()) for key in keys}, shape=[] ) elif not isinstance(spec, CompositeSpec): if self.keys is not None and len(spec) != len(self.keys): From 9da61f239258d421f8dcf8b63da666d354bb58ae Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Jan 2024 09:12:10 +0000 Subject: [PATCH 14/35] [Refactor] LAZY_LEGACY_OP=False (#1832) --- .github/unittest/linux/scripts/run_all.sh | 2 +- .../unittest/linux_distributed/scripts/run_test.sh | 1 + .github/unittest/linux_examples/scripts/run_all.sh | 1 + .../unittest/linux_libs/scripts_ataridqn/run_test.sh | 1 + .github/unittest/linux_libs/scripts_brax/run_test.sh | 1 + .github/unittest/linux_libs/scripts_d4rl/run_test.sh | 1 + .../unittest/linux_libs/scripts_envpool/run_test.sh | 1 + .../unittest/linux_libs/scripts_gen-dgrl/run_test.sh | 1 + .github/unittest/linux_libs/scripts_gym/run_test.sh | 1 + .../unittest/linux_libs/scripts_habitat/run_test.sh | 1 + .../unittest/linux_libs/scripts_jumanji/run_test.sh | 1 + .../unittest/linux_libs/scripts_minari/run_test.sh | 1 + .github/unittest/linux_libs/scripts_openx/run_test.sh | 1 + .../linux_libs/scripts_pettingzoo/run_test.sh | 1 + .github/unittest/linux_libs/scripts_rlhf/run_test.sh | 1 + .../unittest/linux_libs/scripts_robohive/setup_env.sh | 3 ++- .../unittest/linux_libs/scripts_roboset/run_test.sh | 1 + .../unittest/linux_libs/scripts_sklearn/run_test.sh | 1 + .../unittest/linux_libs/scripts_smacv2/run_test.sh | 1 + .github/unittest/linux_libs/scripts_vd4rl/run_test.sh | 1 + .github/unittest/linux_libs/scripts_vmas/run_test.sh | 1 + .../linux_olddeps/scripts_gym_0_13/run_test.sh | 1 + .github/unittest/linux_optdeps/scripts/run_test.sh | 1 + .github/unittest/windows_optdepts/scripts/run_test.sh | 1 + test/test_collector.py | 5 +++-- test/test_env.py | 11 ++++++----- test/test_libs.py | 4 ++-- test/test_shared.py | 4 ++-- test/test_transforms.py | 2 +- torchrl/__init__.py | 2 +- torchrl/collectors/utils.py | 5 ++++- torchrl/envs/batched_envs.py | 8 +++++--- torchrl/envs/common.py | 11 +++++++---- torchrl/envs/transforms/r3m.py | 7 ++++--- torchrl/envs/transforms/transforms.py | 4 ++++ torchrl/envs/transforms/vc1.py | 4 ++-- torchrl/envs/transforms/vip.py | 8 +++++--- torchrl/envs/utils.py | 4 +++- torchrl/modules/tensordict_module/rnn.py | 5 ++--- torchrl/trainers/helpers/models.py | 4 +++- 40 files changed, 79 insertions(+), 36 deletions(-) diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 452cfb1c452..2340dbb2e54 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -76,7 +76,7 @@ export DISPLAY=:0 export SDL_VIDEODRIVER=dummy # legacy from bash scripts: remove? -conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy +conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False pip3 install pip --upgrade pip install virtualenv diff --git a/.github/unittest/linux_distributed/scripts/run_test.sh b/.github/unittest/linux_distributed/scripts/run_test.sh index 863e940ad4b..fe7d1ba1ea3 100755 --- a/.github/unittest/linux_distributed/scripts/run_test.sh +++ b/.github/unittest/linux_distributed/scripts/run_test.sh @@ -18,6 +18,7 @@ lib_dir="${env_dir}/lib" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch +export LAZY_LEGACY_OP=False export BATCHED_PIPE_TIMEOUT=60 python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 171f4637c07..74fdc043f0a 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -74,6 +74,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin export SDL_VIDEODRIVER=dummy export MUJOCO_GL=egl export PYOPENGL_PLATFORM=egl +export LAZY_LEGACY_OP=False conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ DISPLAY=unix:0.0 \ diff --git a/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh b/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh index ee7bf9b46b1..d30e71112a8 100755 --- a/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh +++ b/.github/unittest/linux_libs/scripts_ataridqn/run_test.sh @@ -8,6 +8,7 @@ conda activate ./env apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 ln -s /usr/bin/swig3.0 /usr/bin/swig +export LAZY_LEGACY_OP=False export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" diff --git a/.github/unittest/linux_libs/scripts_brax/run_test.sh b/.github/unittest/linux_libs/scripts_brax/run_test.sh index 6a4dac48331..5a6f3a1aa30 100755 --- a/.github/unittest/linux_libs/scripts_brax/run_test.sh +++ b/.github/unittest/linux_libs/scripts_brax/run_test.sh @@ -7,6 +7,7 @@ conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_d4rl/run_test.sh b/.github/unittest/linux_libs/scripts_d4rl/run_test.sh index 242fcf7fc81..062341eacd7 100755 --- a/.github/unittest/linux_libs/scripts_d4rl/run_test.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/run_test.sh @@ -24,6 +24,7 @@ cd .. #cd .. export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_envpool/run_test.sh b/.github/unittest/linux_libs/scripts_envpool/run_test.sh index 289adf454e7..bb4ee655673 100755 --- a/.github/unittest/linux_libs/scripts_envpool/run_test.sh +++ b/.github/unittest/linux_libs/scripts_envpool/run_test.sh @@ -12,6 +12,7 @@ eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_gen-dgrl/run_test.sh b/.github/unittest/linux_libs/scripts_gen-dgrl/run_test.sh index d42193855fa..0b72f134751 100755 --- a/.github/unittest/linux_libs/scripts_gen-dgrl/run_test.sh +++ b/.github/unittest/linux_libs/scripts_gen-dgrl/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_gym/run_test.sh b/.github/unittest/linux_libs/scripts_gym/run_test.sh index 2e5860468c3..d59c5ce6213 100755 --- a/.github/unittest/linux_libs/scripts_gym/run_test.sh +++ b/.github/unittest/linux_libs/scripts_gym/run_test.sh @@ -6,6 +6,7 @@ eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_habitat/run_test.sh b/.github/unittest/linux_libs/scripts_habitat/run_test.sh index 5c9becfe832..a60fffd8f45 100755 --- a/.github/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.github/unittest/linux_libs/scripts_habitat/run_test.sh @@ -16,6 +16,7 @@ conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_jumanji/run_test.sh b/.github/unittest/linux_libs/scripts_jumanji/run_test.sh index 67f86ed73ee..542daa6eb99 100755 --- a/.github/unittest/linux_libs/scripts_jumanji/run_test.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/run_test.sh @@ -8,6 +8,7 @@ apt-get update && apt-get install -y git wget export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh index 0567e2be25d..30aabf36ac7 100755 --- a/.github/unittest/linux_libs/scripts_minari/run_test.sh +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_openx/run_test.sh b/.github/unittest/linux_libs/scripts_openx/run_test.sh index 00f9f2f4512..f80bd4cc71a 100755 --- a/.github/unittest/linux_libs/scripts_openx/run_test.sh +++ b/.github/unittest/linux_libs/scripts_openx/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh index 1cdb653ede8..7b15bc9113f 100755 --- a/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh +++ b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh @@ -8,6 +8,7 @@ apt-get update && apt-get install -y git wget export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_rlhf/run_test.sh b/.github/unittest/linux_libs/scripts_rlhf/run_test.sh index bdbe1b18ff1..dcfc686ade0 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/run_test.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get install -y git gcc ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh index 50625f1e906..858062125a1 100755 --- a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -66,7 +66,8 @@ conda env config vars set \ DISPLAY=unix:0.0 \ PYOPENGL_PLATFORM=egl \ NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ - sim_backend=MUJOCO + sim_backend=MUJOCO \ + LAZY_LEGACY_OP=False # make env variables apparent conda deactivate && conda activate "${env_dir}" diff --git a/.github/unittest/linux_libs/scripts_roboset/run_test.sh b/.github/unittest/linux_libs/scripts_roboset/run_test.sh index 67ae605a43e..2d954032c3e 100755 --- a/.github/unittest/linux_libs/scripts_roboset/run_test.sh +++ b/.github/unittest/linux_libs/scripts_roboset/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_sklearn/run_test.sh b/.github/unittest/linux_libs/scripts_sklearn/run_test.sh index f07f1ad949e..ec3b9ed31e6 100755 --- a/.github/unittest/linux_libs/scripts_sklearn/run_test.sh +++ b/.github/unittest/linux_libs/scripts_sklearn/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get install -y git gcc ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_smacv2/run_test.sh b/.github/unittest/linux_libs/scripts_smacv2/run_test.sh index 65fd7462df3..f1f130e488a 100755 --- a/.github/unittest/linux_libs/scripts_smacv2/run_test.sh +++ b/.github/unittest/linux_libs/scripts_smacv2/run_test.sh @@ -8,6 +8,7 @@ apt-get update && apt-get install -y git wget export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh b/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh index e0323047a16..f4684ff5f30 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/run_test.sh @@ -9,6 +9,7 @@ apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf ln -s /usr/bin/swig3.0 /usr/bin/swig export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_libs/scripts_vmas/run_test.sh b/.github/unittest/linux_libs/scripts_vmas/run_test.sh index 66934039783..7bcef426e6f 100755 --- a/.github/unittest/linux_libs/scripts_vmas/run_test.sh +++ b/.github/unittest/linux_libs/scripts_vmas/run_test.sh @@ -8,6 +8,7 @@ apt-get update && apt-get install -y git wget export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh index 408b6a748de..c10f65fb4e3 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh @@ -6,6 +6,7 @@ eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/linux_optdeps/scripts/run_test.sh b/.github/unittest/linux_optdeps/scripts/run_test.sh index 5c2d0994f6b..6ebdb427e0a 100755 --- a/.github/unittest/linux_optdeps/scripts/run_test.sh +++ b/.github/unittest/linux_optdeps/scripts/run_test.sh @@ -9,6 +9,7 @@ conda activate ./env STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' diff --git a/.github/unittest/windows_optdepts/scripts/run_test.sh b/.github/unittest/windows_optdepts/scripts/run_test.sh index 9404909d22b..351eb4bfef7 100644 --- a/.github/unittest/windows_optdepts/scripts/run_test.sh +++ b/.github/unittest/windows_optdepts/scripts/run_test.sh @@ -12,6 +12,7 @@ source "$this_dir/set_cuda_envs.sh" export CKPT_BACKEND=torch export MAX_IDLE_COUNT=60 export BATCHED_PIPE_TIMEOUT=60 +export LAZY_LEGACY_OP=False python -m torch.utils.collect_env pytest --junitxml=test-results/junit.xml -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py diff --git a/test/test_collector.py b/test/test_collector.py index 61c1d886c24..ce7cade5746 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -35,6 +35,7 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) +from tensordict import LazyStackedTensorDict from tensordict.nn import TensorDictModule, TensorDictSequential from tensordict.tensordict import assert_allclose_td, TensorDict @@ -1896,7 +1897,7 @@ def test_aggregate_reset_to_root(self): }, [1, 2], ) - td = torch.stack([td0, td1], 0) + td = LazyStackedTensorDict.lazy_stack([td0, td1], 0) assert _aggregate_end_of_traj(td).all() def test_aggregate_reset_to_root_keys(self): @@ -1991,7 +1992,7 @@ def test_aggregate_reset_to_root_keys(self): }, [1, 2], ) - td = torch.stack([td0, td1], 0) + td = LazyStackedTensorDict.lazy_stack([td0, td1], 0) assert _aggregate_end_of_traj(td, reset_keys=["_reset", ("a", "_reset")]).all() def test_aggregate_reset_to_root_errors(self): diff --git a/test/test_env.py b/test/test_env.py index c9965b6b116..00f4bf1dfcf 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1141,10 +1141,12 @@ def test_steptensordict( tds[0]["this", "one"] = torch.zeros(2) tds[1]["but", "not", "this", "one"] = torch.ones(2) tds[0]["next", "this", "one"] = torch.ones(2) * 2 - tensordict = torch.stack(tds, 0) + tensordict = LazyStackedTensorDict.lazy_stack(tds, 0) next_tensordict = TensorDict({}, [4]) if has_out else None if has_out and lazy_stack: - next_tensordict = torch.stack(next_tensordict.unbind(0), 0) + next_tensordict = LazyStackedTensorDict.lazy_stack( + next_tensordict.unbind(0), 0 + ) out = step_mdp( tensordict.lock_(), keep_other=keep_other, @@ -1498,8 +1500,7 @@ def test_heterogeenous( [td_batch_size], ) ) - lazy_td = torch.stack(tds, dim=1) - input_td = lazy_td + lazy_td = LazyStackedTensorDict.lazy_stack(tds, dim=1) td = step_mdp( lazy_td.lock_(), @@ -1785,7 +1786,7 @@ def main_penv(j, q=None): r_p.append(env_s.rollout(100, break_when_any_done=False, policy=policy)) r_s.append(env_p.rollout(100, break_when_any_done=False, policy=policy)) - td_equals = torch.stack(r_p).contiguous() == torch.stack(r_s).contiguous() + td_equals = torch.stack(r_p) == torch.stack(r_s) if td_equals.all(): if q is not None: q.put(("passed", j)) diff --git a/test/test_libs.py b/test/test_libs.py index 13891331b05..5fcf3497139 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -555,7 +555,7 @@ def non_null_obs(batched_td): env_type = type(env0._env) assert_allclose_td(*tdreset, rtol=RTOL, atol=ATOL) - tdrollout = torch.stack(tdrollout, 0).contiguous() + tdrollout = torch.stack(tdrollout, 0) # custom filtering of non-null obs: mujoco rendering sometimes fails # and renders black images. To counter this in the tests, we select @@ -597,7 +597,7 @@ def non_null_obs(batched_td): assert_allclose_td(tdreset[0], tdreset2, rtol=RTOL, atol=ATOL) assert final_seed0 == final_seed2 # same magic trick for mujoco as above - tdrollout = torch.stack([tdrollout[0], rollout2], 0).contiguous() + tdrollout = torch.stack([tdrollout[0], rollout2], 0) idx = non_null_obs(tdrollout) assert_allclose_td( tdrollout[0][..., idx], tdrollout[1][..., idx], rtol=RTOL, atol=ATOL diff --git a/test/test_shared.py b/test/test_shared.py index e7cfa77b137..a2d2a88d6ca 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -9,7 +9,7 @@ import pytest import torch -from tensordict import TensorDict +from tensordict import LazyStackedTensorDict, TensorDict from torch import multiprocessing as mp @@ -81,7 +81,7 @@ def remote_process(command_pipe_child, command_pipe_parent, tensordict): command_pipe_parent.close() assert isinstance(tensordict, TensorDict), f"td is of type {type(tensordict)}" assert tensordict.is_shared() or tensordict.is_memmap() - new_tensordict = torch.stack( + new_tensordict = LazyStackedTensorDict.lazy_stack( [ tensordict[i].contiguous().clone().zero_() for i in range(tensordict.shape[0]) diff --git a/test/test_transforms.py b/test/test_transforms.py index 94ea338365e..2bc9f36a79b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -791,7 +791,7 @@ def test_transform_model(self, dim, N, padding): model(tdbase0) tdbase0.batch_size = [10] tdbase0 = tdbase0.expand(5, 10) - tdbase0_copy = tdbase0.transpose(0, 1).to_tensordict() + tdbase0_copy = tdbase0.transpose(0, 1) tdbase0.refine_names("time", None) tdbase0_copy.names = [None, "time"] v1 = model(tdbase0) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 9b460db5216..ef80f84a428 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -11,7 +11,7 @@ from torch import multiprocessing as mp -set_lazy_legacy(True).set() +set_lazy_legacy(False).set() if torch.cuda.device_count() > 1: n = torch.cuda.device_count() - 1 diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 87145f26847..eee3b3e4a98 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -6,6 +6,8 @@ from typing import Callable import torch + +from tensordict import set_lazy_legacy from tensordict.tensordict import pad, TensorDictBase @@ -25,6 +27,7 @@ def stacked_output_fun(*args, **kwargs): return stacked_output_fun +@set_lazy_legacy(False) def split_trajectories( rollout_tensordict: TensorDictBase, prefix=None ) -> TensorDictBase: @@ -88,7 +91,7 @@ def split_trajectories( ), ) if rollout_tensordict.ndimension() == 1: - rollout_tensordict = rollout_tensordict.unsqueeze(0).to_tensordict() + rollout_tensordict = rollout_tensordict.unsqueeze(0) return rollout_tensordict.unflatten_keys(sep) out_splits = rollout_tensordict.view(-1).split(splits, 0) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 187bd083f09..96fd0c4f165 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -348,7 +348,7 @@ def _check_for_empty_spec(specs: CompositeSpec): self.done_spec = output_spec["full_done_spec"] self._dummy_env_str = str(meta_data[0]) - self._env_tensordict = torch.stack( + self._env_tensordict = LazyStackedTensorDict.lazy_stack( [meta_data.tensordict for meta_data in meta_data], 0 ) self._batch_locked = meta_data[0].batch_locked @@ -463,7 +463,7 @@ def _create_td(self) -> None: ) for tensordict in shared_tensordict_parent ] - shared_tensordict_parent = torch.stack( + shared_tensordict_parent = LazyStackedTensorDict.lazy_stack( shared_tensordict_parent, 0, ) @@ -474,7 +474,9 @@ def _create_td(self) -> None: self.shared_tensordicts = [ td.clone() for td in self.shared_tensordict_parent.unbind(0) ] - self.shared_tensordict_parent = torch.stack(self.shared_tensordicts, 0) + self.shared_tensordict_parent = LazyStackedTensorDict.lazy_stack( + self.shared_tensordicts, 0 + ) else: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 87a51e6bef5..39484ac355a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,7 +14,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict import unravel_key +from tensordict import LazyStackedTensorDict, unravel_key from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torchrl._utils import _replace_last, implement_for, prod, seed_generator @@ -2307,9 +2307,12 @@ def rollout( else: tensordicts = self._rollout_nonstop(**kwargs) batch_size = self.batch_size if tensordict is None else tensordict.batch_size - out_td = torch.stack(tensordicts, len(batch_size), out=out) if return_contiguous: - out_td = out_td.contiguous() + out_td = torch.stack(tensordicts, len(batch_size), out=out) + else: + out_td = LazyStackedTensorDict.lazy_stack( + tensordicts, len(batch_size), out=out + ) out_td.refine_names(..., "time") return out_td @@ -2408,7 +2411,7 @@ def step_and_maybe_reset( ... for i in range(n): ... data, data_ = env.step_and_maybe_reset(data_) ... result.append(data) - ... return torch.stack(result).contiguous() + ... return torch.stack(result) >>> env = ParallelEnv(2, lambda: GymEnv("CartPole-v1")) >>> print(rollout(env, 2)) TensorDict( diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 9c10c15b2e4..05017a8a8ec 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -6,7 +6,7 @@ from typing import List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import set_lazy_legacy, TensorDict, TensorDictBase from torch.hub import load_state_dict_from_url from torch.nn import Identity @@ -84,9 +84,10 @@ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True): self.convnet = convnet self.del_keys = del_keys + @set_lazy_legacy(False) def _call(self, tensordict): - tensordict_view = tensordict.view(-1) - super()._call(tensordict_view) + with tensordict.view(-1) as tensordict_view: + super()._call(tensordict_view) if self.del_keys: tensordict.exclude(*self.in_keys, inplace=True) return tensordict diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 41ce74293bb..07bc29d1b59 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -20,6 +20,7 @@ from tensordict import ( is_tensor_collection, NonTensorData, + set_lazy_legacy, unravel_key, unravel_key_list, ) @@ -2883,6 +2884,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: return self.unfolding(tensordict) + @set_lazy_legacy(False) def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension if not tensordict.ndim: @@ -2972,6 +2974,8 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: *range(data.ndim + self.dim, data.ndim - 1), ) tensordict.set(out_key, data) + if tensordict_orig is not tensordict: + tensordict_orig = tensordict.transpose(tensordict.ndim - 1, i) return tensordict_orig def __repr__(self) -> str: diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index 5cb038b699a..252ddfc4a90 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -167,8 +167,8 @@ def _call(self, tensordict): if in_key != out_key ] saved_td = tensordict.select(*in_keys) - tensordict_view = tensordict.view(-1) - super()._call(self.model_transforms(tensordict_view)) + with tensordict.view(-1) as tensordict_view: + super()._call(self.model_transforms(tensordict_view)) if self.del_keys: tensordict.exclude(*self.in_keys, inplace=True) else: diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 48110387fad..289dd60f053 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import torch -from tensordict import TensorDict +from tensordict import set_lazy_legacy, TensorDict from tensordict.tensordict import TensorDictBase from torch.hub import load_state_dict_from_url @@ -72,9 +72,11 @@ def __init__(self, in_keys, out_keys, model_name="resnet50", del_keys: bool = Tr self.convnet = convnet self.del_keys = del_keys + @set_lazy_legacy(False) def _call(self, tensordict): - tensordict_view = tensordict.view(-1) - super()._call(tensordict_view) + with tensordict.view(-1) as tensordict_view: + super()._call(tensordict_view) + if self.del_keys: tensordict.exclude(*self.in_keys, inplace=True) return tensordict diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f505def52af..82f0c2d21fb 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -450,7 +450,9 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) else: - fake_tensordict = torch.stack([fake_tensordict.clone() for _ in range(3)], -1) + fake_tensordict = LazyStackedTensorDict.lazy_stack( + [fake_tensordict.clone() for _ in range(3)], -1 + ) # eliminate empty containers fake_tensordict_select = fake_tensordict.select(*fake_tensordict.keys(True, True)) real_tensordict_select = real_tensordict.select(*real_tensordict.keys(True, True)) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index e392070c517..fe970c292be 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -12,7 +12,7 @@ from tensordict.nn import TensorDictModuleBase as ModuleBase from tensordict.tensordict import NO_DEFAULT -from tensordict.utils import expand_as_right, prod +from tensordict.utils import expand_as_right, prod, set_lazy_legacy from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase @@ -1297,6 +1297,7 @@ def set_recurrent_mode(self, mode: bool = True): out._recurrent_mode = mode return out + @set_lazy_legacy(False) def forward(self, tensordict: TensorDictBase): # we want to get an error if the value input is missing, but not the hidden states defaults = [NO_DEFAULT, None] @@ -1318,8 +1319,6 @@ def forward(self, tensordict: TensorDictBase): ) else: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) - # TODO: replace by contiguous, or ultimately deprecate the default lazy unsqueeze - tensordict_shaped = tensordict_shaped.to_tensordict() is_init = tensordict_shaped.get("is_init").squeeze(-1) splits = None diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 3782de64fa2..c57642a7237 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -8,6 +8,8 @@ from typing import Optional, Sequence import torch + +from tensordict import set_lazy_legacy from tensordict.nn import InteractionType from torch import distributions as d, nn @@ -450,6 +452,7 @@ def make_redq_model( return model +@set_lazy_legacy(False) def make_dreamer( cfg: "DictConfig", # noqa: F821 proof_environment: EnvBase = None, @@ -511,7 +514,6 @@ def make_dreamer( ).to(device) with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): tensordict = proof_environment.fake_tensordict().unsqueeze(-1) - tensordict = tensordict.to_tensordict().to(device) tensordict = tensordict.to(device) world_model(tensordict) From 156a668e7912660776d1aa4c4b16fd6936370dee Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Jan 2024 14:55:45 +0000 Subject: [PATCH 15/35] [Feature] `serial_for_single` arg in batched envs (#1846) --- examples/a2c/utils_atari.py | 4 +- examples/cql/utils.py | 2 + examples/ddpg/utils.py | 2 + examples/decision_transformer/utils.py | 2 +- examples/discrete_sac/utils.py | 2 + .../collectors/single_machine/generic.py | 6 +- .../collectors/single_machine/rpc.py | 6 +- .../collectors/single_machine/sync.py | 6 +- examples/dreamer/dreamer_utils.py | 1 + examples/iql/utils.py | 2 + examples/ppo/utils_atari.py | 4 +- examples/redq/utils.py | 1 + examples/sac/utils.py | 2 + examples/td3/utils.py | 2 + test/test_env.py | 8 ++ torchrl/_utils.py | 2 +- torchrl/collectors/collectors.py | 2 +- torchrl/envs/batched_envs.py | 105 +++++++++++++++++- torchrl/envs/gym_like.py | 2 +- torchrl/objectives/dqn.py | 2 +- torchrl/trainers/helpers/collectors.py | 2 +- torchrl/trainers/helpers/losses.py | 2 +- 22 files changed, 150 insertions(+), 17 deletions(-) diff --git a/examples/a2c/utils_atari.py b/examples/a2c/utils_atari.py index 63d15557700..7b3625b1e2b 100644 --- a/examples/a2c/utils_atari.py +++ b/examples/a2c/utils_atari.py @@ -61,7 +61,9 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): env = ParallelEnv( - num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + num_envs, + EnvCreator(lambda: make_base_env(env_name, device=device)), + serial_for_single=True, ) env = TransformedEnv(env) env.append_transform(ToTensorImage()) diff --git a/examples/cql/utils.py b/examples/cql/utils.py index 828b370a559..0af1a082e28 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -80,6 +80,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1): parallel_env = ParallelEnv( train_num_envs, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -89,6 +90,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1): ParallelEnv( eval_num_envs, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ), train_env.transform.clone(), ) diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 2260e220b4b..935fb426988 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -76,6 +76,7 @@ def make_environment(cfg): parallel_env = ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -87,6 +88,7 @@ def make_environment(cfg): ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ), train_env.transform.clone(), ) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 7cb5b52b6ea..940e26a5c0a 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -142,7 +142,7 @@ def make_env(): return make_base_env(env_cfg) env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(make_env)), + ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True), env_cfg, obs_loc, obs_std, diff --git a/examples/discrete_sac/utils.py b/examples/discrete_sac/utils.py index f7d581ce7e2..49ec8bc1204 100644 --- a/examples/discrete_sac/utils.py +++ b/examples/discrete_sac/utils.py @@ -77,6 +77,7 @@ def make_environment(cfg): parallel_env = ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -88,6 +89,7 @@ def make_environment(cfg): ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ), train_env.transform.clone(), ) diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index 9c1fd9976f0..77dbf4a7cde 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -100,7 +100,11 @@ def gym_make(): if args.worker_parallelism == "collector" or num_workers == 1: action_spec = make_env().action_spec else: - make_env = ParallelEnv(num_workers, make_env) + make_env = ParallelEnv( + num_workers, + make_env, + serial_for_single=True, + ) action_spec = make_env.action_spec if args.worker_parallelism == "collector" and num_workers > 1: diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 7de1cf5aad0..4ca9e9f4a3e 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -96,7 +96,11 @@ def gym_make(): if num_workers == 1: action_spec = make_env().action_spec else: - make_env = ParallelEnv(num_workers, make_env) + make_env = ParallelEnv( + num_workers, + make_env, + serial_for_single=True, + ) action_spec = make_env.action_spec collector = RPCDataCollector( diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index 7f3d62efa45..8b3bd02aad2 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -95,7 +95,11 @@ def gym_make(): if args.worker_parallelism == "collector" or num_workers == 1: action_spec = make_env().action_spec else: - make_env = ParallelEnv(num_workers, make_env) + make_env = ParallelEnv( + num_workers, + make_env, + serial_for_single=True, + ) action_spec = make_env.action_spec if args.worker_parallelism == "collector" and num_workers > 1: diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 385e4a53aab..51593a33caa 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -270,6 +270,7 @@ def parallel_env_constructor( create_env_kwargs=None, pin_memory=cfg.pin_memory, device=cfg.collector_device, + serial_for_single=True, ) if batch_transform: kwargs.update( diff --git a/examples/iql/utils.py b/examples/iql/utils.py index 31dcb732b00..fe1e5ce32b8 100644 --- a/examples/iql/utils.py +++ b/examples/iql/utils.py @@ -84,6 +84,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1): parallel_env = ParallelEnv( train_num_envs, EnvCreator(lambda: env_maker(cfg)), + serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -93,6 +94,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1): ParallelEnv( eval_num_envs, EnvCreator(lambda: env_maker(cfg)), + serial_for_single=True, ), train_env.transform.clone(), ) diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py index 1355212ed70..c78bc67f45a 100644 --- a/examples/ppo/utils_atari.py +++ b/examples/ppo/utils_atari.py @@ -60,7 +60,9 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): env = ParallelEnv( - num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + num_envs, + EnvCreator(lambda: make_base_env(env_name, device=device)), + serial_for_single=True, ) env = TransformedEnv(env) env.append_transform(ToTensorImage()) diff --git a/examples/redq/utils.py b/examples/redq/utils.py index 76ddf4ad302..ef377903202 100644 --- a/examples/redq/utils.py +++ b/examples/redq/utils.py @@ -747,6 +747,7 @@ def parallel_env_constructor( num_workers=cfg.collector.env_per_collector, create_env_fn=make_transformed_env, create_env_kwargs=None, + serial_for_single=True, pin_memory=False, ) if batch_transform: diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 69c7b7c7658..1e157ce85cd 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -68,6 +68,7 @@ def make_environment(cfg): parallel_env = ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -77,6 +78,7 @@ def make_environment(cfg): ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ), train_env.transform.clone(), ) diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 36d3ef99a9a..0abc769d365 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -78,6 +78,7 @@ def make_environment(cfg): parallel_env = ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -89,6 +90,7 @@ def make_environment(cfg): ParallelEnv( cfg.collector.env_per_collector, EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, ), train_env.transform.clone(), ) diff --git a/test/test_env.py b/test/test_env.py index 00f4bf1dfcf..a6f6873f9ba 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -395,6 +395,14 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): env.shared_tensordict_parent.device.type == torch.device(edevice).type ) + def test_serial_for_single(self): + env = ParallelEnv(1, ContinuousActionVecMockEnv, serial_for_single=True) + assert isinstance(env, SerialEnv) + env = ParallelEnv(1, ContinuousActionVecMockEnv) + assert isinstance(env, ParallelEnv) + env = ParallelEnv(2, ContinuousActionVecMockEnv, serial_for_single=True) + assert isinstance(env, ParallelEnv) + @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) def test_env_with_batch_size(self, num_parallel_env, env_batch_size): diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 93b72483ce8..98abe9648c6 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -565,7 +565,7 @@ def ctx_factory(): if inspect.isclass(func): raise RuntimeError( - "Cannot decorate classes; it is ambiguous whether or not only the " + "Cannot decorate classes; it is ambiguous whether only the " "constructor or all methods should have the context manager applied; " "additionally, decorating a class at definition-site will prevent " "use of the identifier as a conventional type. " diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 8d1762f8465..08c86ffea45 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1993,7 +1993,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. Defaults to ``None`` (i.e. no random frames) - reset_at_each_iter (bool): Whether or not environments should be reset for each batch. + reset_at_each_iter (bool): whether environments should be reset for each batch. default=False. postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a useful format for training. diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 96fd0c4f165..a8d399ea08b 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -25,7 +25,7 @@ from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE from torchrl.data.tensor_specs import CompositeSpec from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING -from torchrl.envs.common import EnvBase +from torchrl.envs.common import _EnvPostInit, EnvBase from torchrl.envs.env_creator import get_env_metadata # legacy @@ -104,6 +104,19 @@ def new_fun(self, *args, **kwargs): return new_fun +class _PEnvMeta(_EnvPostInit): + def __call__(cls, *args, **kwargs): + serial_for_single = kwargs.pop("serial_for_single", False) + if serial_for_single: + num_workers = kwargs.get("num_workers", None) + if num_workers is None: + num_workers = args[0] + if num_workers == 1: + # We still use a serial to keep the shape unchanged + return SerialEnv(*args, **kwargs) + return super().__call__(*args, **kwargs) + + class _BatchedEnv(EnvBase): """Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely. @@ -120,12 +133,14 @@ class _BatchedEnv(EnvBase): If a single task is used, a callable should be used and not a list of identical callables: if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were needed, which comes with a slight compute overhead; + + Keyword Args: create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created; share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy stack is returned. default = None (False if single task); - shared_memory (bool): whether or not the returned tensordict will be placed in shared memory; - memmap (bool): whether or not the returned tensordict will be placed in memory map. + shared_memory (bool): whether the returned tensordict will be placed in shared memory; + memmap (bool): whether the returned tensordict will be placed in memory map. policy_proof (callable, optional): if provided, it'll be used to get the list of tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc. device (str, int, torch.device): The device of the batched environment can be passed. @@ -147,7 +162,84 @@ class _BatchedEnv(EnvBase): Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance. This parameter has no effect for the :class:`~SerialEnv` class. - + serial_for_single (bool, optional): if ``True``, creating a parallel environment + with a single worker will return a :class:`~SerialEnv` instead. + This option has no effect with :class:`~SerialEnv`. Defaults to ``False``. + + Examples: + >>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator + >>> make_env = EnvCreator(lambda: GymEnv("Pendulum-v1")) # EnvCreator ensures that the env is sharable. Optional in most cases. + >>> env = SerialEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on the same process serially + >>> env = ParallelEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on dedicated processes + >>> from torchrl.envs import DMControlEnv + >>> env = ParallelEnv(2, [ + ... lambda: DMControlEnv("humanoid", "stand"), + ... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands + >>> r = env.rollout(10) # executes 10 random steps in the environment + >>> r[0] # data for Humanoid stand + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), + com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False), + head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False), + joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), + next: TensorDict( + fields={ + com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False), + head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False), + joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), + reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + >>> r[1] # data for Humanoid walk + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), + com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False), + head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False), + joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), + next: TensorDict( + fields={ + com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False), + head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False), + joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), + reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + >>> env = ParallelEnv(1, make_env, serial_for_single=True) + >>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary """ _verbose: bool = VERBOSE @@ -162,6 +254,7 @@ def __init__( self, num_workers: int, create_env_fn: Union[Callable[[], EnvBase], Sequence[Callable[[], EnvBase]]], + *, create_env_kwargs: Union[dict, Sequence[dict]] = None, pin_memory: bool = False, share_individual_td: Optional[bool] = None, @@ -172,8 +265,10 @@ def __init__( allow_step_when_done: bool = False, num_threads: int = None, num_sub_threads: int = 1, + serial_for_single: bool = False, ): super().__init__(device=device) + self.serial_for_single = serial_for_single self.is_closed = True if num_threads is None: num_threads = num_workers + 1 # 1 more thread for this proc @@ -760,7 +855,7 @@ def to(self, device: DEVICE_TYPING): return self -class ParallelEnv(_BatchedEnv): +class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta): """Creates one environment per process. TensorDicts are passed via shared memory or memory map. diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 6d0b3dc0213..7eca6f5a1db 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -478,7 +478,7 @@ def auto_register_info_dict(self): within the tensordict. This method returns a (possibly transformed) environment where we make sure that - the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether or not + the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether the info is filled at reset time. This method requires running a few iterations in the environment to diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index dc0ed1e1df4..623c3f7189a 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -45,7 +45,7 @@ class DQNLoss(LossModule): delay_value (bool, optional): whether to duplicate the value network into a new target value network to create a DQN with a target network. Default is ``False``. - double_dqn (bool, optional): whether or not to use Double DQN, as described in + double_dqn (bool, optional): whether to use Double DQN, as described in https://arxiv.org/abs/1509.06461. Defaults to ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 418dd638269..f8f9c55809b 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -411,7 +411,7 @@ class OffPolicyCollectorConfig(OnPolicyCollectorConfig): """Off-policy collector config struct.""" multi_step: bool = False - # whether or not multi-step rewards should be used. + # whether multi-step rewards should be used. n_steps_return: int = 3 # If multi_step is set to True, this value defines the number of steps to look ahead for the reward computation. init_random_frames: int = 50000 diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 0adff694d3f..a949bea6718 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -147,7 +147,7 @@ class PPOLossConfig: lmbda: float = 0.95 # lambda factor in GAE (using 'lambda' as attribute is prohibited in python, hence the misspelling) entropy_bonus: bool = True - # Whether or not to add an entropy term to the PPO loss. + # whether to add an entropy term to the PPO loss. entropy_coef: float = 1e-3 # Entropy factor for the PPO loss samples_mc_entropy: int = 1 From 79374d82c7120ba2c738038b9f32532978421e02 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Jan 2024 18:50:12 +0000 Subject: [PATCH 16/35] [BugFix] Fix VD4RL (#1834) --- torchrl/data/datasets/vd4rl.py | 53 ++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index f107804ae84..54e933f71f5 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import functools + import importlib import json import logging @@ -12,7 +14,6 @@ import shutil import tempfile from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Callable, List @@ -20,6 +21,7 @@ import torch from tensordict import PersistentTensorDict, TensorDict +from torch import multiprocessing as mp from torchrl._utils import KeyDependentDefaultDict from torchrl.data.datasets.utils import _get_root_dir @@ -96,6 +98,8 @@ class VD4RLExperienceReplay(TensorDictReplayBuffer): transform that will be appended to the transform list. Supports `int` types (square resizing) or a list/tuple of `int` (rectangular resizing). Defaults to ``None`` (no resizing). + num_workers (int, optional): the number of workers to download the files. + Defaults to ``0`` (no multiprocessing). Attributes: available_datasets: a list of accepted entries to be downloaded. These @@ -173,6 +177,7 @@ def __init__( split_trajs: bool = False, totensor: bool = True, image_size: int | List[int] | None = None, + num_workers: int = 0, **env_kwargs, ): if not _has_h5py or not _has_hf_hub: @@ -191,6 +196,7 @@ def __init__( self.root = root self.split_trajs = split_trajs self.download = download + self.num_workers = num_workers if self.download == "force" or (self.download and not self._is_downloaded()): if self.download == "force": try: @@ -199,7 +205,9 @@ def __init__( shutil.rmtree(self.data_path) except FileNotFoundError: pass - storage = self._download_and_preproc(dataset_id, data_path=self.data_path) + storage = self._download_and_preproc( + dataset_id, data_path=self.data_path, num_workers=self.num_workers + ) elif self.split_trajs and not os.path.exists(self.data_path): storage = self._make_split() else: @@ -251,14 +259,23 @@ def _parse_datasets(cls): return sibs @classmethod - def _download_and_preproc(cls, dataset_id, data_path): + def _hf_hub_download(cls, subfolder, filename, *, tmpdir): from huggingface_hub import hf_hub_download - files = [] + return hf_hub_download( + "conglu/vd4rl", + subfolder=subfolder, + filename=filename, + repo_type="dataset", + cache_dir=str(tmpdir), + ) + + @classmethod + def _download_and_preproc(cls, dataset_id, data_path, num_workers): + tds = [] with tempfile.TemporaryDirectory() as tmpdir: sibs = cls._parse_datasets() - # files = [] total_steps = 0 paths_to_proc = [] @@ -270,19 +287,19 @@ def _download_and_preproc(cls, dataset_id, data_path): for file in sibs[path]: paths_to_proc.append(str(path)) files_to_proc.append(str(file.parts[-1])) - - with ThreadPoolExecutor(32) as executor: - files = executor.map( - lambda path_file: hf_hub_download( - "conglu/vd4rl", - subfolder=path_file[0], - filename=path_file[1], - repo_type="dataset", - cache_dir=str(tmpdir), - ), - zip(paths_to_proc, files_to_proc), - ) - files = list(files) + func = functools.partial(cls._hf_hub_download, tmpdir=tmpdir) + if num_workers > 0: + with mp.Pool(num_workers) as pool: + files = pool.starmap( + func, + zip(paths_to_proc, files_to_proc), + ) + files = list(files) + else: + files = [ + func(subfolder, filename) + for (subfolder, filename) in zip(paths_to_proc, files_to_proc) + ] logging.info("Downloaded, processing files") if _has_tqdm: import tqdm From 6277226411fe88142fd0bc1627204f3c813b64a1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Jan 2024 20:51:23 +0000 Subject: [PATCH 17/35] [Doc] Make tutos runnable without colab (#1826) --- torchrl/objectives/ppo.py | 23 +++++++++++-------- tutorials/sphinx-tutorials/coding_ddpg.py | 10 ++++++++ tutorials/sphinx-tutorials/coding_dqn.py | 12 ++++++++++ tutorials/sphinx-tutorials/coding_ppo.py | 16 +++++++++++++ tutorials/sphinx-tutorials/dqn_with_rnn.py | 16 +++++++++++++ tutorials/sphinx-tutorials/multi_task.py | 11 +++++++++ tutorials/sphinx-tutorials/multiagent_ppo.py | 4 ++-- tutorials/sphinx-tutorials/pendulum.py | 17 ++++++++++++++ .../sphinx-tutorials/pretrained_models.py | 17 ++++++++++++++ tutorials/sphinx-tutorials/rb_tutorial.py | 17 ++++++++++++++ tutorials/sphinx-tutorials/torchrl_demo.py | 12 ++++++++++ tutorials/sphinx-tutorials/torchrl_envs.py | 12 ++++++++++ 12 files changed, 156 insertions(+), 11 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 5533dfd74d9..542877f8f20 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -345,17 +345,19 @@ def functional(self): @property def actor(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network @property def critic(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network @@ -363,23 +365,26 @@ def critic(self): def actor_params(self): logging.warning( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network_params @property def critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network_params @property def target_critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.target_critic_network_params diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 319206e6e5d..1f69651a3b4 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -57,6 +57,16 @@ from typing import Tuple warnings.filterwarnings("ignore") +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + # sphinx_gallery_end_ignore import torch.cuda diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index ef05c2d977f..fcddd699b3a 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -87,6 +87,18 @@ import warnings warnings.filterwarnings("ignore") + +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + + # sphinx_gallery_end_ignore import os diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 56f96221a40..679d625220c 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -104,6 +104,22 @@ # description and more about the algorithm itself. # +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + +# sphinx_gallery_end_ignore + from collections import defaultdict import matplotlib.pyplot as plt diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 14470617eef..a1c82d5c429 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -68,6 +68,22 @@ # ----- # +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + +# sphinx_gallery_end_ignore + import torch import tqdm from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 76eefc4b671..a12c2b05ff8 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -13,6 +13,17 @@ import warnings warnings.filterwarnings("ignore") + +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + # sphinx_gallery_end_ignore import torch diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index f32d2d93b2f..90fd82dab3c 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -659,8 +659,8 @@ with torch.no_grad(): GAE( tensordict_data, - params=loss_module.critic_params, - target_params=loss_module.target_critic_params, + params=loss_module.critic_network_params, + target_params=loss_module.target_critic_network_params, ) # Compute GAE and add it to the data data_view = tensordict_data.reshape(-1) # Flatten the batch size to shuffle data diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 889c9616a2b..b72d2ff0f92 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -73,6 +73,23 @@ # simulation graph. # * Finally, we will train a simple policy to solve the system we implemented. # + +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + +# sphinx_gallery_end_ignore + from collections import defaultdict from typing import Optional diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index 24c4dee726e..e8abf33cef8 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -13,6 +13,23 @@ # in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601), # but other models (e.g. VIP) will work equally well. # + +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + +# sphinx_gallery_end_ignore + import torch.cuda from tensordict.nn import TensorDictSequential from torch import nn diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index ab3a99cdece..6106e3cf65a 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -46,6 +46,23 @@ # replay buffer is a straightforward process, as shown in the following # example: # + +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + +# sphinx_gallery_end_ignore + import tempfile from torchrl.data import ReplayBuffer diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 9da761696d4..d1a261e63f5 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -128,6 +128,18 @@ import warnings warnings.filterwarnings("ignore") + +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + + # sphinx_gallery_end_ignore import torch diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index ef995030c9d..dc836b43150 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -30,6 +30,18 @@ import warnings warnings.filterwarnings("ignore") + +from torch import multiprocessing + +# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside +# `__main__` method call, but for the easy of reading the code switch to fork +# which is also a default spawn method in Google's Colaboratory +try: + multiprocessing.set_start_method("fork") +except RuntimeError: + assert multiprocessing.get_start_method() == "fork" + + # sphinx_gallery_end_ignore import torch From b1cc7962ba92a529ed997e3d0210af2cc6d370f3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 30 Jan 2024 20:37:54 +0000 Subject: [PATCH 18/35] [Feature] Fine control over devices in collectors (#1835) --- benchmarks/ecosystem/gym_env_throughput.py | 4 - test/mocking_classes.py | 16 +- test/test_collector.py | 596 ++++++-- test/test_cost.py | 6 +- test/test_distributions.py | 2 +- test/test_env.py | 28 +- test/test_exploration.py | 2 +- test/test_libs.py | 3 +- test/test_postprocs.py | 2 +- test/test_rb.py | 10 +- test/test_rb_distributed.py | 2 +- test/test_specs.py | 77 +- test/test_transforms.py | 3 +- torchrl/collectors/collectors.py | 1280 +++++++++++------ torchrl/collectors/distributed/generic.py | 196 ++- torchrl/collectors/distributed/ray.py | 138 +- torchrl/collectors/distributed/rpc.py | 182 ++- torchrl/collectors/distributed/sync.py | 185 ++- torchrl/collectors/utils.py | 3 +- torchrl/data/datasets/d4rl.py | 3 +- torchrl/data/datasets/openml.py | 2 +- torchrl/data/postprocs/postprocs.py | 2 +- torchrl/data/replay_buffers/replay_buffers.py | 7 +- torchrl/data/replay_buffers/storages.py | 3 +- torchrl/data/rlhf/dataset.py | 2 +- torchrl/data/tensor_specs.py | 212 ++- torchrl/envs/batched_envs.py | 85 +- torchrl/envs/common.py | 149 +- torchrl/envs/env_creator.py | 2 +- torchrl/envs/gym_like.py | 3 +- torchrl/envs/libs/brax.py | 2 +- torchrl/envs/libs/gym.py | 6 +- torchrl/envs/libs/jax_utils.py | 2 +- torchrl/envs/libs/jumanji.py | 2 +- torchrl/envs/libs/openml.py | 2 +- torchrl/envs/libs/pettingzoo.py | 2 +- torchrl/envs/libs/robohive.py | 3 +- torchrl/envs/libs/vmas.py | 2 +- torchrl/envs/transforms/transforms.py | 3 +- torchrl/envs/transforms/vip.py | 3 +- torchrl/envs/utils.py | 9 +- torchrl/modules/models/recipes/impala.py | 2 +- torchrl/modules/planners/cem.py | 2 +- torchrl/modules/planners/common.py | 2 +- torchrl/modules/planners/mppi.py | 2 +- torchrl/modules/tensordict_module/common.py | 3 +- .../modules/tensordict_module/exploration.py | 2 +- torchrl/modules/tensordict_module/rnn.py | 4 +- torchrl/objectives/a2c.py | 4 +- torchrl/objectives/cql.py | 4 +- torchrl/objectives/ddpg.py | 4 +- torchrl/objectives/decision_transformer.py | 2 +- torchrl/objectives/deprecated.py | 3 +- torchrl/objectives/iql.py | 8 +- torchrl/objectives/ppo.py | 4 +- torchrl/objectives/redq.py | 4 +- torchrl/objectives/reinforce.py | 4 +- torchrl/objectives/sac.py | 4 +- torchrl/objectives/td3.py | 6 +- torchrl/objectives/utils.py | 2 +- torchrl/objectives/value/advantages.py | 2 +- torchrl/record/recorder.py | 2 +- torchrl/trainers/helpers/collectors.py | 3 +- torchrl/trainers/trainers.py | 2 +- tutorials/sphinx-tutorials/coding_ddpg.py | 2 +- tutorials/sphinx-tutorials/pendulum.py | 2 +- 66 files changed, 2334 insertions(+), 986 deletions(-) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index c69fc985ded..13adacd5868 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -115,7 +115,6 @@ def make(envname=envname, gym_backend=gym_backend): frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -178,7 +177,6 @@ def make_env(envname=envname, gym_backend=gym_backend): frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -222,7 +220,6 @@ def make_env( total_frames=num_workers * 10_000, num_sub_threads=num_workers // num_collectors, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -260,7 +257,6 @@ def make_env(envname=envname, gym_backend=gym_backend): frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 diff --git a/test/mocking_classes.py b/test/mocking_classes.py index def8ddae1d5..9e5b2ff6879 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -6,7 +6,8 @@ import torch import torch.nn as nn -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( @@ -229,6 +230,7 @@ def _step(self, tensordict): "observation": n.clone(), }, batch_size=[], + device=self.device, ) def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: @@ -240,7 +242,9 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict( - {"done": done, "terminated": done.clone(), "observation": n}, [] + {"done": done, "terminated": done.clone(), "observation": n}, + [], + device=self.device, ) def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: @@ -1374,8 +1378,9 @@ def _step( return tensordict -class HeteroCountingEnvPolicy: +class HeterogeneousCountingEnvPolicy(TensorDictModuleBase): def __init__(self, full_action_spec: TensorSpec, count: bool = True): + super().__init__() self.full_action_spec = full_action_spec self.count = count @@ -1386,7 +1391,7 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.update(action_td) -class HeteroCountingEnv(EnvBase): +class HeterogeneousCountingEnv(EnvBase): """A heterogeneous, counting Env.""" def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): @@ -1569,13 +1574,14 @@ def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) -class MultiKeyCountingEnvPolicy: +class MultiKeyCountingEnvPolicy(TensorDictModuleBase): def __init__( self, full_action_spec: TensorSpec, count: bool = True, deterministic: bool = False, ): + super().__init__() if not deterministic and not count: raise ValueError("Not counting policy is always deterministic") diff --git a/test/test_collector.py b/test/test_collector.py index ce7cade5746..2e090ad8fcf 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import argparse import logging @@ -11,13 +12,16 @@ import numpy as np import pytest import torch + from _utils_internal import ( check_rollout_consistency_multikey_env, decorate_thread_sub_func, generate_seeds, + get_available_devices, get_default_devices, PENDULUM_VERSIONED, PONG_VERSIONED, + retry, ) from mocking_classes import ( ContinuousActionVecMockEnv, @@ -28,16 +32,15 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, - HeteroCountingEnv, - HeteroCountingEnvPolicy, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, MockSerialEnv, MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict import LazyStackedTensorDict -from tensordict.nn import TensorDictModule, TensorDictSequential -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import nn from torchrl._utils import _replace_last, prod, seed_generator @@ -95,16 +98,15 @@ def __init__(self, out_features: int): self.linear = nn.LazyLinear(out_features) def forward(self, tensordict): - return TensorDict( - {self.out_keys[0]: self.linear(tensordict.get(self.in_keys[0]))}, - [], + return tensordict.set( + self.out_keys[0], self.linear(tensordict.get(self.in_keys[0])) ) class UnwrappablePolicy(nn.Module): def __init__(self, out_features: int): super().__init__() - self.linear = nn.LazyLinear(out_features) + self.linear = nn.Linear(2, out_features) def forward(self, observation, other_stuff): return self.linear(observation), other_stuff.sum() @@ -163,110 +165,360 @@ def make_policy(env): raise NotImplementedError -def _is_consistent_device_type( - device_type, policy_device_type, storing_device_type, tensordict_device_type -): - if storing_device_type is None: - if device_type is None: - if policy_device_type is None: - return tensordict_device_type == "cpu" +# def _is_consistent_device_type( +# device_type, policy_device_type, storing_device_type, tensordict_device_type +# ): +# if storing_device_type is None: +# if device_type is None: +# if policy_device_type is None: +# return tensordict_device_type == "cpu" +# +# return tensordict_device_type == policy_device_type +# +# return tensordict_device_type == device_type +# +# return tensordict_device_type == storing_device_type - return tensordict_device_type == policy_device_type - return tensordict_device_type == device_type +class TestCollectorDevices: + class DeviceLessEnv(EnvBase): + # receives data on cpu, outputs on gpu -- tensordict has no device + def __init__(self, default_device): + self.default_device = default_device + super().__init__(device=None) + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec((), device=default_device) + ) + self.reward_spec = UnboundedContinuousTensorSpec(1, device=default_device) + self.full_done_spec = CompositeSpec( + done=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + truncated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + terminated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + ) + self.action_spec = UnboundedContinuousTensorSpec((), device=None) + assert self.device is None + assert self.full_observation_spec is not None + assert self.full_done_spec is not None + assert self.full_state_spec is not None + assert self.full_action_spec is not None + assert self.full_reward_spec is not None + + def _step(self, tensordict): + assert tensordict.device is None + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "reward": torch.zeros((1,)), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=None, + ) - return tensordict_device_type == storing_device_type + def _reset(self, tensordict=None): + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=None, + ) + def _set_seed(self, seed: int | None = None): + return seed -@pytest.mark.skipif( - IS_WINDOWS and PYTHON_3_10, - reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection", -) -@pytest.mark.parametrize("num_env", [2]) -@pytest.mark.parametrize("device", ["cuda", "cpu", None]) -@pytest.mark.parametrize("policy_device", ["cuda", "cpu", None]) -@pytest.mark.parametrize("storing_device", ["cuda", "cpu", None]) -def test_output_device_consistency( - num_env, device, policy_device, storing_device, seed=40 -): - if ( - device == "cuda" or policy_device == "cuda" or storing_device == "cuda" - ) and not torch.cuda.is_available(): - pytest.skip("cuda is not available") - - if IS_WINDOWS and PYTHON_3_7: - if device == "cuda" and policy_device == "cuda" and device is None: - pytest.skip( - "BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows" + class EnvWithDevice(EnvBase): + def __init__(self, default_device): + self.default_device = default_device + super().__init__(device=self.default_device) + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec( + (), device=self.default_device + ) ) + self.reward_spec = UnboundedContinuousTensorSpec( + 1, device=self.default_device + ) + self.full_done_spec = CompositeSpec( + done=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + truncated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + terminated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + device=self.default_device, + ) + self.action_spec = UnboundedContinuousTensorSpec( + (), device=self.default_device + ) + assert self.device == torch.device(self.default_device) + assert self.full_observation_spec is not None + assert self.full_done_spec is not None + assert self.full_state_spec is not None + assert self.full_action_spec is not None + assert self.full_reward_spec is not None + + def _step(self, tensordict): + assert tensordict.device == torch.device(self.default_device) + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "reward": torch.zeros((1,)), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=self.default_device, + ) - _device = "cuda:0" if device == "cuda" else device - _policy_device = "cuda:0" if policy_device == "cuda" else policy_device - _storing_device = "cuda:0" if storing_device == "cuda" else storing_device - - if num_env == 1: - - def env_fn(seed): - env = make_make_env("vec")() - env.set_seed(seed) - return env + def _reset(self, tensordict=None): + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=self.default_device, + ) - else: + def _set_seed(self, seed: int | None = None): + return seed - def env_fn(seed): - # 1226: faster execution - # env = ParallelEnv( - env = SerialEnv( - num_workers=num_env, - create_env_fn=make_make_env("vec"), - create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - ) - return env + class DeviceLessPolicy(TensorDictModuleBase): + in_keys = ["observation"] + out_keys = ["action"] - if _policy_device is None: - policy = make_policy("vec") - else: - policy = ParametricPolicy().to(torch.device(_policy_device)) + # receives data on gpu and outputs on cpu + def forward(self, tensordict): + assert tensordict.device is None + return tensordict.set("action", torch.zeros((), device="cpu")) - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=_device, - storing_device=_storing_device, - ) - for _, d in enumerate(collector): - assert _is_consistent_device_type( - device, policy_device, storing_device, d.device.type + class PolicyWithDevice(TensorDictModuleBase): + in_keys = ["observation"] + out_keys = ["action"] + # receives and sends data on gpu + default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + + def forward(self, tensordict): + assert tensordict.device == torch.device(self.default_device) + return tensordict.set("action", torch.zeros((), device=self.default_device)) + + @pytest.mark.parametrize("main_device", get_default_devices()) + @pytest.mark.parametrize("storing_device", [None, *get_default_devices()]) + def test_output_device(self, main_device, storing_device): + + # env has no device, policy is strictly on GPU + device = None + env_device = None + policy_device = main_device + env = self.DeviceLessEnv(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, ) - break - assert d.names[-1] == "time" + for data in collector: # noqa: B007 + break - collector.shutdown() + assert data.device == storing_device - ccollector = aSyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=_device, - storing_device=_storing_device, - ) - - for _, d in enumerate(ccollector): - assert _is_consistent_device_type( - device, policy_device, storing_device, d.device.type + # env is on cuda, policy has no device + device = None + env_device = main_device + policy_device = None + env = self.EnvWithDevice(main_device) + policy = self.DeviceLessPolicy() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, ) - break - assert d.names[-1] == "time" - - ccollector.shutdown() - del ccollector + for data in collector: # noqa: B007 + break + assert data.device == storing_device + + # env and policy are on device + device = main_device + env_device = None + policy_device = None + env = self.EnvWithDevice(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == main_device + + # same but more specific + device = None + env_device = main_device + policy_device = main_device + env = self.EnvWithDevice(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == main_device + + # none has a device + device = None + env_device = None + policy_device = None + env = self.DeviceLessEnv(main_device) + policy = self.DeviceLessPolicy() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == storing_device + + +# @pytest.mark.skipif( +# IS_WINDOWS and PYTHON_3_10, +# reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection", +# ) +# @pytest.mark.parametrize("num_env", [2]) +# @pytest.mark.parametrize("device", ["cuda", "cpu", None]) +# @pytest.mark.parametrize("policy_device", ["cuda", "cpu", None]) +# @pytest.mark.parametrize("storing_device", ["cuda", "cpu", None]) +# def test_output_device_consistency( +# num_env, device, policy_device, storing_device, seed=40 +# ): +# if ( +# device == "cuda" or policy_device == "cuda" or storing_device == "cuda" +# ) and not torch.cuda.is_available(): +# pytest.skip("cuda is not available") +# +# if IS_WINDOWS and PYTHON_3_7: +# if device == "cuda" and policy_device == "cuda" and device is None: +# pytest.skip( +# "BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows" +# ) +# +# _device = "cuda:0" if device == "cuda" else device +# _policy_device = "cuda:0" if policy_device == "cuda" else policy_device +# _storing_device = "cuda:0" if storing_device == "cuda" else storing_device +# +# if num_env == 1: +# +# def env_fn(seed): +# env = make_make_env("vec")() +# env.set_seed(seed) +# return env +# +# else: +# +# def env_fn(seed): +# # 1226: faster execution +# # env = ParallelEnv( +# env = SerialEnv( +# num_workers=num_env, +# create_env_fn=make_make_env("vec"), +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# return env +# +# if _policy_device is None: +# policy = make_policy("vec") +# else: +# policy = ParametricPolicy().to(torch.device(_policy_device)) +# +# collector = SyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# device=_device, +# storing_device=_storing_device, +# ) +# for _, d in enumerate(collector): +# assert _is_consistent_device_type( +# device, policy_device, storing_device, d.device.type +# ) +# break +# assert d.names[-1] == "time" +# +# collector.shutdown() +# +# ccollector = aSyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# device=_device, +# storing_device=_storing_device, +# ) +# +# for _, d in enumerate(ccollector): +# assert _is_consistent_device_type( +# device, policy_device, storing_device, d.device.type +# ) +# break +# assert d.names[-1] == "time" +# +# ccollector.shutdown() +# del ccollector @pytest.mark.parametrize("num_env", [1, 2]) @@ -830,7 +1082,10 @@ def test_collector_vecnorm_envcreator(static_seed): policy = RandomPolicy(env_make.action_spec) num_data_collectors = 2 c = MultiSyncDataCollector( - [env_make] * num_data_collectors, policy=policy, total_frames=int(1e6) + [env_make] * num_data_collectors, + policy=policy, + total_frames=int(1e6), + frames_per_batch=200, ) init_seed = 0 @@ -889,8 +1144,9 @@ def create_env(): collector = collector_class( [create_env] * 3, policy=policy, - devices=[torch.device("cuda:0")] * 3, - storing_devices=[torch.device("cuda:0")] * 3, + device=[torch.device("cuda:0")] * 3, + storing_device=[torch.device("cuda:0")] * 3, + frames_per_batch=20, ) # collect state_dict state_dict = collector.state_dict() @@ -1125,10 +1381,10 @@ def env_fn(seed): frames_per_batch=20, max_frames_per_traj=2000, total_frames=20000, - devices=[ + device=[ device, ], - storing_devices=[ + storing_device=[ storing_device, ], ) @@ -1147,10 +1403,10 @@ def env_fn(seed): frames_per_batch=20, max_frames_per_traj=2000, total_frames=20000, - devices=[ + device=[ device, ], - storing_devices=[ + storing_device=[ storing_device, ], ) @@ -1170,7 +1426,7 @@ def env_fn(seed): ], ) class TestAutoWrap: - num_envs = 2 + num_envs = 1 @pytest.fixture def env_maker(self): @@ -1193,30 +1449,44 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): return collector_kwargs - @pytest.mark.parametrize("multiple_outputs", [False, True]) - def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): + @pytest.mark.parametrize("multiple_outputs", [True, False]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_auto_wrap_modules( + self, collector_class, multiple_outputs, env_maker, device + ): policy = WrappablePolicy( out_features=env_maker().action_spec.shape[-1], multiple_outputs=multiple_outputs, ) + # init lazy params + policy(env_maker().reset().get("observation")) + collector = collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy) + **self._create_collector_kwargs(env_maker, collector_class, policy), + device=device, ) out_keys = ["action"] if multiple_outputs: out_keys.extend(f"output{i}" for i in range(1, 4)) - if collector_class is not SyncDataCollector: - assert all( - isinstance(p, TensorDictModule) for p in collector._policy_dict.values() - ) - assert all(p.out_keys == out_keys for p in collector._policy_dict.values()) - assert all(p.module is policy for p in collector._policy_dict.values()) - else: + if collector_class is SyncDataCollector: assert isinstance(collector.policy, TensorDictModule) assert collector.policy.out_keys == out_keys - assert collector.policy.module is policy + # this does not work now that we force the device of the policy + # assert collector.policy.module is policy + + for i, data in enumerate(collector): + if i == 0: + assert (data["action"] != 0).any() + for p in policy.parameters(): + p.data.zero_() + assert p.device == torch.device("cpu") + collector.update_policy_weights_() + elif i == 4: + assert (data["action"] == 0).all() + break + collector.shutdown() del collector @@ -1231,28 +1501,33 @@ def test_no_wrap_compatible_module(self, collector_class, env_maker): ) if collector_class is not SyncDataCollector: - assert all( - isinstance(p, TensorDictCompatiblePolicy) - for p in collector._policy_dict.values() - ) - assert all( - p.out_keys == ["action"] for p in collector._policy_dict.values() - ) - assert all(p is policy for p in collector._policy_dict.values()) + # We now do the casting only on the remote workers + pass else: assert isinstance(collector.policy, TensorDictCompatiblePolicy) assert collector.policy.out_keys == ["action"] assert collector.policy is policy + + for i, data in enumerate(collector): + if i == 0: + assert (data["action"] != 0).any() + for p in policy.parameters(): + p.data.zero_() + assert p.device == torch.device("cpu") + collector.update_policy_weights_() + elif i == 4: + assert (data["action"] == 0).all() + break + collector.shutdown() del collector def test_auto_wrap_error(self, collector_class, env_maker): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) - with pytest.raises( TypeError, match=(r"Arguments to policy.forward are incompatible with entries in"), - ): + ) if collector_class is SyncDataCollector else pytest.raises(EOFError): collector_class( **self._create_collector_kwargs(env_maker, collector_class, policy) ) @@ -1367,8 +1642,7 @@ def env_fn(seed): frames_per_batch=frames_per_batch, init_random_frames=-1, reset_at_each_iter=False, - devices=get_default_devices()[0], - storing_devices=get_default_devices()[0], + device=get_default_devices()[0], split_trajs=False, preemptive_threshold=0.0, # stop after one iteration ) @@ -1396,23 +1670,44 @@ def test_maxframes_error(): ) -def test_reset_heterogeneous_envs(): +@retry(AssertionError, tries=10, delay=0) +@pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("env_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) +def test_reset_heterogeneous_envs( + policy_device: torch.device, env_device: torch.device, storing_device: torch.device +): + if ( + policy_device is not None + and policy_device.type == "cuda" + and env_device is None + ): + env_device = torch.device("cpu") # explicit mapping + elif env_device is not None and env_device.type == "cuda" and policy_device is None: + policy_device = torch.device("cpu") env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2)) env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3)) - env = SerialEnv(2, [env1, env2]) + env = SerialEnv(2, [env1, env2], device=env_device) collector = SyncDataCollector( - env, RandomPolicy(env.action_spec), total_frames=10_000, frames_per_batch=1000 + env, + RandomPolicy(env.action_spec), + total_frames=10_000, + frames_per_batch=100, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, ) try: for data in collector: # noqa: B007 break + data_device = storing_device if storing_device is not None else env_device assert ( data[0]["next", "truncated"].squeeze() - == torch.tensor([False, True]).repeat(250)[:500] + == torch.tensor([False, True], device=data_device).repeat(25)[:50] ).all(), data[0]["next", "truncated"][:10] assert ( data[1]["next", "truncated"].squeeze() - == torch.tensor([False, False, True]).repeat(168)[:500] + == torch.tensor([False, False, True], device=data_device).repeat(17)[:50] ).all(), data[1]["next", "truncated"][:10] finally: collector.shutdown() @@ -1554,15 +1849,17 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): ) -class TestHetEnvsCollector: +class TestHeterogeneousEnvsCollector: @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) @pytest.mark.parametrize("frames_per_batch", [4, 8, 16]) - def test_collector_het_env(self, batch_size, frames_per_batch, seed=1, max_steps=4): + def test_collector_heterogeneous_env( + self, batch_size, frames_per_batch, seed=1, max_steps=4 + ): batch_size = torch.Size(batch_size) - env = HeteroCountingEnv(max_steps=max_steps - 1, batch_size=batch_size) + env = HeterogeneousCountingEnv(max_steps=max_steps - 1, batch_size=batch_size) torch.manual_seed(seed) device = get_default_devices()[0] - policy = HeteroCountingEnvPolicy(env.input_spec["full_action_spec"]) + policy = HeterogeneousCountingEnvPolicy(env.input_spec["full_action_spec"]) ccollector = SyncDataCollector( create_env_fn=env, policy=policy, @@ -1590,14 +1887,14 @@ def test_collector_het_env(self, batch_size, frames_per_batch, seed=1, max_steps assert (_td["lazy"][..., i]["action"] == 1).all() del ccollector - def test_multi_collector_het_env_consistency( + def test_multi_collector_heterogeneous_env_consistency( self, seed=1, frames_per_batch=20, batch_dim=10 ): - env = HeteroCountingEnv(max_steps=3, batch_size=(batch_dim,)) + env = HeterogeneousCountingEnv(max_steps=3, batch_size=(batch_dim,)) torch.manual_seed(seed) env_fn = lambda: TransformedEnv(env, InitTracker()) check_env_specs(env_fn(), return_contiguous=False) - policy = HeteroCountingEnvPolicy(env.input_spec["full_action_spec"]) + policy = HeterogeneousCountingEnvPolicy(env.input_spec["full_action_spec"]) ccollector = MultiaSyncDataCollector( create_env_fn=[env_fn], @@ -1649,13 +1946,16 @@ class TestMultiKeyEnvsCollector: def test_collector(self, batch_size, frames_per_batch, max_steps, seed=1): env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps) torch.manual_seed(seed) - policy = MultiKeyCountingEnvPolicy(env.input_spec["full_action_spec"]) + device = get_default_devices()[0] + policy = MultiKeyCountingEnvPolicy( + env.input_spec["full_action_spec"].to(device) + ) ccollector = SyncDataCollector( create_env_fn=env, policy=policy, frames_per_batch=frames_per_batch, total_frames=100, - device=get_default_devices()[0], + device=device, ) for _td in ccollector: @@ -1672,8 +1972,9 @@ def test_multi_collector_consistency( env = MultiKeyCountingEnv(batch_size=(batch_dim,)) env_fn = lambda: env torch.manual_seed(seed) + device = get_default_devices()[0] policy = MultiKeyCountingEnvPolicy( - env.input_spec["full_action_spec"], deterministic=True + env.input_spec["full_action_spec"].to(device), deterministic=True ) ccollector = MultiaSyncDataCollector( @@ -1681,7 +1982,7 @@ def test_multi_collector_consistency( policy=policy, frames_per_batch=frames_per_batch, total_frames=100, - device=get_default_devices()[0], + device=device, ) for i, d in enumerate(ccollector): if i == 0: @@ -1748,11 +2049,14 @@ def _step( **self.full_done_spec.zero(), }, self.batch_size, + device=self.device, ) def _reset(self, tensordict=None): self.state.zero_() - return TensorDict({"state": self.state.clone()}, self.batch_size) + return TensorDict( + {"state": self.state.clone()}, self.batch_size, device=self.device + ) def _set_seed(self, seed): return seed diff --git a/test/test_cost.py b/test/test_cost.py index 9561f6063e4..87e17eb252c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,11 +47,11 @@ get_default_devices, ) from mocking_classes import ContinuousActionConvMockEnv -from tensordict.nn import NormalParamExtractor, TensorDictModule -from tensordict.nn.utils import Buffer # from torchrl.data.postprocs.utils import expand_as_right -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict +from tensordict.nn import NormalParamExtractor, TensorDictModule +from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn from torchrl.data import ( diff --git a/test/test_distributions.py b/test/test_distributions.py index 30bb0288dd4..e6f228628a4 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from _utils_internal import get_default_devices -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torch import autograd, nn from torchrl.modules import ( NormalParamWrapper, diff --git a/test/test_env.py b/test/test_env.py index a6f6873f9ba..eaa31007186 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -38,8 +38,8 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, - HeteroCountingEnv, - HeteroCountingEnvPolicy, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, @@ -48,9 +48,13 @@ NestedCountingEnv, ) from packaging import version -from tensordict import dense_stack_tds +from tensordict import ( + assert_allclose_td, + dense_stack_tds, + LazyStackedTensorDict, + TensorDict, +) from tensordict.nn import TensorDictModuleBase -from tensordict.tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict from tensordict.utils import _unravel_key_to_tuple from torch import nn @@ -2034,12 +2038,12 @@ def test_nested_reset(self, nest_done, has_root_done, batch_size): class TestHeteroEnvs: @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) def test_reset(self, batch_size): - env = HeteroCountingEnv(batch_size=batch_size) + env = HeterogeneousCountingEnv(batch_size=batch_size) env.reset() @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) def test_rand_step(self, batch_size): - env = HeteroCountingEnv(batch_size=batch_size) + env = HeterogeneousCountingEnv(batch_size=batch_size) td = env.reset() assert (td["lazy"][..., 0]["tensor_0"] == 0).all() td = env.rand_step() @@ -2050,7 +2054,7 @@ def test_rand_step(self, batch_size): @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) @pytest.mark.parametrize("rollout_steps", [1, 2, 5]) def test_rollout(self, batch_size, rollout_steps, n_lazy_dim=3): - env = HeteroCountingEnv(batch_size=batch_size) + env = HeterogeneousCountingEnv(batch_size=batch_size) td = env.rollout(rollout_steps, return_contiguous=False) td = dense_stack_tds(td) @@ -2072,8 +2076,8 @@ def test_rollout(self, batch_size, rollout_steps, n_lazy_dim=3): @pytest.mark.parametrize("rollout_steps", [1, 2, 5]) @pytest.mark.parametrize("count", [True, False]) def test_rollout_policy(self, batch_size, rollout_steps, count): - env = HeteroCountingEnv(batch_size=batch_size) - policy = HeteroCountingEnvPolicy( + env = HeterogeneousCountingEnv(batch_size=batch_size) + policy = HeterogeneousCountingEnvPolicy( env.input_spec["full_action_spec"], count=count ) td = env.rollout(rollout_steps, policy=policy, return_contiguous=False) @@ -2091,14 +2095,14 @@ def test_rollout_policy(self, batch_size, rollout_steps, count): @pytest.mark.parametrize("batch_size", [(1, 2)]) @pytest.mark.parametrize("env_type", ["serial", "parallel"]) def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): - env_fun = lambda: HeteroCountingEnv(batch_size=batch_size) + env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size) if env_type == "serial": vec_env = SerialEnv(n_workers, env_fun) else: vec_env = ParallelEnv(n_workers, env_fun) vec_batch_size = (n_workers,) + batch_size # check_env_specs(vec_env, return_contiguous=False) - policy = HeteroCountingEnvPolicy(vec_env.input_spec["full_action_spec"]) + policy = HeterogeneousCountingEnvPolicy(vec_env.input_spec["full_action_spec"]) vec_env.reset() td = vec_env.rollout( rollout_steps, @@ -2173,7 +2177,7 @@ def test_parallel( MockBatchedUnLockedEnv, MockSerialEnv, NestedCountingEnv, - HeteroCountingEnv, + HeterogeneousCountingEnv, MultiKeyCountingEnv, ], ) diff --git a/test/test_exploration.py b/test/test_exploration.py index 0d916e5d5e9..777f2714edb 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,9 +14,9 @@ NestedCountingEnv, ) from scipy.stats import ttest_1samp +from tensordict import TensorDict from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential -from tensordict.tensordict import TensorDict from torch import nn from torchrl._utils import _replace_last diff --git a/test/test_libs.py b/test/test_libs.py index 5fcf3497139..a1414948817 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -43,13 +43,12 @@ rollout_consistency_assertion, ) from packaging import version -from tensordict import LazyStackedTensorDict +from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict from tensordict.nn import ( ProbabilisticTensorDictModule, TensorDictModule, TensorDictSequential, ) -from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import implement_for from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector diff --git a/test/test_postprocs.py b/test/test_postprocs.py index e28ee2eb592..10a559d3cac 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -7,7 +7,7 @@ import pytest import torch from _utils_internal import get_default_devices -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict from torchrl.collectors.utils import split_trajectories from torchrl.data.postprocs.postprocs import MultiStep diff --git a/test/test_rb.py b/test/test_rb.py index 96f392d5a22..4b9b1a5dc9f 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -19,8 +19,14 @@ from _utils_internal import get_default_devices, make_tc from packaging import version from packaging.version import parse -from tensordict import is_tensor_collection, is_tensorclass, tensorclass -from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from tensordict import ( + assert_allclose_td, + is_tensor_collection, + is_tensorclass, + tensorclass, + TensorDict, + TensorDictBase, +) from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map from torchrl.data import ( diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 1d5a2398e92..a31836a4e72 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -13,7 +13,7 @@ import torch import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from tensordict.tensordict import TensorDict +from tensordict import TensorDict from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage diff --git a/test/test_specs.py b/test/test_specs.py index cc97be11918..36f5aef65ca 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import contextlib import numpy as np import pytest @@ -10,7 +11,7 @@ import torchrl.data.tensor_specs from _utils_internal import get_available_devices, get_default_devices, set_global_var from scipy.stats import chisquare -from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple from torchrl.data.tensor_specs import ( @@ -341,7 +342,7 @@ def test_multi_discrete_conversion(ns, shape, device): @pytest.mark.parametrize("is_complete", [True, False]) -@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("device", [None, *get_default_devices()]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) @pytest.mark.parametrize("shape", [(), (2, 3)]) class TestComposite: @@ -368,6 +369,7 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None): if is_complete else None, shape=shape, + device=device, ) def test_getitem(self, shape, is_complete, device, dtype): @@ -390,18 +392,26 @@ def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype): def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest): ts = self._composite_spec(shape, is_complete, device, dtype) - if dest == device: - ts["good"] = UnboundedContinuousTensorSpec( + ts["good"] = UnboundedContinuousTensorSpec( + shape=shape, device=device, dtype=dtype + ) + cm = ( + contextlib.nullcontext() + if (device == dest) or (device is None) + else pytest.raises( + RuntimeError, match="All devices of CompositeSpec must match" + ) + ) + with cm: + # auto-casting is introduced since v0.3 + ts["bad"] = UnboundedContinuousTensorSpec( shape=shape, device=dest, dtype=dtype ) - assert ts["good"].device == dest - else: - with pytest.raises( - RuntimeError, match="All devices of CompositeSpec must match" - ): - ts["bad"] = UnboundedContinuousTensorSpec( - shape=shape, device=dest, dtype=dtype - ) + assert ts.device == device + assert ts["good"].device == ( + device if device is not None else torch.zeros(()).device + ) + assert ts["bad"].device == (device if device is not None else dest) def test_del(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) @@ -682,9 +692,12 @@ def test_create_composite_nested(shape, device): c = CompositeSpec(_d, shape=shape) assert isinstance(c["a", "b"], UnboundedContinuousTensorSpec) assert c["a"].shape == torch.Size(shape) + assert c.device is None # device not explicitly passed + assert c["a"].device is None # device not explicitly passed + assert c["a", "b"].device == device + c = c.to(device) assert c.device == device assert c["a"].device == device - assert c["a", "b"].device == device @pytest.mark.parametrize("recurse", [True, False]) @@ -2277,7 +2290,7 @@ def test_stack(self): class TestLazyStackedCompositeSpecs: - def _get_het_specs( + def _get_heterogeneous_specs( self, batch_size=(), stack_dim: int = 0, @@ -2362,7 +2375,7 @@ def _get_het_specs( ), ] - return torch.stack(spec_list, dim=stack_dim) + return torch.stack(spec_list, dim=stack_dim).cpu() def test_stack_index(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) @@ -2640,7 +2653,7 @@ def test_unsqueeze(self): assert c.squeeze().shape == torch.Size([2, 3]) - c = self._get_het_specs() + c = self._get_heterogeneous_specs() cu = c.unsqueeze(0) assert cu.shape == torch.Size([1, 3]) cus = cu.squeeze(0) @@ -2648,14 +2661,14 @@ def test_unsqueeze(self): @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_len(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) assert len(c) == c.shape[0] assert len(c) == len(c.rand()) @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_eq(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) - c2 = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) + c2 = self._get_heterogeneous_specs(batch_size=batch_size) assert c == c2 and not c != c2 assert c == c.clone() and not c != c.clone() @@ -2663,12 +2676,12 @@ def test_eq(self, batch_size): del c2["shared"] assert not c == c2 and c != c2 - c2 = self._get_het_specs(batch_size=batch_size) + c2 = self._get_heterogeneous_specs(batch_size=batch_size) del c2[0]["lidar"] assert not c == c2 and c != c2 - c2 = self._get_het_specs(batch_size=batch_size) + c2 = self._get_heterogeneous_specs(batch_size=batch_size) c2[0]["lidar"].space.low += 1 assert not c == c2 and c != c2 @@ -2676,7 +2689,7 @@ def test_eq(self, batch_size): @pytest.mark.parametrize("include_nested", [True, False]) @pytest.mark.parametrize("leaves_only", [True, False]) def test_del(self, batch_size, include_nested, leaves_only): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) td_c = c.rand() keys = list(c.keys(include_nested=include_nested, leaves_only=leaves_only)) @@ -2709,7 +2722,7 @@ def test_del(self, batch_size, include_nested, leaves_only): @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_is_in(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) td_c = c.rand() assert c.is_in(td_c) @@ -2735,7 +2748,7 @@ def test_is_in(self, batch_size): assert c.is_in(td_c) def test_type_check(self): - c = self._get_het_specs() + c = self._get_heterogeneous_specs() td_c = c.rand() c.type_check(td_c) @@ -2743,7 +2756,7 @@ def test_type_check(self): @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_project(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) td_c = c.rand() assert c.is_in(td_c) val = c.project(td_c) @@ -2775,7 +2788,7 @@ def test_project(self, batch_size): assert c.is_in(td_c) def test_repr(self): - c = self._get_het_specs() + c = self._get_heterogeneous_specs() expected = f"""LazyStackedCompositeSpec( fields={{ @@ -2869,7 +2882,7 @@ def test_repr(self): @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) def test_consolidate_spec(self, batch_size): - spec = self._get_het_specs(batch_size) + spec = self._get_heterogeneous_specs(batch_size) spec_lazy = spec.clone() assert not check_no_exclusive_keys(spec_lazy) @@ -2938,8 +2951,8 @@ def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): @pytest.mark.parametrize("batch_size", [(2,), (2, 1)]) def test_update(self, batch_size, stack_dim=0): - spec = self._get_het_specs(batch_size, stack_dim) - spec2 = self._get_het_specs(batch_size, stack_dim) + spec = self._get_heterogeneous_specs(batch_size, stack_dim) + spec2 = self._get_heterogeneous_specs(batch_size, stack_dim) del spec2["shared"] spec2["hetero"] = spec2["hetero"].unsqueeze(-1) @@ -2964,7 +2977,7 @@ def test_update(self, batch_size, stack_dim=0): @pytest.mark.parametrize("batch_size", [(2,), (2, 1)]) @pytest.mark.parametrize("stack_dim", [0, 1]) def test_set_item(self, batch_size, stack_dim): - spec = self._get_het_specs(batch_size, stack_dim) + spec = self._get_heterogeneous_specs(batch_size, stack_dim) new = torch.stack( [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], @@ -2998,8 +3011,8 @@ def test_set_item(self, batch_size, stack_dim): stack_dim, ) spec["comp"] = comp - assert spec["comp"] == comp - assert spec["comp", "a"] == new + assert spec["comp"] == comp.to(spec.device) + assert spec["comp", "a"] == new.to(spec.device) # MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080. diff --git a/test/test_transforms.py b/test/test_transforms.py index 2bc9f36a79b..b325a1ccd99 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -42,9 +42,8 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict import unravel_key +from tensordict import TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictSequential -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 08c86ffea45..661903b784d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -6,6 +6,11 @@ import _pickle import abc + +import contextlib + +import functools + import inspect import logging import os @@ -17,17 +22,23 @@ from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager - from textwrap import indent from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn + +from tensordict import ( + LazyStackedTensorDict, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import TensorDictModule, TensorDictModuleBase -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import multiprocessing as mp +from torch.utils._pytree import tree_map from torch.utils.data import IterableDataset from torchrl._utils import ( @@ -77,7 +88,8 @@ class RandomPolicy: """ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - self.action_spec = action_spec + super().__init__() + self.action_spec = action_spec.clone() self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: @@ -142,10 +154,16 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: def _policy_is_tensordict_compatible(policy: nn.Module): - sig = inspect.signature(policy.forward) + if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( + policy.policy, RandomPolicy + ): + return True if isinstance(policy, TensorDictModuleBase): return True + + sig = inspect.signature(policy.forward) + if ( len(sig.parameters) == 1 and hasattr(policy, "in_keys") @@ -184,6 +202,74 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): _iterator = None + def _make_compatible_policy(self, policy, observation_spec=None): + if policy is None: + if not hasattr(self, "env") or self.env is None: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + policy = RandomPolicy(self.env.input_spec["full_action_spec"]) + # make sure policy is an nn.Module + policy = _NonParametricPolicyWrapper(policy) + if not _policy_is_tensordict_compatible(policy): + # policy is a nn.Module that doesn't operate on tensordicts directly + # so we attempt to auto-wrap policy with TensorDictModule + if observation_spec is None: + raise ValueError( + "Unable to read observation_spec from the environment. This is " + "required to check compatibility of the environment and policy " + "since the policy is a nn.Module that operates on tensors " + "rather than a TensorDictModule or a nn.Module that accepts a " + "TensorDict as input and defines in_keys and out_keys." + ) + + try: + # signature modified by make_functional + sig = policy.forward.__signature__ + except AttributeError: + sig = inspect.signature(policy.forward) + required_kwargs = { + str(k) for k, p in sig.parameters.items() if p.default is inspect._empty + } + next_observation = { + key: value for key, value in observation_spec.rand().items() + } + # we check if all the mandatory params are there + if set(sig.parameters) == {"tensordict"} or set(sig.parameters) == {"td"}: + pass + elif not required_kwargs.difference(set(next_observation)): + in_keys = [str(k) for k in sig.parameters if k in next_observation] + if not hasattr(self, "env") or self.env is None: + out_keys = ["action"] + else: + out_keys = list(self.env.action_keys) + for p in policy.parameters(): + policy_device = p.device + break + else: + policy_device = None + if policy_device: + next_observation = tree_map( + lambda x: x.to(policy_device), next_observation + ) + output = policy(**next_observation) + + if isinstance(output, tuple): + out_keys.extend(f"output{i + 1}" for i in range(len(output) - 1)) + + policy = TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + else: + raise TypeError( + f"""Arguments to policy.forward are incompatible with entries in +env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). +If you want TorchRL to automatically wrap your policy with a TensorDictModule +then the arguments to policy.forward must correspond one-to-one with entries +in env.observation_spec that are prefixed with 'next_'. For more complex +behaviour and more control you can consider writing your own TensorDictModule. +""" + ) + return policy + def _get_policy_and_device( self, policy: Optional[ @@ -192,110 +278,46 @@ def _get_policy_and_device( Callable[[TensorDictBase], TensorDictBase], ] ] = None, - device: Optional[DEVICE_TYPING] = None, observation_spec: TensorSpec = None, - ) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]: + ) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. - From a policy and a device, assigns the self.device attribute to - the desired device and maps the policy onto it or (if the device is - ommitted) assigns the self.device attribute to the policy device. - Args: create_env_fn (Callable or list of callables): an env creator function (or a list of creators) create_env_kwargs (dictionary): kwargs for the env creator policy (TensorDictModule, optional): a policy to be used - device (int, str or torch.device, optional): device where to place - the policy observation_spec (TensorSpec, optional): spec of the observations """ - if policy is None: - if not hasattr(self, "env") or self.env is None: - raise ValueError( - "env must be provided to _get_policy_and_device if policy is None" - ) - policy = RandomPolicy(self.env.input_spec["full_action_spec"]) - elif isinstance(policy, nn.Module): - # TODO: revisit these checks when we have determined whether arbitrary - # callables should be supported as policies. - if not _policy_is_tensordict_compatible(policy): - # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with TensorDictModule - if observation_spec is None: - raise ValueError( - "Unable to read observation_spec from the environment. This is " - "required to check compatibility of the environment and policy " - "since the policy is a nn.Module that operates on tensors " - "rather than a TensorDictModule or a nn.Module that accepts a " - "TensorDict as input and defines in_keys and out_keys." - ) + policy = self._make_compatible_policy(policy, observation_spec) + param_and_buf = TensorDict.from_module(policy, as_module=True) - try: - # signature modified by make_functional - sig = policy.forward.__signature__ - except AttributeError: - sig = inspect.signature(policy.forward) - required_params = { - str(k) - for k, p in sig.parameters.items() - if p.default is inspect._empty - } - next_observation = { - key: value for key, value in observation_spec.rand().items() - } - # we check if all the mandatory params are there - if not required_params.difference(set(next_observation)): - in_keys = [str(k) for k in sig.parameters if k in next_observation] - if not hasattr(self, "env") or self.env is None: - out_keys = ["action"] - else: - out_keys = self.env.action_keys - output = policy(**next_observation) + def get_weights_fn(param_and_buf=param_and_buf): + return param_and_buf.data - if isinstance(output, tuple): - out_keys.extend(f"output{i+1}" for i in range(len(output) - 1)) + if self.policy_device: + # create a stateless policy and populate it with params + def _map_to_device_params(param, device): + is_param = isinstance(param, nn.Parameter) - policy = TensorDictModule( - policy, in_keys=in_keys, out_keys=out_keys - ) - else: - raise TypeError( - f"""Arguments to policy.forward are incompatible with entries in -env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). -If you want TorchRL to automatically wrap your policy with a TensorDictModule -then the arguments to policy.forward must correspond one-to-one with entries -in env.observation_spec that are prefixed with 'next_'. For more complex -behaviour and more control you can consider writing your own TensorDictModule. -""" - ) + pd = param.detach().to(device, non_blocking=True) - try: - policy_device = next(policy.parameters()).device - except Exception: - policy_device = ( - torch.device(device) if device is not None else torch.device("cpu") - ) + if is_param: + pd = nn.Parameter(pd, requires_grad=False) + return pd - device = torch.device(device) if device is not None else policy_device - get_weights_fn = None - if policy_device != device: - param_and_buf = TensorDict.from_module(policy, as_module=True) + # Create a stateless policy, then populate this copy with params on device + with param_and_buf.apply( + functools.partial(_map_to_device_params, device="meta") + ).to_module(policy): + policy = deepcopy(policy) - def get_weights_fn(param_and_buf=param_and_buf): - return param_and_buf.data + param_and_buf.apply( + functools.partial(_map_to_device_params, device=self.policy_device) + ).to_module(policy) - policy_cast = deepcopy(policy).requires_grad_(False).to(device) - # here things may break bc policy.to("cuda") gives us weights on cuda:0 (same - # but different) - try: - device = next(policy_cast.parameters()).device - except StopIteration: # noqa - pass - else: - policy_cast = policy - return policy_cast, device, get_weights_fn + return policy, get_weights_fn def update_policy_weights_( self, policy_weights: Optional[TensorDictBase] = None @@ -363,6 +385,8 @@ class SyncDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -370,28 +394,47 @@ class SyncDataCollector(DataCollectorBase): during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. - device (int, str or torch.device, optional): The device on which the - policy will be placed. - If it differs from the input policy device, the - :meth:`~.update_policy_weights_` method should be queried - at appropriate times during the training loop to accommodate for - the lag between parameter configuration at various times. - Defaults to ``None`` (i.e. policy is kept on its original device). + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). storing_device (int, str or torch.device, optional): The device on which - the output :class:`tensordict.TensorDict` will be stored. For long - trajectories, it may be necessary to store the data on a different + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. - Defaults to ``"cpu"``. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. create_env_kwargs (dict, optional): Dictionary of kwargs for ``create_env_fn``. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a @@ -411,15 +454,15 @@ class SyncDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``ExplorationType.RANDOM`` + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. return_same_td (bool, optional): if ``True``, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. - Default is False. + Default is ``False``. interruptor (_Interruptor, optional): An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement @@ -496,9 +539,11 @@ def __init__( ], *, frames_per_batch: int, - total_frames: int, + total_frames: int = -1, device: DEVICE_TYPING = None, storing_device: DEVICE_TYPING = None, + policy_device: DEVICE_TYPING = None, + env_device: DEVICE_TYPING = None, create_env_kwargs: dict | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, @@ -532,29 +577,43 @@ def __init__( ) env.update_kwargs(create_env_kwargs) - if storing_device is None: - if device is not None: - storing_device = device - elif policy is not None: - try: - policy_device = next(policy.parameters()).device - except (AttributeError, StopIteration): - policy_device = torch.device("cpu") - storing_device = policy_device - else: - storing_device = torch.device("cpu") + ########################## + # Setting devices: + # The rule is the following: + # - If no device is passed, all devices are assumed to work OOB. + # The tensordict used for output is not on any device (ie, actions and observations + # can be on a different device). + # - If the ``device`` is passed, it is used for all devices (storing, env and policy) + # unless overridden by another kwarg. + # - The rest of the kwargs control the respective device. + storing_device, policy_device, env_device = self._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + + self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + self.device = device + # Check if we need to cast things from device to device + # If the policy has a None device and the env too, no need to cast (we don't know + # and assume the user knows what she's doing). + # If the devices match we're happy too. + # Only if the values differ we need to cast + self._cast_to_policy_device = self.policy_device != self.env_device - self.storing_device = torch.device(storing_device) self.env: EnvBase = env + del env self.closed = False if not reset_when_done: raise ValueError("reset_when_done is deprectated.") self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() - (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( + (self.policy, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, - device=device, observation_spec=self.env.observation_spec, ) @@ -563,7 +622,12 @@ def __init__( else: self.policy_weights = TensorDict({}, []) - self.env: EnvBase = self.env.to(self.device) + if self.env_device: + self.env: EnvBase = self.env.to(self.env_device) + elif self.env.device is not None: + # we we did not receive an env device, we use the device of the env + self.env_device = self.env.device + self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) @@ -580,7 +644,7 @@ def __init__( "Possible solutions: Set max_frames_per_traj to 0 or " "remove the StepCounter limit from the environment transforms." ) - env = self.env = TransformedEnv( + self.env = TransformedEnv( self.env, StepCounter(max_steps=self.max_frames_per_traj) ) @@ -614,7 +678,11 @@ def __init__( ) self.postproc = postproc - if self.postproc is not None and hasattr(self.postproc, "to"): + if ( + self.postproc is not None + and hasattr(self.postproc, "to") + and self.storing_device + ): self.postproc.to(self.storing_device) if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( @@ -630,64 +698,139 @@ def __init__( ) self.return_same_td = return_same_td - self._tensordict = env.reset() - traj_ids = torch.arange(self.n_env, device=env.device).view(self.env.batch_size) - self._tensordict.set( + # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env + self._shuttle = self.env.reset() + if self.policy_device != self.env_device or self.env_device is None: + self._shuttle_has_no_device = True + self._shuttle.clear_device_() + else: + self._shuttle_has_no_device = False + + traj_ids = torch.arange(self.n_env, device=self.storing_device).view( + self.env.batch_size + ) + self._shuttle.set( ("collector", "traj_ids"), traj_ids, ) - with torch.no_grad(): - self._tensordict_out = self.env.fake_tensordict() + self._final_rollout = self.env.fake_tensordict() + + # If storing device is not None, we use this to cast the storage. + # If it is None and the env and policy are on the same device, + # the storing device is already the same as those, so we don't need + # to consider this use case. + # In all other cases, we can't really put a device on the storage, + # since at least one data source has a device that is not clear. + if self.storing_device: + self._final_rollout = self._final_rollout.to( + self.storing_device, non_blocking=True + ) + else: + # erase all devices + self._final_rollout.clear_device_() + # If the policy has a valid spec, we use it + self._policy_output_keys = set() if ( hasattr(self.policy, "spec") and self.policy.spec is not None and all(v is not None for v in self.policy.spec.values(True, True)) ): if any( - key not in self._tensordict_out.keys(isinstance(key, tuple)) + key not in self._final_rollout.keys(isinstance(key, tuple)) for key in self.policy.spec.keys(True, True) ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information # the policy could have more keys than the env: policy_spec = self.policy.spec - if policy_spec.ndim < self._tensordict_out.ndim: - policy_spec = policy_spec.expand(self._tensordict_out.shape) + if policy_spec.ndim < self._final_rollout.ndim: + policy_spec = policy_spec.expand(self._final_rollout.shape) for key, spec in policy_spec.items(True, True): - if key in self._tensordict_out.keys(isinstance(key, tuple)): + self._policy_output_keys.add(key) + if key in self._final_rollout.keys(True): continue - self._tensordict_out.set(key, spec.zero()) + self._final_rollout.set(key, spec.zero()) else: # otherwise, we perform a small number of steps with the policy to - # determine the relevant keys with which to pre-populate _tensordict_out. + # determine the relevant keys with which to pre-populate _final_rollout. # This is the safest thing to do if the spec has None fields or if there is # no spec at all. # See #505 for additional context. - self._tensordict_out.update(self._tensordict) + self._final_rollout.update(self._shuttle.copy()) with torch.no_grad(): - self._tensordict_out = self.policy(self._tensordict_out.to(self.device)) + policy_input = self._shuttle.copy() + if self.policy_device: + policy_input = policy_input.to(self.policy_device) + # we cast to policy device, we'll deal with the device later + policy_input_copy = policy_input.copy() + policy_input_clone = ( + policy_input.clone() + ) # to test if values have changed in-place + policy_output = self.policy(policy_input) + + # check that we don't have exclusive keys, because they don't appear in keys + def check_exclusive(val): + if ( + isinstance(val, LazyStackedTensorDict) + and val._has_exclusive_keys + ): + raise RuntimeError( + "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " + "Consider using a placeholder for missing keys." + ) - self._tensordict_out = ( - self._tensordict_out.unsqueeze(-1) - .expand(*env.batch_size, self.frames_per_batch) + policy_output._fast_apply(check_exclusive, call_on_nested=True) + # Use apply, because it works well with lazy stacks + # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit + # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has + # changed them here). + # This will cause a failure to update entries when policy and env device mismatch and + # casting is necessary. + filtered_policy_output = policy_output.apply( + lambda value_output, value_input, value_input_clone: value_output + if (value_input is None) + or (value_output is not value_input) + or ~torch.isclose(value_output, value_input_clone).any() + else None, + policy_input_copy, + policy_input_clone, + default=None, + ) + self._policy_output_keys = list( + self._policy_output_keys.union( + set(filtered_policy_output.keys(True, True)) + ) + ) + self._final_rollout.update( + policy_output.select(*self._policy_output_keys) + ) + del filtered_policy_output, policy_output, policy_input + + _env_output_keys = [] + for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: + _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) + self._env_output_keys = _env_output_keys + self._final_rollout = ( + self._final_rollout.unsqueeze(-1) + .expand(*self.env.batch_size, self.frames_per_batch) .clone() .zero_() ) + # in addition to outputs of the policy, we add traj_ids to - # _tensordict_out which will be collected during rollout - self._tensordict_out = self._tensordict_out.to(self.storing_device) - self._tensordict_out.set( + # _final_rollout which will be collected during rollout + self._final_rollout.set( ("collector", "traj_ids"), torch.zeros( - *self._tensordict_out.batch_size, + *self._final_rollout.batch_size, dtype=torch.int64, device=self.storing_device, ), ) - self._tensordict_out.refine_names(..., "time") + self._final_rollout.refine_names(..., "time") if split_trajs is None: split_trajs = False @@ -697,6 +840,23 @@ def __init__( self._frames = 0 self._iter = -1 + @classmethod + def _get_devices( + cls, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + device = torch.device(device) if device else device + storing_device = torch.device(storing_device) if storing_device else device + policy_device = torch.device(policy_device) if policy_device else device + env_device = torch.device(env_device) if env_device else device + if storing_device is None and (env_device == policy_device): + storing_device = env_device + return storing_device, policy_device, env_device + # for RPC def next(self): return super().next() @@ -739,13 +899,32 @@ def iterator(self) -> Iterator[TensorDictBase]: Yields: TensorDictBase objects containing (chunks of) trajectories """ - if self.storing_device.type == "cuda": + if self.storing_device and self.storing_device.type == "cuda": stream = torch.cuda.Stream(self.storing_device, priority=-1) event = stream.record_event() + streams = [stream] + events = [event] + elif self.storing_device is None: + streams = [] + events = [] + # this way of checking cuda is robust to lazy stacks with mismatching shapes + cuda_devices = set() + + def cuda_check(tensor: torch.Tensor): + if tensor.is_cuda: + cuda_devices.add(tensor.device) + + self._final_rollout.apply(cuda_check) + for device in cuda_devices: + streams.append(torch.cuda.Stream(device, priority=-1)) + events.append(streams[-1].record_event()) else: - event = None - stream = None - with torch.cuda.stream(stream): + streams = [] + events = [] + with contextlib.ExitStack() as stack: + for stream in streams: + stack.enter_context(torch.cuda.stream(stream)) + total_frames = self.total_frames while self._frames < self.total_frames: @@ -781,9 +960,10 @@ def is_private(key): if self.return_same_td: # This is used with multiprocessed collectors to use the buffers # stored in the tensordict. - if event is not None: - event.record() - event.synchronize() + if events: + for event in events: + event.record() + event.synchronize() yield tensordict_out else: # we must clone the values, as the tensordict is updated in-place. @@ -804,12 +984,15 @@ def _update_traj_ids(self, tensordict) -> None: tensordict.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): - traj_ids = self._tensordict.get(("collector", "traj_ids")) - traj_ids = traj_ids.clone() + traj_ids = self._shuttle.get(("collector", "traj_ids")) + traj_sop = traj_sop.to(self.storing_device) + traj_ids = traj_ids.clone().to(self.storing_device) traj_ids[traj_sop] = traj_ids.max() + torch.arange( - 1, traj_sop.sum() + 1, device=traj_ids.device + 1, + traj_sop.sum() + 1, + device=self.storing_device, ) - self._tensordict.set(("collector", "traj_ids"), traj_ids) + self._shuttle.set(("collector", "traj_ids"), traj_ids) @torch.no_grad() def rollout(self) -> TensorDictBase: @@ -820,10 +1003,10 @@ def rollout(self) -> TensorDictBase: """ if self.reset_at_each_iter: - self._tensordict.update(self.env.reset()) + self._shuttle.update(self.env.reset()) - # self._tensordict.fill_(("collector", "step_count"), 0) - self._tensordict_out.fill_(("collector", "traj_ids"), -1) + # self._shuttle.fill_(("collector", "step_count"), 0) + self._final_rollout.fill_(("collector", "traj_ids"), -1) tensordicts = [] with set_exploration_type(self.exploration_type): for t in range(self.frames_per_batch): @@ -831,20 +1014,64 @@ def rollout(self) -> TensorDictBase: self.init_random_frames is not None and self._frames < self.init_random_frames ): - self.env.rand_action(self._tensordict) + self.env.rand_action(self._shuttle) else: - self.policy(self._tensordict) - tensordict, tensordict_ = self.env.step_and_maybe_reset( - self._tensordict - ) - self._tensordict = tensordict_.set( - "collector", tensordict.get("collector").clone(False) - ) - tensordicts.append( - tensordict.to(self.storing_device, non_blocking=True) - ) + if self._cast_to_policy_device: + if self.policy_device is not None: + policy_input = self._shuttle.to( + self.policy_device, non_blocking=True + ) + elif self.policy_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # policy_input = self._shuttle.clear_device_() + policy_input = self._shuttle + else: + policy_input = self._shuttle + # we still do the assignment for security + policy_output = self.policy(policy_input) + if self._shuttle is not policy_output: + # ad-hoc update shuttle + self._shuttle.update( + policy_output, keys_to_update=self._policy_output_keys + ) + + if self._cast_to_policy_device: + if self.env_device is not None: + env_input = self._shuttle.to(self.env_device, non_blocking=True) + elif self.env_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # env_input = self._shuttle.clear_device_() + env_input = self._shuttle + else: + env_input = self._shuttle + env_output, env_next_output = self.env.step_and_maybe_reset(env_input) + + if self._shuttle is not env_output: + # ad-hoc update shuttle + next_data = env_output.get("next") + if self._shuttle_has_no_device: + # Make sure + next_data.clear_device_() + self._shuttle.set("next", next_data) + + if self.storing_device is not None: + tensordicts.append( + self._shuttle.to(self.storing_device, non_blocking=True) + ) + else: + tensordicts.append(self._shuttle) + + # carry over collector data without messing up devices + collector_data = self._shuttle.get("collector").copy() + self._shuttle = env_next_output + if self._shuttle_has_no_device: + self._shuttle.clear_device_() + self._shuttle.set("collector", collector_data) + + self._update_traj_ids(env_output) - self._update_traj_ids(tensordict) if ( self.interruptor is not None and self.interruptor.collection_stopped() @@ -852,37 +1079,47 @@ def rollout(self) -> TensorDictBase: try: torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out[: t + 1], + self._final_rollout.ndim - 1, + out=self._final_rollout[: t + 1], ) except RuntimeError: - with self._tensordict_out.unlock_(): + with self._final_rollout.unlock_(): torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out[: t + 1], + self._final_rollout.ndim - 1, + out=self._final_rollout[: t + 1], ) break else: try: - self._tensordict_out = torch.stack( + self._final_rollout = torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out, + self._final_rollout.ndim - 1, + out=self._final_rollout, ) except RuntimeError: - with self._tensordict_out.unlock_(): - self._tensordict_out = torch.stack( + with self._final_rollout.unlock_(): + self._final_rollout = torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out, + self._final_rollout.ndim - 1, + out=self._final_rollout, ) - return self._tensordict_out + return self._final_rollout + + @staticmethod + def _update_device_wise(tensor0, tensor1): + # given 2 tensors, returns tensor0 if their identity matches, + # or a copy of tensor1 on the device of tensor0 otherwise + if tensor1 is None or tensor1 is tensor0: + return tensor0 + if tensor1.device == tensor0.device: + return tensor1 + return tensor1.to(tensor0.device, non_blocking=True) def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata - md = self._tensordict.get("collector").clone() + collector_metadata = self._shuttle.get("collector").clone() if index is not None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: @@ -896,20 +1133,22 @@ def reset(self, index=None, **kwargs) -> None: device=self.env.device, ) _reset[index] = 1 - self._tensordict.set(reset_key, _reset) + self._shuttle.set(reset_key, _reset) else: _reset = None - self._tensordict.zero_() + self._shuttle.zero_() - self._tensordict.update(self.env.reset(**kwargs)) - md["traj_ids"] = md["traj_ids"] - md["traj_ids"].min() - self._tensordict["collector"] = md + self._shuttle.update(self.env.reset(**kwargs), inplace=True) + collector_metadata["traj_ids"] = ( + collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() + ) + self._shuttle["collector"] = collector_metadata def shutdown(self) -> None: """Shuts down all workers and/or closes the local environment.""" if not self.closed: self.closed = True - del self._tensordict, self._tensordict_out + del self._shuttle, self._final_rollout if not self.env.is_closed: self.env.close() del self.env @@ -974,7 +1213,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: def __repr__(self) -> str: env_str = indent(f"env={self.env}", 4 * " ") policy_str = indent(f"policy={self.policy}", 4 * " ") - td_out_str = indent(f"td_out={self._tensordict_out}", 4 * " ") + td_out_str = indent(f"td_out={self._final_rollout}", 4 * " ") string = ( f"{self.__class__.__name__}(" f"\n{env_str}," @@ -994,38 +1233,61 @@ class _MultiDataCollector(DataCollectorBase): policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. - total_frames (int): A keyword-only argument representing the + total_frames (int, optional): A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. - device (int, str, torch.device or sequence of such, optional): - The device on which the policy will be placed. - If it differs from the input policy device, the - :meth:`~.update_policy_weights_` method should be queried - at appropriate times during the training loop to accommodate for - the lag between parameter configuration at various times. - If necessary, a list of devices can be passed in which case each - element will correspond to the designated device of a sub-collector. - Defaults to ``None`` (i.e. policy is kept on its original device). - storing_device (int, str, torch.device or sequence of such, optional): - The device on which the output :class:`tensordict.TensorDict` will - be stored. For long trajectories, it may be necessary to store the - data on a different device than the one where the policy and env - are executed. - If necessary, a list of devices can be passed in which case each - element will correspond to the designated storing device of a - sub-collector. - Defaults to ``"cpu"``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. create_env_kwargs (dict, optional): A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number @@ -1051,15 +1313,9 @@ class _MultiDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``ExplorationType.RANDOM`` - return_same_td (bool, optional): if ``True``, the same TensorDict - will be returned at each iteration, with its values - updated. This feature should be used cautiously: if the same - tensordict is added to a replay buffer for instance, - the whole content of the buffer will be identical. - Default is ``False``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. @@ -1088,10 +1344,12 @@ def __init__( ] ], *, - frames_per_batch: int = 200, + frames_per_batch: int, total_frames: Optional[int] = -1, - device: DEVICE_TYPING = None, - storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, create_env_kwargs: Optional[Sequence[dict]] = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, @@ -1101,10 +1359,8 @@ def __init__( exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, exploration_mode=None, reset_when_done: bool = True, - preemptive_threshold: float = None, update_at_each_batch: bool = False, - devices=None, - storing_devices=None, + preemptive_threshold: float = None, num_threads: int = None, num_sub_threads: int = 1, ): @@ -1134,99 +1390,86 @@ def __init__( # To go around this, we do the copies of the policy in the server # (this object) to each possible device, and send to all the # processes their copy of the policy. - if devices is not None: - if device is not None: - raise ValueError("Cannot pass both devices and device") - warnings.warn( - "`devices` keyword argument will soon be deprecated from multiprocessed collectors. " - "Please use `device` instead." - ) - device = devices - if storing_devices is not None: - if storing_device is not None: - raise ValueError("Cannot pass both storing_devices and storing_device") - warnings.warn( - "`storing_devices` keyword argument will soon be deprecated from multiprocessed collectors. " - "Please use `storing_device` instead." - ) - storing_device = storing_devices - - def device_err_msg(device_name, devices_list): - return ( - f"The length of the {device_name} argument should match the " - f"number of workers of the collector. Got len(" - f"create_env_fn)={self.num_workers} and len(" - f"storing_device)={len(devices_list)}" - ) - if isinstance(device, (str, int, torch.device)): - device = [torch.device(device) for _ in range(self.num_workers)] - elif device is None: - device = [None for _ in range(self.num_workers)] - elif isinstance(device, Sequence): - if len(device) != self.num_workers: - raise RuntimeError(device_err_msg("devices", device)) - device = [torch.device(_device) for _device in device] - else: - raise ValueError( - "devices should be either None, a torch.device or equivalent " - "or an iterable of devices. " - f"Found {type(device)} instead." - ) - self._policy_dict = {} - self._policy_weights_dict = {} - self._get_weights_fn_dict = {} + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) - for i, (_device, create_env, kwargs) in enumerate( - zip(device, self.create_env_fn, self.create_env_kwargs) - ): - if _device in self._policy_dict: - device[i] = _device - continue + # to avoid confusion + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices - if hasattr(create_env, "observation_spec"): - observation_spec = create_env.observation_spec - else: - try: - observation_spec = create_env(**kwargs).observation_spec - except: # noqa - observation_spec = None + del storing_device, env_device, policy_device, device - _policy, _device, _get_weight_fn = self._get_policy_and_device( - policy=policy, device=_device, observation_spec=observation_spec - ) - self._policy_dict[_device] = _policy - if isinstance(_policy, nn.Module): - self._policy_weights_dict[_device] = TensorDict.from_module( - _policy, as_module=True - ) - else: - self._policy_weights_dict[_device] = TensorDict({}, []) + _policy_weights_dict = {} + _get_weights_fn_dict = {} - self._get_weights_fn_dict[_device] = _get_weight_fn - device[i] = _device - self.device = device + policy = _NonParametricPolicyWrapper(policy) + policy_weights = TensorDict.from_module(policy, as_module=True) - if storing_device is None: - self.storing_device = self.device - else: - if isinstance(storing_device, (str, int, torch.device)): - self.storing_device = [ - torch.device(storing_device) for _ in range(self.num_workers) - ] - elif isinstance(storing_device, Sequence): - if len(storing_device) != self.num_workers: - raise RuntimeError( - device_err_msg("storing_devices", storing_device) - ) - self.storing_device = [ - torch.device(_storing_device) for _storing_device in storing_device - ] + # store a stateless policy + + with policy_weights.apply(_make_meta_params).to_module(policy): + self.policy = deepcopy(policy) + + for policy_device in policy_devices: + # if we have already mapped onto that device, get that value + if policy_device in _policy_weights_dict: + continue + # If policy device is None, the only thing we need to do is + # make sure that the weights are shared. + if policy_device is None: + + def map_weight( + weight, + ): + is_param = isinstance(weight, nn.Parameter) + weight = weight.data + if weight.device.type in ("cpu", "mps"): + weight = weight.share_memory_() + if is_param: + weight = nn.Parameter(weight, requires_grad=False) + return weight + + # in other cases, we need to cast the policy if and only if not all the weights + # are on the appropriate device else: - raise ValueError( - "storing_devices should be either a torch.device or equivalent or an iterable of devices. " - f"Found {type(storing_device)} instead." - ) + # check the weights devices + has_different_device = [False] + + def map_weight( + weight, + policy_device=policy_device, + has_different_device=has_different_device, + ): + is_param = isinstance(weight, nn.Parameter) + weight = weight.data + if weight.device != policy_device: + has_different_device[0] = True + weight = weight.to(policy_device) + elif weight.device.type in ("cpu", "mps"): + weight = weight.share_memory_() + if is_param: + weight = nn.Parameter(weight, requires_grad=False) + return weight + + local_policy_weights = TensorDictParams(policy_weights.apply(map_weight)) + + def _get_weight_fn(weights=policy_weights): + # This function will give the local_policy_weight the original weights. + # see self.update_policy_weights_ to see how this is used + return weights + + # We lock the weights to be able to cache a bunch of ops and to avoid modifying it + _policy_weights_dict[policy_device] = local_policy_weights.lock_() + _get_weights_fn_dict[policy_device] = _get_weight_fn + + self._policy_weights_dict = _policy_weights_dict + self._get_weights_fn_dict = _get_weights_fn_dict if total_frames is None or total_frames < 0: total_frames = float("inf") @@ -1278,18 +1521,74 @@ def device_err_msg(device_name, devices_list): self._frames = 0 self._iter = -1 + def _get_devices( + self, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + # convert all devices to lists + if not isinstance(storing_device, (list, tuple)): + storing_device = [ + storing_device, + ] * self.num_workers + if not isinstance(policy_device, (list, tuple)): + policy_device = [ + policy_device, + ] * self.num_workers + if not isinstance(env_device, (list, tuple)): + env_device = [ + env_device, + ] * self.num_workers + if not isinstance(device, (list, tuple)): + device = [ + device, + ] * self.num_workers + if not ( + len(device) + == len(storing_device) + == len(policy_device) + == len(env_device) + == self.num_workers + ): + raise RuntimeError( + f"THe length of the devices does not match the number of workers: {self.num_workers}." + ) + storing_device, policy_device, env_device = zip( + *[ + SyncDataCollector._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + for (storing_device, policy_device, env_device, device) in zip( + storing_device, policy_device, env_device, device + ) + ] + ) + return storing_device, policy_device, env_device + @property def frames_per_batch_worker(self): raise NotImplementedError def update_policy_weights_(self, policy_weights=None) -> None: - for _device in self._policy_dict: + for _device in self._policy_weights_dict: if policy_weights is not None: + if isinstance(policy_weights, TensorDictParams): + policy_weights = policy_weights.data self._policy_weights_dict[_device].data.update_(policy_weights) elif self._get_weights_fn_dict[_device] is not None: - self._policy_weights_dict[_device].data.update_( - self._get_weights_fn_dict[_device]() - ) + original_weights = self._get_weights_fn_dict[_device]() + if original_weights is None: + # if the weights match in identity, we can spare a call to update_ + continue + if isinstance(original_weights, TensorDictParams): + original_weights = original_weights.data + self._policy_weights_dict[_device].data.update_(original_weights) @property def _queue_len(self) -> int: @@ -1303,53 +1602,58 @@ def _run_processes(self) -> None: for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): - _device = self.device[i] - _storing_device = self.storing_device[i] pipe_parent, pipe_child = mp.Pipe() # send messages to procs if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( env_fun, EnvBase ): # to avoid circular imports env_fun = CloudpickleWrapper(env_fun) - kwargs = { - "pipe_parent": pipe_parent, - "pipe_child": pipe_child, - "queue_out": queue_out, - "create_env_fn": env_fun, - "create_env_kwargs": env_fun_kwargs, - "policy": self._policy_dict[_device], - "max_frames_per_traj": self.max_frames_per_traj, - "frames_per_batch": self.frames_per_batch_worker, - "reset_at_each_iter": self.reset_at_each_iter, - "device": _device, - "storing_device": _storing_device, - "exploration_type": self.exploration_type, - "reset_when_done": self.reset_when_done, - "idx": i, - "interruptor": self.interruptor, - } - proc = _ProcessNoWarn( - target=_main_async_collector, - num_threads=self.num_sub_threads, - kwargs=kwargs, - ) - # proc.daemon can't be set as daemonic processes may be launched by the process itself - try: - proc.start() - except _pickle.PicklingError as err: - if "" in str(err): - raise RuntimeError( - """Can't open a process with doubly cloud-pickled lambda function. + # Create a policy on the right device + policy_device = self.policy_device[i] + storing_device = self.storing_device[i] + env_device = self.env_device[i] + policy = self.policy + with self._policy_weights_dict[policy_device].to_module(policy): + kwargs = { + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": policy, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker, + "reset_at_each_iter": self.reset_at_each_iter, + "policy_device": policy_device, + "storing_device": storing_device, + "env_device": env_device, + "exploration_type": self.exploration_type, + "reset_when_done": self.reset_when_done, + "idx": i, + "interruptor": self.interruptor, + } + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + try: + proc.start() + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. This error is likely due to an attempt to use a ParallelEnv in a multiprocessed data collector. To do this, consider wrapping your lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: `env = ParallelEnv(N, EnvCreator(my_lambda_function))`. This will not only ensure that your lambda function is cloud-pickled once, but also that the state dict is synchronised across processes if needed.""" - ) from err - pipe_child.close() - self.procs.append(proc) - self.pipes.append(pipe_parent) + ) from err + pipe_child.close() + self.procs.append(proc) + self.pipes.append(pipe_parent) for pipe_parent in self.pipes: msg = pipe_parent.recv() if msg != "instantiated": @@ -1971,48 +2275,102 @@ class aSyncDataCollector(MultiaSyncDataCollector): create_env_fn (Callabled): Callable returning an instance of EnvBase policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. - total_frames (int): lower bound of the total number of frames returned - by the collector. In parallel settings, the actual number of - frames may well be greater than this as the closing signals are - sent to the workers only once the total number of frames has - been collected on the server. - create_env_kwargs (dict, optional): A dictionary with the arguments - used to create an environment - max_frames_per_traj: Maximum steps per trajectory. Note that a - trajectory can span over multiple batches (unless - reset_at_each_iter is set to True, see below). Once a trajectory - reaches n_steps, the environment is reset. If the - environment wraps multiple environments together, the number of - steps is tracked for each environment independently. Negative + + Keyword Args: + frames_per_batch (int): A keyword-only argument representing the + total number of elements in a batch. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps) - frames_per_batch (int): Time-length of a batch. - reset_at_each_iter and frames_per_batch == n_steps are equivalent configurations. - Defaults to ``200`` - init_random_frames (int): Number of frames for which the policy is ignored before it is called. - This feature is mainly intended to be used in offline/model-based settings, where a batch of random - trajectories can be used to initialize training. - Defaults to ``None`` (i.e. no random frames) - reset_at_each_iter (bool): whether environments should be reset for each batch. - default=False. - postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a - useful format for training. - default: None. - split_trajs (bool): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. - See utils.split_trajectories for more information. - device (int, str, torch.device, optional): The device on which the - policy will be placed. If it differs from the input policy - device, the update_policy_weights_() method should be queried - at appropriate times during the training loop to accommodate for - the lag between parameter configuration at various times. - Default is `None` (i.e. policy is kept on its original device) - storing_device (int, str, torch.device, optional): The device on which - the output TensorDict will be stored. For long trajectories, - it may be necessary to store the data on a different. - device than the one where the policy is stored. Default is None. - update_at_each_batch (bool): if ``True``, the policy weights will be updated every time a batch of trajectories - is collected. - default=False + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. """ @@ -2024,19 +2382,27 @@ def __init__( TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] - ] = None, + ], + *, + frames_per_batch: int, total_frames: Optional[int] = -1, - create_env_kwargs: Optional[dict] = None, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Optional[Sequence[dict]] = None, max_frames_per_traj: int | None = None, - frames_per_batch: int = 200, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, - device: Optional[Union[int, str, torch.device]] = None, - storing_device: Optional[Union[int, str, torch.device]] = None, - seed: Optional[int] = None, - pin_memory: bool = False, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + exploration_mode=None, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float = None, + num_threads: int = None, + num_sub_threads: int = 1, **kwargs, ): super().__init__( @@ -2050,9 +2416,17 @@ def __init__( init_random_frames=init_random_frames, postproc=postproc, split_trajs=split_trajs, - devices=[device] if device is not None else None, - storing_devices=[storing_device] if storing_device is not None else None, - **kwargs, + device=device, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + exploration_type=exploration_type, + exploration_mode=exploration_mode, + reset_when_done=reset_when_done, + update_at_each_batch=update_at_each_batch, + preemptive_threshold=preemptive_threshold, + num_threads=num_threads, + num_sub_threads=num_sub_threads, ) # for RPC @@ -2086,8 +2460,9 @@ def _main_async_collector( max_frames_per_traj: int, frames_per_batch: int, reset_at_each_iter: bool, - device: Optional[Union[torch.device, str, int]], storing_device: Optional[Union[torch.device, str, int]], + env_device: Optional[Union[torch.device, str, int]], + policy_device: Optional[Union[torch.device, str, int]], idx: int = 0, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, @@ -2098,15 +2473,6 @@ def _main_async_collector( # init variables that will be cleared when closing tensordict = data = d = data_in = inner_collector = dc_iter = None - # send the policy to device - try: - policy = policy.to(device) - except Exception: - if RL_WARNINGS: - warnings.warn( - "Couldn't cast the policy onto the desired device on remote process. " - "If your policy is not a nn.Module instance you can probably ignore this warning." - ) inner_collector = SyncDataCollector( create_env_fn, create_env_kwargs=create_env_kwargs, @@ -2117,8 +2483,9 @@ def _main_async_collector( reset_at_each_iter=reset_at_each_iter, postproc=None, split_trajs=False, - device=device, storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, exploration_type=exploration_type, reset_when_done=reset_when_done, return_same_td=True, @@ -2190,7 +2557,27 @@ def _main_async_collector( raise RuntimeError( f"expected device to be {storing_device} but got {tensordict.device}" ) - tensordict.share_memory_() + # If policy and env are on cpu, we put in shared mem, + # if policy is on cuda and env on cuda, we are fine with this + # If policy is on cuda and env on cpu (or opposite) we put tensors that + # are on cpu in shared mem. + if tensordict.device is not None: + # placehoder in case we need different behaviours + if tensordict.device.type in ("cpu", "mps"): + tensordict.share_memory_() + elif tensordict.device.type == "cuda": + tensordict.share_memory_() + else: + raise NotImplementedError( + f"Device {tensordict.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + tensordict.apply( + lambda x: x.share_memory_() + if x.device.type in ("cpu", "mps") + else x + ) data = (tensordict, idx) else: if d is not tensordict: @@ -2258,3 +2645,64 @@ def _main_async_collector( else: raise Exception(f"Unrecognized message {msg}") + + +class _PolicyMetaClass(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + # no kwargs + if isinstance(args[0], nn.Module): + return args[0] + return super().__call__(*args) + + +class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): + """A wrapper for non-parametric policies.""" + + def __init__(self, policy): + super().__init__() + self.policy = policy + + @property + def forward(self): + forward = self.__dict__.get("_forward", None) + if forward is None: + + @functools.wraps(self.policy) + def forward(*input, **kwargs): + return self.policy.__call__(*input, **kwargs) + + self.__dict__["_forward"] = forward + return forward + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + + elif attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + + elif "policy" in self.__dir__(): + policy = self.__getattribute__("policy") + return getattr(policy, attr) + try: + super().__getattr__(attr) + except Exception: + raise AttributeError( + f"policy not set in {self.__class__.__name__}, cannot access {attr}." + ) + + +def _make_meta_params(param): + is_param = isinstance(param, nn.Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = nn.Parameter(pd, requires_grad=False) + return pd diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index f213f73d160..073d2f445ab 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. r"""Generic distributed data-collector using torch.distributed backend.""" +from __future__ import annotations import logging import os @@ -11,7 +12,7 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import OrderedDict +from typing import Callable, List, OrderedDict, Type import torch.cuda from tensordict import TensorDict @@ -261,6 +262,8 @@ class DistributedDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -268,19 +271,55 @@ class DistributedDataCollector(DataCollectorBase): during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -293,13 +332,10 @@ class DistributedDataCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``"random"``, ``"mode"`` or - ``"mean"``. - Defaults to ``"random"`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -328,8 +364,6 @@ class DistributedDataCollector(DataCollectorBase): is one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed documentation for more information. Defaults to ``"gloo"``. - storing_device (torch.device or compatible, optional): the device where the - data will be delivered. Defaults to ``"cpu"``. update_after_each_batch (bool, optional): if ``True``, the weights will be updated after each collection. For ``sync=True``, this means that all workers will see their weights updated. For ``sync=False``, @@ -367,27 +401,29 @@ def __init__( create_env_fn, policy, *, - frames_per_batch, - total_frames, - max_frames_per_traj=-1, - init_random_frames=-1, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - exploration_type=DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, - reset_when_done=True, - collector_class=SyncDataCollector, - collector_kwargs=None, - num_workers_per_collector=1, - sync=False, - slurm_kwargs=None, - backend="gloo", - storing_device="cpu", - update_after_each_batch=False, - max_weight_update_interval=-1, - launcher="submitit", - tcp_port=None, + frames_per_batch: int, + total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Callable | None = None, + split_trajs: bool = False, + exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_mode: str = None, + collector_class: Type = SyncDataCollector, + collector_kwargs: dict = None, + num_workers_per_collector: int = 1, + sync: bool = False, + slurm_kwargs: dict | None = None, + backend: str = "gloo", + update_after_each_batch: bool = False, + max_weight_update_interval: int = -1, + launcher: str = "submitit", + tcp_port: int = None, ): exploration_type = _convert_exploration_type( exploration_mode=exploration_mode, exploration_type=exploration_type @@ -410,7 +446,12 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + + self.device = device self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + # make private to avoid changes from users during collection self._sync = sync self.update_after_each_batch = update_after_each_batch @@ -450,7 +491,7 @@ def __init__( ) # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_workers @@ -465,12 +506,12 @@ def __init__( ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc self.split_trajs = split_trajs self.backend = backend @@ -480,6 +521,66 @@ def __init__( self._init_workers() self._make_container() + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._device = value + else: + self._device = [value] * self.num_workers + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._storing_device = value + else: + self._storing_device = [value] * self.num_workers + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._env_device = value + else: + self._env_device = [value] * self.num_workers + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._policy_device = value + else: + self._policy_device = [value] * self.num_workers + def _init_master_dist( self, world_size, @@ -530,20 +631,7 @@ def _make_container(self): if self._VERBOSE: logging.info("got data", _data) logging.info("expanding...") - if not issubclass(self.collector_class, SyncDataCollector): - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) - else: - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) + self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) if self._VERBOSE: logging.info("locking") if self._sync: diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 11f94e4ea64..fa2d8e8191e 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import logging import warnings from typing import Callable, Dict, Iterator, List, OrderedDict, Union import torch import torch.nn as nn -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -119,25 +120,64 @@ class RayCollector(DataCollectorBase): instance of :class:`~torchrl.envs.EnvBase`. policy (Callable): Instance of TensorDictModule class. Must accept TensorDictBase object as input. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int, Optional): lower bound of the total number of frames returned by the collector. The iterator will stop once the total number of frames equates or exceeds the total number of frames passed to the collector. Default value is -1, which mean no target total number of frames (i.e. the collector will run indefinitely). - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. + create_env_kwargs (dict, optional): Dictionary of kwargs for + ``create_env_fn``. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -150,13 +190,10 @@ class RayCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``"random"`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (Python class): a collector class to be remotely instantiated. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -182,8 +219,6 @@ class RayCollector(DataCollectorBase): tensordicts collected on each node. If ``False`` (default), each tensordict results from a separate node in a "first-ready, first-served" fashion. - storing_device (torch.device, optional): if specified, collected tensordicts will be moved - to these devices before returning them to the user. update_after_each_batch (bool, optional): if ``True``, the weights will be updated after each collection. For ``sync=True``, this means that all workers will see their weights updated. For ``sync=False``, @@ -237,13 +272,16 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, max_frames_per_traj=-1, init_random_frames=-1, reset_at_each_iter=False, postproc=None, split_trajs=False, exploration_type=DEFAULT_EXPLORATION_TYPE, - reset_when_done=True, collector_class: Callable[[TensorDict], TensorDict] = SyncDataCollector, collector_kwargs: Union[Dict, List[Dict]] = None, num_workers_per_collector: int = 1, @@ -251,7 +289,6 @@ def __init__( ray_init_config: Dict = None, remote_configs: Union[Dict, List[Dict]] = None, num_collectors: int = None, - storing_device: torch.device = "cpu", update_after_each_batch=False, max_weight_update_interval=-1, ): @@ -358,7 +395,10 @@ def check_list_length_consistency(*lists): self.collector_kwargs = ( collector_kwargs if collector_kwargs is not None else [{}] ) + self.device = device self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device self._batches_since_weight_update = [0 for _ in range(self.num_collectors)] self._sync = sync @@ -373,7 +413,7 @@ def check_list_length_consistency(*lists): self._frames_per_batch_corrected = frames_per_batch # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_collectors @@ -388,14 +428,14 @@ def check_list_length_consistency(*lists): ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done collector_kwarg["split_trajs"] = False collector_kwarg["frames_per_batch"] = self._frames_per_batch_corrected + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc # Create remote instances of the collector class self._remote_collectors = [] @@ -414,6 +454,54 @@ def check_list_length_consistency(*lists): ] ray.wait(object_refs=pending_samples) + @property + def num_workers(self): + return self.num_collectors + + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + self._device = value + else: + self._device = [value] * self.num_collectors + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + self._storing_device = value + else: + self._storing_device = [value] * self.num_collectors + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + self._env_device = value + else: + self._env_device = [value] * self.num_collectors + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + self._policy_device = value + else: + self._policy_device = [value] * self.num_collectors + @staticmethod def _make_collector(cls, env_maker, policy, other_params): """Create a single collector instance.""" @@ -512,7 +600,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: self.collected_frames += out_td.numel() - yield out_td.to(self.storing_device) + yield out_td if self.max_weight_update_interval > -1: for j in range(self.num_collectors): @@ -549,7 +637,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: ) # should not be necessary, deleted automatically when ref count is down to 0 self.collected_frames += out_td.numel() - yield out_td.to(self.storing_device) + yield out_td for j in range(self.num_collectors): self._batches_since_weight_update[j] += 1 diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 98228d15f7b..50729038b4a 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. r"""Generic distributed data-collector using torch.distributed.rpc backend.""" +from __future__ import annotations + import collections import logging import os @@ -11,7 +13,7 @@ import time import warnings from copy import copy, deepcopy -from typing import OrderedDict +from typing import Callable, List, OrderedDict from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( @@ -96,31 +98,69 @@ class RPCDataCollector(DataCollectorBase): Args: create_env_fn (Callable or List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable, optional): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`RandomPolicy` instance with the environment ``action_spec``. - frames_per_batch (int): A keyword-only argument representing the - total number of elements in a batch. - total_frames (int): A keyword-only argument representing the - total number of frames returned by the collector + + Keyword Args: + frames_per_batch (int): A keyword-only argument representing the total + number of elements in a batch. + total_frames (int): A keyword-only argument representing the total + number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -133,14 +173,10 @@ class RPCDataCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, - ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``ExplorationType.RANDOM`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -156,7 +192,6 @@ class RPCDataCollector(DataCollectorBase): should always be preferred. If multiple simultaneous environment need to be executed on a single node, consider using a :class:`~torchrl.envs.ParallelEnv` instance. - collector_kwargs (dict or list, optional): a dictionary of parameters to be passed to the remote data-collector. If a list is provided, each element will correspond to an individual set of keyword arguments for the @@ -174,9 +209,6 @@ class RPCDataCollector(DataCollectorBase): first-served" fashion. slurm_kwargs (dict): a dictionary of parameters to be passed to the submitit executor. - storing_device (int, str or torch.device, optional): the device where - data will be stored and delivered by the iterator. Defaults to - ``"cpu"``. update_after_each_batch (bool, optional): if ``True``, the weights will be updated after each collection. For ``sync=True``, this means that all workers will see their weights updated. For ``sync=False``, @@ -217,22 +249,24 @@ def __init__( create_env_fn, policy, *, - frames_per_batch, - total_frames, - max_frames_per_traj=-1, - init_random_frames=-1, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - exploration_type=DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, - reset_when_done=True, + frames_per_batch: int, + total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Callable | None = None, + split_trajs: bool = False, + exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, sync=False, slurm_kwargs=None, - storing_device="cpu", update_after_each_batch=False, max_weight_update_interval=-1, launcher="submitit", @@ -259,6 +293,12 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + + self.device = device + self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + self.storing_device = storing_device # make private to avoid changes from users during collection self._sync = sync @@ -300,7 +340,7 @@ def __init__( ) # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_workers @@ -315,12 +355,12 @@ def __init__( ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc self.split_trajs = split_trajs if tensorpipe_options is None: @@ -331,6 +371,66 @@ def __init__( ) self._init() + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._device = value + else: + self._device = [value] * self.num_workers + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._storing_device = value + else: + self._storing_device = [value] * self.num_workers + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._env_device = value + else: + self._env_device = [value] * self.num_workers + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._policy_device = value + else: + self._policy_device = [value] * self.num_workers + def _init_master_rpc( self, world_size, @@ -585,7 +685,7 @@ def _next_async_rpc(self): args=(self.collector_rrefs[i],), ) self.futures.append((future, i)) - return data.to(self.storing_device) + return data self.futures.append((future, i)) def _next_sync_rpc(self): @@ -612,7 +712,7 @@ def _next_sync_rpc(self): ) else: self.futures.append((future, i)) - data = torch.cat(data).to(self.storing_device) + data = torch.cat(data) traj_ids = data.get(("collector", "traj_ids"), None) if traj_ids is not None: for i in range(1, self.num_workers): diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 8d3afa488d4..d7a5c94487d 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. r"""Generic distributed data-collector using torch.distributed backend.""" +from __future__ import annotations import logging import os import socket from copy import copy, deepcopy from datetime import timedelta -from typing import OrderedDict +from typing import Callable, List, OrderedDict import torch.cuda from tensordict import TensorDict @@ -142,6 +143,8 @@ class DistributedSyncDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -149,19 +152,55 @@ class DistributedSyncDataCollector(DataCollectorBase): during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -174,13 +213,10 @@ class DistributedSyncDataCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``"random"``, ``"mode"`` or - ``"mean"``. - Defaults to ``"random"`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -205,8 +241,14 @@ class DistributedSyncDataCollector(DataCollectorBase): is one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed documentation for more information. Defaults to ``"gloo"``. - storing_device (torch.device or compatible, optional): the device where the - data will be delivered. Defaults to ``"cpu"``. + max_weight_update_interval (int, optional): the maximum number of + batches that can be collected before the policy weights of a worker + is updated. + For sync collections, this parameter is overwritten by ``update_after_each_batch``. + For async collections, it may be that one worker has not seen its + parameters being updated for a certain time even if ``update_after_each_batch`` + is turned on. + Defaults to -1 (no forced update). update_interval (int, optional): the frequency at which the policy is updated. Defaults to 1. launcher (str, optional): how jobs should be launched. @@ -225,22 +267,24 @@ def __init__( create_env_fn, policy, *, - frames_per_batch, - total_frames, - max_frames_per_traj=-1, - init_random_frames=-1, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - exploration_type=DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, - reset_when_done=True, + frames_per_batch: int, + total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Callable | None = None, + split_trajs: bool = False, + exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, slurm_kwargs=None, backend="gloo", - storing_device="cpu", max_weight_update_interval=-1, update_interval=1, launcher="submitit", @@ -267,6 +311,12 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + + self.device = device + self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + self.storing_device = storing_device # make private to avoid changes from users during collection self.update_interval = update_interval @@ -304,19 +354,19 @@ def __init__( ) # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_workers ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc self.split_trajs = split_trajs self.backend = backend @@ -326,6 +376,66 @@ def __init__( self._init_workers() self._make_container() + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._device = value + else: + self._device = [value] * self.num_workers + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._storing_device = value + else: + self._storing_device = [value] * self.num_workers + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._env_device = value + else: + self._env_device = [value] * self.num_workers + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._policy_device = value + else: + self._policy_device = [value] * self.num_workers + def _init_master_dist( self, world_size, @@ -353,20 +463,7 @@ def _make_container(self): ) for _data in pseudo_collector: break - if not issubclass(self.collector_class, SyncDataCollector): - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) - else: - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) + self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) self._single_tds = self._tensordict_out.unbind(0) self._tensordict_out.lock_() pseudo_collector.shutdown() diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index eee3b3e4a98..b8db47f412d 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -7,8 +7,7 @@ import torch -from tensordict import set_lazy_legacy -from tensordict.tensordict import pad, TensorDictBase +from tensordict import pad, set_lazy_legacy, TensorDictBase def _stack_output(fun) -> Callable: diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index adf6317e679..10b9767de8e 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -19,8 +19,7 @@ import torch -from tensordict import PersistentTensorDict, TensorDict -from tensordict.tensordict import make_tensordict +from tensordict import make_tensordict, PersistentTensorDict, TensorDict from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index 0070c86d534..d0acd37822e 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -9,7 +9,7 @@ from typing import Callable import numpy as np -from tensordict.tensordict import TensorDict +from tensordict import TensorDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import ( diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 21f51115d6c..d7b2db3f15a 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -6,7 +6,7 @@ from __future__ import annotations import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from tensordict.utils import expand_right from torch import nn diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 1381c6d2383..79bf3b9b180 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -17,14 +17,15 @@ import torch -from tensordict import is_tensorclass, unravel_key -from tensordict.nn.utils import _set_dispatch_td_nn_modules -from tensordict.tensordict import ( +from tensordict import ( is_tensor_collection, + is_tensorclass, LazyStackedTensorDict, TensorDict, TensorDictBase, + unravel_key, ) +from tensordict.nn.utils import _set_dispatch_td_nn_modules from tensordict.utils import expand_as_right, expand_right from torch import Tensor diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 5357b9a835f..c37cac634e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -18,9 +18,8 @@ import numpy as np import tensordict import torch -from tensordict import is_tensorclass +from tensordict import is_tensor_collection, is_tensorclass, TensorDict, TensorDictBase from tensordict.memmap import MemmapTensor, MemoryMappedTensor -from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.utils import _STRDTYPE2DTYPE, expand_right from torch import multiprocessing as mp diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 35f4e99914c..9c6b3d1e58a 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -15,7 +15,7 @@ from tensordict import TensorDict, TensorDictBase -from tensordict.tensordict import NestedKey +from tensordict.utils import NestedKey from torchrl.data.replay_buffers import ( SamplerWithoutReplacement, TensorDictReplayBuffer, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4aad6a7b3c1..b4d628a9051 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -31,8 +31,7 @@ import numpy as np import torch -from tensordict import unravel_key -from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -81,12 +80,14 @@ def _default_dtype_and_device( dtype: Union[None, torch.dtype], device: Union[None, str, int, torch.device], -) -> Tuple[torch.dtype, torch.device]: + allow_none_device: bool = False, +) -> Tuple[torch.dtype, torch.device | None]: if dtype is None: dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") - device = torch.device(device) + if device is not None: + device = torch.device(device) + elif not allow_none_device: + device = torch.zeros(()).device return dtype, device @@ -354,7 +355,7 @@ class ContinuousBox(Box): _low: torch.Tensor _high: torch.Tensor - device: torch.device = None + device: torch.device | None = None # We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used. @property @@ -522,7 +523,7 @@ class TensorSpec: shape: torch.Size space: Union[None, Box] - device: torch.device = torch.device("cpu") + device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -539,6 +540,10 @@ def decorator(func): return decorator + def clear_device_(self): + """A no-op for all leaf specs (which must have a device).""" + pass + def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False ) -> torch.Tensor: @@ -761,6 +766,14 @@ def zero(self, shape=None) -> torch.Tensor: def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": raise NotImplementedError + def cpu(self): + return self.to("cpu") + + def cuda(self, device=None): + if device is None: + return self.to("cuda") + return self.to(f"cuda:{device}") + @abc.abstractmethod def clone(self) -> "TensorSpec": raise NotImplementedError @@ -809,6 +822,11 @@ def __init__(self, *specs: tuple[T, ...], dim: int) -> None: if self.dim < 0: self.dim = len(self.shape) + self.dim + def clear_device_(self): + """Clears the device of the CompositeSpec.""" + for spec in self._specs: + spec.clear_device_() + def __getitem__(self, item): is_key = isinstance(item, str) or ( isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) @@ -918,6 +936,8 @@ def rand(self, shape=None) -> TensorDictBase: return torch.stack([spec.rand(shape) for spec in self._specs], dim) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: + if dest is None: + return self return torch.stack([spec.to(dest) for spec in self._specs], self.dim) def unbind(self, dim: int): @@ -1160,7 +1180,7 @@ class OneHotDiscreteTensorSpec(TensorSpec): shape: torch.Size space: DiscreteBox - device: torch.device = torch.device("cpu") + device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -1187,7 +1207,9 @@ def __init__( f"The last value of the shape must match n for transform of type {self.__class__}. " f"Got n={space.n} and shape={shape}." ) - super().__init__(shape, space, device, dtype, "discrete") + super().__init__( + shape=shape, space=space, device=device, dtype=dtype, domain="discrete" + ) self.update_mask(mask) @property @@ -1205,6 +1227,8 @@ def update_mask(self, mask): self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if dest is None: + return self if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -1529,8 +1553,6 @@ def __init__( dtype, device = _default_dtype_and_device(dtype, device) if dtype is None: dtype = torch.get_default_dtype() - if device is None: - device = torch._get_default_device() if not isinstance(low, torch.Tensor): low = torch.tensor(low, dtype=dtype, device=device) @@ -1592,7 +1614,11 @@ def __init__( self.shape = shape super().__init__( - shape, ContinuousBox(low, high, device=device), device, dtype, domain=domain + shape=shape, + space=ContinuousBox(low, high, device=device), + device=device, + dtype=dtype, + domain=domain, ) def __eq__(self, other): @@ -1750,6 +1776,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -1845,6 +1873,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -1859,7 +1889,9 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) shape = [*shape, *self.shape] - return torch.randn(shape, device=self.device, dtype=self.dtype) + if self.dtype.is_floating_point: + return torch.randn(shape, device=self.device, dtype=self.dtype) + return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: return True @@ -1979,6 +2011,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2167,6 +2201,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2474,7 +2510,7 @@ class DiscreteTensorSpec(TensorSpec): shape: torch.Size space: DiscreteBox - device: torch.device = torch.device("cpu") + device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -2492,7 +2528,9 @@ def __init__( shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) - super().__init__(shape, space, device, dtype, domain="discrete") + super().__init__( + shape=shape, space=space, device=device, dtype=dtype, domain="discrete" + ) self.update_mask(mask) @property @@ -2690,6 +2728,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2796,6 +2836,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2908,6 +2950,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -3205,6 +3249,15 @@ class CompositeSpec(TensorSpec): to be ``True`` for the corresponding tensors, and :obj:`project()` will have no effect. `spec.encode` cannot be used with missing values. + Attributes: + device (torch.device or None): if not specified, the device of the composite + spec is ``None`` (as it is the case for TensorDicts). A non-none device + constraints all leaves to be of the same device. On the other hand, + a ``None`` device allows leaves to have different devices. Defaults + to ``None``. + shape (torch.Size): the leading shape of all the leaves. Equivalent + to the batch-size of the corresponding tensordicts. + Examples: >>> pixels_spec = BoundedTensorSpec( ... torch.zeros(3,32,32), @@ -3237,7 +3290,6 @@ class CompositeSpec(TensorSpec): device=None, is_shared=False) - Examples: >>> # we can build a nested composite spec using unnamed arguments >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) @@ -3264,7 +3316,7 @@ class CompositeSpec(TensorSpec): @classmethod def __new__(cls, *args, **kwargs): - cls._device = torch.device("cpu") + cls._device = None cls._locked = False return super().__new__(cls) @@ -3330,19 +3382,13 @@ def __init__(self, *args, shape=None, device=None, **kwargs): for key, item in self.items(): if item is None: continue - - try: - item_device = item.device - except RuntimeError as err: - cond1 = DEVICE_ERR_MSG in str(err) - if cond1: - item_device = _device - else: - raise err - - if _device is None: - _device = item_device - elif item_device != _device: + if ( + isinstance(item, CompositeSpec) + and item.device is None + and _device is not None + ): + item = item.clone().to(_device) + elif (_device is not None) and (item.device != _device): raise RuntimeError( f"Setting a new attribute ({key}) on another device " f"({item.device} against {_device}). All devices of " @@ -3361,40 +3407,27 @@ def __init__(self, *args, shape=None, device=None, **kwargs): ) for k, item in argdict.items(): if isinstance(item, dict): - item = CompositeSpec(item, shape=shape) - if item is not None: - if self._device is None: - try: - self._device = item.device - except RuntimeError as err: - if DEVICE_ERR_MSG in str(err): - self._device = item._device - else: - raise err + item = CompositeSpec(item, shape=shape, device=_device) self[k] = item @property def device(self) -> DEVICE_TYPING: - if self._device is None: - # try to replace device by the true device - _device = None - for value in self.values(): - if value is not None: - _device = value.device - if _device is None: - raise RuntimeError( - "device of empty CompositeSpec is not defined. " - "You can set it directly by calling " - "`spec.device = device`." - ) - self._device = _device return self._device @device.setter def device(self, device: DEVICE_TYPING): + if device is None and self._device is not None: + raise RuntimeError( + "To erase the device of a composite spec, call " "spec.clear_device_()." + ) device = torch.device(device) self.to(device) + def clear_device_(self): + """Clears the device of the CompositeSpec.""" + for spec in self._specs: + spec.clear_device_() + def __getitem__(self, idx): """Indexes the current CompositeSpec based on the provided index.""" if isinstance(idx, (str, tuple)): @@ -3456,7 +3489,7 @@ def get(self, item, default=NO_DEFAULT): def __setitem__(self, key, value): if isinstance(key, tuple) and len(key) > 1: if key[0] not in self.keys(True): - self[key[0]] = CompositeSpec(shape=self.shape) + self[key[0]] = CompositeSpec(shape=self.shape, device=self.device) self[key[0]][key[1:]] = value return elif isinstance(key, tuple): @@ -3466,34 +3499,25 @@ def __setitem__(self, key, value): raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") - try: - if value is not None and value.device != self.device: + if isinstance(value, dict): + value = CompositeSpec(value, device=self._device, shape=self.shape) + if ( + value is not None + and self.device is not None + and value.device != self.device + ): + if isinstance(value, CompositeSpec) and value.device is None: + value = value.clone().to(self.device) + else: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " f"All devices of CompositeSpec must match." ) - except RuntimeError as err: - cond1 = DEVICE_ERR_MSG in str(err) - cond2 = self._device is None - if cond1 and cond2: - try: - device_val = value.device - self.to(device_val) - except RuntimeError as suberr: - if DEVICE_ERR_MSG in str(suberr): - pass - else: - raise suberr - elif cond1: - pass - else: - raise err self.set(key, value) def __iter__(self): - for k in self._specs: - yield k + yield from self._specs def __delitem__(self, key: str) -> None: if isinstance(key, tuple) and len(key) > 1: @@ -3668,6 +3692,8 @@ def __len__(self): return len(self.keys()) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if dest is None: + return self if not isinstance(dest, (str, int, torch.device)): raise ValueError( "Only device casting is allowed with specs of type CompositeSpec." @@ -4123,7 +4149,23 @@ def __setitem__(self, key: NestedKey, value): @property def device(self) -> DEVICE_TYPING: - return self._specs[0].device + device = self.__dict__.get("_device", NO_DEFAULT) + if device is NO_DEFAULT: + devices = {spec.device for spec in self._specs} + if len(devices) == 1: + device = list(devices)[0] + elif len(devices) == 2: + device0, device1 = devices + if device0 is None: + device = device1 + elif device1 is None: + device = device0 + else: + device = None + else: + device = None + self.__dict__["_device"] = device + return device @property def ndim(self): @@ -4232,7 +4274,18 @@ def _stack_composite_specs(list_of_spec, dim, out=None): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] if isinstance(spec0, CompositeSpec): - device = spec0.device + devices = {spec.device for spec in list_of_spec} + if len(devices) == 1: + device = list(devices)[0] + elif len(devices) == 2: + device0, device1 = devices + if device0 is None: + device = device1 + elif device1 is None: + device = device0 + else: + device = None + all_equal = True for spec in list_of_spec[1:]: if not isinstance(spec, CompositeSpec): @@ -4240,8 +4293,9 @@ def _stack_composite_specs(list_of_spec, dim, out=None): "Stacking specs cannot occur: Found more than one type of spec in " "the list." ) - if device != spec.device: - raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + if device != spec.device and device is not None: + # spec.device must be None + spec = spec.to(device) if spec.shape != spec0.shape: raise RuntimeError(f"Shapes differ, got {spec.shape} and {spec0.shape}") all_equal = all_equal and spec == spec0 diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index a8d399ea08b..bc253cd3ac7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -5,6 +5,7 @@ from __future__ import annotations +import gc import logging import os @@ -18,9 +19,8 @@ import torch -from tensordict import TensorDict +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple, unravel_key -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE from torchrl.data.tensor_specs import CompositeSpec @@ -410,6 +410,15 @@ def _check_for_empty_spec(specs: CompositeSpec): self._dummy_env_str = meta_data.env_str self._env_tensordict = meta_data.tensordict + if device is None: # In other cases, the device will be mapped later + self._env_tensordict.clear_device_() + device_map = meta_data.device_map + + def map_device(key, value, device_map=device_map): + return value.to(device_map[key]) + + self._env_tensordict.named_apply(map_device, nested_keys=True) + self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) @@ -576,11 +585,10 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.shared_tensordict_parent.device.type == "cpu": - if self._share_memory: - self.shared_tensordict_parent.share_memory_() - elif self._memmap: - self.shared_tensordict_parent.memmap_() + if self._share_memory: + self.shared_tensordict_parent.share_memory_() + elif self._memmap: + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -676,6 +684,8 @@ def _start_workers(self) -> None: for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) + if self.device is not None: + env = env.to(self.device) self._envs.append(env) self.is_closed = False @@ -766,6 +776,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) if out.device == device: out = out.clone() + elif device is None: + out = out.clone().clear_device_() else: out = out.to(device, non_blocking=True) return out @@ -807,6 +819,8 @@ def _step( out = next_td.select(*self._selected_step_keys, strict=False) if out.device == device: out = out.clone() + elif device is None: + out = out.clone().clear_device_() else: out = out.to(device, non_blocking=True) return out @@ -850,8 +864,7 @@ def to(self, device: DEVICE_TYPING): return self super().to(device) if not self.is_closed: - for env in self._envs: - env.to(device) + self._envs = [env.to(device) for env in self._envs] return self @@ -1006,7 +1019,17 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] func = _run_worker_pipe_shared_mem - if self.shared_tensordict_parent.device.type == "cuda": + # We look for cuda tensors through the leaves + # because the shared tensordict could be partially on cuda + # and some leaves may be inaccessible through get (e.g., LazyStacked) + has_cuda = [False] + + def look_for_cuda(tensor, has_cuda=has_cuda): + has_cuda[0] = has_cuda[0] or tensor.is_cuda + + self.shared_tensordict_parent.apply(look_for_cuda) + has_cuda = has_cuda[0] + if has_cuda: self.event = torch.cuda.Event() else: self.event = None @@ -1123,9 +1146,12 @@ def step_and_maybe_reset( if self.shared_tensordict_parent.device == device: next_td = next_td.clone() tensordict_ = tensordict_.clone() - else: + elif device is not None: next_td = next_td.to(device, non_blocking=True) tensordict_ = tensordict_.to(device, non_blocking=True) + else: + next_td = next_td.clone().clear_device_() + tensordict_ = tensordict_.clone().clear_device_() tensordict.set("next", next_td) return tensordict, tensordict_ @@ -1255,6 +1281,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) if out.device == device: out = out.clone() + elif device is None: + out = out.clear_device_().clone() else: out = out.to(device, non_blocking=True) return out @@ -1379,7 +1407,18 @@ def _run_worker_pipe_shared_mem( verbose: bool = False, ) -> None: device = shared_tensordict.device - if device.type == "cuda": + if device is None or device.type != "cuda": + # Check if some tensors are shared on cuda + has_cuda = [False] + + def look_for_cuda(tensor, has_cuda=has_cuda): + has_cuda[0] = has_cuda[0] or tensor.is_cuda + + shared_tensordict.apply(look_for_cuda) + has_cuda = has_cuda[0] + else: + has_cuda = device.type == "cuda" + if has_cuda: event = torch.cuda.Event() else: event = None @@ -1403,7 +1442,7 @@ def _run_worker_pipe_shared_mem( initialized = False child_pipe.send("started") - + next_shared_tensordict, root_shared_tensordict = (None,) * 2 while True: try: if child_pipe.poll(_timeout): @@ -1452,42 +1491,49 @@ def _run_worker_pipe_shared_mem( event.record() event.synchronize() mp_event.set() + del cur_td elif cmd == "step": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - env_input = shared_tensordict - next_td = env._step(env_input) + next_td = env._step(shared_tensordict) next_shared_tensordict.update_(next_td) if event is not None: event.record() event.synchronize() mp_event.set() + del next_td elif cmd == "step_and_maybe_reset": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - env_input = shared_tensordict - td, root_next_td = env.step_and_maybe_reset(env_input) + td, root_next_td = env.step_and_maybe_reset(shared_tensordict) next_shared_tensordict.update_(td.get("next")) root_shared_tensordict.update_(root_next_td) if event is not None: event.record() event.synchronize() mp_event.set() + del td, root_next_td elif cmd == "close": - del shared_tensordict, data if not initialized: raise RuntimeError("call 'init' before closing") env.close() - del env + del ( + env, + shared_tensordict, + data, + next_shared_tensordict, + root_shared_tensordict, + ) mp_event.set() child_pipe.close() if verbose: logging.info(f"{pid} closed") + gc.collect() break elif cmd == "load_state_dict": @@ -1498,6 +1544,7 @@ def _run_worker_pipe_shared_mem( state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) msg = "state_dict" child_pipe.send((msg, state_dict)) + del state_dict else: err_msg = f"{cmd} from env" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 39484ac355a..61cd211b6ae 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict import LazyStackedTensorDict, unravel_key -from tensordict.tensordict import TensorDictBase +from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key from tensordict.utils import NestedKey from torchrl._utils import _replace_last, implement_for, prod, seed_generator @@ -59,6 +58,7 @@ def __init__( env_str: str, device: torch.device, batch_locked: bool = True, + device_map: dict = None, ): self.device = device self.tensordict = tensordict @@ -66,6 +66,7 @@ def __init__( self.batch_size = batch_size self.env_str = env_str self.batch_locked = batch_locked + self.device_map = device_map @property def tensordict(self): @@ -100,7 +101,16 @@ def metadata_from_env(env) -> EnvMetaData: device = env.device specs = specs.to("cpu") batch_locked = env.batch_locked - return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked) + # we need to save the device map, as the tensordict will be placed on cpu + device_map = {} + + def fill_device_map(name, val, device_map=device_map): + device_map[name] = val.device + + tensordict.named_apply(fill_device_map, nested_keys=True) + return EnvMetaData( + tensordict, specs, batch_size, env_str, device, batch_locked, device_map + ) def expand(self, *size: int) -> EnvMetaData: tensordict = self.tensordict.expand(*size).clone() @@ -112,6 +122,7 @@ def expand(self, *size: int) -> EnvMetaData: self.env_str, self.device, self.batch_locked, + self.device_map, ) def clone(self): @@ -122,13 +133,23 @@ def clone(self): deepcopy(self.env_str), self.device, self.batch_locked, + self.device_map, ) def to(self, device: DEVICE_TYPING) -> EnvMetaData: + if device is not None: + device = torch.device(device) + device_map = {key: device for key in self.device_map} tensordict = self.tensordict.contiguous().to(device) specs = self.specs.to(device) return EnvMetaData( - tensordict, specs, self.batch_size, self.env_str, device, self.batch_locked + tensordict, + specs, + self.batch_size, + self.env_str, + device, + self.batch_locked, + device_map, ) @@ -149,6 +170,51 @@ def __call__(cls, *args, **kwargs): class EnvBase(nn.Module, metaclass=_EnvPostInit): """Abstract environment parent class. + Keyword Args: + device (torch.device): The device of the environment. Deviceless environments + are allowed (device=None). If not ``None``, all specs will be cast + on that device and it is expected that all inputs and outputs will + live on that device. + Defaults to ``None``. + dtype (deprecated): dtype of the observations. Will be deprecated in v0.4. + batch_size (torch.Size or equivalent, optional): batch-size of the environment. + Corresponds to the leading dimension of all the input and output + tensordicts the environment reads and writes. Defaults to an empty batch-size. + run_type_checks (bool, optional): If ``True``, type-checks will occur + at every reset and every step. Defaults to ``False``. + allow_done_after_reset (bool, optional): if ``True``, an environment can + be done after a call to :meth:`~.reset` is made. Defaults to ``False``. + + Attributes: + done_spec (CompositeSpec): equivalent to ``full_done_spec`` as all + ``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry + action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf + action if only one action tensor is to be expected. Otherwise links to + ``full_action_spec``. + observation_spec (CompositeSpec): equivalent to ``full_observation_spec``. + reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf + reward if only one reward tensor is to be expected. Otherwise links to + ``full_reward_spec``. + state_spec (CompositeSpec): equivalent to ``full_state_spec``. + full_done_spec (CompositeSpec): a composite spec such that ``full_done_spec.zero()`` + returns a tensordict containing only the leaves encoding the done status of the + environment. + full_action_spec (CompositeSpec): a composite spec such that ``full_action_spec.zero()`` + returns a tensordict containing only the leaves encoding the action of the + environment. + full_observation_spec (CompositeSpec): a composite spec such that ``full_observation_spec.zero()`` + returns a tensordict containing only the leaves encoding the observation of the + environment. + full_reward_spec (CompositeSpec): a composite spec such that ``full_reward_spec.zero()`` + returns a tensordict containing only the leaves encoding the reward of the + environment. + full_state_spec (CompositeSpec): a composite spec such that ``full_state_spec.zero()`` + returns a tensordict containing only the leaves encoding the inputs (actions + excluded) of the environment. + batch_size (torch.Size): The batch-size of the environment. + device (torch.device): the device where the input/outputs of the environment + are to be expected. Can be ``None``. + Methods: step (TensorDictBase -> TensorDictBase): step in the environment reset (TensorDictBase, optional -> TensorDictBase): reset the environment @@ -158,6 +224,15 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): steps if no policy is provided) Examples: + >>> from torchrl.envs import EnvBase + >>> class CounterEnv(EnvBase): + ... def __init__(self, batch_size=(), device=None, **kwargs): + ... self.observation_spec = CompositeSpec( + ... count=UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int64)) + ... self.action_spec = UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int8) + ... # done spec and reward spec are set automatically + ... def _step(self, tensordict): + ... >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.batch_size # how many envs are run at once @@ -238,23 +313,30 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): def __init__( self, + *, device: DEVICE_TYPING = None, dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, ): - if device is None: - device = torch.device("cpu") self.__dict__.setdefault("_batch_size", None) if device is not None: self.__dict__["_device"] = torch.device(device) output_spec = self.__dict__.get("_output_spec", None) if output_spec is not None: - self.__dict__["_output_spec"] = output_spec.to(self.device) + self.__dict__["_output_spec"] = ( + output_spec.to(self.device) + if self.device is not None + else output_spec + ) input_spec = self.__dict__.get("_input_spec", None) if input_spec is not None: - self.__dict__["_input_spec"] = input_spec.to(self.device) + self.__dict__["_input_spec"] = ( + input_spec.to(self.device) + if self.device is not None + else input_spec + ) super().__init__() self.dtype = dtype_map.get(dtype, dtype) @@ -360,8 +442,6 @@ def batch_size(self, value: torch.Size) -> None: @property def device(self) -> torch.device: device = self.__dict__.get("_device", None) - if device is None: - device = self.__dict__["_device"] = torch.device("cpu") return device @device.setter @@ -618,7 +698,7 @@ def action_spec(self) -> TensorSpec: def action_spec(self, value: TensorSpec) -> None: try: self.input_spec.unlock_() - device = self.input_spec.device + device = self.input_spec._device try: delattr(self, "_action_keys") except AttributeError: @@ -806,7 +886,7 @@ def reward_spec(self) -> TensorSpec: def reward_spec(self, value: TensorSpec) -> None: try: self.output_spec.unlock_() - device = self.output_spec.device + device = self.output_spec._device try: delattr(self, "_reward_keys") except AttributeError: @@ -873,7 +953,7 @@ def full_reward_spec(self) -> CompositeSpec: @full_reward_spec.setter def full_reward_spec(self, spec: CompositeSpec) -> None: - self.reward_spec = spec + self.reward_spec = spec.to(self.device) if self.device is not None else spec # done spec @property @@ -938,7 +1018,7 @@ def full_done_spec(self) -> CompositeSpec: @full_done_spec.setter def full_done_spec(self, spec: CompositeSpec) -> None: - self.done_spec = spec + self.done_spec = spec.to(self.device) if self.device is not None else spec # Done spec: done specs belong to output_spec @property @@ -1168,7 +1248,6 @@ def observation_spec(self) -> CompositeSpec: def observation_spec(self, value: TensorSpec) -> None: try: self.output_spec.unlock_() - device = self.output_spec.device if not isinstance(value, CompositeSpec): raise TypeError("The type of an observation_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: @@ -1179,7 +1258,10 @@ def observation_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - self.output_spec["full_observation_spec"] = value.to(device) + device = self.output_spec._device + self.output_spec["full_observation_spec"] = ( + value.to(device) if device is not None else value + ) finally: self.output_spec.lock_() @@ -1253,7 +1335,9 @@ def state_spec(self, value: CompositeSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - self.input_spec["full_state_spec"] = value.to(device) + self.input_spec["full_state_spec"] = ( + value.to(device) if device is not None else value + ) finally: self.input_spec.lock_() @@ -2274,10 +2358,13 @@ def rollout( [None, 'time'] """ - try: - policy_device = next(policy.parameters()).device - except (StopIteration, AttributeError): - policy_device = self.device + if auto_cast_to_device: + try: + policy_device = next(policy.parameters()).device + except (StopIteration, AttributeError): + policy_device = None + else: + policy_device = None env_device = self.device @@ -2330,10 +2417,16 @@ def _rollout_stop_early( tensordicts = [] for i in range(max_steps): if auto_cast_to_device: - tensordict = tensordict.to(policy_device, non_blocking=True) + if policy_device is not None: + tensordict = tensordict.to(policy_device, non_blocking=True) + else: + tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: - tensordict = tensordict.to(env_device, non_blocking=True) + if env_device is not None: + tensordict = tensordict.to(env_device, non_blocking=True) + else: + tensordict.clear_device_() tensordict = self.step(tensordict) tensordicts.append(tensordict.clone(False)) @@ -2378,10 +2471,16 @@ def _rollout_nonstop( tensordict_ = tensordict for i in range(max_steps): if auto_cast_to_device: - tensordict_ = tensordict_.to(policy_device, non_blocking=True) + if policy_device is not None: + tensordict_ = tensordict_.to(policy_device, non_blocking=True) + else: + tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: - tensordict_ = tensordict_.to(env_device, non_blocking=True) + if env_device is not None: + tensordict_ = tensordict_.to(env_device, non_blocking=True) + else: + tensordict_.clear_device_() tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) tensordicts.append(tensordict) if i == max_steps - 1: diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 9053b42f7f6..28c9e00c42a 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -10,7 +10,7 @@ from typing import Callable, Dict, Optional, Union import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase, EnvMetaData diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 7eca6f5a1db..60cb026c658 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -13,8 +13,7 @@ import numpy as np import torch -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( CompositeSpec, diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index c50d1189e59..c10813d4fc3 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -7,7 +7,7 @@ from typing import Dict, Optional, Union import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( BoundedTensorSpec, diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 7915ed91338..c6590d344e3 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -212,7 +212,7 @@ def gym_backend(submodule=None): def _gym_to_torchrl_spec_transform( spec, dtype=None, - device="cpu", + device=None, categorical_action_encoding=False, remap_state_to_observation: bool = True, batch_size: tuple = (), @@ -224,7 +224,7 @@ def _gym_to_torchrl_spec_transform( dtype (torch.dtype): a dtype to use for the spec. Defaults to`spec.dtype`. device (torch.device): the device for the spec. - Defaults to ``"cpu"``. + Defaults to ``None`` (no device for composite and default device for specs). categorical_action_encoding (bool): whether discrete spaces should be mapped to categorical or one-hot. Defaults to ``False`` (one-hot). remap_state_to_observation (bool): whether to rename the 'state' key of @@ -349,7 +349,7 @@ def _gym_to_torchrl_spec_transform( remap_state_to_observation=remap_state_to_observation, ) # the batch-size must be set later - return CompositeSpec(spec_out) + return CompositeSpec(spec_out, device=device) elif isinstance(spec, gym_spaces.dict.Dict): return _gym_to_torchrl_spec_transform( spec.spaces, diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 68437d07d35..95c64183b7d 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -11,7 +11,7 @@ import torch # from jax import dlpack as jax_dlpack, numpy as jnp -from tensordict.tensordict import make_tensordict, TensorDictBase +from tensordict import make_tensordict, TensorDictBase from torch.utils import dlpack as torch_dlpack from torchrl.data.tensor_specs import ( CompositeSpec, diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 8f8e63a6e59..42c32b3547f 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -8,7 +8,7 @@ import numpy as np import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.envs.utils import _classproperty _has_jumanji = importlib.util.find_spec("jumanji") is not None diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 1c3927e6d0f..0aa5aa99313 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -5,7 +5,7 @@ import importlib.util import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 73c293186ac..14e45eb4bc4 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -10,7 +10,7 @@ from typing import Dict, List, Tuple, Union import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torchrl.data.tensor_specs import ( CompositeSpec, diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index ee43e72ffe0..4d4998eb721 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -11,8 +11,7 @@ import numpy as np import torch -from tensordict import TensorDict -from tensordict.tensordict import make_tensordict +from tensordict import make_tensordict, TensorDict from torchrl._utils import implement_for from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 1a5d0e2ce15..51d3970fded 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Union import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( BoundedTensorSpec, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 07bc29d1b59..cf67cdc0cc3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -21,12 +21,13 @@ is_tensor_collection, NonTensorData, set_lazy_legacy, + TensorDict, + TensorDictBase, unravel_key, unravel_key_list, ) from tensordict._tensordict import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, NestedKey from torch import nn, Tensor from torch.utils._pytree import tree_map diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 289dd60f053..ed288fdea9e 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -5,8 +5,7 @@ from typing import List, Optional, Union import torch -from tensordict import set_lazy_legacy, TensorDict -from tensordict.tensordict import TensorDictBase +from tensordict import set_lazy_legacy, TensorDict, TensorDictBase from torch.hub import load_state_dict_from_url from torchrl.data.tensor_specs import ( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 82f0c2d21fb..0b978e5ef68 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -15,7 +15,12 @@ import torch -from tensordict import is_tensor_collection, TensorDictBase, unravel_key +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + unravel_key, +) from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -26,7 +31,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import LazyStackedTensorDict, NestedKey +from tensordict.utils import NestedKey from torchrl._utils import _replace_last from torchrl.data.tensor_specs import ( diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index f80524a0f9f..5a59bc55fa1 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase # TODO: code small architecture ref in Impala paper diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 1a3fdac7387..6d9e6fb3b49 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 66fd1bb9e1f..3ddb9012139 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -6,7 +6,7 @@ from typing import Optional import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torchrl.envs.common import EnvBase from torchrl.modules import SafeModule diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index b390d05fad6..c65b81eb11d 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl.envs.common import EnvBase diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 22786519681..221ba3cde8d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -13,10 +13,9 @@ import torch -from tensordict import unravel_key_list +from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import TensorDictModule, TensorDictModuleBase -from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torch import nn diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 5c8ae799061..f641fdfef88 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,13 +7,13 @@ import numpy as np import torch +from tensordict import TensorDictBase from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right, expand_right, NestedKey from torchrl.data.tensor_specs import CompositeSpec, TensorSpec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index fe970c292be..b05cbd55356 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -9,9 +9,9 @@ import torch.nn.functional as F from tensordict import TensorDictBase, unravel_key_list -from tensordict.nn import TensorDictModuleBase as ModuleBase +from tensordict.base import NO_DEFAULT -from tensordict.tensordict import NO_DEFAULT +from tensordict.nn import TensorDictModuleBase as ModuleBase from tensordict.utils import expand_as_right, prod, set_lazy_legacy from torch import nn, Tensor diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 2cdb7af2553..c32a795a2a0 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -10,8 +10,8 @@ from typing import Tuple import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -89,7 +89,7 @@ class A2CLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 8213fa7044f..3d90d0174b9 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -12,8 +12,8 @@ import numpy as np import torch import torch.nn as nn +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key from torch import Tensor @@ -90,7 +90,7 @@ class CQLLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 3b4debe6259..70239ea62e9 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -11,8 +11,8 @@ from typing import Tuple import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key from torchrl.modules.tensordict_module.actors import ActorCriticWrapper @@ -49,7 +49,7 @@ class DDPGLoss(LossModule): >>> from torchrl.data import BoundedTensorSpec >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 52339d583dd..a24aa4a1271 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -8,8 +8,8 @@ from typing import Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index ea329a2b726..e920bc83960 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -11,9 +11,8 @@ import numpy as np import torch -from tensordict import TensorDict +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torch import Tensor diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index aa0ada74801..1fd48675cb4 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -7,8 +7,8 @@ from typing import Optional, Tuple, Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor from torchrl.data.tensor_specs import TensorSpec @@ -63,7 +63,7 @@ class IQLLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) @@ -538,7 +538,7 @@ class DiscreteIQLLoss(IQLLoss): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) @@ -597,7 +597,7 @@ class DiscreteIQLLoss(IQLLoss): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 542877f8f20..0f7ea835949 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -14,8 +14,8 @@ from typing import Tuple import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -136,7 +136,7 @@ class PPOLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index cac829964fc..af0a94cbc96 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -9,9 +9,9 @@ from typing import Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor @@ -86,7 +86,7 @@ class REDQLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 98c4d4d14d3..c9cc8f383ad 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -11,9 +11,9 @@ from dataclasses import dataclass import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -95,7 +95,7 @@ class ReinforceLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.reinforce import ReinforceLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_obs, n_act = 3, 5 >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 4da874148e7..431296e7486 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -11,9 +11,9 @@ import numpy as np import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor from torchrl.data.tensor_specs import CompositeSpec, TensorSpec @@ -106,7 +106,7 @@ class SACLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 7736c5cfbbf..e1aeb253681 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -7,9 +7,9 @@ from typing import Optional, Tuple import torch -from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec @@ -75,7 +75,7 @@ class TD3Loss(LossModule): >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3 import TD3Loss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 91305a6a777..4c0b8ae67bd 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -10,8 +10,8 @@ from typing import Iterable, Optional, Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import functional as F from torch.nn.modules import dropout diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 1c43d536fe8..fc2e58a19f6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -13,6 +13,7 @@ from typing import Callable, List, Optional, Union import torch +from tensordict import TensorDictBase from tensordict.nn import ( dispatch, is_functional, @@ -20,7 +21,6 @@ TensorDictModule, TensorDictModuleBase, ) -from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torch import nn, Tensor diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 1910c920a41..a6181145311 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -8,7 +8,7 @@ import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from tensordict.utils import NestedKey diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index f8f9c55809b..7063fb2f1c4 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -6,8 +6,9 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Type, Union +from tensordict import TensorDictBase + from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModuleWrapper -from tensordict.tensordict import TensorDictBase from torchrl.collectors.collectors import ( DataCollectorBase, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 6985037d17c..c8629be7f15 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -16,8 +16,8 @@ import numpy as np import torch.nn +from tensordict import pad, TensorDictBase from tensordict.nn import TensorDictModule -from tensordict.tensordict import pad, TensorDictBase from tensordict.utils import expand_right from torch import nn, optim diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 1f69651a3b4..85590c545fa 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -340,7 +340,7 @@ def _loss_value( # value and actor loss, collect the cost values and write them in a tensordict # delivered to the user. -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase def _forward(self, input_tensordict: TensorDictBase) -> TensorDict: diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index b72d2ff0f92..12c8bdc3193 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -96,8 +96,8 @@ import numpy as np import torch import tqdm +from tensordict import TensorDict, TensorDictBase from tensordict.nn import TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec From 967bad2ce0df91a617bfcaa155c93346f765559d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 30 Jan 2024 20:47:09 +0000 Subject: [PATCH 19/35] [Feature, BugFix] Better thread control in penv and collectors (#1848) --- test/test_collector.py | 101 ++++++++++++++++++++++--------- test/test_env.py | 79 +++++++++++++++++------- torchrl/__init__.py | 3 + torchrl/collectors/collectors.py | 37 ++++++++++- torchrl/envs/batched_envs.py | 18 +++++- torchrl/envs/libs/gym.py | 14 ++--- 6 files changed, 188 insertions(+), 64 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 2e090ad8fcf..5dd1cac8d4c 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import gc import logging import sys @@ -2357,39 +2358,79 @@ def make_env(): del collector -@pytest.mark.skipif( - IS_OSX, reason="setting different threads across workeres can randomly fail on OSX." -) -def test_num_threads(): - from torchrl.collectors import collectors - - _main_async_collector_saved = collectors._main_async_collector - collectors._main_async_collector = decorate_thread_sub_func( - collectors._main_async_collector, num_threads=3 +class TestLibThreading: + @pytest.mark.skipif( + IS_OSX, + reason="setting different threads across workeres can randomly fail on OSX.", ) - num_threads = torch.get_num_threads() - try: - env = ContinuousActionVecMockEnv() - c = MultiSyncDataCollector( - [env], - policy=RandomPolicy(env.action_spec), - num_threads=7, - num_sub_threads=3, - total_frames=200, - frames_per_batch=200, + def test_num_threads(self): + from torchrl.collectors import collectors + + _main_async_collector_saved = collectors._main_async_collector + collectors._main_async_collector = decorate_thread_sub_func( + collectors._main_async_collector, num_threads=3 ) - assert torch.get_num_threads() == 7 - for _ in c: - pass - finally: + num_threads = torch.get_num_threads() + try: + env = ContinuousActionVecMockEnv() + c = MultiSyncDataCollector( + [env], + policy=RandomPolicy(env.action_spec), + num_threads=7, + num_sub_threads=3, + total_frames=200, + frames_per_batch=200, + ) + assert torch.get_num_threads() == 7 + for _ in c: + pass + finally: + try: + c.shutdown() + del c + except Exception: + logging.info("Failed to shut down collector") + # reset vals + collectors._main_async_collector = _main_async_collector_saved + torch.set_num_threads(num_threads) + + @pytest.mark.skipif( + IS_OSX, + reason="setting different threads across workeres can randomly fail on OSX.", + ) + def test_auto_num_threads(self): + init_threads = torch.get_num_threads() + try: + collector = MultiSyncDataCollector( + [ContinuousActionVecMockEnv], + RandomPolicy(ContinuousActionVecMockEnv().full_action_spec), + frames_per_batch=3, + ) + for _ in collector: + assert torch.get_num_threads() == init_threads - 1 + break + collector.shutdown() + assert torch.get_num_threads() == init_threads + del collector + gc.collect() + finally: + torch.set_num_threads(init_threads) + try: - c.shutdown() - del c - except Exception: - logging.info("Failed to shut down collector") - # reset vals - collectors._main_async_collector = _main_async_collector_saved - torch.set_num_threads(num_threads) + collector = MultiSyncDataCollector( + [ParallelEnv(2, ContinuousActionVecMockEnv)], + RandomPolicy(ContinuousActionVecMockEnv().full_action_spec.expand(2)), + frames_per_batch=3, + ) + for _ in collector: + assert torch.get_num_threads() == init_threads - 2 + break + collector.shutdown() + assert torch.get_num_threads() == init_threads + del collector + gc.collect() + finally: + torch.set_num_threads(init_threads) if __name__ == "__main__": diff --git a/test/test_env.py b/test/test_env.py index eaa31007186..71fe4ab6c60 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import gc import os.path import re from collections import defaultdict @@ -2333,30 +2334,64 @@ def test_terminated_or_truncated_spec(self): assert not data["nested", "_reset"].any() -@pytest.mark.skipif( - IS_OSX, reason="setting different threads across workeres can randomly fail on OSX." -) -def test_num_threads(): - from torchrl.envs import batched_envs - - _run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem - batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func( - batched_envs._run_worker_pipe_shared_mem, num_threads=3 +class TestLibThreading: + @pytest.mark.skipif( + IS_OSX, + reason="setting different threads across workeres can randomly fail on OSX.", ) - num_threads = torch.get_num_threads() - try: - env = ParallelEnv( - 2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7 + def test_num_threads(self): + from torchrl.envs import batched_envs + + _run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem + batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func( + batched_envs._run_worker_pipe_shared_mem, num_threads=3 ) - # We could test that the number of threads isn't changed until we start the procs. - # Even though it's unlikely that we have 7 threads, we still disable this for safety - # assert torch.get_num_threads() != 7 - env.rollout(3) - assert torch.get_num_threads() == 7 - finally: - # reset vals - batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save - torch.set_num_threads(num_threads) + num_threads = torch.get_num_threads() + try: + env = ParallelEnv( + 2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7 + ) + # We could test that the number of threads isn't changed until we start the procs. + # Even though it's unlikely that we have 7 threads, we still disable this for safety + # assert torch.get_num_threads() != 7 + env.rollout(3) + assert torch.get_num_threads() == 7 + finally: + # reset vals + batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save + torch.set_num_threads(num_threads) + + @pytest.mark.skipif( + IS_OSX, + reason="setting different threads across workeres can randomly fail on OSX.", + ) + def test_auto_num_threads(self): + init_threads = torch.get_num_threads() + + try: + env3 = ParallelEnv(3, lambda: GymEnv("Pendulum-v1")) + env3.rollout(2) + + assert torch.get_num_threads() == max(1, init_threads - 3) + + env2 = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) + env2.rollout(2) + + assert torch.get_num_threads() == max(1, init_threads - 5) + + env2.close() + del env2 + gc.collect() + + assert torch.get_num_threads() == max(1, init_threads - 3) + + env3.close() + del env3 + gc.collect() + + assert torch.get_num_threads() == init_threads + finally: + torch.set_num_threads(init_threads) def test_run_type_checks(): diff --git a/torchrl/__init__.py b/torchrl/__init__.py index ef80f84a428..109f11657e4 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -49,3 +49,6 @@ # Filter warnings in subprocesses: True by default given the multiple optional # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. filter_warnings_subprocess = True + +_THREAD_POOL_INIT = torch.get_num_threads() +_THREAD_POOL = torch.get_num_threads() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 661903b784d..ab6292fef99 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1368,12 +1368,11 @@ def __init__( exploration_mode=exploration_mode, exploration_type=exploration_type ) self.closed = True - if num_threads is None: - num_threads = len(create_env_fn) + 1 # 1 more thread for this proc + self.num_workers = len(create_env_fn) + self.num_sub_threads = num_sub_threads self.num_threads = num_threads self.create_env_fn = create_env_fn - self.num_workers = len(create_env_fn) self.create_env_kwargs = ( create_env_kwargs if create_env_kwargs is not None @@ -1521,6 +1520,18 @@ def _get_weight_fn(weights=policy_weights): self._frames = 0 self._iter = -1 + @classmethod + def _total_workers_from_env(cls, env_creators): + if isinstance(env_creators, (tuple, list)): + return sum( + cls._total_workers_from_env(env_creator) for env_creator in env_creators + ) + from torchrl.envs import ParallelEnv + + if isinstance(env_creators, ParallelEnv): + return env_creators.num_workers + return 1 + def _get_devices( self, *, @@ -1595,7 +1606,19 @@ def _queue_len(self) -> int: raise NotImplementedError def _run_processes(self) -> None: + if self.num_threads is None: + import torchrl + + total_workers = self._total_workers_from_env(self.create_env_fn) + self.num_threads = max( + 1, torchrl._THREAD_POOL - total_workers + ) # 1 more thread for this proc + torch.set_num_threads(self.num_threads) + assert torch.get_num_threads() == self.num_threads + import torchrl + + torchrl._THREAD_POOL = self.num_threads queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] self.pipes = [] @@ -1702,6 +1725,14 @@ def _shutdown_main(self) -> None: for proc in self.procs: proc.join(1.0) finally: + import torchrl + + torchrl._THREAD_POOL = min( + torchrl._THREAD_POOL_INIT, + torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn), + ) + torch.set_num_threads(torchrl._THREAD_POOL) + for proc in self.procs: if proc.is_alive(): proc.terminate() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index bc253cd3ac7..53e9913baa4 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -270,8 +270,6 @@ def __init__( super().__init__(device=device) self.serial_for_single = serial_for_single self.is_closed = True - if num_threads is None: - num_threads = num_workers + 1 # 1 more thread for this proc self.num_sub_threads = num_sub_threads self.num_threads = num_threads self._cache_in_keys = None @@ -633,6 +631,12 @@ def close(self) -> None: self._shutdown_workers() self.is_closed = True + import torchrl + + torchrl._THREAD_POOL = min( + torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers + ) + torch.set_num_threads(torchrl._THREAD_POOL) def _shutdown_workers(self) -> None: raise NotImplementedError @@ -1010,7 +1014,17 @@ class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta): def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator + if self.num_threads is None: + import torchrl + + self.num_threads = max( + 1, torchrl._THREAD_POOL - self.num_workers + ) # 1 more thread for this proc + torch.set_num_threads(self.num_threads) + import torchrl + + torchrl._THREAD_POOL = self.num_threads ctx = mp.get_context("spawn") diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index c6590d344e3..48da354e7ba 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -306,16 +306,16 @@ def _gym_to_torchrl_spec_transform( shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] - low = torch.tensor(spec.low, device=device, dtype=dtype) - high = torch.tensor(spec.high, device=device, dtype=dtype) + low = torch.as_tensor(spec.low, device=device, dtype=dtype) + high = torch.as_tensor(spec.high, device=device, dtype=dtype) is_unbounded = low.isinf().all() and high.isinf().all() minval, maxval = _minmax_dtype(dtype) minval = torch.as_tensor(minval).to(low.device, dtype) maxval = torch.as_tensor(maxval).to(low.device, dtype) is_unbounded = is_unbounded or ( - torch.isclose(low, torch.tensor(minval, dtype=dtype)).all() - and torch.isclose(high, torch.tensor(maxval, dtype=dtype)).all() + torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all() + and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all() ) return ( UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype) @@ -1480,7 +1480,7 @@ def _read_obs(self, obs, key, tensor, index): # Simplest case: there is one observation, # presented as a np.ndarray. The key should be pixels or observation. # We just write that value at its location in the tensor - tensor[index] = torch.tensor(obs, device=tensor.device) + tensor[index] = torch.as_tensor(obs, device=tensor.device) elif isinstance(obs, dict): if key not in obs: raise KeyError( @@ -1491,13 +1491,13 @@ def _read_obs(self, obs, key, tensor, index): # if the obs is a dict, we expect that the key points also to # a value in the obs. We retrieve this value and write it in the # tensor - tensor[index] = torch.tensor(subobs, device=tensor.device) + tensor[index] = torch.as_tensor(subobs, device=tensor.device) elif isinstance(obs, (list, tuple)): # tuples are stacked along the first dimension when passing gym spaces # to torchrl specs. As such, we can simply stack the tuple and set it # at the relevant index (assuming stacking can be achieved) - tensor[index] = torch.tensor(obs, device=tensor.device) + tensor[index] = torch.as_tensor(obs, device=tensor.device) else: raise NotImplementedError( f"Observations of type {type(obs)} are not supported yet." From 017bcd05a0dfbde63786473e99d89607bf1fbff0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 30 Jan 2024 21:39:35 +0000 Subject: [PATCH 20/35] [CI] Update macos image (#1849) --- .github/workflows/build-wheels-m1.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 5601a9f8485..9cbdf460894 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -33,7 +33,7 @@ jobs: build-matrix: ${{ needs.generate-matrix.outputs.matrix }} post-script: "" package-name: torchrl - runner-type: macos-m1-12 + runner-type: macos-m1-stable smoke-test-script: "" trigger-event: ${{ github.event_name }} env-var-script: .github/scripts/m1_script.sh From 86b8918d718eaec6d2a28c3dcf34ebe9f5615b78 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 09:21:47 +0000 Subject: [PATCH 21/35] [BugFix] thread setting bug (#1852) --- test/test_collector.py | 4 ++-- test/test_env.py | 8 ++++---- torchrl/__init__.py | 1 - torchrl/collectors/collectors.py | 15 +++++---------- torchrl/envs/batched_envs.py | 13 ++++--------- 5 files changed, 15 insertions(+), 26 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 5dd1cac8d4c..027cf776ee4 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2361,7 +2361,7 @@ def make_env(): class TestLibThreading: @pytest.mark.skipif( IS_OSX, - reason="setting different threads across workeres can randomly fail on OSX.", + reason="setting different threads across workers can randomly fail on OSX.", ) def test_num_threads(self): from torchrl.collectors import collectors @@ -2396,7 +2396,7 @@ def test_num_threads(self): @pytest.mark.skipif( IS_OSX, - reason="setting different threads across workeres can randomly fail on OSX.", + reason="setting different threads across workers can randomly fail on OSX.", ) def test_auto_num_threads(self): init_threads = torch.get_num_threads() diff --git a/test/test_env.py b/test/test_env.py index 71fe4ab6c60..22918c390df 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2337,7 +2337,7 @@ def test_terminated_or_truncated_spec(self): class TestLibThreading: @pytest.mark.skipif( IS_OSX, - reason="setting different threads across workeres can randomly fail on OSX.", + reason="setting different threads across workers can randomly fail on OSX.", ) def test_num_threads(self): from torchrl.envs import batched_envs @@ -2363,18 +2363,18 @@ def test_num_threads(self): @pytest.mark.skipif( IS_OSX, - reason="setting different threads across workeres can randomly fail on OSX.", + reason="setting different threads across workers can randomly fail on OSX.", ) def test_auto_num_threads(self): init_threads = torch.get_num_threads() try: - env3 = ParallelEnv(3, lambda: GymEnv("Pendulum-v1")) + env3 = ParallelEnv(3, ContinuousActionVecMockEnv) env3.rollout(2) assert torch.get_num_threads() == max(1, init_threads - 3) - env2 = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) + env2 = ParallelEnv(2, ContinuousActionVecMockEnv) env2.rollout(2) assert torch.get_num_threads() == max(1, init_threads - 5) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 109f11657e4..25103423cac 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -51,4 +51,3 @@ filter_warnings_subprocess = True _THREAD_POOL_INIT = torch.get_num_threads() -_THREAD_POOL = torch.get_num_threads() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index ab6292fef99..7d5e635a5a9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1607,18 +1607,12 @@ def _queue_len(self) -> int: def _run_processes(self) -> None: if self.num_threads is None: - import torchrl - total_workers = self._total_workers_from_env(self.create_env_fn) self.num_threads = max( - 1, torchrl._THREAD_POOL - total_workers + 1, torch.get_num_threads() - total_workers ) # 1 more thread for this proc torch.set_num_threads(self.num_threads) - assert torch.get_num_threads() == self.num_threads - import torchrl - - torchrl._THREAD_POOL = self.num_threads queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] self.pipes = [] @@ -1727,11 +1721,12 @@ def _shutdown_main(self) -> None: finally: import torchrl - torchrl._THREAD_POOL = min( + num_threads = min( torchrl._THREAD_POOL_INIT, - torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn), + torch.get_num_threads() + + self._total_workers_from_env(self.create_env_fn), ) - torch.set_num_threads(torchrl._THREAD_POOL) + torch.set_num_threads(num_threads) for proc in self.procs: if proc.is_alive(): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 53e9913baa4..655db99c983 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -633,10 +633,10 @@ def close(self) -> None: self.is_closed = True import torchrl - torchrl._THREAD_POOL = min( - torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers + num_threads = min( + torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers ) - torch.set_num_threads(torchrl._THREAD_POOL) + torch.set_num_threads(num_threads) def _shutdown_workers(self) -> None: raise NotImplementedError @@ -1015,16 +1015,11 @@ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator if self.num_threads is None: - import torchrl - self.num_threads = max( - 1, torchrl._THREAD_POOL - self.num_workers + 1, torch.get_num_threads() - self.num_workers ) # 1 more thread for this proc torch.set_num_threads(self.num_threads) - import torchrl - - torchrl._THREAD_POOL = self.num_threads ctx = mp.get_context("spawn") From 06fcac1676c23179e43755b8dc50b1e2c280f555 Mon Sep 17 00:00:00 2001 From: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:50:23 +0100 Subject: [PATCH 22/35] [Refactor] Remove unused completed_keys property from StepCounter. (#1854) --- torchrl/envs/transforms/transforms.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index cf67cdc0cc3..c2c2a6aa047 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5068,21 +5068,6 @@ def truncated_keys(self): self._truncated_keys = truncated_keys return truncated_keys - @property - def completed_keys(self): - done_keys = self.__dict__.get("_done_keys", None) - if done_keys is None: - # make the default done keys - done_keys = [] - for reset_key in self.parent._filtered_reset_keys: - if isinstance(reset_key, str): - key = "done" - else: - key = (*reset_key[:-1], "done") - done_keys.append(key) - self.__dict__["_done_keys"] = done_keys - return done_keys - @property def done_keys(self): done_keys = self.__dict__.get("_done_keys", None) From 2754200d3b784c94f1dfb3dec04e2e2f53dc246b Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 31 Jan 2024 12:29:09 +0100 Subject: [PATCH 23/35] [Feature] Submitit run script (#1822) Co-authored-by: vmoens --- .../linux_examples/scripts/run_test.sh | 1 + examples/a2c/a2c_atari.py | 9 +- examples/a2c/a2c_mujoco.py | 9 +- examples/a2c/config_atari.yaml | 2 + examples/a2c/config_mujoco.yaml | 4 +- examples/bandits/README.md | 7 ++ examples/cql/cql_offline.py | 9 +- examples/cql/cql_online.py | 9 +- examples/cql/discrete_cql_config.yaml | 4 +- examples/cql/discrete_cql_online.py | 8 +- examples/cql/offline_config.yaml | 6 +- examples/cql/online_config.yaml | 4 +- examples/ddpg/config.yaml | 8 +- examples/ddpg/ddpg.py | 9 +- examples/decision_transformer/dt_config.yaml | 4 +- examples/decision_transformer/odt_config.yaml | 4 +- examples/decision_transformer/utils.py | 9 +- examples/discrete_sac/config.yaml | 6 +- examples/discrete_sac/discrete_sac.py | 9 +- examples/dqn/config_atari.yaml | 4 +- examples/dqn/config_cartpole.yaml | 4 +- examples/dqn/dqn_atari.py | 9 +- examples/dqn/dqn_cartpole.py | 9 +- examples/dreamer/README.md | 7 ++ examples/dreamer/config.yaml | 1 + examples/dreamer/dreamer.py | 2 +- examples/impala/config_multi_node_ray.yaml | 2 + .../impala/config_multi_node_submitit.yaml | 2 + examples/impala/config_single_node.yaml | 2 + examples/impala/impala_multi_node_ray.py | 6 +- examples/impala/impala_multi_node_submitit.py | 6 +- examples/impala/impala_single_node.py | 6 +- examples/iql/discrete_iql.py | 13 ++- examples/iql/discrete_iql.yaml | 6 +- examples/iql/iql_offline.py | 12 ++- examples/iql/iql_online.py | 13 ++- examples/iql/offline_config.yaml | 5 +- examples/iql/online_config.yaml | 7 +- examples/multiagent/iql.yaml | 2 + examples/multiagent/maddpg_iddpg.yaml | 2 + examples/multiagent/mappo_ippo.yaml | 2 + examples/multiagent/qmix_vdn.yaml | 2 + examples/multiagent/sac.yaml | 2 + examples/multiagent/utils/logging.py | 5 +- examples/ppo/config_atari.yaml | 2 + examples/ppo/config_mujoco.yaml | 4 +- examples/ppo/ppo_atari.py | 23 +++-- examples/ppo/ppo_mujoco.py | 9 +- examples/ppo/utils_atari.py | 9 +- examples/redq/README.md | 7 ++ examples/redq/config.yaml | 9 +- examples/redq/redq.py | 8 +- examples/rlhf/README.md | 6 ++ examples/rlhf/config/train_rlhf.yaml | 2 + examples/rlhf/train_rlhf.py | 9 +- examples/sac/config.yaml | 8 +- examples/sac/sac.py | 9 +- examples/td3/config.yaml | 8 +- examples/td3/td3.py | 9 +- sota-check/README.md | 35 +++++++ sota-check/run_a2c_atari.sh | 27 ++++++ sota-check/run_a2c_mujoco.sh | 26 ++++++ sota-check/run_cql_offline.sh | 27 ++++++ sota-check/run_cql_online.sh | 26 ++++++ sota-check/run_ddpg.sh | 26 ++++++ sota-check/run_discrete_sac.sh | 26 ++++++ sota-check/run_dqn_atari.sh | 26 ++++++ sota-check/run_dqn_cartpole.sh | 26 ++++++ sota-check/run_dt.sh | 26 ++++++ sota-check/run_dt_online.sh | 26 ++++++ sota-check/run_impala_single_node.sh | 26 ++++++ sota-check/run_iql_discrete.sh | 26 ++++++ sota-check/run_iql_offline.sh | 26 ++++++ sota-check/run_iql_online.sh | 26 ++++++ sota-check/run_multiagent_iddpg.sh | 26 ++++++ sota-check/run_multiagent_ippo.sh | 26 ++++++ sota-check/run_multiagent_iql.sh | 26 ++++++ sota-check/run_multiagent_qmix.sh | 26 ++++++ sota-check/run_multiagent_sac.sh | 27 ++++++ sota-check/run_ppo_atari.sh | 26 ++++++ sota-check/run_ppo_mujoco.sh | 26 ++++++ sota-check/run_sac.sh | 26 ++++++ sota-check/run_td3.sh | 26 ++++++ sota-check/submitit-release-check.sh | 91 +++++++++++++++++++ torchrl/collectors/collectors.py | 2 +- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/rlhf/dataset.py | 2 +- torchrl/envs/batched_envs.py | 16 ++-- torchrl/envs/common.py | 10 +- torchrl/envs/gym_like.py | 4 +- torchrl/envs/transforms/gym_transforms.py | 2 +- torchrl/envs/transforms/transforms.py | 10 +- torchrl/envs/utils.py | 2 +- torchrl/trainers/trainers.py | 2 +- 94 files changed, 1043 insertions(+), 110 deletions(-) create mode 100644 examples/bandits/README.md create mode 100644 examples/dreamer/README.md create mode 100644 examples/redq/README.md create mode 100644 sota-check/README.md create mode 100644 sota-check/run_a2c_atari.sh create mode 100644 sota-check/run_a2c_mujoco.sh create mode 100644 sota-check/run_cql_offline.sh create mode 100644 sota-check/run_cql_online.sh create mode 100644 sota-check/run_ddpg.sh create mode 100644 sota-check/run_discrete_sac.sh create mode 100644 sota-check/run_dqn_atari.sh create mode 100644 sota-check/run_dqn_cartpole.sh create mode 100644 sota-check/run_dt.sh create mode 100644 sota-check/run_dt_online.sh create mode 100644 sota-check/run_impala_single_node.sh create mode 100644 sota-check/run_iql_discrete.sh create mode 100644 sota-check/run_iql_offline.sh create mode 100644 sota-check/run_iql_online.sh create mode 100644 sota-check/run_multiagent_iddpg.sh create mode 100644 sota-check/run_multiagent_ippo.sh create mode 100644 sota-check/run_multiagent_iql.sh create mode 100644 sota-check/run_multiagent_qmix.sh create mode 100644 sota-check/run_multiagent_sac.sh create mode 100644 sota-check/run_ppo_atari.sh create mode 100644 sota-check/run_ppo_mujoco.sh create mode 100644 sota-check/run_sac.sh create mode 100644 sota-check/run_td3.sh create mode 100755 sota-check/submitit-release-check.sh diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 0f3685ee59e..0cbcb70ad15 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -45,6 +45,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans optim.updates_per_episode=3 \ optim.warmup_steps=10 \ optim.device=cuda:0 \ + env.backend=gymnasium \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \ optim.gradient_steps=55 \ diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index 0452d7d600f..f329bf7b120 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -93,7 +93,14 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend: exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}") logger = get_logger( - cfg.logger.backend, logger_name="a2c", experiment_name=exp_name + cfg.logger.backend, + logger_name="a2c", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 2628a6f388c..2f38af032a8 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -79,7 +79,14 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend: exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}") logger = get_logger( - cfg.logger.backend, logger_name="a2c", experiment_name=exp_name + cfg.logger.backend, + logger_name="a2c", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment diff --git a/examples/a2c/config_atari.yaml b/examples/a2c/config_atari.yaml index 0b06584ee67..8c94f62fb93 100644 --- a/examples/a2c/config_atari.yaml +++ b/examples/a2c/config_atari.yaml @@ -11,6 +11,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_a2c + group_name: null exp_name: Atari_Schulman17 test_interval: 40_000_000 num_test_episodes: 3 diff --git a/examples/a2c/config_mujoco.yaml b/examples/a2c/config_mujoco.yaml index 48627059de9..b30b7304f61 100644 --- a/examples/a2c/config_mujoco.yaml +++ b/examples/a2c/config_mujoco.yaml @@ -1,6 +1,6 @@ # task and env env: - env_name: HalfCheetah-v3 + env_name: HalfCheetah-v4 # collector collector: @@ -10,6 +10,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_a2c + group_name: null exp_name: Mujoco_Schulman17 test_interval: 1_000_000 num_test_episodes: 5 diff --git a/examples/bandits/README.md b/examples/bandits/README.md new file mode 100644 index 00000000000..3c4a0f680f4 --- /dev/null +++ b/examples/bandits/README.md @@ -0,0 +1,7 @@ +# Bandits example + +## Note: +This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the +benchmarking of future releases, to ensure that it can be successfully run with the release code and that the +results are consistent. For now, be aware that this additional check has not been performed in the case of this +specific example. diff --git a/examples/cql/cql_offline.py b/examples/cql/cql_offline.py index c33bce7d65b..8f1dc5e3897 100644 --- a/examples/cql/cql_offline.py +++ b/examples/cql/cql_offline.py @@ -32,14 +32,19 @@ @hydra.main(config_path=".", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 # Create logger - exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name) + exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="cql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds torch.manual_seed(cfg.env.seed) diff --git a/examples/cql/cql_online.py b/examples/cql/cql_online.py index 4ee218da770..c42e733c31b 100644 --- a/examples/cql/cql_online.py +++ b/examples/cql/cql_online.py @@ -36,14 +36,19 @@ @hydra.main(version_base="1.1", config_path=".", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 # Create logger - exp_name = generate_exp_name("CQL-online", cfg.env.exp_name) + exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="cql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/examples/cql/discrete_cql_config.yaml b/examples/cql/discrete_cql_config.yaml index 2b449629d16..b7f8d527ba3 100644 --- a/examples/cql/discrete_cql_config.yaml +++ b/examples/cql/discrete_cql_config.yaml @@ -3,7 +3,6 @@ env: name: CartPole-v1 task: "" backend: gym - exp_name: cql_cartpole_gym n_samples_stats: 1000 max_episode_steps: 200 seed: 0 @@ -24,6 +23,9 @@ collector: # Logger logger: backend: wandb + project_name: torchrl_example_cql + group_name: null + exp_name: cql_cartpole_gym log_interval: 5000 # record interval in frames eval_steps: 200 mode: online diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py index cc4f89d667e..facbcc49bf9 100644 --- a/examples/cql/discrete_cql_online.py +++ b/examples/cql/discrete_cql_online.py @@ -38,14 +38,18 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.optim.device) # Create logger - exp_name = generate_exp_name("DiscreteCQL", cfg.env.exp_name) + exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="discretecql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + }, ) # Set seeds diff --git a/examples/cql/offline_config.yaml b/examples/cql/offline_config.yaml index 517da255481..d41db847077 100644 --- a/examples/cql/offline_config.yaml +++ b/examples/cql/offline_config.yaml @@ -1,9 +1,8 @@ # env and task env: - name: Hopper-v2 + name: Hopper-v4 task: "" library: gym - exp_name: cql_${replay_buffer.dataset} n_samples_stats: 1000 seed: 0 backend: gym # D4RL uses gym so we make sure gymnasium is hidden @@ -11,6 +10,9 @@ env: # logger logger: backend: wandb + project_name: torchrl_example_cql + group_name: null + exp_name: cql_${replay_buffer.dataset} eval_iter: 5000 eval_steps: 1000 mode: online diff --git a/examples/cql/online_config.yaml b/examples/cql/online_config.yaml index 6c29820856b..367d4755cac 100644 --- a/examples/cql/online_config.yaml +++ b/examples/cql/online_config.yaml @@ -2,7 +2,6 @@ env: name: Pendulum-v1 task: "" - exp_name: cql_${env.name} n_samples_stats: 1000 seed: 0 train_num_envs: 1 @@ -23,6 +22,9 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_cql + group_name: null + exp_name: cql_${env.name} log_interval: 5000 # record interval in frames eval_steps: 1000 mode: online diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 2b3713c0407..fb4a3fa4725 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -1,8 +1,7 @@ # environment and task env: - name: HalfCheetah-v3 + name: HalfCheetah-v4 task: "" - exp_name: ${env.name}_DDPG library: gymnasium max_episode_steps: 1000 seed: 42 @@ -22,7 +21,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ${env.exp_name}_${env.seed} + scratch_dir: ${logger.exp_name}_${env.seed} # optimization optim: @@ -44,5 +43,8 @@ network: # logging logger: backend: wandb + project_name: torchrl_example_ddpg + group_name: null + exp_name: ${env.name}_DDPG mode: online eval_iter: 25000 diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 1eb7af83e02..ea5a1386e4f 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) # Create logger - exp_name = generate_exp_name("DDPG", cfg.env.exp_name) + exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="ddpg_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index d42b52f365e..80915c4f93a 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -1,6 +1,6 @@ # environment and task env: - name: HalfCheetah-v3 + name: HalfCheetah-v4 task: "" library: gym stacked_frames: 20 @@ -20,7 +20,9 @@ env: # logger logger: backend: wandb + project_name: torchrl_example_dt model_name: DT + group_name: null exp_name: DT-HalfCheetah-medium-v2 pretrain_log_interval: 500 # record interval in frames fintune_log_interval: 1 diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index 62376414949..b6137ac62a1 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -1,6 +1,6 @@ # environment and task env: - name: HalfCheetah-v3 + name: HalfCheetah-v4 task: "" library: gym stacked_frames: 20 @@ -20,6 +20,8 @@ env: # logger logger: backend: wandb + project_name: torchrl_example_odt + group_name: null exp_name: oDT-HalfCheetah-medium-v2 model_name: oDT pretrain_log_interval: 500 # record interval in frames diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 940e26a5c0a..8bd9f3bebbf 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -493,17 +493,18 @@ def make_dt_optimizer(optim_cfg, loss_module): def make_logger(cfg): - from omegaconf import OmegaConf - if not cfg.logger.backend: return None exp_name = generate_exp_name(cfg.logger.model_name, cfg.logger.exp_name) - cfg.logger.exp_name = exp_name logger = get_logger( cfg.logger.backend, logger_name=cfg.logger.model_name, experiment_name=exp_name, - wandb_kwargs={"config": OmegaConf.to_container(cfg)}, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) return logger diff --git a/examples/discrete_sac/config.yaml b/examples/discrete_sac/config.yaml index 98f908e84d8..03ae3999f87 100644 --- a/examples/discrete_sac/config.yaml +++ b/examples/discrete_sac/config.yaml @@ -3,7 +3,6 @@ env: name: CartPole-v1 task: "" - exp_name: ${env.name}_DiscreteSAC library: gym seed: 42 max_episode_steps: 500 @@ -23,7 +22,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: ${env.exp_name}_${env.seed} + scratch_dir: ${logger.exp_name}_${env.seed} # optim optim: @@ -48,5 +47,8 @@ network: # logging logger: backend: wandb + project_name: torchrl_example_discrete_sac + group_name: null + exp_name: ${env.name}_DiscreteSAC mode: online eval_iter: 5000 diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 1f052837b2d..2976cf8806d 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) # Create logger - exp_name = generate_exp_name("DiscreteSAC", cfg.env.exp_name) + exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="DiscreteSAC_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/examples/dqn/config_atari.yaml b/examples/dqn/config_atari.yaml index dcdba004c48..691fb4ff626 100644 --- a/examples/dqn/config_atari.yaml +++ b/examples/dqn/config_atari.yaml @@ -21,7 +21,9 @@ buffer: # logger logger: - backend: null + backend: wandb + project_name: torchrl_example_dqn + group_name: null exp_name: DQN test_interval: 1_000_000 num_test_episodes: 3 diff --git a/examples/dqn/config_cartpole.yaml b/examples/dqn/config_cartpole.yaml index c29a0c9cb35..1ebeba42f8c 100644 --- a/examples/dqn/config_cartpole.yaml +++ b/examples/dqn/config_cartpole.yaml @@ -20,7 +20,9 @@ buffer: # logger logger: - backend: null + backend: wandb + project_name: torchrl_example_dqn + group_name: null exp_name: DQN test_interval: 50_000 num_test_episodes: 5 diff --git a/examples/dqn/dqn_atari.py b/examples/dqn/dqn_atari.py index ecfbfa9deab..34be877c320 100644 --- a/examples/dqn/dqn_atari.py +++ b/examples/dqn/dqn_atari.py @@ -99,7 +99,14 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend: exp_name = generate_exp_name("DQN", f"Atari_mnih15_{cfg.env.env_name}") logger = get_logger( - cfg.logger.backend, logger_name="dqn", experiment_name=exp_name + cfg.logger.backend, + logger_name="dqn", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create the test environment diff --git a/examples/dqn/dqn_cartpole.py b/examples/dqn/dqn_cartpole.py index 792b1f65477..5f6bc742cb7 100644 --- a/examples/dqn/dqn_cartpole.py +++ b/examples/dqn/dqn_cartpole.py @@ -82,7 +82,14 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend: exp_name = generate_exp_name("DQN", f"CartPole_{cfg.env.env_name}") logger = get_logger( - cfg.logger.backend, logger_name="dqn", experiment_name=exp_name + cfg.logger.backend, + logger_name="dqn", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create the test environment diff --git a/examples/dreamer/README.md b/examples/dreamer/README.md new file mode 100644 index 00000000000..94e28dc63d9 --- /dev/null +++ b/examples/dreamer/README.md @@ -0,0 +1,7 @@ +# Dreamer example + +## Note: +This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the +benchmarking of future releases, to ensure that it can be successfully run with the release code and that the +results are consistent. For now, be aware that this additional check has not been performed in the case of this +specific example. diff --git a/examples/dreamer/config.yaml b/examples/dreamer/config.yaml index 0ea20873557..e81d74e08fa 100644 --- a/examples/dreamer/config.yaml +++ b/examples/dreamer/config.yaml @@ -32,6 +32,7 @@ init_env_steps: 1000 init_random_frames: 5000 logger: csv offline_logging: False +project_name: torchrl_example_dreamer normalize_rewards_online: True normalize_rewards_online_scale: 5.0 normalize_rewards_online_decay: 0.99999 diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 8c1e9da2e46..1cf6c91856f 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -92,7 +92,7 @@ def main(cfg: "DictConfig"): # noqa: F821 logger_name="dreamer", experiment_name=exp_name, wandb_kwargs={ - "project": "torchrl", + "project": cfg.project_name, "group": f"Dreamer_{cfg.env_name}", "offline": cfg.offline_logging, }, diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index e312b336651..c67b5ed52da 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -41,6 +41,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_impala_ray + group_name: null exp_name: Atari_IMPALA test_interval: 200_000_000 num_test_episodes: 3 diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index f632ba15dc2..59973e46b40 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -22,6 +22,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_impala_submitit + group_name: null exp_name: Atari_IMPALA test_interval: 200_000_000 num_test_episodes: 3 diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml index d39407c1a69..b93c3802a33 100644 --- a/examples/impala/config_single_node.yaml +++ b/examples/impala/config_single_node.yaml @@ -14,6 +14,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_impala + group_name: null exp_name: Atari_IMPALA test_interval: 200_000_000 num_test_episodes: 3 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 49b3dd4bd4d..0a2ce0d02e2 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -141,7 +141,11 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.logger.backend, logger_name="impala", experiment_name=exp_name, - project="impala", + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 2b89ef046a1..d702a17a4e6 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -133,7 +133,11 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.logger.backend, logger_name="impala", experiment_name=exp_name, - project="impala", + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index f5b64e4718a..836e85de8c3 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -111,7 +111,11 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.logger.backend, logger_name="impala", experiment_name=exp_name, - project="impala", + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment diff --git a/examples/iql/discrete_iql.py b/examples/iql/discrete_iql.py index 39009923d02..8a8307366fc 100644 --- a/examples/iql/discrete_iql.py +++ b/examples/iql/discrete_iql.py @@ -18,6 +18,8 @@ import numpy as np import torch import tqdm + +from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -34,15 +36,22 @@ @hydra.main(config_path=".", config_name="discrete_iql") def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + # Create logger - exp_name = generate_exp_name("Discrete-IQL-online", cfg.env.exp_name) + exp_name = generate_exp_name("Discrete-IQL-online", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="iql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/examples/iql/discrete_iql.yaml b/examples/iql/discrete_iql.yaml index 52b6f8e13ca..c21a320e375 100644 --- a/examples/iql/discrete_iql.yaml +++ b/examples/iql/discrete_iql.yaml @@ -2,12 +2,11 @@ env: name: CartPole-v1 task: "" - exp_name: iql_${env.name} n_samples_stats: 1000 seed: 0 train_num_envs: 1 eval_num_envs: 1 - backend: gym + backend: gymnasium # collector @@ -22,6 +21,9 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_discrete_iql + exp_name: iql_${env.name} + group_name: null log_interval: 5000 # record interval in frames eval_steps: 200 mode: online diff --git a/examples/iql/iql_offline.py b/examples/iql/iql_offline.py index 927ac924e90..b6895592a20 100644 --- a/examples/iql/iql_offline.py +++ b/examples/iql/iql_offline.py @@ -31,18 +31,24 @@ ) -@set_gym_backend("gym") @hydra.main(config_path=".", config_name="offline_config") def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + # Create logger - exp_name = generate_exp_name("IQL-offline", cfg.env.exp_name) + exp_name = generate_exp_name("IQL-offline", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="iql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 663aa2d82d3..461eb6bb37d 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -18,6 +18,8 @@ import numpy as np import torch import tqdm + +from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -34,15 +36,22 @@ @hydra.main(config_path=".", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + # Create logger - exp_name = generate_exp_name("IQL-online", cfg.env.exp_name) + exp_name = generate_exp_name("IQL-online", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="iql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/examples/iql/offline_config.yaml b/examples/iql/offline_config.yaml index 8b8cbe8c776..341e995967a 100644 --- a/examples/iql/offline_config.yaml +++ b/examples/iql/offline_config.yaml @@ -2,14 +2,17 @@ env: name: HalfCheetah-v2 task: "" - backend: gym exp_name: iql_${replay_buffer.dataset} n_samples_stats: 1000 seed: 0 + backend: gymnasium # logger logger: backend: wandb + project_name: torchrl_example_iql + exp_name: iql_${replay_buffer.dataset} + group_name: null eval_iter: 500 eval_steps: 1000 mode: online diff --git a/examples/iql/online_config.yaml b/examples/iql/online_config.yaml index e3ef0d081c4..511d77ec365 100644 --- a/examples/iql/online_config.yaml +++ b/examples/iql/online_config.yaml @@ -2,13 +2,11 @@ env: name: Pendulum-v1 task: "" - exp_name: iql_${env.name} n_samples_stats: 1000 seed: 0 train_num_envs: 1 eval_num_envs: 1 - backend: gym - + backend: gymnasium # collector collector: @@ -23,6 +21,9 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_iql + exp_name: iql_${env.name} + group_name: null log_interval: 5000 # record interval in frames eval_steps: 200 mode: online diff --git a/examples/multiagent/iql.yaml b/examples/multiagent/iql.yaml index 801503b7e9d..dc748e3601c 100644 --- a/examples/multiagent/iql.yaml +++ b/examples/multiagent/iql.yaml @@ -36,3 +36,5 @@ eval: logger: backend: wandb # Delete to remove logging + project_name: null + group_name: null diff --git a/examples/multiagent/maddpg_iddpg.yaml b/examples/multiagent/maddpg_iddpg.yaml index 19328cbc39e..8aa97db09da 100644 --- a/examples/multiagent/maddpg_iddpg.yaml +++ b/examples/multiagent/maddpg_iddpg.yaml @@ -37,3 +37,5 @@ eval: logger: backend: wandb # Delete to remove logging + project_name: null + group_name: null diff --git a/examples/multiagent/mappo_ippo.yaml b/examples/multiagent/mappo_ippo.yaml index befec1cf1ca..ed47456b63f 100644 --- a/examples/multiagent/mappo_ippo.yaml +++ b/examples/multiagent/mappo_ippo.yaml @@ -39,3 +39,5 @@ eval: logger: backend: wandb # Delete to remove logging + project_name: null + group_name: null diff --git a/examples/multiagent/qmix_vdn.yaml b/examples/multiagent/qmix_vdn.yaml index a78b3987ffb..bac6db99d63 100644 --- a/examples/multiagent/qmix_vdn.yaml +++ b/examples/multiagent/qmix_vdn.yaml @@ -37,3 +37,5 @@ eval: logger: backend: wandb # Delete to remove logging + project_name: null + group_name: null diff --git a/examples/multiagent/sac.yaml b/examples/multiagent/sac.yaml index ab478ab0dc8..33464debc7d 100644 --- a/examples/multiagent/sac.yaml +++ b/examples/multiagent/sac.yaml @@ -39,3 +39,5 @@ eval: logger: backend: wandb # Delete to remove logging + project_name: null + group_name: null diff --git a/examples/multiagent/utils/logging.py b/examples/multiagent/utils/logging.py index 352d0addc51..cb6df4de7ea 100644 --- a/examples/multiagent/utils/logging.py +++ b/examples/multiagent/utils/logging.py @@ -18,8 +18,9 @@ def init_logging(cfg, model_name: str): logger_name=os.getcwd(), experiment_name=generate_exp_name(cfg.env.scenario_name, model_name), wandb_kwargs={ - "group": model_name, - "project": f"torchrl_{cfg.env.scenario_name}", + "group": cfg.logger.group_name or model_name, + "project": cfg.logger.project_name + or f"torchrl_example_{cfg.env.scenario_name}", }, ) logger.log_hparams(cfg) diff --git a/examples/ppo/config_atari.yaml b/examples/ppo/config_atari.yaml index 6957fd9bddd..d6ec35ab5f2 100644 --- a/examples/ppo/config_atari.yaml +++ b/examples/ppo/config_atari.yaml @@ -11,6 +11,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_ppo + group_name: null exp_name: Atari_Schulman17 test_interval: 40_000_000 num_test_episodes: 3 diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml index 0322526e7b1..3320837ae3d 100644 --- a/examples/ppo/config_mujoco.yaml +++ b/examples/ppo/config_mujoco.yaml @@ -1,6 +1,6 @@ # task and env env: - env_name: HalfCheetah-v3 + env_name: HalfCheetah-v4 # collector collector: @@ -10,6 +10,8 @@ collector: # logger logger: backend: wandb + project_name: torchrl_example_ppo + group_name: null exp_name: Mujoco_Schulman17 test_interval: 1_000_000 num_test_episodes: 5 diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index 1e69dd7678d..238e612e614 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -45,12 +45,12 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"), policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, - device=device, - storing_device=device, + device="cpu", + storing_device="cpu", max_frames_per_traj=-1, ) @@ -96,7 +96,14 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend: exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") logger = get_logger( - cfg.logger.backend, logger_name="ppo", experiment_name=exp_name + cfg.logger.backend, + logger_name="ppo", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment @@ -151,9 +158,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # Compute GAE with torch.no_grad(): - data = adv_module(data) + data = adv_module(data.to(device, non_blocking=True)) data_reshape = data.reshape(-1) - # Update the data buffer data_buffer.extend(data_reshape) @@ -168,9 +174,8 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg_loss_anneal_clip_eps: loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) num_network_updates += 1 - # Get a data batch - batch = batch.to(device) + batch = batch.to(device, non_blocking=True) # Forward pass PPO loss loss = loss_module(batch) @@ -180,7 +185,6 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_sum = ( loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) - # Backward pass loss_sum.backward() torch.nn.utils.clip_grad_norm_( @@ -231,6 +235,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() end_time = time.time() execution_time = end_time - start_time logging.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index 90fe74650f5..83ee779c6ab 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -88,7 +88,14 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend: exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") logger = get_logger( - cfg.logger.backend, logger_name="ppo", experiment_name=exp_name + cfg.logger.backend, + logger_name="ppo", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Create test environment diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py index c78bc67f45a..eaef640ebb0 100644 --- a/examples/ppo/utils_atari.py +++ b/examples/ppo/utils_atari.py @@ -41,15 +41,13 @@ # -------------------------------------------------------------------- -def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False -): +def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False): env = GymEnv( env_name, frame_skip=frame_skip, from_pixels=True, pixels_only=False, - device=device, + device="cpu", ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) @@ -61,8 +59,9 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): env = ParallelEnv( num_envs, - EnvCreator(lambda: make_base_env(env_name, device=device)), + EnvCreator(lambda: make_base_env(env_name)), serial_for_single=True, + device=device, ) env = TransformedEnv(env) env.append_transform(ToTensorImage()) diff --git a/examples/redq/README.md b/examples/redq/README.md new file mode 100644 index 00000000000..151cd4c3be8 --- /dev/null +++ b/examples/redq/README.md @@ -0,0 +1,7 @@ +# REDQ example + +## Note: +This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the +benchmarking of future releases, to ensure that it can be successfully run with the release code and that the +results are consistent. For now, be aware that this additional check has not been performed in the case of this +specific example. diff --git a/examples/redq/config.yaml b/examples/redq/config.yaml index 24e9ae2a60e..fc77974cb38 100644 --- a/examples/redq/config.yaml +++ b/examples/redq/config.yaml @@ -39,13 +39,14 @@ collector: exploration_mode: random logger: + backend: wandb + project_name: torchrl_example_redq + group_name: null + exp_name: cheetah record_video: 0 record_interval: 10 record_frames: 10000 - exp_name: cheetah - backend: wandb - kwargs: - offline: False + mode: online recorder_log_keys: optim: diff --git a/examples/redq/redq.py b/examples/redq/redq.py index 913216f44a8..f89098e1441 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -8,7 +8,6 @@ import hydra import torch.cuda -from omegaconf import OmegaConf from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -68,7 +67,12 @@ def main(cfg: "DictConfig"): # noqa: F821 logger_type=cfg.logger.backend, logger_name="redq_logging", experiment_name=exp_name, - **OmegaConf.to_container(cfg.logger.kwargs), + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) video_tag = exp_name if cfg.logger.record_video else "" diff --git a/examples/rlhf/README.md b/examples/rlhf/README.md index c4b0a261101..0a20f8dc906 100644 --- a/examples/rlhf/README.md +++ b/examples/rlhf/README.md @@ -3,6 +3,12 @@ This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts. +## Note: +This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the +benchmarking of future releases, to ensure that it can be successfully run with the release code and that the +results are consistent. For now, be aware that this additional check has not been performed in the case of this +specific example. + ## Getting started Make sure you have PyTorch>=2.0 installed. You can find installation instructions diff --git a/examples/rlhf/config/train_rlhf.yaml b/examples/rlhf/config/train_rlhf.yaml index 024c239463e..aa8e41a1319 100644 --- a/examples/rlhf/config/train_rlhf.yaml +++ b/examples/rlhf/config/train_rlhf.yaml @@ -3,6 +3,8 @@ io: log_interval: 1 eval_iters: 10 logger: wandb + project_name: torchrl_example_rlhf + group_name: null data: batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size block_size: 550 diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 7dce72e7dd4..6f3e80649d7 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -56,7 +56,14 @@ def main(cfg): ctx = setup(cfg.sys) logger = get_logger( - logger_type=cfg.io.logger, logger_name="./log", experiment_name="torchrlhf-gpt2" + logger_type=cfg.io.logger, + logger_name="./log", + experiment_name="torchrlhf-gpt2", + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.io.project_name, + "group": cfg.logger.group_name, + }, ) # =============== Dataloaders =============== # diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index dfd0ae30c14..b6675ecc9a0 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,8 +1,7 @@ # environment and task env: - name: HalfCheetah-v3 + name: HalfCheetah-v4 task: "" - exp_name: ${env.name}_SAC library: gymnasium max_episode_steps: 1000 seed: 42 @@ -21,7 +20,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ${env.exp_name}_${env.seed} + scratch_dir: ${logger.exp_name}_${env.seed} # optim optim: @@ -46,5 +45,8 @@ network: # logging logger: backend: wandb + project_name: torchrl_example_sac + group_name: null + exp_name: ${env.name}_SAC mode: online eval_iter: 25000 diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 9a08cd8ef9b..a93e3a833dd 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -39,14 +39,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) # Create logger - exp_name = generate_exp_name("SAC", cfg.env.exp_name) + exp_name = generate_exp_name("SAC", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="sac_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) torch.manual_seed(cfg.env.seed) diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 210d865c11d..561766cd5a4 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,8 +1,7 @@ # task and env env: - name: HalfCheetah-v3 + name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency task: "" - exp_name: ${env.name}_TD3 library: gymnasium seed: 42 max_episode_steps: 1000 @@ -22,7 +21,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: ${env.exp_name}_${env.seed} + scratch_dir: ${logger.exp_name}_${env.seed} # optim optim: @@ -47,5 +46,8 @@ network: # logging logger: backend: wandb + project_name: torchrl_example_td3 + group_name: null + exp_name: ${env.name}_TD3 mode: online eval_iter: 25000 diff --git a/examples/td3/td3.py b/examples/td3/td3.py index ab21db76b15..1f42e7273d1 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) # Create logger - exp_name = generate_exp_name("TD3", cfg.env.exp_name) + exp_name = generate_exp_name("TD3", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( logger_type=cfg.logger.backend, logger_name="td3_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, ) # Set seeds diff --git a/sota-check/README.md b/sota-check/README.md new file mode 100644 index 00000000000..3d20540ff73 --- /dev/null +++ b/sota-check/README.md @@ -0,0 +1,35 @@ +# SOTA Performance checks + +This folder contains a `submitit-release-check.sh` file that executes all +the training scripts using `sbatch` with the default configuration and long them +into a common WandB project. + +This script is to be executed before every release to assess the performance of +the various algorithms available in torchrl. The name of the project will include +the specific commit of torchrl used to run the scripts (e.g. `torchrl-examples-check-`). + +## Usage + +To display the script usage, you can use the `--help` option: + +```bash +./submitit-release-check.sh --help +``` + +## Setup + +The following setup should allow you to run the scripts: + +```bash +export MUJOCO_GL=egl + +conda create -n rl-sota-bench python=3.10 -y +conda install anaconda::libglu -y +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 +pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame moviepy imageio submitit hydra-core transformers + +cd /path/to/tensordict +python setup.py develop +cd /path/to/torchrl +python setup.py develop +``` diff --git a/sota-check/run_a2c_atari.sh b/sota-check/run_a2c_atari.sh new file mode 100644 index 00000000000..610cf5389f8 --- /dev/null +++ b/sota-check/run_a2c_atari.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +#SBATCH --job-name=a2c_atari +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/a2c_atari_%j.txt +#SBATCH --error=slurm_errors/a2c_atari_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="a2c_atari" + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/a2c/a2c_atari.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi diff --git a/sota-check/run_a2c_mujoco.sh b/sota-check/run_a2c_mujoco.sh new file mode 100644 index 00000000000..f26bc96fe01 --- /dev/null +++ b/sota-check/run_a2c_mujoco.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=a2c_mujoco +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/a2c_mujoco_%j.txt +#SBATCH --error=slurm_errors/a2c_mujoco_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="a2c_mujoco" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/a2c/a2c_mujoco.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi diff --git a/sota-check/run_cql_offline.sh b/sota-check/run_cql_offline.sh new file mode 100644 index 00000000000..fa3a42c7429 --- /dev/null +++ b/sota-check/run_cql_offline.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +#SBATCH --job-name=cql_offline +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/cql_offline_%j.txt +#SBATCH --error=slurm_errors/cql_offline_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="cql_offline" + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/cql/cql_offline.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_cql_online.sh b/sota-check/run_cql_online.sh new file mode 100644 index 00000000000..78548d9f418 --- /dev/null +++ b/sota-check/run_cql_online.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=cql_online +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/cql_online_%j.txt +#SBATCH --error=slurm_errors/cql_online_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="cql_online" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/cql/cql_online.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_ddpg.sh b/sota-check/run_ddpg.sh new file mode 100644 index 00000000000..7131db8a6e7 --- /dev/null +++ b/sota-check/run_ddpg.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=ddpg +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/ddpg_%j.txt +#SBATCH --error=slurm_errors/ddpg_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="ddpg" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/ddpg/ddpg.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_discrete_sac.sh b/sota-check/run_discrete_sac.sh new file mode 100644 index 00000000000..dfb6a68ce02 --- /dev/null +++ b/sota-check/run_discrete_sac.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=discrete_sac +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/discrete_sac_%j.txt +#SBATCH --error=slurm_errors/discrete_sac_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="discrete_sac" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/discrete_sac/discrete_sac.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_dqn_atari.sh b/sota-check/run_dqn_atari.sh new file mode 100644 index 00000000000..35aa2adb3be --- /dev/null +++ b/sota-check/run_dqn_atari.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=dqn_atari +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/dqn_atari_%j.txt +#SBATCH --error=slurm_errors/dqn_atari_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="dqn_atari" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/dqn/dqn_atari.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_dqn_cartpole.sh b/sota-check/run_dqn_cartpole.sh new file mode 100644 index 00000000000..cfe954a4f09 --- /dev/null +++ b/sota-check/run_dqn_cartpole.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=dqn_cartpole +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/dqn_cartpole_%j.txt +#SBATCH --error=slurm_errors/dqn_cartpole_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="dqn_cartpole" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/dqn/dqn_cartpole.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_dt.sh b/sota-check/run_dt.sh new file mode 100644 index 00000000000..41ec685664d --- /dev/null +++ b/sota-check/run_dt.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=dt +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/dt_offline_%j.txt +#SBATCH --error=slurm_errors/dt_offline_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="dt_offline" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/decision_transformer/dt.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_dt_online.sh b/sota-check/run_dt_online.sh new file mode 100644 index 00000000000..2f116aa3bcf --- /dev/null +++ b/sota-check/run_dt_online.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=dt_online +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/dt_online_%j.txt +#SBATCH --error=slurm_errors/dt_online_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="dt_online" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/decision_transformer/online_dt.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_impala_single_node.sh b/sota-check/run_impala_single_node.sh new file mode 100644 index 00000000000..3dc3cd56ac2 --- /dev/null +++ b/sota-check/run_impala_single_node.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=impala_1node +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/impala_1node_%j.txt +#SBATCH --error=slurm_errors/impala_1node_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="impala_1node" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/impala/impala_single_node.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_iql_discrete.sh b/sota-check/run_iql_discrete.sh new file mode 100644 index 00000000000..b659ed6dc31 --- /dev/null +++ b/sota-check/run_iql_discrete.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=iql_discrete +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/iql_discrete_%j.txt +#SBATCH --error=slurm_errors/iql_discrete_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="iql_discrete" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/iql/discrete_iql.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_iql_offline.sh b/sota-check/run_iql_offline.sh new file mode 100644 index 00000000000..bd4ef8f6e69 --- /dev/null +++ b/sota-check/run_iql_offline.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=iql_offline +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/iql_offline_%j.txt +#SBATCH --error=slurm_errors/iql_offline_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="iql_offline" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/iql/iql_offline.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi diff --git a/sota-check/run_iql_online.sh b/sota-check/run_iql_online.sh new file mode 100644 index 00000000000..702d2b8cbff --- /dev/null +++ b/sota-check/run_iql_online.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=iql_online +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/iql_online_%j.txt +#SBATCH --error=slurm_errors/iql_online_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="iql_online" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/iql/iql_online.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_multiagent_iddpg.sh b/sota-check/run_multiagent_iddpg.sh new file mode 100644 index 00000000000..4629fbff228 --- /dev/null +++ b/sota-check/run_multiagent_iddpg.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=marl_iddpg +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/marl_iddpg_%j.txt +#SBATCH --error=slurm_errors/marl_iddpg_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="marl_iddpg" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/multiagent/maddpg_iddpg.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_multiagent_ippo.sh b/sota-check/run_multiagent_ippo.sh new file mode 100644 index 00000000000..036f739e2e2 --- /dev/null +++ b/sota-check/run_multiagent_ippo.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=marl_ippo +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/marl_ippo_%j.txt +#SBATCH --error=slurm_errors/marl_ippo_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="mappo_ippo" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/multiagent/mappo_ippo.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_multiagent_iql.sh b/sota-check/run_multiagent_iql.sh new file mode 100644 index 00000000000..f5bb6a7af23 --- /dev/null +++ b/sota-check/run_multiagent_iql.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=marl_iql +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/marl_iql_%j.txt +#SBATCH --error=slurm_errors/marl_iql_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="marl_iql" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/multiagent/iql.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_multiagent_qmix.sh b/sota-check/run_multiagent_qmix.sh new file mode 100644 index 00000000000..08b32ce257a --- /dev/null +++ b/sota-check/run_multiagent_qmix.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=marl_qmix_vdn +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/marl_qmix_vdn_%j.txt +#SBATCH --error=slurm_errors/marl_qmix_vdn_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="marl_qmix_vdn" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/multiagent/qmix_vdn.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_multiagent_sac.sh b/sota-check/run_multiagent_sac.sh new file mode 100644 index 00000000000..10e1bbb2d4d --- /dev/null +++ b/sota-check/run_multiagent_sac.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +#SBATCH --job-name=marl_sac +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/marl_sac_%j.txt +#SBATCH --error=slurm_errors/marl_sac_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="marl_sac" + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/multiagent/sac.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_ppo_atari.sh b/sota-check/run_ppo_atari.sh new file mode 100644 index 00000000000..764727acb7e --- /dev/null +++ b/sota-check/run_ppo_atari.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=ppo_atari +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/ppo_atari_%j.txt +#SBATCH --error=slurm_errors/ppo_atari_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="ppo_atari" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/ppo/ppo_atari.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_ppo_mujoco.sh b/sota-check/run_ppo_mujoco.sh new file mode 100644 index 00000000000..0e35974ffcc --- /dev/null +++ b/sota-check/run_ppo_mujoco.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=ppo_mujoco +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/ppo_mujoco_%j.txt +#SBATCH --error=slurm_errors/ppo_mujoco_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="ppo_mujoco" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/ppo/ppo_mujoco.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_sac.sh b/sota-check/run_sac.sh new file mode 100644 index 00000000000..8c7b8ffa5ab --- /dev/null +++ b/sota-check/run_sac.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=sac +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/sac_%j.txt +#SBATCH --error=slurm_errors/sac_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="sac" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/sac/sac.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/run_td3.sh b/sota-check/run_td3.sh new file mode 100644 index 00000000000..314ba68b4ac --- /dev/null +++ b/sota-check/run_td3.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=td3 +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/td3_%j.txt +#SBATCH --error=slurm_errors/td3_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="td3" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/examples/td3/td3.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh new file mode 100755 index 00000000000..cad2783c653 --- /dev/null +++ b/sota-check/submitit-release-check.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +# Function to display script usage +display_usage() { + cat < --n_runs 5 + +EOF + return 1 +} + +# Check if the script is called with --help or without any arguments +if [ "$1" == "--help" ]; then + display_usage +fi + +# Initialize variables with default values +n_runs="1" +slurm_partition="" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --n_runs) + n_runs="$2" + shift 2 + ;; + --partition) + slurm_partition="$2" + shift 2 + ;; + *) + echo "$1 is not a valid argument. See './submitit-release-check.sh --help'." + return 0 + ;; + esac +done + +scripts=( + run_a2c_atari.sh + run_a2c_mujoco.sh + run_cql_offline.sh + run_cql_online.sh + run_ddpg.sh + run_discrete_sac.sh + run_dqn_atari.sh + run_dqn_cartpole.sh + run_impala_single_node.sh + run_iql_offline.sh + run_iql_online.sh + run_iql_discrete.sh + run_multiagent_iddpg.sh + run_multiagent_ippo.sh + run_multiagent_iql.sh + run_multiagent_qmix.sh + run_multiagent_sac.sh + run_ppo_atari.sh + run_ppo_mujoco.sh + run_sac.sh + run_td3.sh + run_dt.sh + run_dt_online.sh +) + +mkdir -p "slurm_errors" +mkdir -p "slurm_logs" + +# remove the previous report +rm -f report.log + +# Submit jobs with the specified partition the specified number of times +if [ -z "$slurm_partition" ]; then + for script in "${scripts[@]}"; do + for ((i=1; i<=$n_runs; i++)); do + sbatch "$script" + done + done +else + for script in "${scripts[@]}"; do + for ((i=1; i<=$n_runs; i++)); do + sbatch --partition="$slurm_partition" "$script" + done + done +fi diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 7d5e635a5a9..ffb8c0f5270 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1058,7 +1058,7 @@ def rollout(self) -> TensorDictBase: if self.storing_device is not None: tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=True) + self._shuttle.to(self.storing_device, non_blocking=False) ) else: tensordicts.append(self._shuttle) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index c37cac634e4..8a0510b11b4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -884,7 +884,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: # to be deprecated in v0.4 def map_device(tensor): if tensor.device != self.device: - return tensor.to(self.device, non_blocking=True) + return tensor.to(self.device, non_blocking=False) return tensor if is_tensor_collection(result): diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 9c6b3d1e58a..0824cff585f 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -394,7 +394,7 @@ def get_dataloader( ) out = TensorDictReplayBuffer( storage=TensorStorage(data), - collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True), + collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False), sampler=SamplerWithoutReplacement(drop_last=True), batch_size=batch_size, prefetch=prefetch, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 655db99c983..e22531c1c22 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -783,7 +783,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: elif device is None: out = out.clone().clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=False) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -804,7 +804,7 @@ def _step( # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device: - data_in = tensordict_in[i].to(env_device, non_blocking=True) + data_in = tensordict_in[i].to(env_device, non_blocking=False) else: data_in = tensordict_in[i] out_td = self._envs[i]._step(data_in) @@ -826,7 +826,7 @@ def _step( elif device is None: out = out.clone().clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=False) return out def __getattr__(self, attr: str) -> Any: @@ -1156,8 +1156,8 @@ def step_and_maybe_reset( next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: - next_td = next_td.to(device, non_blocking=True) - tensordict_ = tensordict_.to(device, non_blocking=True) + next_td = next_td.to(device, non_blocking=False) + tensordict_ = tensordict_.to(device, non_blocking=False) else: next_td = next_td.clone().clear_device_() tensordict_ = tensordict_.clone().clear_device_() @@ -1217,7 +1217,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if out.device == device: out = out.clone() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=False) return out @_check_start @@ -1293,7 +1293,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: elif device is None: out = out.clear_device_().clone() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=False) return out @_check_start @@ -1584,5 +1584,5 @@ def look_for_cuda(tensor, has_cuda=has_cuda): def _update_cuda(t_dest, t_source): if t_source is None: return - t_dest.copy_(t_source.pin_memory(), non_blocking=True) + t_dest.copy_(t_source.pin_memory(), non_blocking=False) return diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 61cd211b6ae..b2b201922e1 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2056,7 +2056,7 @@ def reset( tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly # if tensordict_reset.device != self.device: - # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) + # tensordict_reset = tensordict_reset.to(self.device, non_blocking=False) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -2418,13 +2418,13 @@ def _rollout_stop_early( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict = tensordict.to(policy_device, non_blocking=True) + tensordict = tensordict.to(policy_device, non_blocking=False) else: tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: - tensordict = tensordict.to(env_device, non_blocking=True) + tensordict = tensordict.to(env_device, non_blocking=False) else: tensordict.clear_device_() tensordict = self.step(tensordict) @@ -2472,13 +2472,13 @@ def _rollout_nonstop( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict_ = tensordict_.to(policy_device, non_blocking=True) + tensordict_ = tensordict_.to(policy_device, non_blocking=False) else: tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: if env_device is not None: - tensordict_ = tensordict_.to(env_device, non_blocking=True) + tensordict_ = tensordict_.to(env_device, non_blocking=False) else: tensordict_.clear_device_() tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 60cb026c658..3ce3d2d630c 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -322,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=True) + tensordict_out = tensordict_out.to(self.device, non_blocking=False) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -366,7 +366,7 @@ def _reset( for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() - tensordict_out = tensordict_out.to(self.device, non_blocking=True) + tensordict_out = tensordict_out.to(self.device, non_blocking=False) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index 5645785117d..b5aed62d503 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -147,7 +147,7 @@ def _step(self, tensordict, next_tensordict): raise RuntimeError(self.NO_PARENT_ERR.format(type(self))) lives = self._get_lives() - end_of_life = torch.tensor( + end_of_life = torch.as_tensor( tensordict.get(self.lives_key) > lives, device=self.parent.device ) try: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c2c2a6aa047..52ac8e8f66d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3618,10 +3618,10 @@ def set_container(self, container: Union[Transform, EnvBase]) -> None: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=True) + return tensordict.to(self.device, non_blocking=False) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=True) + return tensordict.to(self.device, non_blocking=False) def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -3634,8 +3634,8 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if parent is None: if self.orig_device is None: return tensordict - return tensordict.to(self.orig_device, non_blocking=True) - return tensordict.to(parent.device, non_blocking=True) + return tensordict.to(self.orig_device, non_blocking=False) + return tensordict.to(parent.device, non_blocking=False) def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec.to(self.device) @@ -5152,7 +5152,7 @@ def _reset( if step_count is None: step_count = self.container.observation_spec[step_count_key].zero() if step_count.device != reset.device: - step_count = step_count.to(reset.device, non_blocking=True) + step_count = step_count.to(reset.device, non_blocking=False) # zero the step count if reset is needed step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 0b978e5ef68..e03cb4043cb 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -269,7 +269,7 @@ def _set_single_key( dest = new_val else: if device is not None and val.device != device: - val = val.to(device, non_blocking=True) + val = val.to(device, non_blocking=False) elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index c8629be7f15..96f8d98477d 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -702,7 +702,7 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: def sample(self, batch: TensorDictBase) -> TensorDictBase: sample = self.replay_buffer.sample(batch_size=self.batch_size) - return sample.to(self.device, non_blocking=True) + return sample.to(self.device, non_blocking=False) def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch) From 69453a66db52256d060634e3b5a905bf12e35c4d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 11:47:08 +0000 Subject: [PATCH 24/35] [BugFix] Fix flaky gym penv test (#1853) --- test/_utils_internal.py | 24 +++++++++---------- torchrl/data/datasets/minari_data.py | 4 ++-- torchrl/data/datasets/openx.py | 2 +- torchrl/data/replay_buffers/replay_buffers.py | 4 ++-- torchrl/data/replay_buffers/samplers.py | 4 ++-- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/replay_buffers/utils.py | 4 ++-- torchrl/data/replay_buffers/writers.py | 4 ++-- torchrl/data/rlhf/utils.py | 6 ++--- torchrl/data/tensor_specs.py | 12 +++++----- torchrl/envs/libs/dm_control.py | 4 ++-- torchrl/envs/libs/envpool.py | 6 ++--- torchrl/envs/libs/gym.py | 1 + torchrl/envs/libs/pettingzoo.py | 6 ++--- torchrl/envs/transforms/gym_transforms.py | 4 ++-- torchrl/envs/transforms/r3m.py | 4 ++-- torchrl/envs/transforms/rlhf.py | 2 +- torchrl/envs/transforms/transforms.py | 8 +++---- torchrl/envs/transforms/vc1.py | 4 ++-- torchrl/envs/transforms/vip.py | 4 ++-- torchrl/modules/distributions/continuous.py | 4 ++-- torchrl/modules/models/exploration.py | 4 +++- torchrl/modules/planners/mppi.py | 2 +- .../modules/tensordict_module/exploration.py | 12 +++++----- torchrl/objectives/a2c.py | 6 +++-- torchrl/objectives/deprecated.py | 14 ++++++----- 26 files changed, 78 insertions(+), 73 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8d473dbf4ee..c9fdc7e39ba 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -330,9 +330,9 @@ def rollout_consistency_assertion( ): """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" - done = rollout[:, :-1]["next", done_key].squeeze(-1) + done = rollout[..., :-1]["next", done_key].squeeze(-1) # data resulting from step, when it's not done - r_not_done = rollout[:, :-1]["next"][~done] + r_not_done = rollout[..., :-1]["next"][~done] # data resulting from step, when it's not done, after step_mdp r_not_done_tp1 = rollout[:, 1:][~done] torch.testing.assert_close( @@ -343,17 +343,15 @@ def rollout_consistency_assertion( if done_strict and not done.any(): raise RuntimeError("No done detected, test could not complete.") - - # data resulting from step, when it's done - r_done = rollout[:, :-1]["next"][done] - # data resulting from step, when it's done, after step_mdp and reset - r_done_tp1 = rollout[:, 1:][done] - assert ( - (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1 - ).all(), ( - f"Entries in next tensordict do not match entries in root " - f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}" - ) + if done.any(): + # data resulting from step, when it's done + r_done = rollout[..., :-1]["next"][done] + # data resulting from step, when it's done, after step_mdp and reset + r_done_tp1 = rollout[..., 1:][done] + # check that at least one obs after reset does not match the version before reset + assert not torch.isclose( + r_done[observation_key], r_done_tp1[observation_key] + ).all() def rand_reset(env): diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 5deeccd3253..babe5638c91 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -412,8 +412,8 @@ def _proc_spec(spec): ) return BoundedTensorSpec( shape=spec["shape"], - low=torch.tensor(spec["low"]), - high=torch.tensor(spec["high"]), + low=torch.as_tensor(spec["low"]), + high=torch.as_tensor(spec["high"]), dtype=_DTYPE_DIR[spec["dtype"]], ) elif spec["type"] == "Discrete": diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 0b825188a5b..598ab782147 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -684,7 +684,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value): truncated, dim=data.ndim - 1, value=True, - index=torch.tensor(-1, device=truncated.device), + index=torch.as_tensor(-1, device=truncated.device), ) done = data.get(("next", "done")) data.set(("next", "truncated"), truncated) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 79bf3b9b180..c3999806aaf 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -867,7 +867,7 @@ def add(self, data: TensorDictBase) -> int: device=data.device, ) if data.batch_size: - data_add["_rb_batch_size"] = torch.tensor(data.batch_size) + data_add["_rb_batch_size"] = torch.as_tensor(data.batch_size) else: data_add = data @@ -1441,7 +1441,7 @@ def __getitem__( if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 3460f6ed51c..15e46ae1038 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -461,10 +461,10 @@ def dumps(self, path): filename=path / "mintree.memmap", ) mm_st.copy_( - torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)]) + torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)]) ) mm_mt.copy_( - torch.tensor([self._min_tree[i] for i in range(self._max_capacity)]) + torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)]) ) with open(path / "sampler_metadata.json", "w") as file: json.dump( diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 8a0510b11b4..fd847f25c74 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1005,7 +1005,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index c042f54c652..7846a6bb9d4 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -28,11 +28,11 @@ def _to_torch( data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False ) -> torch.Tensor: if isinstance(data, np.generic): - return torch.tensor(data, device=device) + return torch.as_tensor(data, device=device) elif isinstance(data, np.ndarray): data = torch.from_numpy(data) elif not isinstance(data, Tensor): - data = torch.tensor(data, device=device) + data = torch.as_tensor(data, device=device) if pin_memory: data = data.pin_memory() diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 41d551535ac..156d32f9539 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -357,7 +357,7 @@ def __getstate__(self): def dumps(self, path): path = Path(path).absolute() path.mkdir(exist_ok=True) - t = torch.tensor(self._current_top_values) + t = torch.as_tensor(self._current_top_values) try: MemoryMappedTensor.from_filename( filename=path / "current_top_values.memmap", @@ -453,7 +453,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index ed7c7d1d35f..311b2584aa5 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -100,7 +100,7 @@ def update(self, kl_values: Sequence[float]): ) n_steps = len(kl_values) # renormalize kls - kl_value = -torch.tensor(kl_values).mean() / self.coef + kl_value = -torch.as_tensor(kl_values).mean() / self.coef proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ mult = 1 + proportional_error * n_steps / self.horizon self.coef *= mult # βₜ₊₁ @@ -314,10 +314,10 @@ def _get_done_status(self, generated, batch): # of generated tokens done_idx = torch.minimum( (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex, - torch.tensor(self.max_new_tokens) - 1, + torch.as_tensor(self.max_new_tokens) - 1, ) truncated_idx = ( - torch.tensor(self.max_new_tokens, device=generated.device).expand_as( + torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as( done_idx ) - 1 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b4d628a9051..1cfc970e61f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1374,9 +1374,9 @@ def encode( ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if ignore_device: - val = torch.tensor(val) + val = torch.as_tensor(val) else: - val = torch.tensor(val, device=self.device) + val = torch.as_tensor(val, device=self.device) if space is None: space = self.space @@ -1555,9 +1555,9 @@ def __init__( dtype = torch.get_default_dtype() if not isinstance(low, torch.Tensor): - low = torch.tensor(low, dtype=dtype, device=device) + low = torch.as_tensor(low, dtype=dtype, device=device) if not isinstance(high, torch.Tensor): - high = torch.tensor(high, dtype=dtype, device=device) + high = torch.as_tensor(high, dtype=dtype, device=device) if high.device != device: high = high.to(device) if low.device != device: @@ -1857,8 +1857,8 @@ def __init__( dtype, device = _default_dtype_and_device(dtype, device) box = ( ContinuousBox( - torch.tensor(-np.inf, device=device).expand(shape), - torch.tensor(np.inf, device=device).expand(shape), + torch.as_tensor(-np.inf, device=device).expand(shape), + torch.as_tensor(np.inf, device=device).expand(shape), ) if shape == _DEFAULT_SHAPE else None diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index b2fdac0a802..2e96efcaf6a 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -102,9 +102,9 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]: def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: if isinstance(array, np.ndarray): - return torch.tensor(array.copy()) + return torch.as_tensor(array.copy()) else: - return torch.tensor(array) + return torch.as_tensor(array) class DMControlWrapper(GymLikeEnv): diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index acf2da598b1..410e25a1b28 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -264,7 +264,7 @@ def _transform_step_output( f"The output of step was had {len(out)} elements, but only 4 or 5 are supported." ) obs = self._treevalue_or_numpy_to_tensor_or_dict(obs) - reward_and_done = {self.reward_key: torch.tensor(reward)} + reward_and_done = {self.reward_key: torch.as_tensor(reward)} reward_and_done["done"] = done reward_and_done["terminated"] = terminated reward_and_done["truncated"] = truncated @@ -290,7 +290,7 @@ def _treevalue_or_numpy_to_tensor_or_dict( if isinstance(x, treevalue.TreeValue): ret = self._treevalue_to_dict(x) elif not isinstance(x, dict): - ret = {"observation": torch.tensor(x)} + ret = {"observation": torch.as_tensor(x)} else: ret = x return ret @@ -304,7 +304,7 @@ def _treevalue_to_dict( """ import treevalue - return {k[0]: torch.tensor(v) for k, v in treevalue.flatten(tv)} + return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)} def _set_seed(self, seed: Optional[int]): if seed is not None: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 48da354e7ba..59730c6df8c 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1506,6 +1506,7 @@ def _read_obs(self, obs, key, tensor, index): def __call__(self, info_dict, tensordict): terminal_obs = info_dict.get(self.backend_key[self.backend], None) for key, item in self.info_spec.items(True, True): + key = (key,) if isinstance(key, str) else key final_obs_buffer = item.zero() if terminal_obs is not None: for i, obs in enumerate(terminal_obs): diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 14e45eb4bc4..a1470776f10 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -462,7 +462,7 @@ def _init_env(self): "info": CompositeSpec( { key: UnboundedContinuousTensorSpec( - shape=torch.tensor(value).shape, + shape=torch.as_tensor(value).shape, device=self.device, ) for key, value in info_dict[agent].items() @@ -501,7 +501,7 @@ def _init_env(self): device=self.device, ) except AttributeError: - state_example = torch.tensor(self.state(), device=self.device) + state_example = torch.as_tensor(self.state(), device=self.device) state_spec = UnboundedContinuousTensorSpec( shape=state_example.shape, dtype=state_example.dtype, @@ -560,7 +560,7 @@ def _reset( if group_info is not None: agent_info_dict = info_dict[agent] for agent_info, value in agent_info_dict.items(): - group_info.get(agent_info)[index] = torch.tensor( + group_info.get(agent_info)[index] = torch.as_tensor( value, device=self.device ) diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index b5aed62d503..99f38ebb32c 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -135,7 +135,7 @@ def _get_lives(self): if callable(lives): lives = lives() elif isinstance(lives, list) and all(callable(_lives) for _lives in lives): - lives = torch.tensor([_lives() for _lives in lives]) + lives = torch.as_tensor([_lives() for _lives in lives]) return lives def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -170,7 +170,7 @@ def _reset(self, tensordict, tensordict_reset): end_of_life = False tensordict_reset.set( self.eol_key, - torch.tensor(end_of_life).expand( + torch.as_tensor(end_of_life).expand( parent.full_done_spec[self.done_key].shape ), ) diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 05017a8a8ec..1c12cf9be15 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -292,8 +292,8 @@ def _init(self): std = [0.229, 0.224, 0.225] normalize = ObservationNorm( in_keys=in_keys, - loc=torch.tensor(mean).view(3, 1, 1), - scale=torch.tensor(std).view(3, 1, 1), + loc=torch.as_tensor(mean).view(3, 1, 1), + scale=torch.as_tensor(std).view(3, 1, 1), standard_normal=True, ) transforms.append(normalize) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 79ee94318cb..623bc2864fe 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -146,7 +146,7 @@ def find_sample_log_prob(module): self.functional_actor.apply(find_sample_log_prob) if not isinstance(coef, torch.Tensor): - coef = torch.tensor(coef) + coef = torch.as_tensor(coef) self.register_buffer("coef", coef) def _reset( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 52ac8e8f66d..e59c481419c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1332,7 +1332,7 @@ def check_val(val): if val is None: return None, None, torch.finfo(torch.get_default_dtype()).max if not isinstance(val, torch.Tensor): - val = torch.tensor(val) + val = torch.as_tensor(val) if not val.dtype.is_floating_point: val = val.float() eps = torch.finfo(val.dtype).resolution @@ -1626,10 +1626,10 @@ def __init__( out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) clamp_min_tensor = ( - clamp_min if isinstance(clamp_min, Tensor) else torch.tensor(clamp_min) + clamp_min if isinstance(clamp_min, Tensor) else torch.as_tensor(clamp_min) ) clamp_max_tensor = ( - clamp_max if isinstance(clamp_max, Tensor) else torch.tensor(clamp_max) + clamp_max if isinstance(clamp_max, Tensor) else torch.as_tensor(clamp_max) ) self.register_buffer("clamp_min", clamp_min_tensor) self.register_buffer("clamp_max", clamp_max_tensor) @@ -2396,7 +2396,7 @@ def __init__( out_keys_inv=out_keys_inv, ) if not isinstance(standard_normal, torch.Tensor): - standard_normal = torch.tensor(standard_normal) + standard_normal = torch.as_tensor(standard_normal) self.register_buffer("standard_normal", standard_normal) self.eps = 1e-6 diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index 252ddfc4a90..746bfe52f1d 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -132,8 +132,8 @@ def _map_tv_to_torchrl( elif isinstance(model_transforms, transforms.Normalize): return ObservationNorm( in_keys=in_keys, - loc=torch.tensor(model_transforms.mean).reshape(3, 1, 1), - scale=torch.tensor(model_transforms.std).reshape(3, 1, 1), + loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1), + scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1), standard_normal=True, ) elif isinstance(model_transforms, transforms.ToTensor): diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index ed288fdea9e..9c272b42b89 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -266,8 +266,8 @@ def _init(self): std = [0.229, 0.224, 0.225] normalize = ObservationNorm( in_keys=in_keys, - loc=torch.tensor(mean).view(3, 1, 1), - scale=torch.tensor(std).view(3, 1, 1), + loc=torch.as_tensor(mean).view(3, 1, 1), + scale=torch.as_tensor(std).view(3, 1, 1), standard_normal=True, ) transforms.append(normalize) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index d4256dcd61f..eb5f2a38944 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -240,11 +240,11 @@ def __init__( if isinstance(max, torch.Tensor): max = max.to(self.device) else: - max = torch.tensor(max, device=self.device) + max = torch.as_tensor(max, device=self.device) if isinstance(min, torch.Tensor): min = min.to(self.device) else: - min = torch.tensor(min, device=self.device) + min = torch.as_tensor(min, device=self.device) self.min = min self.max = max self.update(loc, scale) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index f909b6568c6..59819d940d0 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -345,7 +345,9 @@ def __init__( ) if sigma_init != 0.0: - self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device)) + self.register_buffer( + "sigma_init", torch.as_tensor(sigma_init, device=device) + ) @property def sigma(self): diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index c65b81eb11d..9c0bbc8f147 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -145,7 +145,7 @@ def __init__( self.num_candidates = num_candidates self.top_k = top_k self.reward_key = reward_key - self.register_buffer("temperature", torch.tensor(temperature)) + self.register_buffer("temperature", torch.as_tensor(temperature)) def planning(self, tensordict: TensorDictBase) -> torch.Tensor: batch_size = tensordict.batch_size diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index f641fdfef88..c8fa9cc040f 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -107,10 +107,10 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) + self.register_buffer("eps_init", torch.as_tensor([eps_init])) + self.register_buffer("eps_end", torch.as_tensor([eps_end])) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: @@ -254,12 +254,12 @@ def __init__( ) super().__init__(policy) - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) + self.register_buffer("eps_init", torch.as_tensor([eps_init])) + self.register_buffer("eps_end", torch.as_tensor([eps_end])) if self.eps_end > self.eps_init: raise RuntimeError("eps should decrease over time or be constant") self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) self.action_key = action_key self.action_mask_key = action_mask_key if spec is not None: diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c32a795a2a0..de963bcfdb9 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -283,8 +283,10 @@ def __init__( except AttributeError: device = torch.device("cpu") - self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) - self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device)) + self.register_buffer( + "entropy_coef", torch.as_tensor(entropy_coef, device=device) + ) + self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device)) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index e920bc83960..6ef7ab7386e 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -174,22 +174,24 @@ def __init__( except AttributeError: device = torch.device("cpu") - self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device)) self.register_buffer( - "min_log_alpha", torch.tensor(min_alpha, device=device).log() + "min_log_alpha", torch.as_tensor(min_alpha, device=device).log() ) self.register_buffer( - "max_log_alpha", torch.tensor(max_alpha, device=device).log() + "max_log_alpha", torch.as_tensor(max_alpha, device=device).log() ) self.fixed_alpha = fixed_alpha if fixed_alpha: self.register_buffer( - "log_alpha", torch.tensor(math.log(alpha_init), device=device) + "log_alpha", torch.as_tensor(math.log(alpha_init), device=device) ) else: self.register_parameter( "log_alpha", - torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + torch.nn.Parameter( + torch.as_tensor(math.log(alpha_init), device=device) + ), ) self._target_entropy = target_entropy @@ -230,7 +232,7 @@ def target_entropy(self): np.prod(action_spec[self.tensor_keys.action].shape) ) self.register_buffer( - "target_entropy_buffer", torch.tensor(target_entropy, device=device) + "target_entropy_buffer", torch.as_tensor(target_entropy, device=device) ) return self.target_entropy_buffer return target_entropy From b0653c44fb6e3776653973861a4732c89801173e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 12:13:03 +0000 Subject: [PATCH 25/35] [CI] Fix macos build (#1856) --- .github/workflows/build-wheels-m1.yml | 21 ++++++++++++++------- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 9cbdf460894..84fe79d09d2 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -13,6 +13,10 @@ on: - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ workflow_dispatch: +permissions: + id-token: write + contents: read + jobs: generate-matrix: uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main @@ -23,20 +27,23 @@ jobs: test-infra-ref: main build: needs: generate-matrix + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/rl + smoke-test-script: test/smoke_test.py + package-name: torchrl name: pytorch/rl uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main with: - repository: pytorch/rl + repository: ${{ matrix.repository }} ref: "" test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} - post-script: "" - package-name: torchrl + package-name: ${{ matrix.package-name }} runner-type: macos-m1-stable - smoke-test-script: "" + smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} env-var-script: .github/scripts/m1_script.sh - secrets: - AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} - AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 428997ba3e8..47c1b0c6fec 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -53,7 +53,7 @@ jobs: path: dist/*.whl build-wheel-mac: - runs-on: macos-latest + runs-on: macos-11 strategy: matrix: python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]] @@ -121,7 +121,7 @@ jobs: needs: [build-wheel-linux, build-wheel-mac] strategy: matrix: - os: [["linux", "ubuntu-20.04"], ["mac", "macos-latest"]] + os: [["linux", "ubuntu-20.04"], ["mac", "macos-11"]] python_version: [ "3.8", "3.9", "3.10", "3.11" ] runs-on: ${{ matrix.os[1] }} steps: From b5e90c4f29ff782437f009b8a15dd511074033c9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 18:31:53 +0000 Subject: [PATCH 26/35] [Deprecation] Deprecate in prep for release (#1820) --- .../linux_examples/scripts/run_test.sh | 2 + examples/bandits/dqn.py | 17 +- examples/cql/discrete_cql_config.yaml | 4 +- examples/cql/discrete_cql_online.py | 2 +- examples/cql/offline_config.yaml | 2 +- examples/cql/online_config.yaml | 2 +- examples/cql/utils.py | 7 +- examples/ddpg/config.yaml | 2 +- examples/ddpg/ddpg.py | 2 +- examples/ddpg/utils.py | 6 +- examples/decision_transformer/dt_config.yaml | 2 +- examples/decision_transformer/odt_config.yaml | 2 +- examples/decision_transformer/utils.py | 2 +- examples/discrete_sac/config.yaml | 2 +- examples/discrete_sac/discrete_sac.py | 2 +- examples/discrete_sac/utils.py | 6 +- .../collectors/multi_nodes/ray_train.py | 2 +- examples/iql/utils.py | 6 +- examples/multiagent/iql.py | 22 +- examples/multiagent/maddpg_iddpg.py | 2 +- examples/multiagent/mappo_ippo.py | 2 +- examples/multiagent/qmix_vdn.py | 22 +- examples/multiagent/sac.py | 2 +- examples/redq/config.yaml | 2 +- examples/rlhf/train_rlhf.py | 2 +- examples/sac/config.yaml | 2 +- examples/sac/sac.py | 2 +- examples/sac/utils.py | 6 +- examples/td3/config.yaml | 2 +- examples/td3/td3.py | 2 +- examples/td3/utils.py | 7 +- test/mocking_classes.py | 7 +- test/test_collector.py | 65 +-- test/test_cost.py | 4 +- test/test_distributed.py | 3 +- test/test_exploration.py | 8 - test/test_helpers.py | 18 +- test/test_transforms.py | 452 ++++++++++++++---- torchrl/collectors/collectors.py | 113 +++-- torchrl/collectors/distributed/generic.py | 14 +- torchrl/collectors/distributed/ray.py | 19 +- torchrl/collectors/distributed/rpc.py | 14 +- torchrl/collectors/distributed/sync.py | 14 +- torchrl/data/datasets/d4rl.py | 31 +- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/tensor_specs.py | 6 +- torchrl/envs/gym_like.py | 2 +- torchrl/envs/transforms/transforms.py | 12 +- torchrl/modules/models/models.py | 9 +- torchrl/modules/tensordict_module/actors.py | 10 +- .../modules/tensordict_module/exploration.py | 99 +--- torchrl/modules/tensordict_module/rnn.py | 9 +- torchrl/objectives/a2c.py | 31 +- torchrl/objectives/common.py | 8 +- torchrl/objectives/cql.py | 8 +- torchrl/objectives/ddpg.py | 6 +- torchrl/objectives/deprecated.py | 6 +- torchrl/objectives/dqn.py | 5 +- torchrl/objectives/dreamer.py | 9 +- torchrl/objectives/iql.py | 17 +- torchrl/objectives/multiagent/qmixer.py | 5 +- torchrl/objectives/ppo.py | 8 +- torchrl/objectives/redq.py | 6 +- torchrl/objectives/reinforce.py | 31 +- torchrl/objectives/sac.py | 5 +- torchrl/objectives/td3.py | 6 +- torchrl/objectives/utils.py | 9 +- torchrl/objectives/value/advantages.py | 18 +- torchrl/trainers/helpers/logger.py | 2 + torchrl/trainers/trainers.py | 6 +- 70 files changed, 729 insertions(+), 513 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 0cbcb70ad15..e75f4b1bc1c 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -114,6 +114,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari. buffer.batch_size=10 \ device=cuda:0 \ loss.num_updates=1 \ + logger.backend= \ buffer.buffer_size=120 python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \ collector.total_frames=48 \ @@ -256,6 +257,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari. buffer.batch_size=10 \ device=cuda:0 \ loss.num_updates=1 \ + logger.backend= \ buffer.buffer_size=120 python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ num_workers=2 \ diff --git a/examples/bandits/dqn.py b/examples/bandits/dqn.py index 847cfbfc124..0d9ca828ee6 100644 --- a/examples/bandits/dqn.py +++ b/examples/bandits/dqn.py @@ -7,11 +7,12 @@ import torch import tqdm -from torch import nn +from tensordict.nn import TensorDictSequential +from torch import nn from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import DistributionalQValueActor, EGreedyWrapper, MLP, QValueActor +from torchrl.modules import DistributionalQValueActor, EGreedyModule, MLP, QValueActor from torchrl.objectives import DistributionalDQNLoss, DQNLoss parser = argparse.ArgumentParser() @@ -85,12 +86,14 @@ actor(env.reset()) loss = DQNLoss(actor, loss_function="smooth_l1", action_space=env.action_spec) loss.make_value_estimator(gamma=0.0) - policy = EGreedyWrapper( + policy = TensorDictSequential( actor, - eps_init=eps_greedy, - eps_end=0.0, - annealing_num_steps=n_steps, - spec=env.action_spec, + EGreedyModule( + eps_init=eps_greedy, + eps_end=0.0, + annealing_num_steps=n_steps, + spec=env.action_spec, + ), ) optim = torch.optim.Adam(loss.parameters(), lr, weight_decay=wd) diff --git a/examples/cql/discrete_cql_config.yaml b/examples/cql/discrete_cql_config.yaml index b7f8d527ba3..807479d45bd 100644 --- a/examples/cql/discrete_cql_config.yaml +++ b/examples/cql/discrete_cql_config.yaml @@ -2,7 +2,7 @@ env: name: CartPole-v1 task: "" - backend: gym + backend: gymnasium n_samples_stats: 1000 max_episode_steps: 200 seed: 0 @@ -36,7 +36,7 @@ replay_buffer: prb: 0 buffer_prefetch: 64 size: 1_000_000 - scratch_dir: ${env.exp_name}_${env.seed} + scratch_dir: null # Optimization optim: diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py index facbcc49bf9..107739f3aba 100644 --- a/examples/cql/discrete_cql_online.py +++ b/examples/cql/discrete_cql_online.py @@ -73,7 +73,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/cql/offline_config.yaml b/examples/cql/offline_config.yaml index d41db847077..0047b74d14c 100644 --- a/examples/cql/offline_config.yaml +++ b/examples/cql/offline_config.yaml @@ -5,7 +5,7 @@ env: library: gym n_samples_stats: 1000 seed: 0 - backend: gym # D4RL uses gym so we make sure gymnasium is hidden + backend: gymnasium # logger logger: diff --git a/examples/cql/online_config.yaml b/examples/cql/online_config.yaml index 367d4755cac..9b3e5b5bf24 100644 --- a/examples/cql/online_config.yaml +++ b/examples/cql/online_config.yaml @@ -6,7 +6,7 @@ env: seed: 0 train_num_envs: 1 eval_num_envs: 1 - backend: gym + backend: gymnasium # Collector collector: diff --git a/examples/cql/utils.py b/examples/cql/utils.py index 0af1a082e28..350b105b441 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -121,7 +121,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -133,7 +133,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -144,7 +144,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -320,7 +320,6 @@ def make_discrete_loss(loss_cfg, model): model, loss_function=loss_cfg.loss_function, delay_value=True, - gamma=loss_cfg.gamma, ) loss_module.make_value_estimator(gamma=loss_cfg.gamma) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index fb4a3fa4725..7d17038330b 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -21,7 +21,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optimization optim: diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index ea5a1386e4f..92fdd850fbd 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 935fb426988..4006fc27b38 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -119,7 +119,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -131,7 +131,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -142,7 +142,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index 80915c4f93a..b42d8b58d35 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -36,7 +36,7 @@ replay_buffer: stacked_frames: 20 buffer_prefetch: 64 capacity: 1_000_000 - buffer_scratch_dir: + scratch_dir: device: cpu prefetch: 3 diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index b6137ac62a1..f06972fd46b 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -36,7 +36,7 @@ replay_buffer: stacked_frames: 20 buffer_prefetch: 64 capacity: 1_000_000 - buffer_scratch_dir: + scratch_dir: device: cuda:0 prefetch: 3 diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 8bd9f3bebbf..9d479a8118d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -296,7 +296,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): ) storage = LazyMemmapStorage( max_size=rb_cfg.capacity, - scratch_dir=rb_cfg.buffer_scratch_dir, + scratch_dir=rb_cfg.scratch_dir, device=rb_cfg.device, ) diff --git a/examples/discrete_sac/config.yaml b/examples/discrete_sac/config.yaml index 03ae3999f87..df26c835ef0 100644 --- a/examples/discrete_sac/config.yaml +++ b/examples/discrete_sac/config.yaml @@ -22,7 +22,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optim optim: diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 2976cf8806d..16c5de80a64 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/discrete_sac/utils.py b/examples/discrete_sac/utils.py index 49ec8bc1204..5821ed53465 100644 --- a/examples/discrete_sac/utils.py +++ b/examples/discrete_sac/utils.py @@ -120,14 +120,14 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): with ( tempfile.TemporaryDirectory() - if buffer_scratch_dir is None - else nullcontext(buffer_scratch_dir) + if scratch_dir is None + else nullcontext(scratch_dir) ) as scratch_dir: if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 2db86b9f917..7d456367a5a 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -36,7 +36,7 @@ if __name__ == "__main__": # 1. Define Hyperparameters - device = "cpu" # if not torch.has_cuda else "cuda:0" + device = "cpu" # if not torch.cuda.device_count() else "cuda:0" num_cells = 256 max_grad_norm = 1.0 frame_skip = 1 diff --git a/examples/iql/utils.py b/examples/iql/utils.py index fe1e5ce32b8..997df401b82 100644 --- a/examples/iql/utils.py +++ b/examples/iql/utils.py @@ -125,7 +125,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -137,7 +137,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -148,7 +148,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 4af5da62c91..011e04cde77 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -8,7 +8,7 @@ import hydra import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer @@ -17,7 +17,7 @@ from torchrl.envs import RewardSum, TransformedEnv from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential +from torchrl.modules import EGreedyModule, QValueModule, SafeSequential from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators from utils.logging import init_logging, log_evaluation, log_training @@ -31,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="iql") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding @@ -96,13 +96,15 @@ def train(cfg: "DictConfig"): # noqa: F821 ) qnet = SafeSequential(module, value_module) - qnet_explore = EGreedyWrapper( + qnet_explore = TensorDictSequential( qnet, - eps_init=0.3, - eps_end=0, - annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), - action_key=env.action_key, - spec=env.unbatched_action_spec, + EGreedyModule( + eps_init=0.3, + eps_end=0, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + spec=env.unbatched_action_spec, + ), ) collector = SyncDataCollector( @@ -174,7 +176,7 @@ def train(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() target_net_updater.step() - qnet_explore.step(frames=current_frames) # Update exploration annealing + qnet_explore[1].step(frames=current_frames) # Update exploration annealing collector.update_policy_weights_() training_time = time.time() - training_start diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index 4e6b821604c..e4fd4a25e12 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -36,7 +36,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="maddpg_iddpg") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index b00bb18a2a0..d4481c93071 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -31,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="mappo_ippo") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 5822bda39da..e53c47e04f4 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -8,7 +8,7 @@ import hydra import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer @@ -17,7 +17,7 @@ from torchrl.envs import RewardSum, TransformedEnv from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential +from torchrl.modules import EGreedyModule, QValueModule, SafeSequential from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer from torchrl.objectives import SoftUpdate, ValueEstimators from torchrl.objectives.multiagent.qmixer import QMixerLoss @@ -31,7 +31,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="qmix_vdn") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding @@ -96,13 +96,15 @@ def train(cfg: "DictConfig"): # noqa: F821 ) qnet = SafeSequential(module, value_module) - qnet_explore = EGreedyWrapper( + qnet_explore = TensorDictSequential( qnet, - eps_init=0.3, - eps_end=0, - annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), - action_key=env.action_key, - spec=env.unbatched_action_spec, + EGreedyModule( + eps_init=0.3, + eps_end=0, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + spec=env.unbatched_action_spec, + ), ) if cfg.loss.mixer_type == "qmix": @@ -209,7 +211,7 @@ def train(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() target_net_updater.step() - qnet_explore.step(frames=current_frames) # Update exploration annealing + qnet_explore[1].step(frames=current_frames) # Update exploration annealing collector.update_policy_weights_() training_time = time.time() - training_start diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 1c01b5e50b7..528b5422921 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -33,7 +33,7 @@ def rendering_callback(env, td): @hydra.main(version_base="1.1", config_path=".", config_name="sac") def train(cfg: "DictConfig"): # noqa: F821 # Device - cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0" + cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0" cfg.env.device = cfg.train.device # Seeding diff --git a/examples/redq/config.yaml b/examples/redq/config.yaml index fc77974cb38..c67543716dc 100644 --- a/examples/redq/config.yaml +++ b/examples/redq/config.yaml @@ -68,7 +68,7 @@ buffer: prb: 1 sub_traj_len: size: 500_000 - scratch_dir: + scratch_dir: null prefetch: 64 network: diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 6f3e80649d7..a921e58bad6 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -62,7 +62,7 @@ def main(cfg): wandb_kwargs={ "config": dict(cfg), "project": cfg.io.project_name, - "group": cfg.logger.group_name, + "group": cfg.io.group_name, }, ) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index b6675ecc9a0..6546f1e30b7 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -20,7 +20,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optim optim: diff --git a/examples/sac/sac.py b/examples/sac/sac.py index a93e3a833dd..db23071867a 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 1e157ce85cd..afb731dcc95 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -108,7 +108,7 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): @@ -120,7 +120,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, @@ -131,7 +131,7 @@ def make_replay_buffer( prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, - scratch_dir=buffer_scratch_dir, + scratch_dir=scratch_dir, device=device, ), batch_size=batch_size, diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 561766cd5a4..e94a5b6b774 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -21,7 +21,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: ${logger.exp_name}_${env.seed} + scratch_dir: null # optim optim: diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 1f42e7273d1..003a3bf228c 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 0abc769d365..fed055f98bf 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -121,14 +121,14 @@ def make_replay_buffer( batch_size, prb=False, buffer_size=1000000, - buffer_scratch_dir=None, + scratch_dir=None, device="cpu", prefetch=3, ): with ( tempfile.TemporaryDirectory() - if buffer_scratch_dir is None - else nullcontext(buffer_scratch_dir) + if scratch_dir is None + else nullcontext(scratch_dir) ) as scratch_dir: if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( @@ -248,7 +248,6 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, delay_actor=True, delay_qvalue=True, - gamma=cfg.optim.gamma, action_spec=model[0][1].spec, policy_noise=cfg.optim.policy_noise, noise_clip=cfg.optim.noise_clip, diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 9e5b2ff6879..7a32c9a38ef 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -646,7 +646,7 @@ def _obs_step(self, obs, a): return obs + a / self.maxstep -class DiscreteActionVecPolicy: +class DiscreteActionVecPolicy(TensorDictModuleBase): in_keys = ["observation"] out_keys = ["action"] @@ -979,10 +979,13 @@ def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) -class CountingEnvCountPolicy: +class CountingEnvCountPolicy(TensorDictModuleBase): def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() self.action_spec = action_spec self.action_key = action_key + self.in_keys = [] + self.out_keys = [action_key] def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.set(self.action_key, self.action_spec.zero() + 1) diff --git a/test/test_collector.py b/test/test_collector.py index 027cf776ee4..8369be1578e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1491,37 +1491,38 @@ def test_auto_wrap_modules( collector.shutdown() del collector - def test_no_wrap_compatible_module(self, collector_class, env_maker): - policy = TensorDictCompatiblePolicy( - out_features=env_maker().action_spec.shape[-1] - ) - policy(env_maker().reset()) - - collector = collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy) - ) - - if collector_class is not SyncDataCollector: - # We now do the casting only on the remote workers - pass - else: - assert isinstance(collector.policy, TensorDictCompatiblePolicy) - assert collector.policy.out_keys == ["action"] - assert collector.policy is policy - - for i, data in enumerate(collector): - if i == 0: - assert (data["action"] != 0).any() - for p in policy.parameters(): - p.data.zero_() - assert p.device == torch.device("cpu") - collector.update_policy_weights_() - elif i == 4: - assert (data["action"] == 0).all() - break - - collector.shutdown() - del collector + # Deprecated as from v0.3 + # def test_no_wrap_compatible_module(self, collector_class, env_maker): + # policy = TensorDictCompatiblePolicy( + # out_features=env_maker().action_spec.shape[-1] + # ) + # policy(env_maker().reset()) + # + # collector = collector_class( + # **self._create_collector_kwargs(env_maker, collector_class, policy) + # ) + # + # if collector_class is not SyncDataCollector: + # # We now do the casting only on the remote workers + # pass + # else: + # assert isinstance(collector.policy, TensorDictCompatiblePolicy) + # assert collector.policy.out_keys == ["action"] + # assert collector.policy is policy + # + # for i, data in enumerate(collector): + # if i == 0: + # assert (data["action"] != 0).any() + # for p in policy.parameters(): + # p.data.zero_() + # assert p.device == torch.device("cpu") + # collector.update_policy_weights_() + # elif i == 4: + # assert (data["action"] == 0).all() + # break + # + # collector.shutdown() + # del collector def test_auto_wrap_error(self, collector_class, env_maker): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) @@ -2062,7 +2063,7 @@ def _reset(self, tensordict=None): def _set_seed(self, seed): return seed - class Policy(nn.Module): + class Policy(TensorDictModuleBase): def __init__(self): super().__init__() self.param = nn.Parameter(torch.zeros(())) diff --git a/test/test_cost.py b/test/test_cost.py index 87e17eb252c..c6eb27172ee 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6091,7 +6091,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): else: raise NotImplementedError - loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2") + loss_fn = loss_class(actor, value, loss_critic_type="l2") params = TensorDict.from_module(loss_fn, as_module=True) @@ -11960,7 +11960,7 @@ def test_set_deprecated_keys(self, adv, kwargs): nn.Linear(3, 1), in_keys=["obs"], out_keys=["test_value"] ) - with pytest.warns(DeprecationWarning): + with pytest.raises(RuntimeError, match="via constructor is deprecated"): if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] diff --git a/test/test_distributed.py b/test/test_distributed.py index debfa058ace..6215abd7ceb 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -14,6 +14,7 @@ import time import pytest +from tensordict.nn import TensorDictModuleBase try: import ray @@ -49,7 +50,7 @@ pytest.skip("skipping windows tests in windows", allow_module_level=True) -class CountingPolicy(nn.Module): +class CountingPolicy(TensorDictModuleBase): """A policy for counting env. Returns a step of 1 by default but weights can be adapted. diff --git a/test/test_exploration.py b/test/test_exploration.py index 777f2714edb..d0735a53ae8 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -156,14 +156,6 @@ def test_egreedy_masked(self, module, eps_init, spec_class): assert not (action[~action_mask] == 0).all() assert (masked_action[~action_mask] == 0).all() - def test_egreedy_wrapper_deprecation(self): - torch.manual_seed(0) - spec = BoundedTensorSpec(1, 1, torch.Size([4])) - module = torch.nn.Linear(4, 4, bias=False) - policy = Actor(spec=spec, module=module) - with pytest.deprecated_call(): - EGreedyWrapper(policy) - def test_no_spec_error( self, ): diff --git a/test/test_helpers.py b/test/test_helpers.py index 1843a3f738f..eb9620001c7 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -501,9 +501,13 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial stats = {"loc": None, "scale": None} if initialized: stats = {"loc": 0.0, "scale": 1.0} - t_env.transform = ObservationNorm(standard_normal=True, **stats) + t_env.transform = ObservationNorm( + in_keys=["observation"], standard_normal=True, **stats + ) if composed: - t_env.append_transform(ObservationNorm(standard_normal=True, **stats)) + t_env.append_transform( + ObservationNorm(in_keys=["observation"], standard_normal=True, **stats) + ) if not initialized: with pytest.raises( ValueError, match="Attempted to use an uninitialized parameter" @@ -539,7 +543,7 @@ def test_initialize_stats_from_non_obs_transform(device): def test_initialize_obs_transform_stats_raise_exception(): env = ContinuousActionVecMockEnv() t_env = TransformedEnv(env) - t_env.transform = ObservationNorm() + t_env.transform = ObservationNorm(in_keys=["observation"]) with pytest.raises( RuntimeError, match="More than one key exists in the observation_specs" ): @@ -553,10 +557,14 @@ def test_retrieve_observation_norms_state_dict(device, composed): env.set_seed(1) t_env = TransformedEnv(env) - t_env.transform = ObservationNorm(standard_normal=True, loc=0.5, scale=0.2) + t_env.transform = ObservationNorm( + standard_normal=True, loc=0.5, scale=0.2, in_keys=["observation"] + ) if composed: t_env.append_transform( - ObservationNorm(standard_normal=True, loc=1.0, scale=0.3) + ObservationNorm( + standard_normal=True, loc=1.0, scale=0.3, in_keys=["observation"] + ) ) initialize_observation_norm_transforms(proof_environment=t_env, num_iter=100) state_dicts = retrieve_observation_norms_state_dict(t_env) diff --git a/test/test_transforms.py b/test/test_transforms.py index b325a1ccd99..725945ef113 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -229,22 +229,28 @@ def test_parallel_trans_env_check(self): env = ParallelEnv( 2, lambda: TransformedEnv(ContinuousActionVecMockEnv(), BinarizeReward()) ) - check_env_specs(env) - env.close() + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, lambda: ContinuousActionVecMockEnv()), BinarizeReward() ) - check_env_specs(env) - env.close() + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), BinarizeReward() ) - check_env_specs(env) - env.close() + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) @@ -546,7 +552,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -575,7 +584,10 @@ def test_trans_parallel_env_check(self): high=0.1, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = ContinuousActionVecMockEnv() @@ -618,7 +630,10 @@ def test_parallel_trans_env_check(self): CatFrames(dim=-1, N=3, in_keys=["observation"]), ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -639,7 +654,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), CatFrames(dim=-1, N=3, in_keys=["observation"]), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -1171,7 +1189,10 @@ def make_env(): ) transformed_env = ParallelEnv(2, make_env) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_serial_trans_env_check(self, model, device): if model != "resnet18": @@ -1213,7 +1234,10 @@ def test_trans_parallel_env_check(self, model, device): transformed_env = TransformedEnv( ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), r3m ) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_trans_serial_env_check(self, model, device): if model != "resnet18": @@ -1545,7 +1569,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(10)) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -1558,7 +1585,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), StepCounter(10) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), StepCounter(10)) @@ -1839,7 +1869,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), ct) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): ct = CatTensors( @@ -1861,7 +1894,10 @@ def test_trans_parallel_env_check(self): ) env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), ct) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( @@ -2210,7 +2246,10 @@ def make_env(): return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): keys = ["pixels"] @@ -2222,7 +2261,10 @@ def test_trans_parallel_env_check(self): keys = ["pixels"] ct = Compose(ToTensorImage(), CenterCrop(w=20, h=20, in_keys=keys)) env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="No Gym detected") @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) @@ -2266,7 +2308,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -2280,7 +2325,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), DiscreteActionProjection(7, 10), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("action_key", ["action", ("nested", "stuff")]) def test_transform_no_env(self, action_key): @@ -2526,7 +2574,10 @@ def test_trans_parallel_env_check(self, dtype_fixture): # noqa: F811 ParallelEnv(2, lambda: ContinuousActionVecMockEnv(dtype=torch.float64)), DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self, dtype_fixture): # noqa: F811 t = DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]) @@ -2681,7 +2732,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): t = Compose( @@ -2701,7 +2755,10 @@ def test_trans_parallel_env_check(self): ExcludeTransform("observation_copy"), ) env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -2907,7 +2964,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): t = Compose( @@ -2927,7 +2987,10 @@ def test_trans_parallel_env_check(self): SelectTransform("observation", "observation_orig"), ) env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -3094,6 +3157,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3113,7 +3180,10 @@ def test_trans_parallel_env_check(self): -1, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize("nchannels", [1, 3]) @@ -3265,7 +3335,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3277,7 +3350,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), FrameSkipTransform(2) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = FrameSkipTransform(2) @@ -3500,7 +3576,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): out_keys = None @@ -3516,7 +3595,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), GrayScale(out_keys=out_keys)), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_transform_env(self, out_keys): @@ -3589,7 +3671,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), NoopResetEnv()) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), NoopResetEnv()) @@ -3759,7 +3844,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check( self, @@ -3785,7 +3873,10 @@ def test_trans_parallel_env_check( scale=1.0, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("standard_normal", [True, False]) @pytest.mark.parametrize("in_key", ["observation", ("some_other", "observation")]) @@ -4176,13 +4267,13 @@ def make_env(): ) def test_observationnorm_stats_already_initialized_error(self): - transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1) + transform = ObservationNorm(in_keys=["next_observation"], loc=0, scale=1) with pytest.raises(RuntimeError, match="Loc/Scale are already initialized"): transform.init_stats(num_iter=11) def test_observationnorm_wrong_catdim(self): - transform = ObservationNorm(in_keys="next_observation", loc=0, scale=1) + transform = ObservationNorm(in_keys=["next_observation"], loc=0, scale=1) with pytest.raises( ValueError, match="cat_dim must be part of or equal to reduce_dim" @@ -4336,7 +4427,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4350,7 +4444,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), Resize(20, 21, in_keys=["pixels"])), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="No gym") @pytest.mark.parametrize("out_key", ["pixels", ("agents", "pixels")]) @@ -4406,7 +4503,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4418,7 +4518,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), RewardClipping(-0.1, 0.1) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("reward_key", ["reward", ("agents", "reward")]) def test_transform_no_env(self, reward_key): @@ -4535,7 +4638,10 @@ def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), RewardScaling(0.5, 1.5)) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4547,7 +4653,10 @@ def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), RewardScaling(0.5, 1.5) ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("standard_normal", [True, False]) def test_transform_no_env(self, standard_normal): @@ -4660,9 +4769,12 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) - r = env.rollout(4) - assert r["next", "episode_reward"].unique().numel() > 1 + try: + check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4678,9 +4790,12 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, ContinuousActionVecMockEnv), Compose(RewardScaling(loc=-1, scale=1), RewardSum()), ) - check_env_specs(env) - r = env.rollout(4) - assert r["next", "episode_reward"].unique().numel() > 1 + try: + check_env_specs(env) + r = env.rollout(4) + assert r["next", "episode_reward"].unique().numel() > 1 + finally: + env.close() @pytest.mark.parametrize("has_in_keys,", [True, False]) @pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3]) @@ -5320,7 +5435,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -5334,7 +5452,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, ContinuousActionVecMockEnv), UnsqueezeTransform(-1, in_keys=["observation"]), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5619,19 +5740,28 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), self._circular_transform ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), self._circular_transform ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("squeeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5821,7 +5951,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -5831,7 +5964,10 @@ def test_trans_serial_env_check(self, mode, device): TargetReturn(target_return=10.0, mode=mode).to(device), device=device, ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -5841,7 +5977,10 @@ def test_trans_parallel_env_check(self, mode, device): TargetReturn(target_return=10.0, mode=mode), device=device, ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [SerialEnv, ParallelEnv]) @@ -6047,7 +6186,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6061,7 +6203,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ToTensorImage(in_keys=["pixels"], out_keys=None), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("out_keys", [None, ["stuff"], [("nested", "stuff")]]) @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) @@ -6189,9 +6334,12 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) - assert "mykey" in env.reset().keys() - assert ("next", "mykey") in env.rollout(3).keys(True) + try: + check_env_specs(env) + assert "mykey" in env.reset().keys() + assert ("next", "mykey") in env.rollout(3).keys(True) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -6201,20 +6349,26 @@ def make_env(): ) env = SerialEnv(2, make_env) - check_env_specs(env) - assert "mykey" in env.reset().keys() - assert ("next", "mykey") in env.rollout(3).keys(True) + try: + check_env_specs(env) + assert "mykey" in env.reset().keys() + assert ("next", "mykey") in env.rollout(3).keys(True) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, ContinuousActionVecMockEnv), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), ) - check_env_specs(env) - assert "mykey" in env.reset().keys() - r = env.rollout(3) - assert ("next", "mykey") in r.keys(True) - assert r["next", "mykey"].shape == torch.Size([2, 3, 4]) + try: + check_env_specs(env) + assert "mykey" in env.reset().keys() + r = env.rollout(3) + assert ("next", "mykey") in r.keys(True) + assert r["next", "mykey"].shape == torch.Size([2, 3, 4]) + finally: + env.close() def test_trans_serial_env_check(self): with pytest.raises(RuntimeError, match="The leading shape of the primer specs"): @@ -6414,7 +6568,10 @@ def test_parallel_trans_env_check(self): ), ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6434,7 +6591,10 @@ def test_trans_parallel_env_check(self): T=3, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -6595,7 +6755,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): state_dim = 7 @@ -6610,7 +6773,10 @@ def test_trans_serial_env_check(self): SerialEnv(2, ContinuousActionVecMockEnv), gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): state_dim = 7 @@ -6619,7 +6785,10 @@ def test_trans_parallel_env_check(self): ParallelEnv(2, ContinuousActionVecMockEnv), gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): state_dim = 7 @@ -6736,7 +6905,10 @@ def test_trans_parallel_env_check(self, model, device): transformed_env = TransformedEnv( ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), vip ) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_serial_trans_env_check(self, model, device): in_keys = ["pixels"] @@ -6774,7 +6946,10 @@ def make_env(): ) transformed_env = ParallelEnv(2, make_env) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_transform_model(self, model, device): in_keys = ["pixels"] @@ -7194,7 +7369,10 @@ def test_trans_parallel_env_check(self, device): transformed_env = TransformedEnv( ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), vc1 ) - check_env_specs(transformed_env) + try: + check_env_specs(transformed_env) + finally: + transformed_env.close() def test_serial_trans_env_check(self, device): in_keys = ["pixels"] @@ -7671,8 +7849,12 @@ def test_independent_obs_specs_from_shared_env(self): observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,))) ) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) - t1 = TransformedEnv(base_env, transform=ObservationNorm(loc=3, scale=2)) - t2 = TransformedEnv(base_env, transform=ObservationNorm(loc=1, scale=6)) + t1 = TransformedEnv( + base_env, transform=ObservationNorm(in_keys=["observation"], loc=3, scale=2) + ) + t2 = TransformedEnv( + base_env, transform=ObservationNorm(in_keys=["observation"], loc=1, scale=6) + ) t1_obs_spec = t1.observation_spec t2_obs_spec = t2.observation_spec @@ -8122,7 +8304,9 @@ def test_batch_unlocked_with_batch_size_transformed(device): ), pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), GrayScale, - ObservationNorm, + pytest.param( + partial(ObservationNorm, in_keys=["observation"]), id="ObservationNorm" + ), pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), FiniteTensorDictCheck, @@ -8308,7 +8492,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def make_env(): return TransformedEnv( @@ -8323,7 +8510,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self, create_copy): def make_env(): @@ -8362,7 +8552,10 @@ def make_env(): create_copy=create_copy, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() env = TransformedEnv( ParallelEnv(2, make_env), RenameTransform( @@ -8373,7 +8566,10 @@ def make_env(): create_copy=create_copy, ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("mode", ["forward", "_call"]) @pytest.mark.parametrize( @@ -8572,7 +8768,10 @@ def make_env(): return env env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): def make_env(): @@ -8581,7 +8780,10 @@ def make_env(): env = SerialEnv(2, make_env) env = TransformedEnv(env, InitTracker()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): def make_env(): @@ -8590,7 +8792,10 @@ def make_env(): env = ParallelEnv(2, make_env) env = TransformedEnv(env, InitTracker()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): with pytest.raises(ValueError, match="init_key can only be of type str"): @@ -8855,19 +9060,28 @@ def make_env(): return TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): out_key = "reward" base_env = SerialEnv(2, self.envclass) env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): out_key = "reward" base_env = ParallelEnv(2, self.envclass) env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_model(self): actor = self._make_actor() @@ -9003,15 +9217,24 @@ def test_serial_trans_env_check(self): def test_parallel_trans_env_check(self): env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask()) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = ActionMask() @@ -9084,7 +9307,10 @@ def make_env(): env = ParallelEnv(2, make_env) assert env.device == torch.device("cpu:1") - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): def make_env(): @@ -9100,7 +9326,10 @@ def make_env(): env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1")) assert env.device == torch.device("cpu:1") - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = DeviceCastTransform("cpu:1", "cpu:0") @@ -9189,21 +9418,30 @@ def test_parallel_trans_env_check(self): TestPermuteTransform.envclass(), TestPermuteTransform._get_permute() ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, TestPermuteTransform.envclass), TestPermuteTransform._get_permute(), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_parallel_env_check(self): env = TransformedEnv( ParallelEnv(2, TestPermuteTransform.envclass), TestPermuteTransform._get_permute(), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) def test_transform_compose(self, batch): @@ -9352,7 +9590,10 @@ def make(): ) env = ParallelEnv(2, make) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_transform_no_env(self): t = EndOfLifeTransform() @@ -9752,7 +9993,10 @@ def make_env(): ) env = ParallelEnv(2, make_env) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_serial_trans_env_check(self): def make_env(): @@ -9776,7 +10020,10 @@ def test_trans_parallel_env_check(self): in_keys_inv=["observation_orig"], ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() def test_trans_serial_env_check(self): env = TransformedEnv( @@ -9786,7 +10033,10 @@ def test_trans_serial_env_check(self): in_keys_inv=["observation_orig"], ), ) - check_env_specs(env) + try: + check_env_specs(env) + finally: + env.close() class TestRemoveEmptySpecs(TransformBase): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index ffb8c0f5270..ef972fd343e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -26,10 +26,11 @@ from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np + import torch import torch.nn as nn - from tensordict import ( + is_tensor_collection, LazyStackedTensorDict, TensorDict, TensorDictBase, @@ -169,16 +170,13 @@ def _policy_is_tensordict_compatible(policy: nn.Module): and hasattr(policy, "in_keys") and hasattr(policy, "out_keys") ): - warnings.warn( - "Passing a policy that is not a TensorDictModuleBase subclass but has in_keys and out_keys " - "will soon be deprecated. We'd like to motivate our users to inherit from this class (which " - "has very few restrictions) to make the experience smoother.", - category=DeprecationWarning, + raise RuntimeError( + "Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys " + "is deprecated. Users should inherit from this class (which " + "has very few restrictions) to make the experience smoother. " + "Simply change your policy from `class Policy(nn.Module)` to `Policy(tensordict.nn.TensorDictModuleBase)` " + "and this error should disappear.", ) - # if the policy is a TensorDictModule or takes a single argument and defines - # in_keys and out_keys then we assume it can already deal with TensorDict input - # to forward and we return True - return True elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): # if it's not a TensorDictModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. @@ -235,7 +233,15 @@ def _make_compatible_policy(self, policy, observation_spec=None): key: value for key, value in observation_spec.rand().items() } # we check if all the mandatory params are there - if set(sig.parameters) == {"tensordict"} or set(sig.parameters) == {"td"}: + params = list(sig.parameters.keys()) + if ( + set(sig.parameters) == {"tensordict"} + or set(sig.parameters) == {"td"} + or ( + len(params) == 1 + and is_tensor_collection(sig.parameters[params[0]].annotation) + ) + ): pass elif not required_kwargs.difference(set(next_observation)): in_keys = [str(k) for k in sig.parameters if k in next_observation] @@ -266,6 +272,7 @@ def _make_compatible_policy(self, policy, observation_spec=None): then the arguments to policy.forward must correspond one-to-one with entries in env.observation_spec that are prefixed with 'next_'. For more complex behaviour and more control you can consider writing your own TensorDictModule. +Check the collector documentation to know more about accepted policies. """ ) return policy @@ -385,6 +392,18 @@ class SyncDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total @@ -978,10 +997,10 @@ def is_private(key): # >>> assert data0["done"] is not data1["done"] yield tensordict_out.clone() - def _update_traj_ids(self, tensordict) -> None: + def _update_traj_ids(self, env_output) -> None: # we can't use the reset keys because they're gone traj_sop = _aggregate_end_of_traj( - tensordict.get("next"), done_keys=self.env.done_keys + env_output.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): traj_ids = self._shuttle.get(("collector", "traj_ids")) @@ -1230,11 +1249,23 @@ class _MultiDataCollector(DataCollectorBase): Args: create_env_fn (List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable, optional): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the @@ -2299,8 +2330,23 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable, optional): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the @@ -2497,7 +2543,7 @@ def _main_async_collector( ) -> None: pipe_parent.close() # init variables that will be cleared when closing - tensordict = data = d = data_in = inner_collector = dc_iter = None + collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None inner_collector = SyncDataCollector( create_env_fn, @@ -2571,42 +2617,45 @@ def _main_async_collector( else: inner_collector.init_random_frames = -1 - d = next(dc_iter) + next_data = next(dc_iter) if pipe_child.poll(_MIN_TIMEOUT): # in this case, main send a message to the worker while it was busy collecting trajectories. # In that case, we skip the collected trajectory and get the message from main. This is faster than # sending the trajectory in the queue until timeout when it's never going to be received. continue if j == 0: - tensordict = d - if storing_device is not None and tensordict.device != storing_device: + collected_tensordict = next_data + if ( + storing_device is not None + and collected_tensordict.device != storing_device + ): raise RuntimeError( - f"expected device to be {storing_device} but got {tensordict.device}" + f"expected device to be {storing_device} but got {collected_tensordict.device}" ) # If policy and env are on cpu, we put in shared mem, # if policy is on cuda and env on cuda, we are fine with this # If policy is on cuda and env on cpu (or opposite) we put tensors that # are on cpu in shared mem. - if tensordict.device is not None: + if collected_tensordict.device is not None: # placehoder in case we need different behaviours - if tensordict.device.type in ("cpu", "mps"): - tensordict.share_memory_() - elif tensordict.device.type == "cuda": - tensordict.share_memory_() + if collected_tensordict.device.type in ("cpu", "mps"): + collected_tensordict.share_memory_() + elif collected_tensordict.device.type == "cuda": + collected_tensordict.share_memory_() else: raise NotImplementedError( - f"Device {tensordict.device} is not supported in multi-collectors yet." + f"Device {collected_tensordict.device} is not supported in multi-collectors yet." ) else: # make sure each cpu tensor is shared - assuming non-cpu devices are shared - tensordict.apply( + collected_tensordict.apply( lambda x: x.share_memory_() if x.device.type in ("cpu", "mps") else x ) - data = (tensordict, idx) + data = (collected_tensordict, idx) else: - if d is not tensordict: + if next_data is not collected_tensordict: raise RuntimeError( "SyncDataCollector should return the same tensordict modified in-place." ) @@ -2661,7 +2710,7 @@ def _main_async_collector( continue elif msg == "close": - del tensordict, data, d, data_in + del collected_tensordict, data, next_data, data_in inner_collector.shutdown() del inner_collector, dc_iter pipe_child.send("closed") diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 073d2f445ab..0c5c74b6510 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -260,8 +260,20 @@ class DistributedDataCollector(DataCollectorBase): policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index fa2d8e8191e..6788e48ee3a 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -118,8 +118,23 @@ class RayCollector(DataCollectorBase): Args: create_env_fn (Callable or List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 50729038b4a..dbfc5a7dfd9 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -101,8 +101,20 @@ class RPCDataCollector(DataCollectorBase): policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index d7a5c94487d..7ea805248c9 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -141,8 +141,20 @@ class DistributedSyncDataCollector(DataCollectorBase): policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: frames_per_batch (int): A keyword-only argument representing the total diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 10b9767de8e..468fcb9150c 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -146,7 +146,7 @@ def __init__( prefetch: int | None = None, transform: "torchrl.envs.Transform" | None = None, # noqa-F821 split_trajs: bool = False, - from_env: bool = None, + from_env: bool = False, use_truncated_as_done: bool = True, direct_download: bool = None, terminate_on_end: bool = None, @@ -165,29 +165,16 @@ def __init__( direct_download = not self._has_d4rl if not direct_download: - if from_env is None: - warnings.warn( - "from_env will soon default to ``False``, ie the data will be " - "downloaded without relying on d4rl by default. " - "For now, ``True`` will still be the default. " - "To disable this warning, explicitly pass the ``from_env`` argument " - "during construction of the dataset.", - category=DeprecationWarning, - ) - from_env = True - else: - warnings.warn( - "You are using the D4RL library for collecting data. " - "We advise against this use, as D4RL formatting can be " - "inconsistent. " - "To download the D4RL data without the D4RL library, use " - "direct_download=True in the dataset constructor. " - "Recurring to `direct_download=False` will soon be deprecated." - ) + warnings.warn( + "You are using the D4RL library for collecting data. " + "We advise against this use, as D4RL formatting can be " + "inconsistent. " + "To download the D4RL data without the D4RL library, use " + "direct_download=True in the dataset constructor. " + "Recurring to `direct_download=False` will soon be deprecated." + ) self.from_env = from_env else: - if from_env is None: - from_env = False self.from_env = from_env if (download == "force") or (download and not self._is_downloaded()): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index fd847f25c74..45c7be64a1a 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1087,7 +1087,7 @@ def _reset_batch_size(x): shape = x.get("_rb_batch_size", None) if shape is not None: warnings.warn( - "Reshaping nested tensordicts will be deprecated soon.", + "Reshaping nested tensordicts will be deprecated in v0.4.0.", category=DeprecationWarning, ) data = x.get("_data") diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1cfc970e61f..efe928856b9 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -379,7 +379,7 @@ def high(self, value): @property def minimum(self): warnings.warn( - f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low", + f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0", category=DeprecationWarning, ) return self._low.to(self.device) @@ -387,7 +387,7 @@ def minimum(self): @property def maximum(self): warnings.warn( - f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high", + f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0", category=DeprecationWarning, ) return self._high.to(self.device) @@ -1521,7 +1521,7 @@ class BoundedTensorSpec(TensorSpec): # SPEC_HANDLED_FUNCTIONS = {} DEPRECATED_KWARGS = ( "The `minimum` and `maximum` keyword arguments are now " - "deprecated in favour of `low` and `high`." + "deprecated in favour of `low` and `high` in v0.4.0." ) CONFLICTING_KWARGS = ( "The keyword arguments {} and {} conflict. Only one of these can be passed." diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 3ce3d2d630c..002b270cd84 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -520,7 +520,7 @@ def info_dict_reader(self, value: callable): warnings.warn( f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " f"This method will append a reader to the list of existing readers (if any). " - f"Setting info_dict_reader directly will be soon deprecated.", + f"Setting info_dict_reader directly will be deprecated in v0.4.0.", category=DeprecationWarning, ) self._info_dict_reader.append(value) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e59c481419c..a661b152d39 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2372,15 +2372,9 @@ def __init__( standard_normal: bool = False, ): if in_keys is None: - warnings.warn( - "Not passing in_keys to ObservationNorm will soon be deprecated. " - "Ensure you specify the entries to be normalized", - category=DeprecationWarning, + raise RuntimeError( + "Not passing in_keys to ObservationNorm is a deprecated behaviour." ) - in_keys = [ - "observation", - "pixels", - ] if out_keys is None: out_keys = copy(in_keys) @@ -2719,7 +2713,7 @@ def __init__( raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}") if padding == "zeros": warnings.warn( - "Padding option 'zeros' will be deprecated in the future. " + "Padding option 'zeros' will be deprecated in v0.4.0. " "Please use 'constant' padding with padding_value 0 instead.", category=DeprecationWarning, ) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 1cc10316045..c610bb61350 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -872,9 +872,6 @@ class DistributionalDQNnet(TensorDictModuleBase): """Distributional Deep Q-Network. Args: - DQNet (nn.Module): (deprecated) Q-Network with output length equal - to the number of atoms: - output.shape = [*batch, atoms, actions]. in_keys (list of str or tuples of str): input keys to the log-softmax operation. Defaults to ``["action_value"]``. out_keys (list of str or tuples of str): output keys to the log-softmax @@ -888,11 +885,11 @@ class DistributionalDQNnet(TensorDictModuleBase): "instead." ) - def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None): + def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): super().__init__() if DQNet is not None: warnings.warn( - f"Passing a network to {type(self)} is going to be deprecated.", + f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.", category=DeprecationWarning, ) if not ( @@ -1280,7 +1277,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, ) -> None: warnings.warn( - "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.", + "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", category=DeprecationWarning, ) super().__init__() diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index bf81cfd5dfd..b7a044cae7d 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -445,7 +445,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -825,7 +825,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -922,7 +922,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -1043,7 +1043,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) @@ -1189,7 +1189,7 @@ def __init__( ): if isinstance(action_space, TensorSpec): warnings.warn( - "Using specs in action_space will be deprecated soon," + "Using specs in action_space will be deprecated in v0.4.0," " please use the 'spec' argument if you want to provide an action spec", category=DeprecationWarning, ) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index c8fa9cc040f..9a7f88844cc 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -247,105 +247,10 @@ def __init__( action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): - warnings.warn( - "EGreedyWrapper is deprecated and it will be removed in v0.3. " - "Please use torchrl.modules.EGreedyModule instead.", - category=DeprecationWarning, + raise RuntimeError( + "This class is not removed in favour of torchrl.modules.EGreedyModule." ) - super().__init__(policy) - self.register_buffer("eps_init", torch.as_tensor([eps_init])) - self.register_buffer("eps_end", torch.as_tensor([eps_end])) - if self.eps_end > self.eps_init: - raise RuntimeError("eps should decrease over time or be constant") - self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) - self.action_key = action_key - self.action_mask_key = action_mask_key - if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) - self._spec = spec - elif hasattr(self.td_module, "_spec"): - self._spec = self.td_module._spec.clone() - if action_key not in self._spec.keys(): - self._spec[action_key] = None - elif hasattr(self.td_module, "spec"): - self._spec = self.td_module.spec.clone() - if action_key not in self._spec.keys(): - self._spec[action_key] = None - else: - self._spec = spec - - @property - def spec(self): - return self._spec - - def step(self, frames: int = 1) -> None: - """A step of epsilon decay. - - After self.annealing_num_steps, this function is a no-op. - - Args: - frames (int): number of frames since last step. - - """ - for _ in range(frames): - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), - ) - - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = self.td_module.forward(tensordict) - if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: - if isinstance(self.action_key, tuple) and len(self.action_key) > 1: - action_tensordict = tensordict.get(self.action_key[:-1]) - action_key = self.action_key[-1] - else: - action_tensordict = tensordict - action_key = self.action_key - - out = action_tensordict.get(action_key) - eps = self.eps.item() - cond = ( - torch.rand(action_tensordict.shape, device=action_tensordict.device) - < eps - ).to(out.dtype) - cond = expand_as_right(cond, out) - spec = self.spec - if spec is not None: - if isinstance(spec, CompositeSpec): - spec = spec[self.action_key] - if spec.shape != out.shape: - # In batched envs if the spec is passed unbatched, the rand() will not - # cover all batched dims - if ( - not len(spec.shape) - or out.shape[-len(spec.shape) :] == spec.shape - ): - spec = spec.expand(out.shape) - else: - raise ValueError( - "Action spec shape does not match the action shape" - ) - if self.action_mask_key is not None: - action_mask = tensordict.get(self.action_mask_key, None) - if action_mask is None: - raise KeyError( - f"Action mask key {self.action_mask_key} not found in {tensordict}." - ) - spec.update_mask(action_mask) - out = cond * spec.rand().to(out.device) + (1 - cond) * out - else: - raise RuntimeError( - "spec must be provided by the policy or directly to the exploration wrapper." - ) - action_tensordict.set(action_key, out) - return tensordict - class AdditiveGaussianWrapper(TensorDictModuleWrapper): """Additive Gaussian PO wrapper. diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index b05cbd55356..13cbd05e877 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from typing import Optional, Tuple import torch @@ -555,11 +554,9 @@ def recurrent_mode(self, value): @property def temporal_mode(self): - warnings.warn( + raise RuntimeError( "temporal_mode is deprecated, use recurrent_mode instead.", - category=DeprecationWarning, ) - return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). @@ -1255,11 +1252,9 @@ def recurrent_mode(self, value): @property def temporal_mode(self): - warnings.warn( + raise RuntimeError( "temporal_mode is deprecated, use recurrent_mode instead.", - category=DeprecationWarning, ) - return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index de963bcfdb9..6edcda5c800 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib -import logging import warnings from copy import deepcopy from dataclasses import dataclass @@ -18,7 +17,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -288,8 +287,7 @@ def __init__( ) self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device)) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type @property @@ -298,41 +296,46 @@ def functional(self): @property def actor(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network @property def critic(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network @property def actor_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network_params @property def critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network_params @property def target_critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.target_critic_network_params diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 04a7708e7db..1f5edcf26ed 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -113,14 +113,11 @@ def __init__(self): # self.register_forward_pre_hook(_parameters_to_tensordict) def _set_deprecated_ctor_keys(self, **kwargs) -> None: - """Helper function to set a tensordict key from a constructor and raise a warning simultaneously.""" for key, value in kwargs.items(): if value is not None: - warnings.warn( + raise RuntimeError( f"Setting '{key}' via the constructor is deprecated, use .set_keys(='some_key') instead.", - category=DeprecationWarning, ) - self.set_keys(**{key: value}) def set_keys(self, **kwargs) -> None: """Set tensordict key names. @@ -217,7 +214,8 @@ def convert_to_functional( """ if kwargs.pop("funs_to_decorate", None) is not None: warnings.warn( - "funs_to_decorate is without effect with the new objective API.", + "funs_to_decorate is without effect with the new objective API. This " + "warning will be replaced by an error in v0.4.0.", category=DeprecationWarning, ) if kwargs: diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 3d90d0174b9..f963f0e0b52 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -26,7 +26,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -332,8 +332,7 @@ def __init__( self.target_entropy_buffer = None if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.temperature = temperature self.min_q_weight = min_q_weight @@ -1030,8 +1029,7 @@ def __init__( self.action_space = _find_action_space(action_space) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 70239ea62e9..6572084c8ec 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -5,7 +5,6 @@ from __future__ import annotations -import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -19,7 +18,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -230,8 +229,7 @@ def __init__( self.loss_function = loss_function if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 6ef7ab7386e..3ff093d445c 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math -import warnings from dataclasses import dataclass from numbers import Number from typing import Tuple, Union @@ -22,7 +21,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator @@ -202,8 +201,7 @@ def __init__( self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @property def target_entropy(self): diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 623c3f7189a..37fd1cbdaea 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -24,7 +24,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -224,8 +224,7 @@ def __init__( self.action_space = _find_action_space(action_space) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 7bdfde573fa..9fd8a8a0bd2 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -15,7 +14,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, hold_out_net, @@ -247,11 +246,9 @@ def __init__( self.imagination_horizon = imagination_horizon self.discount_loss = discount_loss if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) if lmbda is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.lmbda = lmbda + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 1fd48675cb4..62d2a628af4 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -17,7 +17,7 @@ from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -285,22 +285,15 @@ def __init__( self.loss_function = loss_function if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) @property def device(self) -> torch.device: - warnings.warn( - "The device attributes of the looses will be deprecated in v0.3.", - category=DeprecationWarning, - ) - for p in self.parameters(): - return p.device raise RuntimeError( - "At least one of the networks of SACLoss must have trainable " "parameters." + "The device attributes of the losses is deprecated since v0.3.", ) def _set_in_keys(self): @@ -407,7 +400,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: ) # assert has no gradient exp_a = torch.exp((min_q - value) * self.temperature) - exp_a = torch.min(exp_a, torch.FloatTensor([100.0]).to(self.device)) + exp_a = exp_a.clamp_max(100) # write log_prob in tensordict for alpha loss tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) @@ -775,7 +768,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: ) # assert has no gradient exp_a = torch.exp((min_Q - value) * self.temperature) - exp_a = torch.min(exp_a, torch.FloatTensor([100.0]).to(self.device)) + exp_a = exp_a.clamp_max(100) # write log_prob in tensordict for alpha loss tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 38f56108784..f7b9307a962 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -28,7 +28,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -265,8 +265,7 @@ def __init__( self.action_space = _find_action_space(action_space) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 0f7ea835949..ac2244b9a23 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import logging import math import warnings @@ -21,7 +20,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -331,8 +330,7 @@ def __init__( self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._set_deprecated_ctor_keys( advantage=advantage_key, value_target=value_target_key, @@ -363,7 +361,7 @@ def critic(self): @property def actor_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " "link will be removed in v0.4.", category=DeprecationWarning, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index af0a94cbc96..61aaf5990e4 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math -import warnings from dataclasses import dataclass from numbers import Number from typing import Union @@ -21,7 +20,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -313,8 +312,7 @@ def __init__( self.gSDE = gSDE if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index c9cc8f383ad..4613810d0d3 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import logging import warnings from copy import deepcopy from dataclasses import dataclass @@ -17,7 +16,7 @@ from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, distance_loss, ValueEstimators, @@ -281,8 +280,7 @@ def __init__( self.target_critic_network_params = None if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @property def functional(self): @@ -290,41 +288,46 @@ def functional(self): @property def actor(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network @property def critic(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network @property def actor_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.actor_network_params @property def critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.critic_network_params @property def target_critic_params(self): - logging.warning( + warnings.warn( f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4." + "link will be removed in v0.4.", + category=DeprecationWarning, ) return self.target_critic_network_params diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 431296e7486..053da9e53d2 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -25,7 +25,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -374,8 +374,7 @@ def __init__( self.actor_network, self.value_network ) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index e1aeb253681..877a8f0c819 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -18,7 +17,7 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING, + _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, default_value_kwargs, distance_loss, @@ -293,8 +292,7 @@ def __init__( self.register_buffer("max_action", high) self.register_buffer("min_action", low) if gamma is not None: - warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) - self.gamma = gamma + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4c0b8ae67bd..43dfa65c0c4 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -25,9 +25,9 @@ raise err_ft from err from torchrl.envs.utils import step_mdp -_GAMMA_LMBDA_DEPREC_WARNING = ( +_GAMMA_LMBDA_DEPREC_ERROR = ( "Passing gamma / lambda parameters through the loss constructor " - "is deprecated and will be removed soon. To customize your value function, " + "is a deprecated feature. To customize your value function, " "run `loss_module.make_value_estimator(ValueEstimators., gamma=val)`." ) @@ -299,9 +299,8 @@ def __init__( tau: Optional[float] = None, ): if eps is None and tau is None: - warnings.warn( - "Neither eps nor tau was provided. Taking the default value " - "eps=0.999. This behaviour will soon be deprecated.", + raise RuntimeError( + "Neither eps nor tau was provided. " "This behaviour is deprecated.", category=DeprecationWarning, ) eps = 0.999 diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index fc2e58a19f6..dfa56e5c672 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -299,23 +299,17 @@ def __init__( self.shifted = shifted if advantage_key is not None: - warnings.warn( - "Setting 'advantage_key' via ctor is deprecated, use .set_keys(advantage_key='some_key') instead.", - category=DeprecationWarning, + raise RuntimeError( + "Setting 'advantage_key' via constructor is deprecated, use .set_keys(advantage_key='some_key') instead.", ) - self.dep_keys["advantage"] = advantage_key if value_target_key is not None: - warnings.warn( - "Setting 'value_target_key' via ctor is deprecated, use .set_keys(value_target_key='some_key') instead.", - category=DeprecationWarning, + raise RuntimeError( + "Setting 'value_target_key' via constructor is deprecated, use .set_keys(value_target_key='some_key') instead.", ) - self.dep_keys["value_target"] = value_target_key if value_key is not None: - warnings.warn( - "Setting 'value_key' via ctor is deprecated, use .set_keys(value_key='some_key') instead.", - category=DeprecationWarning, + raise RuntimeError( + "Setting 'value_key' via constructor is deprecated, use .set_keys(value_key='some_key') instead.", ) - self.dep_keys["value"] = value_key @property def tensor_keys(self) -> _AcceptedKeys: diff --git a/torchrl/trainers/helpers/logger.py b/torchrl/trainers/helpers/logger.py index 6e4e864aa2e..b0b37533519 100644 --- a/torchrl/trainers/helpers/logger.py +++ b/torchrl/trainers/helpers/logger.py @@ -28,3 +28,5 @@ class LoggerConfig: # Keys to log in the recorder offline_logging: bool = True # If True, Wandb will do the logging offline + project_name: str = "" + # The name of the project for WandB diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 96f8d98477d..0764bf9fb72 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -667,9 +667,11 @@ def __init__( self.device = device if flatten_tensordicts is None: warnings.warn( - "flatten_tensordicts default value will soon be changed " + "flatten_tensordicts default value has now changed " "to False for a faster execution. Make sure your " - "code is robust to this change.", + "code is robust to this change. To silence this warning, " + "pass flatten_tensordicts= in your code. " + "This warning will be removed in v0.4.", category=DeprecationWarning, ) flatten_tensordicts = True From d930f5c6b69aa7c3ae2bf7040363aa072d2e745e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 31 Jan 2024 21:23:23 +0000 Subject: [PATCH 27/35] [Feature] Logger (#1858) --- .../unittest/helpers/coverage_run_parallel.py | 3 +- benchmarks/benchmark_batched_envs.py | 13 ++-- benchmarks/conftest.py | 4 +- ...s_rllib_vs_torchrl_sampling_performance.py | 8 +-- .../benchmark_sample_latency_over_rpc.py | 12 ++-- examples/a2c/a2c_atari.py | 5 +- examples/a2c/a2c_mujoco.py | 5 +- examples/a2c/utils_atari.py | 4 +- examples/a2c/utils_mujoco.py | 4 +- examples/bandits/dqn.py | 2 +- examples/cql/cql_offline.py | 4 +- examples/cql/cql_online.py | 4 +- examples/cql/discrete_cql_online.py | 4 +- examples/ddpg/ddpg.py | 4 +- examples/decision_transformer/dt.py | 6 +- examples/decision_transformer/online_dt.py | 6 +- examples/discrete_sac/discrete_sac.py | 4 +- .../collectors/multi_nodes/delayed_dist.py | 4 +- .../collectors/multi_nodes/delayed_rpc.py | 4 +- .../collectors/multi_nodes/generic.py | 4 +- .../distributed/collectors/multi_nodes/ray.py | 5 +- .../collectors/multi_nodes/ray_train.py | 8 +-- .../distributed/collectors/multi_nodes/rpc.py | 4 +- .../collectors/multi_nodes/sync.py | 4 +- .../collectors/single_machine/generic.py | 4 +- .../collectors/single_machine/rpc.py | 4 +- .../collectors/single_machine/sync.py | 4 +- .../distributed_replay_buffer.py | 35 +++++---- examples/dqn/dqn_atari.py | 4 +- examples/dqn/dqn_cartpole.py | 4 +- examples/dreamer/dreamer.py | 8 +-- examples/impala/impala_multi_node_ray.py | 5 +- examples/impala/impala_multi_node_submitit.py | 5 +- examples/impala/impala_single_node.py | 5 +- examples/iql/discrete_iql.py | 4 +- examples/iql/iql_offline.py | 4 +- examples/iql/iql_online.py | 4 +- examples/iql/offline_config.yaml | 2 +- examples/iql/utils.py | 4 +- examples/multiagent/iql.py | 4 +- examples/multiagent/maddpg_iddpg.py | 4 +- examples/multiagent/mappo_ippo.py | 4 +- examples/multiagent/qmix_vdn.py | 4 +- examples/multiagent/sac.py | 4 +- examples/ppo/ppo_atari.py | 5 +- examples/ppo/ppo_mujoco.py | 5 +- examples/redq/utils.py | 8 +-- examples/rlhf/models/reward.py | 4 +- examples/rlhf/models/transformer.py | 5 +- examples/rlhf/train.py | 8 +-- examples/rlhf/train_reward.py | 8 +-- examples/sac/sac.py | 4 +- examples/td3/td3.py | 4 +- test/_utils_internal.py | 5 +- test/test_collector.py | 5 +- test/test_distributed.py | 6 +- test/test_libs.py | 25 +++---- test/test_rb_distributed.py | 4 +- test/test_shared.py | 20 +++--- torchrl/_utils.py | 25 +++++-- torchrl/collectors/collectors.py | 16 ++--- torchrl/collectors/distributed/generic.py | 71 ++++++++++--------- torchrl/collectors/distributed/ray.py | 13 ++-- torchrl/collectors/distributed/rpc.py | 41 ++++++----- torchrl/collectors/distributed/sync.py | 26 +++---- torchrl/collectors/distributed/utils.py | 9 ++- torchrl/data/datasets/atari_dqn.py | 8 +-- torchrl/data/datasets/d4rl.py | 10 +-- torchrl/data/datasets/gen_dgrl.py | 8 +-- torchrl/data/datasets/minari_data.py | 15 ++-- torchrl/data/datasets/roboset.py | 27 +++---- torchrl/data/datasets/vd4rl.py | 7 +- torchrl/data/replay_buffers/storages.py | 16 +++-- torchrl/data/rlhf/dataset.py | 4 +- torchrl/envs/batched_envs.py | 20 +++--- torchrl/envs/env_creator.py | 5 +- torchrl/envs/gym_like.py | 4 +- torchrl/envs/libs/dm_control.py | 5 +- torchrl/envs/libs/envpool.py | 6 +- torchrl/envs/transforms/vc1.py | 4 +- torchrl/envs/utils.py | 5 +- torchrl/record/loggers/csv.py | 2 +- torchrl/trainers/helpers/envs.py | 7 +- torchrl/trainers/helpers/models.py | 2 +- torchrl/trainers/helpers/trainers.py | 5 +- torchrl/trainers/trainers.py | 10 ++- 86 files changed, 372 insertions(+), 356 deletions(-) diff --git a/.github/unittest/helpers/coverage_run_parallel.py b/.github/unittest/helpers/coverage_run_parallel.py index ca156b72c2d..8c6251cf82b 100644 --- a/.github/unittest/helpers/coverage_run_parallel.py +++ b/.github/unittest/helpers/coverage_run_parallel.py @@ -11,7 +11,6 @@ nevertheless. It writes temporary coverage config files on the fly and invokes coverage with proper arguments """ -import logging import os import shlex import subprocess @@ -45,7 +44,7 @@ def write_config(config_path: Path, argv: List[str]) -> None: def main(argv: List[str]) -> int: if len(argv) < 1: - logging.info( + print( # noqa "Usage: 'python coverage_run_parallel.py [command arguments]'" ) sys.exit(1) diff --git a/benchmarks/benchmark_batched_envs.py b/benchmarks/benchmark_batched_envs.py index 3c21372a369..d207778d56c 100644 --- a/benchmarks/benchmark_batched_envs.py +++ b/benchmarks/benchmark_batched_envs.py @@ -15,11 +15,8 @@ """ -import logging - -logging.basicConfig(level=logging.ERROR) -logging.captureWarnings(True) import pandas as pd +from torchrl._utils import logger as torchrl_logger pd.set_option("display.max_columns", 100) pd.set_option("display.width", 1000) @@ -68,8 +65,8 @@ def run_env(env): devices.append("cuda") for device in devices: for num_workers in [1, 4, 16]: - logging.info(f"With num_workers={num_workers}, {device}") - logging.info("Multithreaded...") + torchrl_logger.info(f"With num_workers={num_workers}, {device}") + torchrl_logger.info("Multithreaded...") env_multithreaded = create_multithreaded(num_workers, device) res_multithreaded = Timer( stmt="run_env(env)", @@ -78,7 +75,7 @@ def run_env(env): ) time_multithreaded = res_multithreaded.blocked_autorange().mean - logging.info("Serial...") + torchrl_logger.info("Serial...") env_serial = create_serial(num_workers, device) res_serial = Timer( stmt="run_env(env)", @@ -87,7 +84,7 @@ def run_env(env): ) time_serial = res_serial.blocked_autorange().mean - logging.info("Parallel...") + torchrl_logger.info("Parallel...") env_parallel = create_parallel(num_workers, device) res_parallel = Timer( stmt="run_env(env)", diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py index bec558ac92d..307839afa3d 100644 --- a/benchmarks/conftest.py +++ b/benchmarks/conftest.py @@ -2,13 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import os import time import warnings from collections import defaultdict import pytest +from torchrl._utils import logger as torchrl_logger CALL_TIMES = defaultdict(lambda: 0.0) @@ -32,7 +32,7 @@ def pytest_sessionfinish(maxprint=50): out_str += f"\t{key}{spaces}{item: 4.4f}s\n" if i == maxprint - 1: break - logging.info(out_str) + torchrl_logger.info(out_str) @pytest.fixture(autouse=True) diff --git a/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py b/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py index 02526095a60..8b44599b656 100644 --- a/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +++ b/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. -import logging import os import pickle @@ -23,6 +22,7 @@ from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune import register_env +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.envs.libs.vmas import VmasEnv from vmas import Wrapper @@ -165,11 +165,11 @@ def run_comparison_torchrl_rllib( evaluation = {} for framework in ["TorchRL", "RLlib"]: if framework not in evaluation.keys(): - logging.info(f"\nFramework {framework}") + torchrl_logger.info(f"\nFramework {framework}") vmas_times = [] for n_envs in list_n_envs: n_envs = int(n_envs) - logging.info(f"Running {n_envs} environments") + torchrl_logger.info(f"Running {n_envs} environments") if framework == "TorchRL": vmas_times.append( (n_envs * n_steps) @@ -190,7 +190,7 @@ def run_comparison_torchrl_rllib( device=device, ) ) - logging.info(f"fps {vmas_times[-1]}s") + torchrl_logger.info(f"fps {vmas_times[-1]}s") evaluation[framework] = vmas_times store_pickled_evaluation(name=figure_name_pkl, evaluation=evaluation) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 693cbb9a462..4af76440290 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -14,7 +14,6 @@ This code is based on examples/distributed/distributed_replay_buffer.py. """ import argparse -import logging import os import pickle import sys @@ -25,6 +24,7 @@ import torch import torch.distributed.rpc as rpc from tensordict import TensorDict +from torchrl._utils import logger as torchrl_logger from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import ( @@ -106,10 +106,10 @@ def _create_replay_buffer(self) -> rpc.RRef: buffer_rref = rpc.remote( replay_buffer_info, ReplayBufferNode, args=(1000000,) ) - logging.info(f"Connected to replay buffer {replay_buffer_info}") + torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}") return buffer_rref except Exception: - logging.info("Failed to connect to replay buffer") + torchrl_logger.info("Failed to connect to replay buffer") time.sleep(RETRY_DELAY_SECS) @@ -144,7 +144,7 @@ def __init__(self, capacity: int): rank = args.rank storage_type = args.storage - logging.info(f"Rank: {rank}; Storage: {storage_type}") + torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" @@ -167,7 +167,7 @@ def __init__(self, capacity: int): if i == 0: continue results.append(result) - logging.info(i, results[-1]) + torchrl_logger.info(f"{i}, {results[-1]}") with open( f'./benchmark_{datetime.now().strftime("%d-%m-%Y%H:%M:%S")};batch_size={BATCH_SIZE};tensor_size={TENSOR_SIZE};repeat={REPEATS};storage={storage_type}.pkl', @@ -176,7 +176,7 @@ def __init__(self, capacity: int): pickle.dump(results, f) tensor_results = torch.tensor(results) - logging.info(f"Mean: {torch.mean(tensor_results)}") + torchrl_logger.info(f"Mean: {torch.mean(tensor_results)}") breakpoint() elif rank == 1: # rank 1 is the replay buffer diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index f329bf7b120..d6e78ad1575 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -2,9 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main(config_path=".", config_name="config_atari", version_base="1.1") @@ -220,7 +219,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 2f38af032a8..6a95814fe4e 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -2,9 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") @@ -205,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/a2c/utils_atari.py b/examples/a2c/utils_atari.py index 7b3625b1e2b..89a51f7e64b 100644 --- a/examples/a2c/utils_atari.py +++ b/examples/a2c/utils_atari.py @@ -98,8 +98,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, + "min": proof_environment.action_spec.space.low, + "max": proof_environment.action_spec.space.high, } # Define input keys diff --git a/examples/a2c/utils_mujoco.py b/examples/a2c/utils_mujoco.py index cdc681da522..50780a9d086 100644 --- a/examples/a2c/utils_mujoco.py +++ b/examples/a2c/utils_mujoco.py @@ -51,8 +51,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, + "min": proof_environment.action_spec.space.low, + "max": proof_environment.action_spec.space.high, "tanh_loc": False, } diff --git a/examples/bandits/dqn.py b/examples/bandits/dqn.py index 0d9ca828ee6..55ba34f5010 100644 --- a/examples/bandits/dqn.py +++ b/examples/bandits/dqn.py @@ -122,4 +122,4 @@ f"training reward {data['next', 'reward'].sum() / env.numel() : 4.4f}, " f"loss {loss_val: 4.4f} (init: {init_loss: 4.4f})" ) - policy.step() + policy[1].step() diff --git a/examples/cql/cql_offline.py b/examples/cql/cql_offline.py index 8f1dc5e3897..e0f59a5f406 100644 --- a/examples/cql/cql_offline.py +++ b/examples/cql/cql_offline.py @@ -9,13 +9,13 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra import numpy as np import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - logging.info(f"Training time: {time.time() - start_time}") + torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/cql/cql_online.py b/examples/cql/cql_online.py index c42e733c31b..fec979e2289 100644 --- a/examples/cql/cql_online.py +++ b/examples/cql/cql_online.py @@ -11,7 +11,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra @@ -19,6 +18,7 @@ import torch import tqdm from tensordict import TensorDict +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -211,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") collector.shutdown() diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py index 107739f3aba..cb15919e252 100644 --- a/examples/cql/discrete_cql_online.py +++ b/examples/cql/discrete_cql_online.py @@ -10,7 +10,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra @@ -18,6 +17,7 @@ import torch import torch.cuda import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -196,7 +196,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 92fdd850fbd..e10507cc7f3 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -10,7 +10,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra @@ -19,6 +18,7 @@ import torch import torch.cuda import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 894562185d9..9dc4c855f30 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -6,13 +6,13 @@ This is a self-contained example of an offline Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra import numpy as np import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -79,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pretrain_log_interval = cfg.logger.pretrain_log_interval reward_scaling = cfg.env.reward_scaling - logging.info(" ***Pretraining*** ") + torchrl_logger.info(" ***Pretraining*** ") # Pretraining start_time = time.time() for i in range(pretrain_gradient_steps): @@ -116,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - logging.info(f"Training time: {time.time() - start_time}") + torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index a1df18e5fe6..0ea70c73093 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -6,13 +6,13 @@ This is a self-contained example of an Online Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra import numpy as np import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -81,7 +81,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pretrain_log_interval = cfg.logger.pretrain_log_interval reward_scaling = cfg.env.reward_scaling - logging.info(" ***Pretraining*** ") + torchrl_logger.info(" ***Pretraining*** ") # Pretraining start_time = time.time() for i in range(pretrain_gradient_steps): @@ -132,7 +132,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - logging.info(f"Training time: {time.time() - start_time}") + torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 16c5de80a64..6bc4ad91d1a 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -10,7 +10,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra @@ -18,6 +17,7 @@ import torch import torch.cuda import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -213,7 +213,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index e026912f698..9bf17b76c10 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -23,11 +23,11 @@ and DEFAULT_SLURM_CONF_MAIN dictionaries below). """ -import logging import time from argparse import ArgumentParser import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.distributed import submitit_delayed_launcher from torchrl.collectors.distributed.default_configs import ( @@ -150,7 +150,7 @@ def make_env(): if i == 10: t0 = time.time() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") collector.shutdown() exit() diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index 0f38d898dfc..890968c5aae 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -23,11 +23,11 @@ and DEFAULT_SLURM_CONF_MAIN dictionaries below). """ -import logging import time from argparse import ArgumentParser import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.distributed import RPCDataCollector, submitit_delayed_launcher from torchrl.collectors.distributed.default_configs import ( @@ -148,7 +148,7 @@ def make_env(): if i == 10: t0 = time.time() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") collector.shutdown() exit() diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index 07c83ba98fb..9338a0acea7 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -2,13 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time from argparse import ArgumentParser import gym import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import ( MultiSyncDataCollector, @@ -128,5 +128,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/multi_nodes/ray.py b/examples/distributed/collectors/multi_nodes/ray.py index 21d550281a2..e70b0a58d09 100644 --- a/examples/distributed/collectors/multi_nodes/ray.py +++ b/examples/distributed/collectors/multi_nodes/ray.py @@ -7,10 +7,9 @@ This example should create 3 collector instances, 1 local and 2 remote, but 4 instances seem to be created. Why? """ -import logging - from tensordict.nn import TensorDictModule from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.distributed.ray import RayCollector from torchrl.envs.libs.gym import GymEnv @@ -45,4 +44,4 @@ def env_maker(): for batch in distributed_collector: counter += 1 num_frames += batch.shape.numel() - logging.info(f"batch {counter}, total frames {num_frames}") + torchrl_logger.info(f"batch {counter}, total frames {num_frames}") diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 7d456367a5a..b05e92619fa 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -5,7 +5,6 @@ This script reproduces the PPO example in https://pytorch.org/rl/tutorials/coding_ppo.html with a RayCollector. """ -import logging from collections import defaultdict import matplotlib.pyplot as plt @@ -13,6 +12,7 @@ from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.collectors.distributed.ray import RayCollector from torchrl.data.replay_buffers import ReplayBuffer @@ -85,8 +85,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.minimum, - "max": env.action_spec.space.maximum, + "min": env.action_spec.space.low, + "max": env.action_spec.space.high, }, return_log_prob=True, ) @@ -235,4 +235,4 @@ plt.title("Max step count (test)") save_name = "/tmp/results.jpg" plt.savefig(save_name) - logging.info(f"results saved in {save_name}") + torchrl_logger.info(f"results saved in {save_name}") diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index 2fdbdc47a4c..be30b9c3668 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time from argparse import ArgumentParser @@ -10,6 +9,7 @@ import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import ( MultiSyncDataCollector, @@ -116,5 +116,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index 65b93beb294..688090ca691 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -2,13 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time from argparse import ArgumentParser import gym import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import ( MultiSyncDataCollector, @@ -122,5 +122,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index 77dbf4a7cde..cd723a63806 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -17,7 +17,6 @@ `--env` flag. Any available gym env will work. """ -import logging import time from argparse import ArgumentParser @@ -25,6 +24,7 @@ import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import ( MultiaSyncDataCollector, @@ -160,5 +160,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 4ca9e9f4a3e..c001a6586b1 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -17,7 +17,6 @@ `--env` flag. Any available gym env will work. """ -import logging import time from argparse import ArgumentParser @@ -25,6 +24,7 @@ import torch.cuda import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.distributed import RPCDataCollector @@ -128,5 +128,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index 8b3bd02aad2..b5c77ebdb5b 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -18,7 +18,6 @@ `--env` flag. Any available gym env will work. """ -import logging import time from argparse import ArgumentParser @@ -26,6 +25,7 @@ import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import ( MultiSyncDataCollector, @@ -152,5 +152,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index 64f4627e2e5..0cb9aaaffbd 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -8,7 +8,6 @@ """ import argparse -import logging import os import random import sys @@ -17,7 +16,7 @@ import torch import torch.distributed.rpc as rpc from tensordict import TensorDict -from torchrl._utils import accept_remote_rref_invocation +from torchrl._utils import accept_remote_rref_invocation, logger as torchrl_logger from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -51,7 +50,7 @@ class DummyDataCollectorNode: def __init__(self, replay_buffer: rpc.RRef) -> None: self.id = rpc.get_worker_info().id self.replay_buffer = replay_buffer - logging.info("Data Collector Node constructed") + torchrl_logger.info("Data Collector Node constructed") def _submit_random_item_async(self) -> rpc.RRef: td = TensorDict({"a": torch.randint(100, (1,))}, []) @@ -69,7 +68,7 @@ def collect(self): """Method that begins experience collection (we just generate random TensorDicts in this example). `accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation `rpc.RRef` is provided in place of the object reference.""" for elem in range(50): time.sleep(random.randint(1, 4)) - logging.info( + torchrl_logger.info( f"Collector [{self.id}] submission {elem}: {self._submit_random_item_async().to_here()}" ) @@ -78,22 +77,22 @@ class DummyTrainerNode: """Trainer node responsible for learning from experiences sampled from an experience replay buffer.""" def __init__(self) -> None: - logging.info("DummyTrainerNode") + torchrl_logger.info("DummyTrainerNode") self.id = rpc.get_worker_info().id self.replay_buffer = self._create_replay_buffer() self._create_and_launch_data_collectors() def train(self, iterations: int) -> None: for iteration in range(iterations): - logging.info(f"[{self.id}] Training Iteration: {iteration}") + torchrl_logger.info(f"[{self.id}] Training Iteration: {iteration}") time.sleep(3) batch = rpc.rpc_sync( self.replay_buffer.owner(), ReplayBufferNode.sample, args=(self.replay_buffer, 16), ) - logging.info(f"[{self.id}] Sample Obtained Iteration: {iteration}") - logging.info(f"{batch}") + torchrl_logger.info(f"[{self.id}] Sample Obtained Iteration: {iteration}") + torchrl_logger.info(f"{batch}") def _create_replay_buffer(self) -> rpc.RRef: while True: @@ -102,10 +101,10 @@ def _create_replay_buffer(self) -> rpc.RRef: buffer_rref = rpc.remote( replay_buffer_info, ReplayBufferNode, args=(10000,) ) - logging.info(f"Connected to replay buffer {replay_buffer_info}") + torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}") return buffer_rref except Exception as e: - logging.info(f"Failed to connect to replay buffer: {e}") + torchrl_logger.info(f"Failed to connect to replay buffer: {e}") time.sleep(RETRY_DELAY_SECS) def _create_and_launch_data_collectors(self) -> None: @@ -119,7 +118,7 @@ def _create_and_launch_data_collectors(self) -> None: data_collector_info = rpc.get_worker_info( f"DataCollector{data_collector_number}" ) - logging.info(f"Data collector info: {data_collector_info}") + torchrl_logger.info(f"Data collector info: {data_collector_info}") dc_ref = rpc.remote( data_collector_info, DummyDataCollectorNode, @@ -131,11 +130,11 @@ def _create_and_launch_data_collectors(self) -> None: retries = 0 except Exception: retries += 1 - logging.info( + torchrl_logger.info( f"Failed to connect to DataCollector{data_collector_number} with {retries} retries" ) if retries >= RETRY_LIMIT: - logging.info(f"{len(data_collectors)} data collectors") + torchrl_logger.info(f"{len(data_collectors)} data collectors") for data_collector_info, data_collector in zip( data_collector_infos, data_collectors ): @@ -171,7 +170,7 @@ def __init__(self, capacity: int): if __name__ == "__main__": args = parser.parse_args() rank = args.rank - logging.info(f"Rank: {rank}") + torchrl_logger.info(f"Rank: {rank}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" @@ -188,21 +187,21 @@ def __init__(self, capacity: int): backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) - logging.info(f"Initialised Trainer Node {rank}") + torchrl_logger.info(f"Initialised Trainer Node {rank}") trainer = DummyTrainerNode() trainer.train(100) breakpoint() elif rank == 1: # rank 1 is the replay buffer # replay buffer waits passively for construction instructions from trainer node - logging.info(REPLAY_BUFFER_NODE) + torchrl_logger.info(REPLAY_BUFFER_NODE) rpc.init_rpc( REPLAY_BUFFER_NODE, rank=rank, backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) - logging.info(f"Initialised RB Node {rank}") + torchrl_logger.info(f"Initialised RB Node {rank}") breakpoint() elif rank >= 2: # rank 2+ is a new data collector node @@ -213,7 +212,7 @@ def __init__(self, capacity: int): backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) - logging.info(f"Initialised DC Node {rank}") + torchrl_logger.info(f"Initialised DC Node {rank}") breakpoint() else: sys.exit(1) diff --git a/examples/dqn/dqn_atari.py b/examples/dqn/dqn_atari.py index 34be877c320..1d7f5dd81b5 100644 --- a/examples/dqn/dqn_atari.py +++ b/examples/dqn/dqn_atari.py @@ -7,7 +7,6 @@ DQN: Reproducing experimental results from Mnih et al. 2015 for the Deep Q-Learning Algorithm on Atari Environments. """ -import logging import tempfile import time @@ -16,6 +15,7 @@ import torch.optim import tqdm from tensordict.nn import TensorDictSequential +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -215,7 +215,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/dqn/dqn_cartpole.py b/examples/dqn/dqn_cartpole.py index 5f6bc742cb7..74f5ea99249 100644 --- a/examples/dqn/dqn_cartpole.py +++ b/examples/dqn/dqn_cartpole.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra @@ -11,6 +10,7 @@ import tqdm from tensordict.nn import TensorDictSequential +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type @@ -194,7 +194,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 1cf6c91856f..27732fd96f7 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -1,5 +1,4 @@ import dataclasses -import logging from pathlib import Path import hydra @@ -19,6 +18,7 @@ # float16 from torch.cuda.amp import autocast, GradScaler from torch.nn.utils import clip_grad_norm_ +from torchrl._utils import logger as torchrl_logger from torchrl.envs import EnvBase from torchrl.modules.tensordict_module.exploration import ( @@ -84,7 +84,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.model_device) else: device = torch.device("cpu") - logging.info(f"Using device {device}") + torchrl_logger.info(f"Using device {device}") exp_name = generate_exp_name("Dreamer", cfg.exp_name) logger = get_logger( @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_model_explore=exploration_policy, cfg=cfg, ) - logging.info("collector:", collector) + torchrl_logger.info(f"collector: {collector}") replay_buffer = make_replay_buffer("cpu", cfg) @@ -205,7 +205,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) final_seed = collector.set_seed(cfg.seed) - logging.info(f"init seed: {cfg.seed}, final seed: {final_seed}") + torchrl_logger.info(f"init seed: {cfg.seed}, final seed: {final_seed}") # Training loop collected_frames = 0 pbar = tqdm.tqdm(total=cfg.total_frames) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 0a2ce0d02e2..e52b3af8342 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -7,9 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") @@ -277,7 +276,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index d702a17a4e6..acd36f09d5a 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -7,9 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main( @@ -269,7 +268,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 836e85de8c3..1faff37d1e0 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -7,9 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main(config_path=".", config_name="config_single_node", version_base="1.1") @@ -247,7 +246,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/iql/discrete_iql.py b/examples/iql/discrete_iql.py index 8a8307366fc..9a685faa036 100644 --- a/examples/iql/discrete_iql.py +++ b/examples/iql/discrete_iql.py @@ -11,13 +11,13 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra import numpy as np import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/iql/iql_offline.py b/examples/iql/iql_offline.py index b6895592a20..d3e15221f30 100644 --- a/examples/iql/iql_offline.py +++ b/examples/iql/iql_offline.py @@ -9,13 +9,13 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra import numpy as np import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -129,7 +129,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - logging.info(f"Training time: {time.time() - start_time}") + torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 461eb6bb37d..e1295eaabfe 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -11,13 +11,13 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra import numpy as np import torch import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/iql/offline_config.yaml b/examples/iql/offline_config.yaml index 341e995967a..f7486708c5a 100644 --- a/examples/iql/offline_config.yaml +++ b/examples/iql/offline_config.yaml @@ -1,6 +1,6 @@ # env and task env: - name: HalfCheetah-v2 + name: HalfCheetah-v4 task: "" exp_name: iql_${replay_buffer.dataset} n_samples_stats: 1000 diff --git a/examples/iql/utils.py b/examples/iql/utils.py index 997df401b82..8b594d3a60c 100644 --- a/examples/iql/utils.py +++ b/examples/iql/utils.py @@ -203,8 +203,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 011e04cde77..1408e47e915 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra @@ -10,6 +9,7 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - logging.info(f"\nIteration {i}") + torchrl_logger.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index e4fd4a25e12..d4ed03ad3c6 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra @@ -10,6 +9,7 @@ from tensordict.nn import TensorDictModule from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -170,7 +170,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - logging.info(f"\nIteration {i}") + torchrl_logger.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index d4481c93071..8f4a2356c35 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra @@ -11,6 +10,7 @@ from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -167,7 +167,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - logging.info(f"\nIteration {i}") + torchrl_logger.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index e53c47e04f4..e814ce8f79f 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra @@ -10,6 +9,7 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -170,7 +170,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - logging.info(f"\nIteration {i}") + torchrl_logger.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 528b5422921..28317dba728 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra @@ -12,6 +11,7 @@ from tensordict.nn.distributions import NormalParamExtractor from torch import nn from torch.distributions import Categorical, OneHotCategorical +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -237,7 +237,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - logging.info(f"\nIteration {i}") + torchrl_logger.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index 238e612e614..6b9a18ae5bb 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -7,9 +7,8 @@ This script reproduces the Proximal Policy Optimization (PPO) Algorithm results from Schulman et al. 2017 for the on Atari Environments. """ -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main(config_path=".", config_name="config_atari", version_base="1.1") @@ -238,7 +237,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index 83ee779c6ab..fa497230b6e 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -7,9 +7,8 @@ This script reproduces the Proximal Policy Optimization (PPO) Algorithm results from Schulman et al. 2017 for the on MuJoCo Environments. """ -import logging - import hydra +from torchrl._utils import logger as torchrl_logger @hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1") @@ -230,7 +229,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/redq/utils.py b/examples/redq/utils.py index ef377903202..fe78fa83432 100644 --- a/examples/redq/utils.py +++ b/examples/redq/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import logging from copy import copy from typing import Callable, Dict, Optional, Sequence, Tuple, Union @@ -18,7 +17,8 @@ ) from torch import distributions as d, nn, optim from torch.optim.lr_scheduler import CosineAnnealingLR -from torchrl._utils import VERBOSE + +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.collectors import DataCollectorBase from torchrl.data import ReplayBuffer, TensorDictReplayBuffer @@ -217,7 +217,7 @@ def make_trainer( >>> logger = TensorboardLogger(exp_name=dir) >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, ... replay_buffer, logger) - >>> logging.info(trainer) + >>> torchrl_logger.info(trainer) """ @@ -244,7 +244,7 @@ def make_trainer( raise NotImplementedError(f"lr scheduler {cfg.optim.lr_scheduler}") if VERBOSE: - logging.info( + torchrl_logger.info( f"collector = {collector}; \n" f"loss_module = {loss_module}; \n" f"recorder = {recorder}; \n" diff --git a/examples/rlhf/models/reward.py b/examples/rlhf/models/reward.py index c11f1c02244..c07861b2255 100644 --- a/examples/rlhf/models/reward.py +++ b/examples/rlhf/models/reward.py @@ -2,11 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import warnings import torch from tensordict.nn import TensorDictModule +from torchrl._utils import logger as torchrl_logger from torchrl.modules.models.rlhf import GPT2RewardModel @@ -31,7 +31,7 @@ def init_reward_model( model.to(device) if compile_model: - logging.info("Compiling the reward model...") + torchrl_logger.info("Compiling the reward model...") model = torch.compile(model) model = TensorDictModule( diff --git a/examples/rlhf/models/transformer.py b/examples/rlhf/models/transformer.py index d1c2b02d0a9..5609679bb1e 100644 --- a/examples/rlhf/models/transformer.py +++ b/examples/rlhf/models/transformer.py @@ -2,10 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging - import torch from tensordict.nn import TensorDictModule +from torchrl._utils import logger as torchrl_logger from transformers import GPT2LMHeadModel @@ -29,7 +28,7 @@ def init_transformer( model.to(device) if compile_model: - logging.info("Compiling transformer model...") + torchrl_logger.info("Compiling transformer model...") model = torch.compile(model) if as_tensordictmodule: diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py index f5551e47579..960c24ac969 100644 --- a/examples/rlhf/train.py +++ b/examples/rlhf/train.py @@ -9,13 +9,13 @@ To run on a single GPU, example: $ python train.py --batch_size=32 --compile=False """ -import logging import time import hydra import torch from models.transformer import init_transformer from torch.optim.lr_scheduler import CosineAnnealingLR +from torchrl._utils import logger as torchrl_logger from torchrl.data.rlhf.dataset import get_dataloader from torchrl.data.rlhf.prompt import PromptData @@ -135,20 +135,20 @@ def main(cfg): train_loss = estimate_loss(model, train_loader) val_loss = estimate_loss(model, val_loader) msg = f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}" - logging.info(msg) + torchrl_logger.info(msg) loss_logger.info(msg) if val_loss < best_val_loss or always_save_checkpoint: best_val_loss = val_loss if it > 0: msg = f"saving checkpoint to {out_dir}" - logging.info(msg) + torchrl_logger.info(msg) loss_logger.info(msg) model.module.save_pretrained(out_dir) elif it % log_interval == 0: # loss as float. note: this is a CPU-GPU sync point loss = batch.loss.item() msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms" - logging.info(msg) + torchrl_logger.info(msg) loss_logger.info(msg) diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py index ac1299f0175..5be1c4a3d65 100644 --- a/examples/rlhf/train_reward.py +++ b/examples/rlhf/train_reward.py @@ -2,13 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import time import hydra import torch from models.reward import init_reward_model from torch.optim.lr_scheduler import CosineAnnealingLR +from torchrl._utils import logger as torchrl_logger from torchrl.data.rlhf.dataset import get_dataloader from torchrl.data.rlhf.reward import PairwiseDataset from utils import get_file_logger, resolve_name_or_path, setup @@ -141,13 +141,13 @@ def main(cfg): f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}, " f"{train_acc=:.4f}, {val_acc=:.4f}" ) - logging.info(msg) + torchrl_logger.info(msg) loss_logger.info(msg) if val_loss < best_val_loss or always_save_checkpoint: best_val_loss = val_loss if it > 0: msg = f"saving checkpoint to {reward_out_dir}" - logging.info(msg) + torchrl_logger.info(msg) loss_logger.info(msg) model.module.save_pretrained(reward_out_dir) elif it % log_interval == 0: @@ -156,7 +156,7 @@ def main(cfg): batch.chosen_data.end_scores, batch.rejected_data.end_scores ) msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms" - logging.info(msg) + torchrl_logger.info(msg) loss_logger.info(msg) diff --git a/examples/sac/sac.py b/examples/sac/sac.py index db23071867a..5b0cad1a7c9 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -10,7 +10,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra @@ -20,6 +19,7 @@ import torch.cuda import tqdm from tensordict import TensorDict +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -209,7 +209,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 003a3bf228c..ef2edd578cb 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -10,7 +10,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import logging import time import hydra @@ -18,6 +17,7 @@ import torch import torch.cuda import tqdm +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -207,7 +207,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - logging.info(f"Training took {execution_time:.2f} seconds to finish") + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/test/_utils_internal.py b/test/_utils_internal.py index c9fdc7e39ba..ec73812844b 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import contextlib -import logging import os import os.path @@ -19,7 +18,7 @@ import torch.cuda from tensordict import tensorclass, TensorDict -from torchrl._utils import implement_for, seed_generator +from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm @@ -120,7 +119,7 @@ def f_retry(*args, **kwargs): return f(*args, **kwargs) except ExceptionToCheck as e: msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) - logging.info(msg) + torchrl_logger.info(msg) time.sleep(mdelay) mtries -= 1 try: diff --git a/test/test_collector.py b/test/test_collector.py index 8369be1578e..b5afe7f35d7 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -6,7 +6,6 @@ import argparse import gc -import logging import sys @@ -44,7 +43,7 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import nn -from torchrl._utils import _replace_last, prod, seed_generator +from torchrl._utils import _replace_last, logger as torchrl_logger, prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( _Interruptor, @@ -2390,7 +2389,7 @@ def test_num_threads(self): c.shutdown() del c except Exception: - logging.info("Failed to shut down collector") + torchrl_logger.info("Failed to shut down collector") # reset vals collectors._main_async_collector = _main_async_collector_saved torch.set_num_threads(num_threads) diff --git a/test/test_distributed.py b/test/test_distributed.py index 6215abd7ceb..5f37d8bcac9 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -8,13 +8,13 @@ """ import abc import argparse -import logging import os import sys import time import pytest from tensordict.nn import TensorDictModuleBase +from torchrl._utils import logger as torchrl_logger try: import ray @@ -90,7 +90,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): cls._start_worker() env = ContinuousActionVecMockEnv policy = RandomPolicy(env().action_spec) - logging.info("creating collector") + torchrl_logger.info("creating collector") collector = cls.distributed_class()( [env] * 2, policy, @@ -99,7 +99,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): **cls.distributed_kwargs(), ) total = 0 - logging.info("getting data...") + torchrl_logger.info("getting data...") for data in collector: total += data.numel() assert data.numel() == frames_per_batch diff --git a/test/test_libs.py b/test/test_libs.py index a1414948817..427eef522d0 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. import importlib -import logging from contextlib import nullcontext +from torchrl._utils import logger as torchrl_logger + from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay from torchrl.envs.transforms import ActionMask, TransformedEnv @@ -2350,7 +2351,7 @@ def test_direct_download(self, task, tmpdir): def test_d4rl_dummy(self, task): t0 = time.time() _ = D4RLExperienceReplay(task, split_trajs=True, from_env=True, batch_size=2) - logging.info(f"terminated test after {time.time()-t0}s") + torchrl_logger.info(f"terminated test after {time.time()-t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -2371,7 +2372,7 @@ def test_dataset_build(self, task, split_trajs, from_env): offline = sample.get(key) # assert sim.dtype == offline.dtype, key assert sim.shape[-1] == offline.shape[-1], key - logging.info(f"terminated test after {time.time()-t0}s") + torchrl_logger.info(f"terminated test after {time.time()-t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -2390,7 +2391,7 @@ def test_d4rl_iteration(self, task, split_trajs): for sample in data: # noqa: B007 i += 1 assert len(data) // i == batch_size - logging.info(f"terminated test after {time.time()-t0}s") + torchrl_logger.info(f"terminated test after {time.time()-t0}s") _MINARI_DATASETS = [] @@ -2425,14 +2426,14 @@ def _minari_selected_datasets(): @pytest.mark.slow class TestMinari: def test_load(self, selected_dataset, split): - logging.info("dataset", selected_dataset) + torchrl_logger.info(f"dataset {selected_dataset}") data = MinariExperienceReplay( selected_dataset, batch_size=32, split_trajs=split ) t0 = time.time() for i, sample in enumerate(data): t1 = time.time() - logging.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") + torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") assert data.metadata["action_space"].is_in(sample["action"]) assert data.metadata["observation_space"].is_in(sample["observation"]) t0 = time.time() @@ -2451,7 +2452,7 @@ def test_load(self): t0 = time.time() for i, _ in enumerate(data): t1 = time.time() - logging.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") + torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") t0 = time.time() if i == 10: break @@ -2484,7 +2485,7 @@ def test_load(self, image_size): assert (batch.get("pixels") != 0).any() assert (batch.get(("next", "pixels")) != 0).any() t1 = time.time() - logging.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") + torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") t0 = time.time() if i == 10: break @@ -3000,16 +3001,16 @@ def test_robohive(self, from_pixels): substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") ): - logging.info("not testing envs with prebuilt rendering") + torchrl_logger.info("not testing envs with prebuilt rendering") return if "Adroit" in envname: - logging.info("tcdm are broken") + torchrl_logger.info("tcdm are broken") return try: env = RoboHiveEnv(envname) except AttributeError as err: if "'MjData' object has no attribute 'get_body_xipos'" in str(err): - logging.info("tcdm are broken") + torchrl_logger.info("tcdm are broken") return else: raise err @@ -3017,7 +3018,7 @@ def test_robohive(self, from_pixels): from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 ): - logging.info("no camera") + torchrl_logger.info("no camera") return check_env_specs(env) except Exception as err: diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index a31836a4e72..e34868fdf9f 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse -import logging import os import sys @@ -14,6 +13,7 @@ import torch.distributed.rpc as rpc import torch.multiprocessing as mp from tensordict import TensorDict +from torchrl._utils import logger as torchrl_logger from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -111,7 +111,7 @@ def _construct_buffer(target): buffer_rref = rpc.remote(target, ReplayBufferNode, args=(1000,)) return buffer_rref except Exception as e: - logging.info(f"Failed to connect: {e}") + torchrl_logger.info(f"Failed to connect: {e}") time.sleep(RETRY_BACKOFF) raise RuntimeError("Unable to connect to replay buffer") diff --git a/test/test_shared.py b/test/test_shared.py index a2d2a88d6ca..912f230e8cf 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse -import logging import time import warnings @@ -11,6 +10,7 @@ import torch from tensordict import LazyStackedTensorDict, TensorDict from torch import multiprocessing as mp +from torchrl._utils import logger as torchrl_logger class TestShared: @@ -20,7 +20,7 @@ def remote_process(command_pipe_child, command_pipe_parent, tensordict): assert tensordict.is_shared() t0 = time.time() tensordict.zero_() - logging.info(f"zeroing time: {time.time() - t0}") + torchrl_logger.info(f"zeroing time: {time.time() - t0}") command_pipe_child.send("done") command_pipe_child.close() del command_pipe_child, command_pipe_parent, tensordict @@ -112,7 +112,7 @@ def driver_func(td, stack): command_pipe_child.close() command_pipe_parent.send("stack" if stack else "serial") time_spent = command_pipe_parent.recv() - logging.info(f"stack {stack}: time={time_spent}") + torchrl_logger.info(f"stack {stack}: time={time_spent}") for item in td.values(): assert (item == 0).all() proc.join() @@ -121,7 +121,7 @@ def driver_func(td, stack): @pytest.mark.parametrize("shared", ["shared", "memmap"]) def test_shared(self, shared): - logging.info(f"test_shared: shared={shared}") + torchrl_logger.info(f"test_shared: shared={shared}") torch.manual_seed(0) tensordict = TensorDict( source={ @@ -163,36 +163,36 @@ def test_memmap(idx, dtype, large_scale=False): td_sm = td.clone().share_memory_() td_memmap = td.clone().memmap_() - logging.info("\nTesting reading from TD") + torchrl_logger.info("\nTesting reading from TD") for i in range(2): t0 = time.time() td_sm[idx].clone() if i == 1: - logging.info(f"sm: {time.time() - t0:4.4f} sec") + torchrl_logger.info(f"sm: {time.time() - t0:4.4f} sec") t0 = time.time() td_memmap[idx].clone() if i == 1: - logging.info(f"memmap: {time.time() - t0:4.4f} sec") + torchrl_logger.info(f"memmap: {time.time() - t0:4.4f} sec") td_to_copy = td[idx].contiguous() for k in td_to_copy.keys(): td_to_copy.set_(k, torch.ones_like(td_to_copy.get(k))) - logging.info("\nTesting writing to TD") + torchrl_logger.info("\nTesting writing to TD") for i in range(2): t0 = time.time() sub_td_sm = td_sm.get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: - logging.info(f"sm td: {time.time() - t0:4.4f} sec") + torchrl_logger.info(f"sm td: {time.time() - t0:4.4f} sec") torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a")) t0 = time.time() sub_td_sm = td_memmap.get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: - logging.info(f"memmap td: {time.time() - t0:4.4f} sec") + torchrl_logger.info(f"memmap td: {time.time() - t0:4.4f} sec") torch.testing.assert_close(sub_td_sm.get("a")._tensor, td_to_copy.get("a")) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 98abe9648c6..9538cecb026 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -8,6 +8,7 @@ import functools import inspect + import logging import math @@ -29,6 +30,20 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp +LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "DEBUG") +logger = logging.getLogger("torchrl") +logger.setLevel(getattr(logging, LOGGING_LEVEL)) +# Disable propagation to the root logger +logger.propagate = False +# Remove all attached handlers +while logger.hasHandlers(): + logger.removeHandler(logger.handlers[0]) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + VERBOSE = strtobool(os.environ.get("VERBOSE", "0")) _os_is_windows = sys.platform == "win32" RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1")) @@ -76,7 +91,7 @@ def print(prefix=None): # noqa: T202 strings.append( f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)" ) - logging.info(" -- ".join(strings)) + logger.info(" -- ".join(strings)) @staticmethod def erase(): @@ -424,7 +439,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None): """ if VERBOSE: - logging.info("resetting implement_for") + logger.info("resetting implement_for") if setters_dict is None: setters_dict = copy(cls._implementations) for setter in setters_dict.values(): @@ -671,17 +686,17 @@ def format_size(size): total_size_bytes = get_directory_size(path) formatted_size = format_size(total_size_bytes) - logging.info(f"Directory size: {formatted_size}") + logger.info(f"Directory size: {formatted_size}") if os.path.isdir(path): - logging.info(indent + os.path.basename(path) + "/") + logger.info(indent + os.path.basename(path) + "/") indent += " " for item in os.listdir(path): print_directory_tree( os.path.join(path, item), indent=indent, display_metadata=False ) else: - logging.info(indent + os.path.basename(path)) + logger.info(indent + os.path.basename(path)) def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index ef972fd343e..eff2434d487 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -12,7 +12,6 @@ import functools import inspect -import logging import os import queue import sys @@ -46,6 +45,7 @@ _check_for_faulty_process, _ProcessNoWarn, accept_remote_rref_udf_invocation, + logger as torchrl_logger, prod, RL_WARNINGS, VERBOSE, @@ -2564,7 +2564,7 @@ def _main_async_collector( interruptor=interruptor, ) if verbose: - logging.info("Sync data collector created") + torchrl_logger.info("Sync data collector created") dc_iter = iter(inner_collector) j = 0 pipe_child.send("instantiated") @@ -2577,10 +2577,10 @@ def _main_async_collector( counter = 0 data_in, msg = pipe_child.recv() if verbose: - logging.info(f"worker {idx} received {msg}") + torchrl_logger.info(f"worker {idx} received {msg}") else: if verbose: - logging.info(f"poll failed, j={j}, worker={idx}") + torchrl_logger.info(f"poll failed, j={j}, worker={idx}") # default is "continue" (after first iteration) # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe # in that case, the main process probably expects the worker to continue collect data @@ -2600,7 +2600,7 @@ def _main_async_collector( counter += _timeout if verbose: - logging.info(f"worker {idx} has counter {counter}") + torchrl_logger.info(f"worker {idx} has counter {counter}") if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): raise RuntimeError( f"This process waited for {counter} seconds " @@ -2663,13 +2663,13 @@ def _main_async_collector( try: queue_out.put((data, j), timeout=_TIMEOUT) if verbose: - logging.info(f"worker {idx} successfully sent data") + torchrl_logger.info(f"worker {idx} successfully sent data") j += 1 has_timed_out = False continue except queue.Full: if verbose: - logging.info(f"worker {idx} has timed out") + torchrl_logger.info(f"worker {idx} has timed out") has_timed_out = True continue @@ -2715,7 +2715,7 @@ def _main_async_collector( del inner_collector, dc_iter pipe_child.send("closed") if verbose: - logging.info(f"collector {idx} closed") + torchrl_logger.info(f"collector {idx} closed") break else: diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 0c5c74b6510..e69032d01c1 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -6,7 +6,6 @@ r"""Generic distributed data-collector using torch.distributed backend.""" from __future__ import annotations -import logging import os import socket import warnings @@ -18,7 +17,7 @@ from tensordict import TensorDict from torch import nn -from torchrl._utils import _ProcessNoWarn, VERBOSE +from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -52,10 +51,10 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): os.environ["MASTER_PORT"] = str(tcpport) if verbose: - logging.info( + torchrl_logger.info( f"Rank0 IP address: '{rank0_ip}' \ttcp port: '{tcpport}', backend={backend}." ) - logging.info( + torchrl_logger.info( f"node with rank {rank} with world_size {world_size} -- launching distributed" ) torch.distributed.init_process_group( @@ -66,7 +65,7 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - logging.info(f"Connected!\nNode with rank {rank} -- creating store") + torchrl_logger.info(f"Connected!\nNode with rank {rank} -- creating store") # The store carries instructions for the node _store = torch.distributed.TCPStore( host_name=rank0_ip, @@ -160,7 +159,7 @@ def _run_collector( ): rank = torch.distributed.get_rank() if verbose: - logging.info( + torchrl_logger.info( f"node with rank {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): @@ -196,30 +195,32 @@ def _run_collector( ) total_frames = 0 if verbose: - logging.info(f"node with rank {rank} -- loop") + torchrl_logger.info(f"node with rank {rank} -- loop") while True: instruction = _store.get(f"NODE_{rank}_in") if verbose: - logging.info(f"node with rank {rank} -- new instruction: {instruction}") + torchrl_logger.info( + f"node with rank {rank} -- new instruction: {instruction}" + ) _store.delete_key(f"NODE_{rank}_in") if instruction == b"continue": _store.set(f"NODE_{rank}_status", b"busy") if verbose: - logging.info(f"node with rank {rank} -- new data") + torchrl_logger.info(f"node with rank {rank} -- new data") data = collector.next() total_frames += data.numel() if verbose: - logging.info(f"got data, total frames = {total_frames}") - logging.info(f"node with rank {rank} -- sending {data}") + torchrl_logger.info(f"got data, total frames = {total_frames}") + torchrl_logger.info(f"node with rank {rank} -- sending {data}") if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) if verbose: - logging.info(f"node with rank {rank} -- setting to 'done'") + torchrl_logger.info(f"node with rank {rank} -- setting to 'done'") if not sync: _store.set(f"NODE_{rank}_status", b"done") elif instruction == b"shutdown": if verbose: - logging.info(f"node with rank {rank} -- shutting down") + torchrl_logger.info(f"node with rank {rank} -- shutting down") try: collector.shutdown() except Exception: @@ -599,7 +600,7 @@ def _init_master_dist( backend, ): if self._VERBOSE: - logging.info( + torchrl_logger.info( f"launching main node with tcp port '{self.tcp_port}' and " f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." ) @@ -615,7 +616,7 @@ def _init_master_dist( init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) if self._VERBOSE: - logging.info("main initiated! Launching store...", end="\t") + torchrl_logger.info("main initiated! Launching store...") self._store = torch.distributed.TCPStore( host_name=self.IPAddr, port=int(TCP_PORT) + 1, @@ -624,12 +625,12 @@ def _init_master_dist( timeout=timedelta(10), ) if self._VERBOSE: - logging.info("done. Setting status to 'alive'") + torchrl_logger.info("done. Setting status to 'alive'") self._store.set("TRAINER_status", b"alive") def _make_container(self): if self._VERBOSE: - logging.info("making container") + torchrl_logger.info("making container") env_constructor = self.env_constructors[0] pseudo_collector = SyncDataCollector( env_constructor, @@ -641,11 +642,11 @@ def _make_container(self): for _data in pseudo_collector: break if self._VERBOSE: - logging.info("got data", _data) - logging.info("expanding...") + torchrl_logger.info(f"got data {_data}") + torchrl_logger.info("expanding...") self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) if self._VERBOSE: - logging.info("locking") + torchrl_logger.info("locking") if self._sync: self._tensordict_out.lock_() self._tensordict_out_unbind = self._tensordict_out.unbind(0) @@ -656,11 +657,11 @@ def _make_container(self): for td in self._tensordict_out: td.lock_() if self._VERBOSE: - logging.info("storage created:") - logging.info("shutting down...") + torchrl_logger.info("storage created:") + torchrl_logger.info("shutting down...") pseudo_collector.shutdown() if self._VERBOSE: - logging.info("dummy collector shut down!") + torchrl_logger.info("dummy collector shut down!") del pseudo_collector def _init_worker_dist_submitit(self, executor, i): @@ -743,7 +744,7 @@ def _init_workers(self): else: IPAddr = "localhost" if self._VERBOSE: - logging.info("Server IP address:", IPAddr) + torchrl_logger.info(f"Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -759,20 +760,20 @@ def _init_workers(self): else: for i in range(self.num_workers): if self._VERBOSE: - logging.info("Submitting job") + torchrl_logger.info("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) if self._VERBOSE: - logging.info("job id", job.job_id) # ID of your job + torchrl_logger.info(f"job id {job.job_id}") # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) if self._VERBOSE: - logging.info("job launched") + torchrl_logger.info("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) @@ -781,13 +782,13 @@ def iterator(self): def _iterator_dist(self): if self._VERBOSE: - logging.info("iterating...") + torchrl_logger.info("iterating...") total_frames = 0 if not self._sync: for rank in range(1, self.num_workers + 1): if self._VERBOSE: - logging.info(f"sending 'continue' to {rank}") + torchrl_logger.info(f"sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): @@ -832,7 +833,7 @@ def _next_sync(self, total_frames): if total_frames < self.total_frames: for rank in range(1, self.num_workers + 1): if self._VERBOSE: - logging.info(f"sending 'continue' to {rank}") + torchrl_logger.info(f"sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): @@ -866,7 +867,7 @@ def _next_async(self, total_frames, trackers): total_frames += data.numel() if total_frames < self.total_frames: if self._VERBOSE: - logging.info(f"sending 'continue' to {rank}") + torchrl_logger.info(f"sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers[i] = self._tensordict_out[i].irecv( src=i + 1, return_premature=True @@ -889,7 +890,7 @@ def update_policy_weights_(self, worker_rank=None) -> None: for i in workers: rank = i + 1 if self._VERBOSE: - logging.info(f"updating weights of {rank}") + torchrl_logger.info(f"updating weights of {rank}") self._store.set(f"NODE_{rank}_in", b"update_weights") if self._sync: self.policy_weights.send(rank) @@ -925,12 +926,12 @@ def shutdown(self): for i in range(self.num_workers): rank = i + 1 if self._VERBOSE: - logging.info(f"shutting down node with rank={rank}") + torchrl_logger.info(f"shutting down node with rank={rank}") self._store.set(f"NODE_{rank}_in", b"shutdown") for i in range(self.num_workers): rank = i + 1 if self._VERBOSE: - logging.info(f"getting status of node {rank}", end="\t") + torchrl_logger.info(f"getting status of node {rank}") status = self._store.get(f"NODE_{rank}_out") if status != b"down": raise RuntimeError(f"Expected 'down' but got status {status}.") @@ -945,4 +946,4 @@ def shutdown(self): elif self.launcher == "submitit_delayed": pass if self._VERBOSE: - logging.info("collector shut down") + torchrl_logger.info("collector shut down") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 6788e48ee3a..a467c763fa5 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -1,12 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from __future__ import annotations -import logging import warnings from typing import Callable, Dict, Iterator, List, OrderedDict, Union import torch import torch.nn as nn from tensordict import TensorDict, TensorDictBase + +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -18,7 +24,6 @@ from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -logger = logging.getLogger(__name__) RAY_ERR = None try: @@ -64,8 +69,8 @@ def print_remote_collector_info(self): f"Created remote collector with in machine " f"{get_node_ip_address()} using gpus {ray.get_gpu_ids()}" ) - # logger.warning(s) - logging.info(s) + # torchrl_logger.warning(s) + torchrl_logger.info(s) @classmethod diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index dbfc5a7dfd9..c32dbc8fea9 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -7,7 +7,6 @@ from __future__ import annotations import collections -import logging import os import socket import time @@ -15,6 +14,8 @@ from copy import copy, deepcopy from typing import Callable, List, OrderedDict +from torchrl._utils import logger as torchrl_logger + from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -77,7 +78,7 @@ def _rpc_init_collection_node( **tensorpipe_options, ) if verbose: - logging.info( + torchrl_logger.info( f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" ) rpc.init_rpc( @@ -457,7 +458,7 @@ def _init_master_rpc( f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]} ) if self._VERBOSE: - logging.info("init rpc") + torchrl_logger.info("init rpc") rpc.init_rpc( "TRAINER_NODE", rank=0, @@ -488,7 +489,9 @@ def _start_workers( time.sleep(time_interval) try: if self._VERBOSE: - logging.info(f"trying to connect to collector node {i + 1}") + torchrl_logger.info( + f"trying to connect to collector node {i + 1}" + ) collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}") break except RuntimeError as err: @@ -503,7 +506,7 @@ def _start_workers( if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) if self._VERBOSE: - logging.info("Making collector in remote node") + torchrl_logger.info("Making collector in remote node") collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -527,7 +530,7 @@ def _start_workers( if not self._sync: for i in range(num_workers): if self._VERBOSE: - logging.info("Asking for the first batch") + torchrl_logger.info("Asking for the first batch") future = rpc.rpc_async( collector_infos[i], collector_class.next, @@ -557,7 +560,7 @@ def _init_worker_rpc(self, executor, i): self._VERBOSE, ) if self._VERBOSE: - logging.info("job id", job.job_id) # ID of your job + torchrl_logger.info(f"job id {job.job_id}") # ID of your job return job elif self.launcher == "mp": job = _ProcessNoWarn( @@ -601,7 +604,7 @@ def _init(self): self.jobs = [] for i in range(self.num_workers): if self._VERBOSE: - logging.info(f"Submitting job {i}") + torchrl_logger.info(f"Submitting job {i}") job = self._init_worker_rpc( executor, i, @@ -658,7 +661,7 @@ def update_policy_weights_(self, workers=None, wait=True) -> None: futures = [] for i in workers: if self._VERBOSE: - logging.info(f"calling update on worker {i}") + torchrl_logger.info(f"calling update on worker {i}") futures.append( rpc.rpc_async( self.collector_infos[i], @@ -669,14 +672,14 @@ def update_policy_weights_(self, workers=None, wait=True) -> None: if wait: for i in workers: if self._VERBOSE: - logging.info(f"waiting for worker {i}") + torchrl_logger.info(f"waiting for worker {i}") futures[i].wait() if self._VERBOSE: - logging.info("got it!") + torchrl_logger.info("got it!") def _next_async_rpc(self): if self._VERBOSE: - logging.info("next async") + torchrl_logger.info("next async") if not len(self.futures): raise StopIteration( f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames." @@ -687,7 +690,7 @@ def _next_async_rpc(self): if self.update_after_each_batch: self.update_policy_weights_(workers=(i,), wait=False) if self._VERBOSE: - logging.info(f"future {i} is done") + torchrl_logger.info(f"future {i} is done") data = future.value() self._collected_frames += data.numel() if self._collected_frames < self.total_frames: @@ -702,7 +705,7 @@ def _next_async_rpc(self): def _next_sync_rpc(self): if self._VERBOSE: - logging.info("next sync: futures") + torchrl_logger.info("next sync: futures") if self.update_after_each_batch: self.update_policy_weights_() for i in range(self.num_workers): @@ -719,7 +722,7 @@ def _next_sync_rpc(self): if future.done(): data += [future.value()] if self._VERBOSE: - logging.info( + torchrl_logger.info( f"got data from {i} // data has len {len(data)} / {self.num_workers}" ) else: @@ -750,15 +753,15 @@ def shutdown(self): if self._shutdown: return if self._VERBOSE: - logging.info("shutting down") + torchrl_logger.info("shutting down") for future, i in self.futures: # clear the futures while future is not None and not future.done(): - logging.info(f"waiting for proc {i} to clear") + torchrl_logger.info(f"waiting for proc {i} to clear") future.wait() for i in range(self.num_workers): if self._VERBOSE: - logging.info(f"shutting down {i}") + torchrl_logger.info(f"shutting down {i}") rpc.rpc_sync( self.collector_infos[i], self.collector_class.shutdown, @@ -766,7 +769,7 @@ def shutdown(self): timeout=int(IDLE_TIMEOUT), ) if self._VERBOSE: - logging.info("rpc shutdown") + torchrl_logger.info("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) if self.launcher == "mp": for job in self.jobs: diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 7ea805248c9..3cd0728dd49 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -6,7 +6,6 @@ r"""Generic distributed data-collector using torch.distributed backend.""" from __future__ import annotations -import logging import os import socket from copy import copy, deepcopy @@ -16,7 +15,8 @@ import torch.cuda from tensordict import TensorDict from torch import nn -from torchrl._utils import _ProcessNoWarn, VERBOSE + +from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( @@ -65,7 +65,7 @@ def _distributed_init_collection_node( os.environ["MASTER_PORT"] = str(tcpport) if verbose: - logging.info( + torchrl_logger.info( f"node with rank {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): @@ -100,9 +100,9 @@ def _distributed_init_collection_node( **collector_kwargs, ) - logging.info("IP address:", rank0_ip, "\ttcp port:", tcpport) + torchrl_logger.info(f"IP address: {rank0_ip} \ttcp port: {tcpport}") if verbose: - logging.info(f"node with rank {rank} -- launching distributed") + torchrl_logger.info(f"node with rank {rank} -- launching distributed") torch.distributed.init_process_group( backend, rank=rank, @@ -111,9 +111,9 @@ def _distributed_init_collection_node( # init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - logging.info(f"node with rank {rank} -- creating store") + torchrl_logger.info(f"node with rank {rank} -- creating store") if verbose: - logging.info(f"node with rank {rank} -- loop") + torchrl_logger.info(f"node with rank {rank} -- loop") policy_weights.irecv(0) frames = 0 for i, data in enumerate(collector): @@ -454,7 +454,7 @@ def _init_master_dist( backend, ): TCP_PORT = self.tcp_port - logging.info("init master...", end="\t") + torchrl_logger.info("init master...") torch.distributed.init_process_group( backend, rank=0, @@ -462,7 +462,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - logging.info("done") + torchrl_logger.info("done") def _make_container(self): env_constructor = self.env_constructors[0] @@ -534,7 +534,7 @@ def _init_workers(self): hostname = socket.gethostname() IPAddr = socket.gethostbyname(hostname) - logging.info("Server IP address:", IPAddr) + torchrl_logger.info(f"Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -546,18 +546,18 @@ def _init_workers(self): executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.slurm_kwargs) for i in range(self.num_workers): - logging.info("Submitting job") + torchrl_logger.info("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - logging.info("job id", job.job_id) # ID of your job + torchrl_logger.info(f"job id {job.job_id}") # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - logging.info("job launched") + torchrl_logger.info("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index 24444fc171d..5ae8fc3f60f 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -1,8 +1,7 @@ -import logging import subprocess import time -from torchrl._utils import VERBOSE +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, DEFAULT_SLURM_CONF_MAIN, @@ -97,7 +96,7 @@ def exec_fun(): executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address - logging.info(f"job id: {main_job.job_id}") + torchrl_logger.info(f"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: @@ -108,11 +107,11 @@ def exec_fun(): except ValueError: time.sleep(0.5) continue - logging.info(f"node: {node}") + torchrl_logger.info(f"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip() - logging.info(f"IP: {rank0_ip}") + torchrl_logger.info(f"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 28fddb79fca..46f29669c63 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -8,7 +8,6 @@ import gzip import io import json -import logging import os import shutil @@ -21,6 +20,7 @@ import torch from tensordict import MemoryMappedTensor, TensorDict from torch import multiprocessing as mp +from torchrl._utils import logger as torchrl_logger from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import ( @@ -484,7 +484,7 @@ def _is_downloaded(self): return False def _download_and_preproc(self): - logging.info( + torchrl_logger.info( f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..." ) if os.path.exists(self.dataset_path): @@ -544,7 +544,7 @@ def _download_and_proc_split( tempdir = Path(tempdir) os.makedirs(tempdir / str(run)) files_str = " ".join(run_files) # .decode("utf-8") - logging.info("downloading", files_str) + torchrl_logger.info(f"downloading {files_str}") command = f"gsutil -m cp {files_str} {tempdir}/{run}" subprocess.run( command, shell=True @@ -559,7 +559,7 @@ def _download_and_proc_split( shutil.rmtree(path) raise shutil.rmtree(tempdir / str(run)) - logging.info(f"Concluded run {run} out of {total_episodes}") + torchrl_logger.info(f"Concluded run {run} out of {total_episodes}") @classmethod def _preproc_run(cls, path, gz_files, run): diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 468fcb9150c..d02c292a67c 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -5,8 +5,6 @@ from __future__ import annotations import importlib - -import logging import os import tempfile import urllib @@ -21,6 +19,8 @@ from tensordict import make_tensordict, PersistentTensorDict, TensorDict +from torchrl._utils import logger as torchrl_logger + from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS @@ -438,7 +438,7 @@ def _shift_reward_done(self, dataset): def _download_dataset_from_url(dataset_url, dataset_path): dataset_filepath = _filepath_from_url(dataset_url, dataset_path) if not os.path.exists(dataset_filepath): - logging.info("Downloading dataset:", dataset_url, "to", dataset_filepath) + torchrl_logger.info(f"Downloading dataset: {dataset_url} to {dataset_filepath}") urllib.request.urlretrieve(dataset_url, dataset_filepath) if not os.path.exists(dataset_filepath): raise IOError("Failed to download dataset from %s" % dataset_url) @@ -462,7 +462,7 @@ def _filepath_from_url(dataset_url, dataset_path): if __name__ == "__main__": data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) - logging.info(data) + torchrl_logger.info(data) for sample in data: - logging.info(sample) + torchrl_logger.info(sample) break diff --git a/torchrl/data/datasets/gen_dgrl.py b/torchrl/data/datasets/gen_dgrl.py index d1ca0b15fb8..da5ca42f91b 100644 --- a/torchrl/data/datasets/gen_dgrl.py +++ b/torchrl/data/datasets/gen_dgrl.py @@ -5,7 +5,6 @@ from __future__ import annotations import importlib.util -import logging import os import tarfile import tempfile @@ -16,6 +15,7 @@ import torch from tensordict import TensorDict +from torchrl._utils import logger as torchrl_logger from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import TensorStorage @@ -240,7 +240,7 @@ def _unpack_category_file( batch = self._PROCESS_NPY_BATCH _, file_name, _ = link file_path = os.path.join(download_folder, file_name) - logging.info( + torchrl_logger.info( f"Unpacking dataset file {file_path} ({file_name}) to {download_folder}." ) idx = 0 @@ -326,14 +326,14 @@ def _download_category_file( file_path = os.path.join(download_folder, file_name) if skip_downloaded_files and os.path.isfile(file_path): - logging.info(f"Skipping {file_path}, already downloaded!") + torchrl_logger.info(f"Skipping {file_path}, already downloaded!") return file_name, True in_progress_folder = os.path.join(download_folder, "_in_progress") os.makedirs(in_progress_folder, exist_ok=True) in_progress_file_path = os.path.join(in_progress_folder, file_name) - logging.info( + torchrl_logger.info( f"Downloading dataset file {file_name} ({url}) to {in_progress_file_path}." ) cls._download_with_progress_bar(url, in_progress_file_path) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index babe5638c91..fa3962b0b82 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -6,7 +6,6 @@ import importlib.util import json -import logging import os.path import shutil import tempfile @@ -20,7 +19,7 @@ import torch from tensordict import PersistentTensorDict, TensorDict -from torchrl._utils import KeyDependentDefaultDict +from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler @@ -110,7 +109,7 @@ class MinariExperienceReplay(TensorDictReplayBuffer): >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force") >>> for sample in data: - ... logging.info(sample) + ... torchrl_logger.info(sample) ... break TensorDict( fields={ @@ -252,7 +251,7 @@ def _download_and_preproc(self): td_data = TensorDict({}, []) total_steps = 0 - logging.info("first read through data to create data structure...") + torchrl_logger.info("first read through data to create data structure...") h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") # populate the tensordict episode_dict = {} @@ -291,13 +290,11 @@ def _download_and_preproc(self): td_data["done"] = td_data["truncated"] | td_data["terminated"] td_data = td_data.expand(total_steps) # save to designated location - logging.info( - f"creating tensordict data in {self.data_path_root}: ", end="\t" - ) + torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ") td_data = td_data.memmap_like(self.data_path_root) - logging.info("tensordict structure:", td_data) + torchrl_logger.info(f"tensordict structure: {td_data}") - logging.info(f"Reading data from {max(*episode_dict) + 1} episodes") + torchrl_logger.info(f"Reading data from {max(*episode_dict) + 1} episodes") index = 0 with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: # iterate over episodes and populate the tensordict diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 8d8b84fb7a9..bf8316f8c3c 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -5,7 +5,6 @@ from __future__ import annotations import importlib.util -import logging import os.path import shutil import tempfile @@ -17,7 +16,11 @@ import torch from tensordict import PersistentTensorDict, TensorDict -from torchrl._utils import KeyDependentDefaultDict, print_directory_tree +from torchrl._utils import ( + KeyDependentDefaultDict, + logger as torchrl_logger, + print_directory_tree, +) from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler @@ -92,11 +95,11 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): >>> for batch in d: ... break >>> # data is organised by seed and episode, but stored contiguously - >>> logging.info(batch["seed"], batch["episode"]) + >>> torchrl_logger.info(f"{batch['seed']}, {batch['episode']}") tensor([2, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2, 0, 2, 0, 2, 2, 1, 0, 2, 0, 0, 1, 1, 2, 1]) tensor([17, 20, 18, 9, 6, 1, 12, 6, 2, 6, 8, 15, 8, 21, 17, 3, 9, 20, 23, 12, 3, 16, 19, 16, 16, 4, 4, 12, 1, 2, 15, 24]) - >>> logging.info(batch) + >>> torchrl_logger.info(batch) TensorDict( fields={ action: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float64, is_shared=False), @@ -241,13 +244,13 @@ def _download_and_preproc(self): def _preproc_h5(self, h5_data_files): td_data = TensorDict({}, []) total_steps = 0 - logging.info( + torchrl_logger.info( f"first read through data files {h5_data_files} to create data structure..." ) episode_dict = {} h5_datas = [] for seed, h5_data_name in enumerate(h5_data_files): - logging.info("\nReading", h5_data_name) + torchrl_logger.info(f"\nReading {h5_data_name}") h5_data = PersistentTensorDict.from_h5(h5_data_name) h5_datas.append(h5_data) for i, (episode_key, episode) in enumerate(h5_data.items()): @@ -256,7 +259,7 @@ def _preproc_h5(self, h5_data_files): episode_dict[(seed, episode_num)] = (episode_key, episode_len) # Get the total number of steps for the dataset total_steps += episode_len - logging.info("total_steps", total_steps, end="\t") + torchrl_logger.info(f"total_steps {total_steps}") if i == 0 and seed == 0: td_data.set("episode", 0) td_data.set("seed", 0) @@ -279,14 +282,14 @@ def _preproc_h5(self, h5_data_files): td_data = td_data.expand(total_steps) # save to designated location - logging.info(f"creating tensordict data in {self.data_path_root}: ", end="\t") + torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ") td_data = td_data.memmap_like(self.data_path_root) - # logging.info("tensordict structure:", td_data) - logging.info( - "Local dataset structure:", print_directory_tree(self.data_path_root) + # torchrl_logger.info(f"tensordict structure: {td_data}") + torchrl_logger.info( + f"Local dataset structure: {print_directory_tree(self.data_path_root)}" ) - logging.info(f"Reading data from {len(episode_dict)} episodes") + torchrl_logger.info(f"Reading data from {len(episode_dict)} episodes") index = 0 if _has_tqdm: from tqdm import tqdm diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 54e933f71f5..111efdfb011 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -8,7 +8,6 @@ import importlib import json -import logging import os import pathlib import shutil @@ -23,7 +22,7 @@ from tensordict import PersistentTensorDict, TensorDict from torch import multiprocessing as mp -from torchrl._utils import KeyDependentDefaultDict +from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler @@ -300,7 +299,7 @@ def _download_and_preproc(cls, dataset_id, data_path, num_workers): func(subfolder, filename) for (subfolder, filename) in zip(paths_to_proc, files_to_proc) ] - logging.info("Downloaded, processing files") + torchrl_logger.info("Downloaded, processing files") if _has_tqdm: import tqdm @@ -328,7 +327,7 @@ def _download_and_preproc(cls, dataset_id, data_path, num_workers): # From this point, the local paths are non needed anymore td_save = td_save.expand(total_steps).memmap_like(data_path, num_threads=32) - logging.info("Saved tensordict:", td_save) + torchrl_logger.info(f"Saved tensordict: {td_save}") idx0 = 0 idx1 = 0 while len(files): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 45c7be64a1a..d4d81f10bc1 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -5,7 +5,6 @@ import abc import json -import logging import os import textwrap import warnings @@ -25,7 +24,12 @@ from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten -from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE +from torchrl._utils import ( + _CKPT_BACKEND, + implement_for, + logger as torchrl_logger, + VERBOSE, +) from torchrl.data.replay_buffers.utils import INT_CLASSES try: @@ -688,7 +692,7 @@ def _init( data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 ) -> None: if VERBOSE: - logging.info("Creating a TensorStorage...") + torchrl_logger.info("Creating a TensorStorage...") if self.device == "auto": self.device = data.device if is_tensorclass(data): @@ -851,7 +855,7 @@ def load_state_dict(self, state_dict): def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if VERBOSE: - logging.info("Creating a MemmapStorage...") + torchrl_logger.info("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device if self.device.type != "cpu": @@ -870,7 +874,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: ): if VERBOSE: filesize = os.path.getsize(tensor.filename) / 1024 / 1024 - logging.info( + torchrl_logger.info( f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." ) else: @@ -1272,7 +1276,7 @@ def _init_pytree_common(tensor_path, scratch_dir, max_size, tensor): ) if VERBOSE: filesize = os.path.getsize(out.filename) / 1024 / 1024 - logging.info( + torchrl_logger.info( f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." ) return out diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 0824cff585f..19090d3f4c5 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -5,7 +5,6 @@ from __future__ import annotations import importlib.util -import logging import os from pathlib import Path @@ -16,6 +15,7 @@ from tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey +from torchrl._utils import logger as torchrl_logger from torchrl.data.replay_buffers import ( SamplerWithoutReplacement, TensorDictReplayBuffer, @@ -141,7 +141,7 @@ def load(self): data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0] data_dir_total = data_dir / split / str(max_length) # search for data - logging.info("Looking for data in", data_dir_total) + torchrl_logger.info(f"Looking for data in {data_dir_total}") if os.path.exists(data_dir_total): dataset = TensorDict.load_memmap(data_dir_total) return dataset diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index e22531c1c22..5e88cf4e86d 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -6,7 +6,6 @@ from __future__ import annotations import gc -import logging import os from collections import OrderedDict @@ -22,7 +21,12 @@ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple, unravel_key from torch import multiprocessing as mp -from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE +from torchrl._utils import ( + _check_for_faulty_process, + _ProcessNoWarn, + logger as torchrl_logger, + VERBOSE, +) from torchrl.data.tensor_specs import CompositeSpec from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import _EnvPostInit, EnvBase @@ -623,7 +627,7 @@ def close(self) -> None: if self.is_closed: raise RuntimeError("trying to close a closed environment") if self._verbose: - logging.info(f"closing {self.__class__.__name__}") + torchrl_logger.info(f"closing {self.__class__.__name__}") self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None @@ -1047,7 +1051,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: - logging.info(f"initiating worker {idx}") + torchrl_logger.info(f"initiating worker {idx}") # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() env_fun = self.create_env_fn[idx] @@ -1305,7 +1309,7 @@ def _shutdown_workers(self) -> None: ) for i, channel in enumerate(self.parent_channels): if self._verbose: - logging.info(f"closing {i}") + torchrl_logger.info(f"closing {i}") channel.send(("close", None)) self._events[i].wait() self._events[i].clear() @@ -1473,7 +1477,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): elif cmd == "init": if verbose: - logging.info(f"initializing {pid}") + torchrl_logger.info(f"initializing {pid}") if initialized: raise RuntimeError("worker already initialized") i = 0 @@ -1489,7 +1493,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): elif cmd == "reset": if verbose: - logging.info(f"resetting worker {pid}") + torchrl_logger.info(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") cur_td = env.reset(tensordict=data) @@ -1541,7 +1545,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): mp_event.set() child_pipe.close() if verbose: - logging.info(f"{pid} closed") + torchrl_logger.info(f"{pid} closed") gc.collect() break diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 28c9e00c42a..89ee8cc5614 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -5,13 +5,14 @@ from __future__ import annotations -import logging from collections import OrderedDict from typing import Callable, Dict, Optional, Union import torch from tensordict import TensorDictBase +from torchrl._utils import logger as torchrl_logger + from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase, EnvMetaData @@ -98,7 +99,7 @@ def share_memory(self, state_dict: OrderedDict) -> None: if not item.is_shared(): item.share_memory_() else: - logging.info( + torchrl_logger.info( f"{self.env_type}: {item} is already shared" ) # , deleting key') del state_dict[key] diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 002b270cd84..38995a07a6b 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -7,13 +7,13 @@ import abc import itertools -import logging import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from tensordict import TensorDict, TensorDictBase +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -454,7 +454,7 @@ def set_info_dict_reader( isinstance(info_dict_reader, default_info_dict_reader) and info_dict_reader.info_spec is None ): - logging.info( + torchrl_logger.info( "The info_dict_reader does not have specs. The only way to palliate to this issue automatically " "is to run a dummy rollout and gather the specs automatically. " "To silence this message, provide the specs directly to your spec reader." diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 2e96efcaf6a..9293dd195a0 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -7,14 +7,13 @@ import collections import importlib -import logging import os from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch -from torchrl._utils import VERBOSE +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -33,7 +32,7 @@ n = torch.cuda.device_count() - 1 os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n)) if VERBOSE: - logging.info("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"]) + torchrl_logger.info(f"EGL_DEVICE_ID: {os.environ['EGL_DEVICE_ID']}") _has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index 410e25a1b28..a029a0beb5b 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -6,13 +6,13 @@ from __future__ import annotations import importlib -import logging from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch from tensordict import TensorDict, TensorDictBase +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, @@ -308,7 +308,7 @@ def _treevalue_to_dict( def _set_seed(self, seed: Optional[int]): if seed is not None: - logging.info( + torchrl_logger.info( "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\ supported by envpool. Please create a new environment, passing the seed to the constructor." ) @@ -398,7 +398,7 @@ def _build_env( def _set_seed(self, seed: Optional[int]): """Library EnvPool only supports setting a seed by recreating the environment.""" if seed is not None: - logging.debug("Recreating EnvPool environment to set seed.") + torchrl_logger.debug("Recreating EnvPool environment to set seed.") self.create_env_kwargs["seed"] = seed self._env = self._build_env( env_name=self.env_name, diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index 746bfe52f1d..d8bec1cf524 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import importlib -import logging import os import subprocess from functools import partial @@ -13,6 +12,7 @@ import torch from tensordict import TensorDictBase from torch import nn +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -237,7 +237,7 @@ def install_vc_models(cls, auto_exit=False): try: from vc_models import models # noqa: F401 - logging.info("vc_models found, no need to install.") + torchrl_logger.info("vc_models found, no need to install.") except ModuleNotFoundError: HOME = os.environ.get("HOME") vcdir = HOME + "/.cache/torchrl/eai-vc" diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e03cb4043cb..ebb9100655c 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -7,7 +7,6 @@ import contextlib import importlib.util -import logging import os import re from enum import Enum @@ -32,7 +31,7 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey -from torchrl._utils import _replace_last +from torchrl._utils import _replace_last, logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -525,7 +524,7 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): f"spec check failed at root for spec {name}={spec} and data {td}." ) - logging.info("check_env_specs succeeded!") + torchrl_logger.info("check_env_specs succeeded!") def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1): diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index d9b5f45c25f..256d0a2e840 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -108,7 +108,7 @@ def __del__(self): class CSVLogger(Logger): - """A minimal-dependecy CSV-logger. + """A minimal-dependecy CSV logger. Args: exp_name (str): The name of the experiment. diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 265be40b785..8dc8d478ddf 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -2,14 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging from copy import copy from dataclasses import dataclass, field as dataclass_field from typing import Any, Callable, Optional, Sequence, Tuple, Union import torch -from torchrl._utils import VERBOSE +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import env_creator, EnvCreator @@ -394,7 +393,7 @@ def get_stats_random_rollout( )() if VERBOSE: - logging.info("computing state stats") + torchrl_logger.info("computing state stats") if not hasattr(cfg, "init_env_steps"): raise AttributeError("init_env_steps missing from arguments.") @@ -427,7 +426,7 @@ def get_stats_random_rollout( s[s == 0] = 1.0 if VERBOSE: - logging.info( + torchrl_logger.info( f"stats computed for {val_stats.numel()} steps. Got: \n" f"loc = {m}, \n" f"scale = {s}" diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index c57642a7237..05f566674f2 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -247,7 +247,7 @@ def make_redq_model( >>> import hydra >>> from hydra.core.config_store import ConfigStore >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["observation"]), + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v4"), Compose(DoubleToFloat(["observation"]), ... CatTensors(["observation"], "observation_vector"))) >>> device = torch.device("cpu") >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 13d5ae4c968..207bcec0ffd 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging from dataclasses import dataclass from typing import List, Optional, Union from warnings import warn @@ -13,7 +12,7 @@ from torch import optim from torch.optim.lr_scheduler import CosineAnnealingLR -from torchrl._utils import VERBOSE +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.collectors import DataCollectorBase from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer from torchrl.envs.common import EnvBase @@ -174,7 +173,7 @@ def make_trainer( raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}") if VERBOSE: - logging.info( + torchrl_logger.info( f"collector = {collector}; \n" f"loss_module = {loss_module}; \n" f"recorder = {recorder}; \n" diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 0764bf9fb72..f844613432c 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -import logging import pathlib import warnings from collections import defaultdict, OrderedDict @@ -21,7 +20,12 @@ from tensordict.utils import expand_right from torch import nn, optim -from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict, VERBOSE +from torchrl._utils import ( + _CKPT_BACKEND, + KeyDependentDefaultDict, + logger as torchrl_logger, + VERBOSE, +) from torchrl.collectors.collectors import DataCollectorBase from torchrl.collectors.utils import split_trajectories from torchrl.data.replay_buffers import ( @@ -476,7 +480,7 @@ def __del__(self): def shutdown(self): if VERBOSE: - logging.info("shutting down collector") + torchrl_logger.info("shutting down collector") self.collector.shutdown() def optim_steps(self, batch: TensorDictBase) -> None: From 01a2216d1962616bdab3e7ba31c94fcc2d1138b3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 1 Feb 2024 10:10:04 +0000 Subject: [PATCH 28/35] [Versioning] v0.4.0 (#1860) --- .github/scripts/m1_script.sh | 2 +- .github/workflows/wheels.yml | 8 ++++---- setup.py | 2 +- version.txt | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh index 8e929443ef6..6d2f194e3bc 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/m1_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export BUILD_VERSION=0.3.0 +export BUILD_VERSION=0.4.0 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 47c1b0c6fec..e910ba4201b 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -4,7 +4,7 @@ on: types: [opened, synchronize, reopened] push: branches: - - release/0.3.0 + - release/0.4.0 concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.3.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -72,7 +72,7 @@ jobs: run: | export CC=clang CXX=clang++ python3 -mpip install wheel - BUILD_VERSION=0.3.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: @@ -104,7 +104,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - BUILD_VERSION=0.3.0 python3 setup.py bdist_wheel + BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/setup.py b/setup.py index f31a2ed9f5c..44e772528a7 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,7 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.3.0" + tensordict_dep = "tensordict>=0.4.0" if is_nightly: version = get_nightly_version() diff --git a/version.txt b/version.txt index 0d91a54c7d4..1d0ba9ea182 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.0 +0.4.0 From c2f43e8bd90a4df4f6eb22c766ae62a3d0c3895f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 2 Feb 2024 11:54:26 +0000 Subject: [PATCH 29/35] [Doc] Fix tutos (#1863) --- docs/source/conf.py | 5 + tutorials/sphinx-tutorials/coding_ddpg.py | 210 +++++++++--------- tutorials/sphinx-tutorials/coding_dqn.py | 51 +++-- tutorials/sphinx-tutorials/coding_ppo.py | 196 ++++++++-------- tutorials/sphinx-tutorials/dqn_with_rnn.py | 36 ++- tutorials/sphinx-tutorials/multi_task.py | 9 +- tutorials/sphinx-tutorials/multiagent_ppo.py | 11 +- tutorials/sphinx-tutorials/pendulum.py | 9 +- .../sphinx-tutorials/pretrained_models.py | 16 +- tutorials/sphinx-tutorials/rb_tutorial.py | 46 ++-- tutorials/sphinx-tutorials/torchrl_demo.py | 210 ++++++++++++------ tutorials/sphinx-tutorials/torchrl_envs.py | 13 +- 12 files changed, 486 insertions(+), 326 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index f0821ede0bf..060103b48b4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -189,3 +189,8 @@ generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial") # generate_tutorial_references("../../tutorials/src/", "src") generate_tutorial_references("../../tutorials/media/", "media") + +# We do this to indicate that the script is run by sphinx +import builtins + +builtins.__sphinx_build__ = True diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 85590c545fa..5f8bf2c0830 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -7,6 +7,9 @@ """ ############################################################################## +# Overview +# -------- +# # TorchRL separates the training of RL algorithms in various pieces that will be # assembled in your training script: the environment, the data collection and # storage, the model and finally the loss function. @@ -14,29 +17,33 @@ # TorchRL losses (or "objectives") are stateful objects that contain the # trainable parameters (policy and value models). # This tutorial will guide you through the steps to code a loss from the ground up -# using torchrl. +# using TorchRL. # # To this aim, we will be focusing on DDPG, which is a relatively straightforward # algorithm to code. -# DDPG (`Deep Deterministic Policy Gradient `_) +# `Deep Deterministic Policy Gradient `_ (DDPG) # is a simple continuous control algorithm. It consists in learning a # parametric value function for an action-observation pair, and -# then learning a policy that outputs actions that maximise this value +# then learning a policy that outputs actions that maximize this value # function given a certain observation. # -# Key learnings: +# What you will learn: # # - how to write a loss module and customize its value estimator; -# - how to build an environment in torchrl, including transforms -# (e.g. data normalization) and parallel execution; +# - how to build an environment in TorchRL, including transforms +# (for example, data normalization) and parallel execution; # - how to design a policy and value network; # - how to collect data from your environment efficiently and store them # in a replay buffer; # - how to store trajectories (and not transitions) in your replay buffer); -# - and finally how to evaluate your model. +# - how to evaluate your model. +# +# Prerequisites +# ~~~~~~~~~~~~~ # -# This tutorial assumes that you have completed the PPO tutorial which gives -# an overview of the torchrl components and dependencies, such as +# This tutorial assumes that you have completed the +# `PPO tutorial `_ which gives +# an overview of the TorchRL components and dependencies, such as # :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`, # although it should be # sufficiently transparent to be understood without a deep understanding of @@ -44,17 +51,20 @@ # # .. note:: # We do not aim at giving a SOTA implementation of the algorithm, but rather -# to provide a high-level illustration of torchrl's loss implementations +# to provide a high-level illustration of TorchRL's loss implementations # and the library features that are to be used in the context of # this algorithm. # # Imports and setup # ----------------- # +# .. code-block:: bash +# +# %%bash +# pip3 install torchrl mujoco glfw # sphinx_gallery_start_ignore import warnings -from typing import Tuple warnings.filterwarnings("ignore") from torch import multiprocessing @@ -63,24 +73,34 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore -import torch.cuda + +import torch import tqdm ############################################################################### -# We will execute the policy on cuda if available +# We will execute the policy on CUDA if available +is_fork = multiprocessing.get_start_method() == "fork" device = ( - torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") ) +collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA ############################################################################### -# torchrl :class:`~torchrl.objectives.LossModule` +# TorchRL :class:`~torchrl.objectives.LossModule` # ----------------------------------------------- # # TorchRL provides a series of losses to use in your training scripts. @@ -89,11 +109,11 @@ # # The main characteristics of TorchRL losses are: # -# - they are stateful objects: they contain a copy of the trainable parameters +# - They are stateful objects: they contain a copy of the trainable parameters # such that ``loss_module.parameters()`` gives whatever is needed to train the # algorithm. -# - They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward` -# method will receive a tensordict as input that contains all the necessary +# - They follow the ``TensorDict`` convention: the :meth:`torch.nn.Module.forward` +# method will receive a TensorDict as input that contains all the necessary # information to return a loss value. # # >>> data = replay_buffer.sample() @@ -101,8 +121,9 @@ # # - They output a :class:`tensordict.TensorDict` instance with the loss values # written under a ``"loss_"`` where ``smth`` is a string describing the -# loss. Additional keys in the tensordict may be useful metrics to log during +# loss. Additional keys in the ``TensorDict`` may be useful metrics to log during # training time. +# # .. note:: # The reason we return independent losses is to let the user use a different # optimizer for different sets of parameters for instance. Summing the losses @@ -129,14 +150,14 @@ # # Let us start with the :meth:`~torchrl.objectives.LossModule.__init__` # method. DDPG aims at solving a control task with a simple strategy: -# training a policy to output actions that maximise the value predicted by +# training a policy to output actions that maximize the value predicted by # a value network. Hence, our loss module needs to receive two networks in its # constructor: an actor and a value networks. We expect both of these to be -# tensordict-compatible objects, such as +# TensorDict-compatible objects, such as # :class:`tensordict.nn.TensorDictModule`. # Our loss function will need to compute a target value and fit the value # network to this, and generate an action and fit the policy such that its -# value estimate is maximised. +# value estimate is maximized. # # The crucial step of the :meth:`LossModule.__init__` method is the call to # :meth:`~torchrl.LossModule.convert_to_functional`. This method will extract @@ -149,7 +170,7 @@ # model with different sets of parameters, called "trainable" and "target" # parameters. # The "trainable" parameters are those that the optimizer needs to fit. The -# "target" parameters are usually a copy of the formers with some time lag +# "target" parameters are usually a copy of the former's with some time lag # (absolute or diluted through a moving average). # These target parameters are used to compute the value associated with the # next observation. One the advantages of using a set of target parameters @@ -163,7 +184,7 @@ # accessible but this will just return a **detached** version of the # actor parameters. # -# Later, we will see how the target parameters should be updated in torchrl. +# Later, we will see how the target parameters should be updated in TorchRL. # from tensordict.nn import TensorDictModule @@ -235,27 +256,22 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): hp.update(hyperparams) value_key = "state_action_value" if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator( - value_network=self.actor_critic, value_key=value_key, **hp - ) + self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator( - value_network=self.actor_critic, value_key=value_key, **hp - ) + self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.GAE: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator( - value_network=self.actor_critic, value_key=value_key, **hp - ) + self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") + self._value_estimator.set_keys(value=value_key) ############################################################################### -# The ``make_value_estimator`` method can but does not need to be called: if +# The ``make_value_estimator`` method can but does not need to be called: ifgg # not, the :class:`~torchrl.objectives.LossModule` will query this method with # its default estimator. # @@ -265,7 +281,7 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): # The central piece of an RL algorithm is the training loss for the actor. # In the case of DDPG, this function is quite simple: we just need to compute # the value associated with an action computed using the policy and optimize -# the actor weights to maximise this value. +# the actor weights to maximize this value. # # When computing this value, we must make sure to take the value parameters out # of the graph, otherwise the actor and value loss will be mixed up. @@ -302,7 +318,7 @@ def _loss_actor( def _loss_value( self, tensordict, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +): td_copy = tensordict.clone() # V(s, a) @@ -325,7 +341,7 @@ def _loss_value( tensordict, target_params=target_params ).squeeze(-1) - # Computes the value loss: L2, L1 or smooth L1 depending on self.loss_funtion + # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) td_error = (pred_val - target_value).pow(2) @@ -337,7 +353,7 @@ def _loss_value( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The only missing piece is the forward method, which will glue together the -# value and actor loss, collect the cost values and write them in a tensordict +# value and actor loss, collect the cost values and write them in a ``TensorDict`` # delivered to the user. from tensordict import TensorDict, TensorDictBase @@ -397,7 +413,7 @@ class DDPGLoss(LossModule): # For this example, we will be using the ``"cheetah"`` task. The goal is to make # a half-cheetah run as fast as possible. # -# In TorchRL, one can create such a task by relying on dm_control or gym: +# In TorchRL, one can create such a task by relying on ``dm_control`` or ``gym``: # # .. code-block:: python # @@ -411,7 +427,7 @@ class DDPGLoss(LossModule): # # By default, these environment disable rendering. Training from states is # usually easier than training from images. To keep things simple, we focus -# on learning from states only. To pass the pixels to the tensordicts that +# on learning from states only. To pass the pixels to the ``tensordicts`` that # are collected by :func:`env.step()`, simply pass the ``from_pixels=True`` # argument to the constructor: # @@ -420,7 +436,7 @@ class DDPGLoss(LossModule): # env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True) # # We write a :func:`make_env` helper function that will create an environment -# with either one of the two backends considered above (dm-control or gym). +# with either one of the two backends considered above (``dm-control`` or ``gym``). # from torchrl.envs.libs.dm_control import DMControlEnv @@ -431,7 +447,7 @@ class DDPGLoss(LossModule): def make_env(from_pixels=False): - """Create a base env.""" + """Create a base ``env``.""" global env_library global env_name @@ -502,7 +518,7 @@ def make_env(from_pixels=False): def make_transformed_env( env, ): - """Apply transforms to the env (such as reward scaling and state normalization).""" + """Apply transforms to the ``env`` (such as reward scaling and state normalization).""" env = TransformedEnv(env) @@ -511,16 +527,6 @@ def make_transformed_env( # syntax. env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) - double_to_float_list = [] - double_to_float_inv_list = [] - if env_library is DMControlEnv: - # DMControl requires double-precision - double_to_float_list += [ - "reward", - "action", - ] - double_to_float_inv_list += ["action"] - # We concatenate all states into a single "observation_vector" # even if there is a single tensor, it'll be renamed in "observation_vector". # This facilitates the downstream operations as we know the name of the @@ -536,16 +542,12 @@ def make_transformed_env( # version of the transform env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True)) - double_to_float_list.append(out_key) - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) + env.append_transform(DoubleToFloat()) env.append_transform(StepCounter(max_frames_per_traj)) - # We need a marker for the start of trajectories for our OU exploration: + # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU) + # exploration: env.append_transform(InitTracker()) return env @@ -608,15 +610,16 @@ def make_t_env(): return env -# The backend can be gym or dm_control +# The backend can be ``gym`` or ``dm_control`` backend = "gym" ############################################################################### # .. note:: +# # ``frame_skip`` batches multiple step together with a single action -# If > 1, the other frame counts (e.g. frames_per_batch, total_frames) need to -# be adjusted to have a consistent total number of frames collected across -# experiments. This is important as raising the frame-skip but keeping the +# If > 1, the other frame counts (for example, frames_per_batch, total_frames) +# need to be adjusted to have a consistent total number of frames collected +# across experiments. This is important as raising the frame-skip but keeping the # total number of frames unchanged may seem like cheating: all things compared, # a dataset of 10M elements collected with a frame-skip of 2 and another with # a frame-skip of 1 actually have a ratio of interactions with the environment @@ -630,7 +633,7 @@ def make_t_env(): ############################################################################### # We also define when a trajectory will be truncated. A thousand steps (500 if -# frame-skip = 2) is a good number to use for cheetah: +# frame-skip = 2) is a good number to use for the cheetah task: max_frames_per_traj = 500 @@ -660,7 +663,7 @@ def get_env_stats(): ############################################################################### # Normalization stats # ~~~~~~~~~~~~~~~~~~~ -# Number of random steps used as for stats computation using ObservationNorm +# Number of random steps used as for stats computation using ``ObservationNorm`` init_env_steps = 5000 @@ -764,8 +767,8 @@ def make_ddpg_actor( module=q_net, ).to(device) - # init lazy moduless - qnet(actor(proof_environment.reset())) + # initialize lazy modules + qnet(actor(proof_environment.reset().to(device))) return actor, qnet @@ -779,7 +782,7 @@ def make_ddpg_actor( # ~~~~~~~~~~~ # # The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` -# exploration module, as suggesed in the original paper. +# exploration module, as suggested in the original paper. # Let's define the number of frames before OU noise reaches its minimum value annealing_frames = 1_000_000 @@ -801,24 +804,27 @@ def make_ddpg_actor( # environment and reset it when required. # Data collectors are designed to help developers have a tight control # on the number of frames per batch of data, on the (a)sync nature of this -# collection and on the resources allocated to the data collection (e.g. GPU, -# number of workers etc). +# collection and on the resources allocated to the data collection (for example +# GPU, number of workers, and so on). # # Here we will use -# :class:`~torchrl.collectors.MultiaSyncDataCollector`, a data collector that -# will be executed in an async manner (i.e. data will be collected while -# the policy is being optimized). With the :class:`MultiaSyncDataCollector`, -# multiple workers are running rollouts separately. When a batch is asked, it -# is gathered from the first worker that can provide it. +# :class:`~torchrl.collectors.SyncDataCollector`, a simple, single-process +# data collector. TorchRL offers other collectors, such as +# :class:`~torchrl.collectors.MultiaSyncDataCollector`, which executed the +# rollouts in an asynchronous manner (for example, data will be collected while +# the policy is being optimized, thereby decoupling the training and +# data collection). # # The parameters to specify are: # -# - the list of environment creation functions, +# - an environment factory or an environment, # - the policy, # - the total number of frames before the collector is considered empty, # - the maximum number of frames per trajectory (useful for non-terminating -# environments, like dm_control ones). +# environments, like ``dm_control`` ones). +# # .. note:: +# # The ``max_frames_per_traj`` passed to the collector will have the effect # of registering a new :class:`~torchrl.envs.StepCounter` transform # with the environment used for inference. We can achieve the same result @@ -837,8 +843,8 @@ def make_ddpg_actor( ############################################################################### # The number of frames returned by the collector at each iteration of the outer -# loop is equal to the length of each sub-trajectories times the number of envs -# run in parallel in each collector. +# loop is equal to the length of each sub-trajectories times the number of +# environments run in parallel in each collector. # # In other words, we expect batches from the collector to have a shape # ``[env_per_collector, traj_len]`` where @@ -849,26 +855,18 @@ def make_ddpg_actor( init_random_frames = 5000 num_collectors = 2 -from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.envs import ExplorationType -collector = MultiaSyncDataCollector( - create_env_fn=[ - parallel_env, - ] - * num_collectors, +collector = SyncDataCollector( + parallel_env, policy=actor_model_explore, total_frames=total_frames, - # max_frames_per_traj=max_frames_per_traj, # this is achieved by the env constructor frames_per_batch=frames_per_batch, init_random_frames=init_random_frames, reset_at_each_iter=False, split_trajs=False, - device=device, - # device for execution - storing_device=device, - # device where data will be stored and passed - update_at_each_batch=False, + device=collector_device, exploration_type=ExplorationType.RANDOM, ) @@ -961,7 +959,7 @@ def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb ############################################################################### -# We'll store the replay buffer in a temporary dirrectory on disk +# We'll store the replay buffer in a temporary directory on disk import tempfile @@ -977,17 +975,17 @@ def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb # size by dividing it by the length of the sub-trajectories yielded by our # data collector. # Regarding the batch-size, our sampling strategy will consist in sampling -# trajectories of length ``traj_len=200`` before selecting sub-trajecotries +# trajectories of length ``traj_len=200`` before selecting sub-trajectories # or length ``random_crop_len=25`` on which the loss will be computed. # This strategy balances the choice of storing whole trajectories of a certain -# length with the need for providing sampels with a sufficient heterogeneity +# length with the need for providing samples with a sufficient heterogeneity # to our loss. The following figure shows the dataflow from a collector # that gets 8 frames in each batch with 2 environments run in parallel, # feeds them to a replay buffer that contains 1000 trajectories and # samples sub-trajectories of 2 time steps each. # # .. figure:: /_static/img/replaybuffer_traj.png -# :alt: Storign trajectories in the replay buffer +# :alt: Storing trajectories in the replay buffer # # Let's start with the number of frames stored in the buffer @@ -1005,7 +1003,7 @@ def ceil_div(x, y): ############################################################################### # We also need to define how many updates we'll be doing per batch of data -# collected. This is known as the update-to-data or UTD ratio: +# collected. This is known as the update-to-data or ``UTD`` ratio: update_to_data = 64 ############################################################################### @@ -1032,7 +1030,7 @@ def ceil_div(x, y): # Loss module construction # ------------------------ # -# We build our loss module with the actor and qnet we've just created. +# We build our loss module with the actor and ``qnet`` we've just created. # Because we have target parameters to update, we _must_ create a target network # updater. # @@ -1189,7 +1187,7 @@ def ceil_div(x, y): # # .. note:: # As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 1M. +# use a greater value for ``total_frames`` for example, 1M. from matplotlib import pyplot as plt @@ -1205,7 +1203,7 @@ def ceil_div(x, y): # Conclusion # ---------- # -# In this tutorial, we have learnt how to code a loss module in TorchRL given +# In this tutorial, we have learned how to code a loss module in TorchRL given # the concrete example of DDPG. # # The key takeaways are: @@ -1215,3 +1213,11 @@ def ceil_div(x, y): # - How to use (or not) a target network, and how to update its parameters; # - How to create an optimizer associated with a loss module. # +# Next Steps +# ---------- +# +# To iterate further on this loss module we might consider: +# +# - Using `@dispatch` (see `[Feature] Distpatch IQL loss module `_.) +# - Allowing flexible TensorDict keys. +# diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index fcddd699b3a..f85f6bf1e14 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -86,6 +86,8 @@ import tempfile import warnings +from tensordict.nn import TensorDictSequential + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -94,13 +96,17 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore - import os import uuid @@ -125,7 +131,7 @@ ToTensorImage, TransformedEnv, ) -from torchrl.modules import DuelingCnnDQNet, EGreedyWrapper, QValueActor +from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger @@ -270,6 +276,7 @@ def get_norm_stats(): # let's check that normalizing constants have a size of ``[C, 1, 1]`` where # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`). print("state dict of the observation norm:", obs_norm_sd) + test_env.close() return obs_norm_sd @@ -328,13 +335,14 @@ def make_model(dummy_env): tensordict = dummy_env.fake_tensordict() actor(tensordict) - # we wrap our actor in an EGreedyWrapper for data collection - actor_explore = EGreedyWrapper( - actor, + # we join our actor with an EGreedyModule for data collection + exploration_module = EGreedyModule( + spec=dummy_env.action_spec, annealing_num_steps=total_frames, eps_init=eps_greedy_val, eps_end=eps_greedy_val_env, ) + actor_explore = TensorDictSequential(actor, exploration_module) return actor, actor_explore @@ -381,6 +389,13 @@ def get_replay_buffer(buffer_size, n_optim, batch_size): # We choose the following configuration: we will be running a series of # parallel environments synchronously in parallel in different collectors, # themselves running in parallel but asynchronously. +# +# .. note:: +# This feature is only available when running the code within the "spawn" +# start method of python multiprocessing library. If this tutorial is run +# directly as a script (thereby using the "fork" method) we will be using +# a regular :class:`~torchrl.collectors.SyncDataCollector`. +# # The advantage of this configuration is that we can balance the amount of # compute that is executed in batch with what we want to be executed # asynchronously. We encourage the reader to experiment how the collection @@ -409,11 +424,10 @@ def get_collector( total_frames, device, ): - data_collector = MultiaSyncDataCollector( - [ - make_env(parallel=True, obs_norm_sd=stats), - ] - * num_collectors, + cls = MultiaSyncDataCollector + env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors + data_collector = cls( + env_arg, policy=actor_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -464,7 +478,12 @@ def get_loss_module(actor, gamma): # in practice, and the performance of the algorithm should hopefully not be # too sensitive to slight variations of these. -device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ############################################################################### # Optimizer @@ -642,6 +661,12 @@ def get_loss_module(actor, gamma): ) recorder.register(trainer) +############################################################################### +# The exploration module epsilon factor is also annealed: +# + +trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch) + ############################################################################### # - Any callable (including :class:`~torchrl.trainers.TrainerHookBase` # subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`. diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 679d625220c..be82bbd3bd8 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -15,8 +15,8 @@ Key learnings: -- How to create an environment in TorchRL, transform its outputs, and collect data from this env; -- How to make your classes talk to each other using :class:`tensordict.TensorDict`; +- How to create an environment in TorchRL, transform its outputs, and collect data from this environment; +- How to make your classes talk to each other using :class:`~tensordict.TensorDict`; - The basics of building your training loop with TorchRL: - How to compute the advantage signal for policy gradient methods; @@ -56,7 +56,7 @@ # problem rather than re-inventing the wheel every time you want to train a policy. # # For completeness, here is a brief overview of what the loss computes, even though -# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows: +# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows: # 1. we will sample a batch of data by playing the # policy in the environment for a given number of steps. # 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using @@ -99,7 +99,7 @@ # 5. Finally, we will run our training loop and analyze the results. # # Throughout this tutorial, we'll be using the :mod:`tensordict` library. -# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract +# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract # what a module reads and writes and care less about the specific data # description and more about the algorithm itself. # @@ -114,9 +114,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -159,7 +164,12 @@ # actually return ``frame_skip`` frames). # -device = "cpu" if not torch.has_cuda else "cuda:0" +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) num_cells = 256 # number of cells in each layer i.e. output dim. lr = 3e-4 max_grad_norm = 1.0 @@ -174,22 +184,10 @@ # use. In general, the goal of an RL algorithm is to learn to solve the task # as fast as it can in terms of environment interactions: the lower the ``total_frames`` # the better. -# We also define a ``frame_skip``: in some contexts, repeating the same action -# multiple times over the course of a trajectory may be beneficial as it makes -# the behavior more consistent and less erratic. However, "skipping" -# too many frames will hamper training by reducing the reactivity of the actor -# to observation changes. -# -# When using ``frame_skip`` it is good practice to -# correct the other frame counts by the number of frames we are grouping -# together. If we configure a total count of X frames for training but -# use a ``frame_skip`` of Y, we will be actually collecting XY frames in total -# which exceeds our predefined budget. -# -frame_skip = 1 -frames_per_batch = 1000 // frame_skip +# +frames_per_batch = 1000 # For a complete training, bring the number of frames up to 1M -total_frames = 10_000 // frame_skip +total_frames = 10_000 ###################################################################### # PPO parameters @@ -220,23 +218,23 @@ # control system. Various libraries provide simulation environments for reinforcement # learning, including Gymnasium (previously OpenAI Gym), DeepMind control suite, and # many others. -# As a generalistic library, TorchRL's goal is to provide an interchangeable interface +# As a general library, TorchRL's goal is to provide an interchangeable interface # to a large panel of RL simulators, allowing you to easily swap one environment # with another. For example, creating a wrapped gym environment can be achieved with few characters: # -base_env = GymEnv("InvertedDoublePendulum-v4", device=device, frame_skip=frame_skip) +base_env = GymEnv("InvertedDoublePendulum-v4", device=device) ###################################################################### # There are a few things to notice in this code: first, we created # the environment by calling the ``GymEnv`` wrapper. If extra keyword arguments # are passed, they will be transmitted to the ``gym.make`` method, hence covering -# the most common env construction commands. +# the most common environment construction commands. # Alternatively, one could also directly create a gym environment using ``gym.make(env_name, **kwargs)`` # and wrap it in a `GymWrapper` class. # # Also the ``device`` argument: for gym, this only controls the device where -# input action and observered states will be stored, but the execution will always +# input action and observed states will be stored, but the execution will always # be done on CPU. The reason for this is simply that gym does not support on-device # execution, unless specified otherwise. For other libraries, we have control over # the execution device and, as much as we can, we try to stay consistent in terms of @@ -248,9 +246,9 @@ # We will append some transforms to our environments to prepare the data for # the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different # approach, more similar to other pytorch domain libraries, through the use of transforms. -# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv` -# instance and append the sequence of transforms to it. The transformed env will inherit -# the device and meta-data of the wrapped env, and transform these depending on the sequence +# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv` +# instance and append the sequence of transforms to it. The transformed environment will inherit +# the device and meta-data of the wrapped environment, and transform these depending on the sequence # of transforms it contains. # # Normalization @@ -262,17 +260,17 @@ # run a certain number of random steps in the environment and compute # the summary statistics of these observations. # -# We'll append two other transforms: the :class:`DoubleToFloat` transform will +# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will # convert double entries to single-precision numbers, ready to be read by the -# policy. The :class:`StepCounter` transform will be used to count the steps before +# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before # the environment is terminated. We will use this measure as a supplementary measure # of performance. # -# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict` +# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict` # to communicate. You could think of it as a python dictionary with some extra # tensor features. In practice, this means that many modules we will be working # with need to be told what key to read (``in_keys``) and what key to write -# (``out_keys``) in the tensordict they will receive. Usually, if ``out_keys`` +# (``out_keys``) in the ``tensordict`` they will receive. Usually, if ``out_keys`` # is omitted, it is assumed that the ``in_keys`` entries will be updated # in-place. For our transforms, the only entry we are interested in is referred # to as ``"observation"`` and our transform layers will be told to modify this @@ -284,22 +282,20 @@ Compose( # normalize observations ObservationNorm(in_keys=["observation"]), - DoubleToFloat( - in_keys=["observation"], - ), + DoubleToFloat(), StepCounter(), ), ) ###################################################################### # As you may have noticed, we have created a normalization layer but we did not -# set its normalization parameters. To do this, :class:`ObservationNorm` can +# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can # automatically gather the summary statistics of our environment: # env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0) ###################################################################### -# The :class:`ObservationNorm` transform has now been populated with a +# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a # location and a scale that will be used to normalize the data. # # Let us do a little sanity check for the shape of our summary stats: @@ -313,25 +309,23 @@ # For efficiency purposes, TorchRL is quite stringent when it comes to # environment specs, but you can easily check that your environment specs are # adequate. -# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits -# from it already take care of setting the proper specs for your env so +# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and +# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits +# from it already take care of setting the proper specs for your environment so # you should not have to care about this. # # Nevertheless, let's see a concrete example using our transformed # environment by looking at its specs. -# There are five specs to look at: ``observation_spec`` which defines what +# There are three specs to look at: ``observation_spec`` which defines what # is to be expected when executing an action in the environment, -# ``reward_spec`` which indicates the reward domain, -# ``done_spec`` which indicates the done state of an environment, -# the ``action_spec`` which defines the action space, dtype and device and -# the ``state_spec`` which groups together the specs of all the other inputs -# (if any) to the environment. +# ``reward_spec`` which indicates the reward domain and finally the +# ``input_spec`` (which contains the ``action_spec``) and which represents +# everything an environment requires to execute a single step. # print("observation_spec:", env.observation_spec) print("reward_spec:", env.reward_spec) -print("done_spec:", env.done_spec) -print("action_spec:", env.action_spec) -print("state_spec:", env.state_spec) +print("input_spec:", env.input_spec) +print("action_spec (as defined by input_spec):", env.action_spec) ###################################################################### # the :func:`check_env_specs` function runs a small rollout and compares its output against the environment @@ -349,9 +343,9 @@ # action as input, and outputs an observation, a reward and a done state. The # observation may be composite, meaning that it could be composed of more than one # tensor. This is not a problem for TorchRL, since the whole set of observations -# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout -# (ie a sequence of environment steps and random action generations) over a given -# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape +# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout +# (for example, a sequence of environment steps and random action generations) over a given +# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape # that matches this trajectory length: # rollout = env.rollout(3) @@ -361,8 +355,8 @@ ###################################################################### # Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps # we ran it for. The ``"next"`` entry points to the data coming after the current step. -# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this -# may not be the case if we are using some specific transformations (e.g. multi-step). +# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this +# may not be the case if we are using some specific transformations (for example, multi-step). # # Policy # ------ @@ -388,10 +382,9 @@ # # 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``. # -# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (ie splits the input in two equal parts -# and applies a positive transformation to the scale parameter). +# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter). # -# 3. Create a probabilistic :class:`TensorDictModule` that can generate this distribution and sample from it. +# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it. # actor_net = nn.Sequential( @@ -406,8 +399,8 @@ ) ###################################################################### -# To enable the policy to "talk" with the environment through the tensordict -# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This +# To enable the policy to "talk" with the environment through the ``tensordict`` +# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This # class will simply ready the ``in_keys`` it is provided with and write the # outputs in-place at the registered ``out_keys``. # @@ -417,18 +410,19 @@ ###################################################################### # We now need to build a distribution out of the location and scale of our -# normal distribution. To do so, we instruct the :class:`ProbabilisticActor` -# class to build a :class:`TanhNormal` out of the location and scale +# normal distribution. To do so, we instruct the +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` +# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale # parameters. We also provide the minimum and maximum values of this # distribution, which we gather from the environment specs. # # The name of the ``in_keys`` (and hence the name of the ``out_keys`` from -# the :class:`TensorDictModule` above) cannot be set to any value one may -# like, as the :class:`TanhNormal` distribution constructor will expect the +# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may +# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the # ``loc`` and ``scale`` keyword arguments. That being said, -# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys`` -# where the key-value pair indicates what ``in_key`` string should be used for -# every keyword argument that is to be used. +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts +# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates +# what ``in_key`` string should be used for every keyword argument that is to be used. # policy_module = ProbabilisticActor( module=policy_module, @@ -436,8 +430,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.minimum, - "max": env.action_spec.space.maximum, + "min": env.action_spec.space.low, + "max": env.action_spec.space.high, }, return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights @@ -451,7 +445,7 @@ # won't be used at inference time. This module will read the observations and # return an estimation of the discounted return for the following trajectory. # This allows us to amortize learning by relying on the some utility estimation -# that is learnt on-the-fly during training. Our value network share the same +# that is learned on-the-fly during training. Our value network share the same # structure as the policy, but for simplicity we assign it its own set of # parameters. # @@ -472,7 +466,7 @@ ###################################################################### # let's try our policy and value modules. As we said earlier, the usage of -# :class:`TensorDictModule` makes it possible to directly read the output +# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output # of the environment to run these modules, as they know what information to read # and where to write it: # @@ -483,11 +477,11 @@ # Data collector # -------------- # -# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these -# classes execute three operations: reset an environment, compute an action -# given the latest observation, execute a step in the environment, and repeat -# the last two steps until the environment signals a stop (or reaches a done -# state). +# TorchRL provides a set of `DataCollector classes `__. +# Briefly, these classes execute three operations: reset an environment, +# compute an action given the latest observation, execute a step in the environment, +# and repeat the last two steps until the environment signals a stop (or reaches +# a done state). # # They allow you to control how many frames to collect at each iteration # (through the ``frames_per_batch`` parameter), @@ -495,18 +489,19 @@ # on which ``device`` the policy should be executed, etc. They are also # designed to work efficiently with batched and multiprocessed environments. # -# The simplest data collector is the :class:`SyncDataCollector`: it is an -# iterator that you can use to get batches of data of a given length, and +# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`: +# it is an iterator that you can use to get batches of data of a given length, and # that will stop once a total number of frames (``total_frames``) have been # collected. -# Other data collectors (``MultiSyncDataCollector`` and -# ``MultiaSyncDataCollector``) will execute the same operations in synchronous -# and asynchronous manner over a set of multiprocessed workers. +# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute +# the same operations in synchronous and asynchronous manner over a +# set of multiprocessed workers. # # As for the policy and environment before, the data collector will return -# :class:`tensordict.TensorDict` instances with a total number of elements that will -# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the -# training loop allows you to write dataloading pipelines +# :class:`~tensordict.TensorDict` instances with a total number of elements that will +# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the +# training loop allows you to write data loading pipelines # that are 100% oblivious to the actual specificities of the rollout content. # collector = SyncDataCollector( @@ -528,10 +523,10 @@ # of epochs. # # TorchRL's replay buffers are built using a common container -# :class:`ReplayBuffer` which takes as argument the components of the buffer: -# a storage, a writer, a sampler and possibly some transforms. Only the -# storage (which indicates the replay buffer capacity) is mandatory. We -# also specify a sampler without repetition to avoid sampling multiple times +# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components +# of the buffer: a storage, a writer, a sampler and possibly some transforms. +# Only the storage (which indicates the replay buffer capacity) is mandatory. +# We also specify a sampler without repetition to avoid sampling multiple times # the same item in one epoch. # Using a replay buffer for PPO is not mandatory and we could simply # sample the sub-batches from the collected batch, but using these classes @@ -539,7 +534,7 @@ # replay_buffer = ReplayBuffer( - storage=LazyTensorStorage(frames_per_batch), + storage=LazyTensorStorage(max_size=frames_per_batch), sampler=SamplerWithoutReplacement(), ) @@ -547,8 +542,8 @@ # Loss function # ------------- # -# The PPO loss can be directly imported from torchrl for convenience using the -# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO: +# The PPO loss can be directly imported from TorchRL for convenience using the +# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO: # it hides away the mathematical operations of PPO and the control flow that # goes with it. # @@ -558,11 +553,11 @@ # To compute the advantage, one just needs to (1) build the advantage module, which # utilizes our value operator, and (2) pass each batch of data through it before each # epoch. -# The GAE module will update the input :class:`TensorDict` with new ``"advantage"`` and +# The GAE module will update the input ``tensordict`` with new ``"advantage"`` and # ``"value_target"`` entries. # The ``"value_target"`` is a gradient-free tensor that represents the empirical # value that the value network should represent with the input observation. -# Both of these will be used by :class:`ClipPPOLoss` to +# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to # return the policy and value losses. # @@ -577,9 +572,7 @@ entropy_bonus=bool(entropy_eps), entropy_coef=entropy_eps, # these keys match by default but we set this for completeness - value_target_key=advantage_module.value_target_key, critic_coef=1.0, - gamma=0.99, loss_critic_type="smooth_l1", ) @@ -610,7 +603,7 @@ logs = defaultdict(list) -pbar = tqdm(total=total_frames * frame_skip) +pbar = tqdm(total=total_frames) eval_str = "" # We iterate over the collector until it reaches the total number of frames it was @@ -621,8 +614,7 @@ # We'll need an "advantage" signal to make PPO work. # We re-compute it at each epoch as its value depends on the value # network which is updated in the inner loop. - with torch.no_grad(): - advantage_module(tensordict_data) + advantage_module(tensordict_data) data_view = tensordict_data.reshape(-1) replay_buffer.extend(data_view.cpu()) for _ in range(frames_per_batch // sub_batch_size): @@ -634,7 +626,7 @@ + loss_vals["loss_entropy"] ) - # Optimization: backward, grad clipping and optim step + # Optimization: backward, grad clipping and optimization step loss_value.backward() # this is not strictly mandatory but it's good practice to keep # your gradient norm bounded @@ -643,7 +635,7 @@ optim.zero_grad() logs["reward"].append(tensordict_data["next", "reward"].mean().item()) - pbar.update(tensordict_data.numel() * frame_skip) + pbar.update(tensordict_data.numel()) cum_reward_str = ( f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})" ) @@ -655,8 +647,8 @@ # We evaluate the policy once every 10 batches of data. # Evaluation is rather simple: execute the policy without exploration # (take the expected value of the action distribution) for a given - # number of steps (1000, which is our env horizon). - # The ``rollout`` method of the env can take a policy as argument: + # number of steps (1000, which is our ``env`` horizon). + # The ``rollout`` method of the ``env`` can take a policy as argument: # it will then execute this policy at each step. with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): # execute a rollout with the trained policy @@ -717,7 +709,7 @@ # we could run several simulations in parallel to speed up data collection. # Check :class:`~torchrl.envs.ParallelEnv` for further information. # -# * From a logging perspective, one could add a :class:`~torchrl.record.VideoRecorder` transform to +# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to # the environment after asking for rendering to get a visual rendering of the # inverted pendulum in action. Check :py:mod:`torchrl.record` to # know more. diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index a1c82d5c429..b71a112c91a 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -78,15 +78,24 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore import torch import tqdm -from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictSequential, + TensorDictSequential as Seq, +) from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -104,10 +113,15 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv -from torchrl.modules import ConvNet, EGreedyWrapper, LSTMModule, MLP, QValueModule +from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate -device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu") +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ###################################################################### # Environment @@ -309,11 +323,15 @@ # DQN being a deterministic algorithm, exploration is a crucial part of it. # We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying # progressively to 0. -# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step` +# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step` # (see training loop below). # -stoch_policy = EGreedyWrapper( - stoch_policy, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 +exploration_module = EGreedyModule( + annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 +) +stoch_policy = TensorDictSequential( + stoch_policy, + exploration_module, ) ###################################################################### @@ -419,7 +437,7 @@ pbar.set_description( f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}" ) - stoch_policy.step(data.numel()) + exploration_module.step(data.numel()) updater.step() with set_exploration_type(ExplorationType.MODE), torch.no_grad(): diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index a12c2b05ff8..68cb995a1a3 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -20,9 +20,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 90fd82dab3c..7451d6b39e7 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -122,10 +122,12 @@ # Torch import torch -# Tensordict modules from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +# Tensordict modules +from torch import multiprocessing + # Data collection from torchrl.collectors import SyncDataCollector from torchrl.data.replay_buffers import ReplayBuffer @@ -161,7 +163,12 @@ # # Devices -device = "cpu" if not torch.has_cuda else "cuda:0" # The divice where learning is run +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) vmas_device = device # The device where the simulator is run (VMAS can run on GPU) # Sampling diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 12c8bdc3193..a67976566d5 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -84,9 +84,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index e8abf33cef8..03265c50d2b 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -24,9 +24,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -37,7 +42,12 @@ from torchrl.envs.libs.gym import GymEnv from torchrl.modules import Actor -device = "cuda:0" if torch.cuda.device_count() else "cpu" +is_fork = multiprocessing.get_start_method() == "fork" +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ############################################################################## # Let us first create an environment. For the sake of simplicity, we will be using diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 6106e3cf65a..3d37ce3de83 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -57,9 +57,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -133,7 +138,7 @@ from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage # We define the maximum size of the buffer -size = 10_000 +size = 100 ###################################################################### # A buffer with a list storage buffer can store any kind of data (but we must @@ -260,10 +265,10 @@ class MyData: data = MyData( images=torch.randint( 255, - (1000, 64, 64, 3), + (10, 64, 64, 3), ), - labels=torch.randint(100, (1000,)), - batch_size=[1000], + labels=torch.randint(100, (10,)), + batch_size=[10], ) tempdir = tempfile.TemporaryDirectory() @@ -303,7 +308,7 @@ def transform(x): # Let's build our replay buffer on disk: -rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform) +rb = ReplayBuffer(storage=LazyMemmapStorage(size), transform=transform) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, @@ -344,12 +349,21 @@ def assert0(x): # # Fixed batch-size # ~~~~~~~~~~~~~~~~ -# If the batch-size is passed during construction, it should be ommited when +# If the batch-size is passed during construction, it should be omited when # sampling: +data = MyData( + images=torch.randint( + 255, + (10, 64, 64, 3), + ), + labels=torch.randint(100, (10,)), + batch_size=[10], +) + buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) -buffer_lazymemmap.extend(data) -buffer_lazymemmap.sample() +buffer_lazymemmap.add(data) +buffer_lazymemmap.sample() # will produces 128 identical samples ###################################################################### @@ -363,7 +377,7 @@ def assert0(x): buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background -buffer_lazymemmap.extend(data) +buffer_lazymemmap.add(data) print(buffer_lazymemmap.sample()) @@ -397,10 +411,10 @@ def assert0(x): # we create a data that is big enough to get a couple of samples data = TensorDict( { - "a": torch.arange(512).view(128, 4), - ("b", "c"): torch.arange(1024).view(128, 8), + "a": torch.arange(64).view(16, 4), + ("b", "c"): torch.arange(128).view(16, 8), }, - batch_size=[128], + batch_size=[16], ) buffer_lazymemmap.extend(data) @@ -443,7 +457,7 @@ def assert0(x): from torchrl.data.replay_buffers.samplers import PrioritizedSampler -size = 1000 +size = 100 rb = ReplayBuffer( storage=ListStorage(size), @@ -718,7 +732,7 @@ def assert0(x): GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) -rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) +rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) rb.add(data_exclude) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index d1a261e63f5..5e00442fe36 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -32,75 +32,138 @@ # **Content**: # .. aafig:: # -# "torchrl" -# │ -# ├── "collectors" -# │ └── "collectors.py" -# ├── "data" -# │ ├── "tensor_specs.py" -# │ ├── "postprocs" -# │ │ └── "postprocs.py" -# │ └── "replay_buffers" -# │ ├── "replay_buffers.py" -# │ └── "storages.py" -# ├── "envs" -# │ ├── "common.py" -# │ ├── "env_creator.py" -# │ ├── "gym_like.py" -# │ ├── "vec_env.py" -# │ ├── "libs" -# │ │ ├── "dm_control.py" -# │ │ └── "gym.py" -# │ └── "transforms" -# │ ├── "functional.py" -# │ └── "transforms.py" -# ├── "modules" -# │ ├── "distributions" -# │ │ ├── "continuous.py" -# │ │ └── "discrete.py" -# │ ├── "models" -# │ │ ├── "models.py" -# │ │ └── "exploration.py" -# │ └── "tensordict_module" -# │ ├── "actors.py" -# │ ├── "common.py" -# │ ├── "exploration.py" -# │ ├── "probabilistic.py" -# │ └── "sequence.py" -# ├── "objectives" -# │ ├── "common.py" -# │ ├── "ddpg.py" -# │ ├── "dqn.py" -# │ ├── "functional.py" -# │ ├── "ppo.py" -# │ ├── "redq.py" -# │ ├── "reinforce.py" -# │ ├── "sac.py" -# │ ├── "utils.py" -# │ └── "value" -# │ ├── "advantages.py" -# │ ├── "functional.py" -# │ ├── "pg.py" -# │ ├── "utils.py" -# │ └── "vtrace.py" -# ├── "record" -# │ └── "recorder.py" -# └── "trainers" -# ├── "loggers" -# │ ├── "common.py" -# │ ├── "csv.py" -# │ ├── "mlflow.py" -# │ ├── "tensorboard.py" -# │ └── "wandb.py" -# ├── "trainers.py" -# └── "helpers" -# ├── "collectors.py" -# ├── "envs.py" -# ├── "loggers.py" -# ├── "losses.py" -# ├── "models.py" -# ├── "replay_buffer.py" -# └── "trainers.py" +# "torchrl" +# │ +# ├── "collectors" +# │ └── "collectors.py" +# │ │ +# │ └── "distributed" +# │ └── "default_configs.py" +# │ └── "generic.py" +# │ └── "ray.py" +# │ └── "rpc.py" +# │ └── "sync.py" +# ├── "data" +# │ │ +# │ ├── "datasets" +# │ │ └── "atari_dqn.py" +# │ │ └── "d4rl.py" +# │ │ └── "d4rl_infos.py" +# │ │ └── "gen_dgrl.py" +# │ │ └── "minari_data.py" +# │ │ └── "openml.py" +# │ │ └── "openx.py" +# │ │ └── "roboset.py" +# │ │ └── "vd4rl.py" +# │ ├── "postprocs" +# │ │ └── "postprocs.py" +# │ ├── "replay_buffers" +# │ │ └── "replay_buffers.py" +# │ │ └── "samplers.py" +# │ │ └── "storages.py" +# │ │ └── "transforms.py" +# │ │ └── "writers.py" +# │ ├── "rlhf" +# │ │ └── "dataset.py" +# │ │ └── "prompt.py" +# │ │ └── "reward.py" +# │ └── "tensor_specs.py" +# ├── "envs" +# │ └── "batched_envs.py" +# │ └── "common.py" +# │ └── "env_creator.py" +# │ └── "gym_like.py" +# │ ├── "libs" +# │ │ └── "brax.py" +# │ │ └── "dm_control.py" +# │ │ └── "envpool.py" +# │ │ └── "gym.py" +# │ │ └── "habitat.py" +# │ │ └── "isaacgym.py" +# │ │ └── "jumanji.py" +# │ │ └── "openml.py" +# │ │ └── "pettingzoo.py" +# │ │ └── "robohive.py" +# │ │ └── "smacv2.py" +# │ │ └── "vmas.py" +# │ ├── "model_based" +# │ │ └── "common.py" +# │ │ └── "dreamer.py" +# │ ├── "transforms" +# │ │ └── "functional.py" +# │ │ └── "gym_transforms.py" +# │ │ └── "r3m.py" +# │ │ └── "rlhf.py" +# │ │ └── "transforms.py" +# │ │ └── "vc1.py" +# │ │ └── "vip.py" +# │ └── "vec_envs.py" +# ├── "modules" +# │ ├── "distributions" +# │ │ └── "continuous.py" +# │ │ └── "discrete.py" +# │ │ └── "truncated_normal.py" +# │ ├── "models" +# │ │ └── "decision_transformer.py" +# │ │ └── "exploration.py" +# │ │ └── "model_based.py" +# │ │ └── "models.py" +# │ │ └── "multiagent.py" +# │ │ └── "rlhf.py" +# │ ├── "planners" +# │ │ └── "cem.py" +# │ │ └── "common.py" +# │ │ └── "mppi.py" +# │ └── "tensordict_module" +# │ └── "actors.py" +# │ └── "common.py" +# │ └── "exploration.py" +# │ └── "probabilistic.py" +# │ └── "rnn.py" +# │ └── "sequence.py" +# │ └── "world_models.py" +# ├── "objectives" +# │ └── "a2c.py" +# │ └── "common.py" +# │ └── "cql.py" +# │ └── "ddpg.py" +# │ └── "decision_transformer.py" +# │ └── "deprecated.py" +# │ └── "dqn.py" +# │ └── "dreamer.py" +# │ └── "functional.py" +# │ └── "iql.py" +# │ ├── "multiagent" +# │ │ └── "qmixer.py" +# │ └── "ppo.py" +# │ └── "redq.py" +# │ └── "reinforce.py" +# │ └── "sac.py" +# │ └── "td3.py" +# │ ├── "value" +# │ └── "advantages.py" +# │ └── "functional.py" +# │ └── "pg.py" +# ├── "record" +# │ ├── "loggers" +# │ │ └── "common.py" +# │ │ └── "csv.py" +# │ │ └── "mlflow.py" +# │ │ └── "tensorboard.py" +# │ │ └── "wandb.py" +# │ └── "recorder.py" +# ├── "trainers" +# │ │ +# │ ├── "helpers" +# │ │ └── "collectors.py" +# │ │ └── "envs.py" +# │ │ └── "logger.py" +# │ │ └── "losses.py" +# │ │ └── "models.py" +# │ │ └── "replay_buffer.py" +# │ │ └── "trainers.py" +# │ └── "trainers.py" +# └── "version.py" # # Unlike other domains, RL is less about media than *algorithms*. As such, it # is harder to make truly independent components. @@ -135,9 +198,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index dc836b43150..56896637a87 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -37,9 +37,14 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + pass # sphinx_gallery_end_ignore @@ -575,7 +580,7 @@ def env_make(env_name): parallel_env = ParallelEnv( 2, [env_make, env_make], - [{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], + create_env_kwargs=[{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], ) tensordict = parallel_env.reset() @@ -619,7 +624,7 @@ def env_make(env_name): parallel_env = ParallelEnv( 2, [env_make, env_make], - [{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], + create_env_kwargs=[{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], ) parallel_env = TransformedEnv(parallel_env, GrayScale()) # transforms on main process tensordict = parallel_env.reset() From 0672359e57cec0a80987e299b418760a10de62d2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 4 Feb 2024 08:43:17 +0000 Subject: [PATCH 30/35] [BugFix] Fix load_state_dict and is_empty td bugfix impact (#1869) --- test/test_transforms.py | 13 +++++-------- torchrl/data/replay_buffers/storages.py | 16 +++++++++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 725945ef113..503a24b1f71 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10113,17 +10113,15 @@ def test_trans_parallel_env_check(self): def test_transform_no_env(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) - assert not td.is_empty() t = RemoveEmptySpecs() t._call(td) - assert td.is_empty() + assert len(td.keys()) == 0 def test_transform_compose(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) - assert not td.is_empty() t = Compose(RemoveEmptySpecs()) t._call(td) - assert td.is_empty() + assert len(td.keys()) == 0 def test_transform_env(self): base_env = self.DummyEnv() @@ -10138,7 +10136,7 @@ def test_transform_model(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) t = nn.Sequential(Compose(RemoveEmptySpecs())) td = t(td) - assert td.is_empty(), td + assert len(td.keys()) == 0 @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, rbclass): @@ -10154,14 +10152,13 @@ def test_transform_rb(self, rbclass): td = rb.sample(1) if "index" in td.keys(): del td["index"] - assert td.is_empty() + assert len(td.keys()) == 0 def test_transform_inverse(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) - assert not td.is_empty() t = RemoveEmptySpecs() t.inv(td) - assert not td.is_empty() + assert len(td.keys()) != 0 env = TransformedEnv(self.DummyEnv(), RemoveEmptySpecs()) td2 = env.transform.inv(TensorDict({}, [])) assert ("state", "sub") in td2.keys(True) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index d4d81f10bc1..a1ad94c21fe 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -226,7 +226,9 @@ def load_state_dict(self, state_dict): if isinstance(elt, torch.Tensor): self._storage.append(elt) elif isinstance(elt, (dict, OrderedDict)): - self._storage.append(TensorDict({}, []).load_state_dict(elt)) + self._storage.append( + TensorDict({}, []).load_state_dict(elt, strict=False) + ) else: raise TypeError( f"Objects of type {type(elt)} are not supported by ListStorage.load_state_dict" @@ -497,9 +499,11 @@ def load_state_dict(self, state_dict): ) elif isinstance(_storage, (dict, OrderedDict)): if is_tensor_collection(self._storage): - self._storage.load_state_dict(_storage) + self._storage.load_state_dict(_storage, strict=False) elif self._storage is None: - self._storage = TensorDict({}, []).load_state_dict(_storage) + self._storage = TensorDict({}, []).load_state_dict( + _storage, strict=False + ) else: raise RuntimeError( f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}. If your storage is pytree-based, use the dumps/load API instead." @@ -832,7 +836,7 @@ def load_state_dict(self, state_dict): ) elif isinstance(_storage, (dict, OrderedDict)): if is_tensor_collection(self._storage): - self._storage.load_state_dict(_storage) + self._storage.load_state_dict(_storage, strict=False) self._storage.memmap_() elif self._storage is None: warnings.warn( @@ -840,7 +844,9 @@ def load_state_dict(self, state_dict): "It is preferable to load a storage onto a" "pre-allocated one whenever possible." ) - self._storage = TensorDict({}, []).load_state_dict(_storage) + self._storage = TensorDict({}, []).load_state_dict( + _storage, strict=False + ) self._storage.memmap_() else: raise RuntimeError( From 5f8260102f67698c69a46c06b83ccac4de3710ee Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 4 Feb 2024 09:56:23 +0000 Subject: [PATCH 31/35] [BugFix] better device consistency in EGreedy (#1867) --- torchrl/modules/tensordict_module/exploration.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 9a7f88844cc..763b50eaa60 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -149,10 +149,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: out = action_tensordict.get(action_key) eps = self.eps.item() - cond = ( - torch.rand(action_tensordict.shape, device=action_tensordict.device) - < eps - ).to(out.dtype) + cond = torch.rand(action_tensordict.shape, device=out.device) < eps cond = expand_as_right(cond, out) spec = self.spec if spec is not None: @@ -177,7 +174,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"Action mask key {self.action_mask_key} not found in {tensordict}." ) spec.update_mask(action_mask) - out = cond * spec.rand().to(out.device) + (1 - cond) * out + out = torch.where(cond, spec.rand().to(out.device), out) else: raise RuntimeError("spec must be provided to the exploration wrapper.") action_tensordict.set(action_key, out) From 80fc87fa6a3da46e102287227d1b3fa81552fcc3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Feb 2024 16:25:15 +0000 Subject: [PATCH 32/35] [Doc] Installation instructions in API ref (#1871) --- docs/source/index.rst | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 91906abb857..49bcde82488 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,7 +11,14 @@ TorchRL TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. -It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. +You can install TorchRL directly from PyPI (see more about installation +instructions in the dedicated section below): + +.. code-block:: + + $ pip install torchrl + +TorchRL provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort. This repo attempts to align with the existing pytorch ecosystem libraries in that it has a "dataset pillar" @@ -30,6 +37,32 @@ TorchRL aims at a high modularity and good runtime performance. To read more about TorchRL philosophy and capabilities beyond this API reference, check the `TorchRL paper `__. +Installation +============ + +TorchRL releases are synced with PyTorch, so make sure you always enjoy the latest +features of the library with the `most recent version of PyTorch `__ (although core features +are guaranteed to be backward compatible with pytorch>=1.13). +Nightly releases can be installed via + +.. code-block:: + + $ pip install tensordict-nightly + $ pip install torchrl-nightly + +or via a ``git clone`` if you're willing to contribute to the library: + +.. code-block:: + + $ cd path/to/root + $ git clone https://github.com/pytorch/tensordict + $ git clone https://github.com/pytorch/rl + $ cd tensordict + $ python setup.py develop + $ cd ../rl + $ python setup.py develop + + Tutorials ========= From 19a920eed2ff055079dd03ccac3dbf32e11da2e8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Feb 2024 20:21:59 +0000 Subject: [PATCH 33/35] [BugFix] Fix update in serial / parallel env (#1866) --- test/mocking_classes.py | 19 +- test/test_collector.py | 18 +- test/test_env.py | 7 +- test/test_tensordictmodules.py | 35 ++- torchrl/collectors/collectors.py | 2 +- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/rlhf/dataset.py | 2 +- torchrl/envs/batched_envs.py | 343 ++++++++++++++---------- torchrl/envs/common.py | 10 +- torchrl/envs/gym_like.py | 4 +- torchrl/envs/transforms/transforms.py | 10 +- torchrl/envs/utils.py | 2 +- torchrl/trainers/trainers.py | 2 +- 13 files changed, 278 insertions(+), 178 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 7a32c9a38ef..d68c7f30aa3 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1072,7 +1072,7 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(torch.int).to(self.device) + self.count += action.to(dtype=torch.int, device=self.device) tensordict = TensorDict( source={ "observation": self.count.clone(), @@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): 3, ) ), + device=self.device, ) self.unbatched_action_spec = CompositeSpec( lazy=action_specs, + device=self.device, ) self.unbatched_reward_spec = CompositeSpec( { @@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): }, shape=(self.n_nested_dim,), ) - } + }, + device=self.device, ) self.unbatched_done_spec = CompositeSpec( { @@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): }, shape=(self.n_nested_dim,), ) - } + }, + device=self.device, ) self.action_spec = self.unbatched_action_spec.expand( @@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i): "lidar": lidar, "vector": vector_3d, "tensor_0": tensor_0, - } + }, + device=self.device, ) elif i == 1: return CompositeSpec( @@ -1497,7 +1502,8 @@ def get_agent_obs_spec(self, i): "lidar": lidar, "vector": vector_2d, "tensor_1": tensor_1, - } + }, + device=self.device, ) elif i == 2: return CompositeSpec( @@ -1505,7 +1511,8 @@ def get_agent_obs_spec(self, i): "camera": camera, "vector": vector_2d, "tensor_2": tensor_2, - } + }, + device=self.device, ) else: raise ValueError(f"Index {i} undefined for index 3") diff --git a/test/test_collector.py b/test/test_collector.py index b5afe7f35d7..09c6ee293c3 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1675,8 +1675,12 @@ def test_maxframes_error(): @pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) @pytest.mark.parametrize("env_device", [None, *get_available_devices()]) @pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("parallel", [False, True]) def test_reset_heterogeneous_envs( - policy_device: torch.device, env_device: torch.device, storing_device: torch.device + policy_device: torch.device, + env_device: torch.device, + storing_device: torch.device, + parallel, ): if ( policy_device is not None @@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs( env_device = torch.device("cpu") # explicit mapping elif env_device is not None and env_device.type == "cuda" and policy_device is None: policy_device = torch.device("cpu") - env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2)) - env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3)) - env = SerialEnv(2, [env1, env2], device=env_device) + env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2)) + env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3)) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + env = cls(2, [env1, env2], device=env_device) collector = SyncDataCollector( env, RandomPolicy(env.action_spec), @@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs( assert ( data[0]["next", "truncated"].squeeze() == torch.tensor([False, True], device=data_device).repeat(25)[:50] - ).all(), data[0]["next", "truncated"][:10] + ).all(), data[0]["next", "truncated"] assert ( data[1]["next", "truncated"].squeeze() == torch.tensor([False, False, True], device=data_device).repeat(17)[:50] diff --git a/test/test_env.py b/test/test_env.py index 22918c390df..e316e1ae10f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2095,7 +2095,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count): @pytest.mark.parametrize("batch_size", [(1, 2)]) @pytest.mark.parametrize("env_type", ["serial", "parallel"]) - def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): + @pytest.mark.parametrize("break_when_any_done", [False, True]) + def test_vec_env( + self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2 + ): env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size) if env_type == "serial": vec_env = SerialEnv(n_workers, env_fun) @@ -2109,7 +2112,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): rollout_steps, policy=policy, return_contiguous=False, - break_when_any_done=False, + break_when_any_done=break_when_any_done, ) td = dense_stack_tds(td) for i in range(env_fun().n_nested_dim): diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 83a283e4e56..c2df40be012 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -21,6 +21,7 @@ CompositeSpec, UnboundedContinuousTensorSpec, ) +from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( AdditiveGaussianWrapper, @@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based): ) @pytest.mark.parametrize("python_based", [True, False]) - def test_lstm_parallel_env(self, python_based): + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + torch.manual_seed(0) device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( @@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based): device=device, python_based=python_based, ) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv def create_transformed_env(): primer = lstm_module.make_tensordict_primer() @@ -1807,7 +1815,12 @@ def create_transformed_env(): env.append_transform(primer) return env - env = ParallelEnv( + if heterogeneous: + create_transformed_env = [ + EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env), + ] + env = cls( create_env_fn=create_transformed_env, num_workers=2, ) @@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based): ) @pytest.mark.parametrize("python_based", [True, False]) - def test_gru_parallel_env(self, python_based): + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_gru_parallel_env(self, python_based, parallel, heterogeneous): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + torch.manual_seed(0) + device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs gru_module = GRUModule( @@ -2134,7 +2151,17 @@ def create_transformed_env(): env.append_transform(primer) return env - env = ParallelEnv( + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if heterogeneous: + create_transformed_env = [ + EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env), + ] + + env = cls( create_env_fn=create_transformed_env, num_workers=2, ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index eff2434d487..bea46bb6cd4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1077,7 +1077,7 @@ def rollout(self) -> TensorDictBase: if self.storing_device is not None: tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=False) + self._shuttle.to(self.storing_device, non_blocking=True) ) else: tensordicts.append(self._shuttle) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a1ad94c21fe..55c57a6a6b4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -894,7 +894,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: # to be deprecated in v0.4 def map_device(tensor): if tensor.device != self.device: - return tensor.to(self.device, non_blocking=False) + return tensor.to(self.device, non_blocking=True) return tensor if is_tensor_collection(result): diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 19090d3f4c5..8f039b317fc 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -394,7 +394,7 @@ def get_dataloader( ) out = TensorDictReplayBuffer( storage=TensorStorage(data), - collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False), + collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True), sampler=SamplerWithoutReplacement(drop_last=True), batch_size=batch_size, prefetch=prefetch, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5e88cf4e86d..9669963cb33 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -8,6 +8,7 @@ import gc import os +import weakref from collections import OrderedDict from copy import deepcopy from functools import wraps @@ -19,7 +20,7 @@ import torch from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict._tensordict import _unravel_key_to_tuple, unravel_key +from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, @@ -40,7 +41,6 @@ from torchrl.envs.utils import ( _aggregate_end_of_traj, - _set_single_key, _sort_keys, _update_during_reset, clear_mpi_env_vars, @@ -419,7 +419,13 @@ def _check_for_empty_spec(specs: CompositeSpec): def map_device(key, value, device_map=device_map): return value.to(device_map[key]) - self._env_tensordict.named_apply(map_device, nested_keys=True) + # self._env_tensordict.named_apply( + # map_device, nested_keys=True, filter_empty=True + # ) + self._env_tensordict.named_apply( + map_device, + nested_keys=True, + ) self._batch_locked = meta_data.batch_locked else: @@ -535,22 +541,17 @@ def _create_td(self) -> None: self._selected_keys = self._selected_keys.union(reset_keys) # input keys - self._selected_input_keys = { - _unravel_key_to_tuple(key) for key in self._env_input_keys - } + self._selected_input_keys = {unravel_key(key) for key in self._env_input_keys} # output keys after reset self._selected_reset_keys = { - _unravel_key_to_tuple(key) - for key in self._env_obs_keys + self.done_keys + reset_keys + unravel_key(key) for key in self._env_obs_keys + self.done_keys + reset_keys } # output keys after reset, filtered self._selected_reset_keys_filt = { unravel_key(key) for key in self._env_obs_keys + self.done_keys } # output keys after step - self._selected_step_keys = { - _unravel_key_to_tuple(key) for key in self._env_output_keys - } + self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys} if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( @@ -689,11 +690,27 @@ def _start_workers(self) -> None: _num_workers = self.num_workers self._envs = [] - + weakref_set = set() for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - if self.device is not None: - env = env.to(self.device) + # We want to avoid having the same env multiple times + # so we try to deepcopy it if needed. If we can't, we make + # the user aware that this isn't a very good idea + wr = weakref.ref(env) + if wr in weakref_set: + try: + env = deepcopy(env) + except Exception: + warn( + "Deepcopying the env failed within SerialEnv " + "but more than one copy of the same env was found. " + "This is a dangerous situation if your env keeps track " + "of some variables (e.g., state) in-place. " + "We'll use the same copy of the environment be beaware that " + "this may have important, unwanted issues for stateful " + "environments!" + ) + weakref_set.add(wr) self._envs.append(env) self.is_closed = False @@ -755,8 +772,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = None else: env_device = _env.device - if env_device != self.device: - tensordict_ = tensordict_.to(env_device) + if env_device != self.device and env_device is not None: + tensordict_ = tensordict_.to(env_device, non_blocking=True) else: tensordict_ = tensordict_.clone(False) else: @@ -764,30 +781,33 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( - _td.select(*self._selected_reset_keys_filt, strict=False) + _td, + keys_to_update=list(self._selected_reset_keys_filt), ) selected_output_keys = self._selected_reset_keys_filt device = self.device - if self._single_task: - # select + clone creates 2 tds, but we can create one only - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in selected_output_keys: - _set_single_key( - self.shared_tensordict_parent, out, key, clone=True, device=device - ) - else: - out = self.shared_tensordict_parent.select( - *selected_output_keys, - strict=False, - ) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clone().clear_device_() + + # select + clone creates 2 tds, but we can create one only + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() + + # out = self.shared_tensordict_parent.named_apply( + # select_and_clone, + # nested_keys=True, + # filter_empty=True, + # ) + out = self.shared_tensordict_parent.named_apply( + select_and_clone, + nested_keys=True, + ) + del out["next"] + + if out.device != device: + if device is None: + out = out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=True) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -807,30 +827,29 @@ def _step( # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device - if env_device != self.device: - data_in = tensordict_in[i].to(env_device, non_blocking=False) + if env_device != self.device and env_device is not None: + data_in = tensordict_in[i].to(env_device, non_blocking=True) else: data_in = tensordict_in[i] out_td = self._envs[i]._step(data_in) - next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) + next_td[i].update_(out_td, keys_to_update=list(self._env_output_keys)) + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps device = self.device - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True, device=device) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clone().clear_device_() - else: - out = out.to(device, non_blocking=False) + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() + + # out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) + out = next_td.named_apply(select_and_clone, nested_keys=True) + + if out.device != device: + if device is None: + out = out.clear_device_() + elif out.device != device: + out = out.to(device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -1040,6 +1059,7 @@ def _start_workers(self) -> None: def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda + # self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) self.shared_tensordict_parent.apply(look_for_cuda) has_cuda = has_cuda[0] if has_cuda: @@ -1119,32 +1139,29 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def step_and_maybe_reset( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - if self._single_task and not self.has_lazy_inputs: - # We must use the in_keys and nothing else for the following reasons: - # - efficiency: copying all the keys will in practice mean doing a lot - # of writing operations since the input tensordict may (and often will) - # contain all the previous output data. - # - value mismatch: if the batched env is placed within a transform - # and this transform overrides an observation key (eg, CatFrames) - # the shape, dtype or device may not necessarily match and writing - # the value in-place will fail. - for key in self._env_input_keys: - self.shared_tensordict_parent.set_(key, tensordict.get(key)) - next_td = tensordict.get("next", None) - if next_td is not None: - # we copy the input keys as well as the keys in the 'next' td, if any - # as this mechanism can be used by a policy to set anticipatively the - # keys of the next call (eg, with recurrent nets) - for key in next_td.keys(True, True): - key = unravel_key(("next", key)) - if key in self.shared_tensordict_parent.keys(True, True): - self.shared_tensordict_parent.set_(key, next_td.get(key[1:])) + # We must use the in_keys and nothing else for the following reasons: + # - efficiency: copying all the keys will in practice mean doing a lot + # of writing operations since the input tensordict may (and often will) + # contain all the previous output data. + # - value mismatch: if the batched env is placed within a transform + # and this transform overrides an observation key (eg, CatFrames) + # the shape, dtype or device may not necessarily match and writing + # the value in-place will fail. + self.shared_tensordict_parent.update_( + tensordict, keys_to_update=self._env_input_keys + ) + next_td_passthrough = tensordict.get("next", None) + if next_td_passthrough is not None: + # if we have input "next" data (eg, RNNs which pass the next state) + # the sub-envs will need to process them through step_and_maybe_reset. + # We keep track of which keys are present to let the worker know what + # should be passd to the env (we don't want to pass done states for instance) + next_td_keys = list(next_td_passthrough.keys(True, True)) + self.shared_tensordict_parent.get("next").update_(next_td_passthrough) else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, "next", strict=False) - ) + next_td_keys = None for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", None)) + self.parent_channels[i].send(("step_and_maybe_reset", next_td_keys)) for i in range(self.num_workers): event = self._events[i] @@ -1160,8 +1177,20 @@ def step_and_maybe_reset( next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: - next_td = next_td.to(device, non_blocking=False) - tensordict_ = tensordict_.to(device, non_blocking=False) + next_td = next_td._fast_apply( + lambda x: x.to(device, non_blocking=True) + if x.device != device + else x.clone(), + device=device, + # filter_empty=True, + ) + tensordict_ = tensordict_._fast_apply( + lambda x: x.to(device, non_blocking=True) + if x.device != device + else x.clone(), + device=device, + # filter_empty=True, + ) else: next_td = next_td.clone().clear_device_() tensordict_ = tensordict_.clone().clear_device_() @@ -1170,35 +1199,33 @@ def step_and_maybe_reset( @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - if self._single_task and not self.has_lazy_inputs: - # We must use the in_keys and nothing else for the following reasons: - # - efficiency: copying all the keys will in practice mean doing a lot - # of writing operations since the input tensordict may (and often will) - # contain all the previous output data. - # - value mismatch: if the batched env is placed within a transform - # and this transform overrides an observation key (eg, CatFrames) - # the shape, dtype or device may not necessarily match and writing - # the value in-place will fail. - for key in tensordict.keys(True, True): - # we copy the input keys as well as the keys in the 'next' td, if any - # as this mechanism can be used by a policy to set anticipatively the - # keys of the next call (eg, with recurrent nets) - if key in self._env_input_keys or ( - isinstance(key, tuple) - and key[0] == "next" - and key in self.shared_tensordict_parent.keys(True, True) - ): - val = tensordict.get(key) - self.shared_tensordict_parent.set_(key, val) + # We must use the in_keys and nothing else for the following reasons: + # - efficiency: copying all the keys will in practice mean doing a lot + # of writing operations since the input tensordict may (and often will) + # contain all the previous output data. + # - value mismatch: if the batched env is placed within a transform + # and this transform overrides an observation key (eg, CatFrames) + # the shape, dtype or device may not necessarily match and writing + # the value in-place will fail. + self.shared_tensordict_parent.update_( + tensordict, keys_to_update=list(self._env_input_keys) + ) + next_td_passthrough = tensordict.get("next", None) + if next_td_passthrough is not None: + # if we have input "next" data (eg, RNNs which pass the next state) + # the sub-envs will need to process them through step_and_maybe_reset. + # We keep track of which keys are present to let the worker know what + # should be passd to the env (we don't want to pass done states for instance) + next_td_keys = list(next_td_passthrough.keys(True, True)) + self.shared_tensordict_parent.get("next").update_(next_td_passthrough) else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, "next", strict=False) - ) + next_td_keys = None + if self.event is not None: self.event.record() self.event.synchronize() for i in range(self.num_workers): - self.parent_channels[i].send(("step", None)) + self.parent_channels[i].send(("step", next_td_keys)) for i in range(self.num_workers): event = self._events[i] @@ -1209,19 +1236,21 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") device = self.device - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True, device=device) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False) - if out.device == device: - out = out.clone() + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() + + out = next_td.named_apply( + select_and_clone, + nested_keys=True, + # filter_empty=True, + ) + if out.device != device: + if device is None: + out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=True) return out @_check_start @@ -1258,13 +1287,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i] - .get("next") - .select(*self._selected_reset_keys, strict=False) + self.shared_tensordicts[i].get("next"), + keys_to_update=list(self._selected_reset_keys), ) if tensordict_ is not None: self.shared_tensordicts[i].update_( - tensordict_.select(*self._selected_reset_keys, strict=False) + tensordict_, keys_to_update=list(self._selected_reset_keys) ) continue out = ("reset", tensordict_) @@ -1278,26 +1306,23 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: selected_output_keys = self._selected_reset_keys_filt device = self.device - if self._single_task: - # select + clone creates 2 tds, but we can create one only - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in selected_output_keys: - _set_single_key( - self.shared_tensordict_parent, out, key, clone=True, device=device - ) - else: - out = self.shared_tensordict_parent.select( - *selected_output_keys, - strict=False, - ) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clear_device_().clone() + + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() + + out = self.shared_tensordict_parent.named_apply( + select_and_clone, + nested_keys=True, + # filter_empty=True, + ) + del out["next"] + + if out.device != device: + if device is None: + out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=True) return out @_check_start @@ -1427,6 +1452,7 @@ def _run_worker_pipe_shared_mem( def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda + # shared_tensordict.apply(look_for_cuda, filter_empty=True) shared_tensordict.apply(look_for_cuda) has_cuda = has_cuda[0] else: @@ -1498,7 +1524,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda): raise RuntimeError("call 'init' before resetting") cur_td = env.reset(tensordict=data) shared_tensordict.update_( - cur_td.select(*_selected_reset_keys, strict=False) + cur_td, + keys_to_update=list(_selected_reset_keys), ) if event is not None: event.record() @@ -1510,7 +1537,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + # No need to copy here since we don't write in-place + if data: + next_td_passthrough_keys = data + input = root_shared_tensordict.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + else: + input = root_shared_tensordict + next_td = env._step(input) next_shared_tensordict.update_(next_td) if event is not None: event.record() @@ -1522,9 +1557,25 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict) - next_shared_tensordict.update_(td.get("next")) + # We must copy the root shared td here, or at least get rid of done: + # if we don't `td is root_shared_tensordict` + # which means that root_shared_tensordict will carry the content of next + # in the next iteration. When using StepCounter, it will look for an + # existing done state, find it and consider the env as done by input (not + # by output) of the step! + # Caveat: for RNN we may need some keys of the "next" TD so we pass the list + # through data + if data: + next_td_passthrough_keys = data + input = root_shared_tensordict.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + else: + input = root_shared_tensordict + td, root_next_td = env.step_and_maybe_reset(input) + next_shared_tensordict.update_(td.pop("next")) root_shared_tensordict.update_(root_next_td) + if event is not None: event.record() event.synchronize() @@ -1588,5 +1639,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): def _update_cuda(t_dest, t_source): if t_source is None: return - t_dest.copy_(t_source.pin_memory(), non_blocking=False) + t_dest.copy_(t_source.pin_memory(), non_blocking=True) return + + +def _filter_empty(tensordict): + return tensordict.select(*tensordict.keys(True, True)) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b2b201922e1..61cd211b6ae 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2056,7 +2056,7 @@ def reset( tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly # if tensordict_reset.device != self.device: - # tensordict_reset = tensordict_reset.to(self.device, non_blocking=False) + # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -2418,13 +2418,13 @@ def _rollout_stop_early( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict = tensordict.to(policy_device, non_blocking=False) + tensordict = tensordict.to(policy_device, non_blocking=True) else: tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: - tensordict = tensordict.to(env_device, non_blocking=False) + tensordict = tensordict.to(env_device, non_blocking=True) else: tensordict.clear_device_() tensordict = self.step(tensordict) @@ -2472,13 +2472,13 @@ def _rollout_nonstop( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict_ = tensordict_.to(policy_device, non_blocking=False) + tensordict_ = tensordict_.to(policy_device, non_blocking=True) else: tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: if env_device is not None: - tensordict_ = tensordict_.to(env_device, non_blocking=False) + tensordict_ = tensordict_.to(env_device, non_blocking=True) else: tensordict_.clear_device_() tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 38995a07a6b..d3b3dfd659c 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -322,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=False) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -366,7 +366,7 @@ def _reset( for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() - tensordict_out = tensordict_out.to(self.device, non_blocking=False) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a661b152d39..efa59e25c26 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3612,10 +3612,10 @@ def set_container(self, container: Union[Transform, EnvBase]) -> None: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=False) + return tensordict.to(self.device, non_blocking=True) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=False) + return tensordict.to(self.device, non_blocking=True) def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -3628,8 +3628,8 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if parent is None: if self.orig_device is None: return tensordict - return tensordict.to(self.orig_device, non_blocking=False) - return tensordict.to(parent.device, non_blocking=False) + return tensordict.to(self.orig_device, non_blocking=True) + return tensordict.to(parent.device, non_blocking=True) def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec.to(self.device) @@ -5146,7 +5146,7 @@ def _reset( if step_count is None: step_count = self.container.observation_spec[step_count_key].zero() if step_count.device != reset.device: - step_count = step_count.to(reset.device, non_blocking=False) + step_count = step_count.to(reset.device, non_blocking=True) # zero the step count if reset is needed step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ebb9100655c..46c923ccfec 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -268,7 +268,7 @@ def _set_single_key( dest = new_val else: if device is not None and val.device != device: - val = val.to(device, non_blocking=False) + val = val.to(device, non_blocking=True) elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index f844613432c..03a7be37573 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -708,7 +708,7 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: def sample(self, batch: TensorDictBase) -> TensorDictBase: sample = self.replay_buffer.sample(batch_size=self.batch_size) - return sample.to(self.device, non_blocking=False) + return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch) From 528faa19fe51a65c2ab50f3eb7cf89bd7b701fd0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Feb 2024 20:53:31 +0000 Subject: [PATCH 34/35] [BugFix] check_env_specs seeding logic (#1872) --- test/test_utils.py | 20 +++++++++++++++++++- torchrl/_utils.py | 37 +++++++++++++++++++++++++++++++++++++ torchrl/envs/utils.py | 23 ++++++++++++++++++----- 3 files changed, 74 insertions(+), 6 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 620149daeb6..c2ce2eae6b9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,10 @@ import _utils_internal import pytest -from torchrl._utils import get_binary_env_var, implement_for +import torch + +from _utils_internal import get_default_devices +from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -358,6 +361,21 @@ class MockGym: ) # would break with gymnasium +@pytest.mark.parametrize("device", get_default_devices()) +def test_rng_decorator(device): + with torch.device(device): + torch.manual_seed(10) + s0a = torch.randn(3) + with _rng_decorator(0): + torch.randn(3) + s0b = torch.randn(3) + torch.manual_seed(10) + s1a = torch.randn(3) + s1b = torch.randn(3) + torch.testing.assert_close(s0a, s1a) + torch.testing.assert_close(s0b, s1b) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 9538cecb026..6c52b1d66e7 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -704,3 +704,40 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return new_ending else: return key[:-1] + (new_ending,) + + +class _rng_decorator(_DecoratorContextManager): + """Temporarily sets the seed and sets back the rng state when exiting.""" + + def __init__(self, seed, device=None): + self.seed = seed + self.device = device + self.has_cuda = torch.cuda.is_available() + + def __enter__(self): + self._get_state() + torch.manual_seed(self.seed) + + def _get_state(self): + if self.has_cuda: + if self.device is None: + self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state()) + else: + self._state = ( + torch.random.get_rng_state(), + torch.cuda.get_rng_state(self.device), + ) + + else: + self.state = torch.random.get_rng_state() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.has_cuda: + torch.random.set_rng_state(self._state[0]) + if self.device is not None: + torch.cuda.set_rng_state(self._state[1], device=self.device) + else: + torch.cuda.set_rng_state(self._state[1]) + + else: + torch.random.set_rng_state(self._state) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 46c923ccfec..71b15d1dfae 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -31,7 +31,7 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey -from torchrl._utils import _replace_last, logger as torchrl_logger +from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( CompositeSpec, @@ -419,7 +419,9 @@ def _per_level_env_check(data0, data1, check_dtype): ) -def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): +def check_env_specs( + env, return_contiguous=True, check_dtype=True, seed: int | None = None +): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with @@ -436,7 +438,12 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of inputs/outputs). Defaults to True. check_dtype (bool, optional): if False, dtype checks will be skipped. Defaults to True. - seed (int, optional): for reproducibility, a seed is set. + seed (int, optional): for reproducibility, a seed can be set. + The seed will be set in pytorch temporarily, then the RNG state will + be reverted to what it was before. For the env, we set the seed but since + setting the rng state back to what is was isn't a feature of most environment, + we leave it to the user to accomplish that. + Defaults to ``None``. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -444,8 +451,14 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """ if seed is not None: - torch.manual_seed(seed) - env.set_seed(seed) + device = ( + env.device if env.device is not None and env.device.type == "cuda" else None + ) + with _rng_decorator(seed, device=device): + env.set_seed(seed) + return check_env_specs( + env, return_contiguous=return_contiguous, check_dtype=check_dtype + ) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous) From eec9f9e4bb0270e2d9c567836dd397f154c4f2fc Mon Sep 17 00:00:00 2001 From: Vlad Sobal Date: Tue, 6 Feb 2024 03:53:42 -0500 Subject: [PATCH 35/35] [BugFix] Fix a bug in SliceSampler, indexes outside sampler lengths were produced (#1874) --- torchrl/data/replay_buffers/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 15e46ae1038..0352b803b66 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -790,7 +790,7 @@ def _get_stop_and_length(self, storage, fallback=True): raise RuntimeError( "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) - vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] + vals = self._find_start_stop_traj(end=done.squeeze()[: len(storage)]) if self.cache_values: self._cache["stop-and-length"] = vals return vals