Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] GNN position and velocity key #132

Merged
merged 4 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def _get_excluded_keys(self, group: str):
for other_group in self.group_map.keys():
if other_group != group:
excluded_keys += [other_group, ("next", other_group)]
excluded_keys += ["info", (group, "info"), ("next", group, "info")]
return excluded_keys

def _optimizer_loop(self, group: str) -> TensorDictBase:
Expand Down
23 changes: 16 additions & 7 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,23 @@ class Gnn(Model):
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent position. This key will be processed as a
node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features.
(in the `observation_spec`) representing the agent position.
To do this, your environment needs to have dictionary observations and one of the keys needs to be `position_key`.
This key will be processed as a node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features.
In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence.
pos_features (int, optional): Needed when position_key is specified.
It has to match to the last element of the shape the tensor under position_key.
exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and
(in the `observation_spec`) representing the agent position.
To do this, your environment needs to have dictionary observations and one of the keys needs to be `velocity_key`.
This key will be processed as a node feature, and
it will be used to construct edge features. In particular, it will be used to compute relative velocities
(``vel_node_1 - vel_node_2``) for all neighbours in the graph.
If you want to use this feature in a :class:`~benchmarl.models.SequenceModel`, the GNN needs to be first in sequence.
vel_features (int, optional): Needed when velocity_key is specified.
It has to match to the last element of the shape the tensor under velocity_key.
edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.
Expand Down Expand Up @@ -170,8 +175,7 @@ def __init__(
) and not self.gnn_supports_edge_attrs:
warnings.warn(
"Position key or velocity key provided but GNN class does not support edge attributes. "
"These input keys will be ignored. If instead you want to process them as node features, "
"just set them (position_key or velocity_key) to null."
"These keys will not be used for computing edge features."
)
if (
position_key is not None or velocity_key is not None
Expand Down Expand Up @@ -369,10 +373,15 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _get_key_terminating_with(self, keys: List[NestedKey], key: str) -> NestedKey:
for k in keys:
k_tuple = _unravel_key_to_tuple(k)
if k_tuple[-1] == key and self.agent_group in k_tuple:
if (
k_tuple[-1] == key
and self.agent_group in k_tuple
and not "next" == k_tuple[0]
):
return k
raise KeyError(
f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}"
f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}. "
f"If you are using the GNN in a `SequenceModel` and want to use this key, it needs to be the first model."
)


Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"sphinx": ("https://www.sphinx-doc.org/en/master/", None),
"torch": ("https://pytorch.org/docs/master", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"torchrl": ("https://pytorch.org/rl/stable/", None),
"tensordict": ("https://pytorch.org/tensordict/stable", None),
}
Expand Down
Loading