Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
enh: improve CLI of the dwi/error experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Oct 29, 2024
1 parent 26ef287 commit c875bc5
Showing 1 changed file with 91 additions and 10 deletions.
101 changes: 91 additions & 10 deletions scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,60 @@ def _build_arg_parser() -> argparse.ArgumentParser:
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"hsph_dirs",
"--bval-shell",
help="Shell b-value",
type=float,
default=1000,
)
parser.add_argument("--S0", help="S0 value", type=float, default=100)
parser.add_argument(
"--hsph-dirs",
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
default=60,
)
<<<<<<< Updated upstream
parser.add_argument("bval_shell", help="Shell b-value", type=int)
parser.add_argument("S0", help="S0 value", type=float)
=======
>>>>>>> Stashed changes
parser.add_argument(
"error_data_fname",
"--output-scores",
help="Filename of TSV file containing the data to plot",
type=Path,
default=Path() / "scores.tsv",
)
<<<<<<< Updated upstream
=======
parser.add_argument(
"-n",
"--n-voxels",
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
default=100,
)
parser.add_argument(
"--write-inputs",
help="Filename of NIfTI file containing the generated DWI signal",
type=Path,
default=None,
)
parser.add_argument(
"--output-predicted",
help="Filename of NIfTI file containing the predicted DWI signal",
type=Path,
default=None,
)
>>>>>>> Stashed changes
parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float)
parser.add_argument("--snr", help="Signal to noise ratio", type=float)
parser.add_argument("--repeats", help="Number of repeats", type=int, default=5)
parser.add_argument(
"--kfold", help="Number of directions to leave out/predict", nargs="+", type=int
"--kfold",
help="Number of folds in repeated-k-fold cross-validation",
nargs="+",
type=int,
default=None,
)
return parser

Expand All @@ -134,36 +172,48 @@ def main() -> None:
parser = _build_arg_parser()
args = _parse_args(parser)

n_voxels = 100

data, gtab = testsims.simulate_voxels(
args.S0,
args.hsph_dirs,
bval_shell=args.bval_shell,
snr=args.snr,
n_voxels=n_voxels,
n_voxels=args.n_voxels,
evals=args.evals,
seed=None,
)

<<<<<<< Updated upstream
=======
# Save the generated signal and gradient table
if args.write_inputs:
testsims.serialize_dmri(
data,
gtab,
args.write_inputs,
args.write_inputs.with_suffix(".bval"),
args.write_inputs.with_suffix(".bvec"),
)

>>>>>>> Stashed changes
X = gtab[~gtab.b0s_mask].bvecs
y = data[:, ~gtab.b0s_mask]

snr_str = args.snr if args.snr is not None else "None"

a = 1.15
lambda_s = 120
alpha = 100
alpha = 1
gpr = EddyMotionGPR(
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
kernel=SphericalKriging(beta_a=a, beta_l=lambda_s),
alpha=alpha,
optimizer=None,
# optimizer="Nelder-Mead",
# optimizer=None,
optimizer="cross-validation",
# disp=True,
# ftol=1,
# max_iter=2e5,
)

<<<<<<< Updated upstream
# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
Expand All @@ -183,6 +233,37 @@ def main() -> None:
grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())
=======
if args.kfold:
# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
scores["snr"] += [snr_str] * len(cv_scores)

print(f"Finished {n}-fold cross-validation")

scores_df = pd.DataFrame(scores)
scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a")

grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())
else:
gpr.n_trials = 1000
gpr.fit(X, y.T)
print(gpr.kernel_)

if args.output_predicted:
cv = KFold(n_splits=3, shuffle=False, random_state=None)
predictions = cross_val_predict(gpr, X, y.T, cv=cv)

testsims.serialize_dwi(predictions.T, args.output_predicted)
>>>>>>> Stashed changes


if __name__ == "__main__":
Expand Down

0 comments on commit c875bc5

Please sign in to comment.