Skip to content

Commit

Permalink
[BugFix] Non exclusive terminated and truncated (#1911)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 15, 2024
1 parent 45764b5 commit bd7e268
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5163,11 +5163,10 @@ def _reset(
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
for step_count_key, truncated_key, done_key, terminated_key in zip(
for step_count_key, truncated_key, done_key in zip(
self.step_count_keys,
self.truncated_keys,
self.done_keys,
self.terminated_keys,
):
step_count = tensordict.get(step_count_key)
next_step_count = step_count + 1
Expand All @@ -5178,9 +5177,12 @@ def _step(
truncated = truncated | next_tensordict.get(truncated_key, False)
if self.update_done:
done = next_tensordict.get(done_key, None)
terminated = next_tensordict.get(terminated_key, None)
if terminated is not None:
truncated = truncated & ~terminated

# we can have terminated and truncated
# terminated = next_tensordict.get(terminated_key, None)
# if terminated is not None:
# truncated = truncated & ~terminated

done = truncated | done # we assume no done after reset
next_tensordict.set(done_key, done)
next_tensordict.set(truncated_key, truncated)
Expand Down

0 comments on commit bd7e268

Please sign in to comment.