From 576b0e9b7573ae0b05dfb00702d5e56a758f55e9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 19 Dec 2024 15:03:32 +0000 Subject: [PATCH] [BugFix] Compatibility of tensordict primers with batched envs (specifically for LSTM and GRU) ghstack-source-id: d5981b4dbee8305250faa776c46424c7cf959578 Pull Request resolved: https://github.com/pytorch/rl/pull/2668 --- .github/workflows/nightly_build.yml | 39 +++---- .../decision_transformer/utils.py | 2 +- test/test_tensordictmodules.py | 105 +++++++++++++----- test/test_transforms.py | 45 ++++++-- torchrl/envs/batched_envs.py | 36 +++++- torchrl/envs/transforms/transforms.py | 29 ++++- torchrl/modules/tensordict_module/rnn.py | 14 ++- 7 files changed, 204 insertions(+), 66 deletions(-) diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 08eb61bfa6c..732077f4b58 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -21,11 +21,6 @@ on: branches: - "nightly" -env: - ACTIONS_RUNNER_FORCED_INTERNAL_NODE_VERSION: node16 - ACTIONS_RUNNER_FORCE_ACTIONS_NODE_VERSION: node16 - ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true # https://github.com/actions/checkout/issues/1809 - concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. @@ -41,12 +36,15 @@ jobs: matrix: python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] - container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 env: AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache" + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version[0] }} - name: Install PyTorch nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -67,7 +65,7 @@ jobs: python3 -mpip install auditwheel auditwheel show dist/* - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: dist/*.whl @@ -81,12 +79,15 @@ jobs: matrix: python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] - container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version[0] }} - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: /tmp/wheels @@ -121,7 +122,7 @@ jobs: env: AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache" - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch Nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -138,7 +139,7 @@ jobs: export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install numpy pytest pillow>=4.1.1 scipy networkx expecttest pyyaml - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: /tmp/wheels @@ -179,7 +180,7 @@ jobs: with: python-version: ${{ matrix.python_version[1] }} - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch nightly shell: bash run: | @@ -193,7 +194,7 @@ jobs: --package_name torchrl-nightly \ --python-tag=${{ matrix.python-tag }} - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: dist/*.whl @@ -212,7 +213,7 @@ jobs: with: python-version: ${{ matrix.python_version[1] }} - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch Nightly shell: bash run: | @@ -229,7 +230,7 @@ jobs: run: | python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: wheels @@ -265,9 +266,9 @@ jobs: python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: wheels diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index d4a67e7d3a9..415e19a1f7c 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): ) # copy action from the input tensordict to the output - transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) + transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec)) transformed_env.append_transform(DoubleToFloat()) obsnorm = ObservationNorm( diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index d3b7b7850f4..c2a34f3797d 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import functools import os import pytest @@ -12,6 +13,7 @@ import torchrl.modules from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential +from tensordict.utils import assert_close from torch import nn from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs import ( @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based): @pytest.mark.parametrize("python_based", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) @pytest.mark.parametrize("heterogeneous", [True, False]) - def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): + @pytest.mark.parametrize("within", [False, True]) + def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv torch.manual_seed(0) + num_envs = 3 device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): else: cls = SerialEnv - def create_transformed_env(): - primer = lstm_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True, device=device + if within: + + def create_transformed_env(): + primer = lstm_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True, device=device + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + else: + create_transformed_env = functools.partial( + DiscreteActionVecMockEnv, + categorical_action_encoding=True, + device=device, ) - env = TransformedEnv(env) - env.append_transform(InitTracker()) - env.append_transform(primer) - return env if heterogeneous: create_transformed_env = [ - EnvCreator(create_transformed_env), - EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env) for _ in range(num_envs) ] env = cls( create_env_fn=create_transformed_env, - num_workers=2, + num_workers=num_envs, ) + if not within: + env = env.append_transform(InitTracker()) + env.append_transform(lstm_module.make_tensordict_primer()) mlp = TensorDictModule( MLP( @@ -1002,6 +1017,19 @@ def create_transformed_env(): data = env.rollout(10, actor, break_when_any_done=break_when_any_done) assert (data.get(("next", "recurrent_state_c")) != 0.0).all() assert (data.get("recurrent_state_c") != 0.0).any() + return data + + @pytest.mark.parametrize("python_based", [True, False]) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_lstm_parallel_within(self, python_based, parallel, heterogeneous): + out_within = self.test_lstm_parallel_env( + python_based, parallel, heterogeneous, within=True + ) + out_not_within = self.test_lstm_parallel_env( + python_based, parallel, heterogeneous, within=False + ) + assert_close(out_within, out_not_within) @pytest.mark.skipif( not _has_functorch, reason="vmap can only be used with functorch" @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based): @pytest.mark.parametrize("python_based", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) @pytest.mark.parametrize("heterogeneous", [True, False]) - def test_gru_parallel_env(self, python_based, parallel, heterogeneous): + @pytest.mark.parametrize("within", [False, True]) + def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv torch.manual_seed(0) + num_workers = 3 device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs @@ -1347,15 +1377,24 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous): python_based=python_based, ) - def create_transformed_env(): - primer = gru_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True, device=device + if within: + + def create_transformed_env(): + primer = gru_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True, device=device + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + else: + create_transformed_env = functools.partial( + DiscreteActionVecMockEnv, + categorical_action_encoding=True, + device=device, ) - env = TransformedEnv(env) - env.append_transform(InitTracker()) - env.append_transform(primer) - return env if parallel: cls = ParallelEnv @@ -1363,14 +1402,17 @@ def create_transformed_env(): cls = SerialEnv if heterogeneous: create_transformed_env = [ - EnvCreator(create_transformed_env), - EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env) for _ in range(num_workers) ] - env = cls( + env: ParallelEnv | SerialEnv = cls( create_env_fn=create_transformed_env, - num_workers=2, + num_workers=num_workers, ) + if not within: + primer = gru_module.make_tensordict_primer() + env = env.append_transform(InitTracker()) + env.append_transform(primer) mlp = TensorDictModule( MLP( @@ -1396,6 +1438,19 @@ def create_transformed_env(): data = env.rollout(10, actor, break_when_any_done=break_when_any_done) assert (data.get("recurrent_state") != 0.0).any() assert (data.get(("next", "recurrent_state")) != 0.0).all() + return data + + @pytest.mark.parametrize("python_based", [True, False]) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_gru_parallel_within(self, python_based, parallel, heterogeneous): + out_within = self.test_gru_parallel_env( + python_based, parallel, heterogeneous, within=True + ) + out_not_within = self.test_gru_parallel_env( + python_based, parallel, heterogeneous, within=False + ) + assert_close(out_within, out_not_within) @pytest.mark.skipif( not _has_functorch, reason="vmap can only be used with functorch" diff --git a/test/test_transforms.py b/test/test_transforms.py index cc3ca40b059..44ebce72c5c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7408,7 +7408,7 @@ def make_env(): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=Unbounded([2, 4])), + TensorDictPrimer(mykey=Unbounded([4])), ) try: check_env_specs(env) @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): pass @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) - def test_trans_serial_env_check(self, spec_shape): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=Unbounded(spec_shape)), - ) + @pytest.mark.parametrize("expand_specs", [True, False, None]) + def test_trans_serial_env_check(self, spec_shape, expand_specs): + if expand_specs is None: + with pytest.warns(FutureWarning, match=""): + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + mykey=Unbounded(spec_shape), expand_specs=expand_specs + ), + ) + env.observation_spec + elif expand_specs is True: + shape = spec_shape[:-1] + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + Composite(mykey=Unbounded(spec_shape), shape=shape), + expand_specs=expand_specs, + ), + ) + else: + # If we don't expand, we can't use [4] + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + mykey=Unbounded(spec_shape), expand_specs=expand_specs + ), + ) + if spec_shape == [4]: + with pytest.raises(ValueError): + env.observation_spec + return + check_env_specs(env) assert "mykey" in env.reset().keys() r = env.rollout(3) @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env): transform = KLRewardTransform(actor, out_keys=out_key) return Compose( TensorDictPrimer( - primers={ - "sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1]) - } + sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]), + shape=base_env.shape, ), transform, ) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 17bd28c8390..f7a25c1bd5c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We keep track of which keys are present to let the worker know what # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) + next_shared_tensordict_parent = shared_tensordict_parent.get("next") + + # We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset. + # The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of + # the batched env but part of the specs of a transformed batched env. + # If that is the case, `update_` will fail to find the entries to update. + # What we do instead is keeping the tensors on the side and putting them back after completing _step. + keys_to_update, keys_to_copy = zip( + *[ + (key, None) + if key in next_shared_tensordict_parent.keys(True, True) + else (None, key) + for key in next_td_keys + ] + ) + keys_to_update = [key for key in keys_to_update if key is not None] + keys_to_copy = [key for key in keys_to_copy if key is not None] data = [ - {"next_td_passthrough_keys": next_td_keys} + {"next_td_passthrough_keys": keys_to_update} for _ in range(self.num_workers) ] - shared_tensordict_parent.get("next").update_( - next_td_passthrough, non_blocking=self.non_blocking - ) + if keys_to_update: + next_shared_tensordict_parent.update_( + next_td_passthrough, + non_blocking=self.non_blocking, + keys_to_update=keys_to_update, + ) + if keys_to_copy: + next_td_passthrough = next_td_passthrough.select(*keys_to_copy) + else: + next_td_passthrough = None else: + next_td_passthrough = None data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: @@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor): LazyStackedTensorDict(*non_tensor_tds), keys_to_update=self._non_tensor_keys, ) + if next_td_passthrough is not None: + out.update(next_td_passthrough) + self._sync_w2m() if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..5b2dafc9b04 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4984,6 +4984,7 @@ def __init__( | Dict[NestedKey, float] | Dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, + expand_specs: bool = None, **kwargs, ): self.device = kwargs.pop("device", None) @@ -4995,8 +4996,10 @@ def __init__( ) kwargs = primers if not isinstance(kwargs, Composite): - kwargs = Composite(kwargs) + kwargs = Composite(**kwargs) self.primers = kwargs + self.expand_specs = expand_specs + if random and default_value: raise ValueError( "Setting random to True and providing a default_value are incompatible." @@ -5089,12 +5092,26 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: ) if self.primers.shape != observation_spec.shape: - try: - # We try to set the primer shape to the observation spec shape - self.primers.shape = observation_spec.shape - except ValueError: - # If we fail, we expand them to that shape + if self.expand_specs: self.primers = self._expand_shape(self.primers) + elif self.expand_specs is None: + warnings.warn( + f"expand_specs wasn't specified in the {type(self).__name__} constructor. " + f"The current behaviour is that the transform will attempt to set the shape of the composite " + f"spec, and if this can't be done it will be expanded. " + f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size " + f"will raise an exception.", + category=FutureWarning, + ) + try: + # We try to set the primer shape to the observation spec shape + self.primers.shape = observation_spec.shape + except ValueError: + # If we fail, we expand them to that shape + self.primers = self._expand_shape(self.primers) + else: + self.primers.shape = observation_spec.shape + device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) return observation_spec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index f4ceb648665..68309c346cd 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -592,6 +592,10 @@ def make_tensordict_primer(self): inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across processes and dealt with properly. + When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the + single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the + batched env instance level (i.e., a transformed batch of regular envs). + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states @@ -649,7 +653,8 @@ def make_tuple(key): { in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), - } + }, + expand_specs=True, ) @property @@ -1410,6 +1415,10 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the + single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the + batched env instance level (i.e., a transformed batch of regular envs). + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given module. @@ -1459,7 +1468,8 @@ def make_tuple(key): return TensorDictPrimer( { in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)), - } + }, + expand_specs=True, ) @property