Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eigenpooling #90

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/AtomicGraphNets.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module AtomicGraphNets

export AGNConv, AGNPool#, AGNConvDEQ
include("layers.jl")
include("layers/layers.jl")
using .Layers: AGNConv, AGNPool

include("models.jl")
Expand Down
235 changes: 0 additions & 235 deletions src/layers.jl

This file was deleted.

102 changes: 102 additions & 0 deletions src/layers/conv/agnconv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Flux
using Flux: glorot_uniform, normalise, @functor#, destructure
using Zygote: @adjoint, @nograd
using LinearAlgebra, SparseArrays
using Statistics
using ChemistryFeaturization
#using DifferentialEquations, DiffEqSensitivity

"""
AGNConv(in=>out)

Atomic graph convolutional layer. Almost identical to GCNConv from GeometricFlux but adapted to be most similar to Tian's original AGNN structure, so explicitly has self and convolutional weights separately.

# Fields
- `selfweight::Array{T,2}`: weights applied to features at a node
- `convweight::Array{T,2}`: convolutional weights
- `bias::Array{T,2}`: additive bias (second dimension is always 1 because only learnable per-feature, not per-node)
- `σ::F`: activation function (will be applied before `reg_norm` to outputs), defaults to softplus

# Arguments
- `in::Integer`: the dimension of input features.
- `out::Integer`: the dimension of output features.
- `σ=softplus`: activation function
- `initW=glorot_uniform`: initialization function for weights
- `initb=zeros`: initialization function for biases

"""
struct AGNConv{T,F}
selfweight::Array{T,2}
convweight::Array{T,2}
bias::Array{T,2}
σ::F
end

function AGNConv(
ch::Pair{<:Integer,<:Integer},
σ = softplus;
initW = glorot_uniform,
initb = zeros,
T::DataType = Float64,
)
selfweight = T.(initW(ch[2], ch[1]))
convweight = T.(initW(ch[2], ch[1]))
b = T.(initb(ch[2], 1))
AGNConv(selfweight, convweight, b, σ)
end

@functor AGNConv

"""
Define action of layer on inputs: do a graph convolution, add this (weighted by convolutional weight) to the features themselves (weighted by self weight) and the per-feature bias (concatenated to match number of nodes in graph).

# Arguments
- input: a FeaturizedAtoms object, or graph_laplacian, encoded_features

# Note
In the case of providing two matrices, the following conditions must hold:
- `lapl` must be square and of dimension N x N where N is the number of nodes in the graph
- `X` (encoded features) must be of dimension M x N, where M is `size(l.convweight)[2]` (or equivalently, `size(l.selfweight)[2]`)
"""
function (l::AGNConv{T,F})(lapl::Matrix{<:Real}, X::Matrix{<:Real}) where {T<:Real,F}
# should we put dimension checks here? Could allow more informative errors, but would likely introduce performance penalty. For now it's just in docstring.
out_mat =
T.(
normalise(
l.σ.(
l.convweight * X * lapl +
l.selfweight * X +
reduce(hcat, l.bias for i = 1:size(X, 2)),
),
dims = [1, 2],
),
)
lapl, out_mat
end

# alternate signature so FeaturizedAtoms can be fed into first layer
(l::AGNConv)(a::FeaturizedAtoms{AtomGraph,GraphNodeFeaturization}) =
l(a.atoms.laplacian, a.encoded_features)

# signature to splat appropriately
(l::AGNConv)(t::Tuple{Matrix{R1},Matrix{R2}}) where {R1<:Real,R2<:Real} = l(t...)

# fixes from Dhairya so backprop works
@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end
@nograd LinearAlgebra.diagm

@adjoint function Broadcast.broadcasted(Float32, a::SparseMatrixCSC{T,N}) where {T,N}
Float32.(a), Δ -> (nothing, T.(Δ))
end
@nograd issymmetric

@adjoint function Broadcast.broadcasted(Float64, a::SparseMatrixCSC{T,N}) where {T,N}
Float64.(a), Δ -> (nothing, T.(Δ))
end

@adjoint function softplus(x::Real)
y = softplus(x)
return y, Δ -> (Δ * σ(x),)
end
48 changes: 48 additions & 0 deletions src/layers/layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module Layers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this really needs to be a separate module since it's kind of the main/only thing the package does apart from the convenience functions for building standard model architectures, and I don't really see a risk of any sort of namespace conflicts...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a module because as I was re-organizing the files I felt like this could be more coherently organized if it were all in one place/module, now that we have different types of pooling layers and all that.
I'm not really particular about it being a module or not, so whatever works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'm 💯 fine with the file reorganization, I just don't think we need an actual explicit module.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in 77024f0.


#using DifferentialEquations, DiffEqSensitivity

include("conv/agnconv.jl")
include("pool/agnpool.jl")

# following commented out for now because it only runs suuuuper slowly but slows down precompilation a lot
"""
# DEQ-style model where we treat the convolution as a SteadyStateProblem
struct AGNConvDEQ{T,F}
conv::AGNConv{T,F}
end

function AGNConvDEQ(ch::Pair{<:Integer,<:Integer}, σ=softplus; initW=glorot_uniform, initb=glorot_uniform, T::DataType=Float32, bias::Bool=true)
conv = AGNConv(ch, σ; initW=initW, initb=initb, T=T)
AGNConvDEQ(conv)
end

@functor AGNConvDEQ

# set up SteadyStateProblem where the derivative is the convolution operation
# (we want the "fixed point" of the convolution)
# need it in the form f(u,p,t) (but t doesn't matter)
# u is the features, p is the parameters of conv
# re(p) reconstructs the convolution with new parameters p
function (l::AGNConvDEQ)(fa::FeaturizedAtoms)
p,re = Flux.destructure(l.conv)
# do one convolution to get initial guess
guess = l.conv(gr)[2]

f = function (dfeat,feat,p,t)
input = gr
input.encoded_features = reshape(feat,size(guess))
output = re(p)(input)
dfeat .= vec(output[2]) .- vec(input.encoded_features)
end

prob = SteadyStateProblem{true}(f, vec(guess), p)
#return solve(prob, DynamicSS(Tsit5())).u
alg = SSRootfind()
#alg = SSRootfind(nlsolve = (f,u0,abstol) -> (res=SteadyStateDiffEq.NLsolve.nlsolve(f,u0,autodiff=:forward,method=:anderson,iterations=Int(1e6),ftol=abstol);res.zero))
out_mat = reshape(solve(prob, alg, sensealg = SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP())).u,size(guess))
return AtomGraph(gr.graph, gr.elements, out_mat, gr.featurization)
end
"""

end
Loading