From bcc3ebe1efb43108609f4557f462244bbcb632f3 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 10:34:49 -0600 Subject: [PATCH] test_create_fi(n)at_element --- FIAT/macro.py | 3 - test/finat/test_create_fiat_element.py | 150 ++++++++++++++++++++++++ test/finat/test_create_finat_element.py | 138 ++++++++++++++++++++++ 3 files changed, 288 insertions(+), 3 deletions(-) create mode 100644 test/finat/test_create_fiat_element.py create mode 100644 test/finat/test_create_finat_element.py diff --git a/FIAT/macro.py b/FIAT/macro.py index 18b95114..77ab8938 100644 --- a/FIAT/macro.py +++ b/FIAT/macro.py @@ -1,6 +1,5 @@ from itertools import chain, combinations -from functools import cache import numpy from FIAT import expansions, polynomial_set @@ -196,7 +195,6 @@ def get_parent_complex(self): return self._parent_complex -@cache class IsoSplit(SplitSimplicialComplex): """Splits simplex into the simplicial complex obtained by connecting points on a regular lattice. @@ -302,7 +300,6 @@ def construct_subcomplex(self, dimension): return PowellSabinSplit(subcomplex, dimension=self.split_dimension) -@cache class AlfeldSplit(PowellSabinSplit): """Splits a simplicial complex by connecting cell vertices to their barycenter. diff --git a/test/finat/test_create_fiat_element.py b/test/finat/test_create_fiat_element.py new file mode 100644 index 00000000..f99053ab --- /dev/null +++ b/test/finat/test_create_fiat_element.py @@ -0,0 +1,150 @@ +import pytest + +import FIAT +from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange + +import ufl +import finat.ufl +from finat.element_factory import create_element as _create_element + + +supported_elements = { + # These all map directly to FIAT elements + "Brezzi-Douglas-Marini": FIAT.BrezziDouglasMarini, + "Brezzi-Douglas-Fortin-Marini": FIAT.BrezziDouglasFortinMarini, + "Lagrange": FIAT.Lagrange, + "Nedelec 1st kind H(curl)": FIAT.Nedelec, + "Nedelec 2nd kind H(curl)": FIAT.NedelecSecondKind, + "Raviart-Thomas": FIAT.RaviartThomas, + "Regge": FIAT.Regge, +} +"""A :class:`.dict` mapping UFL element family names to their +FIAT-equivalent constructors.""" + + +def create_element(ufl_element): + """Create a FIAT element given a UFL element.""" + finat_element = _create_element(ufl_element) + return finat_element.fiat_equivalent + + +@pytest.fixture(params=["BDM", + "BDFM", + "Lagrange", + "N1curl", + "N2curl", + "RT", + "Regge"]) +def triangle_names(request): + return request.param + + +@pytest.fixture +def ufl_element(triangle_names): + return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2) + + +def test_triangle_basic(ufl_element): + element = create_element(ufl_element) + assert isinstance(element, supported_elements[ufl_element.family()]) + + +@pytest.fixture(params=["CG", "DG", "DG L2"], scope="module") +def tensor_name(request): + return request.param + + +@pytest.fixture(params=[ufl.interval, ufl.triangle, + ufl.quadrilateral], + ids=lambda x: x.cellname(), + scope="module") +def ufl_A(request, tensor_name): + return finat.ufl.FiniteElement(tensor_name, request.param, 1) + + +@pytest.fixture +def ufl_B(tensor_name): + return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1) + + +def test_tensor_prod_simple(ufl_A, ufl_B): + tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B) + + tensor = create_element(tensor_ufl) + A = create_element(ufl_A) + B = create_element(ufl_B) + + assert isinstance(tensor, FIAT.TensorProductElement) + + assert tensor.A is A + assert tensor.B is B + + +@pytest.mark.parametrize(('family', 'expected_cls'), + [('P', FIAT.GaussLobattoLegendre), + ('DP', FIAT.GaussLegendre), + ('DP L2', FIAT.GaussLegendre)]) +def test_interval_variant_default(family, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3) + assert isinstance(create_element(ufl_element), expected_cls) + + +@pytest.mark.parametrize(('family', 'variant', 'expected_cls'), + [('P', 'equispaced', FIAT.Lagrange), + ('P', 'spectral', FIAT.GaussLobattoLegendre), + ('DP', 'equispaced', FIAT_DiscontinuousLagrange), + ('DP', 'spectral', FIAT.GaussLegendre), + ('DP L2', 'equispaced', FIAT_DiscontinuousLagrange), + ('DP L2', 'spectral', FIAT.GaussLegendre)]) +def test_interval_variant(family, variant, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant) + assert isinstance(create_element(ufl_element), expected_cls) + + +def test_triangle_variant_spectral(): + ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_triangle_variant_spectral_l2(): + ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_quadrilateral_variant_spectral_q(): + element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral')) + assert isinstance(element.element.A, FIAT.GaussLobattoLegendre) + assert isinstance(element.element.B, FIAT.GaussLobattoLegendre) + + +def test_quadrilateral_variant_spectral_dq(): + element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.element.A, FIAT.GaussLegendre) + assert isinstance(element.element.B, FIAT.GaussLegendre) + + +def test_quadrilateral_variant_spectral_dq_l2(): + element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.element.A, FIAT.GaussLegendre) + assert isinstance(element.element.B, FIAT.GaussLegendre) + + +def test_quadrilateral_variant_spectral_rtcf(): + element = create_element(finat.ufl.FiniteElement('RTCF', ufl.quadrilateral, 2, variant='spectral')) + assert isinstance(element.element._elements[0].A, FIAT.GaussLobattoLegendre) + assert isinstance(element.element._elements[0].B, FIAT.GaussLegendre) + assert isinstance(element.element._elements[1].A, FIAT.GaussLegendre) + assert isinstance(element.element._elements[1].B, FIAT.GaussLobattoLegendre) + + +def test_cache_hit(ufl_element): + A = create_element(ufl_element) + B = create_element(ufl_element) + + assert A is B + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) diff --git a/test/finat/test_create_finat_element.py b/test/finat/test_create_finat_element.py new file mode 100644 index 00000000..7964824c --- /dev/null +++ b/test/finat/test_create_finat_element.py @@ -0,0 +1,138 @@ +import pytest + +import ufl +import finat.ufl +import finat +from finat.element_factory import create_element, supported_elements + + +@pytest.fixture(params=["BDM", + "BDFM", + "Lagrange", + "N1curl", + "N2curl", + "RT", + "Regge"]) +def triangle_names(request): + return request.param + + +@pytest.fixture +def ufl_element(triangle_names): + return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2) + + +def test_triangle_basic(ufl_element): + element = create_element(ufl_element) + assert isinstance(element, supported_elements[ufl_element.family()]) + + +@pytest.fixture +def ufl_vector_element(triangle_names): + return finat.ufl.VectorElement(triangle_names, ufl.triangle, 2) + + +def test_triangle_vector(ufl_element, ufl_vector_element): + scalar = create_element(ufl_element) + vector = create_element(ufl_vector_element) + + assert isinstance(vector, finat.TensorFiniteElement) + assert scalar == vector.base_element + + +@pytest.fixture(params=["CG", "DG", "DG L2"]) +def tensor_name(request): + return request.param + + +@pytest.fixture(params=[ufl.interval, ufl.triangle, + ufl.quadrilateral], + ids=lambda x: x.cellname()) +def ufl_A(request, tensor_name): + return finat.ufl.FiniteElement(tensor_name, request.param, 1) + + +@pytest.fixture +def ufl_B(tensor_name): + return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1) + + +def test_tensor_prod_simple(ufl_A, ufl_B): + tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B) + + tensor = create_element(tensor_ufl) + A = create_element(ufl_A) + B = create_element(ufl_B) + + assert isinstance(tensor, finat.TensorProductElement) + + assert tensor.factors == (A, B) + + +@pytest.mark.parametrize(('family', 'expected_cls'), + [('P', finat.GaussLobattoLegendre), + ('DP', finat.GaussLegendre), + ('DP L2', finat.GaussLegendre)]) +def test_interval_variant_default(family, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3) + assert isinstance(create_element(ufl_element), expected_cls) + + +@pytest.mark.parametrize(('family', 'variant', 'expected_cls'), + [('P', 'equispaced', finat.Lagrange), + ('P', 'spectral', finat.GaussLobattoLegendre), + ('DP', 'equispaced', finat.DiscontinuousLagrange), + ('DP', 'spectral', finat.GaussLegendre), + ('DP L2', 'equispaced', finat.DiscontinuousLagrange), + ('DP L2', 'spectral', finat.GaussLegendre)]) +def test_interval_variant(family, variant, expected_cls): + ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant) + assert isinstance(create_element(ufl_element), expected_cls) + + +def test_triangle_variant_spectral(): + ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_triangle_variant_spectral_l2(): + ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral') + create_element(ufl_element) + + +def test_quadrilateral_variant_spectral_q(): + element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral')) + assert isinstance(element.product.factors[0], finat.GaussLobattoLegendre) + assert isinstance(element.product.factors[1], finat.GaussLobattoLegendre) + + +def test_quadrilateral_variant_spectral_dq(): + element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.product.factors[0], finat.GaussLegendre) + assert isinstance(element.product.factors[1], finat.GaussLegendre) + + +def test_quadrilateral_variant_spectral_dq_l2(): + element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral')) + assert isinstance(element.product.factors[0], finat.GaussLegendre) + assert isinstance(element.product.factors[1], finat.GaussLegendre) + + +def test_cache_hit(ufl_element): + A = create_element(ufl_element) + B = create_element(ufl_element) + + assert A is B + + +def test_cache_hit_vector(ufl_vector_element): + A = create_element(ufl_vector_element) + B = create_element(ufl_vector_element) + + assert A is B + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])