Skip to content

Commit

Permalink
Restricted Cofunction RHS
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 11, 2024
1 parent 17c4106 commit a86614a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ def _apply_dirichlet_bc(self, tensor, bc):
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ufl
import finat.ufl

from ufl.duals import is_dual, is_primal
from pyop2 import op2, mpi
from pyop2.utils import as_tuple

Expand Down Expand Up @@ -296,6 +297,9 @@ def restore_work_function(self, function):
cache[function] = False

def __eq__(self, other):
if is_primal(self) != is_primal(other) or \
is_dual(self) != is_dual(other):
return False
try:
return self.topological == other.topological and \
self.mesh() is other.mesh()
Expand Down
8 changes: 6 additions & 2 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS,
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function
from firedrake.function import Function, Cofunction
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
Expand Down Expand Up @@ -92,7 +92,11 @@ def __init__(self, F, u, bcs=None, J=None,
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
replace_F = {F_arg: v_res, self.u: self.u_restrict}
for c in F.coefficients():
if c.function_space() == V.dual():
replace_F[c] = Cofunction(V_res.dual()).interpolate(c)
self.F = replace(F, replace_F)
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
Expand Down
9 changes: 6 additions & 3 deletions tests/firedrake/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def test_poisson_inhomogeneous_bcs_2(j):


@pytest.mark.parallel(nprocs=3)
def test_poisson_inhomogeneous_bcs_high_level_interface():
@pytest.mark.parametrize("assembled_rhs", [False, True], ids=("Form", "Cofunction"))
def test_poisson_inhomogeneous_bcs_high_level_interface(assembled_rhs):
mesh = UnitSquareMesh(8, 8)
V = FunctionSpace(mesh, "CG", 2)
bc1 = DirichletBC(V, 0., 1)
Expand All @@ -155,9 +156,11 @@ def test_poisson_inhomogeneous_bcs_high_level_interface():
v = TestFunction(V)
a = inner(grad(u), grad(v)) * dx
u = Function(V)
L = inner(Constant(0), v) * dx
L = inner(Constant(-2), v) * dx
if assembled_rhs:
L = assemble(L)
solve(a == L, u, bcs=[bc1, bc2], restrict=True)
assert errornorm(SpatialCoordinate(mesh)[0], u) < 1.e-12
assert errornorm(SpatialCoordinate(mesh)[0]**2, u) < 1.e-12


@pytest.mark.parametrize("j", [1, 2, 5])
Expand Down

0 comments on commit a86614a

Please sign in to comment.