From f8ba8b5b0377fb117e767a827ce3a304a2396329 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Feb 2024 14:54:18 +0000 Subject: [PATCH] init --- torchrl/envs/transforms/transforms.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 40df963ec5e..65599f1d51f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3667,13 +3667,18 @@ class CatTensors(Transform): the transform is used. This behaviour will only work if a parent is set. out_key (NestedKey): key of the resulting tensor. dim (int, optional): dimension along which the concatenation will occur. - Default is -1. + Default is ``-1``. + + Keyword Args: del_keys (bool, optional): if ``True``, the input values will be deleted after - concatenation. Default is True. + concatenation. Default is ``True``. unsqueeze_if_oor (bool, optional): if ``True``, CatTensor will check that the dimension indicated exist for the tensors to concatenate. If not, the tensors will be unsqueezed along that dimension. Default is ``False``. + sort (bool, optional): if ``True``, the keys will be sorted in the + transform. Otherwise, the order provided by the user will prevail. + Defaults to ``True``. Examples: >>> transform = CatTensors(in_keys=["key1", "key2"]) @@ -3698,8 +3703,10 @@ def __init__( in_keys: Sequence[NestedKey] | None = None, out_key: NestedKey = "observation_vector", dim: int = -1, + *, del_keys: bool = True, unsqueeze_if_oor: bool = False, + sort: bool = True, ): self._initialized = in_keys is not None if not self._initialized: @@ -3707,7 +3714,7 @@ def __init__( raise ValueError( "Lazy call to CatTensors is only supported when `dim=-1`." ) - else: + elif sort: in_keys = sorted(in_keys, key=_sort_keys) if not isinstance(out_key, (str, tuple)): raise Exception("CatTensors requires out_key to be of type NestedKey")