Skip to content

Commit

Permalink
factor out actx init (#918)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jun 30, 2023
1 parent c8dae8b commit 27130a3
Show file tree
Hide file tree
Showing 24 changed files with 396 additions and 506 deletions.
3 changes: 3 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -- Project information -----------------------------------------------------
import sys

project = "mirgecom"
copyright = ("2020, University of Illinois Board of Trustees")
Expand Down Expand Up @@ -97,3 +98,5 @@
nitpick_ignore_regex = [
("py:class", r".*BoundaryDomainTag.*")
]

sys._BUILDING_SPHINX_DOCS = True
1 change: 1 addition & 0 deletions doc/support/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Random Pile'o'Tools

.. automodule:: mirgecom.simutil
.. automodule:: mirgecom.utils
.. automodule:: mirgecom.array_context
35 changes: 12 additions & 23 deletions examples/autoignition-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"""
import logging
import numpy as np
import pyopencl as cl
from functools import partial

from meshmode.mesh import BTAG_ALL
Expand Down Expand Up @@ -75,13 +74,11 @@ class MyRuntimeError(RuntimeError):


@mpi_entry_point
def main(actx_class, ctx_factory=cl.create_some_context, use_logmgr=True,
use_leap=False, use_overintegration=False, use_profiling=False,
casename=None, lazy=False, rst_filename=None, log_dependent=True,
def main(actx_class, use_logmgr=True,
use_leap=False, use_overintegration=False,
casename=None, rst_filename=None, log_dependent=True,
viscous_terms_on=False):
"""Drive example."""
cl_ctx = ctx_factory()

if casename is None:
casename = "mirgecom"

Expand All @@ -96,19 +93,10 @@ def main(actx_class, ctx_factory=cl.create_some_context, use_logmgr=True,
logmgr = initialize_logmgr(use_logmgr,
filename=f"{casename}.sqlite", mode="wu", mpi_comm=comm)

if use_profiling:
queue = cl.CommandQueue(cl_ctx,
properties=cl.command_queue_properties.PROFILING_ENABLE)
else:
queue = cl.CommandQueue(cl_ctx)

from mirgecom.simutil import get_reasonable_memory_pool
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
else:
actx = actx_class(comm, queue, allocator=alloc, force_device_scalars=True)
from mirgecom.array_context import initialize_actx, actx_class_is_profiling
actx = initialize_actx(actx_class, comm)
queue = getattr(actx, "queue", None)
use_profiling = actx_class_is_profiling(actx_class)

# Some discretization parameters
dim = 2
Expand Down Expand Up @@ -690,8 +678,9 @@ def my_rhs(t, state):
if lazy:
raise ValueError("Can't use lazy and profiling together.")

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(
lazy=lazy, distributed=True, profiling=args.profiling)

logging.basicConfig(format="%(message)s", level=logging.INFO)
if args.casename:
Expand All @@ -701,8 +690,8 @@ def my_rhs(t, state):
rst_filename = args.restart_file

main(actx_class, use_logmgr=args.log, use_leap=args.leap,
use_overintegration=args.overintegration, use_profiling=args.profiling,
lazy=lazy, casename=casename, rst_filename=rst_filename,
use_overintegration=args.overintegration,
casename=casename, rst_filename=rst_filename,
log_dependent=log_dependent, viscous_terms_on=args.navierstokes)

# vim: foldmethod=marker
33 changes: 10 additions & 23 deletions examples/combozzle-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import time
import yaml
import numpy as np
import pyopencl as cl
from functools import partial

from meshmode.array_context import PyOpenCLArrayContext
Expand Down Expand Up @@ -160,15 +159,12 @@ def __call__(self, x_vec, *, time=0.0):


@mpi_entry_point
def main(ctx_factory=cl.create_some_context, use_logmgr=True,
use_leap=False, use_overintegration=False,
use_profiling=False, casename=None, lazy=False,
def main(use_logmgr=True,
use_overintegration=False, casename=None,
rst_filename=None, actx_class=PyOpenCLArrayContext,
log_dependent=False, input_file=None,
force_eval=True):
"""Drive example."""
cl_ctx = ctx_factory()

if casename is None:
casename = "mirgecom"

Expand Down Expand Up @@ -600,19 +596,10 @@ def main(ctx_factory=cl.create_some_context, use_logmgr=True,
print(f"ACTX setup start: {time.ctime(time.time())}")
comm.Barrier()

if use_profiling:
queue = cl.CommandQueue(cl_ctx,
properties=cl.command_queue_properties.PROFILING_ENABLE)
else:
queue = cl.CommandQueue(cl_ctx)

from mirgecom.simutil import get_reasonable_memory_pool
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
else:
actx = actx_class(comm, queue, allocator=alloc, force_device_scalars=True)
from mirgecom.array_context import initialize_actx, actx_class_is_profiling
actx = initialize_actx(actx_class, comm)
queue = getattr(actx, "queue", None)
use_profiling = actx_class_is_profiling(actx_class)

rst_path = "restart_data/"
rst_pattern = (
Expand Down Expand Up @@ -1287,8 +1274,9 @@ def dummy_rhs(t, state):
if lazy:
raise ValueError("Can't use lazy and profiling together.")

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(
lazy=lazy, distributed=True, profiling=args.profiling)

logging.basicConfig(format="%(message)s", level=logging.INFO)
if args.casename:
Expand All @@ -1306,9 +1294,8 @@ def dummy_rhs(t, state):

print(f"Calling main: {time.ctime(time.time())}")

main(use_logmgr=args.log, use_leap=args.leap, input_file=input_file,
main(use_logmgr=args.log, input_file=input_file,
use_overintegration=args.overintegration,
use_profiling=args.profiling, lazy=lazy,
casename=casename, rst_filename=rst_filename, actx_class=actx_class,
log_dependent=log_dependent, force_eval=force_eval)

Expand Down
35 changes: 12 additions & 23 deletions examples/doublemach-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import logging
import numpy as np
import pyopencl as cl
from functools import partial

from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa
Expand Down Expand Up @@ -116,15 +115,13 @@ def get_doublemach_mesh():


@mpi_entry_point
def main(ctx_factory=cl.create_some_context, use_logmgr=True,
use_leap=False, use_profiling=False, use_overintegration=False,
casename=None, rst_filename=None, actx_class=None, lazy=False):
def main(use_logmgr=True,
use_leap=False, use_overintegration=False,
casename=None, rst_filename=None, actx_class=None):
"""Drive the example."""
if actx_class is None:
raise RuntimeError("Array context class missing.")

cl_ctx = ctx_factory()

if casename is None:
casename = "mirgecom"

Expand All @@ -136,19 +133,10 @@ def main(ctx_factory=cl.create_some_context, use_logmgr=True,
logmgr = initialize_logmgr(use_logmgr,
filename=f"{casename}.sqlite", mode="wu", mpi_comm=comm)

if use_profiling:
queue = cl.CommandQueue(
cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE)
else:
queue = cl.CommandQueue(cl_ctx)

from mirgecom.simutil import get_reasonable_memory_pool
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
else:
actx = actx_class(comm, queue, allocator=alloc, force_device_scalars=True)
from mirgecom.array_context import initialize_actx, actx_class_is_profiling
actx = initialize_actx(actx_class, comm)
queue = getattr(actx, "queue", None)
use_profiling = actx_class_is_profiling(actx_class)

# Timestepping control
current_step = 0
Expand Down Expand Up @@ -461,8 +449,9 @@ def my_rhs(t, state):
if lazy:
raise ValueError("Can't use lazy and profiling together.")

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True,
profiling=args.profiling)

logging.basicConfig(format="%(message)s", level=logging.INFO)
if args.casename:
Expand All @@ -471,8 +460,8 @@ def my_rhs(t, state):
if args.restart_file:
rst_filename = args.restart_file

main(use_logmgr=args.log, use_leap=args.leap, use_profiling=args.profiling,
use_overintegration=args.overintegration, lazy=lazy,
main(use_logmgr=args.log, use_leap=args.leap,
use_overintegration=args.overintegration,
casename=casename, rst_filename=rst_filename, actx_class=actx_class)

# vim: foldmethod=marker
35 changes: 12 additions & 23 deletions examples/doublemach_physical_av-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import logging
import numpy as np
import pyopencl as cl
from functools import partial

from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa
Expand Down Expand Up @@ -120,15 +119,13 @@ def get_doublemach_mesh():


@mpi_entry_point
def main(ctx_factory=cl.create_some_context, use_logmgr=True,
use_leap=False, use_profiling=False, use_overintegration=False,
casename=None, rst_filename=None, actx_class=None, lazy=False):
def main(use_logmgr=True,
use_leap=False, use_overintegration=False,
casename=None, rst_filename=None, actx_class=None):
"""Drive the example."""
if actx_class is None:
raise RuntimeError("Array context class missing.")

cl_ctx = ctx_factory()

if casename is None:
casename = "mirgecom"

Expand All @@ -151,19 +148,10 @@ def main(ctx_factory=cl.create_some_context, use_logmgr=True,
logmgr = initialize_logmgr(use_logmgr,
filename=logname, mode="wo", mpi_comm=comm)

if use_profiling:
queue = cl.CommandQueue(
cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE)
else:
queue = cl.CommandQueue(cl_ctx)

from mirgecom.simutil import get_reasonable_memory_pool
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
else:
actx = actx_class(comm, queue, allocator=alloc, force_device_scalars=True)
from mirgecom.array_context import initialize_actx, actx_class_is_profiling
actx = initialize_actx(actx_class, comm)
queue = getattr(actx, "queue", None)
use_profiling = actx_class_is_profiling(actx_class)

# Timestepping control
current_step = 0
Expand Down Expand Up @@ -732,9 +720,10 @@ def _my_rhs_phys_visc_div_av(t, state):
if lazy:
raise ValueError("Can't use lazy and profiling together.")

from grudge.array_context import get_reasonable_array_context_class
from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy,
distributed=True)
distributed=True,
profiling=args.profiling)

logging.basicConfig(format="%(message)s", level=logging.INFO)
if args.casename:
Expand All @@ -743,8 +732,8 @@ def _my_rhs_phys_visc_div_av(t, state):
if args.restart_file:
rst_filename = args.restart_file

main(use_logmgr=args.log, use_leap=args.leap, use_profiling=args.profiling,
use_overintegration=args.overintegration, lazy=lazy,
main(use_logmgr=args.log, use_leap=args.leap,
use_overintegration=args.overintegration,
casename=casename, rst_filename=rst_filename, actx_class=actx_class)

# vim: foldmethod=marker
35 changes: 11 additions & 24 deletions examples/heat-source-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import numpy as np
import numpy.linalg as la # noqa
import pyopencl as cl

from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa
import grudge.op as op
Expand All @@ -48,33 +47,20 @@


@mpi_entry_point
def main(actx_class, ctx_factory=cl.create_some_context, use_logmgr=True,
use_leap=False, use_profiling=False, casename=None, lazy=False,
rst_filename=None):
def main(actx_class, use_logmgr=True,
use_leap=False, casename=None, rst_filename=None):
"""Run the example."""
cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx)

from mpi4py import MPI
comm = MPI.COMM_WORLD
num_parts = comm.Get_size()

logmgr = initialize_logmgr(use_logmgr,
filename="heat-source.sqlite", mode="wu", mpi_comm=comm)

if use_profiling:
queue = cl.CommandQueue(
cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE)
else:
queue = cl.CommandQueue(cl_ctx)

from mirgecom.simutil import get_reasonable_memory_pool
alloc = get_reasonable_memory_pool(cl_ctx, queue)

if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000, allocator=alloc)
else:
actx = actx_class(comm, queue, allocator=alloc, force_device_scalars=True)
from mirgecom.array_context import initialize_actx, actx_class_is_profiling
actx = initialize_actx(actx_class, comm)
queue = getattr(actx, "queue", None)
use_profiling = actx_class_is_profiling(actx_class)

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
Expand Down Expand Up @@ -208,8 +194,9 @@ def rhs(t, u):
if lazy:
raise ValueError("Can't use lazy and profiling together.")

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(
lazy=lazy, distributed=True, profiling=args.profiling)

logging.basicConfig(format="%(message)s", level=logging.INFO)
if args.casename:
Expand All @@ -218,7 +205,7 @@ def rhs(t, u):
if args.restart_file:
rst_filename = args.restart_file

main(actx_class, use_logmgr=args.log, use_leap=args.leap, lazy=lazy,
use_profiling=args.profiling, casename=casename, rst_filename=rst_filename)
main(actx_class, use_logmgr=args.log, use_leap=args.leap,
casename=casename, rst_filename=rst_filename)

# vim: foldmethod=marker
Loading

0 comments on commit 27130a3

Please sign in to comment.