From 724a0362fb309703da90686fc06127cf7ee910ea Mon Sep 17 00:00:00 2001 From: sambles Date: Fri, 24 May 2024 10:17:54 +0100 Subject: [PATCH] Fix running V2 workers with custom OED specification files (#1051) * fix "oed_schema_info" getting removed from task params * Pass oed_custom_spec into OedExposure load * pep * Updated Package Requirements: pymysql==1.1.1 --------- Co-authored-by: awsbuild --- requirements-server.txt | 2 +- requirements-worker.txt | 2 +- requirements.txt | 2 +- .../distributed_tasks.py | 19 +++++++++++++------ src/model_execution_worker/tasks.py | 7 ++++--- src/model_execution_worker/utils.py | 14 ++++++++++++++ 6 files changed, 34 insertions(+), 12 deletions(-) diff --git a/requirements-server.txt b/requirements-server.txt index 3aee6238d..454e1939c 100644 --- a/requirements-server.txt +++ b/requirements-server.txt @@ -246,7 +246,7 @@ pycparser==2.21 # via cffi pyjwt==2.8.0 # via djangorestframework-simplejwt -pymysql==1.1.0 +pymysql==1.1.1 # via -r requirements-server.in pyopenssl==24.0.0 # via diff --git a/requirements-worker.txt b/requirements-worker.txt index 7508954ea..dd3e24709 100644 --- a/requirements-worker.txt +++ b/requirements-worker.txt @@ -236,7 +236,7 @@ pyjwt[crypto]==2.8.0 # via # msal # pyjwt -pymysql==1.1.0 +pymysql==1.1.1 # via -r requirements-worker.in pyproj==3.6.1 # via geopandas diff --git a/requirements.txt b/requirements.txt index 37e584683..56a254709 100644 --- a/requirements.txt +++ b/requirements.txt @@ -464,7 +464,7 @@ pyjwt[crypto]==2.8.0 # via # djangorestframework-simplejwt # msal -pymysql==1.1.0 +pymysql==1.1.1 # via # -r ./requirements-server.in # -r ./requirements-worker.in diff --git a/src/model_execution_worker/distributed_tasks.py b/src/model_execution_worker/distributed_tasks.py index 831f59c5a..fd12dfc75 100644 --- a/src/model_execution_worker/distributed_tasks.py +++ b/src/model_execution_worker/distributed_tasks.py @@ -36,6 +36,7 @@ get_worker_versions, merge_dirs, prepare_complex_model_file_inputs, + config_strip_default_exposure, ) @@ -87,7 +88,7 @@ def notify_subtask_status(analysis_id, initiator_id, task_slug, subtask_status, ).delay() -def load_location_data(loc_filepath): +def load_location_data(loc_filepath, oed_schema_info=None): """ Returns location file as DataFrame Returns a DataFrame of Loaction data with 'loc_id' row assgined @@ -102,7 +103,10 @@ def load_location_data(loc_filepath): from oasislmf.utils.data import prepare_location_df from ods_tools.oed.exposure import OedExposure - exposure = OedExposure(location=pathlib.Path(os.path.abspath(loc_filepath))) + exposure = OedExposure( + location=pathlib.Path(os.path.abspath(loc_filepath)), + oed_schema_info=oed_schema_info, + ) exposure.location.dataframe = prepare_location_df(exposure.location.dataframe) return exposure.location.dataframe @@ -468,8 +472,8 @@ def prepare_input_generation_params( model_id = settings.get('worker', 'model_id') config_path = get_oasislmf_config_path(settings, model_id) - config = get_json(config_path) - lookup_params = {**{k: v for k, v in config.items() if not k.startswith('oed_')}, **params} + config = config_strip_default_exposure(get_json(config_path)) + lookup_params = {**config, **params} from oasislmf.manager import OasisManager gen_files_params = OasisManager()._params_generate_files(**lookup_params) @@ -571,7 +575,10 @@ def prepare_keys_file_chunk( output_directory=chunk_target_dir, ) - location_df = load_location_data(params['oed_location_csv']) + location_df = load_location_data( + loc_filepath=params['oed_location_csv'], + oed_schema_info=params.get('oed_schema_info', None) + ) location_df = np.array_split(location_df, num_chunks)[chunk_idx] location_df.reset_index(drop=True, inplace=True) @@ -874,7 +881,7 @@ def prepare_losses_generation_params( model_id = settings.get('worker', 'model_id') config_path = get_oasislmf_config_path(settings, model_id) - config = get_json(config_path) + config = config_strip_default_exposure(get_json(config_path)) run_params = {**config, **params} from oasislmf.manager import OasisManager diff --git a/src/model_execution_worker/tasks.py b/src/model_execution_worker/tasks.py index dacd861df..eb423a573 100755 --- a/src/model_execution_worker/tasks.py +++ b/src/model_execution_worker/tasks.py @@ -34,6 +34,7 @@ get_model_settings, get_worker_versions, prepare_complex_model_file_inputs, + config_strip_default_exposure, ) ''' @@ -297,7 +298,7 @@ def start_analysis(analysis_settings, input_location, complex_data_files=None, * # oasislmf.json config_path = get_oasislmf_config_path(settings) - config = get_json(config_path) + config = config_strip_default_exposure(get_json(config_path)) # model settings model_settings_fp = settings.get('worker', 'MODEL_SETTINGS_FILE', fallback='') @@ -466,8 +467,8 @@ def generate_input(self, task_params['user_data_dir'] = input_data_dir config_path = get_oasislmf_config_path(settings) - config = get_json(config_path) - lookup_params = {**{k: v for k, v in config.items() if not k.startswith('oed_')}, **task_params} + config = config_strip_default_exposure(get_json(config_path)) + lookup_params = {**config, **task_params} gen_files_params = OasisManager()._params_generate_files(**lookup_params) pre_hook_params = OasisManager()._params_exposure_pre_analysis(**lookup_params) diff --git a/src/model_execution_worker/utils.py b/src/model_execution_worker/utils.py index fadefa673..c6fbef9cf 100644 --- a/src/model_execution_worker/utils.py +++ b/src/model_execution_worker/utils.py @@ -9,6 +9,7 @@ 'InvalidInputsException', 'MissingModelDataException', 'prepare_complex_model_file_inputs', + 'config_strip_default_exposure', ] import logging @@ -88,6 +89,19 @@ def paths_to_absolute_paths(dictionary, config_path=''): return params +def config_strip_default_exposure(config): + """ Safeguard to make sure any 'oasislmf.json' files have platform default stripped out + """ + exclude_list = [ + 'oed_location_csv', + 'oed_accounts_csv', + 'oed_info_csv', + 'oed_scope_csv', + 'analysis_settings_json' + ] + return {k: v for k, v in config.items() if k not in exclude_list} + + class TemporaryDir(object): """Context manager for mkdtemp() with option to persist"""