Skip to content

Commit fed01f1

Browse files
author
Vincent Moens
committed
[Feature] TensorDictMap
ghstack-source-id: 57d1544 Pull Request resolved: #2306
1 parent 770a87d commit fed01f1

File tree

5 files changed

+517
-13
lines changed

5 files changed

+517
-13
lines changed

test/test_storage_map.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,23 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6+
import functools
67
import importlib.util
78

89
import pytest
910

1011
import torch
1112

1213
from tensordict import TensorDict
13-
from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash
14+
from torchrl.data import LazyTensorStorage, ListStorage
15+
from torchrl.data.map import (
16+
BinaryToDecimal,
17+
QueryModule,
18+
RandomProjectionHash,
19+
SipHash,
20+
TensorDictMap,
21+
)
22+
from torchrl.envs import GymEnv
1423

1524
_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
1625
"gym", None
@@ -114,6 +123,120 @@ def test_query(self, clone, index_key):
114123
for i in range(1, 3):
115124
assert res[index_key][i].item() != res[index_key][i + 1].item()
116125

126+
def test_query_module(self):
127+
query_module = QueryModule(
128+
in_keys=["key1", "key2"],
129+
index_key="index",
130+
hash_module=SipHash(),
131+
)
132+
133+
embedding_storage = LazyTensorStorage(23)
134+
135+
tensor_dict_storage = TensorDictMap(
136+
query_module=query_module,
137+
storage=embedding_storage,
138+
)
139+
140+
index = TensorDict(
141+
{
142+
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
143+
"key2": torch.Tensor([[0], [2], [4], [-4]]),
144+
},
145+
batch_size=(4,),
146+
)
147+
148+
value = TensorDict(
149+
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
150+
)
151+
152+
tensor_dict_storage[index] = value
153+
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4
154+
155+
new_index = index.clone(True)
156+
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
157+
retrieve_value = tensor_dict_storage[new_index]
158+
159+
assert (retrieve_value["index"] == value["index"]).all()
160+
161+
162+
class TesttTensorDictMap:
163+
@pytest.mark.parametrize(
164+
"storage_type",
165+
[
166+
functools.partial(ListStorage, 1000),
167+
functools.partial(LazyTensorStorage, 1000),
168+
],
169+
)
170+
def test_map(self, storage_type):
171+
query_module = QueryModule(
172+
in_keys=["key1", "key2"],
173+
index_key="index",
174+
hash_module=SipHash(),
175+
)
176+
177+
embedding_storage = storage_type()
178+
179+
tensor_dict_storage = TensorDictMap(
180+
query_module=query_module,
181+
storage=embedding_storage,
182+
)
183+
184+
index = TensorDict(
185+
{
186+
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
187+
"key2": torch.Tensor([[0], [2], [4], [-4]]),
188+
},
189+
batch_size=(4,),
190+
)
191+
192+
value = TensorDict(
193+
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
194+
)
195+
assert not hasattr(tensor_dict_storage, "out_keys")
196+
197+
tensor_dict_storage[index] = value
198+
if isinstance(embedding_storage, LazyTensorStorage):
199+
assert hasattr(tensor_dict_storage, "out_keys")
200+
else:
201+
assert not hasattr(tensor_dict_storage, "out_keys")
202+
assert tensor_dict_storage._has_lazy_out_keys()
203+
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4
204+
205+
new_index = index.clone(True)
206+
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
207+
retrieve_value = tensor_dict_storage[new_index]
208+
209+
assert (retrieve_value["index"] == value["index"]).all()
210+
211+
@pytest.mark.skipif(not _has_gym, reason="gym not installed")
212+
def test_map_rollout(self):
213+
torch.manual_seed(0)
214+
env = GymEnv("CartPole-v1")
215+
env.set_seed(0)
216+
rollout = env.rollout(100)
217+
source, dest = rollout.exclude("next"), rollout.get("next")
218+
storage = TensorDictMap.from_tensordict_pair(
219+
source,
220+
dest,
221+
in_keys=["observation", "action"],
222+
)
223+
storage_indices = TensorDictMap.from_tensordict_pair(
224+
source,
225+
dest,
226+
in_keys=["observation"],
227+
out_keys=["_index"],
228+
)
229+
# maps the (obs, action) tuple to a corresponding next state
230+
storage[source] = dest
231+
storage_indices[source] = source
232+
contains = storage.contains(source)
233+
assert len(contains) == rollout.shape[-1]
234+
assert contains.all()
235+
contains = storage.contains(torch.cat([source, source + 1]))
236+
assert len(contains) == rollout.shape[-1] * 2
237+
assert contains[: rollout.shape[-1]].all()
238+
assert not contains[rollout.shape[-1] :].any()
239+
117240

118241
if __name__ == "__main__":
119242
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash
6+
from .map import (
7+
BinaryToDecimal,
8+
HashToInt,
9+
QueryModule,
10+
RandomProjectionHash,
11+
SipHash,
12+
TensorDictMap,
13+
TensorMap,
14+
)
715
from .postprocs import MultiStep
816
from .replay_buffers import (
917
Flat2TED,

torchrl/data/map/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55

66
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
77
from .query import HashToInt, QueryModule
8+
from .tdstorage import TensorDictMap, TensorMap

0 commit comments

Comments
 (0)