Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function spaces check in linear solver #3214

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion demos/netgen/netgen_mesh.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order
bc = DirichletBC(V, 0.0, [1])
A = assemble(a, bcs=bc)
b = assemble(l)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = assemble(l)
b = assemble(l, bcs=bc, zero_bc_nodes=True)

bc.apply(b)
b_riesz = b.riesz_representation(riesz_map="l2")
bc.apply(b_riesz)
Comment on lines +385 to +386
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b_riesz = b.riesz_representation(riesz_map="l2")
bc.apply(b_riesz)

solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"})

VTKFile("output/Sphere.pvd").write(sol)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/solving-interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ pass in. In the pre-assembled case, we are solving a linear system:
Where :math:`A` is a known matrix, :math:`\vec{b}` is a known right
hand side vector and :math:`\vec{x}` is the unknown solution vector.
In Firedrake, :math:`A` is represented as a
:py:class:`~.Matrix`, while :math:`\vec{x}` is a :py:class:`~.Function`, and
:math:`\vec{b}` a :py:class:`~.Cofunction`.
:py:class:`~.MatrixBase`, while :math:`\vec{b}` and
:math:`\vec{x}` can be :py:class:`~.Function`\s or :py:class:`~.Cofunction`\s.
We build these values by calling ``assemble`` on the UFL forms that
define our problem, which, as before are denoted ``a`` and ``L``.
Similarly to the linear variational case, we first need a function in
Expand Down
7 changes: 4 additions & 3 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if isconstant(c):
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
if self.function_space != self.parent_space:
vec = extract_bc_subvector(
Expand Down Expand Up @@ -87,11 +87,12 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
r = extract_bc_subvector(
output = extract_bc_subvector(
adj_value, c.function_space(), bc
)
r = output.riesz_representation(riesz_map="l2")
if adj_output is None:
adj_output = r
else:
Expand Down
8 changes: 5 additions & 3 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
kwargs["bcs"] = bcs
dFdu = firedrake.assemble(dFdu_adj_form, **kwargs)

dJdu_func = dJdu.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(dJdu)
bc.apply(dJdu_func)
dJdu.assign(dJdu_func.riesz_representation(riesz_map="l2"))

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
Expand All @@ -201,7 +203,7 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
adj_sol_bdy = None
if compute_bdy:
adj_sol_bdy = firedrake.Function(
self.function_space.dual(),
self.function_space,
dJdu_copy.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)
).dat
Expand Down Expand Up @@ -811,7 +813,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
self.add_dependency(bc, no_duplicates=True)

def apply_mixedmass(self, a):
b = firedrake.Function(self.target_space)
b = self.backend.Cofunction(self.target_space.dual())
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
self.mixed_mass.mult(vsrc, vrhs)
return b
Expand Down
2 changes: 1 addition & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ def _check_tensor(self, tensor):
rank = len(self._form.arguments())
if rank == 1:
test, = self._form.arguments()
if tensor is not None and test.function_space() != tensor.function_space():
if tensor is not None and test.function_space() != tensor.function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def restore_work_function(self, function):

def __eq__(self, other):
try:
return self.topological == other.topological and \
return type(self) == type(other) and \
self.topological == other.topological and \
self.mesh() is other.mesh()
except AttributeError:
return False
Expand Down
5 changes: 5 additions & 0 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def solve(self, x, b):
if not isinstance(b, (function.Function, cofunction.Cofunction)):
raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__)

if x.function_space() != self.trial_space or b.function_space() != self.test_space.dual():
# When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*,
# we need to make sure that x and b belong to V and U*, respectively.
raise ValueError("Mismatching function spaces.")

if len(self.trial_space) > 1 and self.nullspace is not None:
self.nullspace._apply(self.trial_space.dof_dset.field_ises)
if len(self.test_space) > 1 and self.transpose_nullspace is not None:
Expand Down
6 changes: 4 additions & 2 deletions tests/regression/test_netgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def poisson(h, degree=2):
# Assembling matrix
A = assemble(a, bcs=bc)
b = assemble(l)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = assemble(l)
b = assemble(l, bcs=bc, zero_bc_nodes=True)

bc.apply(b)
b_riesz = b.riesz_representation(riesz_map="l2")
bc.apply(b_riesz)
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b_riesz = b.riesz_representation(riesz_map="l2")
bc.apply(b_riesz)


# Solving the problem
solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"})
Expand Down Expand Up @@ -96,7 +97,8 @@ def poisson3D(h, degree=2):
# Assembling matrix
A = assemble(a, bcs=bc)
b = assemble(l)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = assemble(l)
b = assemble(l, bcs=bc, zero_bc_nodes=True)

bc.apply(b)
b_riesz = b.riesz_representation(riesz_map="l2")
bc.apply(b_riesz)
Comment on lines +100 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b_riesz = b.riesz_representation(riesz_map="l2")
bc.apply(b_riesz)


# Solving the problem
solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"})
Expand Down
21 changes: 21 additions & 0 deletions tests/regression/test_solving_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,24 @@ def test_solve_cofunction_rhs():
Aw = assemble(action(a, w))
assert isinstance(Aw, Cofunction)
assert np.allclose(Aw.dat.data_ro, L.dat.data_ro)


def test_linear_solver_check_spaces():
mesh = UnitSquareMesh(10, 10)
V = FunctionSpace(mesh, "CG", 1)

u = TrialFunction(V)
v = TestFunction(V)
a = inner(u, v) * dx
A = assemble(a)

L = Cofunction(V.dual())
L.vector()[:] = 1.

Lf = L.riesz_representation(riesz_map="l2")

w = Function(V)
solve(A, w, L)

with pytest.raises(ValueError):
solve(A, w, Lf)
Loading