Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Fix interpolation for Arguments of zany spaces #286

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/test_dual_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,16 @@ def test_ufl_only_shape_mismatch():
assert to_element.value_shape == (2,)
with pytest.raises(ValueError):
compile_expression_dual_evaluation(expr, to_element, W.ufl_element())


@pytest.mark.parametrize("modifier",
[lambda x: x, ufl.VectorElement, ufl.TensorElement],
ids=["scalar", "vector", "tensor"])
def test_interpolate_zany_argument(modifier):
mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 2))
V = ufl.FunctionSpace(mesh, modifier(ufl.FiniteElement("P", ufl.triangle, 1)))
Q = ufl.FunctionSpace(mesh, modifier(ufl.FiniteElement("Argyris", ufl.triangle, 5)))
expr = ufl.TestFunction(Q)
to_element = create_element(V.ufl_element())
kernel = compile_expression_dual_evaluation(expr, to_element, V.ufl_element())
assert kernel.needs_external_coords
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not checking for correctness, just that a kernel actually comes out the other end.

5 changes: 4 additions & 1 deletion tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
builder.set_coefficient_numbers(coefficient_numbers)

needs_external_coords = False
if has_type(expression, GeometricQuantity) or any(fem.needs_coordinate_mapping(c.ufl_element()) for c in coefficients):
if has_type(expression, GeometricQuantity) or any(
fem.needs_coordinate_mapping(c.ufl_element())
for c in chain(coefficients, arguments)
):
# Create a fake coordinate coefficient for a domain.
coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element()))
builder.domain_coordinate[domain] = coords_coefficient
Expand Down
56 changes: 55 additions & 1 deletion tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from gem.unconcatenate import unconcatenate
from gem.utils import cached_property

import finat
from finat.finiteelementbase import FiniteElementBase
from finat.physically_mapped import PhysicalGeometry, NeedsCoordinateMappingElement
from finat.point_set import PointSet, PointSingleton
from finat.quadrature import make_quadrature
Expand Down Expand Up @@ -246,12 +248,64 @@ def physical_vertices(self):
return self.physical_points(vs)


@singledispatch
def needs_coordinate_mapping(element):
raise AssertionError(f"Don't know how to handle {type(element)}")


@needs_coordinate_mapping.register(ufl.FiniteElementBase)
def _needs_coordinate_mapping_ufl(element):
"""Does this UFL element require a CoordinateMapping for translation?"""
if element.family() == 'Real':
return False
else:
return isinstance(create_element(element), NeedsCoordinateMappingElement)
return needs_coordinate_mapping(create_element(element))


@needs_coordinate_mapping.register(NeedsCoordinateMappingElement)
def _needs_coordinate_mapping_finat_needs_coordinate_mapping(element):
return True


@needs_coordinate_mapping.register(FiniteElementBase)
def _needs_coordinate_mapping_finat_base(element):
return False


@needs_coordinate_mapping.register(finat.DiscontinuousElement)
def _needs_coordinate_mapping_finat_discontinuous(element):
return needs_coordinate_mapping(element.element)


@needs_coordinate_mapping.register(finat.FlattenedDimensions)
def _needs_coordinate_mapping_finat_cube(element):
return needs_coordinate_mapping(element.product)


@needs_coordinate_mapping.register(finat.TensorProductElement)
def _needs_coordinate_mapping_finat_tpe(element):
return any(map(needs_coordinate_mapping, element.factors))


@needs_coordinate_mapping.register(finat.TensorFiniteElement)
def _needs_coordinate_mapping_finat_tfe(element):
return needs_coordinate_mapping(element.base_element)


@needs_coordinate_mapping.register(finat.EnrichedElement)
def _needs_coordinate_mapping_finat_enriched(element):
return any(map(needs_coordinate_mapping, element.elements))


@needs_coordinate_mapping.register(finat.mixed.MixedSubElement)
def _needs_coordinate_mapping_finat_mixed(element):
return needs_coordinate_mapping(element.element)


@needs_coordinate_mapping.register(finat.HDivElement)
@needs_coordinate_mapping.register(finat.HCurlElement)
def _needs_coordinate_mapping_finat_hdivcurl(element):
return needs_coordinate_mapping(element.wrappee)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put this here because it meant I didn't need a PR in finat as well, but probably finat should offer a traverse function that yields all the elements in a structured element. Everything needs special-casing because elements don't have regular "children" slots.

e.g. ideally you would just write (if you don't care about ordering)

def traverse(element):
    yield element
    yield from map(traverse, element.children)



class PointSetContext(ContextBase):
Expand Down