Skip to content

Commit

Permalink
skip W/DKI and SMI if max b-value < 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Jenny Chen authored and Jenny Chen committed Oct 14, 2024
1 parent 289a2ad commit 74b88cf
Showing 1 changed file with 94 additions and 81 deletions.
175 changes: 94 additions & 81 deletions designer2/tmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,24 @@ def execute(): #pylint: disable=unused-variable
bvec_dki = bvec[:,bval < maxb]
bval_dki = bval[bval < maxb]

if len(set(np.round(bval_dki, 2))) <= 2:
logger.warning("Fewer than 2 nonzero shells found for DKI fitting.", extra={"bval_dki": bval_dki.tolist()})
if np.max(bval_dki) < 2:
logger.warning("Max b-value is <2. Skipping DKI and WDKI fit.")
app.ARGS.DKI=False
app.ARGS.WDKI=False
else:
if len(set(np.round(bval_dki, 2))) <= 2:
logger.warning("Fewer than 2 nonzero shells found for DKI fitting.", extra={"bval_dki": bval_dki.tolist()})

grad_dki = np.hstack((bvec_dki.T, bval_dki[None,...].T))
dki = tensor.TensorFitting(grad_dki, int(app.ARGS.n_cores))
dt_dki, s0_dki, b_dki = dki.dki_fit(dwi_dki, mask, constraints=constraints)
logger.info("DKI fit completed.", extra={"dt_dki_shape": dt_dki.shape})
grad_dki = np.hstack((bvec_dki.T, bval_dki[None,...].T))
dki = tensor.TensorFitting(grad_dki, int(app.ARGS.n_cores))
dt_dki, s0_dki, b_dki = dki.dki_fit(dwi_dki, mask, constraints=constraints)
logger.info("DKI fit completed.", extra={"dt_dki_shape": dt_dki.shape})

dt_ = {}
dt_dki_ = vectorize(dt_dki, mask)
dt_['dt'] = dt_dki_
save_params(dt_, nii, model='dki', outdir=outdir)
logger.info("DKT saved.")
dt_ = {}
dt_dki_ = vectorize(dt_dki, mask)
dt_['dt'] = dt_dki_
save_params(dt_, nii, model='dki', outdir=outdir)
logger.info("DKT saved.")

if app.ARGS.polyreg:
if app.ARGS.WDKI:
Expand Down Expand Up @@ -420,21 +425,26 @@ def execute(): #pylint: disable=unused-variable
bvec_dki = bvec[:,bval < maxb]
bval_dki = bval[bval < maxb]

if len(set(np.round(bval_dki, 2))) <= 2:
logger.warning(f"Fewer than 2 nonzero shells found for DKI fitting at TE={te}.", extra={"bval_dki": bval_dki.tolist()})

grad = np.hstack((bvec_dki.T, bval_dki[None,...].T))
dki = tensor.TensorFitting(grad, int(app.ARGS.n_cores))
dt_dki, s0_dki, b_dki = dki.dki_fit(dwi_dki, mask, constraints=constraints)
logger.info(f"DKI fit completed for TE={te}.", extra={"dt_dki_shape": dt_dki.shape})

if app.ARGS.polyreg:
if app.ARGS.WDKI:
dt_poly_dki = dki.train_rotated_bayes_fit(dwi_dki, dt_dki, s0_dki, b_dki, mask)
logger.info(f"Polyreg WDKI fit completed for TE={te}.", extra={"dt_poly_dki_shape": dt_poly_dki.shape})
if app.ARGS.DTI:
dt_poly_dti = dti.train_rotated_bayes_fit(dwi_dti, dt_dti, s0_dti, b_dti, mask, True)
logger.info(f"Polyreg DTI fit completed for TE={te}.", extra={"dt_poly_dti_shape": dt_poly_dti.shape})
if np.max(bval_dki) < 2:
logger.warning("Max b-value is <2. Skipping DKI and WDKI fit.")
app.ARGS.DKI=False
app.ARGS.WDKI=False
else:
if len(set(np.round(bval_dki, 2))) <= 2:
logger.warning(f"Fewer than 2 nonzero shells found for DKI fitting at TE={te}.", extra={"bval_dki": bval_dki.tolist()})

grad = np.hstack((bvec_dki.T, bval_dki[None,...].T))
dki = tensor.TensorFitting(grad, int(app.ARGS.n_cores))
dt_dki, s0_dki, b_dki = dki.dki_fit(dwi_dki, mask, constraints=constraints)
logger.info(f"DKI fit completed for TE={te}.", extra={"dt_dki_shape": dt_dki.shape})

if app.ARGS.polyreg:
if app.ARGS.WDKI:
dt_poly_dki = dki.train_rotated_bayes_fit(dwi_dki, dt_dki, s0_dki, b_dki, mask)
logger.info(f"Polyreg WDKI fit completed for TE={te}.", extra={"dt_poly_dki_shape": dt_poly_dki.shape})
if app.ARGS.DTI:
dt_poly_dti = dti.train_rotated_bayes_fit(dwi_dti, dt_dti, s0_dti, b_dti, mask, True)
logger.info(f"Polyreg DTI fit completed for TE={te}.", extra={"dt_poly_dti_shape": dt_poly_dti.shape})

if app.ARGS.akc_outliers:
from lib.mpunits import vectorize
Expand Down Expand Up @@ -550,66 +560,69 @@ def execute(): #pylint: disable=unused-variable
warnings.simplefilter('always', UserWarning)

logger.info("Starting SMI fitting process...")

if not app.ARGS.sigma:
logger.warning("No sigma map provided. SMI may be poorly conditioned.")
sigma = None
if np.max(bval) < 2:
logger.warning("Max b-value is <2. Skipping SMI fit.")
app.ARGS.SMI=False
else:
sigma = image_read(path.from_user(app.ARGS.sigma)).numpy()
logger.info("Sigma map loaded.", extra={"sigma_shape": sigma.shape})

if app.ARGS.compartments:
compartments = app.ARGS.compartments
if type(compartments) == str:
compartments = compartments.split(",")
compartments = [str(i) for i in compartments]
if not app.ARGS.sigma:
logger.warning("No sigma map provided. SMI may be poorly conditioned.")
sigma = None
else:
raise MRtrixError(" Compartments must be a comma sepearated string (i.e. IAS,EAS)")
logger.info("SMI compartments specified.", extra={"compartments": compartments})
else:
compartments = ['IAS', 'EAS']
logger.info("Using default SMI compartments.", extra={"compartments": compartments})

if app.ARGS.lmax:
if int(app.ARGS.lmax) not in {0, 2, 4, 6}:
raise ValueError("lmax value must be 0, 2, 4, or 6.")
sigma = image_read(path.from_user(app.ARGS.sigma)).numpy()
logger.info("Sigma map loaded.", extra={"sigma_shape": sigma.shape})

if app.ARGS.compartments:
compartments = app.ARGS.compartments
if type(compartments) == str:
compartments = compartments.split(",")
compartments = [str(i) for i in compartments]
else:
raise MRtrixError(" Compartments must be a comma sepearated string (i.e. IAS,EAS)")
logger.info("SMI compartments specified.", extra={"compartments": compartments})
else:
compartments = ['IAS', 'EAS']
logger.info("Using default SMI compartments.", extra={"compartments": compartments})

if app.ARGS.lmax:
if int(app.ARGS.lmax) not in {0, 2, 4, 6}:
raise ValueError("lmax value must be 0, 2, 4, or 6.")
else:
lmax = int(app.ARGS.lmax)
logger.info("lmax for SMI specified.", extra={"lmax": lmax})
else:
lmax = int(app.ARGS.lmax)
logger.info("lmax for SMI specified.", extra={"lmax": lmax})
else:
lmax = None
logger.info("No lmax specified for SMI. Using default.")

logger.info("Initializing SMI fitting...")
echo_times = dwi_metadata['echo_time_per_volume']
if (np.min(echo_times) < 1.0) and (np.min(echo_times) > 0):
logger.info("Echo times in s, converting to ms.")
echo_times *= 1000

if multi_te_beta:
smi = SMI(bval=bval_orig, bvec=bvec_orig, rotinv_lmax=lmax,
compartments=compartments, echo_time=echo_times,
beta=dwi_metadata['bshape_per_volume'])

logger.info("SMI model initialized for multi-TE/beta data.")

params_smi = smi.fit(dwi_orig, mask=mask, sigma=sigma)
logger.info("SMI fitting completed for multi-TE/beta data.", extra={"params_smi_shape": {key: value.shape for key, value in params_smi.items()}})

save_params(params_smi, nii, model='smi', outdir=outdir)
logger.info("SMI parameters saved for multi-TE/beta data.", extra={"outdir": outdir})
else:
smi = SMI(bval=bval, bvec=bvec, rotinv_lmax=lmax,
compartments=compartments, echo_time=echo_times,
beta=dwi_metadata['bshape_per_volume'])

logger.info("SMI model initialized for single-TE/beta data.")
lmax = None
logger.info("No lmax specified for SMI. Using default.")

logger.info("Initializing SMI fitting...")
echo_times = dwi_metadata['echo_time_per_volume']
if (np.min(echo_times) < 1.0) and (np.min(echo_times) > 0):
logger.info("Echo times in s, converting to ms.")
echo_times *= 1000

if multi_te_beta:
smi = SMI(bval=bval_orig, bvec=bvec_orig, rotinv_lmax=lmax,
compartments=compartments, echo_time=echo_times,
beta=dwi_metadata['bshape_per_volume'])

logger.info("SMI model initialized for multi-TE/beta data.")

params_smi = smi.fit(dwi_orig, mask=mask, sigma=sigma)
logger.info("SMI fitting completed for multi-TE/beta data.", extra={"params_smi_shape": {key: value.shape for key, value in params_smi.items()}})

save_params(params_smi, nii, model='smi', outdir=outdir)
logger.info("SMI parameters saved for multi-TE/beta data.", extra={"outdir": outdir})
else:
smi = SMI(bval=bval, bvec=bvec, rotinv_lmax=lmax,
compartments=compartments, echo_time=echo_times,
beta=dwi_metadata['bshape_per_volume'])

logger.info("SMI model initialized for single-TE/beta data.")

params_smi = smi.fit(dwi, mask=mask, sigma=sigma)
logger.info("SMI fitting completed for single-TE/beta data.", extra={"params_smi_shape": {key: value.shape for key, value in params_smi.items()}})
params_smi = smi.fit(dwi, mask=mask, sigma=sigma)
logger.info("SMI fitting completed for single-TE/beta data.", extra={"params_smi_shape": {key: value.shape for key, value in params_smi.items()}})

save_params(params_smi, nii, model='smi', outdir=outdir)
logger.info("SMI parameters saved for single-TE/beta data.", extra={"outdir": outdir})
save_params(params_smi, nii, model='smi', outdir=outdir)
logger.info("SMI parameters saved for single-TE/beta data.", extra={"outdir": outdir})

logger.info("Execution completed successfully.")

Expand Down

0 comments on commit 74b88cf

Please sign in to comment.