Skip to content

Commit

Permalink
add ARKStepCreateAdjointSolver routine
Browse files Browse the repository at this point in the history
  • Loading branch information
balos1 committed May 14, 2024
1 parent 9753833 commit c6c9900
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 14 deletions.
4 changes: 4 additions & 0 deletions include/arkode/arkode.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <arkode/arkode_butcher.h>
#include <stdio.h>
#include <sundials/sundials_core.h>
#include <sunadjoint/sunadjoint_solver.h>

#ifdef __cplusplus /* wrapper to enable C++ usage */
extern "C" {
Expand Down Expand Up @@ -275,6 +276,9 @@ SUNDIALS_EXPORT int ARKodeSetPostprocessStageFn(void* arkode_mem,
ARKPostProcessFn ProcessStage);
SUNDIALS_EXPORT int ARKodeSetStagePredictFn(void* arkode_mem,
ARKStagePredictFn PredictStage);
SUNDIALS_EXPORT
int ARKodeSetCheckpointScheme(void* arkode_mem,
SUNAdjointCheckpointScheme checkpoint_scheme);

/* Integrate the ODE over an interval in t */
SUNDIALS_EXPORT int ARKodeEvolve(void* arkode_mem, sunrealtype tout,
Expand Down
6 changes: 6 additions & 0 deletions include/arkode/arkode_arkstep.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <sunadaptcontroller/sunadaptcontroller_imexgus.h>
#include <sunadaptcontroller/sunadaptcontroller_soderlind.h>
#include <sundials/sundials_stepper.h>
#include <sunadjoint/sunadjoint_solver.h>

#ifdef __cplusplus /* wrapper to enable C++ usage */
extern "C" {
Expand Down Expand Up @@ -435,6 +436,11 @@ SUNDIALS_EXPORT int ARKStepCreateMRIStepInnerStepper(void* arkode_mem,
SUNDIALS_EXPORT int ARKStepCreateSUNStepper(void* arkode_mem,
SUNStepper* stepper);

/* Adjoint solver functions */
SUNDIALS_EXPORT
int ARKStepCreateAdjointSolver(void* arkode_mem, sunindextype num_cost,
N_Vector sf, SUNAdjointSolver* adj_solver_ptr);

/* Relaxation functions */
SUNDIALS_DEPRECATED_EXPORT_MSG("use ARKodeSetRelaxFn instead")
int ARKStepSetRelaxFn(void* arkode_mem, ARKRelaxFn rfn, ARKRelaxJacFn rjac);
Expand Down
6 changes: 3 additions & 3 deletions include/sunadjoint/sunadjoint_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ typedef struct SUNAdjointSolver_* SUNAdjointSolver;
extern "C" {
#endif

// IDEA: In lieu of Stepper_ID each package that supports adjoint can have a function that creates the adjoint solver.
// E.g., SUNAdjointSolver ARKStepCreateAdjointSolver();
// TODO(CJB): I think this should be a private function that is only used
// within the package CreateAdjointSolver routines.
SUNDIALS_EXPORT
SUNErrCode SUNAdjointSolver_Create(SUNStepper stepper,
sunindextype num_cost_fns, N_Vector sf,
Expand All @@ -52,7 +52,7 @@ SUNErrCode SUNAdjointSolver_Create(SUNStepper stepper,
Solves the adjoint system.
:param adj_solver: The adjoint solver object.
:param tf: The final output time from the forward integration.
:param tf: The final output time from the forward integration.
This is the "starting" time for adjoint solver's backwards integration.
:param tout: The time at which the adjoint solution is desired.
:param sens: The vector of sensitivity solutions dg/dy0 and dg/dp.
Expand Down
45 changes: 45 additions & 0 deletions src/arkode/arkode_arkstep.c
Original file line number Diff line number Diff line change
Expand Up @@ -3247,6 +3247,51 @@ int arkStep_SUNStepperReset(SUNStepper stepper, sunrealtype tR, N_Vector yR)
return (ARKodeReset(arkode_mem, tR, yR));
}

/*---------------------------------------------------------------
Utility routines for interfacing with SUNAdjointSolver
---------------------------------------------------------------*/

int arkStep_fe_Adj(sunrealtype t, N_Vector y, N_Vector ydot, void* user_data)
{
return 0;
}

int arkStep_fi_Adj(sunrealtype t, N_Vector y, N_Vector ydot, void* user_data)
{
return 0;
}

int ARKStepCreateAdjointSolver(void* arkode_mem, sunindextype num_cost,
N_Vector sf, SUNAdjointSolver* adj_solver_ptr)
{
ARKodeMem ark_mem;
ARKodeARKStepMem step_mem;
int retval = arkStep_AccessARKODEStepMem(arkode_mem,
"ARKStepCreateAdjointSolver",
&ark_mem, &step_mem);
if (retval)
{
arkProcessError(NULL, ARK_ILL_INPUT, __LINE__, __func__, __FILE__,
"The ARKStep memory pointer is NULL");
return ARK_ILL_INPUT;
}

// TODO(CJB): should we reinit to tretlast or tcur? tcur could be past the time the
// user asked for in the forward integration if they do not use tstop mode.
ARKodeResize(arkode_mem, sf, -ONE, ark_mem->tretlast, NULL, NULL);
ARKRhsFn fe_adj = step_mem->fe ? arkStep_fe_Adj : NULL;
ARKRhsFn fi_adj = step_mem->fi ? arkStep_fi_Adj: NULL;
ARKStepReInit(arkode_mem, fe_adj, fi_adj, ark_mem->tretlast, sf);

// SUNAdjointSolver will own the SUNStepper and destroy it
SUNStepper stepper;
ARKStepCreateSUNStepper(arkode_mem, &stepper);
SUNAdjointSolver_Create(stepper, num_cost, sf, ark_mem->checkpoint_scheme,
ark_mem->sunctx, adj_solver_ptr);

return ARK_SUCCESS;
}

/*---------------------------------------------------------------
Utility routines for interfacing with MRIStep
---------------------------------------------------------------*/
Expand Down
5 changes: 5 additions & 0 deletions src/arkode/arkode_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <sundials/sundials_adaptcontroller.h>
#include <sundials/sundials_context.h>
#include <sundials/sundials_linearsolver.h>
#include <sunadjoint/sunadjoint_checkpointscheme.h>
#include <sunadjoint/sunadjoint_solver.h>

#include "arkode_adapt_impl.h"
#include "arkode_relaxation_impl.h"
Expand Down Expand Up @@ -488,6 +490,9 @@ struct ARKodeMemRec

sunbooleantype use_compensated_sums;

/* Adjoint solver data */
SUNAdjointCheckpointScheme checkpoint_scheme;

/* XBraid interface variables */
sunbooleantype force_pass; /* when true the step attempt loop will ignore the
return value (kflag) from arkCheckTemporalError
Expand Down
16 changes: 16 additions & 0 deletions src/arkode/arkode_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -2219,6 +2219,22 @@ int ARKodeSetMaxConvFails(void* arkode_mem, int maxncf)
return (ARK_SUCCESS);
}

int ARKodeSetCheckpointScheme(void* arkode_mem, SUNAdjointCheckpointScheme checkpoint_scheme)
{
ARKodeMem ark_mem;
if (arkode_mem == NULL)
{
arkProcessError(NULL, ARK_MEM_NULL, __LINE__, __func__, __FILE__,
MSG_ARK_NO_MEM);
return (ARK_MEM_NULL);
}
ark_mem = (ARKodeMem)arkode_mem;

ark_mem->checkpoint_scheme = checkpoint_scheme;

return (ARK_SUCCESS);
}

/*===============================================================
ARKODE optional output utility functions
===============================================================*/
Expand Down
1 change: 1 addition & 0 deletions src/sunadjoint/sunadjoint_solver.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ SUNErrCode SUNAdjointSolver_Destroy(SUNAdjointSolver* adj_solver_ptr)
{
SUNAdjointSolver adj_solver = *adj_solver_ptr;
// SUNAdjointCheckpointScheme_Destroy(adj_solver->checkpoint_scheme);
SUNStepper_Destroy(&adj_solver->stepper);
free(adj_solver);
*adj_solver_ptr = NULL;
return SUN_SUCCESS;
Expand Down
16 changes: 5 additions & 11 deletions test/unit_tests/arkode/C_serial/ark_test_sunadjoint.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,9 @@ int adjoint_solution(SUNContext sunctx, void* arkode_mem,
// TODO(CJB): Load sf with the sensitivity terminal conditions
N_VConst(0.0, sf);

// TODO(CJB): this block of code needs to be less complicated, should wrap it up in something like ARKStepCreateAdjointSolver()
SUNAdjointSolver adj_solver;
ARKStepCreateAdjointSolver(arkode_mem, num_cost, sf, &adj_solver);
// lotka_volterra_adjoint is J*lambda - user will provide J, we internally will create RHS that is J*lambda
ARKodeResize(arkode_mem, sf, 1.0, tf, NULL, NULL);
ARKStepReInit(arkode_mem, lotka_volterra_adjoint, NULL, tf, sf);
SUNStepper stepper = NULL;
ARKStepCreateSUNStepper(arkode_mem, &stepper);
SUNAdjointSolver adj_solver = NULL;
SUNAdjointSolver_Create(stepper, num_cost, sf, checkpoint_scheme, sunctx,
&adj_solver);
// SUNAdjointSolver_SetJacFn(adj_solver, );
// SUNAdjointSolver_SetJacPFn(adj_solver, );

Expand All @@ -127,8 +121,9 @@ int adjoint_solution(SUNContext sunctx, void* arkode_mem,
N_VPrint(sf);

N_VDestroy(sf);
SUNStepper_Destroy(&stepper);
SUNAdjointSolver_Destroy(&adj_solver);

return 0;
}

int main(int argc, char* argv[])
Expand All @@ -154,8 +149,7 @@ int main(int argc, char* argv[])

// Enable checkpointing during the forward solution
SUNAdjointCheckpointScheme checkpoint_scheme = NULL;
// SUNAdjointCheckpointScheme_NewEmpty(sunctx, &checkpoint_scheme);
// ARKodeSetCheckpointScheme(arkode_mem, checkpoint_scheme);
ARKodeSetCheckpointScheme(arkode_mem, checkpoint_scheme);

//
// Compute the forward solution
Expand Down

0 comments on commit c6c9900

Please sign in to comment.