Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 23, 2024
1 parent b4b5944 commit d14f38f
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4257,6 +4257,14 @@ def __new__(cls, *args, **kwargs):
cls._locked = False
return super().__new__(cls)

@property
def batch_size(self):
return self._shape

@batch_size.setter
def batch_size(self, value: torch.Size):
self._shape = value

@property
def shape(self):
return self._shape
Expand All @@ -4278,8 +4286,22 @@ def shape(self, value: torch.Size):
)
self._shape = _size(value)

def is_empty(self):
"""Whether the composite spec contains specs or not."""
def is_empty(self, recurse: bool = False):
"""Whether the composite spec contains specs or not.
Args:
recurse (bool): whether to recursively assess if the spec is empty.
If ``True``, will return ``True`` if there are no leaves. If ``False``
(default) will return whether there is any spec defined at the root level.
"""
if recurse:
for spec in self._specs.values():
if spec is None:
continue
if isinstance(spec, Composite) and spec.is_empty(recurse=True):
continue
return False
return len(self._specs) == 0

@property
Expand All @@ -4289,6 +4311,34 @@ def ndim(self):
def ndimension(self):
return len(self.shape)

def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any:
"""Removes and returns the value associated with the specified key from the composite spec.
This method searches for the given key in the composite spec, removes it, and returns its associated value.
If the key is not found, it returns the provided default value if specified, otherwise raises a `KeyError`.
Args:
key (NestedKey):
The key to be removed from the composite spec. It can be a single key or a nested key.
default (Any, optional):
The value to return if the specified key is not found in the composite spec.
If not provided and the key is not found, a `KeyError` is raised.
Returns:
Any: The value associated with the specified key that was removed from the composite spec.
Raises:
KeyError: If the specified key is not found in the composite spec and no default value is provided.
"""
key = unravel_key(key)
if key in self.keys(True, True):
result = self[key]
del self[key]
return result
elif default is not NO_DEFAULT:
return default
raise KeyError(f"{key} not found in composite spec.")

def set(self, name, spec):
if self.locked:
raise RuntimeError("Cannot modify a locked Composite.")
Expand Down

0 comments on commit d14f38f

Please sign in to comment.