Skip to content

Commit

Permalink
Merge pull request #112 from yuehhua/develop
Browse files Browse the repository at this point in the history
FeaturedSubgraph supports DataLoader
  • Loading branch information
yuehhua authored Aug 2, 2022
2 parents 98ecaee + ab58d3d commit de22869
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function MLUtils.numobs(fg::FeaturedGraph)
function MLUtils.numobs(fg::AbstractFeaturedGraph)
obs_size = 0
if has_node_feature(fg)
nf_obs_size = numobs(node_feature(fg))
Expand Down Expand Up @@ -27,10 +27,10 @@ function check_obs_size(obs_size, feat_obs_size, feat::String)
return feat_obs_size
end

function MLUtils.getobs(fg::FeaturedGraph, idx)
function MLUtils.getobs(fg::AbstractFeaturedGraph, idx)
nf = has_node_feature(fg) ? getobs(node_feature(fg), idx) : node_feature(fg)
ef = has_edge_feature(fg) ? getobs(edge_feature(fg), idx) : edge_feature(fg)
gf = has_global_feature(fg) ? getobs(global_feature(fg), idx) : global_feature(fg)
pf = has_positional_feature(fg) ? getobs(positional_feature(fg), idx) : positional_feature(fg)
return FeaturedGraph(fg, nf=nf, ef=ef, gf=gf, pf=pf)
return ConcreteFeaturedGraph(fg, nf=nf, ef=ef, gf=gf, pf=pf)
end
6 changes: 6 additions & 0 deletions src/subgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,18 @@ end

Graphs.adjacency_matrix(fsg::FeaturedSubgraph) = view(adjacency_matrix(fsg.fg), fsg.nodes, fsg.nodes)

has_node_feature(fsg::FeaturedSubgraph) = has_node_feature(fsg.fg)
node_feature(fsg::FeaturedSubgraph) = node_feature(fsg.fg)

has_edge_feature(fsg::FeaturedSubgraph) = has_edge_feature(fsg.fg)
edge_feature(fsg::FeaturedSubgraph) = edge_feature(fsg.fg)

has_global_feature(fsg::FeaturedSubgraph) = has_global_feature(fsg.fg)
global_feature(fsg::FeaturedSubgraph) = global_feature(fsg.fg)

has_positional_feature(fsg::FeaturedSubgraph) = has_positional_feature(fsg.fg)
positional_feature(fsg::FeaturedSubgraph) = positional_feature(fsg.fg)

Graphs.is_directed(fsg::FeaturedSubgraph) = is_directed(fsg.fg)

Graphs.neighbors(fsg::FeaturedSubgraph) = mapreduce(i -> neighbors(graph(fsg), i), vcat, fsg.nodes)
Expand Down
13 changes: 13 additions & 0 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@
@test global_feature(idxed_fg) == global_feature(fg)
@test positional_feature(idxed_fg) == positional_feature(fg)
end

fg = subgraph(FeaturedGraph(adjm; nf=nf, ef=ef), [1, 3, 4, 5])
@test numobs(fg) == obs_size
@test getobs(fg) == fg

for idx in (2, 2:5, [1, 3, 5])
idxed_fg = getobs(fg, idx)
@test graph(idxed_fg) == graph(fg)
@test node_feature(idxed_fg) == node_feature(fg)[:, :, idx]
@test edge_feature(idxed_fg) == edge_feature(fg)[:, :, idx]
@test global_feature(idxed_fg) == global_feature(fg)
@test positional_feature(idxed_fg) == positional_feature(fg)
end
end

@testset "shuffleobs" begin
Expand Down

0 comments on commit de22869

Please sign in to comment.