Skip to content

Commit

Permalink
Apply Diriclet boundary on the Cofunction RHS.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Sep 2, 2024
1 parent e035709 commit 8cd4c2a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
5 changes: 5 additions & 0 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.functionspace import RestrictedFunctionSpace
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.bcs import DirichletBC, EquationBC
Expand Down Expand Up @@ -305,6 +306,10 @@ def solve(self, bounds=None):

for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)
for coeff in coefficients:
if isinstance(coeff, Cofunction):
# Apply the DirichletBC to the right hand side of the equation.
dbc.apply(coeff)

if bounds is not None:
lower, upper = bounds
Expand Down
40 changes: 40 additions & 0 deletions tests/regression/test_cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,43 @@ def test_scalar_cofunction_zero_with_subset(V):
assert f is g
assert np.allclose(f.dat.data_ro[:2], 0.0)
assert np.allclose(f.dat.data_ro[2:], 1.0)


def test_diriclet_bc_rhs(V):
# Issue https://github.com/firedrakeproject/firedrake/issues/3498
# Apply DirichletBC to RHS (Cofunction) in LinearVariationalSolver
mesh = UnitIntervalMesh(2)
space = FunctionSpace(mesh, "Lagrange", 1)
test, trial = TestFunction(space), TrialFunction(space)

# Form RHS
u = Function(space, name="u")
problem = LinearVariationalProblem(
inner(trial, test) * dx, inner(Constant(1.0), test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
solver = LinearVariationalSolver(problem)
solver.solve()

assert np.allclose(assemble(inner(u, u) * ds), 0.0)

# Cofunction RHS
b = assemble(inner(Constant(1.0), test) * dx)
u = Function(space, name="u")
problem = LinearVariationalProblem(
inner(trial, test) * dx, b, u,
DirichletBC(space, 0.0, "on_boundary"))
solver = LinearVariationalSolver(problem)
solver.solve()

assert np.allclose(assemble(inner(u, u) * ds), 0.0)

# FormSum RHS
b = assemble(inner(Constant(0.5), test) * dx) + inner(Constant(0.5), test) * dx
u = Function(space, name="u")
problem = LinearVariationalProblem(
inner(trial, test) * dx, b, u,
DirichletBC(space, 0.0, "on_boundary"))
solver = LinearVariationalSolver(problem)
solver.solve()

assert np.allclose(assemble(inner(u, u) * ds), 0.0)

0 comments on commit 8cd4c2a

Please sign in to comment.