Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 15, 2024
1 parent 45764b5 commit f8ba8b5
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3667,13 +3667,18 @@ class CatTensors(Transform):
the transform is used. This behaviour will only work if a parent is set.
out_key (NestedKey): key of the resulting tensor.
dim (int, optional): dimension along which the concatenation will occur.
Default is -1.
Default is ``-1``.
Keyword Args:
del_keys (bool, optional): if ``True``, the input values will be deleted after
concatenation. Default is True.
concatenation. Default is ``True``.
unsqueeze_if_oor (bool, optional): if ``True``, CatTensor will check that
the dimension indicated exist for the tensors to concatenate. If not,
the tensors will be unsqueezed along that dimension.
Default is ``False``.
sort (bool, optional): if ``True``, the keys will be sorted in the
transform. Otherwise, the order provided by the user will prevail.
Defaults to ``True``.
Examples:
>>> transform = CatTensors(in_keys=["key1", "key2"])
Expand All @@ -3698,16 +3703,18 @@ def __init__(
in_keys: Sequence[NestedKey] | None = None,
out_key: NestedKey = "observation_vector",
dim: int = -1,
*,
del_keys: bool = True,
unsqueeze_if_oor: bool = False,
sort: bool = True,
):
self._initialized = in_keys is not None
if not self._initialized:
if dim != -1:
raise ValueError(
"Lazy call to CatTensors is only supported when `dim=-1`."
)
else:
elif sort:
in_keys = sorted(in_keys, key=_sort_keys)
if not isinstance(out_key, (str, tuple)):
raise Exception("CatTensors requires out_key to be of type NestedKey")
Expand Down

0 comments on commit f8ba8b5

Please sign in to comment.