diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 3c9a2689a4..1656b8ef4e 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -10,6 +10,7 @@ DEFAULT_SNES_PARAMETERS ) from firedrake.function import Function +from firedrake.cofunction import Cofunction from firedrake.functionspace import RestrictedFunctionSpace from firedrake.ufl_expr import TrialFunction, TestFunction from firedrake.bcs import DirichletBC, EquationBC @@ -305,6 +306,10 @@ def solve(self, bounds=None): for dbc in problem.dirichlet_bcs(): dbc.apply(problem.u_restrict) + for coeff in coefficients: + if isinstance(coeff, Cofunction): + # Apply the DirichletBC to the right hand side of the equation. + dbc.apply(coeff) if bounds is not None: lower, upper = bounds diff --git a/tests/regression/test_cofunction.py b/tests/regression/test_cofunction.py index ef8262fea8..19ce2d4612 100644 --- a/tests/regression/test_cofunction.py +++ b/tests/regression/test_cofunction.py @@ -60,3 +60,43 @@ def test_scalar_cofunction_zero_with_subset(V): assert f is g assert np.allclose(f.dat.data_ro[:2], 0.0) assert np.allclose(f.dat.data_ro[2:], 1.0) + + +def test_diriclet_bc_rhs(V): + # Issue https://github.com/firedrakeproject/firedrake/issues/3498 + # Apply DirichletBC to RHS (Cofunction) in LinearVariationalSolver + mesh = UnitIntervalMesh(2) + space = FunctionSpace(mesh, "Lagrange", 1) + test, trial = TestFunction(space), TrialFunction(space) + + # Form RHS + u = Function(space, name="u") + problem = LinearVariationalProblem( + inner(trial, test) * dx, inner(Constant(1.0), test) * dx, u, + DirichletBC(space, 0.0, "on_boundary")) + solver = LinearVariationalSolver(problem) + solver.solve() + + assert np.allclose(assemble(inner(u, u) * ds), 0.0) + + # Cofunction RHS + b = assemble(inner(Constant(1.0), test) * dx) + u = Function(space, name="u") + problem = LinearVariationalProblem( + inner(trial, test) * dx, b, u, + DirichletBC(space, 0.0, "on_boundary")) + solver = LinearVariationalSolver(problem) + solver.solve() + + assert np.allclose(assemble(inner(u, u) * ds), 0.0) + + # FormSum RHS + b = assemble(inner(Constant(0.5), test) * dx) + inner(Constant(0.5), test) * dx + u = Function(space, name="u") + problem = LinearVariationalProblem( + inner(trial, test) * dx, b, u, + DirichletBC(space, 0.0, "on_boundary")) + solver = LinearVariationalSolver(problem) + solver.solve() + + assert np.allclose(assemble(inner(u, u) * ds), 0.0)