Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eigenpooling #90

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/AtomicGraphNets.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module AtomicGraphNets

export AGNConv, AGNPool#, AGNConvDEQ
include("layers.jl")
using .Layers: AGNConv, AGNPool
include("layers/layers.jl")
export AGNConv, AGNPool

include("models.jl")
export build_CGCNN, build_SGCNN
Expand Down
235 changes: 0 additions & 235 deletions src/layers.jl

This file was deleted.

102 changes: 102 additions & 0 deletions src/layers/conv/agnconv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Flux
using Flux: glorot_uniform, normalise, @functor#, destructure
using Zygote: @adjoint, @nograd
using LinearAlgebra, SparseArrays
using Statistics
using ChemistryFeaturization
#using DifferentialEquations, DiffEqSensitivity

"""
AGNConv(in=>out)

Atomic graph convolutional layer. Almost identical to GCNConv from GeometricFlux but adapted to be most similar to Tian's original AGNN structure, so explicitly has self and convolutional weights separately.

# Fields
- `selfweight::Array{T,2}`: weights applied to features at a node
- `convweight::Array{T,2}`: convolutional weights
- `bias::Array{T,2}`: additive bias (second dimension is always 1 because only learnable per-feature, not per-node)
- `σ::F`: activation function (will be applied before `reg_norm` to outputs), defaults to softplus

# Arguments
- `in::Integer`: the dimension of input features.
- `out::Integer`: the dimension of output features.
- `σ=softplus`: activation function
- `initW=glorot_uniform`: initialization function for weights
- `initb=zeros`: initialization function for biases

"""
struct AGNConv{T,F}
selfweight::Array{T,2}
convweight::Array{T,2}
bias::Array{T,2}
σ::F
end

function AGNConv(
ch::Pair{<:Integer,<:Integer},
σ = softplus;
initW = glorot_uniform,
initb = zeros,
T::DataType = Float64,
)
selfweight = T.(initW(ch[2], ch[1]))
convweight = T.(initW(ch[2], ch[1]))
b = T.(initb(ch[2], 1))
AGNConv(selfweight, convweight, b, σ)
end

@functor AGNConv

"""
Define action of layer on inputs: do a graph convolution, add this (weighted by convolutional weight) to the features themselves (weighted by self weight) and the per-feature bias (concatenated to match number of nodes in graph).

# Arguments
- input: a FeaturizedAtoms object, or graph_laplacian, encoded_features

# Note
In the case of providing two matrices, the following conditions must hold:
- `lapl` must be square and of dimension N x N where N is the number of nodes in the graph
- `X` (encoded features) must be of dimension M x N, where M is `size(l.convweight)[2]` (or equivalently, `size(l.selfweight)[2]`)
"""
function (l::AGNConv{T,F})(lapl::Matrix{<:Real}, X::Matrix{<:Real}) where {T<:Real,F}
# should we put dimension checks here? Could allow more informative errors, but would likely introduce performance penalty. For now it's just in docstring.
out_mat =
T.(
normalise(
l.σ.(
l.convweight * X * lapl +
l.selfweight * X +
reduce(hcat, l.bias for i = 1:size(X, 2)),
),
dims = [1, 2],
),
)
lapl, out_mat
end

# alternate signature so FeaturizedAtoms can be fed into first layer
(l::AGNConv)(a::FeaturizedAtoms{AtomGraph,GraphNodeFeaturization}) =
l(a.atoms.laplacian, a.encoded_features)

# signature to splat appropriately
(l::AGNConv)(t::Tuple{Matrix{R1},Matrix{R2}}) where {R1<:Real,R2<:Real} = l(t...)

# fixes from Dhairya so backprop works
@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end
@nograd LinearAlgebra.diagm

@adjoint function Broadcast.broadcasted(Float32, a::SparseMatrixCSC{T,N}) where {T,N}
Float32.(a), Δ -> (nothing, T.(Δ))
end
@nograd issymmetric

@adjoint function Broadcast.broadcasted(Float64, a::SparseMatrixCSC{T,N}) where {T,N}
Float64.(a), Δ -> (nothing, T.(Δ))
end

@adjoint function softplus(x::Real)
y = softplus(x)
return y, Δ -> (Δ * σ(x),)
end
44 changes: 44 additions & 0 deletions src/layers/layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#using DifferentialEquations, DiffEqSensitivity

include("conv/agnconv.jl")
include("pool/agnpool.jl")

# following commented out for now because it only runs suuuuper slowly but slows down precompilation a lot
"""
# DEQ-style model where we treat the convolution as a SteadyStateProblem
struct AGNConvDEQ{T,F}
conv::AGNConv{T,F}
end

function AGNConvDEQ(ch::Pair{<:Integer,<:Integer}, σ=softplus; initW=glorot_uniform, initb=glorot_uniform, T::DataType=Float32, bias::Bool=true)
conv = AGNConv(ch, σ; initW=initW, initb=initb, T=T)
AGNConvDEQ(conv)
end

@functor AGNConvDEQ

# set up SteadyStateProblem where the derivative is the convolution operation
# (we want the "fixed point" of the convolution)
# need it in the form f(u,p,t) (but t doesn't matter)
# u is the features, p is the parameters of conv
# re(p) reconstructs the convolution with new parameters p
function (l::AGNConvDEQ)(fa::FeaturizedAtoms)
p,re = Flux.destructure(l.conv)
# do one convolution to get initial guess
guess = l.conv(gr)[2]

f = function (dfeat,feat,p,t)
input = gr
input.encoded_features = reshape(feat,size(guess))
output = re(p)(input)
dfeat .= vec(output[2]) .- vec(input.encoded_features)
end

prob = SteadyStateProblem{true}(f, vec(guess), p)
#return solve(prob, DynamicSS(Tsit5())).u
alg = SSRootfind()
#alg = SSRootfind(nlsolve = (f,u0,abstol) -> (res=SteadyStateDiffEq.NLsolve.nlsolve(f,u0,autodiff=:forward,method=:anderson,iterations=Int(1e6),ftol=abstol);res.zero))
out_mat = reshape(solve(prob, alg, sensealg = SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP())).u,size(guess))
return AtomGraph(gr.graph, gr.elements, out_mat, gr.featurization)
end
"""
Loading