Skip to content

Commit

Permalink
new function restricted_function_space
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 20, 2024
1 parent df65c80 commit 88cd868
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 37 deletions.
37 changes: 36 additions & 1 deletion firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin
from firedrake.petsc import PETSc

__all__ = ['DirichletBC', 'homogenize', 'EquationBC']
__all__ = ['DirichletBC', 'homogenize', 'EquationBC', 'restricted_function_space']


class BCBase(object):
Expand Down Expand Up @@ -690,3 +690,38 @@ def homogenize(bc):
return DirichletBC(bc.function_space(), 0, bc.sub_domain)
else:
raise TypeError("homogenize only takes a DirichletBC or a list/tuple of DirichletBCs")


@PETSc.Log.EventDecorator("CreateFunctionSpace")
def restricted_function_space(V, bcs, name=None):
"""Create a :class:`.RestrictedFunctionSpace` from a list of boundary conditions.
Parameters
----------
V :
FunctionSpace object to restrict
bcs :
A list of boundary conditions.
name :
An optional name for the function space.
"""
if len(V) > 1:
spaces = [restricted_function_space(Vsub, bcs) for Vsub in V]
return firedrake.MixedFunctionSpace(spaces, name=name)

if not isinstance(bcs, (tuple, list)):
bcs = (bcs,)

boundary_set = []
for bc in bcs:
if bc.function_space() != V:
continue
for dbc in bc.dirichlet_bcs():
if isinstance(dbc.sub_domain, (str, int)):
boundary_set.append(dbc.sub_domain)
else:
boundary_set.extend(dbc.sub_domain)
if len(boundary_set) == 0:
return V
return firedrake.RestrictedFunctionSpace(V, boundary_set=boundary_set, name=name)
4 changes: 2 additions & 2 deletions firedrake/eigensolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Specify and solve finite element eigenproblems."""
from firedrake.assemble import assemble
from firedrake.bcs import restricted_function_space
from firedrake.function import Function
from firedrake.functionspace import RestrictedFunctionSpace
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake import utils
from firedrake.petsc import OptionsManager, flatten_parameters
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, A, M=None, bcs=None, bc_shift=0.0, restrict=True):
M = inner(u, v) * dx

if restrict and bcs: # assumed u and v are in the same space here
V_res = RestrictedFunctionSpace(self.output_space, bcs)
V_res = restricted_function_space(self.output_space, bcs)
u_res = TrialFunction(V_res)
v_res = TestFunction(V_res)
self.M = replace(M, {u: u_res, v: v_res})
Expand Down
24 changes: 1 addition & 23 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,33 +317,11 @@ def RestrictedFunctionSpace(function_space, boundary_set=[], name=None):
FunctionSpace object to restrict
boundary_set :
A set of subdomains of the mesh in which Dirichlet boundary conditions
will be applied. Alternatively, an iterable of boundary conditions.
will be applied.
name :
An optional name for the function space.
"""
from firedrake.bcs import BCBase
if len(function_space) > 1:
return MixedFunctionSpace([RestrictedFunctionSpace(Vsub, boundary_set=boundary_set)
for Vsub in function_space], name=name)

if not isinstance(boundary_set, (tuple, list, set, frozenset)):
boundary_set = (boundary_set,)

flat_boundary_set = []
for sub_domain in boundary_set:
if isinstance(sub_domain, BCBase):
bc = sub_domain
if bc.function_space() == function_space:
sub_domain = bc.sub_domain
else:
continue
if isinstance(sub_domain, (str, int)):
flat_boundary_set.append(sub_domain)
else:
flat_boundary_set.extend(sub_domain)
boundary_set = flat_boundary_set

return impl.WithGeometry.create(impl.RestrictedFunctionSpace(function_space,
boundary_set=boundary_set,
name=name),
Expand Down
5 changes: 2 additions & 3 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function
from firedrake.functionspace import RestrictedFunctionSpace
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.bcs import DirichletBC, EquationBC
from firedrake.bcs import DirichletBC, EquationBC, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
from ufl import replace

Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(self, F, u, bcs=None, J=None,
self.restrict = restrict

if restrict and bcs:
V_res = RestrictedFunctionSpace(V, bcs)
V_res = restricted_function_space(V, bcs)
bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
Expand Down
12 changes: 4 additions & 8 deletions tests/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def test_restricted_mixed_space():
Q = FunctionSpace(mesh, "DG", 0)
Z = V * Q
bcs = [DirichletBC(Z.sub(0), 0, [1])]
Z_restricted = RestrictedFunctionSpace(Z, bcs)
Z_restricted = restricted_function_space(Z, bcs)
compare_function_space_assembly(Z, Z_restricted, bcs)


def test_poisson_restricted_mixed_space():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V*Q
Z = V * Q

u, p = TrialFunctions(Z)
v, q = TestFunctions(Z)
Expand All @@ -197,14 +197,10 @@ def test_poisson_restricted_mixed_space():
bcs = [DirichletBC(Z.sub(0), 0, [1])]

w = Function(Z)
problem = LinearVariationalProblem(a, L, w, bcs=bcs, restrict=False)
solver = LinearVariationalSolver(problem)
solver.solve()
solve(a == L, w, bcs=bcs, restrict=False)

w2 = Function(Z)
problem = LinearVariationalProblem(a, L, w2, bcs=bcs, restrict=True)
solver = LinearVariationalSolver(problem)
solver.solve()
solve(a == L, w2, bcs=bcs, restrict=True)

assert errornorm(w.subfunctions[0], w2.subfunctions[0]) < 1.e-12
assert errornorm(w.subfunctions[1], w2.subfunctions[1]) < 1.e-12
Expand Down

0 comments on commit 88cd868

Please sign in to comment.