From ab4250ec712094d978d2071b195ed7f6dab00dd8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:46 +0000 Subject: [PATCH] [BugFix] Fix batching envs with non tensor data ghstack-source-id: daba8a95459cfa978da09291757b6380fab4f308 Pull Request resolved: https://github.com/pytorch/rl/pull/2674 --- torchrl/envs/batched_envs.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f7a25c1bd5c..5b6763f6910 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -730,19 +730,20 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) - env_obs_keys = [ - key for key in env_obs_keys if key not in self._non_tensor_keys - ] - env_input_keys = [ - key for key in env_input_keys if key not in self._non_tensor_keys - ] - env_output_keys = [ - key for key in env_output_keys if key not in self._non_tensor_keys - ] self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) + self._env_obs_keys = [ + key for key in self._env_obs_keys if key not in self._non_tensor_keys + ] + self._env_input_keys = [ + key for key in self._env_input_keys if key not in self._non_tensor_keys + ] + self._env_output_keys = [ + key for key in self._env_output_keys if key not in self._non_tensor_keys + ] + reset_keys = self.reset_keys self._selected_keys = ( set(self._env_output_keys)