Skip to content

Commit

Permalink
Custom Norm Func for GCNConv
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Jan 1, 2024
1 parent fc97137 commit 094ee01
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
16 changes: 9 additions & 7 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ and optionally an edge weight vector.
# Forward
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing) -> AbstractMatrix
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing, normalization_fn::Function = (d) -> 1 ./ sqrt.(d)) -> AbstractMatrix
Takes as input a graph `g`,ca node feature matrix `x` of size `[in, num_nodes]`,
and optionally an edge weight vector. Returns a node feature matrix of size
Expand All @@ -53,9 +53,10 @@ l = GCNConv(3 => 5)
# forward pass
y = l(g, x) # size: 5 × num_nodes
# convolution with edge weights
# convolution with edge weights and custom normalization function
w = [1.1, 0.1, 2.3, 0.5]
y = l(g, x, w)
custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function
y = l(g, x, w, custom_norm_fn)
# Edge weights can also be embedded in the graph.
g = GNNGraph(s, t, w)
Expand Down Expand Up @@ -98,7 +99,8 @@ check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing

function (l::GCNConv)(g::GNNGraph,
x::AbstractMatrix{T},
edge_weight::EW = nothing
edge_weight::EW = nothing,
normalization_fn::Function = (d) -> 1 ./ sqrt.(d)
) where {T, EW <: Union{Nothing, AbstractVector}}

check_gcnconv_input(g, edge_weight)
Expand All @@ -122,7 +124,7 @@ function (l::GCNConv)(g::GNNGraph,
else
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
end
c = 1 ./ sqrt.(d)
c = normalization_fn(d)
x = x .* c'
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
Expand All @@ -139,9 +141,9 @@ function (l::GCNConv)(g::GNNGraph,
end

function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
edge_weight::AbstractVector, normalization_fn::Function)
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
return l(g, x, edge_weight)
return l(g, x, edge_weight, normalization_fn)
end

function Base.show(io::IO, l::GCNConv)
Expand Down
5 changes: 3 additions & 2 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ test_graphs = [g1, g_single_vertex]
l = GCNConv(in_channel => out_channel, add_self_loops = false)
test_layer(l, g1, rtol = RTOL_HIGH, outsize = (out_channel, g1.num_nodes))

@testset "edge weights" begin
@testset "edge weights & custom normalization" begin
s = [2, 3, 1, 3, 1, 2]
t = [1, 1, 2, 2, 3, 3]
w = T[1, 2, 3, 4, 5, 6]
g = GNNGraph((s, t, w), ndata = ones(T, 1, 3), graph_type = GRAPH_T)
x = g.ndata.x
custom_norm_fn(d) = 1 ./ sqrt.(d)
l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true)
l.weight .= 1
d = degree(g, dir = :in, edge_weight = true)
y = l(g, x)
@test y[1, 1] w[1] / (d[1] * d[2]) + w[2] / (d[1] * d[3])
@test y[1, 2] w[3] / (d[2] * d[1]) + w[4] / (d[2] * d[3])
@test y l(g, x, w)
@test y l(g, x, w, custom_norm_fn)

# test gradient with respect to edge weights
w = rand(T, 6)
Expand Down

0 comments on commit 094ee01

Please sign in to comment.