Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix assembly of Real matrices #3846

Merged
merged 11 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 81 additions & 76 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from firedrake.utils import ScalarType, assert_empty, tuplify
from pyop2 import op2
from pyop2.exceptions import MapValueError, SparsityFormatError
from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload
from pyop2.utils import cached_property


Expand Down Expand Up @@ -965,22 +966,24 @@ def assemble(self, tensor=None):
Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms.

"""
self._check_tensor(tensor)
if tensor is None:
tensor = self.allocate()
needs_zeroing = False
else:
needs_zeroing = self._needs_zeroing
if annotate_tape():
raise NotImplementedError(
"Taping with explicit FormAssembler objects is not supported yet. "
"Use assemble instead."
)
if needs_zeroing:
type(self)._as_pyop2_type(tensor).zero()

if tensor is None:
tensor = self.allocate()
else:
self._check_tensor(tensor)
if self._needs_zeroing:
self._as_pyop2_type(tensor).zero()

self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)

return self.result(tensor)

@abc.abstractmethod
Expand All @@ -992,9 +995,9 @@ def _check_tensor(self, tensor):
"""Check input tensor."""

@staticmethod
def _as_pyop2_type(tensor):
"""Return tensor as pyop2 type."""
raise NotImplementedError
@abc.abstractmethod
def _as_pyop2_type(tensor, indices=None):
"""Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it."""

def execute_parloops(self, tensor):
for parloop in self.parloops(tensor):
Expand All @@ -1003,29 +1006,27 @@ def execute_parloops(self, tensor):
def parloops(self, tensor):
if hasattr(self, "_parloops"):
for (lknl, _), parloop in zip(self.local_kernels, self._parloops):
data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal)
data = self._as_pyop2_type(tensor, lknl.indices)
parloop.arguments[0].data = data

else:
# Make parloops for one concrete output tensor and cache them.
# TODO: Make parloops only with some symbolic information of the output tensor.
self._parloops = tuple(parloop_builder.build(tensor) for parloop_builder in self.parloop_builders)
return self._parloops

@cached_property
def parloop_builders(self):
out = []
for local_kernel, subdomain_id in self.local_kernels:
out.append(
ParloopBuilder(
parloops_ = []
for local_kernel, subdomain_id in self.local_kernels:
parloop_builder = ParloopBuilder(
self._form,
self._bcs,
local_kernel,
subdomain_id,
self.all_integer_subdomain_ids,
diagonal=self.diagonal,
)
)
return tuple(out)
pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices)
parloop = parloop_builder.build(pyop2_tensor)
parloops_.append(parloop)
self._parloops = tuple(parloops_)

return self._parloops

@cached_property
def local_kernels(self):
Expand Down Expand Up @@ -1120,10 +1121,11 @@ def _apply_bc(self, tensor, bc):
pass

def _check_tensor(self, tensor):
assert tensor is None
pass

@staticmethod
def _as_pyop2_type(tensor):
def _as_pyop2_type(tensor, indices=None):
assert not indices
return tensor

def result(self, tensor):
Expand Down Expand Up @@ -1198,15 +1200,16 @@ def _apply_dirichlet_bc(self, tensor, bc):
bc.zero(tensor)

def _check_tensor(self, tensor):
rank = len(self._form.arguments())
if rank == 1:
test, = self._form.arguments()
if tensor is not None and test.function_space() != tensor.function_space():
raise ValueError("Form's argument does not match provided result tensor")
if tensor.function_space() != self._form.arguments()[0].function_space():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
def _as_pyop2_type(tensor):
return tensor.dat
def _as_pyop2_type(tensor, indices=None):
if indices is not None and any(index is not None for index in indices):
i, = indices
return tensor.dat[i]
else:
return tensor.dat

def execute_parloops(self, tensor):
# We are repeatedly incrementing into the same Dat so intermediate halo exchanges
Expand Down Expand Up @@ -1454,12 +1457,26 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set):
dat.zero(subset=node_set)

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
if tensor.a.arguments() != self._form.arguments():
raise ValueError("Form's arguments do not match provided result tensor")

@staticmethod
def _as_pyop2_type(tensor):
return tensor.M
def _as_pyop2_type(tensor, indices=None):
if indices is not None and any(index is not None for index in indices):
i, j = indices
mat = tensor.M[i, j]
else:
mat = tensor.M

if mat.handle.getType() == "python":
mat_context = mat.handle.getPythonContext()
if isinstance(mat_context, _GlobalMatPayload):
mat = mat_context.global_
else:
assert isinstance(mat_context, _DatMatPayload)
mat = mat_context.dat

return mat

def result(self, tensor):
tensor.M.assemble()
Expand All @@ -1471,7 +1488,7 @@ class MatrixFreeAssembler(FormAssembler):

Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
2-form.

Notes
Expand All @@ -1498,14 +1515,15 @@ def allocate(self):
appctx=self._appctx or {})

def assemble(self, tensor=None):
self._check_tensor(tensor)
if tensor is None:
tensor = self.allocate()
else:
self._check_tensor(tensor)
tensor.assemble()
return tensor

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
if tensor.a.arguments() != self._form.arguments():
raise ValueError("Form's arguments do not match provided result tensor")


Expand Down Expand Up @@ -1820,12 +1838,12 @@ def __init__(self, form, bcs, local_knl, subdomain_id,
self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)

def build(self, tensor):
def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop:
"""Construct the parloop.

Parameters
----------
tensor : op2.Global or firedrake.cofunction.Cofunction or matrix.MatrixBase
tensor :
The output tensor.

"""
Expand Down Expand Up @@ -1909,17 +1927,28 @@ def collect_lgmaps(self):
:param local_knl: A :class:`tsfc_interface.SplitKernel`.
:param bcs: Iterable of boundary conditions.
"""

if len(self._form.arguments()) == 2 and not self._diagonal:
if not self._bcs:
return None
lgmaps = []
for i, j in self.get_indicess():

if any(i is not None for i in self._local_knl.indices):
i, j = self._local_knl.indices
row_bcs, col_bcs = self._filter_bcs(i, j)
rlgmap, clgmap = self._tensor.M[i, j].local_to_global_maps
# the tensor is already indexed
rlgmap, clgmap = self._tensor.local_to_global_maps
rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap)
clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap)
lgmaps.append((rlgmap, clgmap))
return tuple(lgmaps)
return ((rlgmap, clgmap),)
else:
lgmaps = []
for i, j in self.get_indicess():
row_bcs, col_bcs = self._filter_bcs(i, j)
rlgmap, clgmap = self._tensor[i, j].local_to_global_maps
rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap)
clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap)
lgmaps.append((rlgmap, clgmap))
return tuple(lgmaps)
else:
return None

Expand All @@ -1939,10 +1968,6 @@ def _integral_type(self):
def _indexed_function_spaces(self):
return _FormHandler.index_function_spaces(self._form, self._indices)

@property
def _indexed_tensor(self):
return _FormHandler.index_tensor(self._tensor, self._form, self._indices, self._diagonal)

@cached_property
def _mesh(self):
return self._form.ufl_domains()[self._kinfo.domain_number]
Expand Down Expand Up @@ -1990,28 +2015,27 @@ def _as_parloop_arg(tsfc_arg, self):
@_as_parloop_arg.register(kernel_args.OutputKernelArg)
def _as_parloop_arg_output(_, self):
rank = len(self._form.arguments())
tensor = self._indexed_tensor
Vs = self._indexed_function_spaces

if rank == 0:
return op2.GlobalParloopArg(tensor)
return op2.GlobalParloopArg(self._tensor)
elif rank == 1 or rank == 2 and self._diagonal:
V, = Vs
if V.ufl_element().family() == "Real":
return op2.GlobalParloopArg(tensor)
return op2.GlobalParloopArg(self._tensor)
else:
return op2.DatParloopArg(tensor, self._get_map(V))
return op2.DatParloopArg(self._tensor, self._get_map(V))
elif rank == 2:
rmap, cmap = [self._get_map(V) for V in Vs]

if all(V.ufl_element().family() == "Real" for V in Vs):
assert rmap is None and cmap is None
return op2.GlobalParloopArg(tensor.handle.getPythonContext().global_)
return op2.GlobalParloopArg(self._tensor)
elif any(V.ufl_element().family() == "Real" for V in Vs):
m = rmap or cmap
return op2.DatParloopArg(tensor.handle.getPythonContext().dat, m)
return op2.DatParloopArg(self._tensor, m)
else:
return op2.MatParloopArg(tensor, (rmap, cmap), lgmaps=self.collect_lgmaps())
return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps())
else:
raise AssertionError

Expand Down Expand Up @@ -2122,22 +2146,3 @@ def index_function_spaces(form, indices):
return tuple(a.ufl_function_space()[i] for i, a in zip(indices, form.arguments()))
else:
raise AssertionError

@staticmethod
def index_tensor(tensor, form, indices, diagonal):
"""Return the PyOP2 data structure tied to ``tensor``, indexed
if necessary.
"""
rank = len(form.arguments())
is_indexed = any(i is not None for i in indices)

if rank == 0:
return tensor
elif rank == 1 or rank == 2 and diagonal:
i, = indices
return tensor.dat[i] if is_indexed else tensor.dat
elif rank == 2:
i, j = indices
return tensor.M[i, j] if is_indexed else tensor.M
else:
raise AssertionError
18 changes: 18 additions & 0 deletions tests/regression/test_assemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
from firedrake import *
from firedrake.assemble import TwoFormAssembler
from firedrake.utils import ScalarType, IntType


Expand Down Expand Up @@ -125,6 +126,23 @@ def test_assemble_mat_with_tensor(mesh):
assert np.allclose(M.M.values, 2*assemble(a).M.values, rtol=1e-14)


@pytest.mark.skipcomplex
def test_mat_nest_real_block_assembler_correctly_reuses_tensor(mesh):
V = FunctionSpace(mesh, "CG", 1)
R = FunctionSpace(mesh, "R", 0)
W = V * R

u = TrialFunction(W)
v = TestFunction(W)
a = inner(v, u) * dx

assembler = TwoFormAssembler(a, mat_type="nest")
A1 = assembler.assemble()
A2 = assembler.assemble(tensor=A1)

assert A2.M is A1.M


def test_assemble_diagonal(mesh):
V = FunctionSpace(mesh, "P", 3)
u = TrialFunction(V)
Expand Down
Loading