diff --git a/autotest/test_prt_libmf6_budget.py b/autotest/test_prt_libmf6_budget.py new file mode 100644 index 00000000000..a2d6c53f809 --- /dev/null +++ b/autotest/test_prt_libmf6_budget.py @@ -0,0 +1,110 @@ +""" +This test runs the simulations in test_prt_budget.py, but uses the api +to run the PRT test. +""" + +from pathlib import Path + +import pytest +from framework import TestFramework +from modflowapi import ModflowApi +from test_prt_budget import ( + HorizontalCase, + build_mp7_sim, + build_prt_sim, + check_output, +) + +simname = "prt_libmf6" +cases = [simname] + + +def build_models(idx, test): + # build MODFLOW 6 files + ws = test.workspace + name = cases[idx] + + gwf_sim = HorizontalCase.get_gwf_sim( + test.name, test.workspace, test.targets["mf6"] + ) + prt_sim = build_prt_sim( + test.name, + test.workspace, + test.workspace / "prt", + test.targets["libmf6"], + ) + mp7_sim = build_mp7_sim( + test.name, + test.workspace / "mp7", + test.targets["mp7"], + gwf_sim.get_model(), + ) + + return gwf_sim, prt_sim, mp7_sim + + +def api_func(exe, idx, model_ws=None): + name = cases[idx].upper() + if model_ws is None: + model_ws = Path(".") + + output_file_path = model_ws / "mfsim.stdout" + + try: + mf6 = ModflowApi(exe, working_directory=model_ws) + except Exception as e: + print("Failed to load " + exe) + print("with message: " + str(e)) + return False, open(output_file_path).readlines() + + # initialize the model + try: + mf6.initialize() + except: + return False, open(output_file_path).readlines() + + # time loop + current_time = mf6.get_current_time() + end_time = mf6.get_end_time() + + # model time loop + idx = 0 + while current_time < end_time: + # get dt and prepare for non-linear iterations + dt = mf6.get_time_step() + mf6.prepare_time_step(dt) + + # convergence loop + kiter = 0 + mf6.prepare_solve() + has_converged = mf6.solve() + mf6.finalize_solve() + + # finalize time step and update time + mf6.finalize_time_step() + current_time = mf6.get_current_time() + + # increment counter + idx += 1 + + # cleanup + try: + mf6.finalize() + except: + return False, open(output_file_path).readlines() + + # cleanup and return + return True, open(output_file_path).readlines() + + +@pytest.mark.parametrize("idx, name", enumerate(cases)) +def test_mf6model(idx, name, function_tmpdir, targets): + test = TestFramework( + name=name, + workspace=function_tmpdir, + targets=targets, + build=lambda t: build_models(idx, t), + api_func=lambda exe, ws: api_func(exe, idx, ws), + check=lambda t: check_output(idx, t), + ) + test.run() diff --git a/src/Solution/BaseSolution.f90 b/src/Solution/BaseSolution.f90 index 69a16ada697..8f885bd4932 100644 --- a/src/Solution/BaseSolution.f90 +++ b/src/Solution/BaseSolution.f90 @@ -28,6 +28,11 @@ module BaseSolutionModule procedure(slnaddexchange), deferred :: add_exchange procedure(slngetmodels), deferred :: get_models procedure(slngetexchanges), deferred :: get_exchanges + + ! Expose these for use through the BMI/XMI: + procedure(prepareSolve), deferred :: prepareSolve + procedure(solve), deferred :: solve + procedure(finalizeSolve), deferred :: finalizeSolve end type BaseSolutionType abstract interface @@ -119,6 +124,27 @@ subroutine sln_da(this) class(BaseSolutionType) :: this end subroutine + subroutine prepareSolve(this) + import BaseSolutionType + class(BaseSolutionType) :: this + end subroutine prepareSolve + + subroutine solve(this, kiter) + use KindModule, only: I4B + import BaseSolutionType + class(BaseSolutionType) :: this + integer(I4B), intent(in) :: kiter + end subroutine solve + + subroutine finalizeSolve(this, kiter, isgcnvg, isuppress_output) + use KindModule, only: I4B + import BaseSolutionType + class(BaseSolutionType) :: this + integer(I4B), intent(in) :: kiter + integer(I4B), intent(inout) :: isgcnvg + integer(I4B), intent(in) :: isuppress_output + end subroutine finalizeSolve + end interface contains diff --git a/src/Solution/ExplicitSolution.f90 b/src/Solution/ExplicitSolution.f90 index 35fa3ee1f6f..d6b0d5b71a9 100644 --- a/src/Solution/ExplicitSolution.f90 +++ b/src/Solution/ExplicitSolution.f90 @@ -198,6 +198,9 @@ subroutine sln_ca(this, isgcnvg, isuppress_output) character(len=LINELENGTH) :: line character(len=LINELENGTH) :: fmt integer(I4B) :: im + integer(I4B) :: kiter + + kiter = 1 ! advance the models and solution call this%prepareSolve() @@ -213,10 +216,10 @@ subroutine sln_ca(this, isgcnvg, isuppress_output) case (MNORMAL) ! solve the models - call this%solve() + call this%solve(kiter) ! finish up - call this%finalizeSolve(isgcnvg, isuppress_output) + call this%finalizeSolve(kiter, isgcnvg, isuppress_output) end select end subroutine sln_ca @@ -241,9 +244,10 @@ end subroutine prepareSolve !> @ brief Solve each model !< - subroutine solve(this) + subroutine solve(this, kiter) ! -- dummy variables class(ExplicitSolutionType) :: this !< ExplicitSolutionType instance + integer(I4B), intent(in) :: kiter !< Picard iteration (1 for explicit) ! -- local variables class(NumericalModelType), pointer :: mp => null() integer(I4B) :: im @@ -260,9 +264,10 @@ end subroutine solve !> @ brief Finalize solve !< - subroutine finalizeSolve(this, isgcnvg, isuppress_output) + subroutine finalizeSolve(this, kiter, isgcnvg, isuppress_output) ! -- dummy variables class(ExplicitSolutionType) :: this !< ExplicitSolutionType instance + integer(I4B), intent(in) :: kiter !< Picard iteration number (always 1 for explicit) integer(I4B), intent(inout) :: isgcnvg !< solution group convergence flag integer(I4B), intent(in) :: isuppress_output !< flag for suppressing output ! -- local variables diff --git a/srcbmi/mf6bmiUtil.f90 b/srcbmi/mf6bmiUtil.f90 index 7b8a940fd94..b0ff1770aef 100644 --- a/srcbmi/mf6bmiUtil.f90 +++ b/srcbmi/mf6bmiUtil.f90 @@ -218,11 +218,11 @@ end function get_model_name function getSolution(subcomponent_idx) result(solution) ! -- modules use SolutionGroupModule - use NumericalSolutionModule + use BaseSolutionModule, only: BaseSolutionType, GetBaseSolutionFromList use ListsModule, only: basesolutionlist, solutiongrouplist ! -- dummy variables integer(I4B), intent(in) :: subcomponent_idx !< index of solution - class(NumericalSolutionType), pointer :: solution !< Numerical Solution + class(BaseSolutionType), pointer :: solution !< Base Solution ! -- local variables class(SolutionGroupType), pointer :: sgp integer(I4B) :: solutionIdx @@ -230,7 +230,7 @@ function getSolution(subcomponent_idx) result(solution) ! this is equivalent to how it's done in sgp_ca sgp => GetSolutionGroupFromList(solutiongrouplist, 1) solutionIdx = sgp%idsolutions(subcomponent_idx) - solution => GetNumericalSolutionFromList(basesolutionlist, solutionIdx) + solution => GetBaseSolutionFromList(basesolutionlist, solutionIdx) end function getSolution !> @brief Get the grid type for a named model as a fortran string diff --git a/srcbmi/mf6xmi.F90 b/srcbmi/mf6xmi.F90 index 8116eb117a4..9c476236054 100644 --- a/srcbmi/mf6xmi.F90 +++ b/srcbmi/mf6xmi.F90 @@ -222,13 +222,13 @@ function xmi_prepare_solve(subcomponent_idx) result(bmi_status) & !DIR$ ATTRIBUTES DLLEXPORT :: xmi_prepare_solve ! -- modules use ListsModule, only: solutiongrouplist - use NumericalSolutionModule + use BaseSolutionModule, only: BaseSolutionType use SimVariablesModule, only: istdout ! -- dummy variables integer(kind=c_int) :: subcomponent_idx !< index of the subcomponent (i.e. Numerical Solution) integer(kind=c_int) :: bmi_status !< BMI status code ! -- local variables - class(NumericalSolutionType), pointer :: ns + class(BaseSolutionType), pointer :: bs ! people might not call 'xmi_get_subcomponent_count' first, so let's repeat this: if (solutiongrouplist%Count() /= 1) then @@ -238,11 +238,11 @@ function xmi_prepare_solve(subcomponent_idx) result(bmi_status) & return end if - ! get the numerical solution we are running - ns => getSolution(subcomponent_idx) + ! get the solution we are running + bs => getSolution(subcomponent_idx) ! *_ad (model, exg, sln) - call ns%prepareSolve() + call bs%prepareSolve() ! reset counter allocate (iterationCounter) @@ -262,27 +262,34 @@ function xmi_solve(subcomponent_idx, has_converged) result(bmi_status) & bind(C, name="solve") !DIR$ ATTRIBUTES DLLEXPORT :: xmi_solve ! -- modules - use NumericalSolutionModule + use BaseSolutionModule, only: BaseSolutionType + use NumericalSolutionModule, only: NumericalSolutionType + use ExplicitSolutionModule, only: ExplicitSolutionType ! -- dummy variables integer(kind=c_int), intent(in) :: subcomponent_idx !< index of the subcomponent (i.e. Numerical Solution) integer(kind=c_int), intent(out) :: has_converged !< equal to 1 for convergence, 0 otherwise integer(kind=c_int) :: bmi_status !< BMI status code ! -- local variables - class(NumericalSolutionType), pointer :: ns + class(BaseSolutionType), pointer :: bs ! get the numerical solution we are running - ns => getSolution(subcomponent_idx) + bs => getSolution(subcomponent_idx) ! execute the nth iteration iterationCounter = iterationCounter + 1 - call ns%solve(iterationCounter) + call bs%solve(iterationCounter) ! the following check is equivalent to that in NumericalSolution%sln_ca - if (ns%icnvg == 1) then + select type (bs) + class is (NumericalSolutionType) + if (bs%icnvg == 1) then + has_converged = 1 + else + has_converged = 0 + end if + class is (ExplicitSolutionType) has_converged = 1 - else - has_converged = 0 - end if + end select bmi_status = BMI_SUCCESS @@ -300,23 +307,23 @@ function xmi_finalize_solve(subcomponent_idx) result(bmi_status) & bind(C, name="finalize_solve") !DIR$ ATTRIBUTES DLLEXPORT :: xmi_finalize_solve ! -- modules - use NumericalSolutionModule + use BaseSolutionModule, only: BaseSolutionType ! -- dummy variables integer(kind=c_int), intent(in) :: subcomponent_idx !< index of the subcomponent (i.e. Numerical Solution) integer(kind=c_int) :: bmi_status !< BMI status code ! -- local variables - class(NumericalSolutionType), pointer :: ns + class(BaseSolutionType), pointer :: bs integer(I4B) :: hasConverged ! get the numerical solution we are running - ns => getSolution(subcomponent_idx) + bs => getSolution(subcomponent_idx) ! hasConverged is equivalent to the isgcnvg variable which is initialized to 1, ! see the body of the picard loop in SolutionGroupType%sgp_ca hasConverged = 1 ! finish up - call ns%finalizeSolve(iterationCounter, hasConverged, 0) + call bs%finalizeSolve(iterationCounter, hasConverged, 0) ! check convergence on solution if (.not. hasConverged == 1) then