From 5782db545fb8c377ed1ef839dffc0f0797608911 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 24 Jan 2023 21:46:23 +0000 Subject: [PATCH 1/2] Fix interpolation for Arguments of zany spaces Previously we only checked if any coefficients in the source expression needed a coordinate mapping, but we should also check if any arguments do. --- tests/test_dual_evaluation.py | 10 ++++++++++ tsfc/driver.py | 5 ++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_dual_evaluation.py b/tests/test_dual_evaluation.py index 461226c2..8a537e98 100644 --- a/tests/test_dual_evaluation.py +++ b/tests/test_dual_evaluation.py @@ -59,3 +59,13 @@ def test_ufl_only_shape_mismatch(): assert to_element.value_shape == (2,) with pytest.raises(ValueError): compile_expression_dual_evaluation(expr, to_element, W.ufl_element()) + + +def test_interpolate_zany_argument(): + mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 2)) + V = ufl.FunctionSpace(mesh, ufl.FiniteElement("P", ufl.triangle, 1)) + Q = ufl.FunctionSpace(mesh, ufl.FiniteElement("Argyris", ufl.triangle, 5)) + expr = ufl.TestFunction(Q) + to_element = create_element(V.ufl_element()) + kernel = compile_expression_dual_evaluation(expr, to_element, V.ufl_element()) + assert kernel.needs_external_coords diff --git a/tsfc/driver.py b/tsfc/driver.py index e01f0820..79ee4f58 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -222,7 +222,10 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, builder.set_coefficient_numbers(coefficient_numbers) needs_external_coords = False - if has_type(expression, GeometricQuantity) or any(fem.needs_coordinate_mapping(c.ufl_element()) for c in coefficients): + if has_type(expression, GeometricQuantity) or any( + fem.needs_coordinate_mapping(c.ufl_element()) + for c in chain(coefficients, arguments) + ): # Create a fake coordinate coefficient for a domain. coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element())) builder.domain_coordinate[domain] = coords_coefficient From f1706ab021c12c387550a185c73787d648ce7720 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Sun, 29 Jan 2023 13:28:10 +0000 Subject: [PATCH 2/2] Handle nested zany elements in interpolation When determinining if the kernel needs a coordinate mapping we need to check if terminal elements need it (e.g. VectorElement(Argyris)). Previously we only checked the top-level element type which was incorrect for this case. --- tests/test_dual_evaluation.py | 9 ++++-- tsfc/fem.py | 56 ++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/tests/test_dual_evaluation.py b/tests/test_dual_evaluation.py index 8a537e98..c02cb688 100644 --- a/tests/test_dual_evaluation.py +++ b/tests/test_dual_evaluation.py @@ -61,10 +61,13 @@ def test_ufl_only_shape_mismatch(): compile_expression_dual_evaluation(expr, to_element, W.ufl_element()) -def test_interpolate_zany_argument(): +@pytest.mark.parametrize("modifier", + [lambda x: x, ufl.VectorElement, ufl.TensorElement], + ids=["scalar", "vector", "tensor"]) +def test_interpolate_zany_argument(modifier): mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 2)) - V = ufl.FunctionSpace(mesh, ufl.FiniteElement("P", ufl.triangle, 1)) - Q = ufl.FunctionSpace(mesh, ufl.FiniteElement("Argyris", ufl.triangle, 5)) + V = ufl.FunctionSpace(mesh, modifier(ufl.FiniteElement("P", ufl.triangle, 1))) + Q = ufl.FunctionSpace(mesh, modifier(ufl.FiniteElement("Argyris", ufl.triangle, 5))) expr = ufl.TestFunction(Q) to_element = create_element(V.ufl_element()) kernel = compile_expression_dual_evaluation(expr, to_element, V.ufl_element()) diff --git a/tsfc/fem.py b/tsfc/fem.py index 09882208..1639a6bc 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -31,6 +31,8 @@ from gem.unconcatenate import unconcatenate from gem.utils import cached_property +import finat +from finat.finiteelementbase import FiniteElementBase from finat.physically_mapped import PhysicalGeometry, NeedsCoordinateMappingElement from finat.point_set import PointSet, PointSingleton from finat.quadrature import make_quadrature @@ -246,12 +248,64 @@ def physical_vertices(self): return self.physical_points(vs) +@singledispatch def needs_coordinate_mapping(element): + raise AssertionError(f"Don't know how to handle {type(element)}") + + +@needs_coordinate_mapping.register(ufl.FiniteElementBase) +def _needs_coordinate_mapping_ufl(element): """Does this UFL element require a CoordinateMapping for translation?""" if element.family() == 'Real': return False else: - return isinstance(create_element(element), NeedsCoordinateMappingElement) + return needs_coordinate_mapping(create_element(element)) + + +@needs_coordinate_mapping.register(NeedsCoordinateMappingElement) +def _needs_coordinate_mapping_finat_needs_coordinate_mapping(element): + return True + + +@needs_coordinate_mapping.register(FiniteElementBase) +def _needs_coordinate_mapping_finat_base(element): + return False + + +@needs_coordinate_mapping.register(finat.DiscontinuousElement) +def _needs_coordinate_mapping_finat_discontinuous(element): + return needs_coordinate_mapping(element.element) + + +@needs_coordinate_mapping.register(finat.FlattenedDimensions) +def _needs_coordinate_mapping_finat_cube(element): + return needs_coordinate_mapping(element.product) + + +@needs_coordinate_mapping.register(finat.TensorProductElement) +def _needs_coordinate_mapping_finat_tpe(element): + return any(map(needs_coordinate_mapping, element.factors)) + + +@needs_coordinate_mapping.register(finat.TensorFiniteElement) +def _needs_coordinate_mapping_finat_tfe(element): + return needs_coordinate_mapping(element.base_element) + + +@needs_coordinate_mapping.register(finat.EnrichedElement) +def _needs_coordinate_mapping_finat_enriched(element): + return any(map(needs_coordinate_mapping, element.elements)) + + +@needs_coordinate_mapping.register(finat.mixed.MixedSubElement) +def _needs_coordinate_mapping_finat_mixed(element): + return needs_coordinate_mapping(element.element) + + +@needs_coordinate_mapping.register(finat.HDivElement) +@needs_coordinate_mapping.register(finat.HCurlElement) +def _needs_coordinate_mapping_finat_hdivcurl(element): + return needs_coordinate_mapping(element.wrappee) class PointSetContext(ContextBase):