diff --git a/CHANGELOG.md b/CHANGELOG.md index 84c50c4d411b..13770a9c242b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added bipartite pooling operator from `"DeepTreeGANv2: Iterative Pooling of Point Clouds" `. + ### Changed ### Deprecated diff --git a/test/nn/pool/test_bipartite_pool.py b/test/nn/pool/test_bipartite_pool.py new file mode 100644 index 000000000000..2834db012e7f --- /dev/null +++ b/test/nn/pool/test_bipartite_pool.py @@ -0,0 +1,54 @@ +import torch + +from torch_geometric import nn + + +def test_bipartite_pooling(): + num_nodes = 100 + ratio = 10 + in_channels = 5 + out_channels = 8 + num_graphs = 4 + kw = dict(in_channels=in_channels, out_channels=out_channels) + + gnnlist = [ + nn.GINConv(torch.nn.Linear(in_channels, out_channels)), + # nn.GENConv(**kw), # gradient test breaks + nn.GeneralConv(**kw), + nn.GraphConv(**kw), + nn.MFConv(**kw), + # nn.SimpleConv(), # gradient test breaks + nn.SAGEConv(**kw), + nn.WLConvContinuous(), + nn.GATv2Conv(add_self_loops=False, **kw), + nn.GATConv(add_self_loops=False, **kw), + ] + batch = torch.arange(num_graphs).repeat_interleave(num_nodes) + for gnn in gnnlist: + + pool = nn.BipartitePooling(in_channels, ratio=ratio, gnn=gnn) + + x = torch.randn((num_graphs * num_nodes, in_channels)).requires_grad_() + x.retain_grad() + out, new_batchidx = pool(x, batch) + + if isinstance(gnn, (nn.SimpleConv, nn.WLConvContinuous)): + assert out.shape == torch.Size([num_graphs * ratio, in_channels]) + else: + assert out.shape == torch.Size([num_graphs * ratio, out_channels]) + + for grad_graph in range(num_graphs): + + out[new_batchidx == grad_graph].sum().backward(retain_graph=True) + # only graph igraph gets a gradient + for check_graph in range(num_graphs): + grad_grap_i = x.grad[batch == check_graph].abs().sum(1) + if grad_graph == check_graph: + assert (grad_grap_i > 0).all() + else: + assert (grad_grap_i == 0).all() + + x.grad.zero_() + # all seed nodes get a gradient + assert (pool.seed_nodes.grad.abs().sum(1) > 0).all() + pool.seed_nodes.grad.zero_() diff --git a/torch_geometric/nn/pool/__init__.py b/torch_geometric/nn/pool/__init__.py index 09ef32d8536e..714efcee3b53 100644 --- a/torch_geometric/nn/pool/__init__.py +++ b/torch_geometric/nn/pool/__init__.py @@ -15,6 +15,7 @@ from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x from .topk_pool import TopKPooling from .sag_pool import SAGPooling +from .bipartite_pool import BipartitePooling from .edge_pool import EdgePooling from .cluster_pool import ClusterPooling from .asap import ASAPooling @@ -344,6 +345,7 @@ def nearest( 'ApproxMIPSKNNIndex', 'TopKPooling', 'SAGPooling', + 'BipartitePooling', 'EdgePooling', 'ClusterPooling', 'ASAPooling', diff --git a/torch_geometric/nn/pool/bipartite_pool.py b/torch_geometric/nn/pool/bipartite_pool.py new file mode 100644 index 000000000000..067d99aa165f --- /dev/null +++ b/torch_geometric/nn/pool/bipartite_pool.py @@ -0,0 +1,105 @@ +from typing import Tuple + +import torch +from torch import Tensor + +from torch_geometric.typing import OptTensor + + +class BipartitePooling(torch.nn.Module): + r"""The bipartite pooling operator from the `"DeepTreeGANv2: Iterative + Pooling of Point Clouds" `_ paper. + + The Pooling layer constructs a dense bipartite graph between the input + nodes and the "Seed" nodes that are trainable parameters of the layer. + + Args: + in_channels (int): Size of each input sample. + ratio (int): Number of seed nodes + gnn (torch.nn.Module): A graph neural network layer that + implements the bipartite messages passing methode, such as + :class:`torch_geometric.nn.conv.GATv2Conv`, + :class:`torch_geometric.nn.conv.GATConv`. + :class:`torch_geometric.nn.conv.GINConv`, + :class:`torch_geometric.nn.conv.GeneralConv`, + :class:`torch_geometric.nn.conv.GraphConv`, + :class:`torch_geometric.nn.conv.MFConv`, + :class:`torch_geometric.nn.conv.SAGEConv`, + :class:`torch_geometric.nn.conv.WLConvContinuous`, + Recommended: `:class:`torch_geometric.nn.conv.GATv2Conv` + with `add_self_loops=False`. + Shapes: + - **inputs:** + node features :math:`(|\mathcal{V}|, F_{in})`, + batch :math:`(|\mathcal{V}|)` + - **outputs:** + node features (`ratio`, :math:`F_{out}`) + batch (`"ratio"`) + """ + def __init__( + self, + in_channels: int, + ratio: int, + gnn: torch.nn.Module, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.ratio = ratio + + self.seed_nodes = torch.nn.Parameter( + torch.empty(size=(self.ratio, self.in_channels))) + self.gnn = gnn + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.gnn.reset_parameters() + self.seed_nodes.data.normal_() + + def forward( + self, + x: Tensor, + batch: OptTensor = None, + ) -> Tuple[Tensor, Tensor]: + r"""Forward pass. + + Args: + x (torch.Tensor): The node feature matrix. + + batch (torch.Tensor, optional): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each node to a specific example. (default: :obj:`None`) + """ + if batch is None: + batch = torch.zeros((x.size(0)), dtype=torch.long).to(x.device) + batch_size = batch.max() + 1 + + x_aggrs = self.seed_nodes.repeat(batch_size, 1) + + source_graph_size = len(x) + + source = torch.arange(source_graph_size, device=x.device, + dtype=torch.long).repeat_interleave(self.ratio) + + target = torch.arange(self.ratio, device=x.device, + dtype=torch.long).repeat(source_graph_size) + target += batch.repeat_interleave(self.ratio) * self.ratio + + out = self.gnn( + x=(x, x_aggrs), + edge_index=torch.vstack([source, target]), + # size=(len(x), self.ratio * int(batch_size)), + ) + + new_batchidx = torch.arange(batch_size, dtype=torch.long, + device=x.device).repeat_interleave( + self.ratio) + + return (out, new_batchidx) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, ' + f'{self.in_channels}, {self.ratio})')