diff --git a/FIAT/__init__.py b/FIAT/__init__.py index 26c467a9..a68db2bb 100644 --- a/FIAT/__init__.py +++ b/FIAT/__init__.py @@ -11,6 +11,10 @@ from FIAT.mixed import MixedElement # noqa: F401 from FIAT.restricted import RestrictedElement # noqa: F401 from FIAT.quadrature_element import QuadratureElement # noqa: F401 +from FIAT.tensor_product import TensorProductElement # noqa: F401 +from FIAT.enriched import EnrichedElement # noqa: F401 +from FIAT.nodal_enriched import NodalEnrichedElement # noqa: F401 +from FIAT.discontinuous import DiscontinuousElement # noqa: F401 # Import finite element classes from FIAT.argyris import Argyris @@ -24,9 +28,9 @@ from FIAT.christiansen_hu import ChristiansenHu from FIAT.johnson_mercier import JohnsonMercier from FIAT.brezzi_douglas_marini import BrezziDouglasMarini -from FIAT.Sminus import TrimmedSerendipityEdge, TrimmedSerendipityFace # noqa: F401 -from FIAT.SminusDiv import TrimmedSerendipityDiv # noqa: F401 -from FIAT.SminusCurl import TrimmedSerendipityCurl # noqa: F401 +from FIAT.Sminus import TrimmedSerendipityEdge, TrimmedSerendipityFace +from FIAT.SminusDiv import TrimmedSerendipityDiv +from FIAT.SminusCurl import TrimmedSerendipityCurl from FIAT.brezzi_douglas_fortin_marini import BrezziDouglasFortinMarini from FIAT.discontinuous_lagrange import DiscontinuousLagrange from FIAT.discontinuous_taylor import DiscontinuousTaylor @@ -56,10 +60,6 @@ from FIAT.hu_zhang import HuZhang from FIAT.mardal_tai_winther import MardalTaiWinther from FIAT.bubble import Bubble, FacetBubble -from FIAT.tensor_product import TensorProductElement -from FIAT.enriched import EnrichedElement -from FIAT.nodal_enriched import NodalEnrichedElement -from FIAT.discontinuous import DiscontinuousElement from FIAT.hdiv_trace import HDivTrace from FIAT.kong_mulder_veldhuizen import KongMulderVeldhuizen from FIAT.fdm_element import FDMLagrange, FDMDiscontinuousLagrange, FDMQuadrature, FDMBrokenH1, FDMBrokenL2, FDMHermite # noqa: F401 @@ -76,6 +76,10 @@ "Crouzeix-Raviart": CrouzeixRaviart, "Discontinuous Lagrange": DiscontinuousLagrange, "S": Serendipity, + "SminusF": TrimmedSerendipityFace, + "SminusDiv": TrimmedSerendipityDiv, + "SminusE": TrimmedSerendipityEdge, + "SminusCurl": TrimmedSerendipityCurl, "Brezzi-Douglas-Marini Cube Face": BrezziDouglasMariniCubeFace, "Brezzi-Douglas-Marini Cube Edge": BrezziDouglasMariniCubeEdge, "DPC": DPC, @@ -83,6 +87,8 @@ "Discontinuous Raviart-Thomas": DiscontinuousRaviartThomas, "Hermite": CubicHermite, "Hsieh-Clough-Tocher": HsiehCloughTocher, + "QuadraticPowellSabin6": QuadraticPowellSabin6, + "QuadraticPowellSabin12": QuadraticPowellSabin12, "Alfeld-Sorokina": AlfeldSorokina, "Arnold-Qin": ArnoldQin, "Christiansen-Hu": ChristiansenHu, @@ -102,12 +108,6 @@ "Nedelec 2nd kind H(curl)": NedelecSecondKind, "Raviart-Thomas": RaviartThomas, "Regge": Regge, - "EnrichedElement": EnrichedElement, - "NodalEnrichedElement": NodalEnrichedElement, - "QuadraticPowellSabin6": QuadraticPowellSabin6, - "QuadraticPowellSabin12": QuadraticPowellSabin12, - "TensorProductElement": TensorProductElement, - "BrokenElement": DiscontinuousElement, "HDiv Trace": HDivTrace, "Hellan-Herrmann-Johnson": HellanHerrmannJohnson, "Gopalakrishnan-Lederer-Schoberl 1st kind": GopalakrishnanLedererSchoberlFirstKind, diff --git a/finat/__init__.py b/finat/__init__.py index cad01821..0bf386c7 100644 --- a/finat/__init__.py +++ b/finat/__init__.py @@ -1,48 +1,44 @@ -from .fiat_elements import Bernstein # noqa: F401 -from .fiat_elements import Bubble, CrouzeixRaviart, DiscontinuousTaylor # noqa: F401 -from .fiat_elements import Lagrange, DiscontinuousLagrange, Real # noqa: F401 -from .fiat_elements import DPC, Serendipity, BrezziDouglasMariniCubeEdge, BrezziDouglasMariniCubeFace # noqa: F401 -from .fiat_elements import TrimmedSerendipityFace, TrimmedSerendipityEdge # noqa: F401 -from .fiat_elements import TrimmedSerendipityDiv # noqa: F401 -from .fiat_elements import TrimmedSerendipityCurl # noqa: F401 -from .fiat_elements import BrezziDouglasMarini, BrezziDouglasFortinMarini # noqa: F401 -from .fiat_elements import Nedelec, NedelecSecondKind, RaviartThomas # noqa: F401 -from .fiat_elements import HellanHerrmannJohnson, Regge # noqa: F401 -from .fiat_elements import GopalakrishnanLedererSchoberlFirstKind # noqa: F401 -from .fiat_elements import GopalakrishnanLedererSchoberlSecondKind # noqa: F401 -from .fiat_elements import FacetBubble # noqa: F401 -from .fiat_elements import KongMulderVeldhuizen # noqa: F401 +from .fiat_elements import (Bernstein, Bubble, BrezziDouglasFortinMarini, # noqa: F401 + BrezziDouglasMarini, BrezziDouglasMariniCubeEdge, # noqa: F401 + BrezziDouglasMariniCubeFace, CrouzeixRaviart, # noqa: F401 + DiscontinuousLagrange, DiscontinuousTaylor, DPC, # noqa: F401 + FacetBubble, GopalakrishnanLedererSchoberlFirstKind, # noqa: F401 + GopalakrishnanLedererSchoberlSecondKind, HellanHerrmannJohnson, # noqa: F401 + KongMulderVeldhuizen, Lagrange, Real, Serendipity, # noqa: F401 + TrimmedSerendipityCurl, TrimmedSerendipityDiv, # noqa: F401 + TrimmedSerendipityEdge, TrimmedSerendipityFace, # noqa: F401 + Nedelec, NedelecSecondKind, RaviartThomas, Regge) # noqa: F401 -from .argyris import Argyris # noqa: F401 -from .aw import ArnoldWinther # noqa: F401 -from .aw import ArnoldWintherNC # noqa: F401 -from .hz import HuZhang # noqa: F401 -from .bell import Bell # noqa: F401 -from .bernardi_raugel import BernardiRaugel, BernardiRaugelBubble # noqa: F401 -from .hct import HsiehCloughTocher, ReducedHsiehCloughTocher # noqa: F401 -from .arnold_qin import ArnoldQin, ReducedArnoldQin # noqa: F401 -from .christiansen_hu import ChristiansenHu # noqa: F401 -from .alfeld_sorokina import AlfeldSorokina # noqa: F401 -from .guzman_neilan import GuzmanNeilanFirstKindH1, GuzmanNeilanSecondKindH1, GuzmanNeilanBubble, GuzmanNeilanH1div # noqa: F401 +from .argyris import Argyris # noqa: F401 +from .aw import ArnoldWinther, ArnoldWintherNC # noqa: F401 +from .hz import HuZhang # noqa: F401 +from .bell import Bell # noqa: F401 +from .bernardi_raugel import BernardiRaugel, BernardiRaugelBubble # noqa: F401 +from .hct import HsiehCloughTocher, ReducedHsiehCloughTocher # noqa: F401 +from .arnold_qin import ArnoldQin, ReducedArnoldQin # noqa: F401 +from .christiansen_hu import ChristiansenHu # noqa: F401 +from .alfeld_sorokina import AlfeldSorokina # noqa: F401 +from .guzman_neilan import GuzmanNeilanFirstKindH1, GuzmanNeilanSecondKindH1, GuzmanNeilanBubble, GuzmanNeilanH1div # noqa: F401 from .powell_sabin import QuadraticPowellSabin6, QuadraticPowellSabin12 # noqa: F401 -from .hermite import Hermite # noqa: F401 -from .johnson_mercier import JohnsonMercier # noqa: F401 -from .mtw import MardalTaiWinther # noqa: F401 -from .morley import Morley # noqa: F401 -from .trace import HDivTrace # noqa: F401 -from .direct_serendipity import DirectSerendipity # noqa: F401 +from .hermite import Hermite # noqa: F401 +from .johnson_mercier import JohnsonMercier # noqa: F401 +from .mtw import MardalTaiWinther # noqa: F401 +from .morley import Morley # noqa: F401 +from .trace import HDivTrace # noqa: F401 +from .direct_serendipity import DirectSerendipity # noqa: F401 from .spectral import GaussLobattoLegendre, GaussLegendre, Legendre, IntegratedLegendre, FDMLagrange, FDMQuadrature, FDMDiscontinuousLagrange, FDMBrokenH1, FDMBrokenL2, FDMHermite # noqa: F401 -from .tensorfiniteelement import TensorFiniteElement # noqa: F401 -from .tensor_product import TensorProductElement # noqa: F401 -from .cube import FlattenedDimensions # noqa: F401 -from .discontinuous import DiscontinuousElement # noqa: F401 -from .enriched import EnrichedElement # noqa: F401 -from .hdivcurl import HCurlElement, HDivElement # noqa: F401 -from .mixed import MixedElement # noqa: F401 -from .nodal_enriched import NodalEnrichedElement # noqa: 401 +from .tensorfiniteelement import TensorFiniteElement # noqa: F401 +from .tensor_product import TensorProductElement # noqa: F401 +from .cube import FlattenedDimensions # noqa: F401 +from .discontinuous import DiscontinuousElement # noqa: F401 +from .enriched import EnrichedElement # noqa: F401 +from .hdivcurl import HCurlElement, HDivElement # noqa: F401 +from .mixed import MixedElement # noqa: F401 +from .nodal_enriched import NodalEnrichedElement # noqa: F401 from .quadrature_element import QuadratureElement, make_quadrature_element # noqa: F401 -from .restricted import RestrictedElement # noqa: F401 -from .runtime_tabulated import RuntimeTabulated # noqa: F401 -from . import quadrature # noqa: F401 -from . import cell_tools # noqa: F401 +from .restricted import RestrictedElement # noqa: F401 +from .runtime_tabulated import RuntimeTabulated # noqa: F401 +from . import quadrature # noqa: F401 +from . import cell_tools # noqa: F401 +from . import element_factory # noqa: F401 diff --git a/finat/element_factory.py b/finat/element_factory.py new file mode 100644 index 00000000..317fa7a0 --- /dev/null +++ b/finat/element_factory.py @@ -0,0 +1,368 @@ +# This file was modified from FFC +# (http://bitbucket.org/fenics-project/ffc), copyright notice +# reproduced below. +# +# Copyright (C) 2009-2013 Kristian B. Oelgaard and Anders Logg +# +# This file is part of FFC. +# +# FFC is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# FFC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with FFC. If not, see . + +import weakref +from functools import singledispatch, cache + +import finat +import finat.ufl +import ufl + +from FIAT import ufc_cell + +__all__ = ("as_fiat_cell", "create_base_element", + "create_element", "supported_elements") + + +# List of supported elements and mapping to element classes +supported_elements = {"Argyris": finat.Argyris, + "Bell": finat.Bell, + "Bernardi-Raugel": finat.BernardiRaugel, + "Bernardi-Raugel Bubble": finat.BernardiRaugelBubble, + "Bernstein": finat.Bernstein, + "Brezzi-Douglas-Fortin-Marini": finat.BrezziDouglasFortinMarini, + "Brezzi-Douglas-Marini Cube Face": finat.BrezziDouglasMariniCubeFace, + "Brezzi-Douglas-Marini": finat.BrezziDouglasMarini, + "Brezzi-Douglas-Marini Cube Edge": finat.BrezziDouglasMariniCubeEdge, + "Bubble": finat.Bubble, + "FacetBubble": finat.FacetBubble, + "Crouzeix-Raviart": finat.CrouzeixRaviart, + "Direct Serendipity": finat.DirectSerendipity, + "Discontinuous Lagrange": finat.DiscontinuousLagrange, + "Discontinuous Lagrange L2": finat.DiscontinuousLagrange, + "Discontinuous Taylor": finat.DiscontinuousTaylor, + "Discontinuous Raviart-Thomas": lambda *args, **kwargs: finat.DiscontinuousElement(finat.RaviartThomas(*args, **kwargs)), + "DPC": finat.DPC, + "DPC L2": finat.DPC, + "Hermite": finat.Hermite, + "Hsieh-Clough-Tocher": finat.HsiehCloughTocher, + "Reduced-Hsieh-Clough-Tocher": finat.ReducedHsiehCloughTocher, + "QuadraticPowellSabin6": finat.QuadraticPowellSabin6, + "QuadraticPowellSabin12": finat.QuadraticPowellSabin12, + "Alfeld-Sorokina": finat.AlfeldSorokina, + "Arnold-Qin": finat.ArnoldQin, + "Reduced-Arnold-Qin": finat.ReducedArnoldQin, + "Christiansen-Hu": finat.ChristiansenHu, + "Guzman-Neilan 1st kind H1": finat.GuzmanNeilanFirstKindH1, + "Guzman-Neilan 2nd kind H1": finat.GuzmanNeilanSecondKindH1, + "Guzman-Neilan H1(div)": finat.GuzmanNeilanH1div, + "Guzman-Neilan Bubble": finat.GuzmanNeilanBubble, + "Johnson-Mercier": finat.JohnsonMercier, + "Lagrange": finat.Lagrange, + "Kong-Mulder-Veldhuizen": finat.KongMulderVeldhuizen, + "Gauss-Lobatto-Legendre": finat.GaussLobattoLegendre, + "Gauss-Legendre": finat.GaussLegendre, + "Gauss-Legendre L2": finat.GaussLegendre, + "Legendre": finat.Legendre, + "Integrated Legendre": finat.IntegratedLegendre, + "Morley": finat.Morley, + "Nedelec 1st kind H(curl)": finat.Nedelec, + "Nedelec 2nd kind H(curl)": finat.NedelecSecondKind, + "Raviart-Thomas": finat.RaviartThomas, + "Real": finat.Real, + "S": finat.Serendipity, + "SminusF": finat.TrimmedSerendipityFace, + "SminusDiv": finat.TrimmedSerendipityDiv, + "SminusE": finat.TrimmedSerendipityEdge, + "SminusCurl": finat.TrimmedSerendipityCurl, + "Regge": finat.Regge, + "HDiv Trace": finat.HDivTrace, + "Hellan-Herrmann-Johnson": finat.HellanHerrmannJohnson, + "Gopalakrishnan-Lederer-Schoberl 1st kind": finat.GopalakrishnanLedererSchoberlFirstKind, + "Gopalakrishnan-Lederer-Schoberl 2nd kind": finat.GopalakrishnanLedererSchoberlSecondKind, + "Conforming Arnold-Winther": finat.ArnoldWinther, + "Nonconforming Arnold-Winther": finat.ArnoldWintherNC, + "Hu-Zhang": finat.HuZhang, + "Mardal-Tai-Winther": finat.MardalTaiWinther, + # These require special treatment + "Q": None, + "DQ": None, + "DQ L2": None, + "RTCE": None, + "RTCF": None, + "NCE": None, + "NCF": None, + } +"""A :class:`.dict` mapping UFL element family names to their +FInAT-equivalent constructors. If the value is ``None``, the UFL +element is supported, but must be handled specially because it doesn't +have a direct FInAT equivalent.""" + + +@cache +def as_fiat_cell(cell): + """Convert a ufl cell to a FIAT cell. + + :arg cell: the :class:`ufl.Cell` to convert.""" + if not isinstance(cell, ufl.AbstractCell): + raise ValueError("Expecting a UFL Cell") + return ufc_cell(cell) + + +@singledispatch +def convert(element, **kwargs): + """Handler for converting UFL elements to FInAT elements. + + :arg element: The UFL element to convert. + + Do not use this function directly, instead call + :func:`create_element`.""" + if element.family() in supported_elements: + raise ValueError("Element %s supported, but no handler provided" % element) + raise ValueError("Unsupported element type %s" % type(element)) + + +cg_interval_variants = { + "fdm": finat.FDMLagrange, + "fdm_ipdg": finat.FDMLagrange, + "fdm_quadrature": finat.FDMQuadrature, + "fdm_broken": finat.FDMBrokenH1, + "fdm_hermite": finat.FDMHermite, +} + + +dg_interval_variants = { + "fdm": finat.FDMDiscontinuousLagrange, + "fdm_quadrature": finat.FDMDiscontinuousLagrange, + "fdm_ipdg": lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args)), + "fdm_broken": finat.FDMBrokenL2, +} + + +# Base finite elements first +@convert.register(finat.ufl.FiniteElement) +def convert_finiteelement(element, **kwargs): + cell = as_fiat_cell(element.cell) + if element.family() == "Quadrature": + degree = element.degree() + scheme = element.quadrature_scheme() + if degree is None or scheme is None: + raise ValueError("Quadrature scheme and degree must be specified!") + + return finat.make_quadrature_element(cell, degree, scheme), set() + lmbda = supported_elements[element.family()] + if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}: + lmbda = None + element = finat.ufl.FiniteElement("DQ", element.cell, 0) + if lmbda is None: + if element.cell.cellname() == "quadrilateral": + # Handle quadrilateral short names like RTCF and RTCE. + element = element.reconstruct(cell=quadrilateral_tpc) + elif element.cell.cellname() == "hexahedron": + # Handle hexahedron short names like NCF and NCE. + element = element.reconstruct(cell=hexahedron_tpc) + else: + raise ValueError("%s is supported, but handled incorrectly" % + element.family()) + finat_elem, deps = _create_element(element, **kwargs) + return finat.FlattenedDimensions(finat_elem), deps + + finat_kwargs = {} + kind = element.variant() + if kind is None: + kind = 'spectral' # default variant + + if element.family() == "Lagrange": + if kind == 'spectral': + lmbda = finat.GaussLobattoLegendre + elif element.cell.cellname() == "interval" and kind in cg_interval_variants: + lmbda = cg_interval_variants[kind] + elif any(map(kind.startswith, ['integral', 'demkowicz', 'fdm'])): + lmbda = finat.IntegratedLegendre + finat_kwargs["variant"] = kind + elif kind in ['mgd', 'feec', 'qb', 'mse']: + degree = element.degree() + shift_axes = kwargs["shift_axes"] + restriction = kwargs["restriction"] + deps = {"shift_axes", "restriction"} + return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps + else: + # Let FIAT handle the general case + lmbda = finat.Lagrange + finat_kwargs["variant"] = kind + + elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]: + if kind == 'spectral': + lmbda = finat.GaussLegendre + elif element.cell.cellname() == "interval" and kind in dg_interval_variants: + lmbda = dg_interval_variants[kind] + elif any(map(kind.startswith, ['integral', 'demkowicz', 'fdm'])): + lmbda = finat.Legendre + finat_kwargs["variant"] = kind + elif kind in ['mgd', 'feec', 'qb', 'mse']: + degree = element.degree() + shift_axes = kwargs["shift_axes"] + restriction = kwargs["restriction"] + deps = {"shift_axes", "restriction"} + return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps + else: + # Let FIAT handle the general case + lmbda = finat.DiscontinuousLagrange + finat_kwargs["variant"] = kind + + elif element.variant() is not None: + finat_kwargs["variant"] = element.variant() + + return lmbda(cell, element.degree(), **finat_kwargs), set() + + +# Element modifiers and compound element types +@convert.register(finat.ufl.BrokenElement) +def convert_brokenelement(element, **kwargs): + finat_elem, deps = _create_element(element._element, **kwargs) + return finat.DiscontinuousElement(finat_elem), deps + + +@convert.register(finat.ufl.EnrichedElement) +def convert_enrichedelement(element, **kwargs): + elements, deps = zip(*[_create_element(elem, **kwargs) + for elem in element._elements]) + return finat.EnrichedElement(elements), set.union(*deps) + + +@convert.register(finat.ufl.NodalEnrichedElement) +def convert_nodalenrichedelement(element, **kwargs): + elements, deps = zip(*[_create_element(elem, **kwargs) + for elem in element._elements]) + return finat.NodalEnrichedElement(elements), set.union(*deps) + + +@convert.register(finat.ufl.MixedElement) +def convert_mixedelement(element, **kwargs): + elements, deps = zip(*[_create_element(elem, **kwargs) + for elem in element.sub_elements]) + return finat.MixedElement(elements), set.union(*deps) + + +@convert.register(finat.ufl.VectorElement) +@convert.register(finat.ufl.TensorElement) +def convert_tensorelement(element, **kwargs): + inner_elem, deps = _create_element(element.sub_elements[0], **kwargs) + shape = element.reference_value_shape + shape = shape[:len(shape) - len(inner_elem.value_shape)] + shape_innermost = kwargs["shape_innermost"] + return (finat.TensorFiniteElement(inner_elem, shape, not shape_innermost), + deps | {"shape_innermost"}) + + +@convert.register(finat.ufl.TensorProductElement) +def convert_tensorproductelement(element, **kwargs): + cell = element.cell + if type(cell) is not ufl.TensorProductCell: + raise ValueError("TensorProductElement not on TensorProductCell?") + shift_axes = kwargs["shift_axes"] + dim_offset = 0 + elements = [] + deps = set() + for elem in element.sub_elements: + kwargs["shift_axes"] = shift_axes + dim_offset + dim_offset += elem.cell.topological_dimension() + finat_elem, ds = _create_element(elem, **kwargs) + elements.append(finat_elem) + deps.update(ds) + return finat.TensorProductElement(elements), deps + + +@convert.register(finat.ufl.HDivElement) +def convert_hdivelement(element, **kwargs): + finat_elem, deps = _create_element(element._element, **kwargs) + return finat.HDivElement(finat_elem), deps + + +@convert.register(finat.ufl.HCurlElement) +def convert_hcurlelement(element, **kwargs): + finat_elem, deps = _create_element(element._element, **kwargs) + return finat.HCurlElement(finat_elem), deps + + +@convert.register(finat.ufl.WithMapping) +def convert_withmapping(element, **kwargs): + return _create_element(element.wrapee, **kwargs) + + +@convert.register(finat.ufl.RestrictedElement) +def convert_restrictedelement(element, **kwargs): + finat_elem, deps = _create_element(element._element, **kwargs) + return finat.RestrictedElement(finat_elem, element.restriction_domain()), deps + + +hexahedron_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval, ufl.interval) +quadrilateral_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval) +_cache = weakref.WeakKeyDictionary() + + +def create_element(ufl_element, shape_innermost=True, shift_axes=0, restriction=None): + """Create a FInAT element (suitable for tabulating with) given a UFL element. + + :arg ufl_element: The UFL element to create a FInAT element from. + :arg shape_innermost: Vector/tensor indices come after basis function indices + :arg restriction: cell restriction in interior facet integrals + (only for runtime tabulated elements) + """ + finat_element, deps = _create_element(ufl_element, + shape_innermost=shape_innermost, + shift_axes=shift_axes, + restriction=restriction) + return finat_element + + +def _create_element(ufl_element, **kwargs): + """A caching wrapper around :py:func:`convert`. + + Takes a UFL element and an unspecified set of parameter options, + and returns the converted element with the set of keyword names + that were relevant for conversion. + """ + # Look up conversion in cache + try: + cache = _cache[ufl_element] + except KeyError: + _cache[ufl_element] = {} + cache = _cache[ufl_element] + + for key, finat_element in cache.items(): + # Cache hit if all relevant parameter values match. + if all(kwargs[param] == value for param, value in key): + return finat_element, set(param for param, value in key) + + # Convert if cache miss + if ufl_element.cell is None: + raise ValueError("Don't know how to build element when cell is not given") + + finat_element, deps = convert(ufl_element, **kwargs) + + # Store conversion in cache + key = frozenset((param, kwargs[param]) for param in deps) + cache[key] = finat_element + + # Forward result + return finat_element, deps + + +def create_base_element(ufl_element, **kwargs): + """Create a "scalar" base FInAT element given a UFL element. + Takes a UFL element and an unspecified set of parameter options, + and returns the converted element. + """ + finat_element = create_element(ufl_element, **kwargs) + if isinstance(finat_element, finat.TensorFiniteElement): + finat_element = finat_element.base_element + return finat_element diff --git a/finat/ufl/elementlist.py b/finat/ufl/elementlist.py index 172eb78f..d487f83e 100644 --- a/finat/ufl/elementlist.py +++ b/finat/ufl/elementlist.py @@ -184,8 +184,8 @@ def show_elements(): register_element("S", None, 0, H1, "identity", (1, None), cubes) register_element("DPC", None, 0, L2, "identity", (0, None), cubes) -register_element("BDMCE", None, 1, HCurl, "covariant Piola", (1, None), ("quadrilateral",)) -register_element("BDMCF", None, 1, HDiv, "contravariant Piola", (1, None), ("quadrilateral",)) +register_element("Brezzi-Douglas-Marini Cube Edge", "BDMCE", 1, HCurl, "covariant Piola", (1, None), ("quadrilateral",)) +register_element("Brezzi-Douglas-Marini Cube Face", "BDMCF", 1, HDiv, "contravariant Piola", (1, None), ("quadrilateral",)) register_element("SminusE", "SminusE", 1, HCurl, "covariant Piola", (1, None), cubes[1:3]) register_element("SminusF", "SminusF", 1, HDiv, "contravariant Piola", (1, None), cubes[1:2]) register_element("SminusDiv", "SminusDiv", 1, HDiv, "contravariant Piola", (1, None), cubes[1:3]) diff --git a/test/finat/test_create_fiat_element.py b/test/finat/test_create_fiat_element.py new file mode 100644 index 00000000..f99053ab --- /dev/null +++ b/test/finat/test_create_fiat_element.py @@ -0,0 +1,150 @@ +import pytest + +import FIAT +from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange + +import ufl +import finat.ufl +from finat.element_factory import create_element as _create_element + + +supported_elements = { + # These all map directly to FIAT elements + "Brezzi-Douglas-Marini": FIAT.BrezziDouglasMarini, + "Brezzi-Douglas-Fortin-Marini": FIAT.BrezziDouglasFortinMarini, + "Lagrange": FIAT.Lagrange, + "Nedelec 1st kind H(curl)": FIAT.Nedelec, + "Nedelec 2nd kind H(curl)": FIAT.NedelecSecondKind, + "Raviart-Thomas": FIAT.RaviartThomas, + "Regge": FIAT.Regge, +} +"""A :class:`.dict` mapping UFL element family names to their +FIAT-equivalent constructors.""" + + +def create_element(ufl_element): + """Create a FIAT element given a UFL element.""" + finat_element = _create_element(ufl_element) + return finat_element.fiat_equivalent + + +@pytest.fixture(params=["BDM", + "BDFM", + "Lagrange", + "N1curl", + "N2curl", + "RT", + "Regge"]) +def triangle_names(request): + return request.param + + +@pytest.fixture +def ufl_element(triangle_names): + return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2) + + +def test_triangle_basic(ufl_element): + element = create_element(ufl_element) + assert isinstance(element, supported_elements[ufl_element.family()]) + + +@pytest.fixture(params=["CG", "DG", "DG L2"], scope="module") +def tensor_name(request): + return request.param + + +@pytest.fixture(params=[ufl.interval, ufl.triangle, + ufl.quadrilateral], + ids=lambda x: x.cellname(), + scope="module") +def ufl_A(request, tensor_name): + return finat.ufl.FiniteElement(tensor_name, request.param, 1) + + +@pytest.fixture +def ufl_B(tensor_name): + return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1) + + +def test_tensor_prod_simple(ufl_A, ufl_B): + tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B) + + tensor = create_element(tensor_ufl) + A = create_element(ufl_A) + B = create_element(ufl_B) + + assert isinstance(tensor, FIAT.TensorProductElement) + + assert tensor.A is A + assert tensor.B is B + + +@pytest.mark.parametrize(('family', 'expected_cls'), + [('P', FIAT.GaussLobattoLegendre), + ('DP', FIAT.GaussLegendre), + ('DP L2', FIAT.GaussLegendre)]) +def test_interval_variant_default(family, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3) + assert isinstance(create_element(ufl_element), expected_cls) + + +@pytest.mark.parametrize(('family', 'variant', 'expected_cls'), + [('P', 'equispaced', FIAT.Lagrange), + ('P', 'spectral', FIAT.GaussLobattoLegendre), + ('DP', 'equispaced', FIAT_DiscontinuousLagrange), + ('DP', 'spectral', FIAT.GaussLegendre), + ('DP L2', 'equispaced', FIAT_DiscontinuousLagrange), + ('DP L2', 'spectral', FIAT.GaussLegendre)]) +def test_interval_variant(family, variant, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant) + assert isinstance(create_element(ufl_element), expected_cls) + + +def test_triangle_variant_spectral(): + ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_triangle_variant_spectral_l2(): + ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_quadrilateral_variant_spectral_q(): + element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral')) + assert isinstance(element.element.A, FIAT.GaussLobattoLegendre) + assert isinstance(element.element.B, FIAT.GaussLobattoLegendre) + + +def test_quadrilateral_variant_spectral_dq(): + element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.element.A, FIAT.GaussLegendre) + assert isinstance(element.element.B, FIAT.GaussLegendre) + + +def test_quadrilateral_variant_spectral_dq_l2(): + element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.element.A, FIAT.GaussLegendre) + assert isinstance(element.element.B, FIAT.GaussLegendre) + + +def test_quadrilateral_variant_spectral_rtcf(): + element = create_element(finat.ufl.FiniteElement('RTCF', ufl.quadrilateral, 2, variant='spectral')) + assert isinstance(element.element._elements[0].A, FIAT.GaussLobattoLegendre) + assert isinstance(element.element._elements[0].B, FIAT.GaussLegendre) + assert isinstance(element.element._elements[1].A, FIAT.GaussLegendre) + assert isinstance(element.element._elements[1].B, FIAT.GaussLobattoLegendre) + + +def test_cache_hit(ufl_element): + A = create_element(ufl_element) + B = create_element(ufl_element) + + assert A is B + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) diff --git a/test/finat/test_create_finat_element.py b/test/finat/test_create_finat_element.py new file mode 100644 index 00000000..7964824c --- /dev/null +++ b/test/finat/test_create_finat_element.py @@ -0,0 +1,138 @@ +import pytest + +import ufl +import finat.ufl +import finat +from finat.element_factory import create_element, supported_elements + + +@pytest.fixture(params=["BDM", + "BDFM", + "Lagrange", + "N1curl", + "N2curl", + "RT", + "Regge"]) +def triangle_names(request): + return request.param + + +@pytest.fixture +def ufl_element(triangle_names): + return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2) + + +def test_triangle_basic(ufl_element): + element = create_element(ufl_element) + assert isinstance(element, supported_elements[ufl_element.family()]) + + +@pytest.fixture +def ufl_vector_element(triangle_names): + return finat.ufl.VectorElement(triangle_names, ufl.triangle, 2) + + +def test_triangle_vector(ufl_element, ufl_vector_element): + scalar = create_element(ufl_element) + vector = create_element(ufl_vector_element) + + assert isinstance(vector, finat.TensorFiniteElement) + assert scalar == vector.base_element + + +@pytest.fixture(params=["CG", "DG", "DG L2"]) +def tensor_name(request): + return request.param + + +@pytest.fixture(params=[ufl.interval, ufl.triangle, + ufl.quadrilateral], + ids=lambda x: x.cellname()) +def ufl_A(request, tensor_name): + return finat.ufl.FiniteElement(tensor_name, request.param, 1) + + +@pytest.fixture +def ufl_B(tensor_name): + return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1) + + +def test_tensor_prod_simple(ufl_A, ufl_B): + tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B) + + tensor = create_element(tensor_ufl) + A = create_element(ufl_A) + B = create_element(ufl_B) + + assert isinstance(tensor, finat.TensorProductElement) + + assert tensor.factors == (A, B) + + +@pytest.mark.parametrize(('family', 'expected_cls'), + [('P', finat.GaussLobattoLegendre), + ('DP', finat.GaussLegendre), + ('DP L2', finat.GaussLegendre)]) +def test_interval_variant_default(family, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3) + assert isinstance(create_element(ufl_element), expected_cls) + + +@pytest.mark.parametrize(('family', 'variant', 'expected_cls'), + [('P', 'equispaced', finat.Lagrange), + ('P', 'spectral', finat.GaussLobattoLegendre), + ('DP', 'equispaced', finat.DiscontinuousLagrange), + ('DP', 'spectral', finat.GaussLegendre), + ('DP L2', 'equispaced', finat.DiscontinuousLagrange), + ('DP L2', 'spectral', finat.GaussLegendre)]) +def test_interval_variant(family, variant, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant) + assert isinstance(create_element(ufl_element), expected_cls) + + +def test_triangle_variant_spectral(): + ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_triangle_variant_spectral_l2(): + ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_quadrilateral_variant_spectral_q(): + element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral')) + assert isinstance(element.product.factors[0], finat.GaussLobattoLegendre) + assert isinstance(element.product.factors[1], finat.GaussLobattoLegendre) + + +def test_quadrilateral_variant_spectral_dq(): + element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.product.factors[0], finat.GaussLegendre) + assert isinstance(element.product.factors[1], finat.GaussLegendre) + + +def test_quadrilateral_variant_spectral_dq_l2(): + element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.product.factors[0], finat.GaussLegendre) + assert isinstance(element.product.factors[1], finat.GaussLegendre) + + +def test_cache_hit(ufl_element): + A = create_element(ufl_element) + B = create_element(ufl_element) + + assert A is B + + +def test_cache_hit_vector(ufl_vector_element): + A = create_element(ufl_vector_element) + B = create_element(ufl_vector_element) + + assert A is B + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])