Skip to content

Commit

Permalink
dwidenoise: Add demeaning
Browse files Browse the repository at this point in the history
By default, if the input image series contains a gradient table that appears to be arranged into shells, then the demeaning will be computed per shell per voxel; otherwise a single mean will be computed per voxel.
Demeaning can also be disabled using the -demean option.
The code responsible for the demeaning has been integrated with the code for phase demodulation and the variance-stabilising transform as "preconditioning".
The presence of demeaning does not affect the PCA or the noise level estimation in any way; it only influences the _reported_ signal rank.
Supersedes #2363.
  • Loading branch information
Lestropie committed Dec 15, 2024
1 parent 4416f12 commit 8a068a9
Show file tree
Hide file tree
Showing 17 changed files with 624 additions and 316 deletions.
88 changes: 54 additions & 34 deletions cmd/dwi2noise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
#include "algo/threaded_loop.h"
#include "axes.h"
#include "command.h"
#include "denoise/demodulate.h"
#include "denoise/estimate.h"
#include "denoise/estimator/estimator.h"
#include "denoise/exports.h"
#include "denoise/kernel/kernel.h"
#include "denoise/precondition.h"
#include "denoise/subsample.h"
#include "dwi/gradient.h"
#include "exception.h"
#include "filter/demodulate.h"

Expand Down Expand Up @@ -104,12 +105,10 @@ void usage() {
+ Estimator::estimator_option
+ Kernel::options
+ subsample_option
+ demodulation_options
+ Option("vst",
"apply a within-patch variance-stabilising transformation based on a pre-estimated noise level map; "
"note that this will be used for within-patch non-stationariy correction only, "
"the output noise level estimate will still be derived from the input data")
+ Argument("image").type_image_in()
+ precondition_options

+ DWI::GradImportOptions()
+ DWI::GradExportOptions()

+ OptionGroup("Options for exporting additional data regarding PCA behaviour")
+ Option("rank",
Expand Down Expand Up @@ -156,37 +155,58 @@ void usage() {
// clang-format on

template <typename T>
void run(Header &data,
void run(Image<T> &input,
std::shared_ptr<Subsample> subsample,
std::shared_ptr<Kernel::Base> kernel,
Image<float> &vst_noise_image,
std::shared_ptr<Estimator::Base> estimator,
Exports &exports) {
auto input = data.get_image<T>().with_direct_io(3);
Estimate<T> func(data, subsample, kernel, vst_noise_image, estimator, exports);
ThreadedLoop("running MP-PCA noise level estimation", data, 0, 3).run(func, input);
Estimate<T> func(input, subsample, kernel, estimator, exports);
ThreadedLoop("running MP-PCA noise level estimation", input, 0, 3).run(func, input);
}

template <typename T>
void run(Header &data,
void run(Header &dwi,
const Demodulation &demodulation,
const demean_type demean,
Image<float> &vst_noise_image,
std::shared_ptr<Subsample> subsample,
std::shared_ptr<Kernel::Base> kernel,
Image<float> &vst_noise_image,
std::shared_ptr<Estimator::Base> estimator,
Exports &exports) {
if (!demodulation) {
run<T>(data, subsample, kernel, vst_noise_image, estimator, exports);
auto opt_preconditioned = get_options("preconditioned");
if (!demodulation && demean == demean_type::NONE && !vst_noise_image.valid()) {
if (!opt_preconditioned.empty()) {
WARN("-preconditioned option ignored: no preconditioning taking place");
}
Image<T> input = dwi.get_image<T>().with_direct_io(3);
run<T>(input, subsample, kernel, estimator, exports);
return;
}
auto input = data.get_image<T>();
auto input_demod = Image<T>::scratch(data, "Phase-demodulated version of \"" + data.name() + "\"");
{
Filter::Demodulate demodulator(input, demodulation.axes, demodulation.mode == demodulation_t::LINEAR);
demodulator(input, input_demod);
Image<T> input(dwi.get_image<T>());
const Precondition<T> preconditioner(input, demodulation, demean, vst_noise_image);
Header H_preconditioned(input);
Stride::set(H_preconditioned, Stride::contiguous_along_axis(3, input));
Image<T> input_preconditioned;
input_preconditioned = opt_preconditioned.empty()
? Image<T>::scratch(H_preconditioned, "Preconditioned version of \"" + input.name() + "\"")
: Image<T>::create(opt_preconditioned[0][0], H_preconditioned);
preconditioner(input, input_preconditioned, false);
run(input_preconditioned, subsample, kernel, estimator, exports);
if (vst_noise_image.valid()) {
Interp::Cubic<Image<float>> vst(vst_noise_image);
const Transform transform(exports.noise_out);
for (auto l = Loop(exports.noise_out)(exports.noise_out); l; ++l) {
vst.scanner(transform.voxel2scanner * Eigen::Vector3d({default_type(exports.noise_out.index(0)),
default_type(exports.noise_out.index(1)),
default_type(exports.noise_out.index(2))}));
exports.noise_out.value() *= vst.value();
}
}
if (preconditioner.rank() == 1 && exports.rank_input.valid()) {
for (auto l = Loop(exports.rank_input)(exports.rank_input); l; ++l)
exports.rank_input.value() =
std::max<uint16_t>(uint16_t(exports.rank_input.value()) + uint16_t(1), uint16_t(dwi.size(3)));
}
Estimate<T> func(data, subsample, kernel, vst_noise_image, estimator, exports);
ThreadedLoop("running MP-PCA noise level estimation", data, 0, 3).run(func, input_demod);
}

void run() {
Expand All @@ -195,20 +215,20 @@ void run() {
throw Exception("input image must be 4-dimensional");
bool complex = dwi.datatype().is_complex();

const Demodulation demodulation = get_demodulation(dwi);
const Demodulation demodulation = select_demodulation(dwi);
const demean_type demean = select_demean(dwi);
Image<float> vst_noise_image;
auto opt = get_options("vst");
if (!opt.empty())
vst_noise_image = Image<float>::open(opt[0][0]);

auto subsample = Subsample::make(dwi);
assert(subsample);

auto kernel = Kernel::make_kernel(dwi, subsample->get_factors());
assert(kernel);

Image<float> vst_noise_image;
auto opt = get_options("vst");
if (!opt.empty())
vst_noise_image = Image<float>::open(opt[0][0]);

auto estimator = Estimator::make_estimator(false);
auto estimator = Estimator::make_estimator(vst_noise_image, false);
assert(estimator);

Exports exports(dwi, subsample->header());
Expand All @@ -233,20 +253,20 @@ void run() {
case 0:
assert(demodulation.axes.empty());
INFO("select real float32 for processing");
run<float>(dwi, subsample, kernel, vst_noise_image, estimator, exports);
run<float>(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports);
break;
case 1:
assert(demodulation.axes.empty());
INFO("select real float64 for processing");
run<double>(dwi, subsample, kernel, vst_noise_image, estimator, exports);
run<double>(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports);
break;
case 2:
INFO("select complex float32 for processing");
run<cfloat>(dwi, demodulation, subsample, kernel, vst_noise_image, estimator, exports);
run<cfloat>(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports);
break;
case 3:
INFO("select complex float64 for processing");
run<cdouble>(dwi, demodulation, subsample, kernel, vst_noise_image, estimator, exports);
run<cdouble>(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports);
break;
}
}
Loading

0 comments on commit 8a068a9

Please sign in to comment.