Skip to content

Commit

Permalink
Return expected adapt devices for spaces with vertical topologies
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie Kawczynski committed Feb 14, 2025
1 parent c19162a commit 5ee41b7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/Grids/Grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ function vertical_topology end

ClimaComms.context(grid::AbstractGrid) = ClimaComms.context(topology(grid))
ClimaComms.device(grid::AbstractGrid) = ClimaComms.device(topology(grid))
ClimaComms.device(grid::ExtrudedFiniteDifferenceGrid) =
ClimaComms.device(vertical_topology(grid))

Meshes.domain(grid::AbstractGrid) = Meshes.domain(topology(grid))

Expand Down
2 changes: 2 additions & 0 deletions src/Topologies/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct IntervalTopology{
boundaries::B
end

Adapt.@adapt_structure IntervalTopology

## gpu
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
boundaries::B
Expand Down
11 changes: 8 additions & 3 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ using ClimaCore.CommonSpaces
using ClimaCore.Grids
using Adapt

function test_adapt(cpu_space_in)
function test_adapt(cpu_space_in; broken_device_adapt = false)
test_adapt_space(cpu_space_in)
cpu_f_in = Fields.Field(Float64, cpu_space_in)
cpu_f_out = Adapt.adapt(Array, cpu_f_in)
Expand All @@ -730,6 +730,11 @@ function test_adapt(cpu_space_in)
# cpu -> gpu
gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in)
@test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray
@test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice broken =
broken_device_adapt
@test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray broken =
broken_device_adapt

# gpu -> gpu
cpu_f_out =
ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out)
Expand Down Expand Up @@ -790,7 +795,7 @@ end
n_quad_points = 4,
h_elem = 10,
)
test_adapt(space)
test_adapt(space; broken_device_adapt = true)

space = ColumnSpace(;
device = ClimaComms.CPUSingleThreaded(),
Expand Down Expand Up @@ -845,7 +850,7 @@ end
x_elem = 3,
y_elem = 4,
)
test_adapt(space)
test_adapt(space; broken_device_adapt = true)

# FieldVector
cspace = ExtrudedCubedSphereSpace(;
Expand Down

0 comments on commit 5ee41b7

Please sign in to comment.