Skip to content

Commit

Permalink
Simplify initialization of types for N, dNdx, and dNdξ (#858)
Browse files Browse the repository at this point in the history
This patch simplifies the initialization of types for `N`, `dNdx`, and
`dNdξ` used in the `FunctionValues` constructor. The change in LOC isn't
significant but I think it is easier to follow when there are no
auxiliary types and when they are grouped by "case" instead of function
name. Fixes #857.
  • Loading branch information
fredrikekre authored Dec 8, 2023
1 parent ca8fba1 commit 5350d2d
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions src/FEValues/FunctionValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,25 @@
# vdim = vector dimension (dimension of the field) #
#################################################################

# Helpers to get the correct types for FunctionValues for the given function and, if needed, geometric interpolations.
struct SInterpolationDims{rdim,sdim} end
struct VInterpolationDims{rdim,sdim,vdim} end
function InterpolationDims(::ScalarInterpolation, ip_geo::VectorizedInterpolation{sdim}) where sdim
return SInterpolationDims{getdim(ip_geo),sdim}()
end
function InterpolationDims(::VectorInterpolation{vdim}, ip_geo::VectorizedInterpolation{sdim}) where {vdim,sdim}
return VInterpolationDims{getdim(ip_geo),sdim,vdim}()
end

typeof_N(::Type{T}, ::SInterpolationDims) where T = T
typeof_N(::Type{T}, ::VInterpolationDims{<:Any,dim,dim}) where {T,dim} = Vec{dim,T}
typeof_N(::Type{T}, ::VInterpolationDims{<:Any,<:Any,vdim}) where {T,vdim} = SVector{vdim,T} # Why not ::Vec here?

typeof_dNdx(::Type{T}, ::SInterpolationDims{dim,dim}) where {T,dim} = Vec{dim,T}
typeof_dNdx(::Type{T}, ::SInterpolationDims{<:Any,sdim}) where {T,sdim} = SVector{sdim,T} # Why not ::Vec here?
typeof_dNdx(::Type{T}, ::VInterpolationDims{dim,dim,dim}) where {T,dim} = Tensor{2,dim,T}
typeof_dNdx(::Type{T}, ::VInterpolationDims{<:Any,sdim,vdim}) where {T,sdim,vdim} = SMatrix{vdim,sdim,T} # If vdim=sdim!=rdim Tensor would be possible...

typeof_dNdξ(::Type{T}, ::SInterpolationDims{dim,dim}) where {T,dim} = Vec{dim,T}
typeof_dNdξ(::Type{T}, ::SInterpolationDims{rdim}) where {T,rdim} = SVector{rdim,T} # Why not ::Vec here?
typeof_dNdξ(::Type{T}, ::VInterpolationDims{dim,dim,dim}) where {T,dim} = Tensor{2,dim,T}
typeof_dNdξ(::Type{T}, ::VInterpolationDims{rdim,<:Any,vdim}) where {T,rdim,vdim} = SMatrix{vdim,rdim,T} # If vdim=rdim!=sdim Tensor would be possible...
# Scalar, sdim == rdim sdim rdim
typeof_N( ::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = T
typeof_dNdx(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Vec{dim, T}
typeof_dNdξ(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Vec{dim, T}

# Vector, vdim == sdim == rdim vdim sdim rdim
typeof_N( ::Type{T}, ::VectorInterpolation{dim}, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Vec{dim, T}
typeof_dNdx(::Type{T}, ::VectorInterpolation{dim}, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Tensor{2, dim, T}
typeof_dNdξ(::Type{T}, ::VectorInterpolation{dim}, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Tensor{2, dim, T}

# Scalar, sdim != rdim (TODO: Use Vec if (s|r)dim <= 3?)
typeof_N( ::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, sdim, rdim} = T
typeof_dNdx(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, sdim, rdim} = SVector{sdim, T}
typeof_dNdξ(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, sdim, rdim} = SVector{rdim, T}

# Vector, vdim != sdim != rdim (TODO: Use Vec/Tensor if (s|r)dim <= 3?)
typeof_N( ::Type{T}, ::VectorInterpolation{vdim}, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, vdim, sdim, rdim} = SVector{vdim, T}
typeof_dNdx(::Type{T}, ::VectorInterpolation{vdim}, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, vdim, sdim, rdim} = SMatrix{vdim, sdim, T}
typeof_dNdξ(::Type{T}, ::VectorInterpolation{vdim}, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, vdim, sdim, rdim} = SMatrix{vdim, rdim, T}

"""
FunctionValues{DiffOrder}(::Type{T}, ip_fun, qr::QuadratureRule, ip_geo::VectorizedInterpolation)
Expand All @@ -51,18 +47,17 @@ struct FunctionValues{DiffOrder, IP, N_t, dNdx_t, dNdξ_t}
end
end
function FunctionValues{DiffOrder}(::Type{T}, ip::Interpolation, qr::QuadratureRule, ip_geo::VectorizedInterpolation) where {DiffOrder, T}
ip_dims = InterpolationDims(ip, ip_geo)
n_shape = getnbasefunctions(ip)
n_qpoints = getnquadpoints(qr)

= zeros(typeof_N(T, ip_dims), n_shape, n_qpoints)
= zeros(typeof_N(T, ip, ip_geo), n_shape, n_qpoints)
Nx = isa(mapping_type(ip), IdentityMapping) ?: similar(Nξ)

if DiffOrder == 0
dNdξ = dNdx = nothing
elseif DiffOrder == 1
dNdξ = zeros(typeof_dNdξ(T, ip_dims), n_shape, n_qpoints)
dNdx = fill(zero(typeof_dNdx(T, ip_dims)) * T(NaN), n_shape, n_qpoints)
dNdξ = zeros(typeof_dNdξ(T, ip, ip_geo), n_shape, n_qpoints)
dNdx = fill(zero(typeof_dNdx(T, ip, ip_geo)) * T(NaN), n_shape, n_qpoints)
else
throw(ArgumentError("Currently only values and gradients can be updated in FunctionValues"))
end
Expand Down

0 comments on commit 5350d2d

Please sign in to comment.