Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] Add pooling layers #576

Merged
merged 14 commits into from
Jan 10, 2025
4 changes: 3 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ export TGCN,
EvolveGCNO

include("layers/pool.jl")
export GlobalPool
export GlobalPool,
GlobalAttentionPool,
TopKPool

end #module

93 changes: 92 additions & 1 deletion GNNLux/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,95 @@ end

(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

@doc raw"""
GlobalAttentionPool(fgate, ffeat=identity)

Global soft attention layer from the [Gated Graph Sequence Neural
Networks](https://arxiv.org/abs/1511.05493) paper

```math
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
```

where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNlib.jl/stable/api/utils/#GNNlib.softmax_nodes)
operation:

```math
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
```

# Arguments

- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
It is typically expressed by a neural network.

- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
It is typically expressed by a neural network.

# Examples

```julia
using Graphs, LuxCore, Lux, GNNLux, Random

rng = Random.default_rng()
chin = 6
chout = 5

fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
pool = GlobalAttentionPool(fgate, ffeat)

g = batch([GNNGraph(Graphs.random_regular_graph(10, 4),
ndata=rand(Float32, chin, 10))
for i=1:3])

ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat))
st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat))

u, st = pool(g, g.ndata.x, ps, st)

@assert size(u) == (chout, g.num_graphs)
```
"""
aurorarossi marked this conversation as resolved.
Show resolved Hide resolved
struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)}
fgate
ffeat
end

GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)

function (l::GlobalAttentionPool)(g, x, ps, st)
fgate = StatefulLuxLayer{true}(l.fgate, ps.fgate, _getstate(st, :fgate))
ffeat = StatefulLuxLayer{true}(l.ffeat, ps.ffeat, _getstate(st, :ffeat))
m = (; fgate, ffeat)
return GNNlib.global_attention_pool(m, g, x), st
end

(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

"""
TopKPool(adj, k, in_channel)

Top-k pooling layer.

# Arguments

- `adj`: Adjacency matrix of a graph.
- `k`: Top-k nodes are selected to pool together.
- `in_channel`: The dimension of input channel.
"""
struct TopKPool{T, S}
A::AbstractMatrix{T}
k::Int
p::AbstractVector{S}
Ã::AbstractMatrix{T}
end

function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform)
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
end

aurorarossi marked this conversation as resolved.
Show resolved Hide resolved
(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x)
43 changes: 37 additions & 6 deletions GNNLux/test/layers/pool.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
@testitem "Pooling" setup=[TestModuleLux] begin
using .TestModuleLux
@testset "GlobalPool" begin
@testset "Pooling" begin

rng = StableRNG(1234)
g = rand_graph(rng, 10, 40)
in_dims = 3
x = randn(rng, Float32, in_dims, 10)

@testset "GCNConv" begin
@testset "GlobalPool" begin
g = rand_graph(rng, 10, 40)
in_dims = 3
x = randn(rng, Float32, in_dims, 10)
l = GlobalPool(mean)
test_lux_layer(rng, l, g, x, sizey=(in_dims,1))
end
@testset "GlobalAttentionPool" begin
n = 10
chin = 6
chout = 5
ng = 3
g = batch([GNNGraph(rand_graph(rng, 10, 40),
ndata = rand(Float32, chin, n)) for i in 1:ng])

fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
l = GlobalAttentionPool(fgate, ffeat)

test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true)
end

@testset "TopKPool" begin
N = 10
k, in_channel = 4, 7
X = rand(in_channel, N)
ps = (;)
st = (;)
for T in [Bool, Float64]
adj = rand(T, N, N)
p = GNNLux.TopKPool(adj, k, in_channel)
@test eltype(p.p) === Float32
@test size(p.p) == (in_channel,)
@test eltype(p.Ã) === T
@test size(p.Ã) == (k, k)
y = p(X, ps, st)
@test size(y) == (in_channel, k)
end
end
end
end
4 changes: 2 additions & 2 deletions GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
n_in = size(x, 1)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
h = zeros_like(l.Wh, size(l.Wh, 2))
c = zeros_like(l.Wh, size(l.Wh, 2))
state = (h, c)
for t in 1:l.num_iters
q, state = l.lstm(qstar, state) # [n_in, n_graphs]
Expand Down
7 changes: 4 additions & 3 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ operation:
# Arguments

- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
It is tipically expressed by a neural network.
It is typically expressed by a neural network.

- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
It is tipically expressed by a neural network.
It is typically expressed by a neural network.

# Examples

Expand Down Expand Up @@ -156,7 +156,8 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
end

function (l::Set2Set)(g, x)
return GNNlib.set2set_pool(l, g, x)
m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh)
return GNNlib.set2set_pool(m, g, x)
end

(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))
Loading