Skip to content

Commit

Permalink
[Feature] TensorDictMap
Browse files Browse the repository at this point in the history
ghstack-source-id: 3b0df1e0dd416a87fbbf65725c9271e579ccd49b
Pull Request resolved: #2306
  • Loading branch information
vmoens committed Aug 4, 2024
1 parent a453b92 commit bef503f
Show file tree
Hide file tree
Showing 5 changed files with 517 additions and 13 deletions.
125 changes: 124 additions & 1 deletion test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import importlib.util

import pytest

import torch

from tensordict import TensorDict
from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash
from torchrl.data import LazyTensorStorage, ListStorage
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
)
from torchrl.envs import GymEnv

_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
"gym", None
Expand Down Expand Up @@ -114,6 +123,120 @@ def test_query(self, clone, index_key):
for i in range(1, 3):
assert res[index_key][i].item() != res[index_key][i + 1].item()

def test_query_module(self):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = LazyTensorStorage(23)

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)

tensor_dict_storage[index] = value
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()


class TesttTensorDictMap:
@pytest.mark.parametrize(
"storage_type",
[
functools.partial(ListStorage, 1000),
functools.partial(LazyTensorStorage, 1000),
],
)
def test_map(self, storage_type):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = storage_type()

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)
assert not hasattr(tensor_dict_storage, "out_keys")

tensor_dict_storage[index] = value
if isinstance(embedding_storage, LazyTensorStorage):
assert hasattr(tensor_dict_storage, "out_keys")
else:
assert not hasattr(tensor_dict_storage, "out_keys")
assert tensor_dict_storage._has_lazy_out_keys()
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()

@pytest.mark.skipif(not _has_gym, reason="gym not installed")
def test_map_rollout(self):
torch.manual_seed(0)
env = GymEnv("CartPole-v1")
env.set_seed(0)
rollout = env.rollout(100)
source, dest = rollout.exclude("next"), rollout.get("next")
storage = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation", "action"],
)
storage_indices = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation"],
out_keys=["_index"],
)
# maps the (obs, action) tuple to a corresponding next state
storage[source] = dest
storage_indices[source] = source
contains = storage.contains(source)
assert len(contains) == rollout.shape[-1]
assert contains.all()
contains = storage.contains(torch.cat([source, source + 1]))
assert len(contains) == rollout.shape[-1] * 2
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
10 changes: 9 additions & 1 deletion torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash
from .map import (
BinaryToDecimal,
HashToInt,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
TensorMap,
)
from .postprocs import MultiStep
from .replay_buffers import (
Flat2TED,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
from .query import HashToInt, QueryModule
from .tdstorage import TensorDictMap, TensorMap
Loading

0 comments on commit bef503f

Please sign in to comment.