Skip to content

Commit

Permalink
Added bipartite pooling operator from `"DeepTreeGANv2: Iterative Pool…
Browse files Browse the repository at this point in the history
…ing of Point Clouds" <https://arxiv.org/abs/2312.00042>`.
  • Loading branch information
mova committed Sep 13, 2024
1 parent 6d9e850 commit 22a7216
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added bipartite pooling operator from `"DeepTreeGANv2: Iterative Pooling of Point Clouds" <https://arxiv.org/abs/2312.00042>`.
- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
Expand Down
54 changes: 54 additions & 0 deletions test/nn/pool/test_bipartite_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from torch_geometric import nn


def test_bipartite_pooling():
num_nodes = 100
ratio = 10
in_channels = 5
out_channels = 8
num_graphs = 4
kw = dict(in_channels=in_channels, out_channels=out_channels)

gnnlist = [
nn.GINConv(torch.nn.Linear(in_channels, out_channels)),
# nn.GENConv(**kw), # gradient test breaks
nn.GeneralConv(**kw),
nn.GraphConv(**kw),
nn.MFConv(**kw),
# nn.SimpleConv(), # gradient test breaks
nn.SAGEConv(**kw),
nn.WLConvContinuous(),
nn.GATv2Conv(add_self_loops=False, **kw),
nn.GATConv(add_self_loops=False, **kw),
]
batch = torch.arange(num_graphs).repeat_interleave(num_nodes)
for gnn in gnnlist:

pool = nn.BipartitePooling(in_channels, ratio=ratio, gnn=gnn)

x = torch.randn((num_graphs * num_nodes, in_channels)).requires_grad_()
x.retain_grad()
out, new_batchidx = pool(x, batch)

if isinstance(gnn, (nn.SimpleConv, nn.WLConvContinuous)):
assert out.shape == torch.Size([num_graphs * ratio, in_channels])
else:
assert out.shape == torch.Size([num_graphs * ratio, out_channels])

for grad_graph in range(num_graphs):

out[new_batchidx == grad_graph].sum().backward(retain_graph=True)
# only graph igraph gets a gradient
for check_graph in range(num_graphs):
grad_grap_i = x.grad[batch == check_graph].abs().sum(1)
if grad_graph == check_graph:
assert (grad_grap_i > 0).all()
else:
assert (grad_grap_i == 0).all()

x.grad.zero_()
# all seed nodes get a gradient
assert (pool.seed_nodes.grad.abs().sum(1) > 0).all()
pool.seed_nodes.grad.zero_()
2 changes: 2 additions & 0 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
from .topk_pool import TopKPooling
from .sag_pool import SAGPooling
from .bipartite_pool import BipartitePooling
from .edge_pool import EdgePooling
from .cluster_pool import ClusterPooling
from .asap import ASAPooling
Expand Down Expand Up @@ -344,6 +345,7 @@ def nearest(
'ApproxMIPSKNNIndex',
'TopKPooling',
'SAGPooling',
'BipartitePooling',
'EdgePooling',
'ClusterPooling',
'ASAPooling',
Expand Down
105 changes: 105 additions & 0 deletions torch_geometric/nn/pool/bipartite_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Tuple

import torch
from torch import Tensor

from torch_geometric.typing import OptTensor


class BipartitePooling(torch.nn.Module):
r"""The bipartite pooling operator from the `"DeepTreeGANv2: Iterative
Pooling of Point Clouds" <https://arxiv.org/abs/2312.00042>`_ paper.
The Pooling layer constructs a dense bipartite graph between the input
nodes and the "Seed" nodes that are trainable parameters of the layer.
Args:
in_channels (int): Size of each input sample.
ratio (int): Number of seed nodes
gnn (torch.nn.Module): A graph neural network layer that
implements the bipartite messages passing methode, such as
:class:`torch_geometric.nn.conv.GATv2Conv`,
:class:`torch_geometric.nn.conv.GATConv`.
:class:`torch_geometric.nn.conv.GINConv`,
:class:`torch_geometric.nn.conv.GeneralConv`,
:class:`torch_geometric.nn.conv.GraphConv`,
:class:`torch_geometric.nn.conv.MFConv`,
:class:`torch_geometric.nn.conv.SAGEConv`,
:class:`torch_geometric.nn.conv.WLConvContinuous`,
Recommended: `:class:`torch_geometric.nn.conv.GATv2Conv`
with `add_self_loops=False`.
Shapes:
- **inputs:**
node features :math:`(|\mathcal{V}|, F_{in})`,
batch :math:`(|\mathcal{V}|)`
- **outputs:**
node features (`ratio`, :math:`F_{out}`)
batch (`"ratio"`)
"""
def __init__(
self,
in_channels: int,
ratio: int,
gnn: torch.nn.Module,
**kwargs,
):
super().__init__()

self.in_channels = in_channels
self.ratio = ratio

self.seed_nodes = torch.nn.Parameter(
torch.empty(size=(self.ratio, self.in_channels)))
self.gnn = gnn

self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.gnn.reset_parameters()
self.seed_nodes.data.normal_()

def forward(
self,
x: Tensor,
batch: OptTensor = None,
) -> Tuple[Tensor, Tensor]:
r"""Forward pass.
Args:
x (torch.Tensor): The node feature matrix.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific example. (default: :obj:`None`)
"""
if batch is None:
batch = torch.zeros((x.size(0)), dtype=torch.long).to(x.device)
batch_size = batch.max() + 1

x_aggrs = self.seed_nodes.repeat(batch_size, 1)

source_graph_size = len(x)

source = torch.arange(source_graph_size, device=x.device,
dtype=torch.long).repeat_interleave(self.ratio)

target = torch.arange(self.ratio, device=x.device,
dtype=torch.long).repeat(source_graph_size)
target += batch.repeat_interleave(self.ratio) * self.ratio

out = self.gnn(
x=(x, x_aggrs),
edge_index=torch.vstack([source, target]),
# size=(len(x), self.ratio * int(batch_size)),
)

new_batchidx = torch.arange(batch_size, dtype=torch.long,
device=x.device).repeat_interleave(
self.ratio)

return (out, new_batchidx)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, '
f'{self.in_channels}, {self.ratio})')

0 comments on commit 22a7216

Please sign in to comment.