Skip to content

Commit

Permalink
typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jun 27, 2023
1 parent 58695e8 commit 8582216
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
6 changes: 3 additions & 3 deletions examples/wave-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(actx_class, snapshot_pattern="wave-mpi-{step:04d}-{rank:04d}.pkl",
filename="wave-mpi.sqlite", mode="wu", mpi_comm=comm)

from mirgecom.simutil import initialize_actx, actx_class_is_profiling
actx, cl_ctx, queue, alloc = initialize_actx(actx_class)
actx, cl_ctx, queue, alloc = initialize_actx(actx_class, comm)
use_profiling = actx_class_is_profiling(actx_class)

if restart_step is None:
Expand Down Expand Up @@ -126,7 +126,7 @@ def main(actx_class, snapshot_pattern="wave-mpi-{step:04d}-{rank:04d}.pkl",
from grudge.dt_utils import characteristic_lengthscales
nodal_dt = characteristic_lengthscales(actx, dcoll) / wave_speed

dt = actx.to_numpy(current_cfl * op.nodal_min(dcoll, "vol", nodal_dt))[()]
dt = actx.to_numpy(current_cfl * op.nodal_min(dcoll, "vol", nodal_dt))[()] # type: ignore[index]

t_final = 1

Expand Down Expand Up @@ -235,7 +235,7 @@ def rhs(t, w):
logmgr.close()

final_soln = actx.to_numpy(op.norm(dcoll, fields[0], 2))
assert np.abs(final_soln - 0.04409852463947439) < 1e-14
assert np.abs(final_soln - 0.04409852463947439) < 1e-14 # type: ignore[operator]


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion examples/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def main(actx_class, use_logmgr: bool = False) -> None:
from grudge.dt_utils import characteristic_lengthscales
nodal_dt = characteristic_lengthscales(actx, dcoll) / wave_speed
dt = actx.to_numpy(current_cfl * op.nodal_min(dcoll, "vol",
nodal_dt))[()]
nodal_dt))[()] # type: ignore[index]

print("%d elements" % mesh.nelements)

Expand Down Expand Up @@ -130,6 +130,8 @@ def rhs(t, w):

if istep % 10 == 0:
if use_profiling:
from mirgecom.profiling import PyOpenCLProfilingArrayContext
assert isinstance(actx, PyOpenCLProfilingArrayContext)
print(actx.tabulate_profiling_data())
print(istep, t, actx.to_numpy(op.norm(dcoll, fields[0], 2)))
vis.write_vtk_file("fld-wave-%04d.vtu" % istep,
Expand Down
33 changes: 22 additions & 11 deletions mirgecom/simutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
----------------------------
.. autofunction:: configurate
.. autofunction:: get_reasonable_actx_class
.. autofunction:: get_reasonable_array_context_class
.. autofunction:: actx_class_is_lazy
.. autofunction:: actx_class_is_eager
.. autofunction:: actx_class_is_profiling
Expand Down Expand Up @@ -72,19 +72,22 @@
from functools import partial

import grudge.op as op
# from grudge.op import nodal_min, elementwise_min
from arraycontext import map_array_container, flatten

from arraycontext import map_array_container, flatten, ArrayContext
from meshmode.dof_array import DOFArray
from mirgecom.viscous import get_viscous_timestep

from typing import List, Dict, Optional, Tuple, Any
from typing import List, Dict, Optional, Tuple, Any, TYPE_CHECKING, Type
from grudge.discretization import DiscretizationCollection, PartID
from grudge.dof_desc import DD_VOLUME_ALL
from mirgecom.utils import normalize_boundaries
import pyopencl as cl

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mpi4py.MPI import Comm


def get_number_of_tetrahedron_nodes(dim, order, include_faces=False):
"""Get number of nodes (modes) in *dim* Tetrahedron of *order*."""
Expand Down Expand Up @@ -211,7 +214,7 @@ def get_sim_timestep(

def write_visfile(dcoll, io_fields, visualizer, vizname,
step=0, t=0, overwrite=False, vis_timer=None,
comm=None):
comm: Optional["Comm"] = None):
"""Write parallel VTK output for the fields specified in *io_fields*.
This routine writes a parallel-compatible unstructured VTK visualization
Expand Down Expand Up @@ -1134,7 +1137,7 @@ def configurate(config_key, config_object=None, default_value=None):
return default_value


def get_reasonable_actx_class(lazy: bool = False, distributed: bool = True,
def get_reasonable_array_context_class(lazy: bool = False, distributed: bool = True,
profiling: bool = False):
if lazy and profiling:
raise ValueError("Can't specify both lazy and profiling")
Expand Down Expand Up @@ -1164,7 +1167,10 @@ def actx_class_is_profiling(actx_class) -> bool:
return issubclass(actx_class, PyOpenCLProfilingArrayContext)


def initialize_actx(actx_class, comm) -> Tuple[Any]:
def initialize_actx(actx_class: Type[ArrayContext], comm: Optional["Comm"]) -> Tuple[ArrayContext, cl.Context, cl.CommandQueue, cl.tools.AllocatorBase]:
from arraycontext import PytatoPyOpenCLArrayContext, PyOpenCLArrayContext
from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoPyOpenCLArrayContext

cl_ctx = cl.create_some_context()
if actx_class_is_profiling(actx_class):
queue = cl.CommandQueue(cl_ctx,
Expand All @@ -1175,16 +1181,21 @@ def initialize_actx(actx_class, comm) -> Tuple[Any]:
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if actx_class_is_lazy(actx_class):
assert issubclass(actx_class, PytatoPyOpenCLArrayContext)
if comm:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
assert issubclass(actx_class, MPIPytatoPyOpenCLArrayContext)
actx: ArrayContext = actx_class(mpi_communicator=comm, queue=queue, mpi_base_tag=12000, allocator=alloc) # type: ignore[call-arg]
else:
assert not issubclass(actx_class, MPIPytatoPyOpenCLArrayContext)
actx = actx_class(queue, allocator=alloc)
else:
assert actx_class_is_eager(actx_class)
assert issubclass(actx_class, PyOpenCLArrayContext)
if comm:
actx = actx_class(comm, queue, allocator=alloc,
force_device_scalars=True)
assert issubclass(actx_class, MPIPyOpenCLArrayContext)
actx = actx_class(mpi_communicator=comm, queue=queue, allocator=alloc,
force_device_scalars=True) # type: ignore[call-arg]
else:
assert not issubclass(actx_class, MPIPyOpenCLArrayContext)
actx = actx_class(queue, allocator=alloc, force_device_scalars=True)

return actx, cl_ctx, queue, alloc
Expand Down

0 comments on commit 8582216

Please sign in to comment.