Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remapping weights script #45

Merged
merged 7 commits into from
Jan 23, 2025
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions mesh_generation/generate_rof_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0

# =========================================================================================
# Generate an remapping weights between two ESMF mesh files for remapping a runoff field
# an unmasked mesh to a masked mesh without losing any water volume. Each field on the
# unmasked mesh is mapped to the nearest ocean cell in the resulting weights.
#
# To run:
# python generate_rof_weights.py --mesh_filename=<input_file> --weights_filename=<output_file>
anton-seaice marked this conversation as resolved.
Show resolved Hide resolved
#
# This script currently supports mesh files in the ESMF unstructed mesh format.
#
# There is not enough memory on the gadi login node to run this, its simplest to run in
# a terminal through are.nci.org.au
#
# The run command and full github url of the current version of this script is added to the
# metadata of the generated weights file. This is to uniquely identify the script and inputs used
# to generate the mesh file. To produce weights files for sharing, ensure you are using a version
# of this script which is committed and pushed to github. For mesh files intended for released
# configurations, use the latest version checked in to the main branch of the github repository.
#
# Contact:
# Anton Steketee <[email protected]>
#
# Dependencies:
# esmpy, xarray and scipy
# =========================================================================================


import xarray as xr
import esmpy
from sklearn.neighbors import BallTree
from numpy import deg2rad
from copy import copy

from pathlib import Path
import sys
import os
from datetime import datetime

path_root = Path(__file__).parents[1]
sys.path.append(str(path_root))
from scripts_common import get_provenance_metadata, md5sum

TEMP_WEIGHTS_F = "temp_weights.nc"
COMP_ENCODING = {"complevel": 1, "compression": "zlib"} # compression settings to use


def drof_remapping_weights(mesh_filename, weights_filename, global_attrs=None):
# We need to generate remapping weights for use in the mediator, such that the overall volume of runoff is conserved and no runoff is mapped onto land cells. Inside the mediator, the grid doesn't change as we run the mediator with the ocean grid (the DROF component does the remapping from JRA grid to mediator grid). There we use the same _mesh_file for the input and output mesh, however this same routine would work for differing input and output meshes

model_mesh = esmpy.Mesh(
filename=mesh_filename,
filetype=esmpy.FileFormat.ESMFMESH,
)

med_in_fld = esmpy.Field(model_mesh, meshloc=esmpy.MeshLoc.ELEMENT)

med_out_fld = esmpy.Field(model_mesh, meshloc=esmpy.MeshLoc.ELEMENT)

try:
os.remove(TEMP_WEIGHTS_F) # rm old temp file
except OSError:
pass

# Generate remapping weights and write to file.
esmpy.Regrid(
med_in_fld,
med_out_fld,
filename=TEMP_WEIGHTS_F,
regrid_method=esmpy.RegridMethod.CONSERVE,
# unmapped_action=esmpy.UnmappedAction.ERROR, #ignore errors about some destination cells not having source cells,
)

"""
From https://earthsystemmodeling.org/docs/release/ESMF_5_2_0rp3/ESMF_refdoc/node3.html :

" The indices and weights generated by ESMF_FieldRegridStore() are stored in the output file as variables col, row and S. Where col and row are the indices to the source and the destination grid cells. These are a one-dimension array with length defined by dimension n_s. S is the weight which is multiplied by the source value indicated by col and then summed with the destination value indicated by row to build the final interpolated value of the destination.

Per the above note, we want to adjust all row values, so they are ocean cells. When we do this, we want to adjust S, the weight to account for the difference in area.
"""

weights_ds = xr.open_dataset(TEMP_WEIGHTS_F)

mod_mesh_ds = xr.open_dataset(mesh_filename)

# Find index for all ocean cells
mask_i = mod_mesh_ds.elementCount.where(mod_mesh_ds.elementMask, drop=True).astype(
"int"
)

center_coords_rad = deg2rad(mod_mesh_ds.centerCoords)

# Make a BallTree from the ocean cells
mask_tree = BallTree(
center_coords_rad.isel(elementCount=mask_i), metric="haversine"
dougiesquire marked this conversation as resolved.
Show resolved Hide resolved
)

# Using the Tree, look up the nearest ocean cell to every destination grid cell in our weights file. Note our weights are indexed from 1 (i.e. Fortran style) but xarray starts from 0 (i.e. python style), so subract one from our destination grid cell indices.

ii = mask_tree.query(
center_coords_rad.isel(elementCount=(weights_ds.row - 1)), return_distance=False
dougiesquire marked this conversation as resolved.
Show resolved Hide resolved
)

new_row = mask_i[ii[:, 0]] + 1

# Get the mesh element areas and adjust:
# n.b. per CMEPS we are using the internally calculated areas, not the user provided ones.
med_out_fld.get_area()
area = copy(med_out_fld.data)
old_area = area[weights_ds.row - 1]
new_area = area[new_row - 1]

weights_ds["row"] = xr.DataArray(data=new_row, dims="n_s")

weights_ds["S"] = weights_ds.S * old_area / new_area

# add global attributes
weights_ds.attrs = {
anton-seaice marked this conversation as resolved.
Show resolved Hide resolved
"gridType": "unstructured mesh",
"inputFile": f"{mesh_filename} (md5 hash: {md5sum(mesh_filename)})",
}

# add git info to history
if global_attrs:
weights_ds.attrs |= global_attrs

# save (compressed)
encoding = {}
for iVar in weights_ds.data_vars:
encoding[iVar] = COMP_ENCODING
weights_ds.to_netcdf(weights_filename, encoding=encoding)

os.remove(TEMP_WEIGHTS_F)

return True


def main():
parser = argparse.ArgumentParser(
description="Create an remapping weights to transfer runoff from unmasked mesh to masked mesh using ESMF mesh file."
)

parser.add_argument(
"--mesh_filename",
type=str,
required=True,
help="The path to the mesh file specifying the model grid.",
anton-seaice marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
"--weights_filename",
type=str,
required=True,
help="The path to the weights file to output (netcdf).",
)

args = parser.parse_args()
mesh_filename = os.path.abspath(args.mesh_filename)
weights_filename = os.path.abspath(args.weights_filename)

this_file = os.path.normpath(__file__)

# Add some info about how the file was generated
runcmd = f"python3 {os.path.basename(this_file)} --mesh-filename={mesh_filename} --weights_filename={weights_filename} "

global_attrs = {"history": get_provenance_metadata(this_file, runcmd)}

drof_remapping_weights(mesh_filename, weights_filename, global_attrs)

return True


if __name__ == "__main__":
import argparse

main()
Loading