Skip to content

Commit

Permalink
Fix high order in parallel
Browse files Browse the repository at this point in the history
Signed-off-by: Umberto Zerbinati <[email protected]>
  • Loading branch information
Umberto Zerbinati committed Jan 19, 2024
1 parent d933c82 commit 3bfb9c4
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions ngsPETSc/utils/firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def curveField(self, order, digits=8):
element = low_order_element.reconstruct(degree=order)
space = fd.VectorFunctionSpace(self, fd.BrokenElement(element))
newFunctionCoordinates = fd.assemble(fd.interpolate(self.coordinates, space))

self.netgen_mesh = self.comm.bcast(self.netgen_mesh, root=0)
#Computing reference points using fiat
fiat_element = newFunctionCoordinates.function_space().finat_element.fiat_equivalent
entity_ids = fiat_element.entity_dofs()
Expand All @@ -99,9 +99,7 @@ def curveField(self, order, digits=8):
# Assert singleton point for each node.
pt, = nodes[dof].get_point_dict().keys()
refPts.append(pt)

V = newFunctionCoordinates.dat.data
getIdx = self._cell_numbering.getOffset
refPts = np.array(refPts)
rnd = lambda x: round(x, digits)
if self.geometric_dimension() == 2:
Expand All @@ -116,17 +114,18 @@ def curveField(self, order, digits=8):
self.netgen_mesh.CalcElementMapping(refPts, curvedPhysPts)
cellMap = newFunctionCoordinates.cell_node_map()
for i, el in enumerate(self.netgen_mesh.Elements2D()):
if el.curved:
Idx = self.locate_cell(sum(physPts[i])/len(physPts[i]))
isInMesh = (0<=Idx<len(cellMap.values)) if Idx is not None else False
if el.curved and isInMesh:
pts = [tuple(map(rnd, pts))
for pts in physPts[i][0:refPts.shape[0]]]
dofMap = {k: v for v, k in enumerate(pts)}
p = [dofMap[tuple(map(rnd, pts))]
for pts in V[cellMap.values[getIdx(i)]][0:refPts.shape[0]]]
for pts in V[cellMap.values[Idx]][0:refPts.shape[0]]]
curvedPhysPts[i] = curvedPhysPts[i][p]
for j, datIdx in enumerate(cellMap.values[getIdx(i)][0:refPts.shape[0]]):
for j, datIdx in enumerate(cellMap.values[Idx][0:refPts.shape[0]]):
newFunctionCoordinates.sub(0).dat.data[datIdx] = curvedPhysPts[i][j][0]
newFunctionCoordinates.sub(1).dat.data[datIdx] = curvedPhysPts[i][j][1]

if self.geometric_dimension() == 3:
#Mapping to the physical domain
physPts = np.ndarray((len(self.netgen_mesh.Elements3D()),
Expand All @@ -139,14 +138,16 @@ def curveField(self, order, digits=8):
self.netgen_mesh.CalcElementMapping(refPts, curvedPhysPts)
cellMap = newFunctionCoordinates.cell_node_map()
for i, el in enumerate(self.netgen_mesh.Elements3D()):
if el.curved:
Idx = self.locate_cell(sum(physPts[i])/len(physPts[i]))
isInMesh = (0<=Idx<len(cellMap.values)) if Idx is not None else False
if el.curved and isInMesh:
pts = [tuple(map(rnd, pts))
for pts in physPts[i][0:refPts.shape[0]]]
dofMap = {k: v for v, k in enumerate(pts)}
p = [dofMap[tuple(map(rnd, pts))]
for pts in V[cellMap.values[getIdx(i)]][0:refPts.shape[0]]]
for pts in V[cellMap.values[Idx]][0:refPts.shape[0]]]
curvedPhysPts[i] = curvedPhysPts[i][p]
for j, datIdx in enumerate(cellMap.values[getIdx(i)][0:refPts.shape[0]]):
for j, datIdx in enumerate(cellMap.values[Idx][0:refPts.shape[0]]):
newFunctionCoordinates.sub(0).dat.data[datIdx] = curvedPhysPts[i][j][0]
newFunctionCoordinates.sub(1).dat.data[datIdx] = curvedPhysPts[i][j][1]
newFunctionCoordinates.sub(2).dat.data[datIdx] = curvedPhysPts[i][j][2]
Expand Down

0 comments on commit 3bfb9c4

Please sign in to comment.