diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 2ef74bb4521..b701b2f6bf7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4257,6 +4257,14 @@ def __new__(cls, *args, **kwargs): cls._locked = False return super().__new__(cls) + @property + def batch_size(self): + return self._shape + + @batch_size.setter + def batch_size(self, value: torch.Size): + self._shape = value + @property def shape(self): return self._shape @@ -4278,8 +4286,22 @@ def shape(self, value: torch.Size): ) self._shape = _size(value) - def is_empty(self): - """Whether the composite spec contains specs or not.""" + def is_empty(self, recurse: bool = False): + """Whether the composite spec contains specs or not. + + Args: + recurse (bool): whether to recursively assess if the spec is empty. + If ``True``, will return ``True`` if there are no leaves. If ``False`` + (default) will return whether there is any spec defined at the root level. + + """ + if recurse: + for spec in self._specs.values(): + if spec is None: + continue + if isinstance(spec, Composite) and spec.is_empty(recurse=True): + continue + return False return len(self._specs) == 0 @property @@ -4289,6 +4311,34 @@ def ndim(self): def ndimension(self): return len(self.shape) + def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any: + """Removes and returns the value associated with the specified key from the composite spec. + + This method searches for the given key in the composite spec, removes it, and returns its associated value. + If the key is not found, it returns the provided default value if specified, otherwise raises a `KeyError`. + + Args: + key (NestedKey): + The key to be removed from the composite spec. It can be a single key or a nested key. + default (Any, optional): + The value to return if the specified key is not found in the composite spec. + If not provided and the key is not found, a `KeyError` is raised. + + Returns: + Any: The value associated with the specified key that was removed from the composite spec. + + Raises: + KeyError: If the specified key is not found in the composite spec and no default value is provided. + """ + key = unravel_key(key) + if key in self.keys(True, True): + result = self[key] + del self[key] + return result + elif default is not NO_DEFAULT: + return default + raise KeyError(f"{key} not found in composite spec.") + def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.")