Skip to content

Commit

Permalink
Enhance NUOPC cap to support MOM_mask_table.
Browse files Browse the repository at this point in the history
- Determine masked blocks.
- Evenly distribute eliminated cells.
- Fill ESMF gindex array accordingly.
- During Export phase, set fields of eliminated cells to zero.
  • Loading branch information
alperaltuntas committed Oct 22, 2023
1 parent c58aa8d commit b017639
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 21 deletions.
138 changes: 118 additions & 20 deletions config_src/drivers/nuopc_cap/mom_cap.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_cap_mod
use MOM_domains, only: MOM_infra_init, MOM_infra_end
use MOM_file_parser, only: get_param, log_version, param_file_type, close_param_file
use MOM_get_input, only: get_MOM_input, directories
use MOM_domains, only: pass_var
use MOM_domains, only: pass_var, pe_here
use MOM_error_handler, only: MOM_error, FATAL, is_root_pe
use MOM_grid, only: ocean_grid_type, get_global_grid_size
use MOM_ocean_model_nuopc, only: ice_ocean_boundary_type
Expand All @@ -29,6 +29,7 @@ module MOM_cap_mod
use MOM_cap_methods, only: med2mod_areacor, state_diagnose
use MOM_cap_methods, only: ChkErr
use MOM_ensemble_manager, only: ensemble_manager_init
use MOM_coms, only: sum_across_PEs

#ifdef CESMCOUPLED
use shr_log_mod, only: shr_log_setLogUnit
Expand Down Expand Up @@ -826,6 +827,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
type(ocean_grid_type) , pointer :: ocean_grid
type(ocean_internalstate_wrapper) :: ocean_internalstate
integer :: npet, ntiles
integer :: npes ! number of PEs (from FMS).
integer :: nxg, nyg, cnt
integer :: isc,iec,jsc,jec
integer, allocatable :: xb(:),xe(:),yb(:),ye(:),pe(:)
Expand All @@ -852,6 +854,8 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
integer :: lsize
integer :: ig,jg, ni,nj,k
integer, allocatable :: gindex(:) ! global index space
integer, allocatable :: gindex_ocn(:) ! global index space for ocean cells (excl. masked cells)
integer, allocatable :: gindex_elim(:) ! global index space for eliminated cells
character(len=128) :: fldname
character(len=256) :: cvalue
character(len=256) :: frmt ! format specifier for several error msgs
Expand All @@ -875,6 +879,11 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
real(ESMF_KIND_R8) :: min_areacor_glob(2)
real(ESMF_KIND_R8) :: max_areacor_glob(2)
character(len=*), parameter :: subname='(MOM_cap:InitializeRealize)'
integer :: niproc, njproc
integer :: ip, jp, pe_ix
integer :: num_elim_blocks ! number of blocks to be eliminated
integer :: num_elim_cells_global, num_elim_cells_local, num_elim_cells_remaining
integer, allocatable :: cell_mask(:,:)
!--------------------------------

rc = ESMF_SUCCESS
Expand Down Expand Up @@ -919,19 +928,19 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
rc = ESMF_FAILURE
call ESMF_LogWrite(subname//' ntiles must be 1', ESMF_LOGMSG_ERROR)
endif
ntiles = mpp_get_domain_npes(ocean_public%domain)
write(tmpstr,'(a,1i6)') subname//' ntiles = ',ntiles
npes = mpp_get_domain_npes(ocean_public%domain)
write(tmpstr,'(a,1i6)') subname//' npes = ',npes
call ESMF_LogWrite(trim(tmpstr), ESMF_LOGMSG_INFO)

!---------------------------------
! get start and end indices of each tile and their PET
!---------------------------------

allocate(xb(ntiles),xe(ntiles),yb(ntiles),ye(ntiles),pe(ntiles))
allocate(xb(npes),xe(npes),yb(npes),ye(npes),pe(npes))
call mpp_get_compute_domains(ocean_public%domain, xbegin=xb, xend=xe, ybegin=yb, yend=ye)
call mpp_get_pelist(ocean_public%domain, pe)
if (dbug > 1) then
do n = 1,ntiles
do n = 1,npes
write(tmpstr,'(a,6i6)') subname//' tiles ',n,pe(n),xb(n),xe(n),yb(n),ye(n)
call ESMF_LogWrite(trim(tmpstr), ESMF_LOGMSG_INFO)
enddo
Expand All @@ -953,17 +962,102 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
call get_global_grid_size(ocean_grid, ni, nj)
lsize = ( ocean_grid%iec - ocean_grid%isc + 1 ) * ( ocean_grid%jec - ocean_grid%jsc + 1 )

! Create the global index space for the computational domain
allocate(gindex(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex(k) = ni * (jg - 1) + ig
num_elim_blocks = 0
num_elim_cells_global = 0
num_elim_cells_local = 0
num_elim_cells_remaining = 0

! Compute the number of eliminated blocks (specified in MOM_mask_table)
if (associated(ocean_grid%Domain%maskmap)) then
njproc = size(ocean_grid%Domain%maskmap, 1)
niproc = size(ocean_grid%Domain%maskmap, 2)

do ip = 1, niproc
do jp = 1, njproc
if (.not. ocean_grid%Domain%maskmap(jp,ip)) then
num_elim_blocks = num_elim_blocks+1
endif
enddo
enddo
enddo
endif

! Apply land block elimination to ESMF gindex
! (Here we assume that each processor gets assigned a single tile. If multi-tile implementation is to be added
! in MOM6 NUOPC cap in the future, below code must be updated accordingly.)
if (num_elim_blocks>0) then

allocate(cell_mask(ni, nj), source=0)
allocate(gindex_ocn(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex_ocn(k) = ni * (jg - 1) + ig
cell_mask(ig, jg) = 1
enddo
enddo
call sum_across_PEs(cell_mask, ni*nj)

if (maxval(cell_mask) /= 1 ) then
call MOM_error(FATAL, "Encountered cells shared by multiple PEs while attempting to determine masked cells.")
endif

num_elim_cells_global = ni * nj - sum(cell_mask)
num_elim_cells_local = num_elim_cells_global / npes

if (pe_here() == pe(npes)) then
! assign all remaining cells to the last PE.
num_elim_cells_remaining = num_elim_cells_global - num_elim_cells_local * npes
allocate(gindex_elim(num_elim_cells_local+num_elim_cells_remaining))
else
allocate(gindex_elim(num_elim_cells_local))
endif

! Zero-based PE index.
pe_ix = pe_here() - pe(1)

k = 0
do jg = 1, nj
do ig = 1, ni
if (cell_mask(ig, jg) == 0) then
k = k + 1
if (k > pe_ix * num_elim_cells_local .and. &
k <= ((pe_ix+1) * num_elim_cells_local + num_elim_cells_remaining)) then
gindex_elim(k - pe_ix * num_elim_cells_local) = ni * (jg -1) + ig
endif
endif
enddo
enddo

allocate(gindex(lsize + num_elim_cells_local + num_elim_cells_remaining))
do k = 1, lsize
gindex(k) = gindex_ocn(k)
enddo
do k = 1, num_elim_cells_local + num_elim_cells_remaining
gindex(k+lsize) = gindex_elim(k)
enddo

deallocate(cell_mask)
deallocate(gindex_ocn)
deallocate(gindex_elim)

else ! no eliminated land blocks

! Create the global index space for the computational domain
allocate(gindex(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex(k) = ni * (jg - 1) + ig
enddo
enddo

endif

DistGrid = ESMF_DistGridCreate(arbSeqIndexList=gindex, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return
Expand All @@ -987,6 +1081,10 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
call ESMF_MeshGet(Emesh, spatialDim=spatialDim, numOwnedElements=numOwnedElements, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return

if (lsize /= numOwnedElements - num_elim_cells_local - num_elim_cells_remaining) then
call MOM_error(FATAL, "Discrepancy detected between ESMF mesh and internal MOM6 domain sizes. Check mask table.")
endif

allocate(ownedElemCoords(spatialDim*numOwnedElements))
allocate(lonMesh(numOwnedElements), lon(numOwnedElements))
allocate(latMesh(numOwnedElements), lat(numOwnedElements))
Expand Down Expand Up @@ -1018,7 +1116,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
end do

eps_omesh = get_eps_omesh(ocean_state)
do n = 1,numOwnedElements
do n = 1,lsize
diff_lon = abs(mod(lonMesh(n) - lon(n),360.0))
if (diff_lon > eps_omesh) then
frmt = "('ERROR: Difference between ESMF Mesh and MOM6 domain coords is "//&
Expand Down Expand Up @@ -1122,11 +1220,11 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)

! generate delayout and dist_grid

allocate(deBlockList(2,2,ntiles))
allocate(petMap(ntiles))
allocate(deLabelList(ntiles))
allocate(deBlockList(2,2,npes))
allocate(petMap(npes))
allocate(deLabelList(npes))

do n = 1, ntiles
do n = 1, npes
deLabelList(n) = n
deBlockList(1,1,n) = xb(n)
deBlockList(1,2,n) = xe(n)
Expand Down
9 changes: 8 additions & 1 deletion config_src/drivers/nuopc_cap/mom_cap_methods.F90
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ subroutine State_SetExport(state, fldname, isc, iec, jsc, jec, input, ocean_grid

! local variables
type(ESMF_StateItem_Flag) :: itemFlag
integer :: n, i, j, i1, j1, ig,jg
integer :: n, i, j, k, i1, j1, ig,jg
integer :: lbnd1,lbnd2
real(ESMF_KIND_R8), pointer :: dataPtr1d(:)
real(ESMF_KIND_R8), pointer :: dataPtr2d(:,:)
Expand Down Expand Up @@ -888,6 +888,13 @@ subroutine State_SetExport(state, fldname, isc, iec, jsc, jec, input, ocean_grid
enddo
end if

! if a maskmap is provided, set exports of all eliminated cells to zero.
if (associated(ocean_grid%Domain%maskmap)) then
do k = n+1, size(dataPtr1d)
dataPtr1d(k) = 0.0
enddo
endif

else if (geomtype == ESMF_GEOMTYPE_GRID) then

call state_getfldptr(state, trim(fldname), dataptr2d, rc)
Expand Down

0 comments on commit b017639

Please sign in to comment.