Skip to content

Commit

Permalink
LineDiGraph transform: added option to specify which node features wi…
Browse files Browse the repository at this point in the history
…ll become edge features in the line digraph
  • Loading branch information
Flunzmas committed Jan 2, 2025
1 parent f887c29 commit 75c383b
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions torch_geometric/transforms/line_digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,26 @@ class LineDiGraph(BaseTransform):
Line-digraph node indices are equal to indices in the original graph's
coalesced :obj:`edge_index`.
Args:
node_to_edge_features (str, optional): If set to :obj:`'none'`, the
node attributes of the original graph will not be put into the line
digraph. Otherwise, for each line digraph edge
:obj:`e = ((u, v), (v, x))`, :obj:`'inner'` will use node
attributes of :obj:`v` as edge attributes, :obj:`'outer'` will use
node attributes of :obj:`u` and :obj:`x` as edge attributes, and
:obj:`'all'` will use node attributes of :obj:`u`, :obj:`v` and
:obj:`x` as edge attributes.
"""
def __init__(self, node_to_edge_features: str = 'none'):
assert node_to_edge_features in ['none', 'inner', 'outer', 'all']
self.node_to_edge_features = node_to_edge_features

def forward(self, data: Data) -> Data:
assert data.edge_index is not None
assert data.is_directed()
edge_index, edge_attr = data.edge_index, data.edge_attr
E = data.num_edges
new_num_nodes = data.num_edges

edge_index, edge_attr = coalesce(edge_index, edge_attr, data.num_nodes)
row, col = edge_index
Expand All @@ -35,8 +49,30 @@ def forward(self, data: Data) -> Data:
mask = row.unsqueeze(0) == col.unsqueeze(1) # (num_edges, num_edges)
new_edge_index = torch.nonzero(mask).T

new_num_edges = new_edge_index.size(1)

# Obtain new edge attributes
if data.x is None or self.node_to_edge_features == 'none':
new_edge_attr = None
else:
node_features = data.x
edge_src = row[new_edge_index[0]]
edge_mid = col[new_edge_index[0]]
edge_dst = col[new_edge_index[1]]

if self.node_to_edge_features == "inner":
new_edge_attr = node_features[edge_mid]
elif self.node_to_edge_features == "outer":
new_edge_attr = torch.cat([node_features[edge_src],
node_features[edge_dst]], dim=-1)
else: # self.node_to_edge_features == "all"
new_edge_attr = torch.cat([node_features[edge_src],
node_features[edge_mid],
node_features[edge_dst]], dim=-1)

data.edge_index = new_edge_index
data.x = edge_attr
data.num_nodes = E
data.edge_attr = None
data.num_nodes = new_num_nodes
data.num_edges = new_num_edges
data.edge_attr = new_edge_attr
return data

0 comments on commit 75c383b

Please sign in to comment.