Skip to content

Commit

Permalink
refactor finitization code
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Oct 22, 2024
1 parent c76035e commit 446074a
Showing 1 changed file with 66 additions and 36 deletions.
102 changes: 66 additions & 36 deletions py/specprodDB/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def convert(cls, data, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
#
# Version has no floating-point columns.
#
# data = finitize(data)
data_columns = list()
for column in cls.__table__.columns:
if column.name == 'id':
Expand Down Expand Up @@ -256,6 +260,7 @@ def convert(cls, data, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
expand_dchisq = ('dchisq_psf', 'dchisq_rex', 'dchisq_dev', 'dchisq_exp', 'dchisq_ser',)
data_columns = list()
for column in cls.__table__.columns:
Expand Down Expand Up @@ -353,6 +358,7 @@ def convert(cls, data, survey=None, tileid=None, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
check_columns = {'survey': survey, 'tileid': tileid}
for column in check_columns:
if check_columns[column] is None:
Expand Down Expand Up @@ -457,6 +463,7 @@ def convert(cls, data, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
data_columns = list()
for column in cls.__table__.columns:
data_column = data[column.name.upper()][row_index].tolist()
Expand Down Expand Up @@ -549,6 +556,7 @@ def convert(cls, data, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
data_columns = list()
for column in cls.__table__.columns:
if column.name == 'date_obs':
Expand Down Expand Up @@ -639,6 +647,7 @@ def convert(cls, data, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
data_columns = list()
for column in cls.__table__.columns:
if column.name == 'frameid':
Expand Down Expand Up @@ -727,6 +736,7 @@ def convert(cls, data, tileid=None, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
if tileid is None:
try:
tileid = data.meta['TILEID']
Expand Down Expand Up @@ -801,6 +811,10 @@ def convert(cls, data, tileid=None, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
#
# Potential table has no floating point columns.
#
# data = finitize(data)
if tileid is None:
try:
tileid = data.meta['TILEID']
Expand Down Expand Up @@ -980,6 +994,7 @@ def convert(cls, data, survey=None, program=None, row_index=None):
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
default_columns = {'spgrp': 'healpix',
'sv_nspec': 0, 'main_nspec': 0, 'zcat_nspec': 0,
'sv_primary': False, 'main_primary': False, 'zcat_primary': False}
Expand Down Expand Up @@ -1182,6 +1197,7 @@ def convert(cls, data, survey=None, program=None, tileid=None, night=None,
row_index = np.arange(len(data))
if len(row_index) == 0:
return []
data = finitize(data)
default_columns = {'spgrp': spgrp,
'sv_nspec': 0, 'main_nspec': 0, 'zcat_nspec': 0,
'sv_primary': False, 'main_primary': False, 'zcat_primary': False}
Expand Down Expand Up @@ -1292,8 +1308,56 @@ def deduplicate_targetid(data):
return load_rows


def load_file(filepaths, tcls, hdu=1, row_filter=None, q3c=None, chunksize=50000,
replacement_value=-9999.0):
def finitize(data, replacement_value=-9999.0):
"""Convert ``NaN`` and other non-finite floating point values.
Parameters
----------
data : :class:`~astropy.table.Table`
Data table to convert.
replacement_value : :class:`float`, optional
Replace ``NaN`` or other non-finite values with this value (default -9999.0).
Returns
-------
:class:`~astropy.table.Table`
The input `data` modified in-place.
"""
try:
colnames = data.names
except AttributeError:
colnames = data.colnames
masked = dict()
for col in colnames:
if data[col].dtype.kind == 'f':
if isinstance(data[col], MaskedColumn):
bad = ~np.isfinite(data[col].data.data)
masked[col] = True
else:
bad = ~np.isfinite(data[col])
if np.any(bad):
if bad.ndim == 1:
log.warning("%d rows of bad data detected in column " +
"%s.", bad.sum(), col)
elif bad.ndim == 2:
nbadrows = len(bad.sum(1).nonzero()[0])
nbaditems = bad.sum(1).sum()
log.warning("%d rows (%d items) of bad data detected in column " +
"%s.", nbadrows, nbaditems, col)
else:
log.warning("Bad data detected in high-dimensional column %s.", col)
if col in masked:
log.debug("data['%s'].data.data[bad] = %f", col, replacement_value)
log.debug("data['%s'].mask[bad] = False", col)
data[col].data.data[bad] = replacement_value
data[col].mask[bad] = False
else:
log.debug("data['%s'][bad] = %f", col, replacement_value)
data[col][bad] = replacement_value
return data


def load_file(filepaths, tcls, hdu=1, row_filter=None, q3c=None, chunksize=50000):
"""Load data file into the database, assuming that column names map
to database column names with no surprises.
Expand All @@ -1313,8 +1377,6 @@ def load_file(filepaths, tcls, hdu=1, row_filter=None, q3c=None, chunksize=50000
named `q3c`.
chunksize : :class:`int`, optional
If set, load database `chunksize` rows at a time (default 50000).
replacement_value : :class:`float`, optional
Replace ``NaN`` or other non-finite values with this value (default -9999.0).
Returns
-------
Expand All @@ -1339,38 +1401,6 @@ def load_file(filepaths, tcls, hdu=1, row_filter=None, q3c=None, chunksize=50000
else:
log.error("Unrecognized data file, %s!", filepath)
return
try:
colnames = data.names
except AttributeError:
colnames = data.colnames
masked = dict()
for col in colnames:
if data[col].dtype.kind == 'f':
if isinstance(data[col], MaskedColumn):
bad = ~np.isfinite(data[col].data.data)
masked[col] = True
else:
bad = ~np.isfinite(data[col])
if np.any(bad):
if bad.ndim == 1:
log.warning("%d rows of bad data detected in column " +
"%s of %s.", bad.sum(), col, filepath)
elif bad.ndim == 2:
nbadrows = len(bad.sum(1).nonzero()[0])
nbaditems = bad.sum(1).sum()
log.warning("%d rows (%d items) of bad data detected in column " +
"%s of %s.", nbadrows, nbaditems, col, filepath)
else:
log.warning("Bad data detected in high-dimensional column %s of %s.", col, filepath)
if col in masked:
log.debug("data['%s'].data.data[bad] = %f", col, replacement_value)
log.debug("data['%s'].mask[bad] = False", col)
data[col].data.data[bad] = replacement_value
data[col].mask[bad] = False
else:
log.debug("data['%s'][bad] = %f", col, replacement_value)
data[col][bad] = replacement_value
log.info("Integrity check complete on %s.", tn)
if row_filter is None:
good_rows = np.ones((len(data),), dtype=bool)
else:
Expand Down

0 comments on commit 446074a

Please sign in to comment.