From 75c383be6d8db06330a5d6d91c3f6030d15b5d86 Mon Sep 17 00:00:00 2001 From: Andreas Boltres Date: Thu, 2 Jan 2025 12:26:12 +0100 Subject: [PATCH] LineDiGraph transform: added option to specify which node features will become edge features in the line digraph --- torch_geometric/transforms/line_digraph.py | 42 ++++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/torch_geometric/transforms/line_digraph.py b/torch_geometric/transforms/line_digraph.py index df94c58508c7..99674ddae255 100644 --- a/torch_geometric/transforms/line_digraph.py +++ b/torch_geometric/transforms/line_digraph.py @@ -21,12 +21,26 @@ class LineDiGraph(BaseTransform): Line-digraph node indices are equal to indices in the original graph's coalesced :obj:`edge_index`. + + Args: + node_to_edge_features (str, optional): If set to :obj:`'none'`, the + node attributes of the original graph will not be put into the line + digraph. Otherwise, for each line digraph edge + :obj:`e = ((u, v), (v, x))`, :obj:`'inner'` will use node + attributes of :obj:`v` as edge attributes, :obj:`'outer'` will use + node attributes of :obj:`u` and :obj:`x` as edge attributes, and + :obj:`'all'` will use node attributes of :obj:`u`, :obj:`v` and + :obj:`x` as edge attributes. """ + def __init__(self, node_to_edge_features: str = 'none'): + assert node_to_edge_features in ['none', 'inner', 'outer', 'all'] + self.node_to_edge_features = node_to_edge_features + def forward(self, data: Data) -> Data: assert data.edge_index is not None assert data.is_directed() edge_index, edge_attr = data.edge_index, data.edge_attr - E = data.num_edges + new_num_nodes = data.num_edges edge_index, edge_attr = coalesce(edge_index, edge_attr, data.num_nodes) row, col = edge_index @@ -35,8 +49,30 @@ def forward(self, data: Data) -> Data: mask = row.unsqueeze(0) == col.unsqueeze(1) # (num_edges, num_edges) new_edge_index = torch.nonzero(mask).T + new_num_edges = new_edge_index.size(1) + + # Obtain new edge attributes + if data.x is None or self.node_to_edge_features == 'none': + new_edge_attr = None + else: + node_features = data.x + edge_src = row[new_edge_index[0]] + edge_mid = col[new_edge_index[0]] + edge_dst = col[new_edge_index[1]] + + if self.node_to_edge_features == "inner": + new_edge_attr = node_features[edge_mid] + elif self.node_to_edge_features == "outer": + new_edge_attr = torch.cat([node_features[edge_src], + node_features[edge_dst]], dim=-1) + else: # self.node_to_edge_features == "all" + new_edge_attr = torch.cat([node_features[edge_src], + node_features[edge_mid], + node_features[edge_dst]], dim=-1) + data.edge_index = new_edge_index data.x = edge_attr - data.num_nodes = E - data.edge_attr = None + data.num_nodes = new_num_nodes + data.num_edges = new_num_edges + data.edge_attr = new_edge_attr return data