Skip to content

Commit

Permalink
make poly fitting compatible with freq slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Wilensky committed Feb 22, 2024
1 parent 87f6210 commit 7c7e455
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions SSINS/incoherent_noise_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,16 @@ def mean_subtract(self, freq_slice=slice(None), return_coeffs=False):
"""


wt = np.where(np.logical_not(self.metric_array.mask), self.weights_array[:, freq_slice], 0)
wtsq = np.where(np.logical_not(self.metric_array.mask), self.weights_square_array[:, freq_slice], 0)
wt = np.where(np.logical_not(self.metric_array.mask), self.weights_array, 0)
wtsq = np.where(np.logical_not(self.metric_array.mask), self.weights_square_array, 0)
if np.any(wt > 0):
weights_factor = np.where(wt > 0, wt / np.sqrt(self.C * wtsq), 0)
if self.dmatr is None:
fitspec = np.ma.average(self.metric_array[:, freq_slice], axis=0, weights=wt)
fitspec = np.ma.average(self.metric_array[:, freq_slice],
axis=0, weights=wt[:, freq_slice])
else:
tmatr, fmatr = self.dmatr
data = self.metric_array[:, freq_slice].data
data = self.metric_array.data

wt_data = wt * data # shape tfp

Expand All @@ -482,49 +483,72 @@ def mean_subtract(self, freq_slice=slice(None), return_coeffs=False):
if fmatr is None: # Separates over frequency

# make the operator on the left-hand-side of normal equations
lhs_op = np.tensordot(wt, ttmatr, axes=((0, ), (0, ))) # shape fpAa
lhs_op = np.tensordot(wt[:, freq_slice], ttmatr,
axes=((0, ), (0, ))) # shape fpAa

# Make the vector on the rhs
rhs_vec = np.tensordot(wt_data, tmatr, axes=((0, ), (0, ))) # shape fpa
rhs_vec = np.tensordot(wt_data[:, freq_slice], tmatr,
axes=((0, ), (0, ))) # shape fpa

soln = np.linalg.solve(lhs_op, rhs_vec) # shape fpa
fitspec = np.tensordot(tmatr, soln, axes=((-1,), (-1,))) # shape tfp
else:
if freq_slice == slice(None): # Using the whole band
Nsb = self.Nsubband
Nfreqs_fit = self.Nfreqs
freq_slice_fit = slice(None)
freq_slice_into_fitspec = slice(None)
else: # FIXME: Have to compute which subbands should be refit. Should really precompute these.
low_ind = np.digitize(freq_slice.start, self.subband_freq_chans) - 1
high_ind = np.digitize(freq_slice.stop, self.subband_freq_chans)
Nsb = high_ind - low_ind

low_freq = self.subband_freq_chans[low_ind]
high_freq = self.subband_freq_chans[high_ind]
Nfreqs_fit = high_freq - low_freq

freq_slice_fit = slice(low_freq, high_freq)

# Invert the slice above to be compatible with what MF expects
freq_chans = np.arange(self.Nfreqs)
freq_chans_fit = freq_chans[freq_slice_fit]
low_ind = np.where(freq_chans_fit == freq_slice.start)[0][0]
high_ind = np.where(freq_chans_fit == freq_slice.stop)[0][0]
freq_slice_into_fitspec = slice(low_ind, high_ind)

new_shape = (self.Ntimes, self.Nsubband, self.Nfreq_sb,
self.Npols)
new_shape = (self.Ntimes, Nsb, self.Nfreq_sb, self.Npols)

# Make RHS vec by multiplying by design matrix transpose
wt_data_res = wt_data.reshape(new_shape)
wt_data_res = wt_data[:, freq_slice_fit].reshape(new_shape)
rhs_tmult = np.tensordot(wt_data_res, tmatr, axes=((0, ), (0, ))) # shape Nwpa
rhs_vec = np.tensordot(rhs_tmult, fmatr, axes=((1,), (0,))) # shape Npab
Ncoeff = (self.time_order + 1) * (self.freq_order + 1)
rhs_vec = rhs_vec.reshape(self.Nsubband, self.Npols, Ncoeff)
rhs_vec = rhs_vec.reshape(Nsb, self.Npols, Ncoeff)

# Make the lhs_op as above but with extra steps for freq axis
wt_res = wt.reshape(new_shape)
wt_res = wt[:, freq_slice_fit].reshape(new_shape)
ffmatr = fmatr[:, np.newaxis] * fmatr[:, :, np.newaxis] # shape fBb

lhs_tmult = np.tensordot(wt_res, ttmatr, axes=((0,), (0,))) # shape NwpAa
lhs_op = np.tensordot(lhs_tmult, ffmatr, axes=((1,), (0, ))) # shape NpAaBb
lhs_op = lhs_op.swapaxes(3, 4) # shape NpABab
lhs_op = lhs_op.reshape(self.Nsubband, self.Npols, Ncoeff, Ncoeff)
lhs_op = lhs_op.reshape(Nsb, self.Npols, Ncoeff, Ncoeff)

soln = np.linalg.solve(lhs_op, rhs_vec)
soln = soln.reshape(self.Nsubband, self.Npols,
soln = soln.reshape(Nsb, self.Npols,
self.time_order + 1, self.freq_order + 1)
fitspec_tmult = np.tensordot(tmatr, soln, axes=((1, ), (2, ))) # Shape tNpb
fitspec_res = np.tensordot(fmatr, fitspec_tmult, axes=((1, ), (3, ))) # shape wtNp
fitspec_res = fitspec_res.transpose(1, 2, 0, 3)

fitspec = fitspec_res.reshape(self.Ntimes, self.Nfreqs, self.Npols)
fitspec = fitspec_res.reshape(self.Ntimes, Nfreqs_fit, self.Npols)
fitspec = fitspec[:, freq_slice_into_fitspec]

MS = (self.metric_array[:, freq_slice] / fitspec - 1) * weights_factor

MS = (self.metric_array[:, freq_slice] / fitspec - 1) * weights_factor[:, freq_slice]
else: # Whole slice has been flagged. Don't rely on solve returning 0.
MS[:] = np.ma.masked



if return_coeffs:
return(MS, soln)
else:
Expand Down

0 comments on commit 7c7e455

Please sign in to comment.