Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 12, 2024
1 parent 1a56816 commit a630d56
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from firedrake.adjoint_utils import annotate_assemble
from firedrake.ufl_expr import extract_unique_domain
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace, RestrictedFunctionSpace
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
from firedrake.petsc import PETSc
from firedrake.slate import slac, slate
Expand Down Expand Up @@ -1180,24 +1180,31 @@ def allocate(self):

def _apply_bc(self, tensor, bc):
# TODO Maybe this could be a singledispatchmethod?
# Handle special diagonal case first.
if self._diagonal:
assert isinstance(bc, DirichletBC)
assert not self._zero_bc_nodes
if not isinstance(bc, DirichletBC):
raise TypeError(f"diagonal expects a DirichletBC: got {bc}")
# Ignore self._zero_bc_nodes.
tensor_func = tensor.riesz_representation(riesz_map="l2")
bc.set(tensor_func, 1)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
else:
test, = self._form.arguments()
if test.function_space() == bc.function_space_parent:
if isinstance(bc, DirichletBC):
assert self._zero_bc_nodes
bc.zero(tensor)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
else:
raise AssertionError
if isinstance(bc, DirichletBC):
# Ignore column bcs in Petrov-Galerkin formulation.
if bc.function_space_parent == test.function_space():
if not self._zero_bc_nodes:
tensor_func = tensor.riesz_representation(riesz_map="l2")
bc.apply(tensor_func)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
else:
bc.zero(tensor)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
else:
raise AssertionError

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
Expand Down Expand Up @@ -1420,29 +1427,36 @@ def _apply_bc(self, tensor, bc):
index = 0 if V.index is None else V.index
space = V if V.parent is None else V.parent
if isinstance(bc, DirichletBC):
if space == spaces[0] and space == spaces[1]:
if all(isinstance(s.topological, RestrictedFunctionSpace) for s in spaces):
# Make this the primal (the only) path.
# -- This path should work fine with Petrov-Galerkin formulations.
pass
elif all(not isinstance(s.topological, RestrictedFunctionSpace) for s in spaces):
if space != spaces[0]:
raise TypeError("bc space does not match the test function space")
elif space != spaces[1]:
raise TypeError("bc space does not match the trial function space")
# Set diagonal entries on bc nodes to 1 if the current
# block is on the matrix diagonal and its index matches the
# index of the function space the bc is defined on.
op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight)
# Handle off-diagonal block involving real function space.
# "lgmaps" is correctly constructed in _matrix_arg, but
# is ignored by PyOP2 in this case.
# Walk through row blocks associated with index.
if space == spaces[0]:
for j, s in enumerate(spaces[1]):
# Handle off-diagonal block involving real function space.
# "lgmaps" is correctly constructed in _matrix_arg, but
# is ignored by PyOP2 in this case.
# Walk through row blocks associated with index.
for j, s in enumerate(space):
if j != index and s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
# Walk through col blocks associated with index.
if space == spaces[1]:
for i, s in enumerate(spaces[0]):
# Walk through col blocks associated with index.
for i, s in enumerate(space):
if i != index and s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set)
else:
raise TypeError("Must define bcs all on regular function spaces or all on restricted function spaces")
elif isinstance(bc, EquationBCSplit):
if space == spaces[0]:
for j, s in enumerate(spaces[1]):
if s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
for j, s in enumerate(spaces[1]):
if s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False).assemble(tensor=tensor)
else:
raise AssertionError
Expand Down

0 comments on commit a630d56

Please sign in to comment.