-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mpi4py #1070
mpi4py #1070
Changes from 35 commits
78a1a03
a13a81d
a8d5690
ac480bd
d229266
f933974
c224013
ebe22ed
2bfcea9
fd0dc40
7ca527c
38be749
e5085ae
f8c9550
345c36b
1cea59e
b82b06a
a115db6
eebefe4
110d0f2
dd06eb6
73f90cf
8626b9a
442a873
832c203
6dd00bc
e67aa8e
2741579
6fe26c4
af624c3
fe22182
5bcf53b
6c5ffa1
c3b1a4b
31d7b84
311f5b9
70198d5
727afa7
ac177bd
01f82fa
7bfd960
1aff2df
8d4ad6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1303,6 +1303,14 @@ def defined(self): | |
# Add SDFG arrays, in case a replacement added a new output | ||
result.update(self.sdfg.arrays) | ||
|
||
# MPI-related stuff | ||
result.update({k: self.sdfg.process_grids[v] for k, v in self.variables.items() if v in self.sdfg.process_grids}) | ||
try: | ||
from mpi4py import MPI | ||
result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)}) | ||
except: | ||
alexnick83 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
|
||
return result | ||
|
||
def _add_state(self, label=None): | ||
|
@@ -4356,8 +4364,11 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): | |
# Add object as first argument | ||
if modname in self.variables.keys(): | ||
arg = self.variables[modname] | ||
else: | ||
elif modname in self.scope_vars.keys(): | ||
arg = self.scope_vars[modname] | ||
else: | ||
# Fallback to (name, object) | ||
arg = (modname, self.defined[modname]) | ||
args.append(arg) | ||
# Otherwise, try to find a default implementation for the SDFG | ||
elif not found_ufunc: | ||
|
@@ -4667,7 +4678,9 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]: | |
|
||
result = [] | ||
for operand in operands: | ||
if isinstance(operand, str) and operand in self.sdfg.arrays: | ||
if isinstance(operand, str) and operand in self.sdfg.process_grids: | ||
result.append((operand, type(self.sdfg.process_grids[operand]).__name__)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do process grids take precedence over everything else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A communicator is registered in the process grids but has a corresponding Scalar with the same name (currently, an integer) to allow creating AccessNodes. So far, we only had persistent communicators, COMM_WORLD or communicators created in the program's init. However, going forward, we want to support communicator creation/deletion during the program's execution, raising the need for expressing dependencies among communicator creation routines and communication calls on the SDFG. I suppose the best solution is to make an opaque type and only check SDFG.arrays. |
||
elif isinstance(operand, str) and operand in self.sdfg.arrays: | ||
result.append((operand, type(self.sdfg.arrays[operand]))) | ||
elif isinstance(operand, str) and operand in self.scope_arrays: | ||
result.append((operand, type(self.scope_arrays[operand]))) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1503,6 +1503,62 @@ def find_disallowed_statements(node: ast.AST): | |
return None | ||
|
||
|
||
class MPIResolver(ast.NodeTransformer): | ||
""" Resolves mpi4py-related constants, e.g., mpi4py.MPI.COMM_WORLD. """ | ||
def __init__(self, globals: Dict[str, Any]): | ||
from mpi4py import MPI | ||
self.globals = globals | ||
self.MPI = MPI | ||
self.parent = None | ||
|
||
def visit(self, node): | ||
node.parent = self.parent | ||
self.parent = node | ||
node = super().visit(node) | ||
if isinstance(node, ast.AST): | ||
self.parent = node.parent | ||
return node | ||
|
||
def visit_Name(self, node: ast.Name) -> Union[ast.Name, ast.Attribute]: | ||
self.generic_visit(node) | ||
if node.id in self.globals: | ||
obj = self.globals[node.id] | ||
if isinstance(obj, self.MPI.Comm): | ||
lattr = ast.Attribute(ast.Name(id='mpi4py', ctx=ast.Load), attr='MPI') | ||
if obj is self.MPI.COMM_NULL: | ||
newnode = ast.copy_location(ast.Attribute(value=lattr, attr='COMM_NULL'), node) | ||
newnode.parent = node.parent | ||
return newnode | ||
return node | ||
|
||
def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: | ||
self.generic_visit(node) | ||
if isinstance(node.attr, str) and node.attr == 'Request': | ||
try: | ||
val = astutils.evalnode(node, self.globals) | ||
if val is self.MPI.Request and not isinstance(node.parent, ast.Attribute): | ||
newnode = ast.copy_location( | ||
ast.Attribute(value=ast.Name(id='dace', ctx=ast.Load), attr='MPI_Request'), node) | ||
newnode.parent = node.parent | ||
return newnode | ||
except SyntaxError: | ||
pass | ||
return node | ||
|
||
|
||
class ModuloConverter(ast.NodeTransformer): | ||
""" Converts a % b expressions to (a + b) % b for C/C++ compatibility. """ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will make some memlets that use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is frontend specific (see, e.g., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this issue comes up now and then, and we forget about it after a while. The problem is that we like to use Python syntax on the SDFG to take advantage of the AST module, but we want this syntax to follow C semantics. Therefore, the frontend has to rewrite certain Python expressions so that they return the same result both in Python and C/C++. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you are thinking that SDFG elements that include Python code should be unparsed to C/C++ equivalent expressions (semantically) during code generation. However, then the other frontends must change to convert C/Fortran/whatever semantics to Python. Is this a good idea? |
||
|
||
def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp: | ||
if isinstance(node.op, ast.Mod): | ||
left = self.generic_visit(node.left) | ||
right = self.generic_visit(node.right) | ||
newleft = ast.copy_location(ast.BinOp(left=left, op=ast.Add(), right=copy.deepcopy(right)), left) | ||
alexnick83 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
node.left = newleft | ||
return node | ||
return self.generic_visit(node) | ||
|
||
|
||
def preprocess_dace_program(f: Callable[..., Any], | ||
argtypes: Dict[str, data.Data], | ||
global_vars: Dict[str, Any], | ||
|
@@ -1544,6 +1600,12 @@ def preprocess_dace_program(f: Callable[..., Any], | |
newmod = global_vars[mod] | ||
#del global_vars[mod] | ||
global_vars[modval] = newmod | ||
|
||
try: | ||
src_ast = MPIResolver(global_vars).visit(src_ast) | ||
except ModuleNotFoundError: | ||
pass | ||
src_ast = ModuloConverter().visit(src_ast) | ||
|
||
# Resolve constants to their values (if they are not already defined in this scope) | ||
# and symbols to their names | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. | ||
alexnick83 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import dace.library | ||
import dace.properties | ||
import dace.sdfg.nodes | ||
from dace.transformation.transformation import ExpandTransformation | ||
from .. import environments | ||
from dace.libraries.mpi.nodes.node import MPINode | ||
|
||
|
||
@dace.library.expansion | ||
class ExpandAlltoallMPI(ExpandTransformation): | ||
|
||
environments = [environments.mpi.MPI] | ||
|
||
@staticmethod | ||
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): | ||
(inbuffer, in_count_str), (outbuffer, out_count_str) = node.validate(parent_sdfg, parent_state) | ||
in_mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(inbuffer.dtype.base_type) | ||
out_mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(outbuffer.dtype.base_type) | ||
|
||
if inbuffer.dtype.veclen > 1: | ||
raise (NotImplementedError) | ||
|
||
comm = "MPI_COMM_WORLD" | ||
if node.grid: | ||
comm = f"__state->{node.grid}_comm" | ||
|
||
code = f""" | ||
int size; | ||
MPI_Comm_size({comm}, &size); | ||
int sendrecv_amt = {in_count_str} / size; | ||
MPI_Alltoall(_inbuffer, sendrecv_amt, {in_mpi_dtype_str}, \ | ||
_outbuffer, sendrecv_amt, {out_mpi_dtype_str}, \ | ||
{comm}); | ||
""" | ||
tasklet = dace.sdfg.nodes.Tasklet(node.name, | ||
node.in_connectors, | ||
node.out_connectors, | ||
code, | ||
language=dace.dtypes.Language.CPP) | ||
return tasklet | ||
|
||
|
||
@dace.library.node | ||
class Alltoall(MPINode): | ||
|
||
# Global properties | ||
implementations = { | ||
"MPI": ExpandAlltoallMPI, | ||
} | ||
default_implementation = "MPI" | ||
|
||
grid = dace.properties.Property(dtype=str, allow_none=True, default=None) | ||
|
||
def __init__(self, name, grid=None, *args, **kwargs): | ||
super().__init__(name, *args, inputs={"_inbuffer"}, outputs={"_outbuffer"}, **kwargs) | ||
self.grid = grid | ||
|
||
def validate(self, sdfg, state): | ||
""" | ||
:return: A three-tuple (buffer, root) of the three data descriptors in the | ||
parent SDFG. | ||
""" | ||
|
||
inbuffer, outbuffer = None, None | ||
for e in state.out_edges(self): | ||
if e.src_conn == "_outbuffer": | ||
outbuffer = sdfg.arrays[e.data.data] | ||
for e in state.in_edges(self): | ||
if e.dst_conn == "_inbuffer": | ||
inbuffer = sdfg.arrays[e.data.data] | ||
|
||
in_count_str = "XXX" | ||
out_count_str = "XXX" | ||
for _, src_conn, _, _, data in state.out_edges(self): | ||
if src_conn == '_outbuffer': | ||
dims = [str(e) for e in data.subset.size_exact()] | ||
out_count_str = "*".join(dims) | ||
for _, _, _, dst_conn, data in state.in_edges(self): | ||
if dst_conn == '_inbuffer': | ||
dims = [str(e) for e in data.subset.size_exact()] | ||
in_count_str = "*".join(dims) | ||
|
||
return (inbuffer, in_count_str), (outbuffer, out_count_str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not happy about this. Can we generalize to add more "supported global types" instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that we could add another API method to register supported global types, where the user would specify how exactly they should be handled. Shouldn't this be a new PR, though?