Skip to content

Commit

Permalink
Only interpolate the residual, not every cofunction in the RHS
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 21, 2024
1 parent 474edb3 commit 73be82b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
12 changes: 9 additions & 3 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ def visitor(e, *operands):
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)

# Apply BCs after assembly
rank = len(self._form.arguments())
if rank == 1:
for bc in self._bcs:
bc.zero(result)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
return tensor
Expand All @@ -409,7 +415,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
if rank == 0:
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
elif rank == 1 or (rank == 2 and self._diagonal):
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
Expand Down Expand Up @@ -811,9 +817,9 @@ def restructure_base_form(expr, visited=None):
return ufl.action(expr, ustar)

# -- Case (6) -- #
if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()):
if isinstance(expr, ufl.FormSum) and all(not isinstance(c, ufl.form.BaseForm) for c in expr.components()):
# Return ufl.Sum
return sum([c for c in expr.components()])
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
return expr

@staticmethod
Expand Down
18 changes: 9 additions & 9 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS,
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function, Cofunction
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.function import Function
from firedrake.ufl_expr import TrialFunction, TestFunction, action
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
from ufl import replace
from firedrake.__future__ import interpolate
from ufl import replace, Form

__all__ = ["LinearVariationalProblem",
"LinearVariationalSolver",
Expand Down Expand Up @@ -91,12 +92,11 @@ def __init__(self, F, u, bcs=None, J=None,
bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
replace_F = {F_arg: v_res, self.u: self.u_restrict}
for c in F.coefficients():
if c.function_space() == V.dual():
replace_F[c] = Cofunction(V_res.dual()).interpolate(c)
self.F = replace(F, replace_F)
if isinstance(F, Form):
F_arg, = F.arguments()
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
else:
self.F = action(replace(F, {self.u: self.u_restrict}), interpolate(v_res, V))
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
Expand Down

0 comments on commit 73be82b

Please sign in to comment.