From f30d544cbd67646b8a9ef2963e63a1b268dad016 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 6 Feb 2024 13:58:54 +0000 Subject: [PATCH] init --- torchrl/envs/batched_envs.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9669963cb33..6d7e0f473bc 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1295,7 +1295,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_, keys_to_update=list(self._selected_reset_keys) ) continue - out = ("reset", tensordict_) + if tensordict_ is not None: + self.shared_tensordicts[i].update_(tensordict_) + out = ("reset", True) + else: + out = ("reset", False) + channel.send(out) workers.append(i) @@ -1522,7 +1527,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): torchrl_logger.info(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - cur_td = env.reset(tensordict=data) + # we use 'data' to tell the env whether there is anything + # to read in the input tensordict + cur_td = env.reset(tensordict=root_shared_tensordict if data else None) shared_tensordict.update_( cur_td, keys_to_update=list(_selected_reset_keys),