From 5350d2d5f18fb8bc4746d877dab393d33380a81f Mon Sep 17 00:00:00 2001
From: Fredrik Ekre <ekrefredrik@gmail.com>
Date: Fri, 8 Dec 2023 01:02:13 -0800
Subject: [PATCH] =?UTF-8?q?Simplify=20initialization=20of=20types=20for=20?=
 =?UTF-8?q?`N`,=20`dNdx`,=20and=20`dNd=CE=BE`=20(#858)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch simplifies the initialization of types for `N`, `dNdx`, and
`dNdξ` used in the `FunctionValues` constructor. The change in LOC isn't
significant but I think it is easier to follow when there are no
auxiliary types and when they are grouped by "case" instead of function
name. Fixes #857.
---
 src/FEValues/FunctionValues.jl | 49 +++++++++++++++-------------------
 1 file changed, 22 insertions(+), 27 deletions(-)

diff --git a/src/FEValues/FunctionValues.jl b/src/FEValues/FunctionValues.jl
index 7250116892..167dff7d22 100644
--- a/src/FEValues/FunctionValues.jl
+++ b/src/FEValues/FunctionValues.jl
@@ -5,29 +5,25 @@
 # vdim = vector dimension (dimension of the field)              #
 #################################################################
 
-# Helpers to get the correct types for FunctionValues for the given function and, if needed, geometric interpolations.
-struct SInterpolationDims{rdim,sdim} end
-struct VInterpolationDims{rdim,sdim,vdim} end
-function InterpolationDims(::ScalarInterpolation, ip_geo::VectorizedInterpolation{sdim}) where sdim
-    return SInterpolationDims{getdim(ip_geo),sdim}()
-end
-function InterpolationDims(::VectorInterpolation{vdim}, ip_geo::VectorizedInterpolation{sdim}) where {vdim,sdim}
-    return VInterpolationDims{getdim(ip_geo),sdim,vdim}()
-end
-
-typeof_N(::Type{T}, ::SInterpolationDims) where T = T
-typeof_N(::Type{T}, ::VInterpolationDims{<:Any,dim,dim}) where {T,dim} = Vec{dim,T}
-typeof_N(::Type{T}, ::VInterpolationDims{<:Any,<:Any,vdim}) where {T,vdim} = SVector{vdim,T} # Why not ::Vec here?
-
-typeof_dNdx(::Type{T}, ::SInterpolationDims{dim,dim}) where {T,dim} = Vec{dim,T}
-typeof_dNdx(::Type{T}, ::SInterpolationDims{<:Any,sdim}) where {T,sdim} = SVector{sdim,T} # Why not ::Vec here?
-typeof_dNdx(::Type{T}, ::VInterpolationDims{dim,dim,dim}) where {T,dim} = Tensor{2,dim,T}
-typeof_dNdx(::Type{T}, ::VInterpolationDims{<:Any,sdim,vdim}) where {T,sdim,vdim} = SMatrix{vdim,sdim,T} # If vdim=sdim!=rdim Tensor would be possible...
-
-typeof_dNdξ(::Type{T}, ::SInterpolationDims{dim,dim}) where {T,dim} = Vec{dim,T}
-typeof_dNdξ(::Type{T}, ::SInterpolationDims{rdim}) where {T,rdim} = SVector{rdim,T} # Why not ::Vec here?
-typeof_dNdξ(::Type{T}, ::VInterpolationDims{dim,dim,dim}) where {T,dim} = Tensor{2,dim,T}
-typeof_dNdξ(::Type{T}, ::VInterpolationDims{rdim,<:Any,vdim}) where {T,rdim,vdim} = SMatrix{vdim,rdim,T} # If vdim=rdim!=sdim Tensor would be possible...
+# Scalar, sdim == rdim                                                 sdim                     rdim
+typeof_N(   ::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = T
+typeof_dNdx(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Vec{dim, T}
+typeof_dNdξ(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Vec{dim, T}
+
+# Vector, vdim == sdim == rdim              vdim                            sdim                     rdim
+typeof_N(   ::Type{T}, ::VectorInterpolation{dim}, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Vec{dim, T}
+typeof_dNdx(::Type{T}, ::VectorInterpolation{dim}, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Tensor{2, dim, T}
+typeof_dNdξ(::Type{T}, ::VectorInterpolation{dim}, ::VectorizedInterpolation{dim, <: AbstractRefShape{dim}}) where {T, dim} = Tensor{2, dim, T}
+
+# Scalar, sdim != rdim (TODO: Use Vec if (s|r)dim <= 3?)
+typeof_N(   ::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, sdim, rdim} = T
+typeof_dNdx(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, sdim, rdim} = SVector{sdim, T}
+typeof_dNdξ(::Type{T}, ::ScalarInterpolation, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, sdim, rdim} = SVector{rdim, T}
+
+# Vector, vdim != sdim != rdim (TODO: Use Vec/Tensor if (s|r)dim <= 3?)
+typeof_N(   ::Type{T}, ::VectorInterpolation{vdim}, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, vdim, sdim, rdim} = SVector{vdim, T}
+typeof_dNdx(::Type{T}, ::VectorInterpolation{vdim}, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, vdim, sdim, rdim} = SMatrix{vdim, sdim, T}
+typeof_dNdξ(::Type{T}, ::VectorInterpolation{vdim}, ::VectorizedInterpolation{sdim, <: AbstractRefShape{rdim}}) where {T, vdim, sdim, rdim} = SMatrix{vdim, rdim, T}
 
 """
     FunctionValues{DiffOrder}(::Type{T}, ip_fun, qr::QuadratureRule, ip_geo::VectorizedInterpolation)
@@ -51,18 +47,17 @@ struct FunctionValues{DiffOrder, IP, N_t, dNdx_t, dNdξ_t}
     end
 end
 function FunctionValues{DiffOrder}(::Type{T}, ip::Interpolation, qr::QuadratureRule, ip_geo::VectorizedInterpolation) where {DiffOrder, T}
-    ip_dims = InterpolationDims(ip, ip_geo)
     n_shape = getnbasefunctions(ip)
     n_qpoints = getnquadpoints(qr)
     
-    Nξ = zeros(typeof_N(T, ip_dims), n_shape, n_qpoints)
+    Nξ = zeros(typeof_N(T, ip, ip_geo), n_shape, n_qpoints)
     Nx = isa(mapping_type(ip), IdentityMapping) ? Nξ : similar(Nξ)
 
     if DiffOrder == 0
         dNdξ = dNdx = nothing
     elseif DiffOrder == 1
-        dNdξ = zeros(typeof_dNdξ(T, ip_dims),               n_shape, n_qpoints)
-        dNdx = fill(zero(typeof_dNdx(T, ip_dims)) * T(NaN), n_shape, n_qpoints)
+        dNdξ = zeros(typeof_dNdξ(T, ip, ip_geo),               n_shape, n_qpoints)
+        dNdx = fill(zero(typeof_dNdx(T, ip, ip_geo)) * T(NaN), n_shape, n_qpoints)
     else
         throw(ArgumentError("Currently only values and gradients can be updated in FunctionValues"))
     end