diff --git a/parm/archive/enkf.yaml.j2 b/parm/archive/enkf.yaml.j2 index d3f16e8e69..a95046d4d6 100644 --- a/parm/archive/enkf.yaml.j2 +++ b/parm/archive/enkf.yaml.j2 @@ -54,12 +54,14 @@ enkf: "radstat.ensmean"] %} {% else %} {% if lobsdiag_forenkf %} - {% set da_files = ["atmens_observer.yaml", - "atmens_solver.yaml", + {% set da_files = ["atmensanlobs.yaml", + "atmensanlsol.yaml", + "atmensanlfv3inc.yaml", "atminc.ensmean.nc", "atmensstat"] %} {% else %} - {% set da_files = ["atmens.yaml", + {% set da_files = ["atmensanlletkf.yaml", + "atmensanlfv3inc.yaml", "atminc.ensmean.nc", "atmensstat"] %} {% endif %} diff --git a/parm/archive/gdas.yaml.j2 b/parm/archive/gdas.yaml.j2 index db92141ede..56e47e595a 100644 --- a/parm/archive/gdas.yaml.j2 +++ b/parm/archive/gdas.yaml.j2 @@ -58,7 +58,8 @@ gdas: # Analysis state {% if DO_JEDIATMVAR %} - - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmvar.yaml" + - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmanlvar.yaml" + - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmanlfv3inc.yaml" - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmstat" {% else %} - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}gsistat" diff --git a/parm/archive/gfsa.yaml.j2 b/parm/archive/gfsa.yaml.j2 index 4a86778e2e..226a7178fa 100644 --- a/parm/archive/gfsa.yaml.j2 +++ b/parm/archive/gfsa.yaml.j2 @@ -32,7 +32,8 @@ gfsa: # State data {% if DO_JEDIATMVAR %} - - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmvar.yaml" + - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmanlvar.yaml" + - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmanlfv3inc.yaml" - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}atmstat" {% else %} - "{{ COMIN_ATMOS_ANALYSIS | relpath(ROTDIR) }}/{{ head }}gsistat" diff --git a/parm/config/gfs/config.atmensanl b/parm/config/gfs/config.atmensanl index ddd3d88659..f5a1278248 100644 --- a/parm/config/gfs/config.atmensanl +++ b/parm/config/gfs/config.atmensanl @@ -6,7 +6,11 @@ echo "BEGIN: config.atmensanl" export JCB_BASE_YAML="${PARMgfs}/gdas/atm/jcb-base.yaml.j2" -export JCB_ALGO_YAML=@JCB_ALGO_YAML@ +if [[ ${lobsdiag_forenkf} = ".false." ]] ; then + export JCB_ALGO_YAML=@JCB_ALGO_YAML_LETKF@ +else + export JCB_ALGO_YAML=@JCB_ALGO_YAML_OBS@ +fi export INTERP_METHOD='barycentric' diff --git a/parm/config/gfs/yaml/defaults.yaml b/parm/config/gfs/yaml/defaults.yaml index b423601df3..05e1b24012 100644 --- a/parm/config/gfs/yaml/defaults.yaml +++ b/parm/config/gfs/yaml/defaults.yaml @@ -31,7 +31,8 @@ atmanl: IO_LAYOUT_Y: 1 atmensanl: - JCB_ALGO_YAML: "${PARMgfs}/gdas/atm/jcb-prototype_lgetkf.yaml.j2" + JCB_ALGO_YAML_LETKF: "${PARMgfs}/gdas/atm/jcb-prototype_lgetkf.yaml.j2" + JCB_ALGO_YAML_OBS: "${PARMgfs}/gdas/atm/jcb-prototype_lgetkf_observer.yaml.j2" LAYOUT_X_ATMENSANL: 8 LAYOUT_Y_ATMENSANL: 8 IO_LAYOUT_X: 1 diff --git a/scripts/exglobal_atm_analysis_finalize.py b/scripts/exglobal_atm_analysis_finalize.py index 3f4313631c..35220928c9 100755 --- a/scripts/exglobal_atm_analysis_finalize.py +++ b/scripts/exglobal_atm_analysis_finalize.py @@ -21,4 +21,6 @@ # Instantiate the atm analysis task AtmAnl = AtmAnalysis(config) + + # Finalize JEDI variational analysis AtmAnl.finalize() diff --git a/scripts/exglobal_atm_analysis_fv3_increment.py b/scripts/exglobal_atm_analysis_fv3_increment.py index 66f6796343..72413ddbd4 100755 --- a/scripts/exglobal_atm_analysis_fv3_increment.py +++ b/scripts/exglobal_atm_analysis_fv3_increment.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # exglobal_atm_analysis_fv3_increment.py # This script creates an AtmAnalysis object -# and runs the init_fv3_increment and fv3_increment methods +# and runs the initialize_fv3inc and execute methods # which convert the JEDI increment into an FV3 increment import os @@ -17,7 +17,9 @@ # Take configuration from environment and cast it as python dictionary config = cast_strdict_as_dtypedict(os.environ) - # Instantiate the atm analysis task - AtmAnl = AtmAnalysis(config) - AtmAnl.init_fv3_increment() - AtmAnl.fv3_increment() + # Instantiate the atm analysis object + AtmAnl = AtmAnalysis(config, 'atmanlfv3inc') + + # Initialize and execute FV3 increment converter + AtmAnl.initialize_jedi() + AtmAnl.execute(config.APRUN_ATMANLFV3INC) diff --git a/scripts/exglobal_atm_analysis_initialize.py b/scripts/exglobal_atm_analysis_initialize.py index 1793b24b0b..9deae07bb3 100755 --- a/scripts/exglobal_atm_analysis_initialize.py +++ b/scripts/exglobal_atm_analysis_initialize.py @@ -20,5 +20,8 @@ config = cast_strdict_as_dtypedict(os.environ) # Instantiate the atm analysis task - AtmAnl = AtmAnalysis(config) - AtmAnl.initialize() + AtmAnl = AtmAnalysis(config, 'atmanlvar') + + # Initialize JEDI variational analysis + AtmAnl.initialize_jedi() + AtmAnl.initialize_analysis() diff --git a/scripts/exglobal_atm_analysis_variational.py b/scripts/exglobal_atm_analysis_variational.py index 07bc208331..8359532069 100755 --- a/scripts/exglobal_atm_analysis_variational.py +++ b/scripts/exglobal_atm_analysis_variational.py @@ -18,5 +18,7 @@ config = cast_strdict_as_dtypedict(os.environ) # Instantiate the atm analysis task - AtmAnl = AtmAnalysis(config) - AtmAnl.variational() + AtmAnl = AtmAnalysis(config, 'atmanlvar') + + # Execute JEDI variational analysis + AtmAnl.execute(config.APRUN_ATMANLVAR, ['fv3jedi', 'variational']) diff --git a/scripts/exglobal_atmens_analysis_finalize.py b/scripts/exglobal_atmens_analysis_finalize.py index b49cb3c413..d68c260e78 100755 --- a/scripts/exglobal_atmens_analysis_finalize.py +++ b/scripts/exglobal_atmens_analysis_finalize.py @@ -21,4 +21,6 @@ # Instantiate the atmens analysis task AtmEnsAnl = AtmEnsAnalysis(config) + + # Finalize ensemble DA analysis AtmEnsAnl.finalize() diff --git a/scripts/exglobal_atmens_analysis_fv3_increment.py b/scripts/exglobal_atmens_analysis_fv3_increment.py index c50b00548f..48eb6a6a1e 100755 --- a/scripts/exglobal_atmens_analysis_fv3_increment.py +++ b/scripts/exglobal_atmens_analysis_fv3_increment.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # exglobal_atmens_analysis_fv3_increment.py # This script creates an AtmEnsAnalysis object -# and runs the init_fv3_increment and fv3_increment methods +# and runs the initialize_fv3inc and execute methods # which convert the JEDI increment into an FV3 increment import os @@ -17,7 +17,9 @@ # Take configuration from environment and cast it as python dictionary config = cast_strdict_as_dtypedict(os.environ) - # Instantiate the atmens analysis task - AtmEnsAnl = AtmEnsAnalysis(config) - AtmEnsAnl.init_fv3_increment() - AtmEnsAnl.fv3_increment() + # Instantiate the atmens analysis object + AtmEnsAnl = AtmEnsAnalysis(config, 'atmensanlfv3inc') + + # Initialize and execute JEDI FV3 increment converter + AtmEnsAnl.initialize_jedi() + AtmEnsAnl.execute(config.APRUN_ATMENSANLFV3INC) diff --git a/scripts/exglobal_atmens_analysis_initialize.py b/scripts/exglobal_atmens_analysis_initialize.py index 1d578b44f2..326fe80628 100755 --- a/scripts/exglobal_atmens_analysis_initialize.py +++ b/scripts/exglobal_atmens_analysis_initialize.py @@ -20,5 +20,11 @@ config = cast_strdict_as_dtypedict(os.environ) # Instantiate the atmens analysis task - AtmEnsAnl = AtmEnsAnalysis(config) - AtmEnsAnl.initialize() + if not config.lobsdiag_forenkf: + AtmEnsAnl = AtmEnsAnalysis(config, 'atmensanlletkf') + else: + AtmEnsAnl = AtmEnsAnalysis(config, 'atmensanlobs') + + # Initialize JEDI ensemble DA analysis + AtmEnsAnl.initialize_jedi() + AtmEnsAnl.initialize_analysis() diff --git a/scripts/exglobal_atmens_analysis_letkf.py b/scripts/exglobal_atmens_analysis_letkf.py index 30394537cd..45b06524fe 100755 --- a/scripts/exglobal_atmens_analysis_letkf.py +++ b/scripts/exglobal_atmens_analysis_letkf.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # exglobal_atmens_analysis_letkf.py # This script creates an AtmEnsAnalysis object -# and runs the letkf method -# which executes the global atm local ensemble analysis +# and runs the execute method which executes +# the global atm local ensemble analysis import os from wxflow import Logger, cast_strdict_as_dtypedict @@ -18,5 +18,7 @@ config = cast_strdict_as_dtypedict(os.environ) # Instantiate the atmens analysis task - AtmEnsAnl = AtmEnsAnalysis(config) - AtmEnsAnl.letkf() + AtmEnsAnl = AtmEnsAnalysis(config, 'atmensanlletkf') + + # Execute the JEDI ensemble DA analysis + AtmEnsAnl.execute(config.APRUN_ATMENSANLLETKF, ['fv3jedi', 'localensembleda']) diff --git a/scripts/exglobal_atmens_analysis_obs.py b/scripts/exglobal_atmens_analysis_obs.py index e4b5c98952..c701f8cb4e 100755 --- a/scripts/exglobal_atmens_analysis_obs.py +++ b/scripts/exglobal_atmens_analysis_obs.py @@ -18,6 +18,7 @@ config = cast_strdict_as_dtypedict(os.environ) # Instantiate the atmens analysis task - AtmEnsAnl = AtmEnsAnalysis(config) - AtmEnsAnl.init_observer() - AtmEnsAnl.observe() + AtmEnsAnl = AtmEnsAnalysis(config, 'atmensanlobs') + + # Initialize and execute JEDI ensembler DA analysis in observer mode + AtmEnsAnl.execute(config.APRUN_ATMENSANLOBS, ['fv3jedi', 'localensembleda']) diff --git a/scripts/exglobal_atmens_analysis_sol.py b/scripts/exglobal_atmens_analysis_sol.py index db55959753..be78e694b1 100755 --- a/scripts/exglobal_atmens_analysis_sol.py +++ b/scripts/exglobal_atmens_analysis_sol.py @@ -18,6 +18,8 @@ config = cast_strdict_as_dtypedict(os.environ) # Instantiate the atmens analysis task - AtmEnsAnl = AtmEnsAnalysis(config) - AtmEnsAnl.init_solver() - AtmEnsAnl.solve() + AtmEnsAnl = AtmEnsAnalysis(config, 'atmensanlsol') + + # Initialize and execute JEDI ensemble DA analysis in solver mode + AtmEnsAnl.initialize_jedi() + AtmEnsAnl.execute(config.APRUN_ATMENSANLSOL, ['fv3jedi', 'localensembleda']) diff --git a/sorc/gdas.cd b/sorc/gdas.cd index 09594d1c03..faa95efb18 160000 --- a/sorc/gdas.cd +++ b/sorc/gdas.cd @@ -1 +1 @@ -Subproject commit 09594d1c032fd187f9869ac74b2b5b351112e93c +Subproject commit faa95efb18f0f52acab2cf09b17f78406f9b48b1 diff --git a/ush/python/pygfs/jedi/__init__.py b/ush/python/pygfs/jedi/__init__.py new file mode 100644 index 0000000000..5d7e85057c --- /dev/null +++ b/ush/python/pygfs/jedi/__init__.py @@ -0,0 +1 @@ +from .jedi import Jedi diff --git a/ush/python/pygfs/jedi/jedi.py b/ush/python/pygfs/jedi/jedi.py new file mode 100644 index 0000000000..62dcb517ca --- /dev/null +++ b/ush/python/pygfs/jedi/jedi.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 + +import os +from logging import getLogger +from typing import List, Dict, Any, Optional +from jcb import render +from wxflow import (AttrDict, + chdir, rm_p, + parse_j2yaml, + logit, + Task, + Executable, + WorkflowException) + +logger = getLogger(__name__.split('.')[-1]) + + +class Jedi: + """ + Class for initializing and executing JEDI applications + """ + @logit(logger, name="Jedi") + def __init__(self, task_config: AttrDict, yaml_name: Optional[str] = None) -> None: + """Constructor for JEDI objects + + This method will construct a Jedi object. + This includes: + - save a copy of task_config for provenance + - set the default JEDI YAML and executable names + - set an empty AttrDict for the JEDI config + - set the default directory for J2-YAML templates + + Parameters + ---------- + task_config: AttrDict + Attribute-dictionary of all configuration variables associated with a GDAS task. + yaml_name: str, optional + Name of YAML file for JEDI configuration + + Returns + ---------- + None + """ + + # For provenance, save incoming task_config as a private attribute of JEDI object + self._task_config = task_config + + _exe_name = os.path.basename(task_config.JEDIEXE) + + self.exe = os.path.join(task_config.DATA, _exe_name) + if yaml_name: + self.yaml = os.path.join(task_config.DATA, yaml_name + '.yaml') + else: + self.yaml = os.path.join(task_config.DATA, os.path.splitext(_exe_name)[0] + '.yaml') + self.config = AttrDict() + self.j2tmpl_dir = os.path.join(task_config.PARMgfs, 'gdas') + + @logit(logger) + def set_config(self, task_config: AttrDict, algorithm: Optional[str] = None) -> AttrDict: + """Compile a JEDI configuration dictionary from a template file and save to a YAML file + + Parameters + ---------- + task_config : AttrDict + Dictionary of all configuration variables associated with a GDAS task. + algorithm (optional) : str + Name of the algorithm used to generate the JEDI configuration dictionary. + It will override the algorithm set in the task_config.JCB_<>_YAML file. + + Returns + ---------- + None + """ + + if 'JCB_BASE_YAML' in task_config.keys(): + # Step 1: Fill templates of the JCB base YAML file + jcb_config = parse_j2yaml(task_config.JCB_BASE_YAML, task_config) + + # Step 2: If algorithm is present then override the algorithm in the JEDI + # config. Otherwise, if the algorithm J2-YAML is present, fill + # its templates and merge. + if algorithm: + jcb_config['algorithm'] = algorithm + elif 'JCB_ALGO' in task_config.keys(): + jcb_config['algorithm'] = task_config.JCB_ALGO + elif 'JCB_ALGO_YAML' in task_config.keys(): + jcb_algo_config = parse_j2yaml(task_config.JCB_ALGO_YAML, task_config) + jcb_config.update(jcb_algo_config) + + # Step 3: Generate the JEDI YAML using JCB + self.config = render(jcb_config) + elif 'JEDIYAML' in task_config.keys(): + # Generate JEDI YAML without using JCB + self.config = parse_j2yaml(task_config.JEDIYAML, task_config, + searchpath=self.j2tmpl_dir) + else: + logger.exception(f"FATAL ERROR: Unable to compile JEDI configuration dictionary, ABORT!") + raise KeyError(f"FATAL ERROR: Task config must contain JCB_BASE_YAML or JEDIYAML") + + @logit(logger) + def execute(self, task_config: AttrDict, aprun_cmd: str, jedi_args: Optional[List] = None) -> None: + """Execute JEDI application + + Parameters + ---------- + task_config: AttrDict + Attribute-dictionary of all configuration variables associated with a GDAS task. + aprun_cmd: str + String comprising the run command for the JEDI executable. + jedi_args (optional): List + List of strings comprising optional input arguments for the JEDI executable. + + Returns + ---------- + jedi_config: AttrDict + Attribute-dictionary of JEDI configuration rendered from a template. + """ + + chdir(task_config.DATA) + + exec_cmd = Executable(aprun_cmd) + exec_cmd.add_default_arg(self.exe) + if jedi_args: + for arg in jedi_args: + exec_cmd.add_default_arg(arg) + exec_cmd.add_default_arg(self.yaml) + + try: + exec_cmd() + except OSError: + raise OSError(f"FATAL ERROR: Failed to execute {exec_cmd}") + except Exception: + raise WorkflowException(f"FATAL ERROR: An error occurred during execution of {exec_cmd}") + + @staticmethod + @logit(logger) + def link_exe(task_config: AttrDict) -> None: + """Link JEDI executable to run directory + + Parameters + ---------- + task_config: AttrDict + Attribute-dictionary of all configuration variables associated with a GDAS task. + + Returns + ---------- + None + """ + + # TODO: linking is not permitted per EE2. + # Needs work in JEDI to be able to copy the exec. [NOAA-EMC/GDASApp#1254] + logger.warn("Linking is not permitted per EE2.") + exe_dest = os.path.join(task_config.DATA, os.path.basename(task_config.JEDIEXE)) + if os.path.exists(exe_dest): + rm_p(exe_dest) + os.symlink(task_config.JEDIEXE, exe_dest) + + @logit(logger) + def get_obs_dict(self, task_config: AttrDict) -> Dict[str, Any]: + """Compile a dictionary of observation files to copy + + This method extracts 'observers' from the JEDI yaml and from that list, extracts a list of + observation files that are to be copied to the run directory + from the observation input directory + + Parameters + ---------- + task_config: AttrDict + Attribute-dictionary of all configuration variables associated with a GDAS task. + + Returns + ---------- + obs_dict: Dict + a dictionary containing the list of observation files to copy for FileHandler + """ + + observations = find_value_in_nested_dict(self.config, 'observations') + + copylist = [] + for ob in observations['observers']: + obfile = ob['obs space']['obsdatain']['engine']['obsfile'] + basename = os.path.basename(obfile) + copylist.append([os.path.join(task_config.COM_OBS, basename), obfile]) + obs_dict = { + 'mkdir': [os.path.join(task_config.DATA, 'obs')], + 'copy': copylist + } + return obs_dict + + @logit(logger) + def get_bias_dict(self, task_config: AttrDict) -> Dict[str, Any]: + """Compile a dictionary of observation files to copy + + This method extracts 'observers' from the JEDI yaml and from that list, extracts a list of + observation bias correction files that are to be copied to the run directory + from the component directory. + TODO: COM_ATMOS_ANALYSIS_PREV is hardwired here and this method is not appropriate in + `analysis.py` and should be implemented in the component where this is applicable. + + Parameters + ---------- + task_config: AttrDict + Attribute-dictionary of all configuration variables associated with a GDAS task. + + Returns + ---------- + bias_dict: Dict + a dictionary containing the list of observation bias files to copy for FileHandler + """ + + observations = find_value_in_nested_dict(self.config, 'observations') + + copylist = [] + for ob in observations['observers']: + if 'obs bias' in ob.keys(): + obfile = ob['obs bias']['input file'] + obdir = os.path.dirname(obfile) + basename = os.path.basename(obfile) + prefix = '.'.join(basename.split('.')[:-2]) + for file in ['satbias.nc', 'satbias_cov.nc', 'tlapse.txt']: + bfile = f"{prefix}.{file}" + copylist.append([os.path.join(task_config.COM_ATMOS_ANALYSIS_PREV, bfile), os.path.join(obdir, bfile)]) + # TODO: Why is this specific to ATMOS? + + bias_dict = { + 'mkdir': [os.path.join(task_config.DATA, 'bc')], + 'copy': copylist + } + return bias_dict + + +@logit(logger) +def find_value_in_nested_dict(nested_dict: Dict, target_key: str) -> Any: + """ + Recursively search through a nested dictionary and return the value for the target key. + This returns the first target key it finds. So if a key exists in a subsequent + nested dictionary, it will not be found. + + Parameters + ---------- + nested_dict : Dict + Dictionary to search + target_key : str + Key to search for + + Returns + ------- + Any + Value of the target key + + Raises + ------ + KeyError + If key is not found in dictionary + + TODO: if this gives issues due to landing on an incorrect key in the nested + dictionary, we will have to implement a more concrete method to search for a key + given a more complete address. See resolved conversations in PR 2387 + + # Example usage: + nested_dict = { + 'a': { + 'b': { + 'c': 1, + 'd': { + 'e': 2, + 'f': 3 + } + }, + 'g': 4 + }, + 'h': { + 'i': 5 + }, + 'j': { + 'k': 6 + } + } + + user_key = input("Enter the key to search for: ") + result = find_value_in_nested_dict(nested_dict, user_key) + """ + + if not isinstance(nested_dict, dict): + raise TypeError(f"Input is not of type(dict)") + + result = nested_dict.get(target_key) + if result is not None: + return result + + for value in nested_dict.values(): + if isinstance(value, dict): + try: + result = find_value_in_nested_dict(value, target_key) + if result is not None: + return result + except KeyError: + pass + + raise KeyError(f"Key '{target_key}' not found in the nested dictionary") diff --git a/ush/python/pygfs/task/atm_analysis.py b/ush/python/pygfs/task/atm_analysis.py index 4e9d37335c..8d340a5b73 100644 --- a/ush/python/pygfs/task/atm_analysis.py +++ b/ush/python/pygfs/task/atm_analysis.py @@ -5,33 +5,49 @@ import gzip import tarfile from logging import getLogger -from typing import Dict, List, Any +from pprint import pformat +from typing import Optional, Dict, Any from wxflow import (AttrDict, FileHandler, add_to_datetime, to_fv3time, to_timedelta, to_YMDH, - chdir, + Task, parse_j2yaml, save_as_yaml, - logit, - Executable, - WorkflowException) -from pygfs.task.analysis import Analysis + logit) +from pygfs.jedi import Jedi logger = getLogger(__name__.split('.')[-1]) -class AtmAnalysis(Analysis): +class AtmAnalysis(Task): """ - Class for global atm analysis tasks + Class for JEDI-based global atm analysis tasks """ @logit(logger, name="AtmAnalysis") - def __init__(self, config): + def __init__(self, config: Dict[str, Any], yaml_name: Optional[str] = None): + """Constructor global atm analysis task + + This method will construct a global atm analysis task. + This includes: + - extending the task_config attribute AttrDict to include parameters required for this task + - instantiate the Jedi attribute object + + Parameters + ---------- + config: Dict + dictionary object containing task configuration + yaml_name: str, optional + name of YAML file for JEDI configuration + + Returns + ---------- + None + """ super().__init__(config) _res = int(self.task_config.CASE[1:]) _res_anl = int(self.task_config.CASE_ANL[1:]) _window_begin = add_to_datetime(self.task_config.current_cycle, -to_timedelta(f"{self.task_config.assim_freq}H") / 2) - _jedi_yaml = os.path.join(self.task_config.DATA, f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.atmvar.yaml") # Create a local dictionary that is repeatedly used across this class local_dict = AttrDict( @@ -48,7 +64,6 @@ def __init__(self, config): 'OPREFIX': f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.", 'APREFIX': f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.", 'GPREFIX': f"gdas.t{self.task_config.previous_cycle.hour:02d}z.", - 'jedi_yaml': _jedi_yaml, 'atm_obsdatain_path': f"{self.task_config.DATA}/obs/", 'atm_obsdataout_path': f"{self.task_config.DATA}/diags/", 'BKG_TSTEP': "PT1H" # Placeholder for 4D applications @@ -58,30 +73,87 @@ def __init__(self, config): # Extend task_config with local_dict self.task_config = AttrDict(**self.task_config, **local_dict) + # Create JEDI object + self.jedi = Jedi(self.task_config, yaml_name) + @logit(logger) - def initialize(self: Analysis) -> None: + def initialize_jedi(self): + """Initialize JEDI application + + This method will initialize a JEDI application used in the global atm analysis. + This includes: + - generating and saving JEDI YAML config + - linking the JEDI executable + + Parameters + ---------- + None + + Returns + ---------- + None + """ + + # get JEDI-to-FV3 increment converter config and save to YAML file + logger.info(f"Generating JEDI YAML config: {self.jedi.yaml}") + self.jedi.set_config(self.task_config) + logger.debug(f"JEDI config:\n{pformat(self.jedi.config)}") + + # save JEDI config to YAML file + logger.debug(f"Writing JEDI YAML config to: {self.jedi.yaml}") + save_as_yaml(self.jedi.config, self.jedi.yaml) + + # link JEDI executable + logger.info(f"Linking JEDI executable {self.task_config.JEDIEXE} to {self.jedi.exe}") + self.jedi.link_exe(self.task_config) + + @logit(logger) + def initialize_analysis(self) -> None: """Initialize a global atm analysis - This method will initialize a global atm analysis using JEDI. + This method will initialize a global atm analysis. This includes: + - staging observation files + - staging bias correction files - staging CRTM fix files - staging FV3-JEDI fix files - staging B error files - staging model backgrounds - - generating a YAML file for the JEDI executable - creating output directories + + Parameters + ---------- + None + + Returns + ---------- + None """ super().initialize() + # stage observations + logger.info(f"Staging list of observation files generated from JEDI config") + obs_dict = self.jedi.get_obs_dict(self.task_config) + FileHandler(obs_dict).sync() + logger.debug(f"Observation files:\n{pformat(obs_dict)}") + + # stage bias corrections + logger.info(f"Staging list of bias correction files generated from JEDI config") + bias_dict = self.jedi.get_bias_dict(self.task_config) + FileHandler(bias_dict).sync() + logger.debug(f"Bias correction files:\n{pformat(bias_dict)}") + # stage CRTM fix files logger.info(f"Staging CRTM fix files from {self.task_config.CRTM_FIX_YAML}") - crtm_fix_list = parse_j2yaml(self.task_config.CRTM_FIX_YAML, self.task_config) - FileHandler(crtm_fix_list).sync() + crtm_fix_dict = parse_j2yaml(self.task_config.CRTM_FIX_YAML, self.task_config) + FileHandler(crtm_fix_dict).sync() + logger.debug(f"CRTM fix files:\n{pformat(crtm_fix_dict)}") # stage fix files logger.info(f"Staging JEDI fix files from {self.task_config.JEDI_FIX_YAML}") - jedi_fix_list = parse_j2yaml(self.task_config.JEDI_FIX_YAML, self.task_config) - FileHandler(jedi_fix_list).sync() + jedi_fix_dict = parse_j2yaml(self.task_config.JEDI_FIX_YAML, self.task_config) + FileHandler(jedi_fix_dict).sync() + logger.debug(f"JEDI fix files:\n{pformat(jedi_fix_dict)}") # stage static background error files, otherwise it will assume ID matrix logger.info(f"Stage files for STATICB_TYPE {self.task_config.STATICB_TYPE}") @@ -90,22 +162,20 @@ def initialize(self: Analysis) -> None: else: berror_staging_dict = {} FileHandler(berror_staging_dict).sync() + logger.debug(f"Background error files:\n{pformat(berror_staging_dict)}") # stage ensemble files for use in hybrid background error if self.task_config.DOHYBVAR: logger.debug(f"Stage ensemble files for DOHYBVAR {self.task_config.DOHYBVAR}") fv3ens_staging_dict = parse_j2yaml(self.task_config.FV3ENS_STAGING_YAML, self.task_config) FileHandler(fv3ens_staging_dict).sync() + logger.debug(f"Ensemble files:\n{pformat(fv3ens_staging_dict)}") # stage backgrounds logger.info(f"Staging background files from {self.task_config.VAR_BKG_STAGING_YAML}") bkg_staging_dict = parse_j2yaml(self.task_config.VAR_BKG_STAGING_YAML, self.task_config) FileHandler(bkg_staging_dict).sync() - - # generate variational YAML file - logger.debug(f"Generate variational YAML file: {self.task_config.jedi_yaml}") - save_as_yaml(self.task_config.jedi_config, self.task_config.jedi_yaml) - logger.info(f"Wrote variational YAML to: {self.task_config.jedi_yaml}") + logger.debug(f"Background files:\n{pformat(bkg_staging_dict)}") # need output dir for diags and anl logger.debug("Create empty output [anl, diags] directories to receive output from executable") @@ -116,54 +186,32 @@ def initialize(self: Analysis) -> None: FileHandler({'mkdir': newdirs}).sync() @logit(logger) - def variational(self: Analysis) -> None: - - chdir(self.task_config.DATA) - - exec_cmd = Executable(self.task_config.APRUN_ATMANLVAR) - exec_name = os.path.join(self.task_config.DATA, 'gdas.x') - exec_cmd.add_default_arg(exec_name) - exec_cmd.add_default_arg('fv3jedi') - exec_cmd.add_default_arg('variational') - exec_cmd.add_default_arg(self.task_config.jedi_yaml) + def execute(self, aprun_cmd: str, jedi_args: Optional[str] = None) -> None: + """Run JEDI executable - try: - logger.debug(f"Executing {exec_cmd}") - exec_cmd() - except OSError: - raise OSError(f"Failed to execute {exec_cmd}") - except Exception: - raise WorkflowException(f"An error occured during execution of {exec_cmd}") + This method will run JEDI executables for the global atm analysis - pass + Parameters + ---------- + aprun_cmd : str + Run command for JEDI application on HPC system + jedi_args : List + List of additional optional arguments for JEDI application - @logit(logger) - def init_fv3_increment(self: Analysis) -> None: - # Setup JEDI YAML file - self.task_config.jedi_yaml = os.path.join(self.task_config.DATA, - f"{self.task_config.JCB_ALGO}.yaml") - save_as_yaml(self.get_jedi_config(self.task_config.JCB_ALGO), self.task_config.jedi_yaml) + Returns + ---------- + None + """ - # Link JEDI executable to run directory - self.task_config.jedi_exe = self.link_jediexe() + if jedi_args: + logger.info(f"Executing {self.jedi.exe} {' '.join(jedi_args)} {self.jedi.yaml}") + else: + logger.info(f"Executing {self.jedi.exe} {self.jedi.yaml}") - @logit(logger) - def fv3_increment(self: Analysis) -> None: - # Run executable - exec_cmd = Executable(self.task_config.APRUN_ATMANLFV3INC) - exec_cmd.add_default_arg(self.task_config.jedi_exe) - exec_cmd.add_default_arg(self.task_config.jedi_yaml) - - try: - logger.debug(f"Executing {exec_cmd}") - exec_cmd() - except OSError: - raise OSError(f"Failed to execute {exec_cmd}") - except Exception: - raise WorkflowException(f"An error occured during execution of {exec_cmd}") + self.jedi.execute(self.task_config, aprun_cmd, jedi_args) @logit(logger) - def finalize(self: Analysis) -> None: + def finalize(self) -> None: """Finalize a global atm analysis This method will finalize a global atm analysis using JEDI. @@ -171,9 +219,16 @@ def finalize(self: Analysis) -> None: - tar output diag files and place in ROTDIR - copy the generated YAML file from initialize to the ROTDIR - copy the updated bias correction files to ROTDIR - - write UFS model readable atm incrment file + Parameters + ---------- + None + + Returns + ---------- + None """ + # ---- tar up diags # path of output tar statfile atmstat = os.path.join(self.task_config.COM_ATMOS_ANALYSIS, f"{self.task_config.APREFIX}atmstat") @@ -196,16 +251,19 @@ def finalize(self: Analysis) -> None: diaggzip = f"{diagfile}.gz" archive.add(diaggzip, arcname=os.path.basename(diaggzip)) + # get list of yamls to copy to ROTDIR + yamls = glob.glob(os.path.join(self.task_config.DATA, '*atm*yaml')) + # copy full YAML from executable to ROTDIR - logger.info(f"Copying {self.task_config.jedi_yaml} to {self.task_config.COM_ATMOS_ANALYSIS}") - src = os.path.join(self.task_config.DATA, f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.atmvar.yaml") - dest = os.path.join(self.task_config.COM_ATMOS_ANALYSIS, f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.atmvar.yaml") - logger.debug(f"Copying {src} to {dest}") - yaml_copy = { - 'mkdir': [self.task_config.COM_ATMOS_ANALYSIS], - 'copy': [[src, dest]] - } - FileHandler(yaml_copy).sync() + for src in yamls: + yaml_base = os.path.splitext(os.path.basename(src))[0] + dest_yaml_name = f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.{yaml_base}.yaml" + dest = os.path.join(self.task_config.COM_ATMOS_ANALYSIS, dest_yaml_name) + logger.debug(f"Copying {src} to {dest}") + yaml_copy = { + 'copy': [[src, dest]] + } + FileHandler(yaml_copy).sync() # copy bias correction files to ROTDIR logger.info("Copy bias correction files from DATA/ to COM/") diff --git a/ush/python/pygfs/task/atmens_analysis.py b/ush/python/pygfs/task/atmens_analysis.py index 2e51f82d59..55e72702b1 100644 --- a/ush/python/pygfs/task/atmens_analysis.py +++ b/ush/python/pygfs/task/atmens_analysis.py @@ -5,34 +5,52 @@ import gzip import tarfile from logging import getLogger -from typing import Dict, List +from pprint import pformat +from typing import Optional, Dict, Any from wxflow import (AttrDict, FileHandler, add_to_datetime, to_fv3time, to_timedelta, to_YMDH, to_YMD, chdir, + Task, parse_j2yaml, save_as_yaml, logit, Executable, WorkflowException, Template, TemplateConstants) -from pygfs.task.analysis import Analysis -from jcb import render +from pygfs.jedi import Jedi logger = getLogger(__name__.split('.')[-1]) -class AtmEnsAnalysis(Analysis): +class AtmEnsAnalysis(Task): """ - Class for global atmens analysis tasks + Class for JEDI-based global atmens analysis tasks """ @logit(logger, name="AtmEnsAnalysis") - def __init__(self, config): + def __init__(self, config: Dict[str, Any], yaml_name: Optional[str] = None): + """Constructor global atmens analysis task + + This method will construct a global atmens analysis task. + This includes: + - extending the task_config attribute AttrDict to include parameters required for this task + - instantiate the Jedi attribute object + + Parameters + ---------- + config: Dict + dictionary object containing task configuration + yaml_name: str, optional + name of YAML file for JEDI configuration + + Returns + ---------- + None + """ super().__init__(config) _res = int(self.task_config.CASE_ENS[1:]) _window_begin = add_to_datetime(self.task_config.current_cycle, -to_timedelta(f"{self.task_config.assim_freq}H") / 2) - _jedi_yaml = os.path.join(self.task_config.DATA, f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.atmens.yaml") # Create a local dictionary that is repeatedly used across this class local_dict = AttrDict( @@ -46,7 +64,6 @@ def __init__(self, config): 'OPREFIX': f"{self.task_config.EUPD_CYC}.t{self.task_config.cyc:02d}z.", 'APREFIX': f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.", 'GPREFIX': f"gdas.t{self.task_config.previous_cycle.hour:02d}z.", - 'jedi_yaml': _jedi_yaml, 'atm_obsdatain_path': f"./obs/", 'atm_obsdataout_path': f"./diags/", 'BKG_TSTEP': "PT1H" # Placeholder for 4D applications @@ -56,21 +73,56 @@ def __init__(self, config): # Extend task_config with local_dict self.task_config = AttrDict(**self.task_config, **local_dict) + # Create JEDI object + self.jedi = Jedi(self.task_config, yaml_name) + + @logit(logger) + def initialize_jedi(self): + """Initialize JEDI application + + This method will initialize a JEDI application used in the global atmens analysis. + This includes: + - generating and saving JEDI YAML config + - linking the JEDI executable + + Parameters + ---------- + None + + Returns + ---------- + None + """ + + # get JEDI config and save to YAML file + logger.info(f"Generating JEDI config: {self.jedi.yaml}") + self.jedi.set_config(self.task_config) + logger.debug(f"JEDI config:\n{pformat(self.jedi.config)}") + + # save JEDI config to YAML file + logger.info(f"Writing JEDI config to YAML file: {self.jedi.yaml}") + save_as_yaml(self.jedi.config, self.jedi.yaml) + + # link JEDI-to-FV3 increment converter executable + logger.info(f"Linking JEDI executable {self.task_config.JEDIEXE} to {self.jedi.exe}") + self.jedi.link_exe(self.task_config) + @logit(logger) - def initialize(self: Analysis) -> None: + def initialize_analysis(self) -> None: """Initialize a global atmens analysis - This method will initialize a global atmens analysis using JEDI. + This method will initialize a global atmens analysis. This includes: + - staging observation files + - staging bias correction files - staging CRTM fix files - staging FV3-JEDI fix files - staging model backgrounds - - generating a YAML file for the JEDI executable - creating output directories Parameters ---------- - Analysis: parent class for GDAS task + None Returns ---------- @@ -78,26 +130,35 @@ def initialize(self: Analysis) -> None: """ super().initialize() + # stage observations + logger.info(f"Staging list of observation files generated from JEDI config") + obs_dict = self.jedi.get_obs_dict(self.task_config) + FileHandler(obs_dict).sync() + logger.debug(f"Observation files:\n{pformat(obs_dict)}") + + # stage bias corrections + logger.info(f"Staging list of bias correction files generated from JEDI config") + bias_dict = self.jedi.get_bias_dict(self.task_config) + FileHandler(bias_dict).sync() + logger.debug(f"Bias correction files:\n{pformat(bias_dict)}") + # stage CRTM fix files logger.info(f"Staging CRTM fix files from {self.task_config.CRTM_FIX_YAML}") - crtm_fix_list = parse_j2yaml(self.task_config.CRTM_FIX_YAML, self.task_config) - FileHandler(crtm_fix_list).sync() + crtm_fix_dict = parse_j2yaml(self.task_config.CRTM_FIX_YAML, self.task_config) + FileHandler(crtm_fix_dict).sync() + logger.debug(f"CRTM fix files:\n{pformat(crtm_fix_dict)}") # stage fix files logger.info(f"Staging JEDI fix files from {self.task_config.JEDI_FIX_YAML}") - jedi_fix_list = parse_j2yaml(self.task_config.JEDI_FIX_YAML, self.task_config) - FileHandler(jedi_fix_list).sync() + jedi_fix_dict = parse_j2yaml(self.task_config.JEDI_FIX_YAML, self.task_config) + FileHandler(jedi_fix_dict).sync() + logger.debug(f"JEDI fix files:\n{pformat(jedi_fix_dict)}") # stage backgrounds logger.info(f"Stage ensemble member background files") bkg_staging_dict = parse_j2yaml(self.task_config.LGETKF_BKG_STAGING_YAML, self.task_config) FileHandler(bkg_staging_dict).sync() - - # generate ensemble da YAML file - if not self.task_config.lobsdiag_forenkf: - logger.debug(f"Generate ensemble da YAML file: {self.task_config.jedi_yaml}") - save_as_yaml(self.task_config.jedi_config, self.task_config.jedi_yaml) - logger.info(f"Wrote ensemble da YAML to: {self.task_config.jedi_yaml}") + logger.debug(f"Ensemble member background files:\n{pformat(bkg_staging_dict)}") # need output dir for diags and anl logger.debug("Create empty output [anl, diags] directories to receive output from executable") @@ -108,187 +169,47 @@ def initialize(self: Analysis) -> None: FileHandler({'mkdir': newdirs}).sync() @logit(logger) - def observe(self: Analysis) -> None: - """Execute a global atmens analysis in observer mode + def execute(self, aprun_cmd: str, jedi_args: Optional[str] = None) -> None: + """Run JEDI executable - This method will execute a global atmens analysis in observer mode using JEDI. - This includes: - - changing to the run directory - - running the global atmens analysis executable in observer mode + This method will run JEDI executables for the global atmens analysis Parameters ---------- - Analysis: parent class for GDAS task - + aprun_cmd : str + Run command for JEDI application on HPC system + jedi_args : List + List of additional optional arguments for JEDI application Returns ---------- None """ - chdir(self.task_config.DATA) - - exec_cmd = Executable(self.task_config.APRUN_ATMENSANLOBS) - exec_name = os.path.join(self.task_config.DATA, 'gdas.x') - - exec_cmd.add_default_arg(exec_name) - exec_cmd.add_default_arg('fv3jedi') - exec_cmd.add_default_arg('localensembleda') - exec_cmd.add_default_arg(self.task_config.jedi_yaml) - try: - logger.debug(f"Executing {exec_cmd}") - exec_cmd() - except OSError: - raise OSError(f"Failed to execute {exec_cmd}") - except Exception: - raise WorkflowException(f"An error occured during execution of {exec_cmd}") + if jedi_args: + logger.info(f"Executing {self.jedi.exe} {' '.join(jedi_args)} {self.jedi.yaml}") + else: + logger.info(f"Executing {self.jedi.exe} {self.jedi.yaml}") - pass + self.jedi.execute(self.task_config, aprun_cmd, jedi_args) @logit(logger) - def solve(self: Analysis) -> None: - """Execute a global atmens analysis in solver mode - - This method will execute a global atmens analysis in solver mode using JEDI. - This includes: - - changing to the run directory - - running the global atmens analysis executable in solver mode - - Parameters - ---------- - Analysis: parent class for GDAS task - - Returns - ---------- - None - """ - chdir(self.task_config.DATA) - - exec_cmd = Executable(self.task_config.APRUN_ATMENSANLSOL) - exec_name = os.path.join(self.task_config.DATA, 'gdas.x') - - exec_cmd.add_default_arg(exec_name) - exec_cmd.add_default_arg('fv3jedi') - exec_cmd.add_default_arg('localensembleda') - exec_cmd.add_default_arg(self.task_config.jedi_yaml) - - try: - logger.debug(f"Executing {exec_cmd}") - exec_cmd() - except OSError: - raise OSError(f"Failed to execute {exec_cmd}") - except Exception: - raise WorkflowException(f"An error occured during execution of {exec_cmd}") - - pass - - @logit(logger) - def letkf(self: Analysis) -> None: - """Execute a global atmens analysis - - This method will execute a global atmens analysis using JEDI. - This includes: - - changing to the run directory - - running the global atmens analysis executable - - Parameters - ---------- - Analysis: parent class for GDAS task - - Returns - ---------- - None - """ - chdir(self.task_config.DATA) - - exec_cmd = Executable(self.task_config.APRUN_ATMENSANLLETKF) - exec_name = os.path.join(self.task_config.DATA, 'gdas.x') - - exec_cmd.add_default_arg(exec_name) - exec_cmd.add_default_arg('fv3jedi') - exec_cmd.add_default_arg('localensembleda') - exec_cmd.add_default_arg(self.task_config.jedi_yaml) - - try: - logger.debug(f"Executing {exec_cmd}") - exec_cmd() - except OSError: - raise OSError(f"Failed to execute {exec_cmd}") - except Exception: - raise WorkflowException(f"An error occured during execution of {exec_cmd}") - - pass - - @logit(logger) - def init_observer(self: Analysis) -> None: - # Setup JEDI YAML file - jcb_config = parse_j2yaml(self.task_config.JCB_BASE_YAML, self.task_config) - jcb_algo_config = parse_j2yaml(self.task_config.JCB_ALGO_YAML, self.task_config) - jcb_config.update(jcb_algo_config) - jedi_config = render(jcb_config) - - self.task_config.jedi_yaml = os.path.join(self.task_config.DATA, f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.atmens_observer.yaml") - - logger.debug(f"Generate ensemble da observer YAML file: {self.task_config.jedi_yaml}") - save_as_yaml(jedi_config, self.task_config.jedi_yaml) - logger.info(f"Wrote ensemble da observer YAML to: {self.task_config.jedi_yaml}") - - @logit(logger) - def init_solver(self: Analysis) -> None: - # Setup JEDI YAML file - jcb_config = parse_j2yaml(self.task_config.JCB_BASE_YAML, self.task_config) - jcb_algo_config = parse_j2yaml(self.task_config.JCB_ALGO_YAML, self.task_config) - jcb_config.update(jcb_algo_config) - jedi_config = render(jcb_config) - - self.task_config.jedi_yaml = os.path.join(self.task_config.DATA, f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.atmens_solver.yaml") - - logger.debug(f"Generate ensemble da solver YAML file: {self.task_config.jedi_yaml}") - save_as_yaml(jedi_config, self.task_config.jedi_yaml) - logger.info(f"Wrote ensemble da solver YAML to: {self.task_config.jedi_yaml}") - - @logit(logger) - def init_fv3_increment(self: Analysis) -> None: - # Setup JEDI YAML file - self.task_config.jedi_yaml = os.path.join(self.task_config.DATA, - f"{self.task_config.JCB_ALGO}.yaml") - save_as_yaml(self.get_jedi_config(self.task_config.JCB_ALGO), self.task_config.jedi_yaml) - - # Link JEDI executable to run directory - self.task_config.jedi_exe = self.link_jediexe() - - @logit(logger) - def fv3_increment(self: Analysis) -> None: - # Run executable - exec_cmd = Executable(self.task_config.APRUN_ATMENSANLFV3INC) - exec_cmd.add_default_arg(self.task_config.jedi_exe) - exec_cmd.add_default_arg(self.task_config.jedi_yaml) - - try: - logger.debug(f"Executing {exec_cmd}") - exec_cmd() - except OSError: - raise OSError(f"Failed to execute {exec_cmd}") - except Exception: - raise WorkflowException(f"An error occured during execution of {exec_cmd}") - - @logit(logger) - def finalize(self: Analysis) -> None: + def finalize(self) -> None: """Finalize a global atmens analysis This method will finalize a global atmens analysis using JEDI. This includes: - tar output diag files and place in ROTDIR - copy the generated YAML file from initialize to the ROTDIR - - write UFS model readable atm incrment file Parameters ---------- - Analysis: parent class for GDAS task + None Returns ---------- None """ + # ---- tar up diags # path of output tar statfile atmensstat = os.path.join(self.task_config.COM_ATMOS_ANALYSIS_ENS, f"{self.task_config.APREFIX}atmensstat") @@ -317,7 +238,9 @@ def finalize(self: Analysis) -> None: # copy full YAML from executable to ROTDIR for src in yamls: logger.info(f"Copying {src} to {self.task_config.COM_ATMOS_ANALYSIS_ENS}") - dest = os.path.join(self.task_config.COM_ATMOS_ANALYSIS_ENS, os.path.basename(src)) + yaml_base = os.path.splitext(os.path.basename(src))[0] + dest_yaml_name = f"{self.task_config.RUN}.t{self.task_config.cyc:02d}z.{yaml_base}.yaml" + dest = os.path.join(self.task_config.COM_ATMOS_ANALYSIS_ENS, dest_yaml_name) logger.debug(f"Copying {src} to {dest}") yaml_copy = { 'copy': [[src, dest]] @@ -337,6 +260,7 @@ def finalize(self: Analysis) -> None: logger.info("Copy UFS model readable atm increment file") cdate = to_fv3time(self.task_config.current_cycle) cdate_inc = cdate.replace('.', '_') + # loop over ensemble members for imem in range(1, self.task_config.NMEM_ENS + 1): memchar = f"mem{imem:03d}"