Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/ddh norm #180

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,8 @@ def _parallel_normalization(self):
max_workers = self.norm_workers
if max_workers == 1:
for dh in self.data_handlers:
dh.normalize(self.means, self.stds)
dh.normalize(self.means, self.stds,
max_workers=dh.norm_workers)
else:
with ThreadPoolExecutor(max_workers=max_workers) as exe:
futures = {}
Expand Down Expand Up @@ -691,7 +692,8 @@ def _get_stats(self):
future = exe.submit(dh._get_stats)
futures[future] = idh

for i, _ in enumerate(as_completed(futures)):
for i, future in enumerate(as_completed(futures)):
_ = future.result()
logger.debug(f'{i+1} out of {len(self.data_handlers)} '
'means calculated.')

Expand Down Expand Up @@ -731,10 +733,10 @@ def check_cached_stats(self):
means_check = means_check and os.path.exists(self.means_file)
if stdevs_check and means_check:
logger.info(f'Loading stdevs from {self.stdevs_file}')
with open(self.stdevs_file, 'r') as fh:
with open(self.stdevs_file) as fh:
self.stds = json.load(fh)
logger.info(f'Loading means from {self.means_file}')
with open(self.means_file, 'r') as fh:
with open(self.means_file) as fh:
self.means = json.load(fh)

msg = ('The training features and cached statistics are '
Expand Down Expand Up @@ -777,8 +779,7 @@ def _get_feature_means(self, feature):
feature : str
Feature to get mean for
"""

logger.debug(f'Calculating mean for {feature}')
logger.debug(f'Calculating multi-handler mean for {feature}')
for idh, dh in enumerate(self.data_handlers):
self.means[feature] += (self.handler_weights[idh]
* dh.means[feature])
Expand All @@ -798,7 +799,7 @@ def _get_feature_stdev(self, feature):
Feature to get stdev for
"""

logger.debug(f'Calculating stdev for {feature}')
logger.debug(f'Calculating multi-handler stdev for {feature}')
for idh, dh in enumerate(self.data_handlers):
variance = dh.stds[feature]**2
self.stds[feature] += (variance * self.handler_weights[idh])
Expand All @@ -823,6 +824,9 @@ def normalize(self, means=None, stds=None):
feature names and values: standard deviations. if None, this will
be calculated. if norm is true these will be used for data
normalization
features : list | None
Optional list of features used to index data array during
normalization. If this is None self.features will be used.
"""
if means is None or stds is None:
self.get_stats()
Expand Down
112 changes: 20 additions & 92 deletions sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
estimate_max_workers,
get_chunk_slices,
get_raster_shape,
nn_fill_array,
np_to_pd_times,
spatial_coarsening,
uniform_box_sampler,
Expand Down Expand Up @@ -449,17 +450,6 @@ def load_workers(self):
n_procs)
return load_workers

@property
def norm_workers(self):
"""Get upper bound on workers used for normalization."""
if self.data is not None:
norm_workers = estimate_max_workers(self._norm_workers,
2 * self.feature_mem,
self.shape[-1])
else:
norm_workers = self._norm_workers
return norm_workers

@property
def time_chunks(self):
"""Get time chunks which will be extracted from source data
Expand Down Expand Up @@ -543,8 +533,7 @@ def get_handle_features(cls, file_paths):
handle_features = []
for f in file_paths:
handle = cls.source_handler([f])
for r in handle:
handle_features.append(Feature.get_basename(r))
handle_features += [Feature.get_basename(r) for r in handle]
return list(set(handle_features))

@property
Expand Down Expand Up @@ -921,72 +910,6 @@ def get_cache_file_names(self,
target,
features)

@property
def means(self):
"""Get the mean values for each feature.

Returns
-------
dict
"""
self._get_stats()
return self._means

@property
def stds(self):
"""Get the standard deviation values for each feature.

Returns
-------
dict
"""
self._get_stats()
return self._stds

def _get_stats(self):
if self._means is None or self._stds is None:
msg = (f'DataHandler has {len(self.features)} features '
f'and mismatched shape of {self.shape}')
assert len(self.features) == self.shape[-1], msg
self._stds = {}
self._means = {}
for idf, fname in enumerate(self.features):
self._means[fname] = np.nanmean(self.data[..., idf])
self._stds[fname] = np.nanstd(self.data[..., idf])

def normalize(self, means=None, stds=None, max_workers=None):
"""Normalize all data features.

Parameters
----------
means : dict | none
Dictionary of means for all features with keys: feature names and
values: mean values. If this is None, the self.means attribute will
be used. If this is not None, this DataHandler object means
attribute will be updated.
stds : dict | none
dictionary of standard deviation values for all features with keys:
feature names and values: standard deviations. If this is None, the
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
max_workers : None | int
Max workers to perform normalization. if None, self.norm_workers
will be used
"""
if means is not None:
self._means = means
if stds is not None:
self._stds = stds

max_workers = max_workers or self.norm_workers
if self._is_normalized:
logger.info('Skipping DataHandler, already normalized')
else:
self._normalize(self.data,
self.val_data,
max_workers=max_workers)
self._is_normalized = True

def get_next(self):
"""Get data for observation using random observation index. Loops
repeatedly over randomized time index
Expand Down Expand Up @@ -1159,7 +1082,7 @@ def run_all_data_init(self):
self.run_data_compute()

logger.info('Building final data array')
self.parallel_data_fill(shifted_time_chunks, self.extract_workers)
self.data_fill(shifted_time_chunks, self.extract_workers)

if self.invert_lat:
self.data = self.data[::-1]
Expand All @@ -1182,8 +1105,16 @@ def run_all_data_init(self):

logger.info(f'Finished extracting data for {self.input_file_info} in '
f'{dt.now() - now}')

self.run_nn_fill()
return self.data

def run_nn_fill(self):
"""Run nn nan fill on full data array."""
for i in range(self.data.shape[-1]):
if np.isnan(self.data[..., i]).any():
self.data[..., i] = nn_fill_array(self.data[..., i])

def run_data_extraction(self):
"""Run the raw dataset extraction process from disk to raw
un-manipulated datasets.
Expand Down Expand Up @@ -1238,7 +1169,7 @@ def run_data_compute(self):
logger.info(f'Finished computing {self.derive_features} for '
f'{self.input_file_info}')

def data_fill(self, t, t_slice, f_index, f):
def _single_data_fill(self, t, t_slice, f_index, f):
"""Place single extracted / computed chunk in final data array

Parameters
Expand Down Expand Up @@ -1269,14 +1200,12 @@ def serial_data_fill(self, shifted_time_chunks):
for t, ts in enumerate(shifted_time_chunks):
for _, f in enumerate(self.noncached_features):
f_index = self.features.index(f)
self.data_fill(t, ts, f_index, f)
interval = int(np.ceil(len(shifted_time_chunks) / 10))
if t % interval == 0:
logger.info(f'Added {t + 1} of {len(shifted_time_chunks)} '
'chunks to final data array')
self._single_data_fill(t, ts, f_index, f)
logger.info(f'Added {t + 1} of {len(shifted_time_chunks)} '
'chunks to final data array')
self._raw_data.pop(t)

def parallel_data_fill(self, shifted_time_chunks, max_workers=None):
def data_fill(self, shifted_time_chunks, max_workers=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of this naming convention. I always find it confusing when there is a .method and a ._method. I think you're trying to hide the serial method from the "public" view but it makes the developer's job more confusing. I'd suggest _single_data_fill() and _data_fill()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this just bc it includes the serial method also so its not strictly "parallel"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah understood but I think _single*() and _data_fill() would be the most clear, dont you think?

"""Fill final data array with extracted / computed chunks

Parameters
Expand Down Expand Up @@ -1304,13 +1233,13 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None):
for t, ts in enumerate(shifted_time_chunks):
for _, f in enumerate(self.noncached_features):
f_index = self.features.index(f)
future = exe.submit(self.data_fill, t, ts, f_index, f)
future = exe.submit(self._single_data_fill,
t, ts, f_index, f)
futures[future] = {'t': t, 'fidx': f_index}

logger.info(f'Started adding {len(futures)} chunks '
f'to data array in {dt.now() - now}.')

interval = int(np.ceil(len(futures) / 10))
for i, future in enumerate(as_completed(futures)):
try:
future.result()
Expand All @@ -1320,9 +1249,8 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None):
'final data array.')
logger.exception(msg)
raise RuntimeError(msg) from e
if i % interval == 0:
logger.debug(f'Added {i+1} out of {len(futures)} '
'chunks to final data array')
logger.debug(f'Added {i+1} out of {len(futures)} '
'chunks to final data array')
logger.info('Finished building data array')

@abstractmethod
Expand Down
Loading
Loading