Skip to content

Commit

Permalink
era_downloader test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 9, 2024
1 parent a181ae8 commit fd13f49
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions tests/utilities/test_era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import numpy as np
import pandas as pd

from sup3r.preprocessing.names import FEATURE_NAMES
from sup3r.utilities.era_downloader import EraDownloader
Expand All @@ -17,10 +18,25 @@ class EraDownloaderTester(EraDownloader):
# pylint: disable=unused-argument
@classmethod
def download_file(
cls, variables, out_file, level_type, levels=None, **kwargs
cls,
variables,
time_dict,
area, # noqa
out_file,
level_type,
levels=None,
**kwargs, # noqa
):
"""Download either single-level or pressure-level file"""
shape = (10, 10, 100)
n_days = pd.Period(
f'{time_dict["year"]}-{time_dict["month"]}-01'
).days_in_month
ti = pd.date_range(
f'{time_dict["year"]}-{time_dict["month"]}-01',
f'{time_dict["year"]}-{time_dict["month"]}-{n_days}',
freq='D',
)
shape = (10, 10, len(ti))
if levels is not None:
shape = (*shape, len(levels))

Expand All @@ -40,6 +56,7 @@ def download_file(
features.extend([v for f, v in name_map.items() if f in variables])

nc = make_fake_dset(shape=shape, features=features)
nc['time'] = ti
if 'z' in nc:
if level_type == 'single':
nc['z'] = (nc['z'].dims, np.zeros(nc['z'].shape))
Expand All @@ -48,7 +65,7 @@ def download_file(
for i in range(nc['z'].shape[1]):
arr[:, i, ...] = i * 100
nc['z'] = (nc['z'].dims, arr)
nc.to_netcdf(out_file, format='NETCDF4', engine='h5netcdf')
nc.to_netcdf(out_file)


def test_era_dl(tmpdir_factory):
Expand Down Expand Up @@ -98,11 +115,6 @@ def test_era_dl_year(tmpdir_factory):
yearly_file_pattern=yearly_file_pattern,
max_workers=1,
combine_all_files=True,
res_kwargs={
'engine': 'netcdf4',
'combine': 'nested',
'concat_dim': 'time',
},
)

combined_file = yearly_file_pattern.replace('_{var}_', '').format(
Expand Down

0 comments on commit fd13f49

Please sign in to comment.