diff --git a/test/test_transforms.py b/test/test_transforms.py index 16e68a28bf4..9ca0af99b8a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -762,16 +762,16 @@ def test_transform_env_clone(self): ).all() assert cloned is not env.transform - @pytest.mark.parametrize("dim", [-2, -1]) + @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) + @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) def test_transform_model(self, dim, N, padding): # test equivalence between transforms within an env and within a rb key1 = "observation" keys = [key1] out_keys = ["out_" + key1] cat_frames = CatFrames( - N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding + N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding ) cat_frames2 = CatFrames( N=N, @@ -781,23 +781,22 @@ def test_transform_model(self, dim, N, padding): padding=padding, ) envbase = ContinuousActionVecMockEnv() - env = TransformedEnv( - envbase, - Compose( - UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames - ), - ) + env = TransformedEnv(envbase, cat_frames) + torch.manual_seed(10) env.set_seed(10) td = env.rollout(10) + torch.manual_seed(10) envbase.set_seed(10) tdbase = envbase.rollout(10) + tdbase0 = tdbase.clone() model = nn.Sequential(cat_frames2, nn.Identity()) model(tdbase) - assert (td == tdbase).all() + assert assert_allclose_td(td, tdbase) + with pytest.warns(UserWarning): tdbase0.names = None model(tdbase0) @@ -816,7 +815,7 @@ def test_transform_model(self, dim, N, padding): # check that swapping dims and names leads to same result assert_allclose_td(v1, v2.transpose(0, 1)) - @pytest.mark.parametrize("dim", [-2, -1]) + @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) @@ -826,7 +825,7 @@ def test_transform_rb(self, dim, N, padding, rbclass): keys = [key1] out_keys = ["out_" + key1] cat_frames = CatFrames( - N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding + N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding ) cat_frames2 = CatFrames( N=N, @@ -836,12 +835,7 @@ def test_transform_rb(self, dim, N, padding, rbclass): padding=padding, ) - env = TransformedEnv( - ContinuousActionVecMockEnv(), - Compose( - UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames - ), - ) + env = TransformedEnv(ContinuousActionVecMockEnv(), cat_frames) td = env.rollout(10) rb = rbclass(storage=LazyTensorStorage(20)) @@ -875,8 +869,8 @@ def test_transform_as_inverse(self, dim, N, padding): td = env1.rollout(rollout_length) transformed_td = cat_frames._inv_call(td) - assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim, N) - assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim, N) + assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim * N) + assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim * N) with pytest.raises( Exception, match="CatFrames as inverse is not supported as a transform for environments, only for replay buffers.", @@ -971,14 +965,48 @@ def test_transform_no_env(self, device, d, batch_size, dim, N): # we don't want the same tensor to be returned twice, but they're all copies of the same buffer assert v1 is not v2 + @pytest.mark.skipif(not _has_gym, reason="gym required for this test") + @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) + def test_tranform_offline_against_online(self, padding): + torch.manual_seed(0) + env = SerialEnv( + 3, + lambda: TransformedEnv( + GymEnv("CartPole-v1"), + CatFrames( + dim=-1, + N=5, + in_keys=["observation"], + out_keys=["observation_cat"], + padding=padding, + ), + ), + ) + env.set_seed(0) + + r = env.rollout(100, break_when_any_done=False) + + c = CatFrames( + dim=-1, + N=5, + in_keys=["observation", ("next", "observation")], + out_keys=["observation_cat2", ("next", "observation_cat2")], + padding=padding, + ) + + r2 = c(r) + + torch.testing.assert_close(r2["observation_cat2"], r2["observation_cat"]) + assert (r2["observation_cat2"] == r2["observation_cat"]).all() + + assert (r2["next", "observation_cat2"] == r2["next", "observation_cat"]).all() + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)]) @pytest.mark.parametrize("d", range(2, 3)) @pytest.mark.parametrize( "dim", - [ - -3, - ], + [-3], ) @pytest.mark.parametrize("N", [2, 4]) def test_transform_compose(self, device, d, batch_size, dim, N): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f94eb548d15..c6aecc74312 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2620,6 +2620,8 @@ class CatFrames(ObservationTransform): reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise. + done_key (NestedKey, optional): the done key to be used as partial + done indicator. Must be unique. If not provided, defaults to ``"done"``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -2700,6 +2702,7 @@ def __init__( padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, + done_key: NestedKey | None = None, ): if in_keys is None: in_keys = IMAGE_KEYS @@ -2733,6 +2736,19 @@ def __init__( # keeps track of calls to _reset since it's only _call that will populate the buffer self.as_inverse = as_inverse self.reset_key = reset_key + self.done_key = done_key + + @property + def done_key(self): + done_key = self.__dict__.get("_done_key", None) + if done_key is None: + done_key = "done" + self._done_key = done_key + return done_key + + @done_key.setter + def done_key(self, value): + self._done_key = value @property def reset_key(self): @@ -2829,15 +2845,6 @@ def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase: # make linter happy. An exception has already been raised raise NotImplementedError - # # this duplicates the code below, but only for _reset values - # if _all: - # buffer.copy_(torch.roll(buffer_reset, shifts=-d, dims=dim)) - # buffer_reset = buffer - # else: - # buffer_reset = buffer[_reset] = torch.roll( - # buffer_reset, shifts=-d, dims=dim - # ) - # add new obs if self.dim < 0: n = buffer_reset.ndimension() + self.dim else: @@ -2906,69 +2913,136 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: if i != tensordict.ndim - 1: tensordict = tensordict.transpose(tensordict.ndim - 1, i) # first sort the in_keys with strings and non-strings - in_keys = list( - zip( - (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) - if isinstance(in_key, str) or len(in_key) == 1 - ) - ) - in_keys += list( - zip( - (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) - if not isinstance(in_key, str) and not len(in_key) == 1 + keys = [ + (in_key, out_key) + for in_key, out_key in zip(self.in_keys, self.out_keys) + if isinstance(in_key, str) + ] + keys += [ + (in_key, out_key) + for in_key, out_key in zip(self.in_keys, self.out_keys) + if not isinstance(in_key, str) + ] + + def unfold_done(done, N): + prefix = (slice(None),) * (tensordict.ndim - 1) + reset = torch.cat( + [ + torch.zeros_like(done[prefix + (slice(self.N - 1),)]), + torch.ones_like(done[prefix + (slice(1),)]), + done[prefix + (slice(None, -1),)], + ], + tensordict.ndim - 1, ) - ) - for in_key, out_key in zip(self.in_keys, self.out_keys): + reset_unfold = reset.unfold(tensordict.ndim - 1, self.N, 1) + reset_unfold_slice = reset_unfold[..., -1] + reset_unfold_list = [torch.zeros_like(reset_unfold_slice)] + for r in reversed(reset_unfold.unbind(-1)): + reset_unfold_list.append(r | reset_unfold_list[-1]) + reset_unfold_slice = reset_unfold_list[-1] + reset_unfold = torch.stack(list(reversed(reset_unfold_list))[1:], -1) + reset = reset[prefix + (slice(self.N - 1, None),)] + reset[prefix + (0,)] = 1 + return reset_unfold, reset + + done = tensordict.get(("next", self.done_key)) + done_mask, reset = unfold_done(done, self.N) + + for in_key, out_key in keys: # check if we have an obs in "next" that has already been processed. # If so, we must add an offset - data = tensordict.get(in_key) + data_orig = data = tensordict.get(in_key) + n_feat = data_orig.shape[data.ndim + self.dim] + first_val = None if isinstance(in_key, tuple) and in_key[0] == "next": # let's get the out_key we have already processed - prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]] - prev_val = tensordict.get(prev_out_key) - # the first item is located along `dim+1` at the last index of the - # first time index - idx = ( - [slice(None)] * (tensordict.ndim - 1) - + [0] - + [..., -1] - + [slice(None)] * (abs(self.dim) - 1) + prev_out_key = dict(zip(self.in_keys, self.out_keys)).get( + in_key[1], None ) - first_val = prev_val[tuple(idx)].unsqueeze(tensordict.ndim - 1) - data0 = [first_val] * (self.N - 1) - if self.padding == "constant": - data0 = [ - torch.full_like(elt, self.padding_value) for elt in data0[:-1] - ] + data0[-1:] - elif self.padding == "same": - pass - else: - # make linter happy. An exception has already been raised - raise NotImplementedError - elif self.padding == "same": - idx = [slice(None)] * (tensordict.ndim - 1) + [0] - data0 = [data[tuple(idx)].unsqueeze(tensordict.ndim - 1)] * (self.N - 1) - elif self.padding == "constant": - idx = [slice(None)] * (tensordict.ndim - 1) + [0] - data0 = [ - torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze( - tensordict.ndim - 1 + if prev_out_key is not None: + prev_val = tensordict.get(prev_out_key) + # n_feat = prev_val.shape[data.ndim + self.dim] // self.N + first_val = prev_val.unflatten( + data.ndim + self.dim, (self.N, n_feat) ) - ] * (self.N - 1) - else: - # make linter happy. An exception has already been raised - raise NotImplementedError + + idx = [slice(None)] * (tensordict.ndim - 1) + [0] + data0 = [ + torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze( + tensordict.ndim - 1 + ) + ] * (self.N - 1) data = torch.cat(data0 + [data], tensordict.ndim - 1) data = data.unfold(tensordict.ndim - 1, self.N, 1) + + # Place -1 dim at self.dim place before squashing + done_mask_expand = expand_as_right(done_mask, data) data = data.permute( - *range(0, data.ndim + self.dim), + *range(0, data.ndim + self.dim - 1), + -1, + *range(data.ndim + self.dim - 1, data.ndim - 1), + ) + done_mask_expand = done_mask_expand.permute( + *range(0, done_mask_expand.ndim + self.dim - 1), -1, - *range(data.ndim + self.dim, data.ndim - 1), + *range(done_mask_expand.ndim + self.dim - 1, done_mask_expand.ndim - 1), ) + if self.padding != "same": + data = torch.where(done_mask_expand, self.padding_value, data) + else: + # TODO: This is a pretty bad implementation, could be + # made more efficient but it works! + reset_vals = list(data_orig[reset.squeeze(-1)].unbind(0)) + j_ = float("inf") + reps = [] + d = data.ndim + self.dim - 1 + for j in done_mask_expand.sum(d).sum(d).view(-1) // n_feat: + if j > j_: + reset_vals = reset_vals[1:] + reps.extend([reset_vals[0]] * int(j)) + j_ = j + reps = torch.stack(reps) + data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1)) + + if first_val is not None: + # Aggregate reset along last dim + reset = reset.any(-1, True) + rexp = reset.expand(*reset.shape[:-1], n_feat) + rexp = torch.cat( + [ + torch.zeros_like( + data0[0].repeat_interleave( + len(data0), dim=tensordict.ndim - 1 + ), + dtype=torch.bool, + ), + rexp, + ], + tensordict.ndim - 1, + ) + rexp = rexp.unfold(tensordict.ndim - 1, self.N, 1) + rexp_orig = rexp + rexp = torch.cat([rexp[..., 1:], torch.zeros_like(rexp[..., -1:])], -1) + if self.padding == "same": + rexp_orig = rexp_orig.flip(-1).cumsum(-1).flip(-1).bool() + rexp = rexp.flip(-1).cumsum(-1).flip(-1).bool() + rexp_orig = torch.cat( + [torch.zeros_like(rexp_orig[..., -1:]), rexp_orig[..., 1:]], -1 + ) + rexp = rexp.permute( + *range(0, rexp.ndim + self.dim - 1), + -1, + *range(rexp.ndim + self.dim - 1, rexp.ndim - 1), + ) + rexp_orig = rexp_orig.permute( + *range(0, rexp_orig.ndim + self.dim - 1), + -1, + *range(rexp_orig.ndim + self.dim - 1, rexp_orig.ndim - 1), + ) + data[rexp] = first_val[rexp_orig] + data = data.flatten(data.ndim + self.dim - 1, data.ndim + self.dim) tensordict.set(out_key, data) if tensordict_orig is not tensordict: tensordict_orig = tensordict.transpose(tensordict.ndim - 1, i)