Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi4py #1070

Merged
merged 43 commits into from
Aug 2, 2023
Merged

mpi4py #1070

Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
78a1a03
Added module, method, and attribute replacements to allow parsing of …
alexnick83 Jul 26, 2022
a13a81d
ProcessGrids now appear in defined variables and are explicitely retu…
alexnick83 Jul 26, 2022
a8d5690
Added MPIResolver to resolve mpi4py-related constants during preproce…
alexnick83 Jul 26, 2022
ac480bd
Added mpi4py compatiblity tests.
alexnick83 Jul 26, 2022
d229266
Made opaque type for MPI_Request a basic dace type.
alexnick83 Jul 27, 2022
f933974
Adjusted existing Isend/Irecv replacements and added new ones for mpi…
alexnick83 Jul 27, 2022
c224013
Adjusted visit_Attribute of MPI_Resolver to not trigger to calls of M…
alexnick83 Jul 27, 2022
ebe22ed
Replacement for numpy full now also works with (scalar) data.
alexnick83 Jul 27, 2022
2bfcea9
Isend/Irecv can now use communicators other than COMM_WORLD.
alexnick83 Jul 27, 2022
fd0dc40
Added mpi4py-compatible Isend/Irecv test.
alexnick83 Jul 27, 2022
7ca527c
Updated mpi_allgather_test.py for coding style consistency
Com1t Jun 19, 2023
38be749
Added alltoall node basic version based on other collectives
Com1t Jun 19, 2023
e5085ae
Fixed mpi_send_recv_test.py
Com1t Jun 22, 2023
f8c9550
Added mpi4py replacement for send/recv
Com1t Jun 22, 2023
345c36b
Updated mpi_send_recv_test.py for correctness of blocking comm
Com1t Jul 6, 2023
1cea59e
Updated Isend/Irecv test
Com1t Jul 6, 2023
b82b06a
Updated alltoall library node for logical correctness
Com1t Jul 6, 2023
a115db6
Added replacement and test for mpi4py alltoall
Com1t Jul 6, 2023
eebefe4
Corrected the out_desc in alltoall replacement
Com1t Jul 7, 2023
110d0f2
Added alltoall replacement for ProcessGrid and Intracomm
Com1t Jul 7, 2023
dd06eb6
Merge pull request #1288 from Com1t/mpi4py_dev
alexnick83 Jul 7, 2023
73f90cf
Merge branch 'master' into mpi4py
alexnick83 Jul 7, 2023
8626b9a
Fixed bad merge.
alexnick83 Jul 7, 2023
442a873
Updated tests.
alexnick83 Jul 7, 2023
832c203
uncommented out tests.
alexnick83 Jul 7, 2023
6dd00bc
Merge branch 'master' into mpi4py
alexnick83 Jul 11, 2023
e67aa8e
The COMM_WORLD communicator object does not have its name changes to …
alexnick83 Jul 12, 2023
2741579
All (mpi4py) communicators in the global context are now registered i…
alexnick83 Jul 12, 2023
6fe26c4
The Bcast LibraryNode can now accept as a string the name of a variab…
alexnick83 Jul 12, 2023
af624c3
Replacements for COMM_WORLD were removed. Instead, the Intracomm's cl…
alexnick83 Jul 12, 2023
fe22182
Added two new Bcast tests for COMM_WORLD and Intracomm object.
alexnick83 Jul 12, 2023
5bcf53b
Merge branch 'master' into mpi4py
alexnick83 Jul 12, 2023
6c5ffa1
Restored replacements needed for full name of COMM_WORLD. Cleaned up …
alexnick83 Jul 12, 2023
c3b1a4b
Further clean up
alexnick83 Jul 12, 2023
31d7b84
Merge branch 'master' into mpi4py
alexnick83 Jul 21, 2023
311f5b9
Merge branch 'master' into mpi4py
alexnick83 Jul 26, 2023
70198d5
Added comm-comparison tests.
alexnick83 Jul 26, 2023
727afa7
Refactored communicator comparsion replacements.
alexnick83 Jul 26, 2023
ac177bd
Addressed review comments.
alexnick83 Jul 26, 2023
01f82fa
YAPF
alexnick83 Jul 26, 2023
7bfd960
Added extra exception to catch.
alexnick83 Jul 26, 2023
1aff2df
Merge branch 'master' into mpi4py
tbennun Aug 1, 2023
8d4ad6d
Merge branch 'master' into mpi4py
alexnick83 Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ def isconstant(var):
complex64 = typeclass(numpy.complex64)
complex128 = typeclass(numpy.complex128)
string = stringtype()
MPI_Request = opaque('MPI_Request')


@undefined_safe_enum
Expand Down
474 changes: 391 additions & 83 deletions dace/frontend/common/distr.py

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,14 @@ def defined(self):
# Add SDFG arrays, in case a replacement added a new output
result.update(self.sdfg.arrays)

# MPI-related stuff
result.update({k: self.sdfg.process_grids[v] for k, v in self.variables.items() if v in self.sdfg.process_grids})
try:
from mpi4py import MPI
result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not happy about this. Can we generalize to add more "supported global types" instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that we could add another API method to register supported global types, where the user would specify how exactly they should be handled. Shouldn't this be a new PR, though?

except:
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved
pass

return result

def _add_state(self, label=None):
Expand Down Expand Up @@ -4356,8 +4364,11 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
# Add object as first argument
if modname in self.variables.keys():
arg = self.variables[modname]
else:
elif modname in self.scope_vars.keys():
arg = self.scope_vars[modname]
else:
# Fallback to (name, object)
arg = (modname, self.defined[modname])
args.append(arg)
# Otherwise, try to find a default implementation for the SDFG
elif not found_ufunc:
Expand Down Expand Up @@ -4667,7 +4678,9 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]:

result = []
for operand in operands:
if isinstance(operand, str) and operand in self.sdfg.arrays:
if isinstance(operand, str) and operand in self.sdfg.process_grids:
result.append((operand, type(self.sdfg.process_grids[operand]).__name__))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do process grids take precedence over everything else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A communicator is registered in the process grids but has a corresponding Scalar with the same name (currently, an integer) to allow creating AccessNodes. So far, we only had persistent communicators, COMM_WORLD or communicators created in the program's init. However, going forward, we want to support communicator creation/deletion during the program's execution, raising the need for expressing dependencies among communicator creation routines and communication calls on the SDFG. I suppose the best solution is to make an opaque type and only check SDFG.arrays.

elif isinstance(operand, str) and operand in self.sdfg.arrays:
result.append((operand, type(self.sdfg.arrays[operand])))
elif isinstance(operand, str) and operand in self.scope_arrays:
result.append((operand, type(self.scope_arrays[operand])))
Expand Down
62 changes: 62 additions & 0 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,62 @@ def find_disallowed_statements(node: ast.AST):
return None


class MPIResolver(ast.NodeTransformer):
""" Resolves mpi4py-related constants, e.g., mpi4py.MPI.COMM_WORLD. """
def __init__(self, globals: Dict[str, Any]):
from mpi4py import MPI
self.globals = globals
self.MPI = MPI
self.parent = None

def visit(self, node):
node.parent = self.parent
self.parent = node
node = super().visit(node)
if isinstance(node, ast.AST):
self.parent = node.parent
return node

def visit_Name(self, node: ast.Name) -> Union[ast.Name, ast.Attribute]:
self.generic_visit(node)
if node.id in self.globals:
obj = self.globals[node.id]
if isinstance(obj, self.MPI.Comm):
lattr = ast.Attribute(ast.Name(id='mpi4py', ctx=ast.Load), attr='MPI')
if obj is self.MPI.COMM_NULL:
newnode = ast.copy_location(ast.Attribute(value=lattr, attr='COMM_NULL'), node)
newnode.parent = node.parent
return newnode
return node

def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
self.generic_visit(node)
if isinstance(node.attr, str) and node.attr == 'Request':
try:
val = astutils.evalnode(node, self.globals)
if val is self.MPI.Request and not isinstance(node.parent, ast.Attribute):
newnode = ast.copy_location(
ast.Attribute(value=ast.Name(id='dace', ctx=ast.Load), attr='MPI_Request'), node)
newnode.parent = node.parent
return newnode
except SyntaxError:
pass
return node


class ModuloConverter(ast.NodeTransformer):
""" Converts a % b expressions to (a + b) % b for C/C++ compatibility. """
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will make some memlets that use % more complex. Is there a way to limit this behavior? Maybe do it in the code generator (i.e., cppunparse or Mod in the runtime .h file)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is frontend specific (see, e.g., -2 % 5 Python vs C), so I would definitely not handle it in code generation. I don't believe that the added complexity is an issue (sympy is probably not going to simplify the modulo operator anyway, so I doubt it breaks any optimization). We could instead map it to a pymod call, but that is probably an even worse solution if the concern is the simplification of symbolic expressions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this issue comes up now and then, and we forget about it after a while. The problem is that we like to use Python syntax on the SDFG to take advantage of the AST module, but we want this syntax to follow C semantics. Therefore, the frontend has to rewrite certain Python expressions so that they return the same result both in Python and C/C++.

Copy link
Contributor Author

@alexnick83 alexnick83 Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you are thinking that SDFG elements that include Python code should be unparsed to C/C++ equivalent expressions (semantically) during code generation. However, then the other frontends must change to convert C/Fortran/whatever semantics to Python. Is this a good idea?


def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
if isinstance(node.op, ast.Mod):
left = self.generic_visit(node.left)
right = self.generic_visit(node.right)
newleft = ast.copy_location(ast.BinOp(left=left, op=ast.Add(), right=copy.deepcopy(right)), left)
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved
node.left = newleft
return node
return self.generic_visit(node)


def preprocess_dace_program(f: Callable[..., Any],
argtypes: Dict[str, data.Data],
global_vars: Dict[str, Any],
Expand Down Expand Up @@ -1544,6 +1600,12 @@ def preprocess_dace_program(f: Callable[..., Any],
newmod = global_vars[mod]
#del global_vars[mod]
global_vars[modval] = newmod

try:
src_ast = MPIResolver(global_vars).visit(src_ast)
except ModuleNotFoundError:
pass
src_ast = ModuloConverter().visit(src_ast)

# Resolve constants to their values (if they are not already defined in this scope)
# and symbols to their names
Expand Down
27 changes: 19 additions & 8 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,26 +282,37 @@ def _numpy_full(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
shape: Shape,
fill_value: Union[sp.Expr, Number],
fill_value: Union[sp.Expr, Number, data.Scalar],
dtype: dace.typeclass = None):
""" Creates and array of the specified shape and initializes it with
the fill value.
"""
is_data = False
if isinstance(fill_value, (Number, np.bool_)):
vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)]
elif isinstance(fill_value, sp.Expr):
vtype = _sym_type(fill_value)
else:
raise mem_parser.DaceSyntaxError(pv, None, "Fill value {f} must be a number!".format(f=fill_value))
is_data = True
vtype = sdfg.arrays[fill_value].dtype
dtype = dtype or vtype
name, _ = sdfg.add_temp_transient(shape, dtype)

state.add_mapped_tasklet(
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)}, {},
"__out = {}".format(fill_value),
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
if is_data:
state.add_mapped_tasklet(
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)},
dict(__inp=dace.Memlet(data=fill_value, subset='0')),
"__out = __inp",
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
else:
state.add_mapped_tasklet(
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)}, {},
"__out = {}".format(fill_value),
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)

return name

Expand Down
1 change: 1 addition & 0 deletions dace/libraries/mpi/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
from .reduce import Reduce
from .allreduce import Allreduce
from .allgather import Allgather
from .alltoall import Alltoall
from .dummy import Dummy
from .redistribute import Redistribute
84 changes: 84 additions & 0 deletions dace/libraries/mpi/nodes/alltoall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved
import dace.library
import dace.properties
import dace.sdfg.nodes
from dace.transformation.transformation import ExpandTransformation
from .. import environments
from dace.libraries.mpi.nodes.node import MPINode


@dace.library.expansion
class ExpandAlltoallMPI(ExpandTransformation):

environments = [environments.mpi.MPI]

@staticmethod
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
(inbuffer, in_count_str), (outbuffer, out_count_str) = node.validate(parent_sdfg, parent_state)
in_mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(inbuffer.dtype.base_type)
out_mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(outbuffer.dtype.base_type)

if inbuffer.dtype.veclen > 1:
raise (NotImplementedError)

comm = "MPI_COMM_WORLD"
if node.grid:
comm = f"__state->{node.grid}_comm"

code = f"""
int size;
MPI_Comm_size({comm}, &size);
int sendrecv_amt = {in_count_str} / size;
MPI_Alltoall(_inbuffer, sendrecv_amt, {in_mpi_dtype_str}, \
_outbuffer, sendrecv_amt, {out_mpi_dtype_str}, \
{comm});
"""
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
code,
language=dace.dtypes.Language.CPP)
return tasklet


@dace.library.node
class Alltoall(MPINode):

# Global properties
implementations = {
"MPI": ExpandAlltoallMPI,
}
default_implementation = "MPI"

grid = dace.properties.Property(dtype=str, allow_none=True, default=None)

def __init__(self, name, grid=None, *args, **kwargs):
super().__init__(name, *args, inputs={"_inbuffer"}, outputs={"_outbuffer"}, **kwargs)
self.grid = grid

def validate(self, sdfg, state):
"""
:return: A three-tuple (buffer, root) of the three data descriptors in the
parent SDFG.
"""

inbuffer, outbuffer = None, None
for e in state.out_edges(self):
if e.src_conn == "_outbuffer":
outbuffer = sdfg.arrays[e.data.data]
for e in state.in_edges(self):
if e.dst_conn == "_inbuffer":
inbuffer = sdfg.arrays[e.data.data]

in_count_str = "XXX"
out_count_str = "XXX"
for _, src_conn, _, _, data in state.out_edges(self):
if src_conn == '_outbuffer':
dims = [str(e) for e in data.subset.size_exact()]
out_count_str = "*".join(dims)
for _, _, _, dst_conn, data in state.in_edges(self):
if dst_conn == '_inbuffer':
dims = [str(e) for e in data.subset.size_exact()]
in_count_str = "*".join(dims)

return (inbuffer, in_count_str), (outbuffer, out_count_str)
9 changes: 8 additions & 1 deletion dace/libraries/mpi/nodes/bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
if isinstance(buffer, dace.data.Scalar):
ref = "&"

init = ""
comm = "MPI_COMM_WORLD"
if node.grid:
comm = f"__state->{node.grid}_comm"
elif node.fcomm:
init = f"MPI_Comm __comm = MPI_Comm_f2c({node.fcomm});"
comm = "__comm"

code = f"""
{init}
MPI_Bcast({ref}_inbuffer, {count_str}, {mpi_dtype_str}, _root, {comm});
_outbuffer = _inbuffer;"""
tasklet = dace.sdfg.nodes.Tasklet(node.name,
Expand All @@ -67,10 +72,12 @@ class Bcast(MPINode):
default_implementation = "MPI"

grid = dace.properties.Property(dtype=str, allow_none=True, default=None)
fcomm = dace.properties.Property(dtype=str, allow_none=True, default=None)

def __init__(self, name, grid=None, *args, **kwargs):
def __init__(self, name, grid=None, fcomm=None, *args, **kwargs):
super().__init__(name, *args, inputs={"_inbuffer", "_root"}, outputs={"_outbuffer"}, **kwargs)
self.grid = grid
self.fcomm = fcomm

def validate(self, sdfg, state):
"""
Expand Down
12 changes: 10 additions & 2 deletions dace/libraries/mpi/nodes/irecv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):

if buffer.dtype.veclen > 1:
raise NotImplementedError

comm = "MPI_COMM_WORLD"
if node.grid:
comm = f"__state->{node.grid}_comm"

code = ""
if ddt is not None:
code = f"""static MPI_Datatype newtype;
Expand All @@ -33,7 +38,7 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
mpi_dtype_str = "newtype"
count_str = "1"
buffer_offset = 0 #this is here because the frontend already changes the pointer
code += f"MPI_Irecv(_buffer, {count_str}, {mpi_dtype_str}, _src, _tag, MPI_COMM_WORLD, _request);"
code += f"MPI_Irecv(_buffer, {count_str}, {mpi_dtype_str}, int(_src), int(_tag), {comm}, _request);"
if ddt is not None:
code += f"""// MPI_Type_free(&newtype);
"""
Expand All @@ -58,8 +63,11 @@ class Irecv(MPINode):
}
default_implementation = "MPI"

def __init__(self, name, *args, **kwargs):
grid = dace.properties.Property(dtype=str, allow_none=True, default=None)

def __init__(self, name, grid=None, *args, **kwargs):
super().__init__(name, *args, inputs={"_src", "_tag"}, outputs={"_buffer", "_request"}, **kwargs)
self.grid = grid

def validate(self, sdfg, state):
"""
Expand Down
22 changes: 13 additions & 9 deletions dace/libraries/mpi/nodes/isend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):

if buffer.dtype.veclen > 1:
raise NotImplementedError

comm = "MPI_COMM_WORLD"
if node.grid:
comm = f"__state->{node.grid}_comm"

code = ""

Expand All @@ -40,7 +44,7 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
mpi_dtype_str = "newtype"
count_str = "1"
buffer_offset = 0
code += f"MPI_Isend(&(_buffer[{buffer_offset}]), {count_str}, {mpi_dtype_str}, _dest, _tag, MPI_COMM_WORLD, _request);"
code += f"MPI_Isend(&(_buffer[{buffer_offset}]), {count_str}, {mpi_dtype_str}, int(_dest), int(_tag), {comm}, _request);"
if ddt is not None:
code += f"""// MPI_Type_free(&newtype);
"""
Expand Down Expand Up @@ -69,13 +73,12 @@ class Isend(MPINode):
}
default_implementation = "MPI"

# Object fields
n = dace.properties.SymbolicProperty(allow_none=True, default=None)

grid = dace.properties.Property(dtype=str, allow_none=True, default=None)
nosync = dace.properties.Property(dtype=bool, default=False, desc="Do not sync if memory is on GPU")

def __init__(self, name, *args, **kwargs):
def __init__(self, name, grid=None, *args, **kwargs):
super().__init__(name, *args, inputs={"_buffer", "_dest", "_tag"}, outputs={"_request"}, **kwargs)
self.grid = grid

def validate(self, sdfg, state):
"""
Expand All @@ -94,10 +97,11 @@ def validate(self, sdfg, state):
if e.src_conn == "_request":
req = sdfg.arrays[e.data.data]

if dest.dtype.base_type != dace.dtypes.int32:
raise ValueError("Source must be an integer!")
if tag.dtype.base_type != dace.dtypes.int32:
raise ValueError("Tag must be an integer!")
# TODO: Should we expect any integer type here and cast to int32 later?. Investigate further in the future.
# if dest.dtype.base_type != dace.dtypes.int32:
# raise ValueError("Destination must be an integer!")
# if tag.dtype.base_type != dace.dtypes.int32:
# raise ValueError("Tag must be an integer!")
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved

count_str = "XXX"
for _, _, _, dst_conn, data in state.in_edges(self):
Expand Down
Loading