diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 72704ed47..39ef610b3 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -38,6 +38,64 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T}) g.ndata, g.edata, g.gdata) end +""" + add_self_loops(g::GNNHeteroGraph, edge_t::EType) + +Return a graph with the same features as `g` +but also adding self-loops of the specified type, edge_t + +Nodes with already existing self-loops of type edge_t will obtain a second self-loop of type edge_t. + +If the graphs has edge weights for edges of type edge_t, the new edges will have weight 1. + +If no edges of type edge_t exist, or all existing edges have no weight, then all new self loops will have no weight. +""" +function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V} + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + get(g.graph, edge_t, (nothing, nothing, nothing))[3] + end + + src_t, _, tgt_t = edge_t + (src_t === tgt_t) || + @error "cannot add a self-loop with different source and target types" + + n = get(g.num_nodes, src_t, 0) + + if haskey(g.graph, edge_t) + x = g.graph[edge_t] + s, t = x[1:2] + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + else + nodes = convert(T, [1:n;]) + s = nodes + t = nodes + end + + graph = g.graph |> copy + ew = get(g.graph, edge_t, (nothing, nothing, nothing))[3] + + if ew !== nothing + ew = [ew; fill!(similar(ew, n), 1)] + end + + graph[edge_t] = (s, t, ew) + edata = g.edata |> copy + ndata = g.ndata |> copy + ntypes = g.ntypes |> copy + etypes = g.etypes |> copy + num_nodes = g.num_nodes |> copy + num_edges = g.num_edges |> copy + num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1]) + + return GNNHeteroGraph(graph, + num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata, + ntypes, etypes) +end + """ remove_self_loops(g::GNNGraph) diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index 93a97d83b..d56ba5d1c 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -462,4 +462,30 @@ end @test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8] end end + + @testset "add self-loops heterographs" begin + g = rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20)) + # Case in which haskey(g.graph, edge_t) passes + g = add_self_loops(g, (:A, :to1, :A)) + + @test g.num_edges[(:A, :to1, :A)] == 5 + 10 + @test g.num_edges[(:A, :to1, :B)] == 20 + # This test should not use length(keys(g.num_edges)) since that may be undefined behavior + @test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 2 + + # Case in which haskey(g.graph, edge_t) fails + g = add_self_loops(g, (:A, :to3, :A)) + + @test g.num_edges[(:A, :to1, :A)] == 5 + 10 + @test g.num_edges[(:A, :to1, :B)] == 20 + @test g.num_edges[(:A, :to3, :A)] == 10 + @test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 3 + + # Case with edge weights + g = GNNHeteroGraph(Dict((:A, :to1, :A) => ([1, 2, 3], [3, 2, 1], [2, 2, 2]), (:A, :to2, :B) => ([1, 4, 5], [1, 2, 3]))) + n = g.num_nodes[:A] + g = add_self_loops(g, (:A, :to1, :A)) + + @test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n)) + end end \ No newline at end of file