-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trefftz support for Firedrake #3775
base: master
Are you sure you want to change the base?
Changes from 5 commits
1d05e50
48109d3
4e19060
b84833c
fe9643f
4cb3d06
cf7fbfa
08cd4a4
00dad68
f763ba0
4c62f42
9292c1d
99df6dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,231 @@ | ||||||||||||||
""" | ||||||||||||||
This module provides a class to compute the Trefftz embedding of a given function space. | ||||||||||||||
It is also used to compute aggregation embedding of a given function space. | ||||||||||||||
""" | ||||||||||||||
from firedrake.petsc import PETSc | ||||||||||||||
from firedrake.cython.dmcommon import FACE_SETS_LABEL, CELL_SETS_LABEL | ||||||||||||||
from firedrake.assemble import assemble | ||||||||||||||
from firedrake.mesh import Mesh | ||||||||||||||
from firedrake.functionspace import FunctionSpace | ||||||||||||||
from firedrake.function import Function | ||||||||||||||
from firedrake.ufl_expr import TestFunction, TrialFunction | ||||||||||||||
from firedrake.constant import Constant | ||||||||||||||
from ufl import dx, dS, inner, jump, grad, dot, CellDiameter, FacetNormal | ||||||||||||||
import scipy.sparse as sp | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class TrefftzEmbedding(object): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
""" | ||||||||||||||
This class computes the Trefftz embedding of a given function space | ||||||||||||||
Parameters | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly for other docstrings |
||||||||||||||
---------- | ||||||||||||||
V : :class:`.FunctionSpace` | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove types from here and add type-hinting, see policy. |
||||||||||||||
Ambient function space. | ||||||||||||||
b : :class:`.ufl.form.Form` | ||||||||||||||
Bilinear form defining the Trefftz operator. | ||||||||||||||
dim : int, optional | ||||||||||||||
Dimension of the embedding. | ||||||||||||||
Default is the dimension of the function space. | ||||||||||||||
tol : float, optional | ||||||||||||||
Tolerance for the singular values cutoff. | ||||||||||||||
Default is 1e-12. | ||||||||||||||
backend : str, optional | ||||||||||||||
Backend to use for the computation of the SVD. | ||||||||||||||
Default is "scipy". | ||||||||||||||
""" | ||||||||||||||
def __init__(self, V, b, dim=None, tol=1e-12, backend="scipy"): | ||||||||||||||
self.V = V | ||||||||||||||
self.b = b | ||||||||||||||
self.dim = V.dim() if not dim else dim + 1 | ||||||||||||||
self.tol = tol | ||||||||||||||
self.backend = backend | ||||||||||||||
|
||||||||||||||
def assemble(self): | ||||||||||||||
""" | ||||||||||||||
Assemble the embedding, compute the SVD and return the embedding matrix | ||||||||||||||
""" | ||||||||||||||
self.B = assemble(self.b).M.handle | ||||||||||||||
if self.backend == "scipy": | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work in parallel? We should have a parallel test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure having a backend way we want to do this either... |
||||||||||||||
indptr, indices, data = self.B.getValuesCSR() | ||||||||||||||
Bsp = sp.csr_matrix((data, indices, indptr), shape=self.B.getSize()) | ||||||||||||||
_, sig, VT = sp.linalg.svds(Bsp, k=self.dim-1, which="SM") | ||||||||||||||
QT = sp.csr_matrix(VT[0:sum(sig < self.tol), :]) | ||||||||||||||
QTpsc = PETSc.Mat().createAIJ(size=QT.shape, csr=(QT.indptr, QT.indices, QT.data)) | ||||||||||||||
self.dimT = QT.shape[0] | ||||||||||||||
self.sig = sig | ||||||||||||||
else: | ||||||||||||||
raise NotImplementedError("Only scipy backend is supported") | ||||||||||||||
return QTpsc, sig | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class trefftz_ksp(object): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
""" | ||||||||||||||
This class wraps a PETSc KSP object to solve the reduced | ||||||||||||||
system obtained by the Trefftz embedding. | ||||||||||||||
""" | ||||||||||||||
def __init__(self): | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def get_appctx(ksp): | ||||||||||||||
""" | ||||||||||||||
Get the application context from the KSP | ||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
ksp : :class:`PETSc.KSP` | ||||||||||||||
The KSP object | ||||||||||||||
""" | ||||||||||||||
from firedrake.dmhooks import get_appctx | ||||||||||||||
return get_appctx(ksp.getDM()).appctx | ||||||||||||||
|
||||||||||||||
def setUp(self, ksp): | ||||||||||||||
""" | ||||||||||||||
Set up the Trefftz KSP | ||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
ksp : :class:`PETSc.KSP` | ||||||||||||||
The KSP object | ||||||||||||||
""" | ||||||||||||||
appctx = self.get_appctx(ksp) | ||||||||||||||
self.QT, _ = appctx["trefftz_embedding"].assemble() | ||||||||||||||
|
||||||||||||||
def solve(self, ksp, b, x): | ||||||||||||||
""" | ||||||||||||||
Solve the Trefftz KSP | ||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
ksp : :class:`PETSc.KSP` | ||||||||||||||
The KSP object | ||||||||||||||
b : :class:`PETSc.Vec` | ||||||||||||||
The right-hand side | ||||||||||||||
x : :class:`PETSc.Vec` | ||||||||||||||
The solution | ||||||||||||||
""" | ||||||||||||||
A, P = ksp.getOperators() | ||||||||||||||
self.Q = PETSc.Mat().createTranspose(self.QT) | ||||||||||||||
ATF = self.QT @ A @ self.Q | ||||||||||||||
PTF = self.QT @ P @ self.Q | ||||||||||||||
bTF = self.QT.createVecLeft() | ||||||||||||||
self.QT.mult(b, bTF) | ||||||||||||||
|
||||||||||||||
tiny_ksp = PETSc.KSP().create() | ||||||||||||||
tiny_ksp.setOperators(ATF, PTF) | ||||||||||||||
tiny_ksp.setOptionsPrefix("trefftz_") | ||||||||||||||
tiny_ksp.setFromOptions() | ||||||||||||||
xTF = ATF.createVecRight() | ||||||||||||||
tiny_ksp.solve(bTF, xTF) | ||||||||||||||
self.QT.multTranspose(xTF, x) | ||||||||||||||
ksp.setConvergedReason(tiny_ksp.getConvergedReason()) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class AggregationEmbedding(TrefftzEmbedding): | ||||||||||||||
""" | ||||||||||||||
This class computes the aggregation embedding of a given function space. | ||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
V : :class:`.FunctionSpace` | ||||||||||||||
Ambient function space. | ||||||||||||||
mesh : :class:`.Mesh` | ||||||||||||||
The mesh on which the aggregation is defined. | ||||||||||||||
polyMesh : :class:`.Function` | ||||||||||||||
The function defining the aggregation. | ||||||||||||||
dim : int | ||||||||||||||
Dimension of the embedding. | ||||||||||||||
Default is the dimension of the function space. | ||||||||||||||
tol : float | ||||||||||||||
Tolerance for the singular values cutoff. | ||||||||||||||
Default is 1e-12. | ||||||||||||||
""" | ||||||||||||||
def __init__(self, V, mesh, polyMesh, dim=None, tol=1e-12): | ||||||||||||||
# Relabel facets that are inside an aggregated region | ||||||||||||||
offset = 1 + mesh.topology_dm.getLabelSize(FACE_SETS_LABEL) | ||||||||||||||
offset += mesh.topology_dm.getLabelSize(CELL_SETS_LABEL) | ||||||||||||||
nPoly = int(max(polyMesh.dat.data[:])) # Number of aggregates | ||||||||||||||
getIdx = mesh._cell_numbering.getOffset | ||||||||||||||
plex = mesh.topology_dm | ||||||||||||||
pStart, pEnd = plex.getDepthStratum(2) | ||||||||||||||
self.facet_index = [] | ||||||||||||||
for poly in range(nPoly+1): | ||||||||||||||
facets = [] | ||||||||||||||
for i in range(pStart, pEnd): | ||||||||||||||
if polyMesh.dat.data[getIdx(i)] == poly: | ||||||||||||||
for f in plex.getCone(i): | ||||||||||||||
if f in facets: | ||||||||||||||
plex.setLabelValue(FACE_SETS_LABEL, f, offset+poly) | ||||||||||||||
if offset+poly not in self.facet_index: | ||||||||||||||
self.facet_index = self.facet_index + [offset+poly] | ||||||||||||||
facets = facets + list(plex.getCone(i)) | ||||||||||||||
self.mesh = Mesh(plex) | ||||||||||||||
h = CellDiameter(self.mesh) | ||||||||||||||
n = FacetNormal(self.mesh) | ||||||||||||||
W = FunctionSpace(self.mesh, V.ufl_element()) | ||||||||||||||
u = TrialFunction(W) | ||||||||||||||
v = TestFunction(W) | ||||||||||||||
self.b = Constant(0)*inner(u, v)*dx | ||||||||||||||
for i in self.facet_index: | ||||||||||||||
self.b += inner(jump(u), jump(v))*dS(i) | ||||||||||||||
for k in range(1, V.ufl_element().degree()+1): | ||||||||||||||
for i in self.facet_index: | ||||||||||||||
self.b += ((0.5 * h("+") + 0.5 * h("-"))**(2*k)) *\ | ||||||||||||||
inner(jump_normal(u, n("+"), k), jump_normal(v, n("+"), k))*dS(i) | ||||||||||||||
super().__init__(W, self.b, dim, tol) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def jump_normal(u, n, k): | ||||||||||||||
""" | ||||||||||||||
Compute the jump of the normal derivative of a function u | ||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
u : :class:`.Function` | ||||||||||||||
The function. | ||||||||||||||
n : :class:`.ufc.Normal` | ||||||||||||||
The normal vector. | ||||||||||||||
k : int | ||||||||||||||
The order of the normal derivative we aim to compute. | ||||||||||||||
""" | ||||||||||||||
j = 0.5*dot(n, (grad(u)("+")-grad(u)("-"))) | ||||||||||||||
for _ in range(1, k): | ||||||||||||||
j = 0.5*dot(n, (grad(j)-grad(j))) | ||||||||||||||
return j | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def dumb_aggregation(mesh): | ||||||||||||||
""" | ||||||||||||||
Compute a dumb aggregation of the mesh | ||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
mesh : :class:`.Mesh` | ||||||||||||||
The mesh we aim to aggregate. | ||||||||||||||
""" | ||||||||||||||
if mesh.comm.size > 1: | ||||||||||||||
raise NotImplementedError("Parallel mesh aggregation not supported") | ||||||||||||||
plex = mesh.topology_dm | ||||||||||||||
pStart, pEnd = plex.getDepthStratum(2) | ||||||||||||||
_, eEnd = plex.getDepthStratum(1) | ||||||||||||||
adjacency = [] | ||||||||||||||
for i in range(pStart, pEnd): | ||||||||||||||
ad = plex.getAdjacency(i) | ||||||||||||||
local = [] | ||||||||||||||
for a in ad: | ||||||||||||||
supp = plex.getSupport(a) | ||||||||||||||
supp = supp[supp < eEnd] | ||||||||||||||
for s in supp: | ||||||||||||||
if s < pEnd and s != ad[0]: | ||||||||||||||
local = local + [s] | ||||||||||||||
adjacency = adjacency + [(i, local)] | ||||||||||||||
adjacency = sorted(adjacency, key=lambda x: len(x[1]))[::-1] | ||||||||||||||
u = Function(FunctionSpace(mesh, "DG", 0)) | ||||||||||||||
|
||||||||||||||
getIdx = mesh._cell_numbering.getOffset | ||||||||||||||
av = list(range(pStart, pEnd)) | ||||||||||||||
col = 0 | ||||||||||||||
for a in adjacency: | ||||||||||||||
if a[0] in av: | ||||||||||||||
for k in a[1]: | ||||||||||||||
if k in av: | ||||||||||||||
av.remove(k) | ||||||||||||||
u.dat.data[getIdx(k)] = col | ||||||||||||||
av.remove(a[0]) | ||||||||||||||
u.dat.data[getIdx(a[0])] = col | ||||||||||||||
col = col + 1 | ||||||||||||||
return u |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from firedrake import * | ||
Check failure on line 1 in tests/regression/test_trefftz.py GitHub Actions / Firedrake complextest_trefftz.tests.regression.test_trefftz
|
||
from firedrake.trefftz import TrefftzEmbedding, AggregationEmbedding, dumb_aggregation | ||
|
||
|
||
@pytest.mark.skipcomplex | ||
def test_trefftz_laplace(): | ||
order = 6 | ||
mesh = UnitSquareMesh(2, 2) | ||
x, y = SpatialCoordinate(mesh) | ||
h = CellDiameter(mesh) | ||
n = FacetNormal(mesh) | ||
V = FunctionSpace(mesh, "DG", order) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
|
||
def delta(u): | ||
return div(grad(u)) | ||
|
||
a = inner(delta(u), delta(v)) * dx | ||
alpha = 4 | ||
mean_dudn = 0.5 * dot(grad(u("+"))+grad(u("-")), n("+")) | ||
mean_dvdn = 0.5 * dot(grad(v("+"))+grad(v("-")), n("+")) | ||
aDG = inner(grad(u), grad(v)) * dx | ||
aDG += inner((alpha*order**2/(h("+")+h("-")))*jump(u), jump(v))*dS | ||
aDG += inner(-mean_dudn, jump(v))*dS-inner(mean_dvdn, jump(u))*dS | ||
aDG += alpha*order**2/h*inner(u, v)*ds | ||
aDG += -inner(dot(n, grad(u)), v)*ds - inner(dot(n, grad(v)), u)*ds | ||
f = Function(V).interpolate(exp(x)*sin(y)) | ||
L = alpha*order**2/h*inner(f, v)*ds - inner(dot(n, grad(v)), f)*ds | ||
# Solve the problem | ||
uDG = Function(V) | ||
uDG.rename("uDG") | ||
embd = TrefftzEmbedding(V, a, tol=1e-8) | ||
appctx = {"trefftz_embedding": embd} | ||
uDG = Function(V) | ||
solve(aDG == L, uDG, solver_parameters={"ksp_type": "python", | ||
"ksp_python_type": "firedrake.trefftz.trefftz_ksp"}, | ||
appctx=appctx) | ||
assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-6) | ||
assert (embd.dimT < V.dim()/2) | ||
|
||
|
||
@pytest.mark.skipcomplex | ||
def test_trefftz_aggregation(): | ||
from netgen.occ import WorkPlane, OCCGeometry | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test needs to be skipped if ngsPETSc is not installed |
||
|
||
Rectangle = WorkPlane().Rectangle(1, 1).Face() | ||
geo = OCCGeometry(Rectangle, dim=2) | ||
ngmesh = geo.GenerateMesh(maxh=0.3) | ||
mesh = Mesh(ngmesh) | ||
|
||
polymesh = dumb_aggregation(mesh) | ||
|
||
order = 3 | ||
x, y = SpatialCoordinate(mesh) | ||
h = CellDiameter(mesh) | ||
n = FacetNormal(mesh) | ||
V = FunctionSpace(mesh, "DG", order) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
|
||
alpha = 1e3 | ||
mean_dudn = 0.5 * dot(grad(u("+"))+grad(u("-")), n("+")) | ||
mean_dvdn = 0.5 * dot(grad(v("+"))+grad(v("-")), n("+")) | ||
aDG = inner(grad(u), grad(v)) * dx | ||
aDG += inner((alpha*order**2/(h("+")+h("-")))*jump(u), jump(v))*dS | ||
aDG += inner(-mean_dudn, jump(v))*dS-inner(mean_dvdn, jump(u))*dS | ||
aDG += alpha*order**2/h*inner(u, v)*ds | ||
aDG += -inner(dot(n, grad(u)), v)*ds - inner(dot(n, grad(v)), u)*ds | ||
f = Function(V).interpolate(exp(x)*sin(y)) | ||
L = alpha*order**2/h*inner(f, v)*ds - inner(dot(n, grad(v)), f)*ds | ||
agg_embd = AggregationEmbedding(V, mesh, polymesh) | ||
appctx = {"trefftz_embedding": agg_embd} | ||
|
||
uDG = Function(V) | ||
solve(aDG == L, uDG, solver_parameters={"ksp_type": "python", | ||
"ksp_python_type": "firedrake.trefftz.trefftz_ksp"}, | ||
appctx=appctx) | ||
|
||
assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please define
__all__
to avoid pollution of the namespaceThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or equivalently you can replace
from firedrake.trefftz import *
withfrom firedrake.trefftz import OnlyWhatIWant
inside__init__.py
.