Skip to content

Commit

Permalink
Fix mapping of coefficients and constants for form independent compil…
Browse files Browse the repository at this point in the history
…ation. (#3597)

* Fix packing of constants and coefficients

* Fix naming

* Add test

* Fix numbering of constants
  • Loading branch information
jorgensd authored Jan 13, 2025
1 parent f045c42 commit 477076e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
35 changes: 28 additions & 7 deletions python/dolfinx/fem/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,34 @@ def create_form(
for _, idomain in _subdomain_data.items():
idomain.sort(key=lambda x: x[0])

# Extract name of ufl objects and map them to their corresponding C++ object
ufl_coefficients = ufl.algorithms.extract_coefficients(form.ufl_form)
coefficients = {
f"w{ufl_coefficients.index(u)}": uh._cpp_object for (u, uh) in coefficient_map.items()
}
ufl_constants = ufl.algorithms.analysis.extract_constants(form.ufl_form)
constants = {f"c{ufl_constants.index(u)}": uh._cpp_object for (u, uh) in constant_map.items()}
# Extract all coefficients of the compiled form in correct order
coefficients = {}
original_coefficients = ufl.algorithms.extract_coefficients(form.ufl_form)
num_coefficients = form.ufcx_form.num_coefficients
for c in range(num_coefficients):
original_index = form.ufcx_form.original_coefficient_positions[c]
original_coeff = original_coefficients[original_index]
try:
coefficients[f"w{c}"] = coefficient_map[original_coeff]._cpp_object
except KeyError:
raise RuntimeError(f"Missing coefficient {original_coeff}")

# Extract all constants of the compiled form in correct order
# NOTE: Constants are not eliminated
original_constants = ufl.algorithms.analysis.extract_constants(form.ufl_form)
num_constants = form.ufcx_form.num_constants
if num_constants != len(original_constants):
raise RuntimeError(
f"Number of constants in compiled form ({num_constants})",
f"does not match the original form {len(original_constants)}",
)
constants = {}
for counter, constant in enumerate(original_constants):
try:
mapped_constant = constant_map[constant]
constants[f"c{counter}"] = mapped_constant._cpp_object
except KeyError:
raise RuntimeError(f"Missing constant {constant}")

ftype = form_cpp_creator(form.dtype)
f = ftype(
Expand Down
73 changes: 73 additions & 0 deletions python/test/unit/fem/test_assemble_mesh_independent_form.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (C) 2024-2025 Jørgen S. Dokken
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from mpi4py import MPI

import numpy as np
Expand Down Expand Up @@ -152,3 +158,70 @@ def g(x):
# to dolfinx functions and constants
for i in range(1, 4):
create_and_integrate(i, compiled_form)


@pytest.mark.parametrize(
"dtype",
[
np.float32,
np.float64,
pytest.param(np.complex64, marks=pytest.mark.xfail_win32_complex),
pytest.param(np.complex128, marks=pytest.mark.xfail_win32_complex),
],
)
def test_eliminated_data(dtype):
"""
Test that mesh independent compilation handles the re-ordering of coefficients and constants
when removed through differentiation
"""

cell_name = "triangle"
real_type = dtype(0).real.dtype
c_el = basix.ufl.element("Lagrange", cell_name, 1, shape=(2,), dtype=real_type)
domain = ufl.Mesh(c_el)
el = basix.ufl.element("Lagrange", cell_name, 2, dtype=real_type)

V = ufl.FunctionSpace(domain, el)

c = ufl.Constant(domain)
d = ufl.Constant(domain)
u = ufl.Coefficient(V)
v = ufl.Coefficient(V)

J = (c * u**2 + d * v**2) * ufl.dx
dv = ufl.conj(ufl.TestFunction(V))
L = ufl.derivative(J, v, dv)

# Compile form using dolfinx.jit.ffcx_jit
compiled_form = dolfinx.fem.compile_form(
MPI.COMM_WORLD, L, form_compiler_options={"scalar_type": dtype}
)

# Pack discrete data
cell_type = dolfinx.mesh.to_type(cell_name)
mesh = dolfinx.mesh.create_unit_square(
MPI.COMM_WORLD, 5, 2, dtype=real_type, cell_type=cell_type
)
Vh = dolfinx.fem.functionspace(mesh, el)
uh = dolfinx.fem.Function(Vh, dtype=dtype)
uh.interpolate(lambda x: x[0])
vh = dolfinx.fem.Function(Vh, dtype=dtype)
vh.interpolate(lambda x: x[1])
dh = dolfinx.fem.Constant(mesh, dtype(3.0))
ch = dolfinx.fem.Constant(mesh, dtype(2.0))

# Assemble discrete vector
form = dolfinx.fem.create_form(compiled_form, [Vh], mesh, {}, {u: uh, v: vh}, {c: ch, d: dh})
b = dolfinx.fem.assemble_vector(form)
b.scatter_reverse(dolfinx.la.InsertMode.add)
b.scatter_forward()

# Compare to reference solution
dvh = ufl.conj(ufl.TestFunction(Vh))
exact_form = 2 * dh * vh * dvh * ufl.dx
b_exact = dolfinx.fem.assemble_vector(dolfinx.fem.form(exact_form, dtype=dtype))
b_exact.scatter_reverse(dolfinx.la.InsertMode.add)
b_exact.scatter_forward()

tol = np.finfo(dtype).resolution * 1e3
np.testing.assert_allclose(b.array, b_exact.array, atol=tol)

0 comments on commit 477076e

Please sign in to comment.