diff --git a/NEWS.md b/NEWS.md index 2f7c4129..199f1ee6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,19 @@ ClimaUtilities.jl Release Notes main ------ +v0.1.21 +------ + +#### Simplified `TimeVaryingInput0D`, removed `context` argument. PR [#148](https://github.com/CliMA/ClimaUtilities.jl/pull/148) + +The need for a dedicated CUDA extension was by leveraging `ClimaCore` functions. +As a result, the code for `TimeVaryingInput0D` could be significantly +simplified, while attaining greater performance at the same time. + +As a result, the keyword argument `context` is no longer required in +constructing this type of `TimeVaryingInput`s. In the future, the argument will +be removed. + v0.1.20 ------ diff --git a/Project.toml b/Project.toml index 532aae5d..c9c59b66 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaUtilities" uuid = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" authors = ["Gabriele Bozzola ", "Julia Sloan "] -version = "0.1.20" +version = "0.1.21" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -21,7 +21,6 @@ ClimaUtilitiesClimaCoreExt = "ClimaCore" ClimaUtilitiesClimaCoreInterpolationsExt = ["ClimaCore", "Interpolations"] ClimaUtilitiesClimaCoreNCDatasetsExt = ["ClimaCore", "NCDatasets"] ClimaUtilitiesNCDatasetsExt = "NCDatasets" -ClimaUtilitiesCUDAExt = "CUDA" ClimaUtilitiesClimaCoreTempestRemapExt = "ClimaCoreTempestRemap" [compat] diff --git a/docs/src/inputs.md b/docs/src/inputs.md index cd237bbe..1750c43f 100644 --- a/docs/src/inputs.md +++ b/docs/src/inputs.md @@ -27,7 +27,7 @@ regridding onto the computational domains (using [`Regridders`](@ref regridder_m `TimeVaryingInputs` support: - analytic functions of time; -- pairs of 1D arrays (for `PointSpaces`); +- pairs of 1D arrays (e.g., for `PointSpaces` or constant fields); - 2/3D NetCDF files (including composing multiple variables from one or more files into one variable); - linear interpolation in time (default), nearest neighbors, and "period filling"; - boundary conditions and repeating periodic data. diff --git a/ext/ClimaUtilitiesCUDAExt.jl b/ext/ClimaUtilitiesCUDAExt.jl deleted file mode 100644 index 2a7e82df..00000000 --- a/ext/ClimaUtilitiesCUDAExt.jl +++ /dev/null @@ -1,31 +0,0 @@ -module ClimaUtilitiesCUDAExt - -import ClimaComms -import ClimaUtilities -import ClimaUtilities.TimeVaryingInputs -import CUDA - -function TimeVaryingInputs.evaluate!( - device::ClimaComms.CUDADevice, - destination, - itp, - time, - args...; - kwargs..., -) - # Cannot do type dispatch across extensions, we check here - @assert itp isa - Base.get_extension( - ClimaUtilities, - :ClimaUtilitiesClimaCoreExt, - ).TimeVaryingInputs0DExt.InterpolatingTimeVaryingInput0D - CUDA.@cuda TimeVaryingInputs.evaluate!( - parent(destination), - itp, - time, - itp.method, - ) - return nothing -end - -end diff --git a/ext/TimeVaryingInputs0DExt.jl b/ext/TimeVaryingInputs0DExt.jl index fe7c4b1d..7b35144c 100644 --- a/ext/TimeVaryingInputs0DExt.jl +++ b/ext/TimeVaryingInputs0DExt.jl @@ -2,7 +2,6 @@ module TimeVaryingInputs0DExt import ClimaCore import ClimaCore: ClimaComms -import ClimaCore: DeviceSideContext import ClimaCore.Fields: Adapt import ClimaUtilities.Utils: @@ -35,7 +34,6 @@ struct InterpolatingTimeVaryingInput0D{ AA1 <: AbstractArray, AA2 <: AbstractArray, M <: AbstractInterpolationMethod, - CC <: ClimaComms.AbstractCommsContext, R <: Tuple, } <: AbstractTimeVaryingInput # AA1 and AA2 could be different because of different FTs @@ -49,9 +47,6 @@ struct InterpolatingTimeVaryingInput0D{ """Interpolation method""" method::M - """ClimaComms context""" - context::CC - """Range of times over which the interpolator is defined. range is always defined on the CPU. Used by the in() function.""" range::R @@ -66,23 +61,6 @@ function Base.in(time, itp::InterpolatingTimeVaryingInput0D) return itp.range[1] <= time <= itp.range[2] end - -# GPU compatibility -function Adapt.adapt_structure(to, itp::InterpolatingTimeVaryingInput0D) - times = Adapt.adapt_structure(to, itp.times) - vals = Adapt.adapt_structure(to, itp.vals) - method = Adapt.adapt_structure(to, itp.method) - range = Adapt.adapt_structure(to, itp.range) - # On a GPU, we have a "ClimaCore.DeviceSideContext" - InterpolatingTimeVaryingInput0D( - times, - vals, - method, - DeviceSideContext(), - range, - ) -end - function TimeVaryingInputs.evaluate!( destination, itp::InterpolatingTimeVaryingInput0D, @@ -93,35 +71,29 @@ function TimeVaryingInputs.evaluate!( if extrapolation_bc(itp.method) isa Throw time in itp || error("TimeVaryingInput does not cover time $time") end - TimeVaryingInputs.evaluate!( - ClimaComms.device(itp.context), - parent(destination), - itp, - time, - itp.method, - ) + scalar_dest = [zero(eltype(destination))] - return nothing -end + TimeVaryingInputs.evaluate!(scalar_dest, itp, time, itp.method) + fill!(destination, scalar_dest[]) -function TimeVaryingInputs.evaluate!( - device::ClimaComms.AbstractCPUDevice, - destination, - itp::InterpolatingTimeVaryingInput0D, - time, - args...; - kwargs..., -) - TimeVaryingInputs.evaluate!(parent(destination), itp, time, itp.method) return nothing end function TimeVaryingInputs.TimeVaryingInput( times::AbstractArray, vals::AbstractArray; + context = nothing, method::AbstractInterpolationMethod = LinearInterpolation(), - context = ClimaComms.context(), ) + ########### DEPRECATED ############### + if !isnothing(context) + Base.depwarn( + "The keyword argument `context` is no longer required for TimeVaryingInputs. It will be removed.", + :TimeVaryingInput, + ) + end + ########### DEPRECATED ############### + issorted(times) || error("Can only interpolate with sorted times") length(times) == length(vals) || error("times and vals have different lengths") @@ -145,16 +117,11 @@ function TimeVaryingInputs.TimeVaryingInput( end end - # When device is CUDADevice, ArrayType will be a CUDADevice, so that times and vals get - # copied to the GPU. - ArrayType = ClimaComms.array_type(ClimaComms.device(context)) - range = (times[begin], times[end]) return InterpolatingTimeVaryingInput0D( - ArrayType(times), - ArrayType(vals), + copy(times), + copy(vals), method, - context, range, ) end diff --git a/test/time_varying_inputs.jl b/test/time_varying_inputs.jl index 7129fe22..f06e55bd 100644 --- a/test/time_varying_inputs.jl +++ b/test/time_varying_inputs.jl @@ -117,7 +117,6 @@ end input = TimeVaryingInputs.TimeVaryingInput( times, vals; - context, method = TimeVaryingInputs.NearestNeighbor(), ) @@ -125,7 +124,6 @@ end input_clamp = TimeVaryingInputs.TimeVaryingInput( times, vals; - context, method = TimeVaryingInputs.NearestNeighbor( TimeVaryingInputs.Flat(), ), @@ -135,7 +133,6 @@ end input_periodic_calendar = TimeVaryingInputs.TimeVaryingInput( times, vals; - context, method = TimeVaryingInputs.NearestNeighbor( TimeVaryingInputs.PeriodicCalendar(), ), @@ -145,7 +142,6 @@ end input_periodic_calendar_linear = TimeVaryingInputs.TimeVaryingInput( times, vals; - context, method = TimeVaryingInputs.LinearInterpolation( TimeVaryingInputs.PeriodicCalendar(), ), @@ -226,7 +222,7 @@ end @test Array(parent(dest))[1] == vals[10] # Linear interpolation - input = TimeVaryingInputs.TimeVaryingInput(times, vals; context) + input = TimeVaryingInputs.TimeVaryingInput(times, vals) TimeVaryingInputs.evaluate!(dest, input, 0.1)