From 5ee41b79e27fb0b501e03f5240fc8835e42c27ea Mon Sep 17 00:00:00 2001 From: Charlie Kawczynski Date: Fri, 14 Feb 2025 11:04:06 -0800 Subject: [PATCH] Return expected adapt devices for spaces with vertical topologies --- src/Grids/Grids.jl | 2 ++ src/Topologies/interval.jl | 2 ++ test/Fields/unit_field.jl | 11 ++++++++--- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/Grids/Grids.jl b/src/Grids/Grids.jl index 400aa269fb..9662dc4432 100644 --- a/src/Grids/Grids.jl +++ b/src/Grids/Grids.jl @@ -68,6 +68,8 @@ function vertical_topology end ClimaComms.context(grid::AbstractGrid) = ClimaComms.context(topology(grid)) ClimaComms.device(grid::AbstractGrid) = ClimaComms.device(topology(grid)) +ClimaComms.device(grid::ExtrudedFiniteDifferenceGrid) = + ClimaComms.device(vertical_topology(grid)) Meshes.domain(grid::AbstractGrid) = Meshes.domain(topology(grid)) diff --git a/src/Topologies/interval.jl b/src/Topologies/interval.jl index f1084e0d31..72c584d70a 100644 --- a/src/Topologies/interval.jl +++ b/src/Topologies/interval.jl @@ -16,6 +16,8 @@ struct IntervalTopology{ boundaries::B end +Adapt.@adapt_structure IntervalTopology + ## gpu struct DeviceIntervalTopology{B} <: AbstractIntervalTopology boundaries::B diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index f05afb048f..897950d2d0 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -706,7 +706,7 @@ using ClimaCore.CommonSpaces using ClimaCore.Grids using Adapt -function test_adapt(cpu_space_in) +function test_adapt(cpu_space_in; broken_device_adapt = false) test_adapt_space(cpu_space_in) cpu_f_in = Fields.Field(Float64, cpu_space_in) cpu_f_out = Adapt.adapt(Array, cpu_f_in) @@ -730,6 +730,11 @@ function test_adapt(cpu_space_in) # cpu -> gpu gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in) @test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray + @test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice broken = + broken_device_adapt + @test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray broken = + broken_device_adapt + # gpu -> gpu cpu_f_out = ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out) @@ -790,7 +795,7 @@ end n_quad_points = 4, h_elem = 10, ) - test_adapt(space) + test_adapt(space; broken_device_adapt = true) space = ColumnSpace(; device = ClimaComms.CPUSingleThreaded(), @@ -845,7 +850,7 @@ end x_elem = 3, y_elem = 4, ) - test_adapt(space) + test_adapt(space; broken_device_adapt = true) # FieldVector cspace = ExtrudedCubedSphereSpace(;