From 477076efebd381922ed2b9adef1fb6bcda04b9a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Schartum=20Dokken?= Date: Mon, 13 Jan 2025 15:48:00 +0100 Subject: [PATCH] Fix mapping of coefficients and constants for form independent compilation. (#3597) * Fix packing of constants and coefficients * Fix naming * Add test * Fix numbering of constants --- python/dolfinx/fem/forms.py | 35 +++++++-- .../test_assemble_mesh_independent_form.py | 73 +++++++++++++++++++ 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index 85ed796d011..335ef212594 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -508,13 +508,34 @@ def create_form( for _, idomain in _subdomain_data.items(): idomain.sort(key=lambda x: x[0]) - # Extract name of ufl objects and map them to their corresponding C++ object - ufl_coefficients = ufl.algorithms.extract_coefficients(form.ufl_form) - coefficients = { - f"w{ufl_coefficients.index(u)}": uh._cpp_object for (u, uh) in coefficient_map.items() - } - ufl_constants = ufl.algorithms.analysis.extract_constants(form.ufl_form) - constants = {f"c{ufl_constants.index(u)}": uh._cpp_object for (u, uh) in constant_map.items()} + # Extract all coefficients of the compiled form in correct order + coefficients = {} + original_coefficients = ufl.algorithms.extract_coefficients(form.ufl_form) + num_coefficients = form.ufcx_form.num_coefficients + for c in range(num_coefficients): + original_index = form.ufcx_form.original_coefficient_positions[c] + original_coeff = original_coefficients[original_index] + try: + coefficients[f"w{c}"] = coefficient_map[original_coeff]._cpp_object + except KeyError: + raise RuntimeError(f"Missing coefficient {original_coeff}") + + # Extract all constants of the compiled form in correct order + # NOTE: Constants are not eliminated + original_constants = ufl.algorithms.analysis.extract_constants(form.ufl_form) + num_constants = form.ufcx_form.num_constants + if num_constants != len(original_constants): + raise RuntimeError( + f"Number of constants in compiled form ({num_constants})", + f"does not match the original form {len(original_constants)}", + ) + constants = {} + for counter, constant in enumerate(original_constants): + try: + mapped_constant = constant_map[constant] + constants[f"c{counter}"] = mapped_constant._cpp_object + except KeyError: + raise RuntimeError(f"Missing constant {constant}") ftype = form_cpp_creator(form.dtype) f = ftype( diff --git a/python/test/unit/fem/test_assemble_mesh_independent_form.py b/python/test/unit/fem/test_assemble_mesh_independent_form.py index ce8895bba22..0c7ed671ca3 100644 --- a/python/test/unit/fem/test_assemble_mesh_independent_form.py +++ b/python/test/unit/fem/test_assemble_mesh_independent_form.py @@ -1,3 +1,9 @@ +# Copyright (C) 2024-2025 Jørgen S. Dokken +# +# This file is part of DOLFINx (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + from mpi4py import MPI import numpy as np @@ -152,3 +158,70 @@ def g(x): # to dolfinx functions and constants for i in range(1, 4): create_and_integrate(i, compiled_form) + + +@pytest.mark.parametrize( + "dtype", + [ + np.float32, + np.float64, + pytest.param(np.complex64, marks=pytest.mark.xfail_win32_complex), + pytest.param(np.complex128, marks=pytest.mark.xfail_win32_complex), + ], +) +def test_eliminated_data(dtype): + """ + Test that mesh independent compilation handles the re-ordering of coefficients and constants + when removed through differentiation + """ + + cell_name = "triangle" + real_type = dtype(0).real.dtype + c_el = basix.ufl.element("Lagrange", cell_name, 1, shape=(2,), dtype=real_type) + domain = ufl.Mesh(c_el) + el = basix.ufl.element("Lagrange", cell_name, 2, dtype=real_type) + + V = ufl.FunctionSpace(domain, el) + + c = ufl.Constant(domain) + d = ufl.Constant(domain) + u = ufl.Coefficient(V) + v = ufl.Coefficient(V) + + J = (c * u**2 + d * v**2) * ufl.dx + dv = ufl.conj(ufl.TestFunction(V)) + L = ufl.derivative(J, v, dv) + + # Compile form using dolfinx.jit.ffcx_jit + compiled_form = dolfinx.fem.compile_form( + MPI.COMM_WORLD, L, form_compiler_options={"scalar_type": dtype} + ) + + # Pack discrete data + cell_type = dolfinx.mesh.to_type(cell_name) + mesh = dolfinx.mesh.create_unit_square( + MPI.COMM_WORLD, 5, 2, dtype=real_type, cell_type=cell_type + ) + Vh = dolfinx.fem.functionspace(mesh, el) + uh = dolfinx.fem.Function(Vh, dtype=dtype) + uh.interpolate(lambda x: x[0]) + vh = dolfinx.fem.Function(Vh, dtype=dtype) + vh.interpolate(lambda x: x[1]) + dh = dolfinx.fem.Constant(mesh, dtype(3.0)) + ch = dolfinx.fem.Constant(mesh, dtype(2.0)) + + # Assemble discrete vector + form = dolfinx.fem.create_form(compiled_form, [Vh], mesh, {}, {u: uh, v: vh}, {c: ch, d: dh}) + b = dolfinx.fem.assemble_vector(form) + b.scatter_reverse(dolfinx.la.InsertMode.add) + b.scatter_forward() + + # Compare to reference solution + dvh = ufl.conj(ufl.TestFunction(Vh)) + exact_form = 2 * dh * vh * dvh * ufl.dx + b_exact = dolfinx.fem.assemble_vector(dolfinx.fem.form(exact_form, dtype=dtype)) + b_exact.scatter_reverse(dolfinx.la.InsertMode.add) + b_exact.scatter_forward() + + tol = np.finfo(dtype).resolution * 1e3 + np.testing.assert_allclose(b.array, b_exact.array, atol=tol)