|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 | import argparse
|
| 6 | +import functools |
6 | 7 | import importlib.util
|
7 | 8 |
|
8 | 9 | import pytest
|
9 | 10 |
|
10 | 11 | import torch
|
11 | 12 |
|
12 | 13 | from tensordict import TensorDict
|
13 |
| -from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash |
| 14 | +from torchrl.data import LazyTensorStorage, ListStorage |
| 15 | +from torchrl.data.map import ( |
| 16 | + BinaryToDecimal, |
| 17 | + QueryModule, |
| 18 | + RandomProjectionHash, |
| 19 | + SipHash, |
| 20 | + TensorDictMap, |
| 21 | +) |
| 22 | +from torchrl.envs import GymEnv |
14 | 23 |
|
15 | 24 | _has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
|
16 | 25 | "gym", None
|
@@ -114,6 +123,120 @@ def test_query(self, clone, index_key):
|
114 | 123 | for i in range(1, 3):
|
115 | 124 | assert res[index_key][i].item() != res[index_key][i + 1].item()
|
116 | 125 |
|
| 126 | + def test_query_module(self): |
| 127 | + query_module = QueryModule( |
| 128 | + in_keys=["key1", "key2"], |
| 129 | + index_key="index", |
| 130 | + hash_module=SipHash(), |
| 131 | + ) |
| 132 | + |
| 133 | + embedding_storage = LazyTensorStorage(23) |
| 134 | + |
| 135 | + tensor_dict_storage = TensorDictMap( |
| 136 | + query_module=query_module, |
| 137 | + storage=embedding_storage, |
| 138 | + ) |
| 139 | + |
| 140 | + index = TensorDict( |
| 141 | + { |
| 142 | + "key1": torch.Tensor([[-1], [1], [3], [-3]]), |
| 143 | + "key2": torch.Tensor([[0], [2], [4], [-4]]), |
| 144 | + }, |
| 145 | + batch_size=(4,), |
| 146 | + ) |
| 147 | + |
| 148 | + value = TensorDict( |
| 149 | + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) |
| 150 | + ) |
| 151 | + |
| 152 | + tensor_dict_storage[index] = value |
| 153 | + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 |
| 154 | + |
| 155 | + new_index = index.clone(True) |
| 156 | + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) |
| 157 | + retrieve_value = tensor_dict_storage[new_index] |
| 158 | + |
| 159 | + assert (retrieve_value["index"] == value["index"]).all() |
| 160 | + |
| 161 | + |
| 162 | +class TesttTensorDictMap: |
| 163 | + @pytest.mark.parametrize( |
| 164 | + "storage_type", |
| 165 | + [ |
| 166 | + functools.partial(ListStorage, 1000), |
| 167 | + functools.partial(LazyTensorStorage, 1000), |
| 168 | + ], |
| 169 | + ) |
| 170 | + def test_map(self, storage_type): |
| 171 | + query_module = QueryModule( |
| 172 | + in_keys=["key1", "key2"], |
| 173 | + index_key="index", |
| 174 | + hash_module=SipHash(), |
| 175 | + ) |
| 176 | + |
| 177 | + embedding_storage = storage_type() |
| 178 | + |
| 179 | + tensor_dict_storage = TensorDictMap( |
| 180 | + query_module=query_module, |
| 181 | + storage=embedding_storage, |
| 182 | + ) |
| 183 | + |
| 184 | + index = TensorDict( |
| 185 | + { |
| 186 | + "key1": torch.Tensor([[-1], [1], [3], [-3]]), |
| 187 | + "key2": torch.Tensor([[0], [2], [4], [-4]]), |
| 188 | + }, |
| 189 | + batch_size=(4,), |
| 190 | + ) |
| 191 | + |
| 192 | + value = TensorDict( |
| 193 | + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) |
| 194 | + ) |
| 195 | + assert not hasattr(tensor_dict_storage, "out_keys") |
| 196 | + |
| 197 | + tensor_dict_storage[index] = value |
| 198 | + if isinstance(embedding_storage, LazyTensorStorage): |
| 199 | + assert hasattr(tensor_dict_storage, "out_keys") |
| 200 | + else: |
| 201 | + assert not hasattr(tensor_dict_storage, "out_keys") |
| 202 | + assert tensor_dict_storage._has_lazy_out_keys() |
| 203 | + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 |
| 204 | + |
| 205 | + new_index = index.clone(True) |
| 206 | + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) |
| 207 | + retrieve_value = tensor_dict_storage[new_index] |
| 208 | + |
| 209 | + assert (retrieve_value["index"] == value["index"]).all() |
| 210 | + |
| 211 | + @pytest.mark.skipif(not _has_gym, reason="gym not installed") |
| 212 | + def test_map_rollout(self): |
| 213 | + torch.manual_seed(0) |
| 214 | + env = GymEnv("CartPole-v1") |
| 215 | + env.set_seed(0) |
| 216 | + rollout = env.rollout(100) |
| 217 | + source, dest = rollout.exclude("next"), rollout.get("next") |
| 218 | + storage = TensorDictMap.from_tensordict_pair( |
| 219 | + source, |
| 220 | + dest, |
| 221 | + in_keys=["observation", "action"], |
| 222 | + ) |
| 223 | + storage_indices = TensorDictMap.from_tensordict_pair( |
| 224 | + source, |
| 225 | + dest, |
| 226 | + in_keys=["observation"], |
| 227 | + out_keys=["_index"], |
| 228 | + ) |
| 229 | + # maps the (obs, action) tuple to a corresponding next state |
| 230 | + storage[source] = dest |
| 231 | + storage_indices[source] = source |
| 232 | + contains = storage.contains(source) |
| 233 | + assert len(contains) == rollout.shape[-1] |
| 234 | + assert contains.all() |
| 235 | + contains = storage.contains(torch.cat([source, source + 1])) |
| 236 | + assert len(contains) == rollout.shape[-1] * 2 |
| 237 | + assert contains[: rollout.shape[-1]].all() |
| 238 | + assert not contains[rollout.shape[-1] :].any() |
| 239 | + |
117 | 240 |
|
118 | 241 | if __name__ == "__main__":
|
119 | 242 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
0 commit comments