Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Jan 22, 2025
1 parent 82fe46e commit 925b2ac
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 47 deletions.
69 changes: 26 additions & 43 deletions src/pint/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def step(self) -> np.ndarray:
# Note, here we could do various checks like report
# matrix condition number or zero out low singular values.
# print 'log_10 cond=', np.log10(s.max()/s.min())
# Note, Check the threshold from data precision level.Borrowed from
# Note, Check the threshold from data precision level. Borrowed from
# np Curve fit.
threshold = self.threshold
if threshold is None:
Expand Down Expand Up @@ -2421,20 +2421,20 @@ def make_resids(self, model):
"""Update the residuals. Run after updating a model parameter."""
return self.make_combined_residuals(add_args=self.additional_args, model=model)

def get_designmatrix(self):
design_matrixs = []
fit_params = self.model.free_params
if len(self.fit_data) == 1:
design_matrixs.extend(
dmatrix_maker(self.fit_data[0], self.model, fit_params, offset=True)
for dmatrix_maker in self.designmatrix_makers
)
else:
design_matrixs.extend(
dmatrix_maker(self.fit_data[ii], self.model, fit_params, offset=True)
for ii, dmatrix_maker in enumerate(self.designmatrix_makers)
)
return combine_design_matrices_by_quantity(design_matrixs)
# def get_designmatrix(self):
# design_matrixs = []
# fit_params = self.model.free_params
# if len(self.fit_data) == 1:
# design_matrixs.extend(
# dmatrix_maker(self.fit_data[0], self.model, fit_params, offset=True)
# for dmatrix_maker in self.designmatrix_makers
# )
# else:
# design_matrixs.extend(
# dmatrix_maker(self.fit_data[ii], self.model, fit_params, offset=True)
# for ii, dmatrix_maker in enumerate(self.designmatrix_makers)
# )
# return combine_design_matrices_by_quantity(design_matrixs)

def get_noise_covariancematrix(self):
# TODO This needs to be more general
Expand Down Expand Up @@ -2534,37 +2534,17 @@ def fit_toas(
fitpv = self.model.get_params_dict("free", "num")
fitperrs = self.model.get_params_dict("free", "uncertainty")

# Define the linear system
d_matrix = self.get_designmatrix()
M, params, units = (
d_matrix.matrix,
d_matrix.derivative_params,
d_matrix.param_units,
)
ntmpar = len(fitp)

# Get residuals and TOA uncertainties in seconds
if i == 0:
self.update_resids()
# Since the residuals may not have the same unit. Thus the residual here
# has no unit.
residuals = self.resids._combined_resids
residuals = self.resids.calc_combined_resids()

# get any noise design matrices and weight vectors
if not full_cov:
# We assume the fit date type is toa
Mn = self.noise_designmatrix_maker(self.toas, self.model)
phi = self.model.noise_model_basis_weight(self.toas)
phiinv = np.zeros(M.shape[1])
if Mn is not None and phi is not None:
phiinv = np.concatenate((phiinv, 1 / phi))
new_d_matrix = combine_design_matrices_by_param(d_matrix, Mn)
M, params, units = (
new_d_matrix.matrix,
new_d_matrix.derivative_params,
new_d_matrix.param_units,
)

ntmpar = len(fitp)
M, params, units, units_d = self.get_designmatrix(full=(not full_cov))

# normalize the design matrix
M, norm = normalize_designmatrix(M, params)
Expand All @@ -2577,14 +2557,15 @@ def fit_toas(
cm = scipy.linalg.cho_solve(cf, M)
mtcm = np.dot(M.T, cm)
mtcy = np.dot(cm.T, residuals)

else:
phiinv /= norm**2
Nvec = self.scaled_all_sigma() ** 2
phi = self.model.full_basis_weight(self.toas)
phiinv_norm = 1 / phi / norm**2

Nvec = self.scaled_all_sigma() ** 2
cinv = 1 / Nvec

mtcm = np.dot(M.T, cinv[:, None] * M)
mtcm += np.diag(phiinv)
mtcm += np.diag(phiinv_norm)
mtcy = np.dot(M.T, cinv * residuals)

xhat, xvar = None, None
Expand Down Expand Up @@ -2618,12 +2599,14 @@ def fit_toas(

xvar = np.dot(Vt.T / s, Vt)
xhat = np.dot(Vt.T, np.dot(U.T, mtcy) / s)

newres = residuals - np.dot(M, xhat)

# compute linearized chisq
if full_cov:
chi2 = np.dot(newres, scipy.linalg.cho_solve(cf, newres))
else:
chi2 = np.dot(newres, cinv * newres) + np.dot(xhat, phiinv * xhat)
chi2 = np.dot(newres, cinv * newres) + np.dot(xhat, phiinv_norm * xhat)

# compute absolute estimates, normalized errors, covariance matrix
dpars = xhat / norm
Expand Down
9 changes: 5 additions & 4 deletions src/pint/models/timing_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,10 +1795,11 @@ def full_designmatrix(self, toas: TOAs) -> Union[

M = np.hstack((M_tm, M_nm)) if M_nm is not None else M_tm

par.extend([f"_NOISE_{ii}" for ii in range(M_nm.shape[1])])
M_units.extend(np.repeat(u.dimensionless_unscaled, M_nm.shape[1]))
if toas.is_wideband():
M_units_d.extend(np.repeat(pint.dmu / u.s, M_nm.shape[1]))
if M_nm is not None:
par.extend([f"_NOISE_{ii}" for ii in range(M_nm.shape[1])])
M_units.extend(np.repeat(u.dimensionless_unscaled, M_nm.shape[1]))
if toas.is_wideband():
M_units_d.extend(np.repeat(pint.dmu / u.s, M_nm.shape[1]))

return (M, par, M_units, M_units_d) if toas.is_wideband() else (M, par, M_units)

Expand Down

0 comments on commit 925b2ac

Please sign in to comment.