Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DifferentialEquations.jl extensions #74

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,39 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[weakdeps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
DynamicQuantitiesDiffEqBaseExt = "DiffEqBase"
DynamicQuantitiesLinearAlgebraExt = "LinearAlgebra"
DynamicQuantitiesMeasurementsExt = "Measurements"
DynamicQuantitiesRecursiveArrayToolsExt = "RecursiveArrayTools"
DynamicQuantitiesScientificTypesExt = "ScientificTypes"
DynamicQuantitiesUnitfulExt = "Unitful"

[compat]
Compat = "3.42, 4"
DiffEqBase = "6"
Measurements = "2"
PackageExtensionCompat = "1.0.2"
RecursiveArrayTools = "2"
ScientificTypes = "3"
Tricks = "0.1"
Unitful = "1"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Ratios = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SaferIntegers = "88634af6-177f-5301-88b8-7819386cfa38"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Expand All @@ -42,4 +50,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "LinearAlgebra", "Measurements", "Ratios", "SaferIntegers", "SafeTestsets", "ScientificTypes", "StaticArrays", "Test", "Unitful"]
test = ["Aqua", "DiffEqBase", "LinearAlgebra", "Measurements", "Ratios", "RecursiveArrayTools", "SaferIntegers", "SafeTestsets", "ScientificTypes", "StaticArrays", "Test", "Unitful"]
22 changes: 22 additions & 0 deletions ext/DynamicQuantitiesDiffEqBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module DynamicQuantitiesDiffEqBaseExt

using DynamicQuantities:
UnionAbstractQuantity, ustrip, QuantityArray

import DiffEqBase

DiffEqBase.value(x::UnionAbstractQuantity) = ustrip(x)
DiffEqBase.recursive_length(u::UnionAbstractQuantity) = recursive_length(ustrip(u))
DiffEqBase.recursive_length(u::QuantityArray) = recursive_length(ustrip(u))

@inline function DiffEqBase.UNITLESS_ABS2(x::UnionAbstractQuantity)
abs2(ustrip(x))
end
function DiffEqBase.abs2_and_sum(x::UnionAbstractQuantity, y)
reduce(Base.add_sum, ustrip(x), init = zero(real(DiffEqBase.value(x)))) +
reduce(Base.add_sum, y, init = zero(real(DiffEqBase.value(eltype(y)))))
end

DiffEqBase.NAN_CHECK(x::UnionAbstractQuantity) = NAN_CHECK(ustrip(x))

end
13 changes: 13 additions & 0 deletions ext/DynamicQuantitiesRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DynamicQuantitiesRecursiveArrayToolsExt

import DynamicQuantities: AbstractQuantity
import RecursiveArrayTools: RecursiveArrayTools as RAT

function RAT.recursive_unitless_bottom_eltype(::Type{Q}) where {T,Q<:AbstractQuantity{T}}
return T
end
function RAT.recursive_unitless_eltype(::Type{Q}) where {T,Q<:AbstractQuantity{T}}
return T
end

end
1 change: 1 addition & 0 deletions src/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Base.cbrt(q::UnionAbstractQuantity) = new_quantity(typeof(q), cbrt(ustrip(q)), c

Base.abs2(q::UnionAbstractQuantity) = new_quantity(typeof(q), abs2(ustrip(q)), dimension(q)^2)
Base.angle(q::UnionAbstractQuantity{T}) where {T<:Complex} = angle(ustrip(q))
Base.sign(q::UnionAbstractQuantity) = sign(ustrip(q))

############################## Require dimensionless input ##############################
# Note that :clamp, :cmp, :sign already work
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ single value type.
end
@inline promote_except_value(q1::Q, q2::Q) where {Q<:UnionAbstractQuantity} = (q1, q2)

Base.eltype(::Type{Q}) where {Q<:UnionAbstractQuantity} = Q
Base.keys(d::AbstractDimensions) = dimension_names(typeof(d))
Base.getindex(d::AbstractDimensions, k::Symbol) = getfield(d, k)

Expand Down Expand Up @@ -182,6 +183,7 @@ Base.zero(q::Q) where {Q<:UnionAbstractQuantity} = new_quantity(Q, zero(ustrip(q
Base.zero(::AbstractDimensions) = error("There is no such thing as an additive identity for a `AbstractDimensions` object, as + is only defined for `UnionAbstractQuantity`.")
Base.zero(::Type{<:UnionAbstractQuantity}) = error("Cannot create an additive identity for a `UnionAbstractQuantity` type, as the dimensions are unknown. Please use `zero(::UnionAbstractQuantity)` instead.")
Base.zero(::Type{<:AbstractDimensions}) = error("There is no such thing as an additive identity for a `AbstractDimensions` type, as + is only defined for `UnionAbstractQuantity`.")
Base.zero(x::Array{Q}) where {Q<:UnionAbstractQuantity} = zero.(x)

# Dimensionful 1 (oneunit)
Base.oneunit(q::Q) where {Q<:UnionAbstractQuantity} = new_quantity(Q, oneunit(ustrip(q)), dimension(q))
Expand Down
9 changes: 9 additions & 0 deletions test/unittests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using DynamicQuantities: DEFAULT_DIM_BASE_TYPE, DEFAULT_DIM_TYPE, DEFAULT_VALUE_
using DynamicQuantities: array_type, value_type, dim_type, quantity_type
using DynamicQuantities: GenericQuantity
using Ratios: SimpleRatio
using RecursiveArrayTools: RecursiveArrayTools as RAT
using SaferIntegers: SafeInt16
using StaticArrays: SArray, MArray
using LinearAlgebra: norm
Expand Down Expand Up @@ -1261,3 +1262,11 @@ end
end
end
end

@testset "RecursiveArrayTools" begin
for f in (RAT.recursive_unitless_bottom_eltype, RAT.recursive_unitless_eltype)
@test f([0.3u"km/s"]) == Float64
@test f([Quantity{Float32}(0.3u"km/s")]) == Float64
@test f([0.3Unitful.u"km/s"]) == Float64
end
end
Loading