Skip to content

Commit

Permalink
Fix adapt for vert topo and Topology2D
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie Kawczynski committed Feb 14, 2025
1 parent c19162a commit 85eea25
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/Topologies/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct IntervalTopology{
boundaries::B
end

Adapt.@adapt_structure IntervalTopology

## gpu
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
boundaries::B
Expand Down
39 changes: 39 additions & 0 deletions src/Topologies/topology2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,45 @@ mutable struct Topology2D{
ghost_face_neighbor_loc::Vector{Int}
end

function Adapt.adapt_structure(to, topo::Topology2D)
return Topology2D(
Adapt.adapt(to, topo.context),
Adapt.adapt(to, topo.mesh),
Adapt.adapt(to, topo.elemorder),
Adapt.adapt(to, topo.orderindex),
topo.elempid,
topo.local_elem_gidx,
topo.neighbor_pids,
topo.send_elem_lidx,
topo.send_elem_lengths,
topo.recv_elem_gidx,
topo.recv_elem_lengths,
Adapt.adapt(to, topo.interior_faces),
Adapt.adapt(to, topo.ghost_faces),
Adapt.adapt(to, topo.local_vertices),
Adapt.adapt(to, topo.local_vertex_offset),
Adapt.adapt(to, topo.ghost_vertices),
Adapt.adapt(to, topo.ghost_vertex_offset),
Adapt.adapt(to, topo.local_neighbor_elem),
Adapt.adapt(to, topo.local_neighbor_elem_offset),
topo.ghost_neighbor_elem,
topo.ghost_neighbor_elem_offset,
Adapt.adapt(to, topo.boundaries),
topo.internal_elems,
topo.perimeter_elems,
topo.nglobalvertices,
topo.nglobalfaces,
topo.ghost_vertex_gcidx,
topo.ghost_face_gcidx,
topo.comm_vertex_lengths,
topo.comm_face_lengths,
topo.ghost_vertex_neighbor_loc,
topo.ghost_vertex_comm_idx_offset,
Adapt.adapt(to, topo.repr_ghost_vertex),
topo.ghost_face_neighbor_loc,
)
end

ClimaComms.device(topology::Topology2D) = ClimaComms.device(topology.context)
ClimaComms.array_type(topology::Topology2D) =
ClimaComms.array_type(topology.context.device)
Expand Down
66 changes: 46 additions & 20 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,24 @@ using ClimaCore.CommonSpaces
using ClimaCore.Grids
using Adapt

function test_adapt(cpu_space_in)
function test_adapt_types(space_fn; broken_space_type_match = false)
@static if ClimaComms.device() isa ClimaComms.CUDADevice
cpu_space = space_fn(ClimaComms.CPUSingleThreaded())
gpu_space = space_fn(ClimaComms.CUDADevice())
FT = Spaces.undertype(cpu_space)
f_cpu = Fields.Field(FT, cpu_space)
f_gpu = Fields.Field(FT, gpu_space)
f_cpu_from_gpu =
ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), f_gpu)
@test typeof(Fields.field_values(f_cpu_from_gpu)) ==
typeof(Fields.field_values(f_cpu))
@test typeof(axes(f_cpu_from_gpu)) == typeof(axes(f_cpu)) broken =
broken_space_type_match
end
end

function test_adapt(space_fn)
cpu_space_in = space_fn(ClimaComms.CPUSingleThreaded())
test_adapt_space(cpu_space_in)
cpu_f_in = Fields.Field(Float64, cpu_space_in)
cpu_f_out = Adapt.adapt(Array, cpu_f_in)
Expand All @@ -730,7 +747,10 @@ function test_adapt(cpu_space_in)
# cpu -> gpu
gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in)
@test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray
# gpu -> gpu
@test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice
@test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray

# gpu -> cpu
cpu_f_out =
ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out)
@test parent(Fields.field_values(cpu_f_out)) isa Array
Expand Down Expand Up @@ -772,8 +792,8 @@ function test_adapt_space(cpu_space_in)
end

@testset "Test Adapt" begin
space = ExtrudedCubedSphereSpace(;
device = ClimaComms.CPUSingleThreaded(),
ecs_space_fn(dev) = ExtrudedCubedSphereSpace(;
device = dev,
z_elem = 10,
z_min = 0,
z_max = 1,
Expand All @@ -782,27 +802,30 @@ end
n_quad_points = 4,
staggering = Grids.CellCenter(),
)
test_adapt(space)
test_adapt(ecs_space_fn)
test_adapt_types(ecs_space_fn; broken_space_type_match = true)

space = CubedSphereSpace(;
device = ClimaComms.CPUSingleThreaded(),
cs_space_fn(dev) = CubedSphereSpace(;
device = dev,
radius = 10,
n_quad_points = 4,
h_elem = 10,
)
test_adapt(space)
test_adapt(cs_space_fn)
test_adapt_types(cs_space_fn; broken_space_type_match = true)

space = ColumnSpace(;
device = ClimaComms.CPUSingleThreaded(),
column_space_fn(dev) = ColumnSpace(;
device = dev,
z_elem = 10,
z_min = 0,
z_max = 10,
staggering = CellCenter(),
)
test_adapt(space)
test_adapt(column_space_fn)
test_adapt_types(column_space_fn)

space = Box3DSpace(;
device = ClimaComms.CPUSingleThreaded(),
box_space_fn(dev) = Box3DSpace(;
device = dev,
z_elem = 10,
x_min = 0,
x_max = 1,
Expand All @@ -817,10 +840,11 @@ end
y_elem = 4,
staggering = CellCenter(),
)
test_adapt(space)
test_adapt(box_space_fn)
test_adapt_types(box_space_fn; broken_space_type_match = true)

space = SliceXZSpace(;
device = ClimaComms.CPUSingleThreaded(),
slice_space_fn(dev) = SliceXZSpace(;
device = dev,
z_elem = 10,
x_min = 0,
x_max = 1,
Expand All @@ -831,10 +855,11 @@ end
x_elem = 4,
staggering = CellCenter(),
)
test_adapt(space)
test_adapt(slice_space_fn)
# test_adapt_types(slice_space_fn) # not yet supported on gpus

space = RectangleXYSpace(;
device = ClimaComms.CPUSingleThreaded(),
rect_space_fn(dev) = RectangleXYSpace(;
device = dev,
x_min = 0,
x_max = 1,
y_min = 0,
Expand All @@ -845,7 +870,8 @@ end
x_elem = 3,
y_elem = 4,
)
test_adapt(space)
test_adapt(rect_space_fn)
test_adapt_types(rect_space_fn; broken_space_type_match = true)

# FieldVector
cspace = ExtrudedCubedSphereSpace(;
Expand Down

0 comments on commit 85eea25

Please sign in to comment.