Skip to content

Commit

Permalink
[BugFix] Fix batching envs with non tensor data
Browse files Browse the repository at this point in the history
ghstack-source-id: daba8a95459cfa978da09291757b6380fab4f308
Pull Request resolved: #2674
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 84c3ec3 commit ab4250e
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,19 +730,20 @@ def _create_td(self) -> None:
)
)
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
env_obs_keys = [
key for key in env_obs_keys if key not in self._non_tensor_keys
]
env_input_keys = [
key for key in env_input_keys if key not in self._non_tensor_keys
]
env_output_keys = [
key for key in env_output_keys if key not in self._non_tensor_keys
]
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
self._env_output_keys = sorted(env_output_keys, key=_sort_keys)

self._env_obs_keys = [
key for key in self._env_obs_keys if key not in self._non_tensor_keys
]
self._env_input_keys = [
key for key in self._env_input_keys if key not in self._non_tensor_keys
]
self._env_output_keys = [
key for key in self._env_output_keys if key not in self._non_tensor_keys
]

reset_keys = self.reset_keys
self._selected_keys = (
set(self._env_output_keys)
Expand Down

0 comments on commit ab4250e

Please sign in to comment.