From 23a2f71b4e17bc2f6d8d5c5abc088c7354defcec Mon Sep 17 00:00:00 2001 From: Julius Busecke Date: Mon, 24 Jun 2024 10:40:33 -0400 Subject: [PATCH 1/2] Fix CI errors --- ci/environment.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/environment.yaml b/ci/environment.yaml index 289d64f..75a6a79 100644 --- a/ci/environment.yaml +++ b/ci/environment.yaml @@ -10,6 +10,7 @@ dependencies: - fsspec - gcsfs - intake-esm + - intake-xarray - numpy - pip - pytest From 5e9f92f6e2d3d7a780746a67aae568e5b85dd5df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:42:00 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .github/workflows/build_image.yaml | 2 +- Dockerfile | 4 +- README.md | 2 +- .../new-workspace.jupyterlab-workspace | 2 +- notebooks/jbusecke/testing/xesmf_test.py | 1 - .../jbusecke/workflow/cm26_pipeline-debug.py | 15 +- notebooks/jbusecke/workflow/cm26_pipeline.py | 10 +- notebooks/jbusecke/workflow/cm26_utils.py | 24 +-- notebooks/jbusecke/workflow/cm26_utils_old.py | 36 ++-- .../jbusecke/workflow/debugging_script.py | 8 +- notebooks/jbusecke/workflow/utils.py | 22 +- pyproject.toml | 2 +- scale_aware_air_sea/old/cesm_utils.py | 85 ++++---- scale_aware_air_sea/old/cm26_utils.py | 19 +- scale_aware_air_sea/parameters.py | 155 +++++++------- scale_aware_air_sea/plotting.py | 25 ++- scale_aware_air_sea/stages.py | 175 ++++++++-------- scale_aware_air_sea/stages_tests.py | 132 +++++++----- scale_aware_air_sea/utils.py | 105 +++++----- tests/test_utils.py | 195 ++++++++++-------- 20 files changed, 545 insertions(+), 474 deletions(-) diff --git a/.github/workflows/build_image.yaml b/.github/workflows/build_image.yaml index c1da60a..53e33ed 100644 --- a/.github/workflows/build_image.yaml +++ b/.github/workflows/build_image.yaml @@ -6,7 +6,7 @@ on: - main paths: # only run on changes to the Dockerfile - 'Dockerfile' - + jobs: build-and-push: diff --git a/Dockerfile b/Dockerfile index f5d28ed..8656c3b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ -FROM quay.io/pangeo/pangeo-notebook:2023.10.24 +FROM quay.io/pangeo/pangeo-notebook:2023.10.24 LABEL maintainer="Julius Busecked" LABEL repo="https://github.com/ocean-transport/scale-aware-air-sea" -RUN mamba install -n=notebook aerobulk-python -y +RUN mamba install -n=notebook aerobulk-python -y RUN pip install coiled diff --git a/README.md b/README.md index c0cdccf..fb998e9 100644 --- a/README.md +++ b/README.md @@ -17,4 +17,4 @@ pip install . Follow the above instructions but install the package via ``` pip install -e ".[dev]" -``` \ No newline at end of file +``` diff --git a/notebooks/jbusecke/new-workspace.jupyterlab-workspace b/notebooks/jbusecke/new-workspace.jupyterlab-workspace index 83e464e..8a7032f 100644 --- a/notebooks/jbusecke/new-workspace.jupyterlab-workspace +++ b/notebooks/jbusecke/new-workspace.jupyterlab-workspace @@ -1 +1 @@ -{"data":{"layout-restorer:data":{"main":{"dock":{"type":"split-area","orientation":"horizontal","sizes":[0.5,0.5],"children":[{"type":"tab-area","currentIndex":2,"widgets":["notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/aerobulk-python_performance.ipynb","terminal:1","notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/cm26_pipeline.ipynb","notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/reproduce_cm26_flux.ipynb","notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/testing/cm26_xesmf_hack.ipynb"]},{"type":"tab-area","currentIndex":3,"widgets":["dask-dashboard-launcher:/individual-progress","dask-dashboard-launcher:/individual-workers","dask-dashboard-launcher:/individual-task-stream","dask-dashboard-launcher:/individual-workers-memory"]}]},"current":"dask-dashboard-launcher:/individual-workers-memory"},"down":{"size":0,"widgets":[]},"left":{"collapsed":true,"widgets":["filebrowser","running-sessions","dask-dashboard-launcher","git-sessions","@jupyterlab/toc:plugin","extensionmanager.main-view"]},"right":{"collapsed":true,"widgets":["jp-property-inspector","debugger-sidebar"]},"relativeSizes":[0,1,0]},"file-browser-filebrowser:cwd":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke"},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/aerobulk-python_performance.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/aerobulk-python_performance.ipynb","factory":"Notebook"}},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/cm26_pipeline.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/cm26_pipeline.ipynb","factory":"Notebook"}},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/reproduce_cm26_flux.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/reproduce_cm26_flux.ipynb","factory":"Notebook"}},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/testing/cm26_xesmf_hack.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/testing/cm26_xesmf_hack.ipynb","factory":"Notebook"}},"terminal:1":{"data":{"name":"1"}},"dask-dashboard-launcher":{"url":"https://us-central1-b.gcp.pangeo.io/services/dask-gateway/clusters/prod.1e69eadfa71b4df0842dbaaeb5c7b01d/","cluster":""},"dask-dashboard-launcher:/individual-progress":{"data":{"route":"/individual-progress","label":"Progress","key":"Progress"}},"dask-dashboard-launcher:/individual-workers":{"data":{"route":"/individual-workers","label":"Workers","key":"Workers"}},"dask-dashboard-launcher:/individual-task-stream":{"data":{"route":"/individual-task-stream","label":"Task Stream","key":"Task Stream"}},"dask-dashboard-launcher:/individual-workers-memory":{"data":{"route":"/individual-workers-memory","label":"Workers Memory","key":"Workers Memory"}}},"metadata":{"id":"new-workspace","last_modified":"2022-06-23T17:40:03.102459+00:00","created":"2022-06-23T17:40:03.102459+00:00"}} \ No newline at end of file +{"data":{"layout-restorer:data":{"main":{"dock":{"type":"split-area","orientation":"horizontal","sizes":[0.5,0.5],"children":[{"type":"tab-area","currentIndex":2,"widgets":["notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/aerobulk-python_performance.ipynb","terminal:1","notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/cm26_pipeline.ipynb","notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/reproduce_cm26_flux.ipynb","notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/testing/cm26_xesmf_hack.ipynb"]},{"type":"tab-area","currentIndex":3,"widgets":["dask-dashboard-launcher:/individual-progress","dask-dashboard-launcher:/individual-workers","dask-dashboard-launcher:/individual-task-stream","dask-dashboard-launcher:/individual-workers-memory"]}]},"current":"dask-dashboard-launcher:/individual-workers-memory"},"down":{"size":0,"widgets":[]},"left":{"collapsed":true,"widgets":["filebrowser","running-sessions","dask-dashboard-launcher","git-sessions","@jupyterlab/toc:plugin","extensionmanager.main-view"]},"right":{"collapsed":true,"widgets":["jp-property-inspector","debugger-sidebar"]},"relativeSizes":[0,1,0]},"file-browser-filebrowser:cwd":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke"},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/aerobulk-python_performance.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/aerobulk-python_performance.ipynb","factory":"Notebook"}},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/cm26_pipeline.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/cm26_pipeline.ipynb","factory":"Notebook"}},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/reproduce_cm26_flux.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/workflow/reproduce_cm26_flux.ipynb","factory":"Notebook"}},"notebook:1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/testing/cm26_xesmf_hack.ipynb":{"data":{"path":"1_PROJECTS/scale-aware-air-sea/notebooks/jbusecke/testing/cm26_xesmf_hack.ipynb","factory":"Notebook"}},"terminal:1":{"data":{"name":"1"}},"dask-dashboard-launcher":{"url":"https://us-central1-b.gcp.pangeo.io/services/dask-gateway/clusters/prod.1e69eadfa71b4df0842dbaaeb5c7b01d/","cluster":""},"dask-dashboard-launcher:/individual-progress":{"data":{"route":"/individual-progress","label":"Progress","key":"Progress"}},"dask-dashboard-launcher:/individual-workers":{"data":{"route":"/individual-workers","label":"Workers","key":"Workers"}},"dask-dashboard-launcher:/individual-task-stream":{"data":{"route":"/individual-task-stream","label":"Task Stream","key":"Task Stream"}},"dask-dashboard-launcher:/individual-workers-memory":{"data":{"route":"/individual-workers-memory","label":"Workers Memory","key":"Workers Memory"}}},"metadata":{"id":"new-workspace","last_modified":"2022-06-23T17:40:03.102459+00:00","created":"2022-06-23T17:40:03.102459+00:00"}} diff --git a/notebooks/jbusecke/testing/xesmf_test.py b/notebooks/jbusecke/testing/xesmf_test.py index 4a5c975..776807f 100644 --- a/notebooks/jbusecke/testing/xesmf_test.py +++ b/notebooks/jbusecke/testing/xesmf_test.py @@ -12,4 +12,3 @@ ds_atmos = xr.open_zarr('gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr', **kwargs) regridder = xe.Regridder(ds_atmos.olr.isel(time=0), ds.surface_temp.isel(time=0), 'bilinear', periodic=True) # all the atmos data is on the cell center AFAIK - diff --git a/notebooks/jbusecke/workflow/cm26_pipeline-debug.py b/notebooks/jbusecke/workflow/cm26_pipeline-debug.py index b05ec72..58a165d 100644 --- a/notebooks/jbusecke/workflow/cm26_pipeline-debug.py +++ b/notebooks/jbusecke/workflow/cm26_pipeline-debug.py @@ -2,7 +2,7 @@ # coding: utf-8 # # Debug issues with exceeding wind stress -# +# # ## Tasks # - Find timestep(s) that exhibit the behavior # - Confirm that this behavior only affects certain algorithms @@ -31,7 +31,7 @@ from dask.diagnostics import ProgressBar from cm26_utils import write_split_zarr, noskin_ds_wrapper, load_and_merge_cm26 -# 👇 replace with your key +# 👇 replace with your key with open('/home/jovyan/keys/pangeo-forge-ocean-transport-4967-347e2048c5a1.json') as token_file: token = json.load(token_file) @@ -63,7 +63,7 @@ # I did `jupyter nbconvert --to python cm26_pipeline-debug.ipynb` # and then I get the error about the wind stress. -# I have executed this with all algos and I get crashes for: +# I have executed this with all algos and I get crashes for: # 'coare3p6' # 'andreas' # 'coare3p0' @@ -87,7 +87,7 @@ # ## Investigate the max windstress values we are getting with the working algos -# +# # Is there a correlation betweewn max wind speeds and stresses? Yeah definitely! # In[ ]: @@ -118,8 +118,8 @@ # ## Ok can we actually get around this and get some results at all? # If not we need to raise the tau cut of in aerobulk. -# -# My simple approach right here is to set every wind value larger than `threshold` to zero. This is not a feasible solution for our processing, but I just want to see how low we have to go to get all algos to go through! +# +# My simple approach right here is to set every wind value larger than `threshold` to zero. This is not a feasible solution for our processing, but I just want to see how low we have to go to get all algos to go through! # In[ ]: @@ -133,7 +133,7 @@ mask = ds_masked.wind>threshold ds_masked['u_ref'] = ds_masked['u_ref'].where(mask, 0) ds_masked['v_ref'] = ds_masked['v_ref'].where(mask, 0) - + break ds_out = noskin_ds_wrapper(ds_merged, algo=algo, input_range_check=False) with ProgressBar(): @@ -143,4 +143,3 @@ stress_max = stress.max(['xt_ocean', 'yt_ocean']).assign_coords(algo=algo) print(stress_max) datasets.append(stress_max) - diff --git a/notebooks/jbusecke/workflow/cm26_pipeline.py b/notebooks/jbusecke/workflow/cm26_pipeline.py index d299b56..86f2645 100644 --- a/notebooks/jbusecke/workflow/cm26_pipeline.py +++ b/notebooks/jbusecke/workflow/cm26_pipeline.py @@ -18,7 +18,7 @@ from dask.diagnostics import ProgressBar from cm26_utils import write_split_zarr, noskin_ds_wrapper -# 👇 replace with your key +# 👇 replace with your key with open('/home/jovyan/keys/pangeo-forge-ocean-transport-4967-347e2048c5a1.json') as token_file: token = json.load(token_file) fs = gcsfs.GCSFileSystem(token=token) @@ -87,7 +87,7 @@ ############################# wind cutting ############## # ok so lets not add nans into fields like above. Instead, lets see in which timesteps this actually occurs and for noe completely ignore these timesteps -# This is not ideal in the long run, but maybe at least gives us a way to output +# This is not ideal in the long run, but maybe at least gives us a way to output # ds_cut = ds_merged.isel(time=slice(0,500)) # wind = ds_cut.wind @@ -96,7 +96,7 @@ threshold = 30 with ProgressBar(): strong_wind_cells = (wind > threshold).sum(['xt_ocean','yt_ocean']).load() - + strong_wind_index = strong_wind_cells > 0 # double check that these events are still rare in space and time @@ -121,6 +121,6 @@ if fs.exists(path) and overwrite: # # # delete the mapper (only uncomment if you want to start from scratch!) print("DELETE existing store") - fs.rm(path, recursive=True) + fs.rm(path, recursive=True) -write_split_zarr(mapper, ds_out, split_interval=64) \ No newline at end of file +write_split_zarr(mapper, ds_out, split_interval=64) diff --git a/notebooks/jbusecke/workflow/cm26_utils.py b/notebooks/jbusecke/workflow/cm26_utils.py index 981eb98..096f0b6 100644 --- a/notebooks/jbusecke/workflow/cm26_utils.py +++ b/notebooks/jbusecke/workflow/cm26_utils.py @@ -14,7 +14,7 @@ # - Adjust units for aerobulk input # - Calculate relative wind components # """ -# kwargs = dict(consolidated=True, use_cftime=True, inline_array=inline_array, engine='zarr')#, +# kwargs = dict(consolidated=True, use_cftime=True, inline_array=inline_array, engine='zarr')#, # print('Load Data') # mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/control/surface") # # ds_ocean = xr.open_dataset(mapper, chunks='auto', **kwargs) @@ -22,14 +22,14 @@ # ds_ocean = xr.open_dataset(mapper, chunks={'time':3}, **kwargs) # # cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean/GFDL_CM2.6.yaml") # # ds_ocean = cat["GFDL_CM2_6_control_ocean_surface"].to_dask() - + # # ds_flux = cat["GFDL_CM2_6_control_ocean_boundary_flux"].to_dask() # mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/control/ocean_boundary") # # ds_flux = xr.open_dataset(mapper, chunks='auto', **kwargs) # # ds_flux = xr.open_dataset(mapper, chunks={'time':2}, **kwargs) # ds_flux = xr.open_dataset(mapper, chunks={'time':3}, **kwargs) - - + + # # xarray says not to do this # # ds_atmos = xr.open_zarr('gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr', chunks={'time':1}, **kwargs) # mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr") @@ -37,14 +37,14 @@ # # ds_atmos = xr.open_dataset(mapper, chunks={'time':2}, **kwargs) # # ds_atmos = xr.open_dataset(mapper, chunks={'time':3}, **kwargs) # ds_atmos = xr.open_dataset(mapper, chunks={'time':120}, **kwargs).chunk({'time':3}) - + # # # instead do this # # ds_atmos = ds_atmos.chunk({'time':1}) - + # mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/grid") # ds_oc_grid = xr.open_dataset(mapper, chunks={}, **kwargs) # # ds_oc_grid = cat["GFDL_CM2_6_grid"].to_dask() - + # print('Align in time') # # cut to same time # all_dims = set(list(ds_ocean.dims)+list(ds_atmos.dims)) @@ -78,7 +78,7 @@ # mapper = filesystem.get_mapper(path) # ds_regridder = xr.open_zarr(mapper).load() # regridder = xe.Regridder( -# ds_atmos.olr.to_dataset(name='dummy').isel(time=0).reset_coords(drop=True),# this is the same dumb problem I keep having with +# ds_atmos.olr.to_dataset(name='dummy').isel(time=0).reset_coords(drop=True),# this is the same dumb problem I keep having with # ds_ocean.surface_temp.to_dataset(name='dummy').isel(time=0).reset_coords(drop=True), # 'bilinear', # weights=ds_regridder, @@ -100,17 +100,17 @@ # # fix units for aerobulk # ds_merged['surface_temp'] = ds_merged['surface_temp'] + 273.15 # ds_merged['slp'] = ds_merged['slp'] * 100 # check this - + # print('Mask nans') # # atmos missing values are filled with 0s, which causes issues with the filtering # # Ideally this should be masked before the regridding, but xesmf fills with 0 again... # mask = ~np.isnan(ds_merged['surface_temp']) # for mask_var in ['slp', 't_ref', 'q_ref']: # ds_merged[mask_var] = ds_merged[mask_var].where(mask) - + # # Calculate relative wind # print('Calculate relative wind') # ds_merged['u_relative'] = ds_merged['u_ref'] - ds_merged['u_ocean'] # ds_merged['v_relative'] = ds_merged['v_ref'] - ds_merged['v_ocean'] - -# return ds_merged \ No newline at end of file + +# return ds_merged diff --git a/notebooks/jbusecke/workflow/cm26_utils_old.py b/notebooks/jbusecke/workflow/cm26_utils_old.py index fa95bbc..9eb1cab 100644 --- a/notebooks/jbusecke/workflow/cm26_utils_old.py +++ b/notebooks/jbusecke/workflow/cm26_utils_old.py @@ -11,15 +11,15 @@ def write_split_zarr(store, ds, split_dim='time', chunks=1, split_interval=180): This can be helpful to e.g. avoid problems with overly eager dask schedulers """ # Got my inspiration for this mostly here: https://github.com/pydata/xarray/issues/6069 - + # determine the variables and coordinates that depend on the split_dim other_dims = [di for di in ds.dims if di != split_dim] split_vars_coords = [va for va in ds.variables if split_dim in ds[va].dims and va not in ds.dims] non_split_vars_coords = [va for va in ds.variables if va not in split_vars_coords and va not in ds.dims] - + # Generate a stripped dataset that only contains variables/coordinates that do not depend on `split_dim` ds_stripped = ds.drop_vars(split_vars_coords+[split_dim]) - + # initialize the store without writing values print('initializing store') ds.to_zarr( @@ -28,15 +28,15 @@ def write_split_zarr(store, ds, split_dim='time', chunks=1, split_interval=180): encoding={split_dim:{"chunks":[chunks]}}, consolidated=True, # TODO: Not sure if this is proper. Might have to consolidate the whole thing as a last step? ) - + # Write out only the non-split variables/coordinates if len(non_split_vars_coords) > 0: print('Writing coordinates') ds_stripped.to_zarr(store, mode='a') #I guess a is 'add'. This is honestly not clear enough in the xarray docs. - # with `w` there are issues with the shape. - + # with `w` there are issues with the shape. + # TODO: what about the attrs? - + # Populate split chunks as regions n = len(ds[split_dim]) splits = list(range(0,n,split_interval)) @@ -44,29 +44,29 @@ def write_split_zarr(store, ds, split_dim='time', chunks=1, split_interval=180): # Make sure the last item in the list covers the full length of the time on our dataset if splits[-1] != n: splits = splits + [n] - + for ii in tqdm(range(len(splits)-1)): print(f'Writing split {ii}') # TODO: put some retry logic in here... start = splits[ii] stop = splits[ii+1] - + ds_write = ds.isel({split_dim:slice(start, stop)}) print(f'Start: {ds_write[split_dim][0].data}') print(f'Stop: {ds_write[split_dim][-1].data}') - + # strip everything except the values drop_vars = non_split_vars_coords+other_dims ds_write = ds_write.drop_vars(drop_vars) - + with ProgressBar(): ds_write.to_zarr(store, region={split_dim:slice(start, stop)}, mode='a')#why are the variables not instantiated in the init step - -# TODO: This is model agnostic and should live somewhere else? + +# TODO: This is model agnostic and should live somewhere else? def noskin_ds_wrapper(ds_in, algo='ecmwf', **kwargs): ds_out = xr.Dataset() ds_in = ds_in.copy(deep=False) - + sst = ds_in.surface_temp + 273.15 t_zt = ds_in.t_ref hum_zt = ds_in.q_ref @@ -75,7 +75,7 @@ def noskin_ds_wrapper(ds_in, algo='ecmwf', **kwargs): slp = ds_in.slp * 100 # check this zu = 10 zt = 2 - + ql, qh, taux, tauy, evap = noskin( sst, t_zt, @@ -91,7 +91,7 @@ def noskin_ds_wrapper(ds_in, algo='ecmwf', **kwargs): ds_out['ql'] = ql ds_out['qh'] = qh ds_out['evap'] = evap - ds_out['taux'] = taux + ds_out['taux'] = taux ds_out['tauy'] = tauy return ds_out @@ -114,7 +114,7 @@ def load_and_merge_cm26(regridder_token): ) # instead do this ds_atmos = ds_atmos.chunk({'time':1}) - + fs = gcsfs.GCSFileSystem(token=regridder_token) path = 'ocean-transport-group/scale-aware-air-sea/regridding_weights/CM26_atmos2ocean.zarr' mapper = fs.get_mapper(path) @@ -138,4 +138,4 @@ def load_and_merge_cm26(regridder_token): ds_merged = ds_merged.transpose( 'xt_ocean', 'yt_ocean', 'time' ) - return ds_merged \ No newline at end of file + return ds_merged diff --git a/notebooks/jbusecke/workflow/debugging_script.py b/notebooks/jbusecke/workflow/debugging_script.py index ba5dc4b..3d67310 100644 --- a/notebooks/jbusecke/workflow/debugging_script.py +++ b/notebooks/jbusecke/workflow/debugging_script.py @@ -4,7 +4,7 @@ def noskin_ds_wrapper(ds_in): ds_out = xr.Dataset() ds_in = ds_in.copy(deep=False) - + sst = ds_in.surface_temp + 273.15 t_zt = ds_in.t_ref hum_zt = ds_in.q_ref @@ -13,7 +13,7 @@ def noskin_ds_wrapper(ds_in): slp = ds_in.slp * 100 # check this zu = 10 zt = 2 - + ql, qh, taux, tauy, evap = noskin( sst, t_zt, @@ -28,10 +28,10 @@ def noskin_ds_wrapper(ds_in): ds_out['ql'] = ql ds_out['qh'] = qh ds_out['evap'] = evap - ds_out['taux'] = taux + ds_out['taux'] = taux ds_out['tauy'] = tauy return ds_out # load the tempsave file from `cm26_pipeline.ipynb` ds_coarsened = xr.open_dataset('test_coarsened_filled.nc') -ds_coarse_res = noskin_ds_wrapper(ds_coarsened) \ No newline at end of file +ds_coarse_res = noskin_ds_wrapper(ds_coarsened) diff --git a/notebooks/jbusecke/workflow/utils.py b/notebooks/jbusecke/workflow/utils.py index 0585ffb..f443091 100644 --- a/notebooks/jbusecke/workflow/utils.py +++ b/notebooks/jbusecke/workflow/utils.py @@ -28,7 +28,7 @@ # mask_da = da.isel({timedim:0}) # else: # mask_da = da - + # wet_mask = (~np.isnan(mask_da)).astype(int) # ds_out[var] = smooth_inputs(da, wet_mask, dims, filter_scale) # return ds_out @@ -39,19 +39,19 @@ # ['yt_ocean', 'xt_ocean'], # filter_scale # ) - + # all_smoothing_options_except_full = [s for s in ds.smoothing.data if 'full' not in s] - - + + # diff_filtered = ds_filtered.sel(smoothing='smooth_full')-ds_filtered.sel(smoothing=all_smoothing_options_except_full) # diff_unfiltered = ds_filtered.sel(smoothing='smooth_full')-ds.sel(smoothing=all_smoothing_options_except_full) - + # # assigne scale datasets # ds_full=ds_filtered.sel(smoothing='smooth_full') - + # ds_large_scale = ds_filtered.sel(smoothing='smooth_all') - - + + # ds_small_scale = xr.concat( # [ # diff_unfiltered.sel(smoothing='smooth_all'), # the main result, @@ -59,10 +59,10 @@ # ], # dim='smoothing' # ) - + # # mask the outputs # ds_full = ds_full.where(mask) # ds_large_scale = ds_large_scale.where(mask) # ds_small_scale = ds_small_scale.where(mask) - -# return ds_full, ds_large_scale, ds_small_scale \ No newline at end of file + +# return ds_full, ds_large_scale, ds_small_scale diff --git a/pyproject.toml b/pyproject.toml index a8ae0d0..6370d2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,4 +122,4 @@ indent-style = "space" skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. -line-ending = "auto" \ No newline at end of file +line-ending = "auto" diff --git a/scale_aware_air_sea/old/cesm_utils.py b/scale_aware_air_sea/old/cesm_utils.py index 746a22d..1ace645 100644 --- a/scale_aware_air_sea/old/cesm_utils.py +++ b/scale_aware_air_sea/old/cesm_utils.py @@ -23,32 +23,33 @@ def load_and_combine_cesm( # Load ocean data mapper = filesystem.get_mapper("gs://pangeo-cesm-pop/control") ds_ocean = xr.open_dataset(mapper, chunks={"time": 1}, **kwargs) - + # Load atmospheric data - mapper = fsspec.get_mapper("https://ncsa.osn.xsede.org/Pangeo/pangeo-forge-test/prod/recipe-run-502/pangeo-forge/cesm-atm-025deg-feedstock/cesm-atm-025deg.zarr") + mapper = fsspec.get_mapper( + "https://ncsa.osn.xsede.org/Pangeo/pangeo-forge-test/prod/recipe-run-502/pangeo-forge/cesm-atm-025deg-feedstock/cesm-atm-025deg.zarr" + ) ds_atmos = xr.open_dataset(mapper, chunks={}, **kwargs) - ds_atmos = ds_atmos.chunk({"time":1}) + ds_atmos = ds_atmos.chunk({"time": 1}) print("Align in time") # cut to same time - all_dims = set(list(ds_ocean.dims)+list(ds_atmos.dims)) + all_dims = set(list(ds_ocean.dims) + list(ds_atmos.dims)) ds_ocean, ds_atmos = xr.align( ds_ocean, ds_atmos, - join='inner', - exclude=(di for di in all_dims if di !='time') + join="inner", + exclude=(di for di in all_dims if di != "time"), ) print("Interpolating ocean velocities") # interpolate ocean velocities onto the tracer points using xgcm - from xgcm import Grid import pop_tools - + grid, ds_ocean = pop_tools.to_xgcm_grid_dataset(ds_ocean, periodic=False) - + # Fill missing ocean values with 0 - sst_wet_mask = ~np.isnan(ds_ocean['SST'].isel(time=0)) - + sst_wet_mask = ~np.isnan(ds_ocean["SST"].isel(time=0)) + # Do the interpolation ds_ocean["u_ocean"] = grid.interp_like( ds_ocean["U1_1"].fillna(0), ds_ocean["SST"] @@ -59,61 +60,67 @@ def load_and_combine_cesm( print("Regrid Atmospheric Data") # Start regridding the atmosphere onto the ocean grid - + # CESM grid variables are in the datasets, so grab for single time step atmos_grid = ds_atmos.isel(time=0) - ocean_grid = ds_ocean.isel(time=0).drop_vars([co for co in ds_ocean.coords if co not in ['TLONG','TLAT']]) + ocean_grid = ds_ocean.isel(time=0).drop_vars( + [co for co in ds_ocean.coords if co not in ["TLONG", "TLAT"]] + ) # Load precalculated regridder weights from group bucket - path = 'gs://leap-persistent/jbusecke/scale-aware-air-sea/regridding_weights/ncar_atmos2ocean.zarr' + path = "gs://leap-persistent/jbusecke/scale-aware-air-sea/regridding_weights/ncar_atmos2ocean.zarr" mapper_regrid = filesystem.get_mapper(path) ds_regridder = xr.open_zarr(mapper_regrid).load() regridder = xe.Regridder( - atmos_grid, - ocean_grid, - 'bilinear', - weights=ds_regridder, - periodic=True + atmos_grid, ocean_grid, "bilinear", weights=ds_regridder, periodic=True + ) + ds_atmos_regridded = regridder( + ds_atmos[["TS", "UBOT", "VBOT", "QREFHT", "PSL", "U10", "TREFHT"]] ) - ds_atmos_regridded = regridder(ds_atmos[['TS', 'UBOT', 'VBOT', 'QREFHT', 'PSL','U10','TREFHT']]) # Combine into merged dataset ds_merged = xr.merge( [ - ds_atmos_regridded.chunk({'time':1}), # to have more manageable sized chunks (same as ocean variables) - ds_ocean[['SST','u_ocean','v_ocean']] #.chunk({'time':73}), # to have same chunksizes as atmospheric variables + ds_atmos_regridded.chunk( + {"time": 1} + ), # to have more manageable sized chunks (same as ocean variables) + ds_ocean[ + ["SST", "u_ocean", "v_ocean"] + ], # .chunk({'time':73}), # to have same chunksizes as atmospheric variables ] ) - + print("Mask nans") # Atmos missing values are filled with single floats (not sure how these values are chosen) # Ideally this should be masked before the regridding, but xesmf fills with 0 again... - mask = ~np.isnan(ds_ocean['SST'].isel(time=0).reset_coords(drop=True)) + mask = ~np.isnan(ds_ocean["SST"].isel(time=0).reset_coords(drop=True)) # mask = ds_ocean['SST'].reset_coords(drop=True)>3 # for mask_var in ['PSL', 'TREFHT', 'QREFHT', 'VBOT', 'UBOT','SST']: ds_merged = ds_merged.where(mask) - + # also apply this mask to certain coordinates from the grid dataset - for mask_coord in ['TAREA']: + for mask_coord in ["TAREA"]: # ds_merged.coords[mask_coord] = ds_merged[mask_coord].where(mask.isel(time=0).drop('time'),0.0).astype(np.float64) - ds_merged.coords[mask_coord] = ds_merged[mask_coord].where(mask,0.0).astype(np.float64) -# # The casting to float64 is needed to avoid that weird bug where the manual global weighted ave -# # is not close to the xarray weighted mean (I was not able to reproduce this with an example) - -# # Ideally this should be masked before the regridding, -# # but xesmf fills with 0 again... -# mask = ~np.isnan(ds_merged["surface_temp"]) -# for mask_var in ["slp", "t_ref", "q_ref"]: -# ds_merged[mask_var] = ds_merged[mask_var].where(mask) + ds_merged.coords[mask_coord] = ( + ds_merged[mask_coord].where(mask, 0.0).astype(np.float64) + ) + # # The casting to float64 is needed to avoid that weird bug where the manual global weighted ave + # # is not close to the xarray weighted mean (I was not able to reproduce this with an example) + + # # Ideally this should be masked before the regridding, + # # but xesmf fills with 0 again... + # mask = ~np.isnan(ds_merged["surface_temp"]) + # for mask_var in ["slp", "t_ref", "q_ref"]: + # ds_merged[mask_var] = ds_merged[mask_var].where(mask) # Define ice mask and save for later use print("Modify units") # fix units for aerobulk - ds_merged["SST"] = ds_merged["SST"] + 273.15 # convert from degC to K - ds_merged["TAREA"] = ds_merged["TAREA"] / 10000 # convert from cm^2 to m^2 - ds_merged["u_ocean"] = 0.01*ds_merged["u_ocean"] # convert from cm/s to m/s - ds_merged["v_ocean"] = 0.01*ds_merged["v_ocean"] # convert from cm/s to m/s + ds_merged["SST"] = ds_merged["SST"] + 273.15 # convert from degC to K + ds_merged["TAREA"] = ds_merged["TAREA"] / 10000 # convert from cm^2 to m^2 + ds_merged["u_ocean"] = 0.01 * ds_merged["u_ocean"] # convert from cm/s to m/s + ds_merged["v_ocean"] = 0.01 * ds_merged["v_ocean"] # convert from cm/s to m/s # Calculate relative wind print("Calculate relative wind") diff --git a/scale_aware_air_sea/old/cm26_utils.py b/scale_aware_air_sea/old/cm26_utils.py index 77189dd..da4ef4c 100644 --- a/scale_aware_air_sea/old/cm26_utils.py +++ b/scale_aware_air_sea/old/cm26_utils.py @@ -7,11 +7,11 @@ consolidated=True, use_cftime=True, inline_array=inline_array, engine="zarr" ) + def _load_oc_grid(filesystem: gcsfs.GCSFileSystem) -> xr.Dataset: mapper = filesystem.get_mapper("gs://cmip6/GFDL_CM2_6/grid") ds_oc_grid = xr.open_dataset(mapper, chunks={}, **kwargs) return ds_oc_grid - def load_and_combine_cm26( @@ -114,17 +114,18 @@ def load_and_combine_cm26( print("Mask nans") # atmos missing values are filled with 0s, which causes issues with the filtering # Ideally this should be masked before the regridding, but xesmf fills with 0 again... - mask = ~np.isnan(ds_merged['surface_temp'].isel(time=0).reset_coords(drop=True)) - for mask_var in ['slp', 't_ref', 'q_ref', 'v_ref', 'u_ref', 'wind']: + mask = ~np.isnan(ds_merged["surface_temp"].isel(time=0).reset_coords(drop=True)) + for mask_var in ["slp", "t_ref", "q_ref", "v_ref", "u_ref", "wind"]: ds_merged[mask_var] = ds_merged[mask_var].where(mask) - - - # also apply this mask to certain coordinates from the grid dataset (for now only tracer_area since that - for mask_coord in ['area_t']: - ds_merged.coords[mask_coord] = ds_oc_grid[mask_coord].where(mask,0.0).astype(np.float64) + + # also apply this mask to certain coordinates from the grid dataset (for now only tracer_area since that + for mask_coord in ["area_t"]: + ds_merged.coords[mask_coord] = ( + ds_oc_grid[mask_coord].where(mask, 0.0).astype(np.float64) + ) # The casting to float64 is needed to avoid that weird bug where the manual global weighted ave # is not close to the xarray weighted mean (I was not able to reproduce this with an example) - + # Ideally this should be masked before the regridding, # but xesmf fills with 0 again... mask = ~np.isnan(ds_merged["surface_temp"]) diff --git a/scale_aware_air_sea/parameters.py b/scale_aware_air_sea/parameters.py index 3f08cbb..3625395 100644 --- a/scale_aware_air_sea/parameters.py +++ b/scale_aware_air_sea/parameters.py @@ -1,91 +1,84 @@ -def get_params(version:str, test:bool=True) -> dict[str, str]: - bucket = 'gs://leap-persistent/jbusecke' # equivalent to os.environ['PERSISTENT_BUCKET'], but this should work for all collaborators - scratch = 'gs://leap-scratch/jbusecke' - suffix = 'test' if test else '' +def get_params(version: str, test: bool = True) -> dict[str, str]: + bucket = "gs://leap-persistent/jbusecke" # equivalent to os.environ['PERSISTENT_BUCKET'], but this should work for all collaborators + scratch = "gs://leap-scratch/jbusecke" + suffix = "test" if test else "" n_coarsen = 50 - project_path = f"scale-aware-air-sea" - version_full = version+suffix + project_path = "scale-aware-air-sea" + version_full = version + suffix global_params = { - 'filter_type':"gaussian", - 'filter_scale':50, - 'n_coarsen': n_coarsen, - 'version': version_full, - 'paths':{ - model:{ - 'preprocessing': - { - 'scratch': f"{scratch}/{project_path}/{version_full}/temp/{model}.zarr" + "filter_type": "gaussian", + "filter_scale": 50, + "n_coarsen": n_coarsen, + "version": version_full, + "paths": { + model: { + "preprocessing": { + "scratch": f"{scratch}/{project_path}/{version_full}/temp/{model}.zarr" + }, + "smoothing": { + "filter": f"{bucket}/{project_path}/{version_full}/smoothing/{model}_filter.zarr", + "coarse": f"{bucket}/{project_path}/{version_full}/smoothing/{model}_coarse_{n_coarsen}.zarr", + }, + "fluxes": { + "filter": { + "prod": f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_filter_prod.zarr", + "appendix": f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_filter_appendix.zarr", + }, + "coarse": { + "prod": f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_coarse_{n_coarsen}_prod.zarr", + "appendix": f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_coarse_{n_coarsen}_appendix.zarr", + }, + }, + "results": { + "filter": { + "native": { + "prod": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_prod.zarr", + "appendix": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_appendix.zarr", + "all_terms": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_all_terms.zarr", }, - 'smoothing': - { - 'filter': f"{bucket}/{project_path}/{version_full}/smoothing/{model}_filter.zarr", - 'coarse': f"{bucket}/{project_path}/{version_full}/smoothing/{model}_coarse_{n_coarsen}.zarr", + "mean": { + "prod": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_prod.zarr", + "appendix": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_appendix.zarr", + "all_terms": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_all_terms.zarr", }, - 'fluxes': - { - 'filter': - { - 'prod': f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_filter_prod.zarr", - 'appendix': f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_filter_appendix.zarr", - }, - 'coarse': - { - 'prod': f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_coarse_{n_coarsen}_prod.zarr", - 'appendix': f"{bucket}/{project_path}/{version_full}/fluxes/{model}_fluxes_coarse_{n_coarsen}_appendix.zarr", - }, - }, - 'results': - { - 'filter':{ - 'native':{ - 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_prod.zarr", - 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_appendix.zarr", - 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_all_terms.zarr", - }, - 'mean':{ - 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_prod.zarr", - 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_appendix.zarr", - 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_all_terms.zarr", - }, - - }, - 'coarse':{ - 'native':{ - 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_native_prod_{n_coarsen}.zarr", - 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_native_appendix_{n_coarsen}.zarr", - 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_native_all_terms.zarr", - }, - 'mean':{ - 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_mean_prod_{n_coarsen}.zarr", - 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_mean_appendix_{n_coarsen}.zarr", - 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_mean_all_terms.zarr", - }, - - }, + }, + "coarse": { + "native": { + "prod": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_native_prod_{n_coarsen}.zarr", + "appendix": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_native_appendix_{n_coarsen}.zarr", + "all_terms": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_native_all_terms.zarr", + }, + "mean": { + "prod": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_mean_prod_{n_coarsen}.zarr", + "appendix": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_mean_appendix_{n_coarsen}.zarr", + "all_terms": f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_coarse_decomposed_mean_all_terms.zarr", + }, + }, }, - 'plotting': - { - 'max_ice_mask': f"{bucket}/{project_path}/{version_full}/plotting/{model}_max_ice_mask.zarr", - 'full_fluxes':{ + "plotting": { + "max_ice_mask": f"{bucket}/{project_path}/{version_full}/plotting/{model}_max_ice_mask.zarr", + "full_fluxes": { k: { - kk:f"{bucket}/{project_path}/{version_full}/plotting/{model}_full_flux_{k}_{kk}.zarr" for kk in ['global_mean', 'time_mean'] - } for k in ['online', 'offline'] + kk: f"{bucket}/{project_path}/{version_full}/plotting/{model}_full_flux_{k}_{kk}.zarr" + for kk in ["global_mean", "time_mean"] + } + for k in ["online", "offline"] }, -# 'filter':{ -# 'native':{ -# 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_prod.zarr", -# 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_appendix.zarr", -# 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_all_terms.zarr", -# }, -# 'mean':{ -# 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_prod.zarr", -# 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_appendix.zarr", -# 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_all_terms.zarr", -# }, - -# }, + # 'filter':{ + # 'native':{ + # 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_prod.zarr", + # 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_appendix.zarr", + # 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_native_all_terms.zarr", + # }, + # 'mean':{ + # 'prod':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_prod.zarr", + # 'appendix':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_appendix.zarr", + # 'all_terms':f"{bucket}/{project_path}/{version_full}/results/{model}_fluxes_filter_decomposed_mean_all_terms.zarr", + # }, + # }, }, - } for model in ['CM26','CESM'] - } + } + for model in ["CM26", "CESM"] + }, } return global_params diff --git a/scale_aware_air_sea/plotting.py b/scale_aware_air_sea/plotting.py index 49c00eb..58a0586 100644 --- a/scale_aware_air_sea/plotting.py +++ b/scale_aware_air_sea/plotting.py @@ -2,10 +2,11 @@ import xarray as xr from scipy import ndimage as nd + # from https://stackoverflow.com/questions/3662361/fill-in-missing-values-with-nearest-neighbour-in-python-numpy-masked-arrays -def fill(data:np.ndarray, invalid: np.ndarray): +def fill(data: np.ndarray, invalid: np.ndarray): """ - Replace the value of invalid 'data' cells (indicated by 'invalid') + Replace the value of invalid 'data' cells (indicated by 'invalid') by the value of the nearest valid data cell Input: @@ -14,14 +15,17 @@ def fill(data:np.ndarray, invalid: np.ndarray): value should be replaced. If None (default), use: invalid = np.isnan(data) - Output: - Return a filled array. + Output: + Return a filled array. """ - #import numpy as np - #import scipy.ndimage as nd - ind = nd.distance_transform_edt(invalid, return_distances=False, return_indices=True) + # import numpy as np + # import scipy.ndimage as nd + ind = nd.distance_transform_edt( + invalid, return_distances=False, return_indices=True + ) return data[tuple(ind)] + def fill_da(da: xr.DataArray) -> xr.DataArray: """fills nans in dataarray""" data = da.data @@ -29,10 +33,11 @@ def fill_da(da: xr.DataArray) -> xr.DataArray: da.data = filled_data return da + def centered_shrink_axes(ax, factor): bbox = ax.get_position() - left = bbox.x0+(bbox.width*factor/2) - bottom = bbox.y0+(bbox.height*factor/2) + left = bbox.x0 + (bbox.width * factor / 2) + bottom = bbox.y0 + (bbox.height * factor / 2) width = bbox.width * factor height = bbox.height * factor - ax.set_position([left, bottom, width, height]) \ No newline at end of file + ax.set_position([left, bottom, width, height]) diff --git a/scale_aware_air_sea/stages.py b/scale_aware_air_sea/stages.py index feb6448..7e11116 100644 --- a/scale_aware_air_sea/stages.py +++ b/scale_aware_air_sea/stages.py @@ -5,29 +5,27 @@ import numpy as np import xesmf as xe import xarray as xr -import random + def load_cesm_data(fs: gcsfs.GCSFileSystem) -> tuple[xr.Dataset, xr.Dataset]: - kwargs = dict( - consolidated=True, use_cftime=True, engine="zarr" - ) - + kwargs = dict(consolidated=True, use_cftime=True, engine="zarr") + # Load ocean data ocean_path = "gs://pangeo-cesm-pop/control" ds_ocean = xr.open_dataset(fs.get_mapper(ocean_path), chunks={"time": 1}, **kwargs) - + # Load atmospheric data atmos_path = "https://ncsa.osn.xsede.org/Pangeo/pangeo-forge-test/prod/recipe-run-502/pangeo-forge/cesm-atm-025deg-feedstock/cesm-atm-025deg.zarr" ds_atmos = xr.open_dataset(fsspec.get_mapper(atmos_path), chunks={}, **kwargs) - ds_atmos = ds_atmos.chunk({"time":1}) + ds_atmos = ds_atmos.chunk({"time": 1}) print("Interpolating ocean velocities") # interpolate ocean velocities onto the tracer points using xgcm grid, ds_ocean = pop_tools.to_xgcm_grid_dataset(ds_ocean, periodic=True) - + # Fill missing ocean values with 0 - sst_wet_mask = ~np.isnan(ds_ocean['SST'].isel(time=0)) - + sst_wet_mask = ~np.isnan(ds_ocean["SST"].isel(time=0)) + # Do the interpolation ds_ocean["u_ocean"] = grid.interp_like( ds_ocean["U1_1"].fillna(0), ds_ocean["SST"] @@ -35,57 +33,60 @@ def load_cesm_data(fs: gcsfs.GCSFileSystem) -> tuple[xr.Dataset, xr.Dataset]: ds_ocean["v_ocean"] = grid.interp_like( ds_ocean["V1_1"].fillna(0), ds_ocean["SST"] ).where(sst_wet_mask) - + # rename this dataset to CM2.6 conventions (purely subjective choice here) rename_dict = { - 'nlon_t':'xt_ocean', - 'nlat_t':'yt_ocean', - 'TLAT': 'geolat_t', - 'TLONG': 'geolon_t', - 'UBOT':'u_ref', - 'VBOT':'v_ref', - 'SST':'surface_temp', - 'TREFHT':'t_ref', - 'QREFHT':'q_ref', - 'PSL':'slp', - 'TAREA': 'area_t', + "nlon_t": "xt_ocean", + "nlat_t": "yt_ocean", + "TLAT": "geolat_t", + "TLONG": "geolon_t", + "UBOT": "u_ref", + "VBOT": "v_ref", + "SST": "surface_temp", + "TREFHT": "t_ref", + "QREFHT": "q_ref", + "PSL": "slp", + "TAREA": "area_t", } - ds_ocean = ds_ocean.rename({k:v for k,v in rename_dict.items() if k in ds_ocean.variables}) - ds_atmos = ds_atmos.rename({k:v for k,v in rename_dict.items() if k in ds_atmos.variables}) - - + ds_ocean = ds_ocean.rename( + {k: v for k, v in rename_dict.items() if k in ds_ocean.variables} + ) + ds_atmos = ds_atmos.rename( + {k: v for k, v in rename_dict.items() if k in ds_atmos.variables} + ) + # fix units for aerobulk (TODO: Maybe this could be handled better with pint-xarray? print("Modify units") ds_ocean["surface_temp"] = ds_ocean["surface_temp"] + 273.15 - ds_ocean["u_ocean"] = 0.01 * ds_ocean["u_ocean"] # convert from cm/s to m/s - ds_ocean["v_ocean"] = 0.01 * ds_ocean["v_ocean"] # convert from cm/s to m/s - ds_ocean.coords["area_t"] = ds_ocean["area_t"] / 10000 # convert from cm^2 to m^2 - + ds_ocean["u_ocean"] = 0.01 * ds_ocean["u_ocean"] # convert from cm/s to m/s + ds_ocean["v_ocean"] = 0.01 * ds_ocean["v_ocean"] # convert from cm/s to m/s + ds_ocean.coords["area_t"] = ds_ocean["area_t"] / 10000 # convert from cm^2 to m^2 + return ds_ocean, ds_atmos + def load_cm26_data(fs: gcsfs.GCSFileSystem) -> tuple[xr.Dataset, xr.Dataset]: kwargs = dict(consolidated=True, use_cftime=True, engine="zarr") - + print("Load Data") ocean_path = "gs://cmip6/GFDL_CM2_6/control/surface" ds_ocean = xr.open_dataset(fs.get_mapper(ocean_path), chunks={"time": 3}, **kwargs) - + grid_path = "gs://cmip6/GFDL_CM2_6/grid" ds_ocean_grid = xr.open_dataset(fs.get_mapper(grid_path), chunks={}, **kwargs) - + # combine all dataset on the ocean grid together - ds_ocean = xr.merge([ds_ocean_grid, ds_ocean], compat='override') + ds_ocean = xr.merge([ds_ocean_grid, ds_ocean], compat="override") # xarray says not to do this # ds_atmos = xr.open_zarr('gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr', chunks={'time':1}, **kwargs) # noqa: E501 atmos_path = "gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr" - ds_atmos = xr.open_dataset(fs.get_mapper(atmos_path), chunks={"time": 120}, **kwargs).chunk( - {"time": 3} - ) + ds_atmos = xr.open_dataset( + fs.get_mapper(atmos_path), chunks={"time": 120}, **kwargs + ).chunk({"time": 3}) print("Interpolating ocean velocities") # interpolate ocean velocities onto the tracer points using xgcm - from xgcm import Grid # add xgcm comodo attrs ds_ocean["xu_ocean"].attrs["axis"] = "X" @@ -97,7 +98,7 @@ def load_cm26_data(fs: gcsfs.GCSFileSystem) -> tuple[xr.Dataset, xr.Dataset]: ds_ocean["yu_ocean"].attrs["c_grid_axis_shift"] = 0.5 ds_ocean["yt_ocean"].attrs["c_grid_axis_shift"] = 0.0 grid = Grid(ds_ocean) - + # fill missing values with 0, then interpolate. tracer_ref = ds_ocean["surface_temp"] sst_wet_mask = ~np.isnan(tracer_ref) @@ -108,10 +109,10 @@ def load_cm26_data(fs: gcsfs.GCSFileSystem) -> tuple[xr.Dataset, xr.Dataset]: ds_ocean["v_ocean"] = grid.interp_like( ds_ocean["vsurf"].fillna(0), tracer_ref ).where(sst_wet_mask) - + # rename the atmos data coordinates only to CESM conventions - ds_atmos = ds_atmos.rename({'grid_xt':'lon', 'grid_yt':'lat'}) - + ds_atmos = ds_atmos.rename({"grid_xt": "lon", "grid_yt": "lat"}) + # fix units for aerobulk (TODO: Maybe this could be handled better with pint-xarray? print("Modify units") ds_ocean["surface_temp"] = ds_ocean["surface_temp"] + 273.15 @@ -119,45 +120,53 @@ def load_cm26_data(fs: gcsfs.GCSFileSystem) -> tuple[xr.Dataset, xr.Dataset]: return ds_ocean, ds_atmos + def regrid_atmos(ds_ocean: xr.Dataset, ds_atmos: xr.Dataset) -> xr.Dataset: - # Create the regridder (I could save this out to a nc file and upload it as a zarr - # (as in pipeline/old_code/generate_regridding_weights_cm26.ipynb) but since we are + # Create the regridder (I could save this out to a nc file and upload it as a zarr + # (as in pipeline/old_code/generate_regridding_weights_cm26.ipynb) but since we are # writing the full preprocessed dataset out to scratch anyways, why bother? # NOTE: The regridder weight calculation needs 36+ GB of ram available. Which might be important for later beam stages. - + # create stripped down versions of the input datasets, to make sure we are using these coordinates for regridding. - ocean_sample = ds_ocean.drop([co for co in ds_ocean.coords if co not in ['geolon_t','geolat_t']]) - atmos_sample = ds_atmos.drop([co for co in ds_atmos.coords if co not in ['lon','lat']]) - + ocean_sample = ds_ocean.drop( + [co for co in ds_ocean.coords if co not in ["geolon_t", "geolat_t"]] + ) + atmos_sample = ds_atmos.drop( + [co for co in ds_atmos.coords if co not in ["lon", "lat"]] + ) + regridder = regridder = xe.Regridder( atmos_sample, ocean_sample, - 'bilinear', + "bilinear", periodic=True, - unmapped_to_nan=True # I think i need to cut out the lon/lat values before this! Might save some .where()'s later + unmapped_to_nan=True, # I think i need to cut out the lon/lat values before this! Might save some .where()'s later ) return regridder(ds_atmos) - - -def construct_ice_mask(ds:xr.Dataset) -> xr.DataArray: - # Define Ice mask (TODO: raise issue discussing this) - # I am choosing to use the same temp criterion here for both models. - # We should add an appendix looking at the global difference between + + +def construct_ice_mask(ds: xr.Dataset) -> xr.DataArray: + # Define Ice mask (TODO: raise issue discussing this) + # I am choosing to use the same temp criterion here for both models. + # We should add an appendix looking at the global difference between # using a time resolved vs maximally excluding ice_mask (e.g. max extent over a year) # I prototyped another method using the melt rate. # But the computation is super gnarly (see pipeline/step_00_ice_mask_brute_force_cm2.6.ipynb). return ds.surface_temp > 273.15 - -def preprocess(fs: gcsfs.GCSFileSystem, model:str, include_fluxes:bool=False) -> xr.Dataset: + + +def preprocess( + fs: gcsfs.GCSFileSystem, model: str, include_fluxes: bool = False +) -> xr.Dataset: # loading data print(f"{model}: Loading Data") - if model == 'CM26': + if model == "CM26": load_func = load_cm26_data - elif model == 'CESM': + elif model == "CESM": load_func = load_cesm_data - + ds_ocean, ds_atmos = load_func(fs) - + print(f"{model}: Align in time") # cut to same time all_dims = set(list(ds_ocean.dims) + list(ds_atmos.dims)) @@ -167,47 +176,51 @@ def preprocess(fs: gcsfs.GCSFileSystem, model:str, include_fluxes:bool=False) -> join="inner", exclude=(di for di in all_dims if di != "time"), ) - - ds_ocean.coords['ice_mask'] = construct_ice_mask(ds_ocean) - + + ds_ocean.coords["ice_mask"] = construct_ice_mask(ds_ocean) + # regrid atmospheric data - print(f"{model}: Regridding atmosphere (this takes a while, because we are computing the weights on the fly)") + print( + f"{model}: Regridding atmosphere (this takes a while, because we are computing the weights on the fly)" + ) # TODO: maybe get some non-gapped lon/lats and only put those out after regridding? - - #make sure that the ocean lon/lat values have nans in the same locations as the ocean tracer fields - tracer_ref_mask = ~np.isnan(ds_ocean['surface_temp'].isel(time=0, drop=True)) + + # make sure that the ocean lon/lat values have nans in the same locations as the ocean tracer fields + tracer_ref_mask = ~np.isnan(ds_ocean["surface_temp"].isel(time=0, drop=True)) lon_masked = ds_ocean.geolon_t.where(tracer_ref_mask) lat_masked = ds_ocean.geolat_t.where(tracer_ref_mask) area_masked = ds_ocean.area_t.where(tracer_ref_mask, 0.0) # # WHy the fuck does this not work? # ds_ocean_masked = ds_ocean.assign_coords({'geolat_t':lat_masked, 'geolon_t':lon_masked}) # this works, but seriously wtf? The above also works in a cell below, just not in this function? - ds_ocean.coords['geolon_t'].data = lon_masked.data - ds_ocean.coords['geolat_t'].data = lat_masked.data - ds_ocean.coords['area_t'].data = area_masked.data + ds_ocean.coords["geolon_t"].data = lon_masked.data + ds_ocean.coords["geolat_t"].data = lat_masked.data + ds_ocean.coords["area_t"].data = area_masked.data ds_atmos_regridded = regrid_atmos(ds_ocean, ds_atmos) - + # merge data on the ocean grid (and discard variables not needed for analysis) print(f"{model}: Merging on ocean tracer grid") atmos_vars = ["slp", "v_ref", "u_ref", "t_ref", "q_ref"] ocean_vars = ["surface_temp", "u_ocean", "v_ocean"] if include_fluxes: - if model == 'CESM': - atmos_vars = atmos_vars + ['LHFLX', 'SHFLX'] + if model == "CESM": + atmos_vars = atmos_vars + ["LHFLX", "SHFLX"] else: raise - + merge_datasets = [ds_ocean[ocean_vars], ds_atmos_regridded[atmos_vars]] ds_combined = xr.merge(merge_datasets) - + # Calculate relative wind print(f"{model}: Calculate relative wind") ds_combined["u_relative"] = ds_combined["u_ref"] - ds_combined["u_ocean"] ds_combined["v_relative"] = ds_combined["v_ref"] - ds_combined["v_ocean"] - + # Drop coordinates print(f"{model}: Drop extra coords") - keep_coords = ['time', 'geolon_t', 'geolat_t', 'area_t', 'ice_mask'] - ds_combined = ds_combined.drop([co for co in ds_combined.coords if co not in keep_coords]) - return ds_combined \ No newline at end of file + keep_coords = ["time", "geolon_t", "geolat_t", "area_t", "ice_mask"] + ds_combined = ds_combined.drop( + [co for co in ds_combined.coords if co not in keep_coords] + ) + return ds_combined diff --git a/scale_aware_air_sea/stages_tests.py b/scale_aware_air_sea/stages_tests.py index 9b0c890..847cef1 100644 --- a/scale_aware_air_sea/stages_tests.py +++ b/scale_aware_air_sea/stages_tests.py @@ -1,119 +1,139 @@ import xarray as xr import random -def _test_timesteps(ds:xr.Dataset): - assert 'model' in ds.attrs - prod_spec = ds.attrs.get('production_spec') - if prod_spec == 'appendix': + +def _test_timesteps(ds: xr.Dataset): + assert "model" in ds.attrs + prod_spec = ds.attrs.get("production_spec") + if prod_spec == "appendix": assert len(ds.time) == 365 - elif prod_spec == 'prod': - if ds.attrs['model'] == 'CESM': + elif prod_spec == "prod": + if ds.attrs["model"] == "CESM": assert len(ds.time) == 730 - elif ds.attrs['model'] == 'CM26': + elif ds.attrs["model"] == "CM26": assert len(ds.time) == 7305 -def test_data_preprocessing(ds:xr.Dataset, full_check=False): + +def test_data_preprocessing(ds: xr.Dataset, full_check=False): # check that no nans are in the lon/lat fields (warn only, I think we do not have any without nans atm) # Note. Actually this might be useful for the regridding (and masking within), but we should be able to attache fully filled lon/lats for later plotting. _test_timesteps(ds) # check that all variables are on the tracer point for va in ds.data_vars: - assert set(ds[va].dims) == set(['time', 'xt_ocean', 'yt_ocean']) - + assert set(ds[va].dims) == set(["time", "xt_ocean", "yt_ocean"]) + # check that necessary coordinates are included - for co in ['ice_mask', 'area_t', 'geolon_t', 'geolat_t']: + for co in ["ice_mask", "area_t", "geolon_t", "geolat_t"]: assert co in ds.coords - + # Range check on naive global mean for variables ranges = { - 'surface_temp':[270, 310], - 'q_ref':[0.005, 0.02], - 'slp': [100000, 110000], - 't_ref': [270, 310], - 'u_ocean': [0.05, 2], - 'v_ocean': [0.05, 2], - 'u_relative' : [1, 20], - 'v_relative' : [1, 20], - 'u_ref' : [1, 20], - 'v_ref' : [1, 20], - 'area_t': [5e7, 2e8], + "surface_temp": [270, 310], + "q_ref": [0.005, 0.02], + "slp": [100000, 110000], + "t_ref": [270, 310], + "u_ocean": [0.05, 2], + "v_ocean": [0.05, 2], + "u_relative": [1, 20], + "v_relative": [1, 20], + "u_ref": [1, 20], + "v_ref": [1, 20], + "area_t": [5e7, 2e8], } range_test_ds = ds.isel(time=random.randint(0, len(ds.time))).load() for va, r in ranges.items(): test_val = abs(range_test_ds[va]).quantile(0.75).data if not (test_val >= r[0] and test_val <= r[1]): - raise ValueError(f"{va =} failed the range test. Got value={test_val} and range={r}") - + raise ValueError( + f"{va =} failed the range test. Got value={test_val} and range={r}" + ) + # TODO: Check the proper units? if full_check: # test that there are no all nan maps anywhere - nan_test = np.isnan(ds).all(['xt_ocean', 'yt_ocean']).to_array().sum() + nan_test = np.isnan(ds).all(["xt_ocean", "yt_ocean"]).to_array().sum() assert nan_test.data == 0 - - + # finally check that each variable has nans in the same position a = np.isnan(ds.surface_temp.isel(time=0, drop=True)).load() - b = np.isnan(ds.isel(time=0, drop=True).to_array()).all('variable').load() - xr.testing.assert_allclose(a,b) + b = np.isnan(ds.isel(time=0, drop=True).to_array()).all("variable").load() + xr.testing.assert_allclose(a, b) + - - import matplotlib.pyplot as plt import numpy as np + def test_smoothed_data(ds_raw, ds, plot=False, full_check=False): _test_timesteps(ds) - assert 'smoothing_method' in ds.attrs.keys() - - if ds.attrs['smoothing_method'] == 'coarse': - assert 'n_coarsen' in ds.attrs.keys() - + assert "smoothing_method" in ds.attrs.keys() + + if ds.attrs["smoothing_method"] == "coarse": + assert "n_coarsen" in ds.attrs.keys() + # Test that raw and coarse datasets preserver the global mean tracer value # This ensures that both the values and the coarsened area are calculated consistently - test_var = 'surface_temp' - test_roi = dict(time=slice(0,200)) + test_var = "surface_temp" + test_roi = dict(time=slice(0, 200)) # FIXME: THERE IS THIS BIZARRE precision error with weighted again...WTF. Take the `.astype(...)` out to see this mess!!! - raw_test = ds_raw[test_var].isel(**test_roi).astype(np.float64).weighted(ds_raw.area_t).mean(['xt_ocean', 'yt_ocean']).load() - test = ds[test_var].isel(**test_roi).astype(np.float64).weighted(ds.area_t).mean(['xt_ocean', 'yt_ocean']).load() + raw_test = ( + ds_raw[test_var] + .isel(**test_roi) + .astype(np.float64) + .weighted(ds_raw.area_t) + .mean(["xt_ocean", "yt_ocean"]) + .load() + ) + test = ( + ds[test_var] + .isel(**test_roi) + .astype(np.float64) + .weighted(ds.area_t) + .mean(["xt_ocean", "yt_ocean"]) + .load() + ) if plot: plt.figure() - raw_test.plot(label='raw', ls='-') - test.plot(label='coarse', ls=':') - plt.title(f'Global weighted {test_var} average {model}') + raw_test.plot(label="raw", ls="-") + test.plot(label="coarse", ls=":") + plt.title(f"Global weighted {test_var} average {model}") plt.legend() plt.show() xr.testing.assert_allclose(raw_test, test) - - elif ds.attrs['smoothing_method'] == 'filter': - assert 'filter_type' in ds.attrs.keys() - assert 'filter_scale' in ds.attrs.keys() + + elif ds.attrs["smoothing_method"] == "filter": + assert "filter_type" in ds.attrs.keys() + assert "filter_scale" in ds.attrs.keys() if full_check: # test that there are no all nan maps anywhere - nan_test = np.isnan(ds).all(['xt_ocean', 'yt_ocean']).to_array().sum() + nan_test = np.isnan(ds).all(["xt_ocean", "yt_ocean"]).to_array().sum() assert nan_test.data == 0 - + ## Tests for all smoothed datasets ## are eddies visually eliminated? if plot: plt.figure() - ds.isel(time=[0, 100, 300]).surface_temp.plot.contourf(col='time', levels=21, size=4) + ds.isel(time=[0, 100, 300]).surface_temp.plot.contourf( + col="time", levels=21, size=4 + ) plt.show() -def test_data_flux(ds:xr.Dataset, plot=False, full_check=False): - for attr in ['smoothing_method', 'production_spec', 'model']: + +def test_data_flux(ds: xr.Dataset, plot=False, full_check=False): + for attr in ["smoothing_method", "production_spec", "model"]: print(attr) assert attr in ds.attrs.keys() _test_timesteps(ds) # test that there are no all nan maps anywhere if full_check: - nan_test = np.isnan(ds).all(['xt_ocean', 'yt_ocean']).to_array().sum() + nan_test = np.isnan(ds).all(["xt_ocean", "yt_ocean"]).to_array().sum() assert nan_test.data == 0 # Check the ice-mask if plot: plt.figure() - ds.qh.isel(time=[0,90, 180], algo=0, smoothing=0).plot(col='time', robust=True) - plt.show() \ No newline at end of file + ds.qh.isel(time=[0, 90, 180], algo=0, smoothing=0).plot(col="time", robust=True) + plt.show() diff --git a/scale_aware_air_sea/utils.py b/scale_aware_air_sea/utils.py index b22a243..8228a12 100644 --- a/scale_aware_air_sea/utils.py +++ b/scale_aware_air_sea/utils.py @@ -9,28 +9,26 @@ def open_zarr(mapper, chunks={}): return xr.open_dataset( - mapper, - engine='zarr', - chunks=chunks, - consolidated=True, - inline_array=True + mapper, engine="zarr", chunks=chunks, consolidated=True, inline_array=True ) + def maybe_save_and_reload(ds, path, overwrite=False, fs=None): if fs is None: fs = gcsfs.GCSFileSystem() - + if not fs.exists(path): - print(f'Saving the dataset to zarr at {path}') + print(f"Saving the dataset to zarr at {path}") ds.to_zarr(path) elif fs.exists(path) and overwrite: - print(f'Overwriting dataset at {path}') - ds.to_zarr(path, mode='w') - + print(f"Overwriting dataset at {path}") + ds.to_zarr(path, mode="w") + print(f"Reload dataset from {path}") - ds_reloaded = xr.open_dataset(path, engine='zarr', chunks={}) + ds_reloaded = xr.open_dataset(path, engine="zarr", chunks={}) return ds_reloaded + def filter_inputs( da: xr.DataArray, wet_mask: xr.DataArray, @@ -64,7 +62,7 @@ def filter_inputs( dx_min=1, filter_shape=gcm_filters.FilterShape.GAUSSIAN, grid_type=gcm_filters.GridType.TRIPOLAR_REGULAR_WITH_LAND_AREA_WEIGHTED, - grid_vars={"area":da.TAREA,"wet_mask": wet_mask}, + grid_vars={"area": da.TAREA, "wet_mask": wet_mask}, ) else: @@ -175,60 +173,75 @@ def to_zarr_split(ds, mapper, split_dim="time", split_interval=1): # what xr.to_zarr would do g = zarr.open_group(mapper) del g[split_dim] - - ds[[split_dim]].load().to_zarr(mapper, mode='a') + + ds[[split_dim]].load().to_zarr(mapper, mode="a") zarr.consolidate_metadata(mapper) - -def weighted_coarsen(ds:xr.Dataset, dim: Mapping[Any, int], weight_coord:str, timedim='time', **kwargs) -> xr.Dataset: - + + +def weighted_coarsen( + ds: xr.Dataset, dim: Mapping[Any, int], weight_coord: str, timedim="time", **kwargs +) -> xr.Dataset: # Check that the weights have no missing values weights = ds[weight_coord] - if np.isnan(weights).sum()>0: - raise ValueError(f'Found missing values in weights coordinate ({weight_coord}). Please fill with zeros before.') - - # Make sure that the weights are matching the missing values in the input data + if np.isnan(weights).sum() > 0: + raise ValueError( + f"Found missing values in weights coordinate ({weight_coord}). Please fill with zeros before." + ) + + # Make sure that the weights are matching the missing values in the input data # (otherwise creation of aggregated area will be ambigous and depend on each variable) # the important thing to check is if a) all variables have the same mask and variable_missing = np.isnan(ds.to_array()) - + if timedim in ds.dims: - variable_missing = variable_missing.isel({timedim:0}) - - variable_mask = variable_missing.any('variable').load() # loading because we need it multiple times - variable_test = variable_missing.all('variable') + variable_missing = variable_missing.isel({timedim: 0}) + + variable_mask = variable_missing.any( + "variable" + ).load() # loading because we need it multiple times + variable_test = variable_missing.all("variable") if not variable_mask.equals(variable_test): - raise ValueError('Found variables with non-matching missing values. ', - 'Make sure that the missing values in **all** variables are in the same position.') - - # and b) if the weights have nonzero values that do not match the variables (this would lead to additional area being counted below) - weights_test = weights<=0 - + raise ValueError( + "Found variables with non-matching missing values. ", + "Make sure that the missing values in **all** variables are in the same position.", + ) + + # and b) if the weights have nonzero values that do not match the variables (this would lead to additional area being counted below) + weights_test = weights <= 0 + a = variable_mask.squeeze(drop=True) b = weights_test.squeeze(drop=True) - if not np.allclose(a, b.transpose(*a.dims)): # need to transpose this, which too me still seems un xarray-like (I discussed this in an issue once, but whatever). + if not np.allclose( + a, b.transpose(*a.dims) + ): # need to transpose this, which too me still seems un xarray-like (I discussed this in an issue once, but whatever). raise ValueError( - 'Missing values in variables are not matching locations of <=0 values in weights array. ', - 'Please change your weights to only have missing values or zeros where variables have missing values.' + "Missing values in variables are not matching locations of <=0 values in weights array. ", + "Please change your weights to only have missing values or zeros where variables have missing values.", ) - + # start the actual calculation ds_coarse = ds.coarsen(**dim, **kwargs) # construct internal/external dims - construct_kwargs = {di:(di+'_external', di+'_internal') for di in dim} + construct_kwargs = {di: (di + "_external", di + "_internal") for di in dim} ds_construct = ds_coarse.construct(**construct_kwargs) - + # apply weighted mean over internal dimensions weights_coarse = ds_construct[weight_coord] - aggregate_dims = [di+'_internal' for di in dim] + aggregate_dims = [di + "_internal" for di in dim] ds_out = ds_construct.weighted(weights_coarse).mean(aggregate_dims) - + # add new area that corresponds to the area that was used for each coarse cell - ds_out = ds_out.assign_coords(**{weight_coord:weights_coarse.sum(aggregate_dims)}) - + ds_out = ds_out.assign_coords(**{weight_coord: weights_coarse.sum(aggregate_dims)}) + # add other coordinates back - coords_to_treat = [co for co in ds.coords if co != weight_coord and co not in ds_out.coords] - treated_coords = {co:ds[co].coarsen({k:v for k,v in dim.items() if k in ds[co].dims}).mean() for co in coords_to_treat} + coords_to_treat = [ + co for co in ds.coords if co != weight_coord and co not in ds_out.coords + ] + treated_coords = { + co: ds[co].coarsen({k: v for k, v in dim.items() if k in ds[co].dims}).mean() + for co in coords_to_treat + } ds_out = ds_out.assign_coords(**treated_coords) - + # rename to original names and return - return ds_out.rename({di+'_external': di for di in dim}) + return ds_out.rename({di + "_external": di for di in dim}) diff --git a/tests/test_utils.py b/tests/test_utils.py index aa6e573..ee8b47e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,18 +12,23 @@ def dataset(): nt = 50 nx, ny = 20, 100 data = dsa.random.random((nx, ny, nt), chunks=(nx, ny, 1)) - x = xr.DataArray(np.linspace(-10, 10, nx), dims=['x']) - y = xr.DataArray(np.linspace(20, 120, ny), dims=['y']) + x = xr.DataArray(np.linspace(-10, 10, nx), dims=["x"]) + y = xr.DataArray(np.linspace(20, 120, ny), dims=["y"]) lon = x * xr.ones_like(y) lat = xr.ones_like(x) * y ds_test = xr.Dataset( { - k: xr.DataArray(data, dims=["x", "y", "time"], coords={"time": range(nt), 'x':x, 'y':y, 'lon':lon, 'lat':lat}) + k: xr.DataArray( + data, + dims=["x", "y", "time"], + coords={"time": range(nt), "x": x, "y": y, "lon": lon, "lat": lat}, + ) for k in ["a", "b"] - }, + }, ) return ds_test + class Test_to_zarr_split: def test_roundtrip(self, tmp_path): d = tmp_path / "sub" @@ -38,7 +43,6 @@ def test_roundtrip(self, tmp_path): assert ds.chunks == ds_reloaded.chunks xr.testing.assert_equal(ds, ds_reloaded) - def test_metadata(self, tmp_path): d = tmp_path / "sub" d.mkdir() @@ -58,123 +62,140 @@ def test_metadata(self, tmp_path): meta_split = json.load(f) assert meta == meta_split - + + class Test_weighted_coarsen: def test_simple_2_x_2(self): - data_full = np.random.rand(4,4) - weights_full = np.random.rand(4,4) + data_full = np.random.rand(4, 4) + weights_full = np.random.rand(4, 4) d = data_full * weights_full - weights_expected = np.hstack([ - np.vstack([weights_full[0:2, 0:2].sum(), weights_full[2:5, 0:2].sum()]), - np.vstack([weights_full[0:2, 2:5].sum(), weights_full[2:5, 2:5].sum()]), - ]) + weights_expected = np.hstack( + [ + np.vstack([weights_full[0:2, 0:2].sum(), weights_full[2:5, 0:2].sum()]), + np.vstack([weights_full[0:2, 2:5].sum(), weights_full[2:5, 2:5].sum()]), + ] + ) - data_expected = np.hstack([ - np.vstack([d[0:2, 0:2].sum(), d[2:5, 0:2].sum()]), - np.vstack([d[0:2, 2:5].sum(), d[2:5, 2:5].sum()]), - ]) / weights_expected + data_expected = ( + np.hstack( + [ + np.vstack([d[0:2, 0:2].sum(), d[2:5, 0:2].sum()]), + np.vstack([d[0:2, 2:5].sum(), d[2:5, 2:5].sum()]), + ] + ) + / weights_expected + ) - ds = xr.DataArray(data_full, coords={'area':(['x','y'],weights_full)}, dims=['x','y']).to_dataset(name='data') + ds = xr.DataArray( + data_full, coords={"area": (["x", "y"], weights_full)}, dims=["x", "y"] + ).to_dataset(name="data") - da_coarse = weighted_coarsen(ds, {'x':2, 'y':2}, 'area') + da_coarse = weighted_coarsen(ds, {"x": 2, "y": 2}, "area") np.testing.assert_allclose(da_coarse.data, data_expected) np.testing.assert_allclose(da_coarse.area, weights_expected) def test_nan_mismatch_variables(self): - data_full = np.random.rand(4,4) - data_2_full = np.random.rand(4,4) - data_2_full[0,1] = np.nan - - weights_full = np.random.rand(4,4) - ds = xr.Dataset({ - 'data1':xr.DataArray(data_full, dims=['x','y']), - 'data2':xr.DataArray(data_2_full, dims=['x','y']), - }, - coords={'area':(['x','y'],weights_full)}, + data_full = np.random.rand(4, 4) + data_2_full = np.random.rand(4, 4) + data_2_full[0, 1] = np.nan + + weights_full = np.random.rand(4, 4) + ds = xr.Dataset( + { + "data1": xr.DataArray(data_full, dims=["x", "y"]), + "data2": xr.DataArray(data_2_full, dims=["x", "y"]), + }, + coords={"area": (["x", "y"], weights_full)}, ) - with pytest.raises(ValueError, match='Found variables with non-matching missing values.'): - weighted_coarsen(ds, {'x':2, 'y':2}, 'area') - + with pytest.raises( + ValueError, match="Found variables with non-matching missing values." + ): + weighted_coarsen(ds, {"x": 2, "y": 2}, "area") def test_nan_mismatch_weights(self): - data_full = np.random.rand(4,4) - data_full[0,1] = np.nan - data_2_full = np.random.rand(4,4) - data_2_full[0,1] = np.nan - - weights_full = np.random.rand(4,4) - ds = xr.Dataset({ - 'data1':xr.DataArray(data_full, dims=['x','y']), - 'data2':xr.DataArray(data_2_full, dims=['x','y']), - }, - coords={'area':(['x','y'],weights_full)}, + data_full = np.random.rand(4, 4) + data_full[0, 1] = np.nan + data_2_full = np.random.rand(4, 4) + data_2_full[0, 1] = np.nan + + weights_full = np.random.rand(4, 4) + ds = xr.Dataset( + { + "data1": xr.DataArray(data_full, dims=["x", "y"]), + "data2": xr.DataArray(data_2_full, dims=["x", "y"]), + }, + coords={"area": (["x", "y"], weights_full)}, ) - with pytest.raises(ValueError, match='Missing values in variables are not matching locations of <=0 values in weights array.'): - weighted_coarsen(ds, {'x':2, 'y':2}, 'area') + with pytest.raises( + ValueError, + match="Missing values in variables are not matching locations of <=0 values in weights array.", + ): + weighted_coarsen(ds, {"x": 2, "y": 2}, "area") def test_weights_nan(self): - data_full = np.random.rand(4,4) - data_2_full = np.random.rand(4,4) - - weights_full = np.random.rand(4,4) - weights_full[0,1] = np.nan - - ds = xr.Dataset({ - 'data1':xr.DataArray(data_full, dims=['x','y']), - 'data2':xr.DataArray(data_2_full, dims=['x','y']), - }, - coords={'area':(['x','y'],weights_full)}, + data_full = np.random.rand(4, 4) + data_2_full = np.random.rand(4, 4) + + weights_full = np.random.rand(4, 4) + weights_full[0, 1] = np.nan + + ds = xr.Dataset( + { + "data1": xr.DataArray(data_full, dims=["x", "y"]), + "data2": xr.DataArray(data_2_full, dims=["x", "y"]), + }, + coords={"area": (["x", "y"], weights_full)}, ) - with pytest.raises(ValueError, match='Found missing values in weights coordinate '): - weighted_coarsen(ds, {'x':2, 'y':2}, 'area') - + with pytest.raises( + ValueError, match="Found missing values in weights coordinate " + ): + weighted_coarsen(ds, {"x": 2, "y": 2}, "area") def test_preserve_integral(self): - data_full = np.random.rand(4,4, 10) - data_2_full = np.random.rand(4,4, 10) + data_full = np.random.rand(4, 4, 10) + data_2_full = np.random.rand(4, 4, 10) - weights_full = np.random.rand(4,4) + weights_full = np.random.rand(4, 4) - ds = xr.Dataset({ - 'data1':xr.DataArray(data_full, dims=['x','y', 'time']), - 'data2':xr.DataArray(data_2_full, dims=['x','y', 'time']), - }, - coords={'area':(['x','y'],weights_full)}, + ds = xr.Dataset( + { + "data1": xr.DataArray(data_full, dims=["x", "y", "time"]), + "data2": xr.DataArray(data_2_full, dims=["x", "y", "time"]), + }, + coords={"area": (["x", "y"], weights_full)}, ) - ds_coarse = weighted_coarsen(ds, {'x':2, 'y':2}, 'area') + ds_coarse = weighted_coarsen(ds, {"x": 2, "y": 2}, "area") # We expect the weighted mean of both the original and coarsened dataset to stay the same - mean_fine = ds.weighted(ds.area).mean(['x','y']) - mean_coarse = ds_coarse.weighted(ds_coarse.area).mean(['x','y']) + mean_fine = ds.weighted(ds.area).mean(["x", "y"]) + mean_coarse = ds_coarse.weighted(ds_coarse.area).mean(["x", "y"]) xr.testing.assert_allclose(mean_fine, mean_coarse) - + def test_1d_coords(self): ds = dataset() - ds['area'] = xr.ones_like(ds.a.isel(time=0)) - - ds_coarse = weighted_coarsen(ds, {'x':2, 'y':4}, 'area') - - x_expected = ds.x.coarsen({'x':2}).mean() - y_expected = ds.y.coarsen({'y':4}).mean() - + ds["area"] = xr.ones_like(ds.a.isel(time=0)) + + ds_coarse = weighted_coarsen(ds, {"x": 2, "y": 4}, "area") + + x_expected = ds.x.coarsen({"x": 2}).mean() + y_expected = ds.y.coarsen({"y": 4}).mean() + xr.testing.assert_equal(x_expected, ds_coarse.x) xr.testing.assert_equal(y_expected, ds_coarse.y) - + def test_2d_coords(self): ds = dataset() - ds['area'] = xr.ones_like(ds.a.isel(time=0)) - - ds_coarse = weighted_coarsen(ds, {'x':2, 'y':4}, 'area') - - lon_expected = ds.reset_coords().lon.coarsen({'x':2, 'y':4}).mean() - lat_expected = ds.reset_coords().lat.coarsen({'x':2, 'y':4}).mean() - + ds["area"] = xr.ones_like(ds.a.isel(time=0)) + + ds_coarse = weighted_coarsen(ds, {"x": 2, "y": 4}, "area") + + lon_expected = ds.reset_coords().lon.coarsen({"x": 2, "y": 4}).mean() + lat_expected = ds.reset_coords().lat.coarsen({"x": 2, "y": 4}).mean() + xr.testing.assert_allclose(lon_expected, ds_coarse.reset_coords().lon) xr.testing.assert_allclose(lat_expected, ds_coarse.reset_coords().lat) - -