From 8582216b83e61570007316892498652284a880ba Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Jun 2023 15:32:13 -0500 Subject: [PATCH] typing fixes --- examples/wave-mpi.py | 6 +++--- examples/wave.py | 4 +++- mirgecom/simutil.py | 33 ++++++++++++++++++++++----------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/examples/wave-mpi.py b/examples/wave-mpi.py index 2e50fc849..4f71d8b16 100644 --- a/examples/wave-mpi.py +++ b/examples/wave-mpi.py @@ -78,7 +78,7 @@ def main(actx_class, snapshot_pattern="wave-mpi-{step:04d}-{rank:04d}.pkl", filename="wave-mpi.sqlite", mode="wu", mpi_comm=comm) from mirgecom.simutil import initialize_actx, actx_class_is_profiling - actx, cl_ctx, queue, alloc = initialize_actx(actx_class) + actx, cl_ctx, queue, alloc = initialize_actx(actx_class, comm) use_profiling = actx_class_is_profiling(actx_class) if restart_step is None: @@ -126,7 +126,7 @@ def main(actx_class, snapshot_pattern="wave-mpi-{step:04d}-{rank:04d}.pkl", from grudge.dt_utils import characteristic_lengthscales nodal_dt = characteristic_lengthscales(actx, dcoll) / wave_speed - dt = actx.to_numpy(current_cfl * op.nodal_min(dcoll, "vol", nodal_dt))[()] + dt = actx.to_numpy(current_cfl * op.nodal_min(dcoll, "vol", nodal_dt))[()] # type: ignore[index] t_final = 1 @@ -235,7 +235,7 @@ def rhs(t, w): logmgr.close() final_soln = actx.to_numpy(op.norm(dcoll, fields[0], 2)) - assert np.abs(final_soln - 0.04409852463947439) < 1e-14 + assert np.abs(final_soln - 0.04409852463947439) < 1e-14 # type: ignore[operator] if __name__ == "__main__": diff --git a/examples/wave.py b/examples/wave.py index 86d3fcd27..c152484cc 100644 --- a/examples/wave.py +++ b/examples/wave.py @@ -85,7 +85,7 @@ def main(actx_class, use_logmgr: bool = False) -> None: from grudge.dt_utils import characteristic_lengthscales nodal_dt = characteristic_lengthscales(actx, dcoll) / wave_speed dt = actx.to_numpy(current_cfl * op.nodal_min(dcoll, "vol", - nodal_dt))[()] + nodal_dt))[()] # type: ignore[index] print("%d elements" % mesh.nelements) @@ -130,6 +130,8 @@ def rhs(t, w): if istep % 10 == 0: if use_profiling: + from mirgecom.profiling import PyOpenCLProfilingArrayContext + assert isinstance(actx, PyOpenCLProfilingArrayContext) print(actx.tabulate_profiling_data()) print(istep, t, actx.to_numpy(op.norm(dcoll, fields[0], 2))) vis.write_vtk_file("fld-wave-%04d.vtu" % istep, diff --git a/mirgecom/simutil.py b/mirgecom/simutil.py index 4fc83676e..cdb063e0a 100644 --- a/mirgecom/simutil.py +++ b/mirgecom/simutil.py @@ -30,7 +30,7 @@ ---------------------------- .. autofunction:: configurate -.. autofunction:: get_reasonable_actx_class +.. autofunction:: get_reasonable_array_context_class .. autofunction:: actx_class_is_lazy .. autofunction:: actx_class_is_eager .. autofunction:: actx_class_is_profiling @@ -72,12 +72,12 @@ from functools import partial import grudge.op as op -# from grudge.op import nodal_min, elementwise_min -from arraycontext import map_array_container, flatten + +from arraycontext import map_array_container, flatten, ArrayContext from meshmode.dof_array import DOFArray from mirgecom.viscous import get_viscous_timestep -from typing import List, Dict, Optional, Tuple, Any +from typing import List, Dict, Optional, Tuple, Any, TYPE_CHECKING, Type from grudge.discretization import DiscretizationCollection, PartID from grudge.dof_desc import DD_VOLUME_ALL from mirgecom.utils import normalize_boundaries @@ -85,6 +85,9 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from mpi4py.MPI import Comm + def get_number_of_tetrahedron_nodes(dim, order, include_faces=False): """Get number of nodes (modes) in *dim* Tetrahedron of *order*.""" @@ -211,7 +214,7 @@ def get_sim_timestep( def write_visfile(dcoll, io_fields, visualizer, vizname, step=0, t=0, overwrite=False, vis_timer=None, - comm=None): + comm: Optional["Comm"] = None): """Write parallel VTK output for the fields specified in *io_fields*. This routine writes a parallel-compatible unstructured VTK visualization @@ -1134,7 +1137,7 @@ def configurate(config_key, config_object=None, default_value=None): return default_value -def get_reasonable_actx_class(lazy: bool = False, distributed: bool = True, +def get_reasonable_array_context_class(lazy: bool = False, distributed: bool = True, profiling: bool = False): if lazy and profiling: raise ValueError("Can't specify both lazy and profiling") @@ -1164,7 +1167,10 @@ def actx_class_is_profiling(actx_class) -> bool: return issubclass(actx_class, PyOpenCLProfilingArrayContext) -def initialize_actx(actx_class, comm) -> Tuple[Any]: +def initialize_actx(actx_class: Type[ArrayContext], comm: Optional["Comm"]) -> Tuple[ArrayContext, cl.Context, cl.CommandQueue, cl.tools.AllocatorBase]: + from arraycontext import PytatoPyOpenCLArrayContext, PyOpenCLArrayContext + from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoPyOpenCLArrayContext + cl_ctx = cl.create_some_context() if actx_class_is_profiling(actx_class): queue = cl.CommandQueue(cl_ctx, @@ -1175,16 +1181,21 @@ def initialize_actx(actx_class, comm) -> Tuple[Any]: alloc = get_reasonable_memory_pool(cl_ctx, queue) if actx_class_is_lazy(actx_class): + assert issubclass(actx_class, PytatoPyOpenCLArrayContext) if comm: - actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc) + assert issubclass(actx_class, MPIPytatoPyOpenCLArrayContext) + actx: ArrayContext = actx_class(mpi_communicator=comm, queue=queue, mpi_base_tag=12000, allocator=alloc) # type: ignore[call-arg] else: + assert not issubclass(actx_class, MPIPytatoPyOpenCLArrayContext) actx = actx_class(queue, allocator=alloc) else: - assert actx_class_is_eager(actx_class) + assert issubclass(actx_class, PyOpenCLArrayContext) if comm: - actx = actx_class(comm, queue, allocator=alloc, - force_device_scalars=True) + assert issubclass(actx_class, MPIPyOpenCLArrayContext) + actx = actx_class(mpi_communicator=comm, queue=queue, allocator=alloc, + force_device_scalars=True) # type: ignore[call-arg] else: + assert not issubclass(actx_class, MPIPyOpenCLArrayContext) actx = actx_class(queue, allocator=alloc, force_device_scalars=True) return actx, cl_ctx, queue, alloc