Skip to content

Commit

Permalink
rename to CatFrames
Browse files Browse the repository at this point in the history
  • Loading branch information
btx0424 committed Feb 7, 2024
1 parent 02920ed commit 6c1109e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 42 deletions.
2 changes: 1 addition & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
VecNorm,
VIPRewardTransform,
VIPTransform,
History
StackFrames
)
from .utils import (
check_env_specs,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
UnsqueezeTransform,
VecGymEnvTransform,
VecNorm,
History
StackFrames
)
from .vc1 import VC1Transform
from .vip import VIPRewardTransform, VIPTransform
97 changes: 57 additions & 40 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7491,38 +7491,48 @@ def _reset(
forward = _call


class History(Transform):
"""Records and maintains a history of specified observations.
class StackFrames(Transform):
"""Stacks successive observation frames into a single tensor.
This transform tracks the history of selected observations over a specified number of steps.
It takes a list of observation keys and records their values, maintaining a history buffer
of length 'steps'. The shape of each recorded observation in the output is extended to
`[*shape, steps]`, where `shape` is the original shape of the observation. The most recent
observation is indexed at `[..., -1]`.
This transform stacks the history of selected observations over a specified number of steps.
which can be useful for inferring the state in a partially observable environment. The shape
of each stacked observation in the output is extended to `[*shape, steps]`, where `shape` is
the original shape of the observation. The most recent observation is indexed at `[..., -1]`.
Note that this transform is stateless. Also see :class:`CatFrames`.
Args:
N (int): Number of steps for which the observation history is maintained.
in_keys (list of NestedKeys, optional): Keys of the observations in the environment's
observation spec that need to be recorded.
out_keys (list of NestedKeys, optional): Keys under which the recorded observation histories
will be stored in the output. Defaults to `f"{in_key}_h"` for each key in `in_keys`.
steps (int): Number of steps for which the observation history is maintained.
include_last (bool): If True, includes the current step's observation in the history.
padding (str, optional): the padding method. One of ``"same"`` or ``"constant"``.
Defaults to ``"same"``, ie. the first value is used for padding.
padding_value (float, optional): the value to use for padding if ``padding="constant"``.
Defaults to 0.
Examples:
>>> from torchrl.envs.transforms import TransformedEnv, History
>>> from torchrl.envs.transforms import TransformedEnv, StackFrames
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("CartPole-v1"), History(["observation"]))
>>> env = TransformedEnv(GymEnv("CartPole-v1"), StackFrames(["observation"]))
>>> td = env.reset()
>>> print(td["observation_h"].shape)
torch.Size([4, 16])
"""

ACCEPTED_PADDING = {"same", "constant", "zeros"}

def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey] = None,
steps: int = 16,
include_last: bool = True,
N: int = 1,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
padding="same",
padding_value=0,
):
if in_keys is None:
in_keys = ["observation"]
if out_keys is None:
out_keys = [
f"{key}_h" if isinstance(key, str) else key[:-1] + (f"{key[-1]}_h",)
Expand All @@ -7533,34 +7543,42 @@ def __init__(
f"out_keys {out_keys} cannot duplicate with in_keys {in_keys}"
)
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.steps = steps
self.include_last = include_last
self.N = N
if padding not in self.ACCEPTED_PADDING:
raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}")
if padding == "zeros":
warnings.warn(
"Padding option 'zeros' will be deprecated in v0.4.0. "
"Please use 'constant' padding with padding_value 0 instead.",
category=DeprecationWarning,
)
padding = "constant"
padding_value = 0
self.padding = padding
self.padding_value = padding_value

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
for in_key, out_key in zip(self.in_keys, self.out_keys):
is_tuple = isinstance(in_key, tuple)
if in_key in observation_spec.keys(include_nested=is_tuple):
spec = observation_spec[in_key]
spec = spec.unsqueeze(-1).expand(*spec.shape, self.steps)
spec = spec.unsqueeze(-1).expand(*spec.shape, self.N)
observation_spec[out_key] = spec
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
state_spec = input_spec["full_state_spec"]
for in_key, out_key in zip(self.in_keys, self.out_keys):
spec = self.parent.observation_spec[in_key]
state_spec[out_key] = spec.unsqueeze(-1).expand(*spec.shape, self.steps)
state_spec[out_key] = spec.unsqueeze(-1).expand(*spec.shape, self.N)
input_spec["full_state_spec"] = state_spec
return input_spec

def _step(self, tensordict: TensorDictBase, next_tensordict: TensorDictBase) -> TensorDictBase:
for in_key, out_key in zip(self.in_keys, self.out_keys):
if self.include_last:
val = next_tensordict.get(in_key)
else:
val = tensordict.get(in_key)
val_history = tensordict.get(out_key)
val_next = torch.cat([val_history[..., 1:], val.unsqueeze(-1)], dim=-1)
current = next_tensordict.get(in_key)
prev_stacked = tensordict.get(out_key)
val_next = torch.cat([prev_stacked[..., 1:], current.unsqueeze(-1)], dim=-1)
next_tensordict.set(out_key, val_next)
return next_tensordict

Expand All @@ -7572,21 +7590,20 @@ def _reset(
_reset = torch.ones(tensordict.batch_size, dtype=bool, device=tensordict.device)
for in_key, out_key in zip(self.in_keys, self.out_keys):
# get previous observations
val_history = tensordict.get(out_key, None)
if val_history is None:
current = tensordict_reset.get(in_key)
prev_stacked = tensordict.get(out_key, None)
if prev_stacked is None:
spec = self.parent.full_observation_spec[in_key]
val_history = spec.unsqueeze(-1).expand(*spec.shape, self.steps).zero()
stacked = spec.unsqueeze(-1).expand(*spec.shape, self.N).zero()
else:
val_history = val_history.clone()
# reset
if self.include_last:
val_init = tensordict_reset.get(in_key)[_reset.squeeze()]
val_pad = torch.zeros(
*val_init.shape, self.steps-1, dtype=val_init.dtype, device=val_init.device
)
val = torch.cat([val_pad, val_init.unsqueeze(-1)], dim=-1)
else:
val = 0.0
val_history[_reset.squeeze()] = val
tensordict_reset.set(out_key, val_history)
stacked = prev_stacked.clone()
# handle padding
val = current[_reset.squeeze()]
if self.padding == "same":
padding_val = val.unsqueeze(-1).expand(*val.shape, self.N-1)
elif self.padding == "constant":
shape = val.shape + (self.N-1,)
padding_val == torch.full(shape, self.padding_value, dtype=val.dtype, device=val.device)
stacked[_reset.squeeze()] = torch.cat([padding_val, val.unsqueeze(-1)], dim=-1)
tensordict_reset.set(out_key, stacked)
return tensordict_reset

0 comments on commit 6c1109e

Please sign in to comment.