Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 6, 2024
1 parent ff3a350 commit f30d544
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_, keys_to_update=list(self._selected_reset_keys)
)
continue
out = ("reset", tensordict_)
if tensordict_ is not None:
self.shared_tensordicts[i].update_(tensordict_)
out = ("reset", True)
else:
out = ("reset", False)

channel.send(out)
workers.append(i)

Expand Down Expand Up @@ -1522,7 +1527,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
torchrl_logger.info(f"resetting worker {pid}")
if not initialized:
raise RuntimeError("call 'init' before resetting")
cur_td = env.reset(tensordict=data)
# we use 'data' to tell the env whether there is anything
# to read in the input tensordict
cur_td = env.reset(tensordict=root_shared_tensordict if data else None)
shared_tensordict.update_(
cur_td,
keys_to_update=list(_selected_reset_keys),
Expand Down

0 comments on commit f30d544

Please sign in to comment.