Skip to content

Commit

Permalink
Remove Set2Set not working
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Jan 3, 2025
1 parent c4986f6 commit 6df678b
Showing 1 changed file with 0 additions and 38 deletions.
38 changes: 0 additions & 38 deletions GNNLux/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,41 +131,3 @@ function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_un
end

(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x)


@doc raw"""
Set2Set(n_in, n_iters, n_layers = 1)
Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391).
For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times:
```math
\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*)
\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)}
\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i
\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}]
```
where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, input size `2*n_in` and output size `n_in`.
Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`.
```
"""
struct Set2Set{L} <: GNNContainerLayer{(:lstm,)}
lstm::L
num_iters::Int
end

function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
@assert n_layers == 1 "multiple layers not implemented yet" #TODO
n_out = 2 * n_in
lstm = Lux.LSTMCell(n_out => n_in)
return Set2Set(lstm, n_iters)
end

function (l::Set2Set)(g, x, ps, st)
lstm = StatefulLuxLayer{true}(l.lstm, ps.lstm, _getstate(st, :lstm))
m = (; lstm, Wh = ps.lstm.weight_hh)
return GNNlib.set2set_pool(m, g, x)
end

(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

0 comments on commit 6df678b

Please sign in to comment.