Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AllAtOnceReducedFunctional - initial time-parallel implementation #3870

Draft
wants to merge 16 commits into
base: allatoncereducedfunctional
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from firedrake.vector import *
from firedrake.version import __version__ as ver, __version_info__, check # noqa: F401
from firedrake.ensemble import *
from firedrake.ensemblefunction import *
from firedrake.randomfunctiongen import *
from firedrake.external_operators import *
from firedrake.progress_bar import ProgressBar # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions firedrake/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \
UFLEqualityConstraint # noqa F401
from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401
from firedrake.adjoint.all_at_once_reduced_functional import AllAtOnceReducedFunctional # noqa F401
import numpy_adjoint # noqa F401
import firedrake.ufl_expr
import types
Expand Down
1,106 changes: 748 additions & 358 deletions firedrake/adjoint/all_at_once_reduced_functional.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions firedrake/adjoint_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from firedrake.adjoint_utils.solving import * # noqa: F401
from firedrake.adjoint_utils.mesh import * # noqa: F401
from firedrake.adjoint_utils.checkpointing import * # noqa: F401
from firedrake.adjoint_utils.ensemblefunction import * # noqa: F401
101 changes: 101 additions & 0 deletions firedrake/adjoint_utils/ensemblefunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pyadjoint.overloaded_type import OverloadedType
from firedrake.petsc import PETSc
from .checkpointing import disk_checkpointing

from functools import wraps


class EnsembleFunctionMixin(OverloadedType):

@staticmethod
def _ad_annotate_init(init):
@wraps(init)
def wrapper(self, *args, **kwargs):
OverloadedType.__init__(self)
init(self, *args, **kwargs)
return wrapper

@staticmethod
def _ad_to_list(m):
with m.vec_ro() as gvec:
lcomm = PETSc.COMM_SELF
gsize = gvec.size
lvec = PETSc.Vec().createSeq(gsize, comm=lcomm)
is_ = PETSc.IS().createStride(gsize, 0, 1, comm=lcomm)

mode = PETSc.InsertMode.INSERT_VALUES
scatter = PETSc.Scatter().create(gvec, is_, lvec, None)
scatter.scatterBegin(gvec, lvec, addv=mode)
scatter.scatterEnd(gvec, lvec, addv=mode)

return lvec.array_r.tolist()

@staticmethod
def _ad_assign_numpy(dst, src, offset):
with dst.vec_wo() as vec:
begin, end = vec.owner_range
src_array = src[offset + begin: offset + end]
vec.array[:] = src_array
offset += vec.size
return dst, offset

def _ad_dot(self, other, options=None):
# local dot product
ldot = sum(
uself._ad_dot(uother, options=options)
for uself, uother in zip(self.subfunctions,
other.subfunctions))
# global dot product
gdot = self.ensemble.ensemble_comm.allreduce(ldot)
return gdot

def _ad_add(self, other):
new = self.copy()
new += other
return new

def _ad_mul(self, other):
new = self.copy()
# `self` can be a Cofunction in which case only left multiplication with a scalar is allowed.
other = other._fbuf if type(other) is type(self) else other
new._fbuf.assign(other*new._fbuf)
return new

def _ad_iadd(self, other):
self += other
return self

def _ad_imul(self, other):
self *= other
return self

def _ad_copy(self):
return self.copy()

def _ad_convert_riesz(self, value, options=None):
raise ValueError("NotImplementedYet")

def _ad_create_checkpoint(self):
if disk_checkpointing():
raise NotImplementedError(
"Disk checkpointing not implemented for EnsembleFunctions")
else:
return self.copy()

def _ad_restore_at_checkpoint(self, checkpoint):
if isinstance(checkpoint, type(self)):
return checkpoint
raise NotImplementedError(
"Checkpointing not implemented for EnsembleFunctions")

def _ad_from_petsc(self, vec):
with self.vec_wo as self_v:
vec.copy(result=self_v)

def _ad_to_petsc(self, vec=None):
with self.vec_ro as self_v:
if vec:
self_v.copy(result=vec)
else:
vec = self_v.copy()
return vec
3 changes: 3 additions & 0 deletions firedrake/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def component_tensor(self, o, a, _):
def coefficient(self, o):
return ((o, 1),)

def cofunction(self, o):
return ((o, 1),)

def constant_value(self, o):
return ((o, 1),)

Expand Down
62 changes: 61 additions & 1 deletion firedrake/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import weakref
from contextlib import contextmanager
from itertools import zip_longest

from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from pyop2.mpi import MPI, internal_comm
from itertools import zip_longest

__all__ = ("Ensemble", )

Expand Down Expand Up @@ -283,3 +286,60 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r
requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag)
for dat in frecv.dat])
return requests

@contextmanager
def sequential(self, **kwargs):
"""
Context manager for executing code on each ensemble
member in turn.

Any data in `kwargs` will be made available in the context
and will be communicated forward after each ensemble member
exits. Firedrake Functions/Cofunctions will be send with the
corresponding Ensemble methods.

with ensemble.sequential(index=0) as ctx:
print(ensemble.ensemble_comm.rank, ctx.index)
ctx.index += 2

Would print:
0 0
1 2
2 4
3 6
... etc ...

"""
rank = self.ensemble_comm.rank
first_rank = (rank == 0)
last_rank = (rank == self.ensemble_comm.size - 1)

if not first_rank:
src = rank - 1
for i, (k, v) in enumerate(kwargs.items()):
recv_kwargs = {'source': src, 'tag': rank+i*100}
if isinstance(v, (Function, Cofunction)):
self.recv(kwargs[k], **recv_kwargs)
else:
kwargs[k] = self.ensemble_comm.recv(
**recv_kwargs)

ctx = _EnsembleContext(**kwargs)

yield ctx

if not last_rank:
dst = rank + 1
for i, v in enumerate((getattr(ctx, k)
for k in kwargs.keys())):
send_kwargs = {'dest': dst, 'tag': dst+i*100}
if isinstance(v, (Function, Cofunction)):
self.send(v, **send_kwargs)
else:
self.ensemble_comm.send(v, **send_kwargs)


class _EnsembleContext:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
Loading
Loading