Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 18, 2024
1 parent ecd42b4 commit 0ec7655
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 55 deletions.
124 changes: 77 additions & 47 deletions simpleDS/delay_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,7 @@ def __init__(self, uv=None, uvb=None, trcvr=None, taper=None):
expected_units=_conversion_units,
)
_tconversion_units = (
(
units.mK**2
* units.Mpc**3
/ (units.K * units.sr * units.Hz**3) ** 2
),
(units.mK**2 * units.Mpc**3 / (units.K * units.sr * units.Hz**3) ** 2),
(
units.mK**2
* (units.Mpc / littleh) ** 3
Expand Down Expand Up @@ -1478,9 +1474,11 @@ def _select_preprocess(
"numbers and converting to antenna pairs."
)
bls = [
uvutils.baseline_to_antnums(bl, self.Nants_telescope)
if isinstance(bl, (int, np.int_, np.intc))
else bl
(
uvutils.baseline_to_antnums(bl, self.Nants_telescope)
if isinstance(bl, (int, np.int_, np.intc))
else bl
)
for bl in bls
]
else:
Expand Down Expand Up @@ -2316,9 +2314,11 @@ def _get_data(
nsamples = nsamples[slice]
else:
slice = tuple(
np.s_[:]
if _cnt != visdata_axes[axis_ind]
else (visdata_ind_set[axis_ind] or np.s_[:])
(
np.s_[:]
if _cnt != visdata_axes[axis_ind]
else (visdata_ind_set[axis_ind] or np.s_[:])
)
for _cnt in range(len(visdata_dset.shape))
)
if count == 0:
Expand Down Expand Up @@ -2363,9 +2363,11 @@ def _get_data(

if _power:
slice = tuple(
np.s_[:]
if _cnt != power_axes[axis_ind]
else (power_ind_set[axis_ind] or np.s_[:])
(
np.s_[:]
if _cnt != power_axes[axis_ind]
else (power_ind_set[axis_ind] or np.s_[:])
)
for _cnt in range(len(power_dset.shape))
)
if count == 0:
Expand All @@ -2377,9 +2379,11 @@ def _get_data(

if _noise:
slice = tuple(
np.s_[:]
if _cnt != power_axes[axis_ind]
else (power_ind_set[axis_ind] or np.s_[:])
(
np.s_[:]
if _cnt != power_axes[axis_ind]
else (power_ind_set[axis_ind] or np.s_[:])
)
for _cnt in range(len(noise_power_dset.shape))
)
if count == 0:
Expand All @@ -2406,9 +2410,11 @@ def _get_data(

for count, axis_ind in enumerate(power_inds):
slice = tuple(
np.s_[:]
if _cnt != power_axes[axis_ind]
else (power_ind_set[axis_ind] or np.s_[:])
(
np.s_[:]
if _cnt != power_axes[axis_ind]
else (power_ind_set[axis_ind] or np.s_[:])
)
for _cnt in range(len(thermal_dset.shape))
)
if count == 0:
Expand Down Expand Up @@ -3068,18 +3074,26 @@ def write_partial(self, filename):
non_reg = [np.ravel(indices[i]) for i in non_reg_inds]
for mesh_ind in product(*non_reg):
_inds = tuple(
indices[_cnt]
if _cnt in reg_spaced
else mesh_ind[np.nonzero(_cnt == non_reg_inds)[0].item()]
(
indices[_cnt]
if _cnt in reg_spaced
else mesh_ind[
np.nonzero(_cnt == non_reg_inds)[0].item()
]
)
for _cnt in range(len(visdata_dset.shape))
)
data_inds = tuple(
indices[_cnt]
if _cnt in reg_spaced
else np.nonzero(
non_reg[np.nonzero(_cnt == non_reg_inds)[0].item()]
== mesh_ind[np.nonzero(_cnt == non_reg_inds)[0].item()]
)[0].item()
(
indices[_cnt]
if _cnt in reg_spaced
else np.nonzero(
non_reg[np.nonzero(_cnt == non_reg_inds)[0].item()]
== mesh_ind[
np.nonzero(_cnt == non_reg_inds)[0].item()
]
)[0].item()
)
for _cnt in range(len(visdata_dset.shape))
)
visdata_dset[_inds] = self.data_array[data_inds].to_value(
Expand Down Expand Up @@ -3126,18 +3140,26 @@ def write_partial(self, filename):
non_reg = [np.ravel(indices[i]) for i in non_reg_inds]
for mesh_ind in product(*non_reg):
_inds = tuple(
indices[_cnt]
if _cnt in reg_spaced
else mesh_ind[np.nonzero(_cnt == non_reg_inds)[0].item()]
(
indices[_cnt]
if _cnt in reg_spaced
else mesh_ind[
np.nonzero(_cnt == non_reg_inds)[0].item()
]
)
for _cnt in range(len(visdata_dset.shape))
)
data_inds = tuple(
indices[_cnt]
if _cnt in reg_spaced
else np.nonzero(
non_reg[np.nonzero(_cnt == non_reg_inds)[0].item()]
== mesh_ind[np.nonzero(_cnt == non_reg_inds)[0].item()]
)[0].item()
(
indices[_cnt]
if _cnt in reg_spaced
else np.nonzero(
non_reg[np.nonzero(_cnt == non_reg_inds)[0].item()]
== mesh_ind[
np.nonzero(_cnt == non_reg_inds)[0].item()
]
)[0].item()
)
for _cnt in range(len(visdata_dset.shape))
)
data_power_dset[_inds] = self.power_array[data_inds].to_value(
Expand Down Expand Up @@ -3168,18 +3190,26 @@ def write_partial(self, filename):
non_reg = [np.ravel(indices[i]) for i in non_reg_inds]
for mesh_ind in product(*non_reg):
_inds = tuple(
indices[_cnt]
if _cnt in reg_spaced
else mesh_ind[np.nonzero(_cnt == non_reg_inds)[0].item()]
(
indices[_cnt]
if _cnt in reg_spaced
else mesh_ind[
np.nonzero(_cnt == non_reg_inds)[0].item()
]
)
for _cnt in range(len(thermal_dset.shape))
)
data_inds = tuple(
indices[_cnt]
if _cnt in reg_spaced
else np.nonzero(
non_reg[np.nonzero(_cnt == non_reg_inds)[0].item()]
== mesh_ind[np.nonzero(_cnt == non_reg_inds)[0].item()]
)[0].item()
(
indices[_cnt]
if _cnt in reg_spaced
else np.nonzero(
non_reg[np.nonzero(_cnt == non_reg_inds)[0].item()]
== mesh_ind[
np.nonzero(_cnt == non_reg_inds)[0].item()
]
)[0].item()
)
for _cnt in range(len(thermal_dset.shape))
)

Expand Down
2 changes: 1 addition & 1 deletion simpleDS/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __eq__(self, other):
self_lower[key], other_lower[key]
):
values_close = False
except (TypeError):
except TypeError:
# this isn't a type that can be
# handled by np.isclose,
# test for equality
Expand Down
14 changes: 7 additions & 7 deletions simpleDS/tests/test_delay_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_properties(self):
this_param = getattr(self.dspec_object, v)
try:
assert rand_num == this_param.value
except (AssertionError):
except AssertionError:
print(
"setting {prop_name} to a random number failed".format(prop_name=k)
)
Expand Down Expand Up @@ -1072,9 +1072,7 @@ def test_delay_spectrum_thermal_power_units():

dspec_object.calculate_delay_spectrum()
dspec_object.add_trcvr(144 * units.K)
assert (units.mK**2 * units.Mpc**3).is_equivalent(
dspec_object.thermal_power.unit
)
assert (units.mK**2 * units.Mpc**3).is_equivalent(dspec_object.thermal_power.unit)


def test_delay_spectrum_thermal_power_shape():
Expand Down Expand Up @@ -1555,9 +1553,11 @@ def test_select(ds_uvfits_and_uvb, input):
}
if "bls" in uvd_input and not isinstance(uvd_input["bls"], tuple):
uvd_input["bls"] = [
uvd.baseline_to_antnums(bl)
if isinstance(bl, (int, np.int_, np.intc))
else bl
(
uvd.baseline_to_antnums(bl)
if isinstance(bl, (int, np.int_, np.intc))
else bl
)
for bl in uvd_input["bls"]
]

Expand Down

0 comments on commit 0ec7655

Please sign in to comment.