diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f451b3f596..6b1bb7d3e4 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1201,7 +1201,7 @@ def _apply_dirichlet_bc(self, tensor, bc): bc.zero(tensor) def _check_tensor(self, tensor): - if tensor.function_space() != self._form.arguments()[0].function_space(): + if tensor.function_space() != self._form.arguments()[0].function_space().dual(): raise ValueError("Form's argument does not match provided result tensor") @staticmethod diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 8fc81244f7..7e7ecfbcbf 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -13,6 +13,7 @@ import ufl import finat.ufl +from ufl.duals import is_dual, is_primal from pyop2 import op2, mpi from pyop2.utils import as_tuple @@ -296,6 +297,9 @@ def restore_work_function(self, function): cache[function] = False def __eq__(self, other): + if is_primal(self) != is_primal(other) or \ + is_dual(self) != is_dual(other): + return False try: return self.topological == other.topological and \ self.mesh() is other.mesh() diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4a1ac396c5..488b42d55c 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -9,7 +9,7 @@ PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS ) -from firedrake.function import Function +from firedrake.function import Function, Cofunction from firedrake.ufl_expr import TrialFunction, TestFunction from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin @@ -92,7 +92,11 @@ def __init__(self, F, u, bcs=None, J=None, self.u_restrict = Function(V_res).interpolate(u) v_res, u_res = TestFunction(V_res), TrialFunction(V_res) F_arg, = F.arguments() - self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict}) + replace_F = {F_arg: v_res, self.u: self.u_restrict} + for c in F.coefficients(): + if c.function_space() == V.dual(): + replace_F[c] = Cofunction(V_res.dual()).interpolate(c) + self.F = replace(F, replace_F) v_arg, u_arg = self.J.arguments() self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) if self.Jp: diff --git a/tests/firedrake/regression/test_restricted_function_space.py b/tests/firedrake/regression/test_restricted_function_space.py index dc9a2ecc64..6dd093f58a 100644 --- a/tests/firedrake/regression/test_restricted_function_space.py +++ b/tests/firedrake/regression/test_restricted_function_space.py @@ -146,7 +146,8 @@ def test_poisson_inhomogeneous_bcs_2(j): @pytest.mark.parallel(nprocs=3) -def test_poisson_inhomogeneous_bcs_high_level_interface(): +@pytest.mark.parametrize("assembled_rhs", [False, True], ids=("Form", "Cofunction")) +def test_poisson_inhomogeneous_bcs_high_level_interface(assembled_rhs): mesh = UnitSquareMesh(8, 8) V = FunctionSpace(mesh, "CG", 2) bc1 = DirichletBC(V, 0., 1) @@ -155,9 +156,11 @@ def test_poisson_inhomogeneous_bcs_high_level_interface(): v = TestFunction(V) a = inner(grad(u), grad(v)) * dx u = Function(V) - L = inner(Constant(0), v) * dx + L = inner(Constant(-2), v) * dx + if assembled_rhs: + L = assemble(L) solve(a == L, u, bcs=[bc1, bc2], restrict=True) - assert errornorm(SpatialCoordinate(mesh)[0], u) < 1.e-12 + assert errornorm(SpatialCoordinate(mesh)[0]**2, u) < 1.e-12 @pytest.mark.parametrize("j", [1, 2, 5])