Skip to content

Commit

Permalink
add option to compile the full timestepper in advance_state
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Nov 22, 2024
1 parent 7f94b9c commit 89dd6e8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
7 changes: 6 additions & 1 deletion mirgecom/integrators/lsrk.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def lsrk_step(coefs, state, t, dt, rhs):

def euler_step(state, t, dt, rhs):
"""Take one step using the explicit, 1st-order accurate, Euler method."""
return lsrk_step(EulerCoefs, state, t, dt, rhs)
# Full-timestepper compilation doesn't like this version; triggers loop
# nest error in meshmode array context:
# NotImplementedError: Cannot fit loop nest 'frozenset()' into known
# set of loop-nest patterns.
# return lsrk_step(EulerCoefs, state, t, dt, rhs)
return state + dt * rhs(t, state)


LSRK54CarpenterKennedyCoefs = LSRKCoefficients(
Expand Down
32 changes: 23 additions & 9 deletions mirgecom/steppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import numpy as np
from functools import partial
from mirgecom.utils import force_evaluation
from pytools import memoize_in
from arraycontext import get_container_context_recursively_opt
Expand All @@ -39,9 +40,9 @@ def _compile_timestepper(actx, timestepper, rhs):
@memoize_in(actx, ("mirgecom_compiled_operator",
timestepper, rhs))
def get_timestepper():
return actx.compile(lambda y, t, dt: timestepper(state=y, t=t,
dt=dt,
rhs=rhs))
return actx.compile(
lambda state, t, dt: timestepper(
state=state, t=t, dt=dt, rhs=rhs))

return get_timestepper()

Expand Down Expand Up @@ -76,7 +77,8 @@ def _is_unevaluated(actx, ary):
def _advance_state_stepper_func(rhs, timestepper, state, t_final, dt=0,
t=0.0, istep=0, pre_step_callback=None,
post_step_callback=None, force_eval=None,
local_dt=False, max_steps=None, compile_rhs=True):
local_dt=False, max_steps=None, compile_rhs=True,
compile_timestepper=False):
"""Advance state from some time (t) to some time (t_final).
Parameters
Expand Down Expand Up @@ -119,6 +121,8 @@ def _advance_state_stepper_func(rhs, timestepper, state, t_final, dt=0,
the domain.
compile_rhs
An optional boolean indicating whether *rhs* can be compiled.
compile_timestepper
An optional boolean indicating whether *timestepper* can be compiled.
Returns
-------
Expand All @@ -145,10 +149,14 @@ def _advance_state_stepper_func(rhs, timestepper, state, t_final, dt=0,

state = force_evaluation(actx, state)

if compile_rhs:
maybe_compiled_rhs = _compile_rhs(actx, rhs)
if compile_timestepper:
maybe_compiled_timestepper = _compile_timestepper(actx, timestepper, rhs)
else:
maybe_compiled_rhs = rhs
if compile_rhs:
maybe_compiled_rhs = _compile_rhs(actx, rhs)
else:
maybe_compiled_rhs = rhs
maybe_compiled_timestepper = partial(timestepper, rhs=maybe_compiled_rhs)

while marching_loc < marching_limit:
if max_steps is not None:
Expand All @@ -161,7 +169,7 @@ def _advance_state_stepper_func(rhs, timestepper, state, t_final, dt=0,
if force_eval:
state = force_evaluation(actx, state)

state = timestepper(state=state, t=t, dt=dt, rhs=maybe_compiled_rhs)
state = maybe_compiled_timestepper(state=state, t=t, dt=dt)

if force_eval is None:
if _is_unevaluated(actx, state):
Expand Down Expand Up @@ -350,7 +358,7 @@ def generate_singlerate_leap_advancer(timestepper, component_id, rhs, t, dt,
def advance_state(rhs, timestepper, state, t_final, t=0, istep=0, dt=0,
max_steps=None, component_id="state", pre_step_callback=None,
post_step_callback=None, force_eval=None, local_dt=False,
compile_rhs=True):
compile_rhs=True, compile_timestepper=False):
"""Determine what stepper to use and advance the state from (t) to (t_final).
Parameters
Expand Down Expand Up @@ -399,6 +407,8 @@ def advance_state(rhs, timestepper, state, t_final, t=0, istep=0, dt=0,
the domain.
compile_rhs
An optional boolean indicating whether *rhs* can be compiled.
compile_timestepper
An optional boolean indicating whether *timestepper* can be compiled.
Returns
-------
Expand All @@ -424,6 +434,9 @@ def advance_state(rhs, timestepper, state, t_final, t=0, istep=0, dt=0,
raise ValueError("Local timestepping is not supported for Leap-based"
" integrators.")
if leap_timestepper:
if compile_timestepper:
raise ValueError(
"Leap timestepper is not compatible with compile_timestepper=True")
(current_step, current_t, current_state) = \
_advance_state_leap(
rhs=rhs, timestepper=timestepper,
Expand All @@ -444,6 +457,7 @@ def advance_state(rhs, timestepper, state, t_final, t=0, istep=0, dt=0,
istep=istep, force_eval=force_eval,
max_steps=max_steps, local_dt=local_dt,
compile_rhs=compile_rhs,
compile_timestepper=compile_timestepper,
)

return current_step, current_t, current_state

0 comments on commit 89dd6e8

Please sign in to comment.