From bef503fedc26193e2c4111ef1833d8d4b6d062d8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 4 Aug 2024 17:09:26 -0400 Subject: [PATCH] [Feature] TensorDictMap ghstack-source-id: 3b0df1e0dd416a87fbbf65725c9271e579ccd49b Pull Request resolved: https://github.com/pytorch/rl/pull/2306 --- test/test_storage_map.py | 125 ++++++++- torchrl/data/__init__.py | 10 +- torchrl/data/map/__init__.py | 1 + torchrl/data/map/tdstorage.py | 323 ++++++++++++++++++++++++ torchrl/data/replay_buffers/storages.py | 71 +++++- 5 files changed, 517 insertions(+), 13 deletions(-) create mode 100644 torchrl/data/map/tdstorage.py diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 846ebba9e8e..c5f748e2e9d 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import functools import importlib.util import pytest @@ -10,7 +11,15 @@ import torch from tensordict import TensorDict -from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash +from torchrl.data import LazyTensorStorage, ListStorage +from torchrl.data.map import ( + BinaryToDecimal, + QueryModule, + RandomProjectionHash, + SipHash, + TensorDictMap, +) +from torchrl.envs import GymEnv _has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec( "gym", None @@ -114,6 +123,120 @@ def test_query(self, clone, index_key): for i in range(1, 3): assert res[index_key][i].item() != res[index_key][i + 1].item() + def test_query_module(self): + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=SipHash(), + ) + + embedding_storage = LazyTensorStorage(23) + + tensor_dict_storage = TensorDictMap( + query_module=query_module, + storage=embedding_storage, + ) + + index = TensorDict( + { + "key1": torch.Tensor([[-1], [1], [3], [-3]]), + "key2": torch.Tensor([[0], [2], [4], [-4]]), + }, + batch_size=(4,), + ) + + value = TensorDict( + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ) + + tensor_dict_storage[index] = value + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 + + new_index = index.clone(True) + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + retrieve_value = tensor_dict_storage[new_index] + + assert (retrieve_value["index"] == value["index"]).all() + + +class TesttTensorDictMap: + @pytest.mark.parametrize( + "storage_type", + [ + functools.partial(ListStorage, 1000), + functools.partial(LazyTensorStorage, 1000), + ], + ) + def test_map(self, storage_type): + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=SipHash(), + ) + + embedding_storage = storage_type() + + tensor_dict_storage = TensorDictMap( + query_module=query_module, + storage=embedding_storage, + ) + + index = TensorDict( + { + "key1": torch.Tensor([[-1], [1], [3], [-3]]), + "key2": torch.Tensor([[0], [2], [4], [-4]]), + }, + batch_size=(4,), + ) + + value = TensorDict( + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ) + assert not hasattr(tensor_dict_storage, "out_keys") + + tensor_dict_storage[index] = value + if isinstance(embedding_storage, LazyTensorStorage): + assert hasattr(tensor_dict_storage, "out_keys") + else: + assert not hasattr(tensor_dict_storage, "out_keys") + assert tensor_dict_storage._has_lazy_out_keys() + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 + + new_index = index.clone(True) + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + retrieve_value = tensor_dict_storage[new_index] + + assert (retrieve_value["index"] == value["index"]).all() + + @pytest.mark.skipif(not _has_gym, reason="gym not installed") + def test_map_rollout(self): + torch.manual_seed(0) + env = GymEnv("CartPole-v1") + env.set_seed(0) + rollout = env.rollout(100) + source, dest = rollout.exclude("next"), rollout.get("next") + storage = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=["observation", "action"], + ) + storage_indices = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=["observation"], + out_keys=["_index"], + ) + # maps the (obs, action) tuple to a corresponding next state + storage[source] = dest + storage_indices[source] = source + contains = storage.contains(source) + assert len(contains) == rollout.shape[-1] + assert contains.all() + contains = storage.contains(torch.cat([source, source + 1])) + assert len(contains) == rollout.shape[-1] * 2 + assert contains[: rollout.shape[-1]].all() + assert not contains[rollout.shape[-1] :].any() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index c894c15724b..e8b72dad4bb 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -3,7 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash +from .map import ( + BinaryToDecimal, + HashToInt, + QueryModule, + RandomProjectionHash, + SipHash, + TensorDictMap, + TensorMap, +) from .postprocs import MultiStep from .replay_buffers import ( Flat2TED, diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py index 96ca381365e..1fba910884d 100644 --- a/torchrl/data/map/__init__.py +++ b/torchrl/data/map/__init__.py @@ -5,3 +5,4 @@ from .hash import BinaryToDecimal, RandomProjectionHash, SipHash from .query import HashToInt, QueryModule +from .tdstorage import TensorDictMap, TensorMap diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py new file mode 100644 index 00000000000..77dbc229b9e --- /dev/null +++ b/torchrl/data/map/tdstorage.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import abc +import functools +from abc import abstractmethod +from typing import Any, Callable, Dict, Generic, List, TypeVar + +import torch +from tensordict import is_tensor_collection, NestedKey, TensorDictBase +from tensordict.nn.common import TensorDictModuleBase +from torchrl.data.map.hash import RandomProjectionHash, SipHash +from torchrl.data.map.query import QueryModule +from torchrl.data.replay_buffers.storages import ( + _get_default_collate, + LazyTensorStorage, + TensorStorage, +) + +K = TypeVar("K") +V = TypeVar("V") + + +class TensorMap(abc.ABC, Generic[K, V]): + """An Abstraction for implementing different storage. + + This class is for internal use, please use derived classes instead. + """ + + @abstractmethod + def clear(self) -> None: + raise NotImplementedError + + @abstractmethod + def __getitem__(self, item: K) -> V: + raise NotImplementedError + + @abstractmethod + def __setitem__(self, key: K, value: V) -> None: + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def contains(self, item: K) -> torch.Tensor: + raise NotImplementedError + + def __contains__(self, item): + return self.contains(item) + + +class TensorDictMap( + TensorDictModuleBase, TensorMap[TensorDictModuleBase, TensorDictModuleBase] +): + """A Map-Storage for TensorDict. + + This module resembles a storage. It takes a tensordict as its input and + returns another tensordict as output similar to TensorDictModuleBase. However, + it provides additional functionality like python map: + + Keyword Args: + query_module (TensorDictModuleBase): a query module, typically an instance of + :class:`~tensordict.nn.QueryModule`, used to map a set of tensordict + entries to a hash key. + storage (Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]]): + a dictionary representing the map from an index key to a tensor storage. + collate_fn (callable, optional): a function to use to collate samples from the + storage. Defaults to a custom value for each known storage type (stack for + :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` + subtypes and others). + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from typing import cast + >>> from torchrl.data import LazyTensorStorage + >>> query_module = QueryModule( + ... in_keys=["key1", "key2"], + ... index_key="index", + ... ) + >>> embedding_storage = LazyTensorStorage(1000) + >>> tensor_dict_storage = TensorDictMap( + ... query_module=query_module, + ... storage={"out": embedding_storage}, + ... ) + >>> index = TensorDict( + ... { + ... "key1": torch.Tensor([[-1], [1], [3], [-3]]), + ... "key2": torch.Tensor([[0], [2], [4], [-4]]), + ... }, + ... batch_size=(4,), + ... ) + >>> value = TensorDict( + ... {"out": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ... ) + >>> tensor_dict_storage[index] = value + >>> tensor_dict_storage[index] + TensorDict( + fields={ + out: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 + >>> new_index = index.clone(True) + >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + >>> retrieve_value = tensor_dict_storage[new_index] + >>> assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all() + """ + + def __init__( + self, + *, + query_module: QueryModule, + storage: Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], + collate_fn: Callable[[Any], Any] | None = None, + out_keys: List[NestedKey] | None = None, + write_fn: Callable[[Any, Any], Any] | None = None, + ): + super().__init__() + + self.in_keys = query_module.in_keys + if out_keys is not None: + self.out_keys = out_keys + assert not self._has_lazy_out_keys() + + self.query_module = query_module + self.index_key = query_module.index_key + self.storage = storage + self.batch_added = False + if collate_fn is None: + collate_fn = _get_default_collate(self.storage) + self.collate_fn = collate_fn + self.write_fn = write_fn + + @property + def out_keys(self) -> List[NestedKey]: + out_keys = self.__dict__.get("_out_keys_and_lazy") + if out_keys is not None: + return out_keys[0] + storage = self.storage + if isinstance(storage, TensorStorage) and is_tensor_collection( + storage._storage + ): + out_keys = list(storage._storage.keys(True, True)) + self._out_keys_and_lazy = (out_keys, True) + return self.out_keys + raise AttributeError( + f"No out-keys found in the storage of type {type(storage)}" + ) + + @out_keys.setter + def out_keys(self, value): + self._out_keys_and_lazy = (value, False) + + def _has_lazy_out_keys(self): + _out_keys_and_lazy = self.__dict__.get("_out_keys_and_lazy") + if _out_keys_and_lazy is None: + return True + return self._out_keys_and_lazy[1] + + @classmethod + def from_tensordict_pair( + cls, + source, + dest, + in_keys: List[NestedKey], + out_keys: List[NestedKey] | None = None, + storage_constructor: type | None = None, + hash_module: Callable | None = None, + collate_fn: Callable[[Any], Any] | None = None, + write_fn: Callable[[Any, Any], Any] | None = None, + ): + """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. + + Args: + source (TensorDict): An example of source tensordict, used as index in the storage. + dest (TensorDict): An example of dest tensordict, used as data in the storage. + in_keys (List[NestedKey]): a list of keys to use in the map. + out_keys (List[NestedKey]): a list of keys to return in the output tensordict. + All keys absent from out_keys, even if present in ``dest``, will not be stored + in the storage. Defaults to ``None`` (all keys are registered). + storage_constructor (type, optional): a type of tensor storage. + Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. + Other options include :class:`~tensordict.nn.storage.FixedStorage`. + hash_module (Callable, optional): a hash function to use in the :class:`~tensordict.nn.storage.QueryModule`. + Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~tensordict.nn.storage.RandomProjectionHash` + for larger inputs. + collate_fn (callable, optional): a function to use to collate samples from the + storage. Defaults to a custom value for each known storage type (stack for + :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` + subtypes and others). + + Examples: + >>> # The following example requires torchrl and gymnasium to be installed + >>> from torchrl.envs import GymEnv + >>> torch.manual_seed(0) + >>> env = GymEnv("CartPole-v1") + >>> env.set_seed(0) + >>> rollout = env.rollout(100) + >>> source, dest = rollout.exclude("next"), rollout.get("next") + >>> storage = TensorDictMap.from_tensordict_pair( + ... source, dest, + ... in_keys=["observation", "action"], + ... ) + >>> # maps the (obs, action) tuple to a corresponding next state + >>> storage[source] = dest + >>> print(source["_index"]) + tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]) + >>> storage[source] + TensorDict( + fields={ + done: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([14, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([14]), + device=None, + is_shared=False) + + """ + # Build query module + if hash_module is None: + # Count the features, if they're greater than RandomProjectionHash._N_COMPONENTS_DEFAULT + # use that module to project them to that dimensionality. + n_feat = 0 + hash_module = [] + for in_key in in_keys: + n_feat = source[in_key].shape[-1] + if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT: + _hash_module = RandomProjectionHash() + else: + _hash_module = SipHash() + hash_module.append(_hash_module) + query_module = QueryModule(in_keys, hash_module=hash_module) + + # Build key_to_storage + if storage_constructor is None: + storage_constructor = functools.partial(LazyTensorStorage, 1000) + storage = storage_constructor() + result = cls( + query_module=query_module, + storage=storage, + collate_fn=collate_fn, + out_keys=out_keys, + write_fn=write_fn, + ) + return result + + def clear(self) -> None: + for mem in self.storage.values(): + mem.clear() + + def _to_index(self, item: TensorDictBase, extend: bool) -> torch.Tensor: + item = self.query_module(item, extend=extend) + return item[self.index_key] + + def _maybe_add_batch( + self, item: TensorDictBase, value: TensorDictBase | None + ) -> TensorDictBase: + self.batch_added = False + if len(item.batch_size) == 0: + self.batch_added = True + + item = item.unsqueeze(dim=0) + if value is not None: + value = value.unsqueeze(dim=0) + + return item, value + + def _maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase: + if self.batch_added: + item = item.squeeze(dim=0) + return item + + def __getitem__(self, item: TensorDictBase) -> TensorDictBase: + item, _ = self._maybe_add_batch(item, None) + + index = self._to_index(item, extend=False) + + res = self.storage[index] + res = self.collate_fn(res) + res = self._maybe_remove_batch(res) + return res + + def __setitem__(self, item: TensorDictBase, value: TensorDictBase): + if not self._has_lazy_out_keys(): + # TODO: make this work with pytrees and avoid calling select if keys match + value = value.select(*self.out_keys, strict=False) + if self.write_fn is not None: + if len(self): + modifiable = self.contains(item) + if modifiable.any(): + to_modify = (value[modifiable], self[item[modifiable]]) + v1 = self.write_fn(*to_modify) + result = value.empty() + result[modifiable] = v1 + result[~modifiable] = self.write_fn(value[~modifiable]) + value = result + else: + value = self.write_fn(value) + else: + value = self.write_fn(value) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + self.storage.set(index, value) + + def __len__(self): + return len(self.storage) + + def contains(self, item: TensorDictBase) -> torch.Tensor: + item, _ = self._maybe_add_batch(item, None) + index = self._to_index(item, extend=False) + + res = self.storage.contains(index) + res = self._maybe_remove_batch(res) + return res diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 58b1729296d..eb399fbc029 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -185,6 +185,13 @@ def load(self, *args, **kwargs): """Alias for :meth:`~.loads`.""" return self.loads(*args, **kwargs) + def __contains__(self, item): + return self.contains(item) + + @abc.abstractmethod + def contains(self, item): + ... + class ListStorage(Storage): """A storage stored in a list. @@ -194,13 +201,16 @@ class ListStorage(Storage): (like lists, tuples, tensors or tensordicts with non-empty batch-size). Args: - max_size (int): the maximum number of elements stored in the storage. + max_size (int, optional): the maximum number of elements stored in the storage. + If not provided, an unlimited storage is created. """ _default_checkpointer = ListStorageCheckpointer - def __init__(self, max_size: int): + def __init__(self, max_size: int | None = None): + if max_size is None: + max_size = torch.iinfo(torch.int64).max super().__init__(max_size) self._storage = [] @@ -233,7 +243,7 @@ def set( np.ndarray, ), ): - for _cursor, _data in zip(cursor, data): + for _cursor, _data in zip(cursor, data, strict=True): self.set(_cursor, _data, set_cursor=set_cursor) else: raise TypeError( @@ -305,6 +315,20 @@ def __getstate__(self): def __repr__(self): return f"{self.__class__.__name__}(items=[{self._storage[0]}, ...])" + def contains(self, item): + if isinstance(item, int): + if item < 0: + item += len(self._storage) + + return 0 <= item < len(self._storage) + if isinstance(item, torch.Tensor): + return torch.tensor( + [self.contains(elt) for elt in item.tolist()], + dtype=torch.bool, + device=item.device, + ).reshape_as(item) + raise NotImplementedError(f"type {type(item)} is not supported yet.") + class TensorStorage(Storage): """A storage for tensors and tensordicts. @@ -782,6 +806,30 @@ def repr_item(x): maxsize_str = textwrap.indent(f"max_size={self.max_size}", 4 * " ") return f"{self.__class__.__name__}(\n{storage_str}, \n{shape_str}, \n{len_str}, \n{maxsize_str})" + def contains(self, item): + if isinstance(item, int): + if item < 0: + item += self._len_along_dim0 + + return 0 <= item < self._len_along_dim0 + if isinstance(item, torch.Tensor): + + def _is_valid_index(idx): + try: + torch.zeros(self.shape, device="meta")[idx] + return True + except IndexError: + return False + + if item.ndim: + return torch.tensor( + [_is_valid_index(idx) for idx in item], + dtype=torch.bool, + device=item.device, + ) + return torch.tensor(_is_valid_index(item), device=item.device) + raise NotImplementedError(f"type {type(item)} is not supported yet.") + class LazyTensorStorage(TensorStorage): """A pre-allocated tensor storage for tensors and tensordicts. @@ -1269,10 +1317,14 @@ def _collate_list_tensordict(x): return out -def _stack_anything(x): - if is_tensor_collection(x[0]): - return LazyStackedTensorDict.maybe_dense_stack(x) - return torch.stack(x) +def _stack_anything(data): + if is_tensor_collection(data[0]): + return LazyStackedTensorDict.maybe_dense_stack(data) + return torch.utils._pytree.tree_map( + lambda *x: torch.stack(x), + *data, + is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x), + ) def _collate_id(x): @@ -1281,10 +1333,7 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, ListStorage): - if _is_tensordict: - return _collate_list_tensordict - else: - return torch.utils.data._utils.collate.default_collate + return _stack_anything elif isinstance(storage, TensorStorage): return _collate_id else: