Skip to content

Commit

Permalink
Fix FacetBasis for quadratic meshes
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala committed Dec 23, 2023
1 parent 1dca95a commit deabbee
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ with respect to documented and/or tested features.
- Fixed: `MeshTet` uniform refine was reindexing subdomains incorrectly
- Fixed: `MeshDG.draw` did not work; now calls `Basis.draw` which
works for any mesh topology
- Fixed: `FacetBasis` now works with `MeshTri2`, `MeshQuad2`,
`MeshTet2` and `MeshHex2`

## [8.1.0] - 2023-06-16

Expand Down
14 changes: 14 additions & 0 deletions skfem/mapping/mapping_isoparametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def Fmap(self, i, X, tind=None):
def bndmap(self, i, X, find=None):
p = self.mesh.doflocs
facets = self.mesh.facets
if len(self.mesh.dofs.edge_dofs) > 0:
facets = np.vstack((facets,
self.mesh.dofs.edge_dofs[0, self.mesh.f2e]))
# TODO currently supports only one DOF per edge (slice 0 idx)
if len(self.mesh.dofs.facet_dofs) > 0:
facets = np.vstack((facets,
self.mesh.dofs.facet_dofs))
if find is None:
out = np.zeros((facets.shape[1], X.shape[1]))
for itr in range(facets.shape[0]):
Expand All @@ -82,6 +89,13 @@ def bndmap(self, i, X, find=None):
def bndJ(self, i, j, X, find=None):
p = self.mesh.doflocs
facets = self.mesh.facets
if len(self.mesh.dofs.edge_dofs) > 0:
facets = np.vstack((facets,
self.mesh.dofs.edge_dofs[0, self.mesh.f2e]))
# TODO currently supports only one DOF per edge (slice 0 idx)
if len(self.mesh.dofs.facet_dofs) > 0:
facets = np.vstack((facets,
self.mesh.dofs.facet_dofs))
if find is None:
out = np.zeros((facets.shape[1], X.shape[1]))
for itr in range(facets.shape[0]):
Expand Down
9 changes: 9 additions & 0 deletions skfem/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def f2t(self):
self._f2t = self.build_inverse(self.t, self.t2f)
return self._f2t

@property
def f2e(self):
if not hasattr(self, '_f2e'):
_, self._f2e = self.build_entities(
self.facets,
self.bndelem.refdom.facets,
)
return self._f2e

@property
def edges(self):
if not hasattr(self, '_edges'):
Expand Down
4 changes: 4 additions & 0 deletions skfem/mesh/mesh_2d_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def element_finder(self, *args, **kwargs):
@classmethod
def init_refdom(cls):
return cls.__bases__[-1].init_refdom()

def draw(self, *args, **kwargs):
from ..assembly import CellBasis
return CellBasis(self, self.elem()).draw(*args, **kwargs)
4 changes: 4 additions & 0 deletions skfem/refdom.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class RefTet(Refdom):
[0, 3],
[1, 3],
[2, 3]]
normals = np.array([[0., 0., -1.],
[0., -1., 0.],
[-1., 0., 0.],
[1., 1., 1.]])
brefdom = RefTri
nnodes = 4
nfacets = 4
Expand Down
24 changes: 22 additions & 2 deletions tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from skfem.mesh import (Mesh, MeshHex, MeshLine, MeshQuad, MeshTet, MeshTri,
MeshTri2, MeshQuad2, MeshTet2, MeshHex2, MeshLine1DG,
MeshQuad1DG, MeshHex2, MeshTri1DG)
from skfem.assembly import Basis, LinearForm
from skfem.element import ElementTetP1
from skfem.assembly import Basis, LinearForm, Functional, FacetBasis
from skfem.element import (ElementTetP1, ElementTriP0, ElementQuad0,
ElementHex0)
from skfem.utils import projection
from skfem.io.meshio import to_meshio, from_meshio
from skfem.io.json import to_dict, from_dict
Expand Down Expand Up @@ -717,3 +718,22 @@ def test_refine_subdomains_uniform_hexs():
m1 = MeshHex().refined().with_subdomains(sdef).refined()
m2 = MeshHex().refined().refined().with_subdomains(sdef)
np.testing.assert_equal(m1.subdomains, m2.subdomains)


@pytest.mark.parametrize(
"fbasis,refval,dec",
[
(FacetBasis(MeshTri2.init_circle(), ElementTriP0()), 2 * np.pi, 5),
(FacetBasis(MeshQuad2(), ElementQuad0()), 4, 5),
(FacetBasis(MeshTet2.init_ball(), ElementTetP1()), 4 * np.pi, 3),
(FacetBasis(MeshHex2(), ElementHex0()), 6, 3),
]
)
def test_integrate_quadratic_boundary(fbasis, refval, dec):

@Functional
def unit(w):
return 1 + 0 * w.x[0]

np.testing.assert_almost_equal(unit.assemble(fbasis),
refval, decimal=dec)

0 comments on commit deabbee

Please sign in to comment.