From 8d16c12bd783c4e36dc24dca56c7cc24f115d37c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:38 +0100 Subject: [PATCH] [Feature] Composite.pop ghstack-source-id: 64d5bd736657ef56e37d57726dfcfd25b16b699f Pull Request resolved: https://github.com/pytorch/rl/pull/2598 --- torchrl/data/tensor_specs.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 5404beb0ec0..b701b2f6bf7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4311,6 +4311,34 @@ def ndim(self): def ndimension(self): return len(self.shape) + def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any: + """Removes and returns the value associated with the specified key from the composite spec. + + This method searches for the given key in the composite spec, removes it, and returns its associated value. + If the key is not found, it returns the provided default value if specified, otherwise raises a `KeyError`. + + Args: + key (NestedKey): + The key to be removed from the composite spec. It can be a single key or a nested key. + default (Any, optional): + The value to return if the specified key is not found in the composite spec. + If not provided and the key is not found, a `KeyError` is raised. + + Returns: + Any: The value associated with the specified key that was removed from the composite spec. + + Raises: + KeyError: If the specified key is not found in the composite spec and no default value is provided. + """ + key = unravel_key(key) + if key in self.keys(True, True): + result = self[key] + del self[key] + return result + elif default is not NO_DEFAULT: + return default + raise KeyError(f"{key} not found in composite spec.") + def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.")