Skip to content

Commit

Permalink
[Refactor] Refactor trees
Browse files Browse the repository at this point in the history
ghstack-source-id: 368ba4c4402b6db0bc8b0688802ce161db9776b7
Pull Request resolved: #2634
  • Loading branch information
vmoens committed Dec 12, 2024
1 parent 19dfefc commit 57dc25a
Show file tree
Hide file tree
Showing 6 changed files with 678 additions and 56 deletions.
104 changes: 101 additions & 3 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def _state0(self) -> TensorDict:
def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1)
reward = action.clone()
action = action + torch.arange(action.shape[-1]) / action.shape[-1]

return TensorDict(
{
Expand All @@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest:
forest.extend(r4)
return forest

def _make_forest_intersect(self) -> MCTSForest:
def _make_forest_rebranching(self) -> MCTSForest:
"""
├── 0
│ ├── 16
Expand Down Expand Up @@ -449,7 +450,7 @@ def test_forest_check_ids(self):

def test_forest_intersect(self):
state0 = self._state0()
forest = self._make_forest_intersect()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0)
subtree = forest.get_tree(TensorDict(observation=19))

Expand All @@ -467,13 +468,110 @@ def test_forest_intersect(self):

def test_forest_intersect_vertices(self):
state0 = self._state0()
forest = self._make_forest_intersect()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0)
assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash"))
assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash"))
with pytest.raises(ValueError, match="key_type must be"):
tree.vertices(key_type="another key type")

@pytest.mark.skipif(not _has_gym, reason="requires gym")
def test_simple_tree(self):
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")
r = env.rollout(10)
state0 = r[0]
forest = MCTSForest()
forest.extend(r)
# forest = self._make_forest_intersect()
tree = forest.get_tree(state0, compact=False)
assert tree.max_length() == 9
for p in tree.valid_paths():
assert len(p) == 9

@pytest.mark.parametrize(
"tree_type,compact",
[
["simple", False],
["forest", False],
# parent of rebranching trees are still buggy
# ["rebranching", False],
# ["rebranching", True],
],
)
def test_forest_parent(self, tree_type, compact):
if tree_type == "simple":
if not _has_gym:
pytest.skip("requires gym")
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")
r = env.rollout(10)
state0 = r[0]
forest = MCTSForest()
forest.extend(r)
tree = forest.get_tree(state0, compact=compact)
elif tree_type == "forest":
state0 = self._state0()
forest = self._make_forest()
tree = forest.get_tree(state0, compact=compact)
else:
state0 = self._state0()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0, compact=compact)
# Check access
tree.subtree.parent
tree.subtree.subtree.parent
tree.subtree.subtree.subtree.parent

# check present of weakref
assert tree.subtree[0]._parent is not None
assert tree.subtree[0].subtree[0]._parent is not None

# Check content
assert_close(tree.subtree.parent, tree)
for p in tree.valid_paths():
root = tree
for it in p:
node = root.subtree[it]
assert_close(node.parent, root)
root = node

def test_forest_action_attr(self):
state0 = self._state0()
forest = self._make_forest()
tree = forest.get_tree(state0)
assert tree.branching_action is None
assert (tree.subtree.branching_action != tree.subtree.prev_action).any()
assert (
tree.subtree[0].subtree.branching_action
!= tree.subtree[0].subtree.prev_action
).any()
assert tree.prev_action is None

@pytest.mark.parametrize("intersect", [False, True])
def test_forest_check_obs_match(self, intersect):
state0 = self._state0()
if intersect:
forest = self._make_forest_rebranching()
else:
forest = self._make_forest()
tree = forest.get_tree(state0)
for path in tree.valid_paths():
prev_tree = tree
for p in path:
subtree = prev_tree.subtree[p]
assert (
subtree.node_data["observation"]
== subtree.rollout[..., -1]["next", "observation"]
).all()
assert (
subtree.node_observation
== subtree.rollout[..., -1]["next", "observation"]
).all()
prev_tree = subtree


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/map/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
class SipHash(Module):
"""A Module to Compute SipHash values for given tensors.
A hash function module based on SipHash implementation in python.
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
and the output shape will be ``[batch_size]``.
Args:
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
Expand Down
32 changes: 28 additions & 4 deletions torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def __init__(
self.collate_fn = collate_fn
self.write_fn = write_fn

@property
def max_size(self):
return self.storage.max_size

@property
def out_keys(self) -> List[NestedKey]:
out_keys = self.__dict__.get("_out_keys_and_lazy")
Expand Down Expand Up @@ -177,7 +181,7 @@ def from_tensordict_pair(
collate_fn: Callable[[Any], Any] | None = None,
write_fn: Callable[[Any, Any], Any] | None = None,
consolidated: bool | None = None,
):
) -> TensorDictMap:
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
Args:
Expand Down Expand Up @@ -238,7 +242,13 @@ def from_tensordict_pair(
n_feat = 0
hash_module = []
for in_key in in_keys:
n_feat = source[in_key].shape[-1]
entry = source[in_key]
if entry.ndim == source.ndim:
# this is a good example of why td/tc are useful - carrying metadata
# allows us to know if there's a feature dim or not
n_feat = 0
else:
n_feat = entry.shape[-1]
if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT:
_hash_module = RandomProjectionHash()
else:
Expand Down Expand Up @@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
if not self._has_lazy_out_keys():
# TODO: make this work with pytrees and avoid calling select if keys match
value = value.select(*self.out_keys, strict=False)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
if index.unique().numel() < index.numel():
# If multiple values point to the same place in the storage, we cannot process them by batch
# There could be a better way to deal with this, using unique ids.
vals = []
for it, val in zip(item.split(1), value.split(1)):
self[it] = val
vals.append(val)
# __setitem__ may affect the content of the input data
value.update(TensorDictBase.lazy_stack(vals))
return
if self.write_fn is not None:
# We use this block in the following context: the value written in the storage is already present,
# but it needs to be updated.
# We first check if the value is already there using `contains`. If so, we pass the new value and the
# previous one to write_fn. The values that are not present are passed alone.
if len(self):
modifiable = self.contains(item)
if modifiable.any():
Expand All @@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
value = self.write_fn(value)
else:
value = self.write_fn(value)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
self.storage.set(index, value)

def __len__(self):
Expand Down
Loading

0 comments on commit 57dc25a

Please sign in to comment.