Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use phase data in main workflow and add option for phase noise map #10

Merged
merged 9 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions src/patch_denoise/bindings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,17 @@ def _get_parser():
type=IsFile,
help=(
"Phase of the input data. This MUST be in radians. "
"No conversion would be applied."
"No rescaling will be applied."
),
)
data_group.add_argument(
"--noise-map-phase",
metavar="FILE",
default=None,
type=IsFile,
help=(
"Phase component of the noise map estimation file. "
"This MUST be in radians. No rescaling will be applied."
),
)

Expand Down Expand Up @@ -283,17 +293,19 @@ def main():

if args.input_phase is not None:
input_data, affine = load_complex_nifti(args.input_file, args.input_phase)
input_data, affine = load_as_array(args.input_file)
else:
input_data, affine = load_as_array(args.input_file)

kwargs = args.extra

if args.nan_to_num is not None:
input_data = np.nan_to_num(input_data, nan=args.nan_to_num)

n_nans = np.isnan(input_data).sum()
if n_nans > 0:
logging.warning(
f"{n_nans}/{np.prod(input_data.shape)} voxels are NaN."
" You might want to use --nan-to-num=<value>",
f"{n_nans}/{input_data.size} voxels are NaN. "
"You might want to use --nan-to-num=<value>",
stacklevel=0,
)

Expand All @@ -302,14 +314,30 @@ def main():
affine_mask = None
else:
mask, affine_mask = load_as_array(args.mask)
noise_map, affine_noise_map = load_as_array(args.noise_map)

if args.noise_map is not None and args.noise_map_phase is not None:
noise_map, affine_noise_map = load_complex_nifti(
args.noise_map,
args.noise_map_phase,
)
elif args.noise_map is not None:
noise_map, affine_noise_map = load_as_array(args.noise_map)
elif args.noise_map_phase is not None:
raise ValueError(
"The phase component of the noise map has been provided, "
"but not the magnitude."
)
else:
noise_map = None
affine_noise_map = None

if affine is not None:
if affine_mask is not None and np.allclose(affine, affine_mask):
if (affine_mask is not None) and not np.allclose(affine, affine_mask):
logging.warning(
"Affine matrix of input and mask does not match", stacklevel=2
)
if affine_noise_map is not None and np.allclose(affine, affine_noise_map):

if (affine_noise_map is not None) and not np.allclose(affine, affine_noise_map):
logging.warning(
"Affine matrix of input and noise map does not match", stacklevel=2
)
Expand Down Expand Up @@ -344,7 +372,7 @@ def main():
if noise_map is None:
raise RuntimeError("A noise map must be specified for this method.")

denoised_data, patchs_weight, noise_std_map, rank_map = denoise_func(
denoised_data, _, noise_std_map, _ = denoise_func(
input_data,
patch_shape=args.patch_shape,
patch_overlap=args.patch_overlap,
Expand Down
31 changes: 24 additions & 7 deletions tests/test_spacetime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,30 @@ def f(x):


@pytest.mark.parametrize("block_dim", range(5, 10))
def test_noise_estimation(medium_random_matrix, block_dim):
"""Test noise estimation."""
noise_map = estimate_noise(medium_random_matrix, block_dim)

real_std = np.nanstd(medium_random_matrix)
err = np.nanmean(noise_map - real_std)
assert err <= 0.1 * real_std
def test_noise_estimation(block_dim):
"""Test noise estimation.

The mean patch-wise standard deviation should be close to the overall
standard deviation.
"""
for seed in range(15):
print(f"Seed: {seed}")
rng = np.random.RandomState(seed)
medium_random_matrix = rng.randn(200, 200, 100)
print(f"Mean of raw: {np.nanmean(medium_random_matrix)}")
print(f"Max of raw: {np.nanmax(medium_random_matrix)}")
print(f"Min of raw: {np.nanmin(medium_random_matrix)}")
real_std = np.nanstd(medium_random_matrix)
print(f"SD of raw: {real_std}")

noise_map = estimate_noise(medium_random_matrix, block_dim)
print(f"Mean of noise map: {np.nanmean(noise_map)}")
print(f"Max of noise map: {np.nanmax(noise_map)}")
print(f"Min of noise map: {np.nanmin(noise_map)}")
print(f"SD of noise map: {np.nanstd(noise_map)}")
err = np.nanmean(noise_map - real_std)
print(f"Err: {err}")
assert err <= 0.1 * real_std


@parametrize_random_matrix
Expand Down
Loading