Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add heterogeneous add_self_loop support #345

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,62 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T})
g.ndata, g.edata, g.gdata)
end

"""
add_self_loops(g::GNNHeteroGraph, edge_t::EType)

Return a graph with the same features as `g`
but also adding self-loops of the specified type, edge_t

Nodes with already existing self-loops of type edge_t will obtain a second self-loop of type edge_t.

If the graphs has edge weights for edges of type edge_t, the new edges will have weight 1.
"""
function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
end

src_t, _, tgt_t = edge_t
(src_t === tgt_t) ||
@error "cannot add a self-loop with different source and target types"

n = get(g.num_nodes, src_t, 0)

# By avoiding using haskey, this only calls ht_keyindex once instead of twice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# By avoiding using haskey, this only calls ht_keyindex once instead of twice

These kinds of performance concerns for dictionary queries are irrelevant. The heavy operations are the copies and the concatenations. Therefore the concern here should be to make the code has readable as possible, which means using haskey in my opinion.

if (x = get(g.graph, edge_t, nothing)) !== nothing
s, t = x[1:2]
nodes = convert(typeof(s), [1:n;])
s = [s; nodes]
t = [t; nodes]
else
nodes = [1:n;]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a cpu array even if the rest of graph lives on gpu. We should create a new array similar to one in the existing relations.

s = nodes
t = nodes
end

graph = g.graph |> copy
ew = get_edge_weight_nullable(g, edge_t)

if ew !== nothing
ew = [ew; fill!(similar(ew, n), 1)]
end

graph[edge_t] = (s, t, ew)
edata = g.edata |> copy
ndata = g.ndata |> copy
ntypes = g.ntypes |> copy
etypes = g.etypes |> copy
num_nodes = g.num_nodes |> copy
num_edges = g.num_edges |> copy
num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1])

return GNNHeteroGraph(graph,
num_nodes, num_edges, g.num_graphs,
g.graph_indicator,
ndata, edata, g.gdata,
ntypes, etypes)
end

"""
remove_self_loops(g::GNNGraph)

Expand Down
17 changes: 17 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,4 +462,21 @@ end
@test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8]
end
end

@testset "add self-loops heterographs" begin
g = rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20))
g = add_self_loops(g, (:A, :to1, :A))

@test g.num_edges[(:A, :to1, :A)] == 5 + 10
@test g.num_edges[(:A, :to1, :B)] == 20

# This test should not use length(keys(g.num_edges)) since that may be undefined behavior
@test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 2

g = GNNHeteroGraph(Dict((:A, :to1, :A) => ([1, 2, 3], [3, 2, 1], [2, 2, 2]), (:A, :to2, :B) => ([1, 4, 5], [1, 2, 3])))
n = g.num_nodes[:A]
g = add_self_loops(g, (:A, :to1, :A))

@test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n))
end
end