Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JDBetteridge/refactor curvefield #48

Merged
merged 21 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ We are now finally ready to install ngsPETSc:
Authors
----------

Patrick E. Farrell, Stefano Zampini, Umberto Zerbinati
Jack Betteridge, Patrick E. Farrell, Stefano Zampini, Umberto Zerbinati

License
---------------
Expand Down
17 changes: 12 additions & 5 deletions ngsPETSc/utils/firedrake/hierarchies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
This module contains all the functions related
This module contains all the functions related
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@UZerbinati what is this meant to say?

'''
try:
import firedrake as fd
Expand Down Expand Up @@ -204,7 +204,7 @@ def NetgenHierarchy(mesh, levs, flags):
order = flagsUtils(flags, "degree", 1)
if isinstance(order, int):
order= [order]*(levs+1)
tol = flagsUtils(flags, "tol", 1e-8)
permutation_tol = flagsUtils(flags, "tol", 1e-8)
refType = flagsUtils(flags, "refinement_type", "uniform")
optMoves = flagsUtils(flags, "optimisation_moves", False)
snap = flagsUtils(flags, "snap_to", "geometry")
Expand All @@ -221,7 +221,11 @@ def NetgenHierarchy(mesh, levs, flags):
raise RuntimeError("Cannot refine parallel overlapped meshes ")
#We curve the mesh
if order[0]>1:
ho_field = mesh.curve_field(order=order[0], tol=tol, CG=cg)
ho_field = mesh.curve_field(
order=order[0],
permutation_tol=permutation_tol,
cg_field=cg
)
mesh = fd.Mesh(ho_field,distribution_parameters=params, comm=comm)
meshes += [mesh]
cdm = meshes[-1].topology_dm
Expand All @@ -248,8 +252,11 @@ def NetgenHierarchy(mesh, levs, flags):
#We curve the mesh
if order[l+1] > 1:
if snap == "geometry":
mesh = fd.Mesh(mesh.curve_field(order=order[l+1], tol=tol),
distribution_parameters=params, comm=comm)
mesh = fd.Mesh(
mesh.curve_field(order=order[l+1], permutation_tol=permutation_tol),
distribution_parameters=params,
comm=comm
)
elif snap == "coarse":
mesh = snapToCoarse(ho_field, mesh, order[l+1], snap_smoothing, cg)
meshes += [mesh]
Expand Down
210 changes: 127 additions & 83 deletions ngsPETSc/utils/firedrake/meshes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
This module contains all the functions related to wrapping NGSolve meshes to Firedrake
We adopt the same docstring conventiona as the Firedrake project, since this part of
We adopt the same docstring conventions as the Firedrake project, since this part of
the package will only be used in combination with Firedrake.
'''
try:
Expand All @@ -11,6 +11,7 @@

import numpy as np
from petsc4py import PETSc
from scipy.spatial.distance import cdist

import netgen
import netgen.meshing as ngm
Expand Down Expand Up @@ -70,105 +71,148 @@ def refineMarkedElements(self, mark):
else:
raise NotImplementedError("No implementation for dimension other than 2 and 3.")

def curveField(self, order, tol=1e-8, CG=False):

@PETSc.Log.EventDecorator()
UZerbinati marked this conversation as resolved.
Show resolved Hide resolved
def find_permutation(points_a, points_b, tol=1e-5):
""" Find all permutations between a list of two sets of points.

Given two numpy arrays of shape (ncells, npoints, dim) containing
floating point coordinates for each cell, determine each index
permutation that takes `points_a` to `points_b`. Ie:
```
permutation = find_permutation(points_a, points_b)
assert np.allclose(points_a[permutation], points_b, rtol=0, atol=tol)
```
"""
if points_a.shape != points_b.shape:
raise ValueError("`points_a` and `points_b` must have the same shape.")

p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)]
try:
permutation = np.array(p, ndmin=2)
except ValueError as e:
raise ValueError(
"It was not possible to find a permutation for every cell"
" within the provided tolerance"
) from e

if permutation.shape != points_a.shape[0:2]:
raise ValueError(
"It was not possible to find a permutation for every cell"
" within the provided tolerance"
)

return permutation


@PETSc.Log.EventDecorator()
def curveField(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=False):
'''
This method returns a curved mesh as a Firedrake function.

:arg order: the order of the curved mesh
:arg order: the order of the curved mesh.
:arg permutation_tol: tolerance used to construct the permutation of the reference element.
:arg location_tol: tolerance used to locate the cell a point belongs to.
:arg cg_field: return a CG function field representing the mesh, rather than the
default DG field.

'''
#Checking if the mesh is a surface mesh or two dimensional mesh
surf = len(self.netgen_mesh.Elements3D()) == 0
#Constructing mesh as a function
if CG:
space = fd.VectorFunctionSpace(self, "CG", order)
# Check if the mesh is a surface mesh or two dimensional mesh
if len(self.netgen_mesh.Elements3D()) == 0:
ng_element = self.netgen_mesh.Elements2D
else:
ng_element = self.netgen_mesh.Elements3D
ng_dimension = len(ng_element())
geom_dim = self.geometric_dimension()

# Construct the mesh as a Firedrake function
if cg_field:
firedrake_space = fd.VectorFunctionSpace(self, "CG", order)
else:
low_order_element = self.coordinates.function_space().ufl_element().sub_elements[0]
element = low_order_element.reconstruct(degree=order)
space = fd.VectorFunctionSpace(self, fd.BrokenElement(element))
newFunctionCoordinates = fd.assemble(interpolate(self.coordinates, space))
#Computing reference points using fiat
fiat_element = newFunctionCoordinates.function_space().finat_element.fiat_equivalent
ufl_element = low_order_element.reconstruct(degree=order)
firedrake_space = fd.VectorFunctionSpace(self, fd.BrokenElement(ufl_element))
new_coordinates = fd.assemble(interpolate(self.coordinates, firedrake_space))

# Compute reference points using fiat
fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent
entity_ids = fiat_element.entity_dofs()
nodes = fiat_element.dual_basis()
refPts = []
ref = []
for dim in entity_ids:
for entity in entity_ids[dim]:
for dof in entity_ids[dim][entity]:
# Assert singleton point for each node.
pt, = nodes[dof].get_point_dict().keys()
refPts.append(pt)
V = newFunctionCoordinates.dat.data
refPts = np.array(refPts)
els = {True: self.netgen_mesh.Elements2D, False: self.netgen_mesh.Elements3D}
#Mapping to the physical domain
ref.append(pt)
reference_space_points = np.array(ref)

# Curve the mesh on rank 0 only
if self.comm.rank == 0:
physPts = np.ndarray((len(els[surf]()),
refPts.shape[0], self.geometric_dimension()))
self.netgen_mesh.CalcElementMapping(refPts, physPts)
#Cruving the mesh
# Construct numpy arrays for physical domain data
physical_space_points = np.zeros(
(ng_dimension, reference_space_points.shape[0], geom_dim)
)
curved_space_points = np.zeros(
(ng_dimension, reference_space_points.shape[0], geom_dim)
)
self.netgen_mesh.CalcElementMapping(reference_space_points, physical_space_points)
self.netgen_mesh.Curve(order)
curvedPhysPts = np.ndarray((len(els[surf]()),
refPts.shape[0], self.geometric_dimension()))
self.netgen_mesh.CalcElementMapping(refPts, curvedPhysPts)
curved = els[surf]().NumPy()["curved"]
self.netgen_mesh.CalcElementMapping(reference_space_points, curved_space_points)
curved = ng_element().NumPy()["curved"]
# Broadcast a boolean array identifying curved cells
curved = self.comm.bcast(curved, root=0)
physical_space_points = physical_space_points[curved]
curved_space_points = curved_space_points[curved]
else:
physPts = np.ndarray((len(els[surf]()),
refPts.shape[0], self.geometric_dimension()))
curvedPhysPts = np.ndarray((len(els[surf]()),
refPts.shape[0], self.geometric_dimension()))
curved = np.array((len(els[surf]()),1))
physPts = self.comm.bcast(physPts, root=0)
curvedPhysPts = self.comm.bcast(curvedPhysPts, root=0)
curved = self.comm.bcast(curved, root=0)
cellMap = newFunctionCoordinates.cell_node_map()
for i in range(physPts.shape[0]):
#Inefficent code but runs only on curved elements
if curved[i]:
pts = physPts[i][0:refPts.shape[0]]
bary = sum([np.array(pts[i]) for i in range(len(pts))])/len(pts)
Idx = self.locate_cell(bary)
isInMesh = (0<=Idx<len(cellMap.values)) if Idx is not None else False
#Check if element is shared across processes
shared = self.comm.gather(isInMesh, root=0)
shared = self.comm.bcast(shared, root=0)
#Bend if not shared
if np.sum(shared) == 1:
if isInMesh:
p = [np.argmin(np.sum((pts - pt)**2, axis=1))
for pt in V[cellMap.values[Idx]][0:refPts.shape[0]]]
curvedPhysPts[i] = curvedPhysPts[i][p]
res = np.linalg.norm(pts[p]-V[cellMap.values[Idx]][0:refPts.shape[0]])
if res > tol:
fd.logging.warning("[{}] Not able to curve Firedrake element {} \
({}) -- residual: {}".format(self.comm.rank, Idx,i, res))
else:
for j, datIdx in enumerate(cellMap.values[Idx][0:refPts.shape[0]]):
for dim in range(self.geometric_dimension()):
coo = curvedPhysPts[i][j][dim]
newFunctionCoordinates.sub(dim).dat.data[datIdx] = coo
else:
if isInMesh:
p = [np.argmin(np.sum((pts - pt)**2, axis=1))
for pt in V[cellMap.values[Idx]][0:refPts.shape[0]]]
curvedPhysPts[i] = curvedPhysPts[i][p]
res = np.linalg.norm(pts[p]-V[cellMap.values[Idx]][0:refPts.shape[0]])
else:
res = np.inf
res = self.comm.gather(res, root=0)
res = self.comm.bcast(res, root=0)
rank = np.argmin(res)
if self.comm.rank == rank:
if res[rank] > tol:
fd.logging.warning("[{}, {}] Not able to curve Firedrake element {} \
({}) -- residual: {}".format(self.comm.rank, shared, Idx,i, res))
else:
for j, datIdx in enumerate(cellMap.values[Idx][0:refPts.shape[0]]):
for dim in range(self.geometric_dimension()):
coo = curvedPhysPts[i][j][dim]
newFunctionCoordinates.sub(dim).dat.data[datIdx] = coo

return newFunctionCoordinates
curved = self.comm.bcast(None, root=0)
# Construct numpy arrays as buffers to receive physical domain data
ncurved = np.sum(curved)
physical_space_points = np.zeros(
(ncurved, reference_space_points.shape[0], geom_dim)
)
curved_space_points = np.zeros(
(ncurved, reference_space_points.shape[0], geom_dim)
)

# Broadcast curved cell point data
self.comm.Bcast(physical_space_points, root=0)
self.comm.Bcast(curved_space_points, root=0)
cell_node_map = new_coordinates.cell_node_map()

# Select only the points in curved cells
barycentres = np.average(physical_space_points, axis=1)
ng_index = [*map(lambda x: self.locate_cell(x, tolerance=location_tol), barycentres)]

# Select only the indices of points owned by this rank
owned = [(0 <= ii < len(cell_node_map.values)) if ii is not None else False for ii in ng_index]

# Select only the points owned by this rank
physical_space_points = physical_space_points[owned]
curved_space_points = curved_space_points[owned]
barycentres = barycentres[owned]
ng_index = [idx for idx, o in zip(ng_index, owned) if o]

# Get the PyOP2 indices corresponding to the netgen indices
pyop2_index = []
for ngidx in ng_index:
pyop2_index.extend(cell_node_map.values[ngidx])

# Find the correct coordinate permutation for each cell
permutation = find_permutation(
physical_space_points,
new_coordinates.dat.data[pyop2_index].reshape(physical_space_points.shape),
tol=permutation_tol
)

# Apply the permutation to each cell in turn
for ii, p in enumerate(curved_space_points):
curved_space_points[ii] = p[permutation[ii]]

# Assign the curved coordinates to the dat
new_coordinates.dat.data[pyop2_index] = curved_space_points.reshape(-1, geom_dim)

return new_coordinates

def splitToQuads(plex, dim, comm):
'''
Expand Down
38 changes: 38 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
[project]
name = "ngsPETSc"
version = "0.0.5"
description = "NGSolve/Netgen interface to PETSc."
readme = "README.md"
authors = [
{name = "Umberto Zerbinati", email = "[email protected]"},
{name = "Patrick E. Farrell", email = "[email protected]"},
{name = "Stefano Zampini", email = "[email protected]"},
{name = "Jack Betteridge", email = "[email protected]"},
]
maintainers = [
{name = "Umberto Zerbinati", email = "[email protected]"},
]
license = {file = "LICENSE.txt"}
dependencies = [
"mpi4py",
"numpy",
"scipy",
]
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
]

[project.urls]
Documentation = "https://ngspetsc.readthedocs.io/en/latest/"
Repository = "https://github.com/NGSolve/ngsPETSc"

[project.optional-dependencies]
dev = [
"pytest",
"pylint",
]

[build-system]
requires = ['setuptools>=42']
build-backend = 'setuptools.build_meta'
26 changes: 4 additions & 22 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import setuptools, os

with open("README.md", "r", encoding = "utf-8") as fh:
long_description = fh.read()

if 'NGSPETSC_NO_INSTALL_REQUIRED' in os.environ:
install_requires = []
elif 'NGS_FROM_SOURCE' in os.environ:
install_requires = [
'petsc4py',
'mpi4py',
'numpy',
'scipy',
'pytest', #For testing
'pylint', #For formatting
]
Expand All @@ -19,28 +17,12 @@
'ngsolve',
'petsc4py',
'mpi4py',
'scipy',
'pytest', #For testing
'pylint', #For formatting
]

setuptools.setup(
name = "ngsPETSc",
version = "0.0.5",
author = "Umberto Zerbinati",
author_email = "[email protected]",
description = "NGSolve/Netgen interface to PETSc.",
long_description = long_description,
long_description_content_type = "text/markdown",
url = "",
project_urls = {
},
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
packages=["ngsPETSc", "ngsPETSc.utils", "ngsPETSc.utils.firedrake", "ngsPETSc.utils.ngs"],
python_requires = ">=3.8",
install_requires=install_requires

install_requires=install_requires,
packages=["ngsPETSc", "ngsPETSc.utils", "ngsPETSc.utils.firedrake", "ngsPETSc.utils.ngs"]
)