From 0154bb1a69d4f3c9390ccc58e31b9fa28ce6da6a Mon Sep 17 00:00:00 2001 From: Charlie Kawczynski Date: Fri, 14 Feb 2025 11:04:06 -0800 Subject: [PATCH] Fix adapt for vert topo and Topology2D --- src/Topologies/interval.jl | 2 ++ src/Topologies/topology2d.jl | 39 +++++++++++++++++++++ test/Fields/unit_field.jl | 66 +++++++++++++++++++++++++----------- 3 files changed, 87 insertions(+), 20 deletions(-) diff --git a/src/Topologies/interval.jl b/src/Topologies/interval.jl index f1084e0d31..72c584d70a 100644 --- a/src/Topologies/interval.jl +++ b/src/Topologies/interval.jl @@ -16,6 +16,8 @@ struct IntervalTopology{ boundaries::B end +Adapt.@adapt_structure IntervalTopology + ## gpu struct DeviceIntervalTopology{B} <: AbstractIntervalTopology boundaries::B diff --git a/src/Topologies/topology2d.jl b/src/Topologies/topology2d.jl index 607a4e40d1..3b5b4bb369 100644 --- a/src/Topologies/topology2d.jl +++ b/src/Topologies/topology2d.jl @@ -124,6 +124,45 @@ mutable struct Topology2D{ ghost_face_neighbor_loc::Vector{Int} end +function Adapt.adapt_structure(to, topo::Topology2D) + return Topology2D( + Adapt.adapt(to, topo.context), + Adapt.adapt(to, topo.mesh), + Adapt.adapt(to, topo.elemorder), + Adapt.adapt(to, topo.orderindex), + topo.elempid, + topo.local_elem_gidx, + topo.neighbor_pids, + topo.send_elem_lidx, + topo.send_elem_lengths, + topo.recv_elem_gidx, + topo.recv_elem_lengths, + Adapt.adapt(to, topo.interior_faces), + Adapt.adapt(to, topo.ghost_faces), + Adapt.adapt(to, topo.local_vertices), + Adapt.adapt(to, topo.local_vertex_offset), + Adapt.adapt(to, topo.ghost_vertices), + Adapt.adapt(to, topo.ghost_vertex_offset), + Adapt.adapt(to, topo.local_neighbor_elem), + Adapt.adapt(to, topo.local_neighbor_elem_offset), + topo.ghost_neighbor_elem, + topo.ghost_neighbor_elem_offset, + Adapt.adapt(to, topo.boundaries), + topo.internal_elems, + topo.perimeter_elems, + topo.nglobalvertices, + topo.nglobalfaces, + topo.ghost_vertex_gcidx, + topo.ghost_face_gcidx, + topo.comm_vertex_lengths, + topo.comm_face_lengths, + topo.ghost_vertex_neighbor_loc, + topo.ghost_vertex_comm_idx_offset, + Adapt.adapt(to, topo.repr_ghost_vertex), + topo.ghost_face_neighbor_loc, + ) +end + ClimaComms.device(topology::Topology2D) = ClimaComms.device(topology.context) ClimaComms.array_type(topology::Topology2D) = ClimaComms.array_type(topology.context.device) diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index f05afb048f..fa0492f4a9 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -706,7 +706,24 @@ using ClimaCore.CommonSpaces using ClimaCore.Grids using Adapt -function test_adapt(cpu_space_in) +function test_adapt_types(space_fn; broken_space_type_match = false) + @static if ClimaComms.device() isa ClimaComms.CUDADevice + cpu_space = space_fn(ClimaComms.CPUSingleThreaded()) + gpu_space = space_fn(ClimaComms.CUDADevice()) + FT = Spaces.undertype(cpu_space) + f_cpu = Fields.Field(FT, cpu_space) + f_gpu = Fields.Field(FT, gpu_space) + f_cpu_from_gpu = + ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), f_gpu) + @test typeof(Fields.field_values(f_cpu_from_gpu)) == + typeof(Fields.field_values(f_cpu)) + @test typeof(axes(f_cpu_from_gpu)) == typeof(axes(f_cpu)) broken = + broken_space_type_match + end +end + +function test_adapt(space_fn) + cpu_space_in = space_fn(ClimaComms.CPUSingleThreaded()) test_adapt_space(cpu_space_in) cpu_f_in = Fields.Field(Float64, cpu_space_in) cpu_f_out = Adapt.adapt(Array, cpu_f_in) @@ -730,7 +747,10 @@ function test_adapt(cpu_space_in) # cpu -> gpu gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in) @test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray - # gpu -> gpu + @test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice + @test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray + + # gpu -> cpu cpu_f_out = ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out) @test parent(Fields.field_values(cpu_f_out)) isa Array @@ -772,8 +792,8 @@ function test_adapt_space(cpu_space_in) end @testset "Test Adapt" begin - space = ExtrudedCubedSphereSpace(; - device = ClimaComms.CPUSingleThreaded(), + ecs_space_fn(dev) = ExtrudedCubedSphereSpace(; + device = dev, z_elem = 10, z_min = 0, z_max = 1, @@ -782,27 +802,30 @@ end n_quad_points = 4, staggering = Grids.CellCenter(), ) - test_adapt(space) + test_adapt(ecs_space_fn) + test_adapt_types(ecs_space_fn; broken_space_type_match = true) - space = CubedSphereSpace(; - device = ClimaComms.CPUSingleThreaded(), + cs_space_fn(dev) = CubedSphereSpace(; + device = dev, radius = 10, n_quad_points = 4, h_elem = 10, ) - test_adapt(space) + test_adapt(cs_space_fn) + test_adapt_types(cs_space_fn; broken_space_type_match = true) - space = ColumnSpace(; - device = ClimaComms.CPUSingleThreaded(), + column_space_fn(dev) = ColumnSpace(; + device = dev, z_elem = 10, z_min = 0, z_max = 10, staggering = CellCenter(), ) - test_adapt(space) + test_adapt(column_space_fn) + test_adapt_types(column_space_fn) - space = Box3DSpace(; - device = ClimaComms.CPUSingleThreaded(), + box_space_fn(dev) = Box3DSpace(; + device = dev, z_elem = 10, x_min = 0, x_max = 1, @@ -817,10 +840,11 @@ end y_elem = 4, staggering = CellCenter(), ) - test_adapt(space) + test_adapt(box_space_fn) + test_adapt_types(box_space_fn; broken_space_type_match = true) - space = SliceXZSpace(; - device = ClimaComms.CPUSingleThreaded(), + slice_space_fn(dev) = SliceXZSpace(; + device = dev, z_elem = 10, x_min = 0, x_max = 1, @@ -831,10 +855,11 @@ end x_elem = 4, staggering = CellCenter(), ) - test_adapt(space) + test_adapt(slice_space_fn) + # test_adapt_types(slice_space_fn) # not yet supported on gpus - space = RectangleXYSpace(; - device = ClimaComms.CPUSingleThreaded(), + rect_space_fn(dev) = RectangleXYSpace(; + device = dev, x_min = 0, x_max = 1, y_min = 0, @@ -845,7 +870,8 @@ end x_elem = 3, y_elem = 4, ) - test_adapt(space) + test_adapt(rect_space_fn) + test_adapt_types(rect_space_fn; broken_space_type_match = true) # FieldVector cspace = ExtrudedCubedSphereSpace(;