From 1eb88f30db036f2873c0ed913533a2219c2f0ab9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 27 Oct 2023 01:53:04 +0100 Subject: [PATCH 1/3] Add RecursiveArrayTools.jl extension --- Project.toml | 6 +++++- ext/DynamicQuantitiesRecursiveArrayToolsExt.jl | 13 +++++++++++++ src/utils.jl | 6 ++++-- test/unittests.jl | 9 +++++++++ 4 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 ext/DynamicQuantitiesRecursiveArrayToolsExt.jl diff --git a/Project.toml b/Project.toml index ae6fafb0..1a5e42e6 100644 --- a/Project.toml +++ b/Project.toml @@ -11,12 +11,14 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [weakdeps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] DynamicQuantitiesLinearAlgebraExt = "LinearAlgebra" DynamicQuantitiesMeasurementsExt = "Measurements" +DynamicQuantitiesRecursiveArrayToolsExt = "RecursiveArrayTools" DynamicQuantitiesScientificTypesExt = "ScientificTypes" DynamicQuantitiesUnitfulExt = "Unitful" @@ -24,6 +26,7 @@ DynamicQuantitiesUnitfulExt = "Unitful" Compat = "3.42, 4" Measurements = "2" PackageExtensionCompat = "1.0.2" +RecursiveArrayTools = "2" ScientificTypes = "3" Tricks = "0.1" Unitful = "1" @@ -34,6 +37,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" Ratios = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SaferIntegers = "88634af6-177f-5301-88b8-7819386cfa38" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" @@ -42,4 +46,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Aqua", "LinearAlgebra", "Measurements", "Ratios", "SaferIntegers", "SafeTestsets", "ScientificTypes", "StaticArrays", "Test", "Unitful"] +test = ["Aqua", "LinearAlgebra", "Measurements", "Ratios", "RecursiveArrayTools", "SaferIntegers", "SafeTestsets", "ScientificTypes", "StaticArrays", "Test", "Unitful"] diff --git a/ext/DynamicQuantitiesRecursiveArrayToolsExt.jl b/ext/DynamicQuantitiesRecursiveArrayToolsExt.jl new file mode 100644 index 00000000..9190402e --- /dev/null +++ b/ext/DynamicQuantitiesRecursiveArrayToolsExt.jl @@ -0,0 +1,13 @@ +module DynamicQuantitiesRecursiveArrayToolsExt + +import DynamicQuantities: AbstractQuantity +import RecursiveArrayTools: RecursiveArrayTools as RAT + +function RAT.recursive_unitless_bottom_eltype(::Type{Q}) where {T,Q<:AbstractQuantity{T}} + return T +end +function RAT.recursive_unitless_eltype(::Type{Q}) where {T,Q<:AbstractQuantity{T}} + return T +end + +end diff --git a/src/utils.jl b/src/utils.jl index 41ee4650..44ea1739 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,6 +35,8 @@ Base.convert(::Type{T}, q::AbstractQuantity) where {T<:Real} = Base.promote_rule(::Type{Dimensions{R1}}, ::Type{Dimensions{R2}}) where {R1,R2} = Dimensions{promote_type(R1,R2)} Base.promote_rule(::Type{Q1}, ::Type{Q2}) where {T1,T2,D1,D2,Q1<:Quantity{T1,D1},Q2<:Quantity{T2,D2}} = Quantity{promote_type(T1,T2),promote_type(D1,D2)} +Base.eltype(::Type{Q}) where {Q<:AbstractQuantity} = Q + Base.keys(d::AbstractDimensions) = static_fieldnames(typeof(d)) Base.getindex(d::AbstractDimensions, k::Symbol) = getfield(d, k) @@ -96,8 +98,8 @@ Base.:(==)(::AbstractQuantity, ::WeakRef) = error("Cannot compare a quantity to Base.:(==)(::WeakRef, ::AbstractQuantity) = error("Cannot compare a weakref to a quantity") -# Simple flags: -for f in (:iszero, :isfinite, :isinf, :isnan, :isreal) +# Simple flags and returns: +for f in (:iszero, :isfinite, :isinf, :isnan, :isreal, :sign) @eval Base.$f(q::AbstractQuantity) = $f(ustrip(q)) end diff --git a/test/unittests.jl b/test/unittests.jl index d53176c9..6f581f83 100644 --- a/test/unittests.jl +++ b/test/unittests.jl @@ -3,6 +3,7 @@ using DynamicQuantities: FixedRational using DynamicQuantities: DEFAULT_DIM_BASE_TYPE, DEFAULT_DIM_TYPE, DEFAULT_VALUE_TYPE using DynamicQuantities: array_type, value_type, dim_type, quantity_type using Ratios: SimpleRatio +using RecursiveArrayTools: RecursiveArrayTools as RAT using SaferIntegers: SafeInt16 using StaticArrays: SArray, MArray using LinearAlgebra: norm @@ -915,3 +916,11 @@ end @test DynamicQuantities.materialize_first(ref) === x[1] end end + +@testset "RecursiveArrayTools" begin + for f in (RAT.recursive_unitless_bottom_eltype, RAT.recursive_unitless_eltype) + @test f([0.3u"km/s"]) == Float64 + @test f([Quantity{Float32}(0.3u"km/s")]) == Float64 + @test f([0.3Unitful.u"km/s"]) == Float64 + end +end From 2e33b6728fb6e835ec6af07143febedf0d8df6a6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 12 Nov 2023 21:09:13 +0000 Subject: [PATCH 2/3] Start on DiffEqBase extension --- Project.toml | 6 ++++- ext/DynamicQuantitiesDiffEqBaseExt.jl | 36 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 ext/DynamicQuantitiesDiffEqBaseExt.jl diff --git a/Project.toml b/Project.toml index 1a5e42e6..dbdeb01f 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [weakdeps] +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -16,6 +17,7 @@ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] +DynamicQuantitiesDiffEqBaseExt = "DiffEqBase" DynamicQuantitiesLinearAlgebraExt = "LinearAlgebra" DynamicQuantitiesMeasurementsExt = "Measurements" DynamicQuantitiesRecursiveArrayToolsExt = "RecursiveArrayTools" @@ -24,6 +26,7 @@ DynamicQuantitiesUnitfulExt = "Unitful" [compat] Compat = "3.42, 4" +DiffEqBase = "6" Measurements = "2" PackageExtensionCompat = "1.0.2" RecursiveArrayTools = "2" @@ -34,6 +37,7 @@ julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" Ratios = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" @@ -46,4 +50,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Aqua", "LinearAlgebra", "Measurements", "Ratios", "RecursiveArrayTools", "SaferIntegers", "SafeTestsets", "ScientificTypes", "StaticArrays", "Test", "Unitful"] +test = ["Aqua", "DiffEqBase", "LinearAlgebra", "Measurements", "Ratios", "RecursiveArrayTools", "SaferIntegers", "SafeTestsets", "ScientificTypes", "StaticArrays", "Test", "Unitful"] diff --git a/ext/DynamicQuantitiesDiffEqBaseExt.jl b/ext/DynamicQuantitiesDiffEqBaseExt.jl new file mode 100644 index 00000000..945f8e44 --- /dev/null +++ b/ext/DynamicQuantitiesDiffEqBaseExt.jl @@ -0,0 +1,36 @@ +module DynamicQuantitiesDiffEqBaseExt + +using DynamicQuantities: UnionAbstractQuantity, ustrip + +import DiffEqBase + +DiffEqBase.value(x::UnionAbstractQuantity) = ustrip(x) +DiffEqBase.recursive_length(u::UnionAbstractQuantity) = length(u) +DiffEqBase.recursive_length(u::QuantityArray) = length(u) + +# @inline function DiffEqBase.UNITLESS_ABS2(x::UnionAbstractQuantity) +# abs(DiffEqBase.value(x)) +# end +# function DiffEqBase.abs2_and_sum(x::DynamicQuantities.Quantity, y::Float64) +# reduce(Base.add_sum, DiffEqBase.value(x), init = zero(real(DiffEqBase.value(x)))) + +# reduce(Base.add_sum, y, init = zero(real(DiffEqBase.value(eltype(y))))) +# end + +# Base.sign(x::DynamicQuantities.Quantity) = Base.sign(DiffEqBase.value(x)) + +# function DiffEqBase.prob2dtmin(prob; use_end_time = true) +# DiffEqBase.prob2dtmin(prob.tspan, oneunit(first(prob.tspan)), use_end_time) +# end + +# DiffEqBase.NAN_CHECK(x::DynamicQuantities.Quantity) = isnan(x) +# Base.zero(x::Array{T}) where {T<:DynamicQuantities.Quantity} = zero.(x) + +# @inline function DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t) +# @. DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t) +# end + +# f(u, p, t) = u / t; +# problem = ODEProblem(f, [1.0u"km/s"], (0.0u"s", 1.0u"s")); +# sol = solve(problem, Tsit5(), dt = 0.1u"s") + +end \ No newline at end of file From 8e03c10916438e7bce3468bbe1eb2aa8fbbb4df0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 12 Nov 2023 21:29:36 +0000 Subject: [PATCH 3/3] Add other missing parts of DiffEqBase --- ext/DynamicQuantitiesDiffEqBaseExt.jl | 42 +++++++++------------------ src/utils.jl | 1 + 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/ext/DynamicQuantitiesDiffEqBaseExt.jl b/ext/DynamicQuantitiesDiffEqBaseExt.jl index 945f8e44..cd8e2c93 100644 --- a/ext/DynamicQuantitiesDiffEqBaseExt.jl +++ b/ext/DynamicQuantitiesDiffEqBaseExt.jl @@ -1,36 +1,22 @@ module DynamicQuantitiesDiffEqBaseExt -using DynamicQuantities: UnionAbstractQuantity, ustrip +using DynamicQuantities: + UnionAbstractQuantity, ustrip, QuantityArray import DiffEqBase DiffEqBase.value(x::UnionAbstractQuantity) = ustrip(x) -DiffEqBase.recursive_length(u::UnionAbstractQuantity) = length(u) -DiffEqBase.recursive_length(u::QuantityArray) = length(u) - -# @inline function DiffEqBase.UNITLESS_ABS2(x::UnionAbstractQuantity) -# abs(DiffEqBase.value(x)) -# end -# function DiffEqBase.abs2_and_sum(x::DynamicQuantities.Quantity, y::Float64) -# reduce(Base.add_sum, DiffEqBase.value(x), init = zero(real(DiffEqBase.value(x)))) + -# reduce(Base.add_sum, y, init = zero(real(DiffEqBase.value(eltype(y))))) -# end - -# Base.sign(x::DynamicQuantities.Quantity) = Base.sign(DiffEqBase.value(x)) - -# function DiffEqBase.prob2dtmin(prob; use_end_time = true) -# DiffEqBase.prob2dtmin(prob.tspan, oneunit(first(prob.tspan)), use_end_time) -# end - -# DiffEqBase.NAN_CHECK(x::DynamicQuantities.Quantity) = isnan(x) -# Base.zero(x::Array{T}) where {T<:DynamicQuantities.Quantity} = zero.(x) - -# @inline function DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t) -# @. DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t) -# end - -# f(u, p, t) = u / t; -# problem = ODEProblem(f, [1.0u"km/s"], (0.0u"s", 1.0u"s")); -# sol = solve(problem, Tsit5(), dt = 0.1u"s") +DiffEqBase.recursive_length(u::UnionAbstractQuantity) = recursive_length(ustrip(u)) +DiffEqBase.recursive_length(u::QuantityArray) = recursive_length(ustrip(u)) + +@inline function DiffEqBase.UNITLESS_ABS2(x::UnionAbstractQuantity) + abs2(ustrip(x)) +end +function DiffEqBase.abs2_and_sum(x::UnionAbstractQuantity, y) + reduce(Base.add_sum, ustrip(x), init = zero(real(DiffEqBase.value(x)))) + + reduce(Base.add_sum, y, init = zero(real(DiffEqBase.value(eltype(y))))) +end + +DiffEqBase.NAN_CHECK(x::UnionAbstractQuantity) = NAN_CHECK(ustrip(x)) end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index f8f48e85..161017b9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -183,6 +183,7 @@ Base.zero(q::Q) where {Q<:UnionAbstractQuantity} = new_quantity(Q, zero(ustrip(q Base.zero(::AbstractDimensions) = error("There is no such thing as an additive identity for a `AbstractDimensions` object, as + is only defined for `UnionAbstractQuantity`.") Base.zero(::Type{<:UnionAbstractQuantity}) = error("Cannot create an additive identity for a `UnionAbstractQuantity` type, as the dimensions are unknown. Please use `zero(::UnionAbstractQuantity)` instead.") Base.zero(::Type{<:AbstractDimensions}) = error("There is no such thing as an additive identity for a `AbstractDimensions` type, as + is only defined for `UnionAbstractQuantity`.") +Base.zero(x::Array{Q}) where {Q<:UnionAbstractQuantity} = zero.(x) # Dimensionful 1 (oneunit) Base.oneunit(q::Q) where {Q<:UnionAbstractQuantity} = new_quantity(Q, oneunit(ustrip(q)), dimension(q))