Skip to content

Commit

Permalink
fix(bmi-ems): ExplicitModelSolution for PRT fixed for XMI/API calls (M…
Browse files Browse the repository at this point in the history
…ODFLOW-USGS#1962)

* fix(bmi-ems): ExplicitModelSolution for PRT fixed to work with XMI/API calls

* fprettify

* add test to make sure PRT runs with API

* ruff

* more ruff
  • Loading branch information
langevin-usgs authored Jul 26, 2024
1 parent ae6cc66 commit c178700
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 24 deletions.
110 changes: 110 additions & 0 deletions autotest/test_prt_libmf6_budget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
This test runs the simulations in test_prt_budget.py, but uses the api
to run the PRT test.
"""

from pathlib import Path

import pytest
from framework import TestFramework
from modflowapi import ModflowApi
from test_prt_budget import (
HorizontalCase,
build_mp7_sim,
build_prt_sim,
check_output,
)

simname = "prt_libmf6"
cases = [simname]


def build_models(idx, test):
# build MODFLOW 6 files
ws = test.workspace
name = cases[idx]

gwf_sim = HorizontalCase.get_gwf_sim(
test.name, test.workspace, test.targets["mf6"]
)
prt_sim = build_prt_sim(
test.name,
test.workspace,
test.workspace / "prt",
test.targets["libmf6"],
)
mp7_sim = build_mp7_sim(
test.name,
test.workspace / "mp7",
test.targets["mp7"],
gwf_sim.get_model(),
)

return gwf_sim, prt_sim, mp7_sim


def api_func(exe, idx, model_ws=None):
name = cases[idx].upper()
if model_ws is None:
model_ws = Path(".")

output_file_path = model_ws / "mfsim.stdout"

try:
mf6 = ModflowApi(exe, working_directory=model_ws)
except Exception as e:
print("Failed to load " + exe)
print("with message: " + str(e))
return False, open(output_file_path).readlines()

# initialize the model
try:
mf6.initialize()
except:
return False, open(output_file_path).readlines()

# time loop
current_time = mf6.get_current_time()
end_time = mf6.get_end_time()

# model time loop
idx = 0
while current_time < end_time:
# get dt and prepare for non-linear iterations
dt = mf6.get_time_step()
mf6.prepare_time_step(dt)

# convergence loop
kiter = 0
mf6.prepare_solve()
has_converged = mf6.solve()
mf6.finalize_solve()

# finalize time step and update time
mf6.finalize_time_step()
current_time = mf6.get_current_time()

# increment counter
idx += 1

# cleanup
try:
mf6.finalize()
except:
return False, open(output_file_path).readlines()

# cleanup and return
return True, open(output_file_path).readlines()


@pytest.mark.parametrize("idx, name", enumerate(cases))
def test_mf6model(idx, name, function_tmpdir, targets):
test = TestFramework(
name=name,
workspace=function_tmpdir,
targets=targets,
build=lambda t: build_models(idx, t),
api_func=lambda exe, ws: api_func(exe, idx, ws),
check=lambda t: check_output(idx, t),
)
test.run()
26 changes: 26 additions & 0 deletions src/Solution/BaseSolution.f90
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ module BaseSolutionModule
procedure(slnaddexchange), deferred :: add_exchange
procedure(slngetmodels), deferred :: get_models
procedure(slngetexchanges), deferred :: get_exchanges

! Expose these for use through the BMI/XMI:
procedure(prepareSolve), deferred :: prepareSolve
procedure(solve), deferred :: solve
procedure(finalizeSolve), deferred :: finalizeSolve
end type BaseSolutionType

abstract interface
Expand Down Expand Up @@ -119,6 +124,27 @@ subroutine sln_da(this)
class(BaseSolutionType) :: this
end subroutine

subroutine prepareSolve(this)
import BaseSolutionType
class(BaseSolutionType) :: this
end subroutine prepareSolve

subroutine solve(this, kiter)
use KindModule, only: I4B
import BaseSolutionType
class(BaseSolutionType) :: this
integer(I4B), intent(in) :: kiter
end subroutine solve

subroutine finalizeSolve(this, kiter, isgcnvg, isuppress_output)
use KindModule, only: I4B
import BaseSolutionType
class(BaseSolutionType) :: this
integer(I4B), intent(in) :: kiter
integer(I4B), intent(inout) :: isgcnvg
integer(I4B), intent(in) :: isuppress_output
end subroutine finalizeSolve

end interface

contains
Expand Down
13 changes: 9 additions & 4 deletions src/Solution/ExplicitSolution.f90
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ subroutine sln_ca(this, isgcnvg, isuppress_output)
character(len=LINELENGTH) :: line
character(len=LINELENGTH) :: fmt
integer(I4B) :: im
integer(I4B) :: kiter

kiter = 1

! advance the models and solution
call this%prepareSolve()
Expand All @@ -213,10 +216,10 @@ subroutine sln_ca(this, isgcnvg, isuppress_output)
case (MNORMAL)

! solve the models
call this%solve()
call this%solve(kiter)

! finish up
call this%finalizeSolve(isgcnvg, isuppress_output)
call this%finalizeSolve(kiter, isgcnvg, isuppress_output)
end select
end subroutine sln_ca

Expand All @@ -241,9 +244,10 @@ end subroutine prepareSolve

!> @ brief Solve each model
!<
subroutine solve(this)
subroutine solve(this, kiter)
! -- dummy variables
class(ExplicitSolutionType) :: this !< ExplicitSolutionType instance
integer(I4B), intent(in) :: kiter !< Picard iteration (1 for explicit)
! -- local variables
class(NumericalModelType), pointer :: mp => null()
integer(I4B) :: im
Expand All @@ -260,9 +264,10 @@ end subroutine solve

!> @ brief Finalize solve
!<
subroutine finalizeSolve(this, isgcnvg, isuppress_output)
subroutine finalizeSolve(this, kiter, isgcnvg, isuppress_output)
! -- dummy variables
class(ExplicitSolutionType) :: this !< ExplicitSolutionType instance
integer(I4B), intent(in) :: kiter !< Picard iteration number (always 1 for explicit)
integer(I4B), intent(inout) :: isgcnvg !< solution group convergence flag
integer(I4B), intent(in) :: isuppress_output !< flag for suppressing output
! -- local variables
Expand Down
6 changes: 3 additions & 3 deletions srcbmi/mf6bmiUtil.f90
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,19 @@ end function get_model_name
function getSolution(subcomponent_idx) result(solution)
! -- modules
use SolutionGroupModule
use NumericalSolutionModule
use BaseSolutionModule, only: BaseSolutionType, GetBaseSolutionFromList
use ListsModule, only: basesolutionlist, solutiongrouplist
! -- dummy variables
integer(I4B), intent(in) :: subcomponent_idx !< index of solution
class(NumericalSolutionType), pointer :: solution !< Numerical Solution
class(BaseSolutionType), pointer :: solution !< Base Solution
! -- local variables
class(SolutionGroupType), pointer :: sgp
integer(I4B) :: solutionIdx

! this is equivalent to how it's done in sgp_ca
sgp => GetSolutionGroupFromList(solutiongrouplist, 1)
solutionIdx = sgp%idsolutions(subcomponent_idx)
solution => GetNumericalSolutionFromList(basesolutionlist, solutionIdx)
solution => GetBaseSolutionFromList(basesolutionlist, solutionIdx)
end function getSolution

!> @brief Get the grid type for a named model as a fortran string
Expand Down
41 changes: 24 additions & 17 deletions srcbmi/mf6xmi.F90
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ function xmi_prepare_solve(subcomponent_idx) result(bmi_status) &
!DIR$ ATTRIBUTES DLLEXPORT :: xmi_prepare_solve
! -- modules
use ListsModule, only: solutiongrouplist
use NumericalSolutionModule
use BaseSolutionModule, only: BaseSolutionType
use SimVariablesModule, only: istdout
! -- dummy variables
integer(kind=c_int) :: subcomponent_idx !< index of the subcomponent (i.e. Numerical Solution)
integer(kind=c_int) :: bmi_status !< BMI status code
! -- local variables
class(NumericalSolutionType), pointer :: ns
class(BaseSolutionType), pointer :: bs

! people might not call 'xmi_get_subcomponent_count' first, so let's repeat this:
if (solutiongrouplist%Count() /= 1) then
Expand All @@ -238,11 +238,11 @@ function xmi_prepare_solve(subcomponent_idx) result(bmi_status) &
return
end if

! get the numerical solution we are running
ns => getSolution(subcomponent_idx)
! get the solution we are running
bs => getSolution(subcomponent_idx)

! *_ad (model, exg, sln)
call ns%prepareSolve()
call bs%prepareSolve()

! reset counter
allocate (iterationCounter)
Expand All @@ -262,27 +262,34 @@ function xmi_solve(subcomponent_idx, has_converged) result(bmi_status) &
bind(C, name="solve")
!DIR$ ATTRIBUTES DLLEXPORT :: xmi_solve
! -- modules
use NumericalSolutionModule
use BaseSolutionModule, only: BaseSolutionType
use NumericalSolutionModule, only: NumericalSolutionType
use ExplicitSolutionModule, only: ExplicitSolutionType
! -- dummy variables
integer(kind=c_int), intent(in) :: subcomponent_idx !< index of the subcomponent (i.e. Numerical Solution)
integer(kind=c_int), intent(out) :: has_converged !< equal to 1 for convergence, 0 otherwise
integer(kind=c_int) :: bmi_status !< BMI status code
! -- local variables
class(NumericalSolutionType), pointer :: ns
class(BaseSolutionType), pointer :: bs

! get the numerical solution we are running
ns => getSolution(subcomponent_idx)
bs => getSolution(subcomponent_idx)

! execute the nth iteration
iterationCounter = iterationCounter + 1
call ns%solve(iterationCounter)
call bs%solve(iterationCounter)

! the following check is equivalent to that in NumericalSolution%sln_ca
if (ns%icnvg == 1) then
select type (bs)
class is (NumericalSolutionType)
if (bs%icnvg == 1) then
has_converged = 1
else
has_converged = 0
end if
class is (ExplicitSolutionType)
has_converged = 1
else
has_converged = 0
end if
end select

bmi_status = BMI_SUCCESS

Expand All @@ -300,23 +307,23 @@ function xmi_finalize_solve(subcomponent_idx) result(bmi_status) &
bind(C, name="finalize_solve")
!DIR$ ATTRIBUTES DLLEXPORT :: xmi_finalize_solve
! -- modules
use NumericalSolutionModule
use BaseSolutionModule, only: BaseSolutionType
! -- dummy variables
integer(kind=c_int), intent(in) :: subcomponent_idx !< index of the subcomponent (i.e. Numerical Solution)
integer(kind=c_int) :: bmi_status !< BMI status code
! -- local variables
class(NumericalSolutionType), pointer :: ns
class(BaseSolutionType), pointer :: bs
integer(I4B) :: hasConverged

! get the numerical solution we are running
ns => getSolution(subcomponent_idx)
bs => getSolution(subcomponent_idx)

! hasConverged is equivalent to the isgcnvg variable which is initialized to 1,
! see the body of the picard loop in SolutionGroupType%sgp_ca
hasConverged = 1

! finish up
call ns%finalizeSolve(iterationCounter, hasConverged, 0)
call bs%finalizeSolve(iterationCounter, hasConverged, 0)

! check convergence on solution
if (.not. hasConverged == 1) then
Expand Down

0 comments on commit c178700

Please sign in to comment.