diff --git a/tests/test_dual_evaluation.py b/tests/test_dual_evaluation.py index 461226c2..c02cb688 100644 --- a/tests/test_dual_evaluation.py +++ b/tests/test_dual_evaluation.py @@ -59,3 +59,16 @@ def test_ufl_only_shape_mismatch(): assert to_element.value_shape == (2,) with pytest.raises(ValueError): compile_expression_dual_evaluation(expr, to_element, W.ufl_element()) + + +@pytest.mark.parametrize("modifier", + [lambda x: x, ufl.VectorElement, ufl.TensorElement], + ids=["scalar", "vector", "tensor"]) +def test_interpolate_zany_argument(modifier): + mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 2)) + V = ufl.FunctionSpace(mesh, modifier(ufl.FiniteElement("P", ufl.triangle, 1))) + Q = ufl.FunctionSpace(mesh, modifier(ufl.FiniteElement("Argyris", ufl.triangle, 5))) + expr = ufl.TestFunction(Q) + to_element = create_element(V.ufl_element()) + kernel = compile_expression_dual_evaluation(expr, to_element, V.ufl_element()) + assert kernel.needs_external_coords diff --git a/tsfc/driver.py b/tsfc/driver.py index e01f0820..79ee4f58 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -222,7 +222,10 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, builder.set_coefficient_numbers(coefficient_numbers) needs_external_coords = False - if has_type(expression, GeometricQuantity) or any(fem.needs_coordinate_mapping(c.ufl_element()) for c in coefficients): + if has_type(expression, GeometricQuantity) or any( + fem.needs_coordinate_mapping(c.ufl_element()) + for c in chain(coefficients, arguments) + ): # Create a fake coordinate coefficient for a domain. coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element())) builder.domain_coordinate[domain] = coords_coefficient diff --git a/tsfc/fem.py b/tsfc/fem.py index 09882208..1639a6bc 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -31,6 +31,8 @@ from gem.unconcatenate import unconcatenate from gem.utils import cached_property +import finat +from finat.finiteelementbase import FiniteElementBase from finat.physically_mapped import PhysicalGeometry, NeedsCoordinateMappingElement from finat.point_set import PointSet, PointSingleton from finat.quadrature import make_quadrature @@ -246,12 +248,64 @@ def physical_vertices(self): return self.physical_points(vs) +@singledispatch def needs_coordinate_mapping(element): + raise AssertionError(f"Don't know how to handle {type(element)}") + + +@needs_coordinate_mapping.register(ufl.FiniteElementBase) +def _needs_coordinate_mapping_ufl(element): """Does this UFL element require a CoordinateMapping for translation?""" if element.family() == 'Real': return False else: - return isinstance(create_element(element), NeedsCoordinateMappingElement) + return needs_coordinate_mapping(create_element(element)) + + +@needs_coordinate_mapping.register(NeedsCoordinateMappingElement) +def _needs_coordinate_mapping_finat_needs_coordinate_mapping(element): + return True + + +@needs_coordinate_mapping.register(FiniteElementBase) +def _needs_coordinate_mapping_finat_base(element): + return False + + +@needs_coordinate_mapping.register(finat.DiscontinuousElement) +def _needs_coordinate_mapping_finat_discontinuous(element): + return needs_coordinate_mapping(element.element) + + +@needs_coordinate_mapping.register(finat.FlattenedDimensions) +def _needs_coordinate_mapping_finat_cube(element): + return needs_coordinate_mapping(element.product) + + +@needs_coordinate_mapping.register(finat.TensorProductElement) +def _needs_coordinate_mapping_finat_tpe(element): + return any(map(needs_coordinate_mapping, element.factors)) + + +@needs_coordinate_mapping.register(finat.TensorFiniteElement) +def _needs_coordinate_mapping_finat_tfe(element): + return needs_coordinate_mapping(element.base_element) + + +@needs_coordinate_mapping.register(finat.EnrichedElement) +def _needs_coordinate_mapping_finat_enriched(element): + return any(map(needs_coordinate_mapping, element.elements)) + + +@needs_coordinate_mapping.register(finat.mixed.MixedSubElement) +def _needs_coordinate_mapping_finat_mixed(element): + return needs_coordinate_mapping(element.element) + + +@needs_coordinate_mapping.register(finat.HDivElement) +@needs_coordinate_mapping.register(finat.HCurlElement) +def _needs_coordinate_mapping_finat_hdivcurl(element): + return needs_coordinate_mapping(element.wrappee) class PointSetContext(ContextBase):