Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 1, 2024
1 parent 3b0b083 commit e79edc6
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
10 changes: 3 additions & 7 deletions GNNlib/test/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ end

@testset "copy_xj +" begin
for g in TEST_GRAPHS
dev = gpu_device(force=true)
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
broken = get_graph_type(g) == :sparse && gpu_backend() == "AMDGPU"
f(g, x) = propagate(copy_xj, g, +, xj = x)
@test test_gradients(
f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false
Expand All @@ -179,8 +178,7 @@ end

@testset "copy_xj mean" begin
for g in TEST_GRAPHS
dev = gpu_device(force=true)
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
broken = get_graph_type(g) == :sparse && gpu_backend() == "AMDGPU"
f(g, x) = propagate(copy_xj, g, mean, xj = x)
@test test_gradients(
f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false
Expand All @@ -190,8 +188,7 @@ end

@testset "e_mul_xj +" begin
for g in TEST_GRAPHS
dev = gpu_device(force=true)
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
broken = get_graph_type(g) == :sparse && gpu_backend() == "AMDGPU"
e = rand(Float32, size(g.x, 1), g.num_edges)
f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e)
@test test_gradients(
Expand All @@ -207,7 +204,6 @@ end
g = set_edge_weight(g, w)
return propagate(w_mul_xj, g, +, xj = x)
end
dev = gpu_device(force=true)
# @show get_graph_type(g) has_isolated_nodes(g)
# broken = get_graph_type(g) == :sparse
broken = true
Expand Down
16 changes: 15 additions & 1 deletion GNNlib/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using Flux: Flux
# from this module
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
test_gradients, finitediff_withgradient,
check_equal_leaves
check_equal_leaves, gpu_backend


const D_IN = 3
Expand Down Expand Up @@ -177,4 +177,18 @@ TEST_GRAPHS = [generate_test_graphs(:coo)...,
generate_test_graphs(:dense)...,
generate_test_graphs(:sparse)...]


function gpu_backend()
dev = gpu_device()
if dev isa CUDADevice
return "CUDA"
elseif dev isa AMDGPUDevice
return "AMDGPU"
elseif dev isa MetalDevice
return "Metal"
else
return "Unknown"
end
end

end # module
2 changes: 1 addition & 1 deletion GraphNeuralNetworks/test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
l = ChebConv(D_IN => D_OUT, k)
for g in TEST_GRAPHS
has_isolated_nodes(g) && continue
broken = get_graph_type(g) == :sparse || gpu_device() isa AMDGPUDevice
broken = get_graph_type(g) == :sparse || gpu_backend() == "AMDGPU"
@test size(l(g, g.x)) == (D_OUT, g.num_nodes) broken=broken
@test test_gradients(
l, g, g.x, rtol = RTOL_LOW, test_gpu = true, compare_finite_diff = false
Expand Down
26 changes: 23 additions & 3 deletions GraphNeuralNetworks/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@ using SparseArrays
# from Base
export mean, randn, SparseArrays, AbstractSparseMatrix

# from other packages
export Flux, gradient, Dense, Chain, relu, random_regular_graph, erdos_renyi,
BatchNorm, LayerNorm, Dropout, Parallel
# from Flux.jl
export Flux, gradient, Dense, Chain, relu
BatchNorm, LayerNorm, Dropout, Parallel,
gpu_device, cpu_device, get_device,
CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,
gpu_backend

# from Graphs.jl
export random_regular_graph, erdos_renyi

# from this module
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
Expand Down Expand Up @@ -178,5 +184,19 @@ TEST_GRAPHS = [generate_test_graphs(:coo)...,
generate_test_graphs(:dense)...,
generate_test_graphs(:sparse)...]


function gpu_backend()
dev = gpu_device()
if dev isa CUDADevice
return "CUDA"
elseif dev isa AMDGPUDevice
return "AMDGPU"
elseif dev isa MetalDevice
return "Metal"
else
return "Unknown"
end
end

end # testmodule

0 comments on commit e79edc6

Please sign in to comment.