Skip to content

Commit

Permalink
[BugFix] Fix offline CatFrames (#1953)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 24, 2024
1 parent 249b811 commit 931f70a
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 81 deletions.
74 changes: 51 additions & 23 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,16 +762,16 @@ def test_transform_env_clone(self):
).all()
assert cloned is not env.transform

@pytest.mark.parametrize("dim", [-2, -1])
@pytest.mark.parametrize("dim", [-1])
@pytest.mark.parametrize("N", [3, 4])
@pytest.mark.parametrize("padding", ["same", "zeros", "constant"])
@pytest.mark.parametrize("padding", ["zeros", "constant", "same"])
def test_transform_model(self, dim, N, padding):
# test equivalence between transforms within an env and within a rb
key1 = "observation"
keys = [key1]
out_keys = ["out_" + key1]
cat_frames = CatFrames(
N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding
N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding
)
cat_frames2 = CatFrames(
N=N,
Expand All @@ -781,23 +781,22 @@ def test_transform_model(self, dim, N, padding):
padding=padding,
)
envbase = ContinuousActionVecMockEnv()
env = TransformedEnv(
envbase,
Compose(
UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames
),
)
env = TransformedEnv(envbase, cat_frames)

torch.manual_seed(10)
env.set_seed(10)
td = env.rollout(10)

torch.manual_seed(10)
envbase.set_seed(10)
tdbase = envbase.rollout(10)

tdbase0 = tdbase.clone()

model = nn.Sequential(cat_frames2, nn.Identity())
model(tdbase)
assert (td == tdbase).all()
assert assert_allclose_td(td, tdbase)

with pytest.warns(UserWarning):
tdbase0.names = None
model(tdbase0)
Expand All @@ -816,7 +815,7 @@ def test_transform_model(self, dim, N, padding):
# check that swapping dims and names leads to same result
assert_allclose_td(v1, v2.transpose(0, 1))

@pytest.mark.parametrize("dim", [-2, -1])
@pytest.mark.parametrize("dim", [-1])
@pytest.mark.parametrize("N", [3, 4])
@pytest.mark.parametrize("padding", ["same", "zeros", "constant"])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
Expand All @@ -826,7 +825,7 @@ def test_transform_rb(self, dim, N, padding, rbclass):
keys = [key1]
out_keys = ["out_" + key1]
cat_frames = CatFrames(
N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding
N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding
)
cat_frames2 = CatFrames(
N=N,
Expand All @@ -836,12 +835,7 @@ def test_transform_rb(self, dim, N, padding, rbclass):
padding=padding,
)

env = TransformedEnv(
ContinuousActionVecMockEnv(),
Compose(
UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames
),
)
env = TransformedEnv(ContinuousActionVecMockEnv(), cat_frames)
td = env.rollout(10)

rb = rbclass(storage=LazyTensorStorage(20))
Expand Down Expand Up @@ -875,8 +869,8 @@ def test_transform_as_inverse(self, dim, N, padding):
td = env1.rollout(rollout_length)

transformed_td = cat_frames._inv_call(td)
assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim, N)
assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim, N)
assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim * N)
assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim * N)
with pytest.raises(
Exception,
match="CatFrames as inverse is not supported as a transform for environments, only for replay buffers.",
Expand Down Expand Up @@ -971,14 +965,48 @@ def test_transform_no_env(self, device, d, batch_size, dim, N):
# we don't want the same tensor to be returned twice, but they're all copies of the same buffer
assert v1 is not v2

@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
@pytest.mark.parametrize("padding", ["zeros", "constant", "same"])
def test_tranform_offline_against_online(self, padding):
torch.manual_seed(0)
env = SerialEnv(
3,
lambda: TransformedEnv(
GymEnv("CartPole-v1"),
CatFrames(
dim=-1,
N=5,
in_keys=["observation"],
out_keys=["observation_cat"],
padding=padding,
),
),
)
env.set_seed(0)

r = env.rollout(100, break_when_any_done=False)

c = CatFrames(
dim=-1,
N=5,
in_keys=["observation", ("next", "observation")],
out_keys=["observation_cat2", ("next", "observation_cat2")],
padding=padding,
)

r2 = c(r)

torch.testing.assert_close(r2["observation_cat2"], r2["observation_cat"])
assert (r2["observation_cat2"] == r2["observation_cat"]).all()

assert (r2["next", "observation_cat2"] == r2["next", "observation_cat"]).all()

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)])
@pytest.mark.parametrize("d", range(2, 3))
@pytest.mark.parametrize(
"dim",
[
-3,
],
[-3],
)
@pytest.mark.parametrize("N", [2, 4])
def test_transform_compose(self, device, d, batch_size, dim, N):
Expand Down
190 changes: 132 additions & 58 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,6 +2620,8 @@ class CatFrames(ObservationTransform):
reset indicator. Must be unique. If not provided, defaults to the
only reset key of the parent environment (if it has only one)
and raises an exception otherwise.
done_key (NestedKey, optional): the done key to be used as partial
done indicator. Must be unique. If not provided, defaults to ``"done"``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -2700,6 +2702,7 @@ def __init__(
padding_value=0,
as_inverse=False,
reset_key: NestedKey | None = None,
done_key: NestedKey | None = None,
):
if in_keys is None:
in_keys = IMAGE_KEYS
Expand Down Expand Up @@ -2733,6 +2736,19 @@ def __init__(
# keeps track of calls to _reset since it's only _call that will populate the buffer
self.as_inverse = as_inverse
self.reset_key = reset_key
self.done_key = done_key

@property
def done_key(self):
done_key = self.__dict__.get("_done_key", None)
if done_key is None:
done_key = "done"
self._done_key = done_key
return done_key

@done_key.setter
def done_key(self, value):
self._done_key = value

@property
def reset_key(self):
Expand Down Expand Up @@ -2829,15 +2845,6 @@ def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
# make linter happy. An exception has already been raised
raise NotImplementedError

# # this duplicates the code below, but only for _reset values
# if _all:
# buffer.copy_(torch.roll(buffer_reset, shifts=-d, dims=dim))
# buffer_reset = buffer
# else:
# buffer_reset = buffer[_reset] = torch.roll(
# buffer_reset, shifts=-d, dims=dim
# )
# add new obs
if self.dim < 0:
n = buffer_reset.ndimension() + self.dim
else:
Expand Down Expand Up @@ -2906,69 +2913,136 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
if i != tensordict.ndim - 1:
tensordict = tensordict.transpose(tensordict.ndim - 1, i)
# first sort the in_keys with strings and non-strings
in_keys = list(
zip(
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
if isinstance(in_key, str) or len(in_key) == 1
)
)
in_keys += list(
zip(
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
if not isinstance(in_key, str) and not len(in_key) == 1
keys = [
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
if isinstance(in_key, str)
]
keys += [
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
if not isinstance(in_key, str)
]

def unfold_done(done, N):
prefix = (slice(None),) * (tensordict.ndim - 1)
reset = torch.cat(
[
torch.zeros_like(done[prefix + (slice(self.N - 1),)]),
torch.ones_like(done[prefix + (slice(1),)]),
done[prefix + (slice(None, -1),)],
],
tensordict.ndim - 1,
)
)
for in_key, out_key in zip(self.in_keys, self.out_keys):
reset_unfold = reset.unfold(tensordict.ndim - 1, self.N, 1)
reset_unfold_slice = reset_unfold[..., -1]
reset_unfold_list = [torch.zeros_like(reset_unfold_slice)]
for r in reversed(reset_unfold.unbind(-1)):
reset_unfold_list.append(r | reset_unfold_list[-1])
reset_unfold_slice = reset_unfold_list[-1]
reset_unfold = torch.stack(list(reversed(reset_unfold_list))[1:], -1)
reset = reset[prefix + (slice(self.N - 1, None),)]
reset[prefix + (0,)] = 1
return reset_unfold, reset

done = tensordict.get(("next", self.done_key))
done_mask, reset = unfold_done(done, self.N)

for in_key, out_key in keys:
# check if we have an obs in "next" that has already been processed.
# If so, we must add an offset
data = tensordict.get(in_key)
data_orig = data = tensordict.get(in_key)
n_feat = data_orig.shape[data.ndim + self.dim]
first_val = None
if isinstance(in_key, tuple) and in_key[0] == "next":
# let's get the out_key we have already processed
prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]]
prev_val = tensordict.get(prev_out_key)
# the first item is located along `dim+1` at the last index of the
# first time index
idx = (
[slice(None)] * (tensordict.ndim - 1)
+ [0]
+ [..., -1]
+ [slice(None)] * (abs(self.dim) - 1)
prev_out_key = dict(zip(self.in_keys, self.out_keys)).get(
in_key[1], None
)
first_val = prev_val[tuple(idx)].unsqueeze(tensordict.ndim - 1)
data0 = [first_val] * (self.N - 1)
if self.padding == "constant":
data0 = [
torch.full_like(elt, self.padding_value) for elt in data0[:-1]
] + data0[-1:]
elif self.padding == "same":
pass
else:
# make linter happy. An exception has already been raised
raise NotImplementedError
elif self.padding == "same":
idx = [slice(None)] * (tensordict.ndim - 1) + [0]
data0 = [data[tuple(idx)].unsqueeze(tensordict.ndim - 1)] * (self.N - 1)
elif self.padding == "constant":
idx = [slice(None)] * (tensordict.ndim - 1) + [0]
data0 = [
torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze(
tensordict.ndim - 1
if prev_out_key is not None:
prev_val = tensordict.get(prev_out_key)
# n_feat = prev_val.shape[data.ndim + self.dim] // self.N
first_val = prev_val.unflatten(
data.ndim + self.dim, (self.N, n_feat)
)
] * (self.N - 1)
else:
# make linter happy. An exception has already been raised
raise NotImplementedError

idx = [slice(None)] * (tensordict.ndim - 1) + [0]
data0 = [
torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze(
tensordict.ndim - 1
)
] * (self.N - 1)

data = torch.cat(data0 + [data], tensordict.ndim - 1)

data = data.unfold(tensordict.ndim - 1, self.N, 1)

# Place -1 dim at self.dim place before squashing
done_mask_expand = expand_as_right(done_mask, data)
data = data.permute(
*range(0, data.ndim + self.dim),
*range(0, data.ndim + self.dim - 1),
-1,
*range(data.ndim + self.dim - 1, data.ndim - 1),
)
done_mask_expand = done_mask_expand.permute(
*range(0, done_mask_expand.ndim + self.dim - 1),
-1,
*range(data.ndim + self.dim, data.ndim - 1),
*range(done_mask_expand.ndim + self.dim - 1, done_mask_expand.ndim - 1),
)
if self.padding != "same":
data = torch.where(done_mask_expand, self.padding_value, data)
else:
# TODO: This is a pretty bad implementation, could be
# made more efficient but it works!
reset_vals = list(data_orig[reset.squeeze(-1)].unbind(0))
j_ = float("inf")
reps = []
d = data.ndim + self.dim - 1
for j in done_mask_expand.sum(d).sum(d).view(-1) // n_feat:
if j > j_:
reset_vals = reset_vals[1:]
reps.extend([reset_vals[0]] * int(j))
j_ = j
reps = torch.stack(reps)
data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1))

if first_val is not None:
# Aggregate reset along last dim
reset = reset.any(-1, True)
rexp = reset.expand(*reset.shape[:-1], n_feat)
rexp = torch.cat(
[
torch.zeros_like(
data0[0].repeat_interleave(
len(data0), dim=tensordict.ndim - 1
),
dtype=torch.bool,
),
rexp,
],
tensordict.ndim - 1,
)
rexp = rexp.unfold(tensordict.ndim - 1, self.N, 1)
rexp_orig = rexp
rexp = torch.cat([rexp[..., 1:], torch.zeros_like(rexp[..., -1:])], -1)
if self.padding == "same":
rexp_orig = rexp_orig.flip(-1).cumsum(-1).flip(-1).bool()
rexp = rexp.flip(-1).cumsum(-1).flip(-1).bool()
rexp_orig = torch.cat(
[torch.zeros_like(rexp_orig[..., -1:]), rexp_orig[..., 1:]], -1
)
rexp = rexp.permute(
*range(0, rexp.ndim + self.dim - 1),
-1,
*range(rexp.ndim + self.dim - 1, rexp.ndim - 1),
)
rexp_orig = rexp_orig.permute(
*range(0, rexp_orig.ndim + self.dim - 1),
-1,
*range(rexp_orig.ndim + self.dim - 1, rexp_orig.ndim - 1),
)
data[rexp] = first_val[rexp_orig]
data = data.flatten(data.ndim + self.dim - 1, data.ndim + self.dim)
tensordict.set(out_key, data)
if tensordict_orig is not tensordict:
tensordict_orig = tensordict.transpose(tensordict.ndim - 1, i)
Expand Down

0 comments on commit 931f70a

Please sign in to comment.