diff --git a/firedrake/assemble.py b/firedrake/assemble.py index afd7076114..a1d0009eb1 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -385,6 +385,12 @@ def visitor(e, *operands): visited = {} result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited) + # Apply BCs after assembly + rank = len(self._form.arguments()) + if rank == 1: + for bc in self._bcs: + bc.zero(result) + if tensor: BaseFormAssembler.update_tensor(result, tensor) return tensor @@ -409,7 +415,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if rank == 0: assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params) elif rank == 1 or (rank == 2 and self._diagonal): - assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, + assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params, zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight) elif rank == 2: assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, @@ -811,9 +817,9 @@ def restructure_base_form(expr, visited=None): return ufl.action(expr, ustar) # -- Case (6) -- # - if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()): + if isinstance(expr, ufl.FormSum) and all(not isinstance(c, ufl.form.BaseForm) for c in expr.components()): # Return ufl.Sum - return sum([c for c in expr.components()]) + return sum(w*c for w, c in zip(expr.weights(), expr.components())) return expr @staticmethod diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 488b42d55c..3c8fc8b930 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -9,11 +9,12 @@ PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS ) -from firedrake.function import Function, Cofunction -from firedrake.ufl_expr import TrialFunction, TestFunction +from firedrake.function import Function +from firedrake.ufl_expr import TrialFunction, TestFunction, action from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin -from ufl import replace +from firedrake.__future__ import interpolate +from ufl import replace, Form __all__ = ["LinearVariationalProblem", "LinearVariationalSolver", @@ -91,12 +92,11 @@ def __init__(self, F, u, bcs=None, J=None, bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs] self.u_restrict = Function(V_res).interpolate(u) v_res, u_res = TestFunction(V_res), TrialFunction(V_res) - F_arg, = F.arguments() - replace_F = {F_arg: v_res, self.u: self.u_restrict} - for c in F.coefficients(): - if c.function_space() == V.dual(): - replace_F[c] = Cofunction(V_res.dual()).interpolate(c) - self.F = replace(F, replace_F) + if isinstance(F, Form): + F_arg, = F.arguments() + self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict}) + else: + self.F = action(replace(F, {self.u: self.u_restrict}), interpolate(v_res, V)) v_arg, u_arg = self.J.arguments() self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) if self.Jp: