From bd7e268aacc1db08c0e2a446a03f114712c6da20 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 15 Feb 2024 15:28:25 +0000 Subject: [PATCH] [BugFix] Non exclusive terminated and truncated (#1911) --- torchrl/envs/transforms/transforms.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 40df963ec5e..a96d67fbfa5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5163,11 +5163,10 @@ def _reset( def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - for step_count_key, truncated_key, done_key, terminated_key in zip( + for step_count_key, truncated_key, done_key in zip( self.step_count_keys, self.truncated_keys, self.done_keys, - self.terminated_keys, ): step_count = tensordict.get(step_count_key) next_step_count = step_count + 1 @@ -5178,9 +5177,12 @@ def _step( truncated = truncated | next_tensordict.get(truncated_key, False) if self.update_done: done = next_tensordict.get(done_key, None) - terminated = next_tensordict.get(terminated_key, None) - if terminated is not None: - truncated = truncated & ~terminated + + # we can have terminated and truncated + # terminated = next_tensordict.get(terminated_key, None) + # if terminated is not None: + # truncated = truncated & ~terminated + done = truncated | done # we assume no done after reset next_tensordict.set(done_key, done) next_tensordict.set(truncated_key, truncated)