Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
btx0424 committed Jan 15, 2024
1 parent 3ebd79c commit fe0dfdd
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6634,21 +6634,22 @@ def _reset(
if _reset is None:
_reset = torch.ones(tensordict.batch_size, dtype=bool, device=tensordict.device)
for in_key, out_key in zip(self.in_keys, self.out_keys):
# get previous observations
val_history = tensordict.get(out_key, None)
if val_history is None:
spec = self.parent.full_observation_spec[in_key]
val_history = spec.unsqueeze(-1).expand(*spec.shape, self.steps).zero()
else:
val_history = val_history.clone()
if self.include_last:
val_init = tensordict_reset.get(in_key)
val_init = val_init[expand_as_right(_reset, val_init)]
val_pad = torch.zeros(
*val_init.shape, self.steps-1, dtype=val_init.dtype, device=val_init.device
)
val = torch.cat([val_pad, val_init.unsqueeze(-1)], dim=-1).flatten()
else:
val = 0.0
val_history[expand_as_right(_reset, val_history)] = val
# reset
if self.include_last:
val_init = tensordict_reset.get(in_key)[_reset.squeeze()]
val_pad = torch.zeros(
*val_init.shape, self.steps-1, dtype=val_init.dtype, device=val_init.device
)
val = torch.cat([val_pad, val_init.unsqueeze(-1)], dim=-1)
else:
val = 0.0
val_history[_reset.squeeze()] = val
tensordict_reset.set(out_key, val_history)
return tensordict_reset

0 comments on commit fe0dfdd

Please sign in to comment.