Skip to content

Commit

Permalink
more future proofing
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed May 8, 2024
1 parent e76e44e commit 440581d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 57 deletions.
25 changes: 10 additions & 15 deletions SSINS/incoherent_noise_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class INS(UVFlag):
the UVFlag class, a member of the pyuvdata software package.
"""

def __init__(self, indata=None, history="", label="", use_future_array_shapes=False, run_check=True,
def __init__(self, indata=None, history="", label="", run_check=True,
check_extra=True, run_check_acceptability=True, order=0, mask_file=None,
match_events_file=None, spectrum_type="cross",
use_integration_weights=False, nsample_default=1, **kwargs):
Expand All @@ -32,8 +32,6 @@ def __init__(self, indata=None, history="", label="", use_future_array_shapes=Fa
saved INS object. If None, initializes an empty object.
history (str): History to append to object's history string.
label (str): String used for labeling the object (e.g. 'MWA Highband').
use_future_array_shapes (bool): Option to convert to the future planned array shapes before the changes go
into effect by removing the spectral window axis (potentially necessary for initializing from SS).
run_check (bool): Whether to check that the object's parameters have the right shape (default True).
check_extra (bool): Whether to also check optional parameters (default True)
run_check_acceptability (bool): Whether to check that the object's parameters take appropriate values
Expand All @@ -58,13 +56,12 @@ def __init__(self, indata=None, history="", label="", use_future_array_shapes=Fa

self.set_extra_params(order=order, spectrum_type=spectrum_type, use_integration_weights=use_integration_weights,
nsample_default=nsample_default, mask_file=mask_file, match_events_file=match_events_file)

super().__init__(indata=indata, mode='metric', copy_flags=False, waterfall=False, history=history, label=label,
use_future_array_shapes=use_future_array_shapes, run_check=run_check, check_extra=check_extra,
use_future_array_shapes=True, run_check=run_check, check_extra=check_extra,
run_check_acceptability=run_check_acceptability, **kwargs)


def read(self, filename, history="", use_future_array_shapes=False, run_check=True, check_extra=True,
def read(self, filename, history="", run_check=True, check_extra=True,
run_check_acceptability=True, **kwargs):
"""
Populate the object by reading a file. This is called during instantiation, but due to inheritance issues, is not
Expand All @@ -74,8 +71,6 @@ def read(self, filename, history="", use_future_array_shapes=False, run_check=Tr
Args:
filename (str): Path to the file to be read.
history (str): History to be appended to the object's history string.
use_future_array_shapes (bool): Whether to assume a spectral index axis -- should do nothing since all INS
objects should be written out in waterfall mode.
run_check (bool): Whether to check that the object's parameters have the right shape (default True).
check_extra (bool): Whether to also check optional parameters (default True)
run_check_acceptability (bool): Whether to check that the object's parameters take appropriate values
Expand All @@ -92,7 +87,8 @@ def read(self, filename, history="", use_future_array_shapes=False, run_check=Tr
attrs = ("order", "use_integration_weights", "nsample_default", "mask_file", "match_events_file", "spectrum_type", "spec_type_str")
attr_dict = {attr: deepcopy(getattr(self, attr)) for attr in attrs}

super().read(filename, history=history, use_future_array_shapes=use_future_array_shapes, run_check=run_check,
kwargs.pop("use_future_array_shapes", None)
super().read(filename, history=history, use_future_array_shapes=True, run_check=run_check,
check_extra=check_extra, run_check_acceptability=run_check_acceptability, **kwargs)

self._pol_check()
Expand All @@ -118,7 +114,7 @@ def read(self, filename, history="", use_future_array_shapes=False, run_check=Tr
self.metric_array.mask = self.weights_array == 0
else:
# Read in the flag array
flag_uvf = UVFlag(self.mask_file)
flag_uvf = UVFlag(self.mask_file, use_future_array_shapes=True)
self.metric_array.mask = np.copy(flag_uvf.flag_array)
del flag_uvf

Expand Down Expand Up @@ -210,7 +206,7 @@ def _pol_check(self):
" currently support pseudo-Stokes spectra.")

def from_uvdata(self, indata, mode="metric", copy_flags=False, waterfall=False, history="",
label="", use_future_array_shapes=False, run_check=True, check_extra=True,
label="", run_check=True, check_extra=True,
run_check_acceptability=True, **kwargs):
"""
Construct an INS object from a UVData (SS) object. This is called during instantiation, but due to inheritance
Expand All @@ -223,8 +219,6 @@ def from_uvdata(self, indata, mode="metric", copy_flags=False, waterfall=False,
copy_flags (bool): Does nothing -- for compatibility with base class.
waterfall (bool): Does nothing -- for compatibility with base class.
history (str): History to be appended to history string of object.
use_future_array_shapes (bool): Option to convert to the future planned array shapes before the changes go
into effect by removing the spectral window axis (potentially necessary for initializing from SS).
run_check (bool): Whether to check that the object's parameters have the right shape (default True).
check_extra (bool): Whether to also check optional parameters (default True)
run_check_acceptability (bool): Whether to check that the object's parameters take appropriate values
Expand All @@ -235,8 +229,9 @@ def from_uvdata(self, indata, mode="metric", copy_flags=False, waterfall=False,
self._has_data_params_check()
# Must be in metric mode, do not copy flags -- have own flag handling
# will turn to waterfall later. These are just here to match signature.
kwargs.pop("use_future_array_shapes", None)
super().from_uvdata(indata, mode="metric", copy_flags=False, waterfall=False,
history=history, label=label, use_future_array_shapes=use_future_array_shapes,
history=history, label=label, use_future_array_shapes=True,
run_check=run_check, check_extra=check_extra, run_check_acceptability=run_check_acceptability,
**kwargs)

Expand All @@ -255,7 +250,7 @@ def from_uvdata(self, indata, mode="metric", copy_flags=False, waterfall=False,
# Set nsample default if some are zero
indata.nsample_array[indata.nsample_array == 0] = self.nsample_default
# broadcast problems with single pol
self.weights_array *= (indata.integration_time[:, np.newaxis, np.newaxis, np.newaxis] * indata.nsample_array)
self.weights_array *= (indata.integration_time[:, np.newaxis, np.newaxis] * indata.nsample_array)

cross_bool = self.ant_1_array != self.ant_2_array
auto_bool = self.ant_1_array == self.ant_2_array
Expand Down
17 changes: 9 additions & 8 deletions SSINS/sky_subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def read(self, filename, diff=False, flag_choice=None, INS=None, custom=None,
warnings.warn("SS.read will be renamed to SS.read_data soon to avoid"
" conflicts with UVData.read.", category=PendingDeprecationWarning)

super().read(filename, **kwargs)
kwargs.pop("use_future_array_shapes", None)
super().read(filename, use_future_array_shapes=True, **kwargs)
if self.Nphase > 1:
raise NotImplementedError("SSINS cannot handle files with more than one phase center.")

Expand Down Expand Up @@ -92,7 +93,7 @@ def apply_flags(self, flag_choice=None, INS=None, custom=None):
# Skip if nothing to flag
if len(freq_inds) > 0:
blt_inds = np.where(self.time_array == time)
self.data_array.mask[blt_inds, :, freq_inds, pol_inds] = True
self.data_array.mask[blt_inds, freq_inds, pol_inds] = True
elif flag_choice == 'custom':
self.data_array.mask[:] = False
if custom is not None:
Expand Down Expand Up @@ -181,14 +182,14 @@ def diff(self, flag_choice=None, INS=None, custom=None):
blend = bltaxisboundaries2[bl_num + 1] # index in baseline-time axis to end

blt_slice = slice(blstart, blstart + len_diff)
self.data_array[blt_slice, :, :, :] = diff_dat
self.data_array[blt_slice] = diff_dat
"""The differenced visibilities. Complex array of shape (Nblts, Nspws, Nfreqs, Npols)."""
self.flag_array[blt_slice, :, :, :] = diff_flags
self.flag_array[blt_slice] = diff_flags
"""The flag array, which results from boolean OR of the flags corresponding to visibilities that are differenced from one another."""

self.time_array[blt_slice] = diff_times
"""The center time of the differenced visibilities. Length Nblts."""
self.nsample_array[blt_slice, :, :, :] = diff_nsamples
self.nsample_array[blt_slice] = diff_nsamples
"""See pyuvdata documentation. Here we average the nsample_array of the visibilities that are differenced"""

where_bl = np.where(self.baseline_array == bl)
Expand Down Expand Up @@ -245,7 +246,7 @@ def MLE_calc(self):
frequency. Used for developing a mixture fit.
"""

self.MLE = np.sqrt(0.5 * np.mean(np.absolute(self.data_array)**2, axis=(0, 1, -1)))
self.MLE = np.sqrt(0.5 * np.mean(np.absolute(self.data_array)**2, axis=(0, -1)))

def mixture_prob(self, bins):
"""
Expand All @@ -265,7 +266,7 @@ def mixture_prob(self, bins):
if type(bins) == str:
_, bins = np.histogram(np.abs(self.data_array[np.logical_not(self.data_array.mask)]))

N_spec = np.sum(np.logical_not(self.data_array.mask), axis=(0, 1, -1))
N_spec = np.sum(np.logical_not(self.data_array.mask), axis=(0, -1))
N_total = np.sum(N_spec)

# Calculate the fraction belonging to each frequency
Expand Down Expand Up @@ -345,7 +346,7 @@ def write(self, filename_out, file_type_out, UV=None, filename_in=None,
self.reorder_blts(order='baseline')
if UV is None:
UV = UVData()
UV.read(filename_in, **read_kwargs)
UV.read(filename_in, use_future_array_shapes=True, **read_kwargs)

# Option to keep old flags
if not combine:
Expand Down
12 changes: 6 additions & 6 deletions SSINS/tests/test_INS.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,22 @@ def test_mask_to_flags(tmp_path, tv_obs, tv_testfile):
ss.read(tv_testfile, diff=True)

uvd = UVData()
uvd.read(tv_testfile)
uvd.read(tv_testfile, use_future_array_shapes=True)

uvf = UVFlag(uvd, mode='flag', waterfall=True)
uvf = UVFlag(uvd, mode='flag', waterfall=True, use_future_array_shapes=True)
# start with some flags so that we can test the intended OR operation
uvf.flag_array[6, :] = True
ins = INS(ss)

# Check error handling
with pytest.raises(ValueError):
bad_uvf = UVFlag(uvd, mode='metric', waterfall=True)
bad_uvf = UVFlag(uvd, mode='metric', waterfall=True, use_future_array_shapes=True)
err_uvf = ins.flag_uvf(uvf=bad_uvf)
with pytest.raises(ValueError):
bad_uvf = UVFlag(uvd, mode='flag', waterfall=False)
bad_uvf = UVFlag(uvd, mode='flag', waterfall=False, use_future_array_shapes=True)
err_uvf = ins.flag_uvf(uvf=bad_uvf)
with pytest.raises(ValueError):
bad_uvf = UVFlag(uvd, mode='flag', waterfall=True)
bad_uvf = UVFlag(uvd, mode='flag', waterfall=True, use_future_array_shapes=True)
# Pretend the data is off by 1 day
bad_uvf.time_array += 1
err_uvf = ins.flag_uvf(uvf=bad_uvf)
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_mask_to_flags(tmp_path, tv_obs, tv_testfile):

# Test write/read
ins.write(prefix, output_type='flags', uvf=uvf)
read_uvf = UVFlag(flags_outfile, mode='flag', waterfall=True)
read_uvf = UVFlag(flags_outfile, mode='flag', waterfall=True, use_future_array_shapes=True)
# Check equality
assert read_uvf == uvf, "UVFlag object differs after read"

Expand Down
74 changes: 49 additions & 25 deletions SSINS/tests/test_SS.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,20 @@ def test_diff(tv_testfile):
uv = UVData()

# Read in two times and two baselines of data, so that the diff is obvious.
uv.read(tv_testfile, read_data=False)
uv.read(tv_testfile, read_data=False, use_future_array_shapes=True)
times = np.unique(uv.time_array)[:2]
bls = [(uv.antenna_numbers[0], uv.antenna_numbers[1]),
(uv.antenna_numbers[0], uv.antenna_numbers[2])]
uv.read(tv_testfile, times=times, bls=bls)
if hasattr(uv, "telescope"):
bls = [
(uv.telescope.antenna_numbers[0], uv.telescope.antenna_numbers[1]),
(uv.telescope.antenna_numbers[0], uv.telescope.antenna_numbers[2])
]
else:
# This can be removed when we require pyuvdata>=3.0
bls = [
(uv.antenna_numbers[0], uv.antenna_numbers[1]),
(uv.antenna_numbers[0], uv.antenna_numbers[2])
]
uv.read(tv_testfile, times=times, bls=bls, use_future_array_shapes=True)
uv.reorder_blts(order='baseline')

diff_dat = uv.data_array[1::2] - uv.data_array[::2]
Expand All @@ -71,8 +80,26 @@ def test_diff(tv_testfile):
assert np.all(ss.nsample_array == diff_nsamples), "nsample_array is different!"
assert np.all(ss.integration_time == diff_ints), "Integration times are different"
assert np.all(ss.uvw_array == diff_uvw), "uvw_arrays disagree!"
assert np.all(ss.ant_1_array == np.array([uv.antenna_numbers[0], uv.antenna_numbers[0]])), f"ant_1_array disagrees!"
assert np.all(ss.ant_2_array == np.array([uv.antenna_numbers[1], uv.antenna_numbers[2]])), "ant_2_array disagrees!"
if hasattr(uv, "telescope"):
# This can be removed when we require pyuvdata>=3.0
assert np.all(
ss.ant_1_array == np.array(
[uv.telescope.antenna_numbers[0], uv.telescope.antenna_numbers[0]]
)
), "ant_1_array disagrees!"
assert np.all(
ss.ant_2_array == np.array(
[uv.telescope.antenna_numbers[1], uv.telescope.antenna_numbers[2]]
)
), "ant_2_array disagrees!"
else:
# This can be removed when we require pyuvdata>=3.0
assert np.all(
ss.ant_1_array == np.array([uv.antenna_numbers[0], uv.antenna_numbers[0]])
), "ant_1_array disagrees!"
assert np.all(
ss.ant_2_array == np.array([uv.antenna_numbers[1], uv.antenna_numbers[2]])
), "ant_2_array disagrees!"
assert np.all(ss.phase_center_app_dec == diff_pcad)
assert np.all(ss.phase_center_app_ra == diff_pcar)
assert np.all(ss.phase_center_frame_pa == diff_pcfp)
Expand Down Expand Up @@ -109,9 +136,9 @@ def test_apply_flags(tv_obs, tv_testfile):
ins = INS(insfile)
ins.metric_array.mask[[2, 4], 1, :] = True
ss.apply_flags(flag_choice='INS', INS=ins)
assert np.all(ss.data_array.mask[2::ss.Ntimes, :, 1, :]), "The 2nd time was not flagged."
assert np.all(ss.data_array.mask[4::ss.Ntimes, :, 1, :]), "The 4th time was not flagged."
assert not np.any(ss.data_array.mask[:, :, [0] + list(range(2, ss.Nfreqs)), :]), "Channels were flagged that should not have been."
assert np.all(ss.data_array.mask[2::ss.Ntimes, 1, :]), "The 2nd time was not flagged."
assert np.all(ss.data_array.mask[4::ss.Ntimes, 1, :]), "The 4th time was not flagged."
assert not np.any(ss.data_array.mask[:, [0] + list(range(2, ss.Nfreqs)), :]), "Channels were flagged that should not have been."
assert ss.flag_choice == 'INS'

# Make a bad time array to test an error
Expand Down Expand Up @@ -176,8 +203,8 @@ def test_rev_ind(tv_testfile):
ind = np.unravel_index(np.absolute(ss.data_array).argmax(), ss.data_array.shape)
# Convert the blt to a time index
t = ind[0] // ss.Nbls
f = ind[2]
p = ind[3]
f = ind[1]
p = ind[2]

# Make the waterfall histogram
wf_hist = ss.rev_ind(band)
Expand Down Expand Up @@ -208,7 +235,7 @@ def test_write(tmp_path, tv_testfile):

blt_inds = np.where(ss.time_array == np.unique(ss.time_array)[10])
custom = np.zeros_like(ss.data_array.mask)
custom[blt_inds, :, 64:128, :] = 1
custom[blt_inds, 64:128, :] = 1

# Flags the first time and no others
ss.apply_flags(flag_choice='custom', custom=custom)
Expand All @@ -219,17 +246,17 @@ def test_write(tmp_path, tv_testfile):

# Check if the flags propagated correctly
UV = UVData()
UV.read(outfile)
UV.read(outfile, use_future_array_shapes=True)
blt_inds = np.isin(UV.time_array, np.unique(UV.time_array)[10:12])
assert np.all(UV.flag_array[blt_inds, :, 64:128, :]), "Not all expected flags were propagated"
assert np.all(UV.flag_array[blt_inds, 64:128, :]), "Not all expected flags were propagated"

new_blt_inds = np.logical_not(np.isin(UV.time_array, np.unique(UV.time_array)[10:12]))
assert not np.any(UV.flag_array[new_blt_inds, :, 64:128, :]), "More flags were made than expected"
assert not np.any(UV.flag_array[new_blt_inds, 64:128, :]), "More flags were made than expected"

# Test bad read.
bad_uv_filepath = os.path.join(DATA_PATH, '1061312640_mix.uvfits')
bad_uv = UVData()
bad_uv.read(bad_uv_filepath)
bad_uv.read(bad_uv_filepath, use_future_array_shapes=True)
with pytest.raises(ValueError, match="UVData and SS objects were found to be incompatible."):
ss.write(outfile, 'uvfits', bad_uv)

Expand All @@ -242,15 +269,15 @@ def test_read_multifiles(tmp_path, tv_obs, tv_testfile):

# Read in a file's metadata and split it into two objects
uvd_full = UVData()
uvd_full.read(tv_testfile, read_data=False)
uvd_full.read(tv_testfile, read_data=False, use_future_array_shapes=True)
times1 = np.unique(uvd_full.time_array)[:14]
times2 = np.unique(uvd_full.time_array)[14:]

# Write two separate files to be read in later
uvd_split1 = UVData()
uvd_split2 = UVData()
uvd_split1.read(tv_testfile, times=times1)
uvd_split2.read(tv_testfile, times=times2)
uvd_split1.read(tv_testfile, times=times1, use_future_array_shapes=True)
uvd_split2.read(tv_testfile, times=times2, use_future_array_shapes=True)
uvd_split1.write_uvfits(new_fp1)
uvd_split2.write_uvfits(new_fp2)

Expand Down Expand Up @@ -284,19 +311,16 @@ def test_newmask(tv_testfile):


def test_Nphase_gt_1(tmp_path, tv_testfile):
uvd = UVData()
uvd.read(tv_testfile, read_data=False)
uvd = UVData.from_file(tv_testfile, read_data=False, use_future_array_shapes=True)

# Split the object so we can phase to separate locations
unique_times = np.unique(uvd.time_array)
first_times = unique_times[:10]
last_times = unique_times[-10:]

uvfirst = UVData()
uvfirst.read(tv_testfile, times=first_times)
uvfirst = UVData.from_file(tv_testfile, times=first_times, use_future_array_shapes=True)

uvlast = UVData()
uvlast.read(tv_testfile, times=last_times)
uvlast = UVData.from_file(tv_testfile, times=last_times, use_future_array_shapes=True)

# Adjust phase of one object and write new file
og_pc_ra = uvd.phase_center_app_ra[0]
Expand Down
5 changes: 2 additions & 3 deletions SSINS/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,10 @@ def test_write_meta(tmp_path):
testfile = os.path.join(DATA_PATH, f"{obs}.uvfits")
prefix = os.path.join(tmp_path, f"{obs}_test")

uvd = UVData()
uvd.read(testfile, freq_chans=np.arange(32))
uvd = UVData.from_file(testfile, freq_chans=np.arange(32), use_future_array_shapes=True)
ss = SS()
ss.read(testfile, freq_chans=np.arange(32), diff=True)
uvf = UVFlag(uvd, mode="flag", waterfall=True)
uvf = UVFlag(uvd, mode="flag", waterfall=True, use_future_array_shapes=True)
ins = INS(ss)

ins.metric_array[:] = 1
Expand Down

0 comments on commit 440581d

Please sign in to comment.