From 7a47084c650e676e60bd04ad237971af818df1d8 Mon Sep 17 00:00:00 2001 From: Benjamin Aron Date: Mon, 12 Aug 2024 17:46:48 -0400 Subject: [PATCH] add optional maxb argument and add logging to tmi --- designer2/tmi.py | 232 +++++++++++++++++++++++++++++++++-------- docs/docs/TMI/usage.md | 10 ++ lib/tensor.py | 16 +-- 3 files changed, 208 insertions(+), 50 deletions(-) diff --git a/designer2/tmi.py b/designer2/tmi.py index 7121053f..dcd01caf 100644 --- a/designer2/tmi.py +++ b/designer2/tmi.py @@ -1,9 +1,66 @@ #!/usr/bin/env python3 import os +import logging +import json +from logging import StreamHandler, FileHandler from lib.designer_input_utils import get_input_info, convert_input_data, create_shell_table, assert_inputs from lib.designer_fit_wrappers import refit_or_smooth, save_params +# List of keys to exclude from logs +EXCLUDED_KEYS = { + "msg", "levelname", "levelno", "exc_info", "exc_text", + "stack_info", "created", "msecs", "relativeCreated", + "thread", "threadName", "processName", "process", "args" +} + +# Set up logging in JSON format +class JsonFormatter(logging.Formatter): + def format(self, record): + log_record = { + 'level': record.levelname, + 'time': self.formatTime(record, self.datefmt), + 'message': record.getMessage(), + 'name': record.name, + 'pathname': record.pathname, + 'lineno': record.lineno, + } + log_record.update({ + key: value for key, value in record.__dict__.items() + if key not in EXCLUDED_KEYS and key not in log_record + }) + return json.dumps(log_record) + +# Simple Formatter for StreamHandler (console output) +class SimpleFormatter(logging.Formatter): + def format(self, record): + if (record.levelname == "WARNING") or (record.levelname == "ERROR") or (record.levelname == "CRITICAL"): + log_message = f"{record.levelname} - {record.getMessage()}" + else: + log_message = f"... {record.getMessage()}" + extra_info = { + key: value for key, value in record.__dict__.items() + if key not in EXCLUDED_KEYS and key not in ['levelname', 'message', 'name', 'pathname', 'lineno', 'levelno', 'exc_info', 'exc_text', 'args', 'msg', 'filename', 'module', 'funcName'] + } + if extra_info: + log_message += f" | {extra_info}" + return log_message + +def setup_logging(output_dir): + stream_handler = StreamHandler() + stream_handler.setFormatter(SimpleFormatter()) + + # FileHandler to write logs to a JSON file + file_handler = FileHandler(f"{output_dir}/execution_log.json") + file_handler.setFormatter(JsonFormatter()) + + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + logger.addHandler(stream_handler) + logger.addHandler(file_handler) + + return logger + def usage(cmdline): #pylint: disable=unused-variable from mrtrix3 import app #pylint: disable=no-name-in-module, import-outside-toplevel @@ -61,16 +118,18 @@ def usage(cmdline): #pylint: disable=unused-variable dki_options.add_argument('-fit_constraints',metavar=(''),help='constrain the wlls fit (default 0,1,0)') dki_options.add_argument('-fit_smoothing',metavar=(''),help='NLM smoothing on wlls fit') dki_options.add_argument('-polyreg',action='store_true',help='polynomial regression based DKI estimation') + dki_options.add_argument('-maxb', metavar=(''),help='maximum b-value for DKI fitting, default=3.') smi_options = cmdline.add_argument_group('tensor options for the TMI script') smi_options.add_argument('-SMI', action='store_true',help='Perform estimation of SMI (standard model of Diffusion in White Matter). Please use in conjunction with the -bshape, -echo_time, -sigma, and -compartments options.') smi_options.add_argument('-compartments', metavar=(''),help='SMI compartments (IAS, EAS, and FW), default=IAS,EAS') smi_options.add_argument('-sigma', metavar=(''),help='path to noise map for SMI parameter estimation') - smi_options.add_argument('-lmax', metavar=(''),help='lmax for polynomial regression. must be 0,2,4, or 6.') + smi_options.add_argument('-lmax', metavar=(''),help='lmax for SMI polynomial regression. must be 0,2,4, or 6.') #wmti_options = cmdline.add_argument_group('tensor options for the TMI script') #wmti_options.add_argument('-WMTI', action='store_true', help='Include WMTI parameters in output folder (awf,ias_params,eas_params)') +logger = None def execute(): #pylint: disable=unused-variable from mrtrix3 import app, path, run, MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel @@ -79,19 +138,30 @@ def execute(): #pylint: disable=unused-variable from ants import image_read import pandas as pd + outdir = path.from_user(app.ARGS.output, True) + if not os.path.exists(outdir): + os.makedirs(outdir) + + logger = setup_logging(outdir) + logger.info(f"Output directory created: {outdir}") + + logger.info("Starting execution...") + app.make_scratch_dir() dwi_metadata = get_input_info( app.ARGS.input, app.ARGS.fslbval, app.ARGS.fslbvec, app.ARGS.bids) + logger.info("Input information obtained.", extra={"dwi_metadata": dwi_metadata}) assert_inputs(dwi_metadata, None, None) - convert_input_data(dwi_metadata) + logger.info("Input data converted successfully.") shell_table = create_shell_table(dwi_metadata) shell_rows = ['b-value', 'b-shape', 'n volumes', 'echo time'] shell_df = pd.DataFrame(data = shell_table, index = shell_rows) + print('input DWI data has properties:') print(shell_df) @@ -100,26 +170,28 @@ def execute(): #pylint: disable=unused-variable run.command('mrconvert dwi.mif -export_grad_fsl dwi.bvec dwi.bval dwi.nii', show=False) nii = image_read('dwi.nii') dwi = nii.numpy() + logger.info("DWI data converted to NIfTI format.") - outdir = path.from_user(app.ARGS.output, True) - if not os.path.exists(outdir): - os.makedirs(outdir) - - nii = image_read('dwi.nii') - dwi = nii.numpy() + # nii = image_read('dwi.nii') + # dwi = nii.numpy() bvec = np.loadtxt('dwi.bvec') bval = np.loadtxt('dwi.bval') + logger.info("Loaded bvec and bval data.", extra={"bvec_shape": bvec.shape, "bval_shape": bval.shape}) order = np.floor(np.log(abs(np.max(bval)+1)) / np.log(10)) if order >= 2: bval = bval / 1000 + logger.info("bval converted normalized to ms/um^2.", extra={"bval": list(set(np.round(bval, 2)))}) grad = np.hstack((bvec.T, bval[None,...].T)) + logger.info("Gradient table created.", extra={"grad_shape": grad.shape}) if app.ARGS.mask: mask = image_read(path.from_user(app.ARGS.mask)).numpy() + logger.info("Loaded mask from file.", extra={"mask_shape": mask.shape}) else: mask = np.ones(dwi.shape[:-1]) + logger.info("No mask provided. Using default mask with all ones.", extra={"mask_shape": mask.shape}) mask = mask.astype(bool) if app.ARGS.fit_constraints: @@ -129,35 +201,40 @@ def execute(): #pylint: disable=unused-variable constraints = [int(i) for i in constraints] else: raise MRtrixError("Constraints must be a 3 element comma separated string (i.e. 0,1,0)") + logger.info("Fit constraints provided.", extra={"constraints": constraints}) else: constraints = [0,0,0] + logger.info("No fit constraints provided. Using default constraints.", extra={"constraints": constraints}) if (len(set(dwi_metadata['bshape_per_volume'])) > 1) or (len(set(dwi_metadata['echo_time_per_volume'])) > 1): multi_te_beta = True dwi_orig = dwi.copy() bval_orig = bval.copy() bvec_orig = bvec.copy() + logger.info("Multi-TE or Multi-beta detected.", extra={"multi_te_beta": multi_te_beta}) else: multi_te_beta = False if len(set(dwi_metadata['bshape_per_volume'])) > 1: if (app.ARGS.DTI) or (app.ARGS.DKI) or (app.ARGS.WDKI): - print('For variable b-shape data DTI/DKI are run only on LTE part') + logger.info("Variable b-shape detected, filtering for LTE volumes.") lte_idx = (dwi_metadata['bshape_per_volume'] == 1) - if np.sum(lte_idx) < 6: + logger.error("Not enough LTE measurements for DTI/DKI.", extra={"lte_count": np.sum(lte_idx)}) raise MRtrixError("Not enough LTE measurements for DTI/DKI") dwi = dwi_orig[:,:,:,lte_idx] bval = bval_orig[lte_idx] bvec = bvec_orig[:,lte_idx] + logger.info("Filtered DWI, bval, and bvec for LTE volumes.", extra={"dwi_shape": dwi.shape}) if len(set(dwi_metadata['echo_time_per_volume'])) == 1: + logger.info("Single echo time detected.") if app.ARGS.DTI: from lib.mpunits import vectorize - print('...Single shell DTI fit...') + logger.info("Starting DTI fit...") dtishell = (bval <= 0.1) | ((bval > .5) & (bval <= 1.5)) dwi_dti = dwi[:,:,:,dtishell] bval_dti = bval[dtishell] @@ -165,52 +242,68 @@ def execute(): #pylint: disable=unused-variable grad_dti = np.hstack((bvec_dti.T, bval_dti[None,...].T)) dti = tensor.TensorFitting(grad_dti, int(app.ARGS.n_cores)) dt_dti, s0_dti, b_dti = dti.dti_fit(dwi_dti, mask) + logger.info("DTI fit completed.", extra={"dt_dti_shape": dt_dti.shape}) dt_ = {} dt_dti_ = vectorize(dt_dti, mask) dt_['dt'] = dt_dti_ save_params(dt_, nii, model='dti', outdir=outdir) + logger.info("DT saved.") if app.ARGS.DKI or app.ARGS.WDKI: from lib.mpunits import vectorize - print('...Multi shell DKI fit with constraints = ' + str(constraints)) - maxb = 2.5 + logger.info("Starting DKI fit...") + + if app.ARGS.maxb: + maxb = float(app.ARGS.maxb) + logger.info("Maximum b-value set.", extra={"maxb": maxb}) + else: + maxb = 3.01 + dwi_dki = dwi[:,:,:,bval < maxb] bvec_dki = bvec[:,bval < maxb] bval_dki = bval[bval < maxb] + if len(set(np.round(bval_dki, 2))) <= 2: + logger.warning("Fewer than 2 nonzero shells found for DKI fitting.", extra={"bval_dki": bval_dki.tolist()}) + grad_dki = np.hstack((bvec_dki.T, bval_dki[None,...].T)) dki = tensor.TensorFitting(grad_dki, int(app.ARGS.n_cores)) dt_dki, s0_dki, b_dki = dki.dki_fit(dwi_dki, mask, constraints=constraints) + logger.info("DKI fit completed.", extra={"dt_dki_shape": dt_dki.shape}) dt_ = {} dt_dki_ = vectorize(dt_dki, mask) dt_['dt'] = dt_dki_ save_params(dt_, nii, model='dki', outdir=outdir) + logger.info("DKT saved.") if app.ARGS.polyreg: if app.ARGS.WDKI: dt_poly_dki = dki.train_rotated_bayes_fit(dwi_dki, dt_dki, s0_dki, b_dki, mask) + logger.info("Polyreg WDKI fit completed.", extra={"dt_poly_dki_shape": dt_poly_dki.shape}) if app.ARGS.DTI: - dt_poly_dti = dti.train_rotated_bayes_fit(dwi_dti, dt_dti, s0_dti, b_dti, mask, 'True') + dt_poly_dti = dti.train_rotated_bayes_fit(dwi_dti, dt_dti, s0_dti, b_dti, mask, True) + logger.info("Polyreg DTI fit completed.", extra={"dt_poly_dti_shape": dt_poly_dti.shape}) if app.ARGS.akc_outliers: from lib.mpunits import vectorize import scipy.io as sio + logger.info("Starting AKC outlier detection...") dwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) mat = sio.loadmat(os.path.join(dwd,'constant','dirs10000.mat')) dir = mat['dir'] - print('...Outlier detection...') if not (app.ARGS.DKI or app.ARGS.WDKI): + logger.error("AKC Outlier detection must be accompanied by DKI option") raise MRtrixError("AKC Outlier detection must be accompanied by DKI option") else: akc_mask = dki.outlierdetection(dt_dki, mask, dir) akc_mask = vectorize(akc_mask, mask).astype(bool) - print('N outliers = %s' % (np.sum(akc_mask))) + logger.info("Outlier detection completed.", extra={"num_outliers": np.sum(akc_mask)}) dwi_new = refit_or_smooth(akc_mask, dwi_dki, n_cores=int(app.ARGS.n_cores)) dt_new,_,_ = dki.dki_fit(dwi_new, akc_mask) @@ -221,18 +314,20 @@ def execute(): #pylint: disable=unused-variable dt_dki = vectorize(DT, mask) akc_mask = dki.outlierdetection(dt_dki, mask, dir) akc_mask = vectorize(akc_mask, mask).astype(bool) - print('N outliers = %s' % (np.sum(akc_mask))) + logger.info("AKC outlier post-processing completed.", extra={"num_outliers": np.sum(akc_mask)}) else: akc_mask = np.zeros_like(mask) if app.ARGS.fit_smoothing: - print('...Nonlocal smoothing...') + logger.info("Starting tensor based smoothing...") if app.ARGS.DTI: dwi_new = refit_or_smooth(akc_mask, dwi_dti, mask=mask, smoothlevel=int(app.ARGS.fit_smoothing)) dt_dti,_,_ = dti.dti_fit(dwi_new, mask) + logger.info("DTI fit after smoothing completed.", extra={"dt_dti_shape": dt_dti.shape}) if (app.ARGS.DKI or app.ARGS.WDKI): dwi_new = refit_or_smooth(akc_mask, dwi_dki, mask=mask, smoothlevel=int(app.ARGS.fit_smoothing)) dt_dki,_,_ = dki.dki_fit(dwi_new, mask) + logger.info("DKI fit after smoothing completed.", extra={"dt_dki_shape": dt_dki.shape}) if app.ARGS.DTI or app.ARGS.DKI or app.ARGS.polyreg or app.ARGS.WDKI: rdwi = tensor.vectorize(dwi, mask) @@ -250,33 +345,41 @@ def execute(): #pylint: disable=unused-variable trace = tensor.vectorize(trace.T, mask) params_trace = {'trace': trace} save_params(params_trace, nii, model='allshells', outdir=outdir) + logger.info("Trace parameters saved for all shells.", extra={"trace_shape": trace.shape}) if app.ARGS.DTI: - print('...extracting and saving DTI maps...') + logger.info("Extracting and saving DTI maps...") params_dti = dti.extract_parameters(dt_dti, b_dti, mask, extract_dti=True, extract_dki=False, fit_w=False) save_params(params_dti, nii, model='dti', outdir=outdir) + logger.info("DTI maps saved.") if app.ARGS.DKI: - print('...extracting and saving DKI maps...') + logger.info("Extracting and saving DKI maps...") params_dki = dki.extract_parameters(dt_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=False) save_params(params_dki, nii, model='dki', outdir=outdir) + logger.info("DKI maps saved.") if app.ARGS.polyreg: if app.ARGS.WDKI: - params_dki_poly = dki.extract_parameters(dt_poly_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=False) - save_params(params_dki_poly, nii, model='dki_poly', outdir=outdir) + logger.info("Extracting and saving polyreg WDKI maps...") + params_dki_poly = dki.extract_parameters(dt_poly_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=True) + save_params(params_dki_poly, nii, model='wdki_poly', outdir=outdir) + logger.info("Polyreg WDKI maps saved.") if app.ARGS.DTI: + logger.info("Extracting and saving polyreg DTI maps...") params_dti_poly = dti.extract_parameters(dt_poly_dti, b_dti, mask, extract_dti=True, extract_dki=False, fit_w=False) save_params(params_dti_poly, nii, model='dti_poly', outdir=outdir) + logger.info("Polyreg DTI maps saved.") if app.ARGS.WDKI: - print('...extracting and saving WDKI maps...') - params_dwi = dki.extract_parameters(dt_dki, b_dki, mask, extract_dti=False, extract_dki=True, fit_w=True) + logger.info("Extracting and saving WDKI maps...") + params_dwi = dki.extract_parameters(dt_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=True) save_params(params_dwi, nii, model='wdki', outdir=outdir) - + logger.info("WDKI maps saved.") else: for te in set(dwi_metadata['echo_time_per_volume']): + logger.info(f"Processing echo time {te}...") te_idx = (dwi_metadata['echo_time_per_volume'] == te) & (dwi_metadata['bshape_per_volume'] == 1) dwi = dwi_orig[:,:,:,te_idx] @@ -286,9 +389,10 @@ def execute(): #pylint: disable=unused-variable if app.ARGS.DTI: if np.sum(te_idx) < 6: + logger.warning(f"Fewer than 6 measurements found for TE={te}. Skipping DTI fit.") continue - - print('...Single shell DTI fit for TE=' + str(te) + '...') + + logger.info(f"Starting DTI fit for TE={te}...") dtishell = (bval <= 0.1) | ((bval > .5) & (bval <= 1.5)) dwi_dti = dwi[:,:,:,dtishell] bval_dti = bval[dtishell] @@ -296,45 +400,59 @@ def execute(): #pylint: disable=unused-variable grad_dti = np.hstack((bvec_dti.T, bval_dti[None,...].T)) dti = tensor.TensorFitting(grad_dti, int(app.ARGS.n_cores)) dt_dti, s0_dti, b_dti = dti.dti_fit(dwi_dti, mask) + logger.info(f"DTI fit completed for TE={te}.", extra={"dt_dti_shape": dt_dti.shape}) if app.ARGS.DKI or app.ARGS.WDKI: if np.sum(te_idx) < 21: + logger.warning(f"Fewer than 21 measurements found for TE={te}. Skipping DKI fit.") continue - print('...Multi shell DKI fit for TE=' + str(te) + ' with constraints = ' + str(constraints)) + logger.info(f"Starting DKI fit for TE={te} with constraints = {constraints}...") + + if app.ARGS.maxb: + maxb = float(app.ARGS.maxb) + logger.info(f"Maximum b-value set to {maxb} for TE={te}.") + else: + maxb = 3.01 - maxb = 2.5 dwi_dki = dwi[:,:,:,bval < maxb] bvec_dki = bvec[:,bval < maxb] bval_dki = bval[bval < maxb] + if len(set(np.round(bval_dki, 2))) <= 2: + logger.warning(f"Fewer than 2 nonzero shells found for DKI fitting at TE={te}.", extra={"bval_dki": bval_dki.tolist()}) + grad = np.hstack((bvec_dki.T, bval_dki[None,...].T)) dki = tensor.TensorFitting(grad, int(app.ARGS.n_cores)) dt_dki, s0_dki, b_dki = dki.dki_fit(dwi_dki, mask, constraints=constraints) + logger.info(f"DKI fit completed for TE={te}.", extra={"dt_dki_shape": dt_dki.shape}) if app.ARGS.polyreg: if app.ARGS.WDKI: dt_poly_dki = dki.train_rotated_bayes_fit(dwi_dki, dt_dki, s0_dki, b_dki, mask) + logger.info(f"Polyreg WDKI fit completed for TE={te}.", extra={"dt_poly_dki_shape": dt_poly_dki.shape}) if app.ARGS.DTI: - dt_poly_dti = dti.train_rotated_bayes_fit(dwi_dti, dt_dti, s0_dti, b_dti, mask, 'True') + dt_poly_dti = dti.train_rotated_bayes_fit(dwi_dti, dt_dti, s0_dti, b_dti, mask, True) + logger.info(f"Polyreg DTI fit completed for TE={te}.", extra={"dt_poly_dti_shape": dt_poly_dti.shape}) if app.ARGS.akc_outliers: from lib.mpunits import vectorize import scipy.io as sio + logger.info(f"Starting AKC outlier detection for TE={te}...") dwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) mat = sio.loadmat(os.path.join(dwd,'constant','dirs10000.mat')) dir = mat['dir'] - print('...Outlier detection...') if not (app.ARGS.DKI or app.ARGS.WDKI): + logger.error(f"AKC Outlier detection for TE={te} must be accompanied by DKI option.") raise MRtrixError("AKC Outlier detection must be accompanied by DKI option") else: akc_mask = dki.outlierdetection(dt_dki, mask, dir) akc_mask = vectorize(akc_mask, mask).astype(bool) - print('N outliers = %s' % (np.sum(akc_mask))) + logger.info(f"Outlier detection completed for TE={te}.", extra={"num_outliers": np.sum(akc_mask)}) dwi_new = refit_or_smooth(akc_mask, dwi_dki, n_cores=int(app.ARGS.n_cores)) dt_new,_,_ = dki.dki_fit(dwi_new, akc_mask) @@ -345,18 +463,20 @@ def execute(): #pylint: disable=unused-variable dt_dki = vectorize(DT, mask) akc_mask = dki.outlierdetection(dt_dki, mask, dir) akc_mask = vectorize(akc_mask, mask).astype(bool) - print('N outliers = %s' % (np.sum(akc_mask))) + logger.info(f"AKC outlier post-processing completed for TE={te}.", extra={"num_outliers": np.sum(akc_mask)}) else: akc_mask = np.zeros_like(mask) if app.ARGS.fit_smoothing: - print('...Nonlocal smoothing...') + logger.info(f"Starting tensor based smoothing for TE={te}...") if app.ARGS.DTI: dwi_new = refit_or_smooth(akc_mask, dwi_dti, mask=mask, smoothlevel=int(app.ARGS.fit_smoothing)) dt_dti,_,_ = dti.dti_fit(dwi_new, mask) + logger.info(f"DTI fit after smoothing completed for TE={te}.", extra={"dt_dti_shape": dt_dti.shape}) if (app.ARGS.DKI or app.ARGS.WDKI): dwi_new = refit_or_smooth(akc_mask, dwi_dki, mask=mask, smoothlevel=int(app.ARGS.fit_smoothing)) dt_dki,_,_ = dki.dki_fit(dwi_new, mask) + logger.info(f"DKI fit after smoothing completed for TE={te}.", extra={"dt_dki_shape": dt_dki.shape}) if app.ARGS.DTI or app.ARGS.DKI or app.ARGS.polyreg or app.ARGS.WDKI: rdwi = tensor.vectorize(dwi, mask) @@ -374,30 +494,37 @@ def execute(): #pylint: disable=unused-variable trace = tensor.vectorize(trace.T, mask) params_trace = {'trace': trace} save_params(params_trace, nii, model='te'+str(te)+'_shells', outdir=outdir) + logger.info(f"Trace parameters saved for TE={te}.") if app.ARGS.DTI: - print('...extracting and saving DTI maps...') + logger.info(f"Extracting and saving DTI maps for TE={te}...") params_dti = dti.extract_parameters(dt_dti, b_dti, mask, extract_dti=True, extract_dki=False, fit_w=False) save_params(params_dti, nii, model='dti_te'+str(te), outdir=outdir) + logger.info(f"DTI maps saved for TE={te}.") if app.ARGS.DKI: - print('...extracting and saving DKI maps...') + logger.info(f"Extracting and saving DKI maps for TE={te}...") params_dki = dki.extract_parameters(dt_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=False) save_params(params_dki, nii, model='dki_te'+str(te), outdir=outdir) + logger.info(f"DKI maps saved for TE={te}.") if app.ARGS.polyreg: if app.ARGS.WDKI: - params_dki_poly = dki.extract_parameters(dt_poly_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=False) - save_params(params_dki_poly, nii, model='dki_poly_te'+str(te), outdir=outdir) + logger.info(f"Extracting and saving polyreg WDKI maps for TE={te}...") + params_dki_poly = dki.extract_parameters(dt_poly_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=True) + save_params(params_dki_poly, nii, model='wdki_poly_te'+str(te), outdir=outdir) + logger.info(f"Polyreg WDKI maps saved for TE={te}.") if app.ARGS.DTI: + logger.info(f"Extracting and saving polyreg DTI maps for TE={te}...") params_dti_poly = dti.extract_parameters(dt_poly_dti, b_dti, mask, extract_dti=True, extract_dki=False, fit_w=False) save_params(params_dti_poly, nii, model='dti_poly_te'+str(te), outdir=outdir) + logger.info(f"Polyreg DTI maps saved for TE={te}.") if app.ARGS.WDKI: - print('...extracting and saving WDKI maps...') - params_dwi = dki.extract_parameters(dt_dki, b_dki, mask, extract_dti=False, extract_dki=True, fit_w=True) + logger.info(f"Extracting and saving WDKI maps for TE={te}...") + params_dwi = dki.extract_parameters(dt_dki, b_dki, mask, extract_dti=True, extract_dki=True, fit_w=True) save_params(params_dwi, nii, model='wdki_te'+str(te), outdir=outdir) - + logger.info(f"WDKI maps saved for TE={te}.") # if app.ARGS.WMTI: # import dipy.reconst.dki as dki @@ -422,11 +549,14 @@ def execute(): #pylint: disable=unused-variable import warnings warnings.simplefilter('always', UserWarning) + logger.info("Starting SMI fitting process...") + if not app.ARGS.sigma: - warnings.warn('SMI is poorly conditioned without prior estimate of sigma') + logger.warning("No sigma map provided. SMI may be poorly conditioned.") sigma = None else: sigma = image_read(path.from_user(app.ARGS.sigma)).numpy() + logger.info("Sigma map loaded.", extra={"sigma_shape": sigma.shape}) if app.ARGS.compartments: compartments = app.ARGS.compartments @@ -435,32 +565,48 @@ def execute(): #pylint: disable=unused-variable compartments = [str(i) for i in compartments] else: raise MRtrixError(" Compartments must be a comma sepearated string (i.e. IAS,EAS)") + logger.info("SMI compartments specified.", extra={"compartments": compartments}) else: compartments = ['IAS', 'EAS'] + logger.info("Using default SMI compartments.", extra={"compartments": compartments}) if app.ARGS.lmax: if int(app.ARGS.lmax) not in {0, 2, 4, 6}: raise ValueError("lmax value must be 0, 2, 4, or 6.") else: lmax = int(app.ARGS.lmax) + logger.info("lmax for SMI specified.", extra={"lmax": lmax}) else: lmax = None + logger.info("No lmax specified for SMI. Using default.") - print('...SMI fit...') + logger.info("Initializing SMI fitting...") if multi_te_beta: smi = SMI(bval=bval_orig, bvec=bvec_orig, rotinv_lmax=lmax) smi.set_compartments(compartments) smi.set_echotime(dwi_metadata['echo_time_per_volume']) smi.set_bshape(dwi_metadata['bshape_per_volume']) + logger.info("SMI model initialized for multi-TE/beta data.") + params_smi = smi.fit(dwi_orig, mask=mask, sigma=sigma) + logger.info("SMI fitting completed for multi-TE/beta data.", extra={"params_smi_shape": {key: value.shape for key, value in params_smi.items()}}) + save_params(params_smi, nii, model='smi', outdir=outdir) + logger.info("SMI parameters saved for multi-TE/beta data.", extra={"outdir": outdir}) else: smi = SMI(bval=bval, bvec=bvec, rotinv_lmax=lmax) smi.set_compartments(compartments) smi.set_echotime(dwi_metadata['echo_time_per_volume']) smi.set_bshape(dwi_metadata['bshape_per_volume']) + logger.info("SMI model initialized for single-TE/beta data.") + params_smi = smi.fit(dwi, mask=mask, sigma=sigma) + logger.info("SMI fitting completed for single-TE/beta data.", extra={"params_smi_shape": {key: value.shape for key, value in params_smi.items()}}) + save_params(params_smi, nii, model='smi', outdir=outdir) + logger.info("SMI parameters saved for single-TE/beta data.", extra={"outdir": outdir}) + + logger.info("Execution completed successfully.") def main(): import mrtrix3 diff --git a/docs/docs/TMI/usage.md b/docs/docs/TMI/usage.md index a5cf4977..e21d55fb 100644 --- a/docs/docs/TMI/usage.md +++ b/docs/docs/TMI/usage.md @@ -84,12 +84,22 @@ By default, if none of the below options are used, TMI will not estimate paramet - Path to noise map for SMI parameter estimation. Not required but recommended. - We recommend computing sigma prior to running `tmi` using `designer`. +### `-lmax ` +- Value of L-max to use for SMI parameter estimation. +- Reasonable values of L-max will be computed automatically if this option is not used. +- L-max should be 0, 2, 4, or 6. + ### `-bshape ` - Specify the b-shape used in the acquisition (comma separated list the same length as number of inputs). ### `-echo_time ` - Specify the echo time used in the acquisition (comma separated list the same length as number of inputs). +### `-maxb ` +- Specify the maxmimum b-value to use during DKI estimation. +- Units should be in ms/µm^2. +- default maximum b-value for DKI estimation is 3.0 ms/µm^2 + ## Other options for TMI ### `-mask ` diff --git a/lib/tensor.py b/lib/tensor.py index 733238dc..90cedf98 100644 --- a/lib/tensor.py +++ b/lib/tensor.py @@ -11,7 +11,6 @@ import multiprocessing import warnings -warnings.filterwarnings("ignore") class TensorFitting(object): @@ -204,6 +203,7 @@ def extract_parameters(self, dt, b, mask, extract_dti, extract_dki, fit_w=False) fa = np.sqrt(1/2)*np.sqrt((l1-l2)**2+(l2-l3)**2+(l3-l1)**2)/np.sqrt(l1**2+l2**2+l3**2) #trace = vectorize(trace.T, mask) fe = np.abs(np.stack((fa*v1[:,:,:,0], fa*v1[:,:,:,1], fa*v1[:,:,:,2]), axis=3)) + parameters = {} if extract_dti: @@ -254,6 +254,7 @@ def dki_fit(self, dwi, mask, constraints=None): Outputs S0 the true quantitative mean signal at zero gradient strength the gradient tensor b """ + warnings.filterwarnings("ignore") # run the fit order = np.floor(np.log(np.abs(np.max(self.grad[:,-1])+1))/np.log(10)) if order >= 2: @@ -547,7 +548,7 @@ def compute_rotation_matrices(self, n_rotations): return rotation_matrices - def train_rotated_bayes_fit(self, dwi, dt, s0, b, mask, flag_dti='False'): + def train_rotated_bayes_fit(self, dwi, dt, s0, b, mask, flag_dti=False): """ This function trains and evaluates a polynomial regression model to update a wlls fit without outlier voxels. @@ -567,19 +568,20 @@ def train_rotated_bayes_fit(self, dwi, dt, s0, b, mask, flag_dti='False'): SNR = 100 sigma = 1 / SNR - maxb = 3 # grad_orig = self.grad # grad_keep = self.grad[:,3] < maxb # dwi = dwi[..., grad_keep] + + dt = dt.copy() np.maximum(dwi, np.finfo(dwi.dtype).eps, out=dwi) D_apprSq = 1/(np.sum(dt[(0,3,5),:], axis=0)/3)**2 - if not flag_dti == 'True': + if not flag_dti: dt[6:,:] /= np.tile(D_apprSq, (15,1)) # first identify and remove outliers in dt - if flag_dti == 'True': + if flag_dti: outlier_range = (1, 99) all_tinds = np.arange(6) trace = [0,3,5] @@ -717,7 +719,7 @@ def train_rotated_bayes_fit(self, dwi, dt, s0, b, mask, flag_dti='False'): s0_poly = dt_poly[0,:] dt_poly = dt_poly[1:,:] D_apprSq = 1/(np.sum(dt_poly[(0,3,5),:], axis=0)/3)**2 - if not flag_dti == 'True': + if not flag_dti: dt_poly[6:,:] *= np.tile(D_apprSq, (15,1)) return dt_poly @@ -740,7 +742,7 @@ def compute_rotations(n_coeffs, Evec, DD, WT, n_rotations, n_brain_voxels, flag_ dtx[:6, :, rot] = np.vstack((Dx[0, 0, :], Dx[0, 1, :], Dx[0, 2, :], Dx[1, 1, :], Dx[1, 2, :], Dx[2, 2, :])) - if not flag_fit_dti == 'True': + if not flag_fit_dti: Wx = np.zeros((9, 9, n_brain_voxels)) for i in range(3): for j in range(3):