diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b701b2f6bf7..32e61bc3ede 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4339,6 +4339,33 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any: return default raise KeyError(f"{key} not found in composite spec.") + def separates(self, *keys: NestedKey, default: Any = None) -> Composite: + """Splits the composite spec by extracting specified keys and their associated values into a new composite spec. + + This method iterates over the provided keys, removes them from the current composite spec, and adds them to a new + composite spec. If a key is not found, the specified default value is used. The new composite spec is returned. + + Args: + *keys (NestedKey): + One or more keys to be extracted from the composite spec. Each key can be a single key or a nested key. + default (Any, optional): + The value to use if a specified key is not found in the composite spec. Defaults to `None`. + + Returns: + Composite: A new composite spec containing the extracted keys and their associated values. + + Note: + If none of the specified keys are found, the method returns `None`. + """ + out = None + for key in keys: + result = self.pop(key, default=default) + if result is not None: + if out is None: + out = Composite(batch_size=self.batch_size, device=self.device) + out[key] = result + return out + def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.")