Skip to content

Commit

Permalink
edits to netcdf write output so netcdf collection phase can be done w…
Browse files Browse the repository at this point in the history
…ith spatial chunks
  • Loading branch information
bnb32 committed Nov 28, 2023
1 parent b32d7d6 commit 10b6ed4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
12 changes: 12 additions & 0 deletions sup3r/postprocessing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, file_paths):
"""
if not isinstance(file_paths, list):
file_paths = glob.glob(file_paths)
self.file_paths = file_paths
self.flist = sorted(file_paths)
self.data = None
self.file_attrs = {}
Expand Down Expand Up @@ -174,6 +175,17 @@ def collect(

logger.info('Finished file collection.')

def group_spatial_chunks(self):
"""Group same spatial chunks together so each chunk has same spatial
footprint but different times"""
chunks = {}
for file in self.flist:
s_chunk = file.split('_')[0]
dirname = os.path.dirname(file)
s_file = os.path.join(dirname, f's_{s_chunk}.nc')
chunks[s_file] = [*chunks.get(s_file, []), s_file]
return chunks


class CollectorH5(BaseCollector):
"""Sup3r H5 file collection framework"""
Expand Down
8 changes: 4 additions & 4 deletions sup3r/postprocessing/file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,13 @@ def _write_output(cls, data, features, lat_lon, times, out_file,
List of coordinate indices used to label each lat lon pair and to
help with spatial chunk data collection
"""
coords = {'Times': (['Time'], [str(t).encode('utf-8') for t in times]),
'XLAT': (['south_north', 'east_west'], lat_lon[..., 0]),
'XLONG': (['south_north', 'east_west'], lat_lon[..., 1])}
coords = {'Time': [str(t).encode('utf-8') for t in times],
'south_north': lat_lon[:, 0, 0],
'west_east': lat_lon[0, :, 1]}

data_vars = {}
for i, f in enumerate(features):
data_vars[f] = (['Time', 'south_north', 'east_west'],
data_vars[f] = (['Time', 'south_north', 'west_east'],
np.transpose(data[..., i], (2, 0, 1)))

attrs = {}
Expand Down
4 changes: 2 additions & 2 deletions sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def get_lat_lon(cls, file_paths, raster_index, invert_lat=False):
lat_lon = lat_lon[::-1]
# put angle betwen -180 and 180
lat_lon[..., 1] = (lat_lon[..., 1] + 180) % 360 - 180
return lat_lon
return lat_lon.astype(np.float32)

@classmethod
def get_node_cmd(cls, config):
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def run_all_data_init(self):
logger.info(f'Finished extracting data for {self.input_file_info} in '
f'{dt.now() - now}')

return self.data
return self.data.astype(np.float32)

def run_nn_fill(self):
"""Run nn nan fill on full data array."""
Expand Down
6 changes: 4 additions & 2 deletions sup3r/preprocessing/feature_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,9 +1077,11 @@ def compute(file_paths, raster_index):
fp = file_paths if isinstance(file_paths, str) else file_paths[0]
handle = xr.open_dataset(fp)
valid_vars = set(handle.variables)
lat_key = {'XLAT', 'lat', 'latitude'}.intersection(valid_vars)
lat_key = {'XLAT', 'lat', 'latitude', 'south_north'}.intersection(
valid_vars)
lat_key = next(iter(lat_key))
lon_key = {'XLONG', 'lon', 'longitude'}.intersection(valid_vars)
lon_key = {'XLONG', 'lon', 'longitude', 'west_east'}.intersection(
valid_vars)
lon_key = next(iter(lon_key))

if len(handle.variables[lat_key].dims) == 4:
Expand Down

0 comments on commit 10b6ed4

Please sign in to comment.