From 8a068a97082cc8bdfe54afd3868911fa7b17dd29 Mon Sep 17 00:00:00 2001 From: Robert Smith Date: Sun, 15 Dec 2024 17:35:39 +1100 Subject: [PATCH] dwidenoise: Add demeaning By default, if the input image series contains a gradient table that appears to be arranged into shells, then the demeaning will be computed per shell per voxel; otherwise a single mean will be computed per voxel. Demeaning can also be disabled using the -demean option. The code responsible for the demeaning has been integrated with the code for phase demodulation and the variance-stabilising transform as "preconditioning". The presence of demeaning does not affect the PCA or the noise level estimation in any way; it only influences the _reported_ signal rank. Supersedes #2363. --- cmd/dwi2noise.cpp | 88 ++++--- cmd/dwidenoise.cpp | 151 ++++++------ core/filter/demodulate.h | 5 + core/filter/kspace.h | 2 +- src/denoise/demodulate.cpp | 107 -------- src/denoise/demodulate.h | 46 ---- src/denoise/denoise.h | 1 + src/denoise/estimate.cpp | 35 --- src/denoise/estimate.h | 5 - src/denoise/estimator/estimator.cpp | 4 +- src/denoise/estimator/estimator.h | 3 +- src/denoise/estimator/import.h | 17 +- src/denoise/exports.h | 2 - src/denoise/precondition.cpp | 369 ++++++++++++++++++++++++++++ src/denoise/precondition.h | 95 +++++++ src/denoise/recon.cpp | 9 +- src/denoise/recon.h | 1 - 17 files changed, 624 insertions(+), 316 deletions(-) delete mode 100644 src/denoise/demodulate.cpp delete mode 100644 src/denoise/demodulate.h create mode 100644 src/denoise/precondition.cpp create mode 100644 src/denoise/precondition.h diff --git a/cmd/dwi2noise.cpp b/cmd/dwi2noise.cpp index 8e94586bc5..5cc5508efe 100644 --- a/cmd/dwi2noise.cpp +++ b/cmd/dwi2noise.cpp @@ -19,12 +19,13 @@ #include "algo/threaded_loop.h" #include "axes.h" #include "command.h" -#include "denoise/demodulate.h" #include "denoise/estimate.h" #include "denoise/estimator/estimator.h" #include "denoise/exports.h" #include "denoise/kernel/kernel.h" +#include "denoise/precondition.h" #include "denoise/subsample.h" +#include "dwi/gradient.h" #include "exception.h" #include "filter/demodulate.h" @@ -104,12 +105,10 @@ void usage() { + Estimator::estimator_option + Kernel::options + subsample_option - + demodulation_options - + Option("vst", - "apply a within-patch variance-stabilising transformation based on a pre-estimated noise level map; " - "note that this will be used for within-patch non-stationariy correction only, " - "the output noise level estimate will still be derived from the input data") - + Argument("image").type_image_in() + + precondition_options + + + DWI::GradImportOptions() + + DWI::GradExportOptions() + OptionGroup("Options for exporting additional data regarding PCA behaviour") + Option("rank", @@ -156,37 +155,58 @@ void usage() { // clang-format on template -void run(Header &data, +void run(Image &input, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, Exports &exports) { - auto input = data.get_image().with_direct_io(3); - Estimate func(data, subsample, kernel, vst_noise_image, estimator, exports); - ThreadedLoop("running MP-PCA noise level estimation", data, 0, 3).run(func, input); + Estimate func(input, subsample, kernel, estimator, exports); + ThreadedLoop("running MP-PCA noise level estimation", input, 0, 3).run(func, input); } template -void run(Header &data, +void run(Header &dwi, const Demodulation &demodulation, + const demean_type demean, + Image &vst_noise_image, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, Exports &exports) { - if (!demodulation) { - run(data, subsample, kernel, vst_noise_image, estimator, exports); + auto opt_preconditioned = get_options("preconditioned"); + if (!demodulation && demean == demean_type::NONE && !vst_noise_image.valid()) { + if (!opt_preconditioned.empty()) { + WARN("-preconditioned option ignored: no preconditioning taking place"); + } + Image input = dwi.get_image().with_direct_io(3); + run(input, subsample, kernel, estimator, exports); return; } - auto input = data.get_image(); - auto input_demod = Image::scratch(data, "Phase-demodulated version of \"" + data.name() + "\""); - { - Filter::Demodulate demodulator(input, demodulation.axes, demodulation.mode == demodulation_t::LINEAR); - demodulator(input, input_demod); + Image input(dwi.get_image()); + const Precondition preconditioner(input, demodulation, demean, vst_noise_image); + Header H_preconditioned(input); + Stride::set(H_preconditioned, Stride::contiguous_along_axis(3, input)); + Image input_preconditioned; + input_preconditioned = opt_preconditioned.empty() + ? Image::scratch(H_preconditioned, "Preconditioned version of \"" + input.name() + "\"") + : Image::create(opt_preconditioned[0][0], H_preconditioned); + preconditioner(input, input_preconditioned, false); + run(input_preconditioned, subsample, kernel, estimator, exports); + if (vst_noise_image.valid()) { + Interp::Cubic> vst(vst_noise_image); + const Transform transform(exports.noise_out); + for (auto l = Loop(exports.noise_out)(exports.noise_out); l; ++l) { + vst.scanner(transform.voxel2scanner * Eigen::Vector3d({default_type(exports.noise_out.index(0)), + default_type(exports.noise_out.index(1)), + default_type(exports.noise_out.index(2))})); + exports.noise_out.value() *= vst.value(); + } + } + if (preconditioner.rank() == 1 && exports.rank_input.valid()) { + for (auto l = Loop(exports.rank_input)(exports.rank_input); l; ++l) + exports.rank_input.value() = + std::max(uint16_t(exports.rank_input.value()) + uint16_t(1), uint16_t(dwi.size(3))); } - Estimate func(data, subsample, kernel, vst_noise_image, estimator, exports); - ThreadedLoop("running MP-PCA noise level estimation", data, 0, 3).run(func, input_demod); } void run() { @@ -195,7 +215,12 @@ void run() { throw Exception("input image must be 4-dimensional"); bool complex = dwi.datatype().is_complex(); - const Demodulation demodulation = get_demodulation(dwi); + const Demodulation demodulation = select_demodulation(dwi); + const demean_type demean = select_demean(dwi); + Image vst_noise_image; + auto opt = get_options("vst"); + if (!opt.empty()) + vst_noise_image = Image::open(opt[0][0]); auto subsample = Subsample::make(dwi); assert(subsample); @@ -203,12 +228,7 @@ void run() { auto kernel = Kernel::make_kernel(dwi, subsample->get_factors()); assert(kernel); - Image vst_noise_image; - auto opt = get_options("vst"); - if (!opt.empty()) - vst_noise_image = Image::open(opt[0][0]); - - auto estimator = Estimator::make_estimator(false); + auto estimator = Estimator::make_estimator(vst_noise_image, false); assert(estimator); Exports exports(dwi, subsample->header()); @@ -233,20 +253,20 @@ void run() { case 0: assert(demodulation.axes.empty()); INFO("select real float32 for processing"); - run(dwi, subsample, kernel, vst_noise_image, estimator, exports); + run(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports); break; case 1: assert(demodulation.axes.empty()); INFO("select real float64 for processing"); - run(dwi, subsample, kernel, vst_noise_image, estimator, exports); + run(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports); break; case 2: INFO("select complex float32 for processing"); - run(dwi, demodulation, subsample, kernel, vst_noise_image, estimator, exports); + run(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports); break; case 3: INFO("select complex float64 for processing"); - run(dwi, demodulation, subsample, kernel, vst_noise_image, estimator, exports); + run(dwi, demodulation, demean, vst_noise_image, subsample, kernel, estimator, exports); break; } } diff --git a/cmd/dwidenoise.cpp b/cmd/dwidenoise.cpp index b2da3e4524..67c1ebdfcf 100644 --- a/cmd/dwidenoise.cpp +++ b/cmd/dwidenoise.cpp @@ -25,7 +25,6 @@ #include #include -#include "denoise/demodulate.h" #include "denoise/denoise.h" #include "denoise/estimator/base.h" #include "denoise/estimator/estimator.h" @@ -39,6 +38,7 @@ #include "denoise/kernel/kernel.h" #include "denoise/kernel/sphere_radius.h" #include "denoise/kernel/sphere_ratio.h" +#include "denoise/precondition.h" #include "denoise/recon.h" #include "denoise/subsample.h" @@ -146,14 +146,7 @@ void usage() { + Estimator::estimator_denoise_options + Kernel::options + subsample_option - + demodulation_options - - + Option("vst", - "apply a within-patch variance-stabilising transformation based on a pre-estimated noise level map; " - "note that this will be used for within-patch non-stationariy correction only, " - "if noise level estimate is to be used for denoising also " - "it must be additionally provided via the -noise_in option") - + Argument("image").type_image_in() + + precondition_options + OptionGroup("Options that affect reconstruction of the output image series") + Option("filter", @@ -201,10 +194,6 @@ void usage() { + Option("sum_optshrink", "the sum of eigenvector weights computed for the denoising patch centred at each voxel " "as a result of performing optimal shrinkage") - + Argument("image").type_image_out() - + Option("noise_cov", - "export an image of the Coefficient of Variation (CoV) of noise level within each patch " - "(only applicable if -nonstationarity is specified)") + Argument("image").type_image_out(); COPYRIGHT = @@ -242,23 +231,16 @@ void usage() { std::complex operator/(const std::complex &c, const float n) { return c / double(n); } template -void run(Header &data, +void run(Image &input, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, filter_type filter, aggregator_type aggregator, - const std::string &output_name, + Image &output, Exports &exports) { - auto input = data.get_image().with_direct_io(3); - // create output - Header header(data); - header.datatype() = DataType::from(); - auto output = Image::create(output_name, header); - // run - Recon func(data, subsample, kernel, vst_noise_image, estimator, filter, aggregator, exports); - ThreadedLoop("running MP-PCA denoising", data, 0, 3).run(func, input, output); + Recon func(input, subsample, kernel, estimator, filter, aggregator, exports); + ThreadedLoop("running MP-PCA denoising", input, 0, 3).run(func, input, output); // Rescale output if aggregation was performed if (aggregator == aggregator_type::EXCLUSIVE) return; @@ -273,49 +255,75 @@ void run(Header &data, } template -void run(Header &data, +void run(Header &dwi, const Demodulation &demodulation, + const demean_type demean, + Image &vst_noise_image, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, filter_type filter, aggregator_type aggregator, const std::string &output_name, Exports &exports) { - if (!demodulation) { - run(data, subsample, kernel, vst_noise_image, estimator, filter, aggregator, output_name, exports); + auto opt_preconditioned = get_options("preconditioned"); + if (!demodulation && demean == demean_type::NONE && !vst_noise_image.valid()) { + if (!opt_preconditioned.empty()) { + WARN("-preconditioned option ignored: no preconditioning taking place"); + } + auto input = dwi.get_image().with_direct_io(3); + Header H(dwi); + H.datatype() = DataType::from(); + auto output = Image::create(output_name, H); + run(input, subsample, kernel, estimator, filter, aggregator, output, exports); return; } - auto input = data.get_image(); - // generate scratch version of DWI with phase demodulation - Header H_scratch(data); - Stride::set(H_scratch, Stride::contiguous_along_axis(3)); - H_scratch.datatype() = DataType::from(); - H_scratch.datatype().set_byte_order_native(); - auto input_demodulated = Image::scratch(H_scratch, "Phase-demodulated version of input DWI"); - Filter::Demodulate demodulate(input, demodulation.axes, demodulation.mode == demodulation_t::LINEAR); - demodulate(input, input_demodulated, false); - input = Image(); // free memory + auto input = dwi.get_image(); + // perform preconditioning + const Precondition preconditioner(input, demodulation, demean, vst_noise_image); + Header H_preconditioned(dwi); + Stride::set(H_preconditioned, Stride::contiguous_along_axis(3)); + H_preconditioned.datatype() = DataType::from(); + H_preconditioned.datatype().set_byte_order_native(); + Image input_preconditioned; + input_preconditioned = opt_preconditioned.empty() + ? Image::scratch(H_preconditioned, "Preconditioned version of \"" + dwi.name() + "\"") + : Image::create(opt_preconditioned[0][0], H_preconditioned); + preconditioner(input, input_preconditioned, false); // create output - Header header(data); - header.datatype() = DataType::from(); - auto output = Image::create(output_name, header); + Header H(dwi); + H.datatype() = DataType::from(); + auto output = Image::create(output_name, H); // run - Recon func(data, subsample, kernel, vst_noise_image, estimator, filter, aggregator, exports); - ThreadedLoop("running MP-PCA denoising", data, 0, 3).run(func, input_demodulated, output); - // Re-apply phase ramps that were previously demodulated - demodulate(output, true); - // Rescale output if performing aggregation - if (aggregator == aggregator_type::EXCLUSIVE) - return; - for (auto l_voxel = Loop(exports.sum_aggregation)(output, exports.sum_aggregation); l_voxel; ++l_voxel) { - for (auto l_volume = Loop(3)(output); l_volume; ++l_volume) - output.value() /= float(exports.sum_aggregation.value()); + run(input_preconditioned, subsample, kernel, estimator, filter, aggregator, output, exports); + // reverse effects of preconditioning + Image output2(output); + preconditioner(output, output2, true); + // compensate for effects of preconditioning where relevant + if (exports.noise_out.valid() && vst_noise_image.valid()) { + Interp::Cubic> vst(vst_noise_image); + const Transform transform(exports.noise_out); + for (auto l = Loop(exports.noise_out)(exports.noise_out); l; ++l) { + vst.scanner(transform.voxel2scanner * Eigen::Vector3d{default_type(exports.noise_out.index(0)), + default_type(exports.noise_out.index(1)), + default_type(exports.noise_out.index(2))}); + exports.noise_out.value() *= vst.value(); + } } - if (exports.rank_output.valid()) { - for (auto l = Loop(exports.sum_aggregation)(exports.rank_output, exports.sum_aggregation); l; ++l) - exports.rank_output.value() /= exports.sum_aggregation.value(); + if (preconditioner.rank() == 1) { + if (exports.rank_input.valid()) { + for (auto l = Loop(exports.rank_input)(exports.rank_input); l; ++l) + exports.rank_input.value() = + std::min(uint16_t(exports.rank_input.value()) + uint16_t(1), uint16_t(dwi.size(3))); + } + if (exports.rank_output.valid()) { + for (auto l = Loop(exports.rank_output)(exports.rank_output); l; ++l) + exports.rank_output.value() = std::min(float(exports.rank_output.value()) + 1.0f, float(dwi.size(3))); + } + if (exports.sum_optshrink.valid()) { + for (auto l = Loop(exports.sum_optshrink)(exports.sum_optshrink); l; ++l) + exports.sum_optshrink.value() = float(exports.sum_optshrink.value()) + 1.0f; + } } } @@ -324,7 +332,12 @@ void run() { if (dwi.ndim() != 4 || dwi.size(3) <= 1) throw Exception("input image must be 4-dimensional"); - const Demodulation demodulation = get_demodulation(dwi); + const Demodulation demodulation = select_demodulation(dwi); + const demean_type demean = select_demean(dwi); + Image vst_noise_image; + auto opt = get_options("vst"); + if (!opt.empty()) + vst_noise_image = Image::open(opt[0][0]); auto subsample = Subsample::make(dwi); assert(subsample); @@ -332,12 +345,7 @@ void run() { auto kernel = Kernel::make_kernel(dwi, subsample->get_factors()); assert(kernel); - Image vst_noise_image; - auto opt = get_options("vst"); - if (!opt.empty()) - vst_noise_image = Image::open(opt[0][0]); - - auto estimator = Estimator::make_estimator(true); + auto estimator = Estimator::make_estimator(vst_noise_image, true); assert(estimator); filter_type filter = get_options("fixed_rank").empty() ? filter_type::OPTSHRINK : filter_type::TRUNCATE; @@ -400,13 +408,6 @@ void run() { exports.set_sum_aggregation(""); } - opt = get_options("noise_cov"); - if (!opt.empty()) { - if (!vst_noise_image.valid()) - throw Exception("-noise_variance can only be specified if -nonstationarity option is used"); - exports.set_noise_cov(opt[0][0]); - } - int prec = get_option_value("datatype", 0); // default: single precision if (dwi.datatype().is_complex()) prec += 2; // support complex input data @@ -416,9 +417,11 @@ void run() { INFO("select real float32 for processing"); run( // dwi, // + demodulation, // + demean, // + vst_noise_image, // subsample, // kernel, // - vst_noise_image, // estimator, // filter, // aggregator, // @@ -430,9 +433,11 @@ void run() { INFO("select real float64 for processing"); run( // dwi, // + demodulation, // + demean, // + vst_noise_image, // subsample, // kernel, // - vst_noise_image, // estimator, // filter, // aggregator, // @@ -444,9 +449,10 @@ void run() { run( // dwi, // demodulation, // + demean, // + vst_noise_image, // subsample, // kernel, // - vst_noise_image, // estimator, // filter, // aggregator, // @@ -458,9 +464,10 @@ void run() { run( // dwi, // demodulation, // + demean, // + vst_noise_image, // subsample, // kernel, // - vst_noise_image, // estimator, // filter, // aggregator, // diff --git a/core/filter/demodulate.h b/core/filter/demodulate.h index 1fc16e55c7..aaa366d3ef 100644 --- a/core/filter/demodulate.h +++ b/core/filter/demodulate.h @@ -34,6 +34,9 @@ namespace MR::Filter { // From Manzano-Patron et al. 2024 constexpr default_type default_tukey_FWHM_demodulate = 0.58; +// TODO Ideally do more experimentation to figure out a reasonable default here +// Too high and everything ends up in the real axis; +// too low and disjointed phase cound drive up signal rank constexpr default_type default_tukey_alpha_demodulate = 2.0 * (1.0 - default_tukey_FWHM_demodulate); /*! Estimate a linear phase ramp of a complex image and demodulate by such @@ -245,6 +248,8 @@ class Demodulate : public Base { } } + Image operator()() const { return phase; } + protected: // TODO Change to Image; can produce complex value at processing time Image phase; diff --git a/core/filter/kspace.h b/core/filter/kspace.h index e6e99d3c45..dd5cb33d78 100644 --- a/core/filter/kspace.h +++ b/core/filter/kspace.h @@ -28,7 +28,7 @@ namespace MR::Filter { -std::vector kspace_window_choices({"tukey"}); +const std::vector kspace_window_choices({"tukey"}); enum class kspace_windowfn_t { TUKEY }; constexpr default_type default_tukey_width = 0.5; diff --git a/src/denoise/demodulate.cpp b/src/denoise/demodulate.cpp deleted file mode 100644 index f5c5a5e009..0000000000 --- a/src/denoise/demodulate.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright (c) 2008-2024 the MRtrix3 contributors. - * - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this - * file, You can obtain one at http://mozilla.org/MPL/2.0/. - * - * Covered Software is provided under this License on an "as is" - * basis, without warranty of any kind, either expressed, implied, or - * statutory, including, without limitation, warranties that the - * Covered Software is free of defects, merchantable, fit for a - * particular purpose or non-infringing. - * See the Mozilla Public License v. 2.0 for more details. - * - * For more details, see http://www.mrtrix.org/. - */ - -#include "denoise/demodulate.h" - -#include "app.h" -#include "axes.h" - -using namespace MR::App; - -namespace MR::Denoise { - -const char *const demodulation_description = - "If the input data are of complex type, " - "then a smooth non-linear phase will be demodulated removed from each k-space prior to PCA. " - "In the absence of metadata indicating otherwise, " - "it is inferred that the first two axes correspond to acquired slices, " - "and different slices / volumes will be demodulated individually; " - "this behaviour can be modified using the -demod_axes option. " - "A strictly linear phase term can instead be regressed from each k-space, " - "similarly to performed in Cordero-Grande et al. 2019, " - "by specifying -demodulate linear."; - -// clang-format off -const OptionGroup demodulation_options = OptionGroup("Options for phase demodulation of complex data") - + Option("demodulate", - "select form of phase demodulation; " - "options are: " + join(demodulation_choices, ",") + " " - "(default: nonlinear)") - + Argument("mode").type_choice(demodulation_choices) - + Option("demod_axes", - "comma-separated list of axis indices along which FFT can be applied for phase demodulation") - + Argument("axes").type_sequence_int(); -// clang-format on - -Demodulation get_demodulation(const Header &H) { - const bool complex = H.datatype().is_complex(); - auto opt_mode = get_options("demodulate"); - auto opt_axes = get_options("demod_axes"); - Demodulation result; - if (opt_mode.empty()) { - if (complex) { - result.mode = demodulation_t::NONLINEAR; - } else { - if (!opt_axes.empty()) { - throw Exception("Option -demod_axes cannot be specified: " - "no phase demodulation of magnitude data"); - } - } - } else { - result.mode = demodulation_t(int(opt_mode[0][0])); - if (!complex) { - switch (result.mode) { - case demodulation_t::NONE: - WARN("Specifying -demodulate none is redundant: " - "never any phase demodulation for magnitude input data"); - break; - default: - throw Exception("Phase modulation cannot be utilised for magnitude-only input data"); - } - } - } - if (!complex) - return result; - if (opt_axes.empty()) { - auto slice_encoding_it = H.keyval().find("SliceEncodingDirection"); - if (slice_encoding_it == H.keyval().end()) { - // TODO Ideally this would be the first two axes *on disk*, - // not following transform realignment - INFO("No header information on slice encoding; " - "assuming first two axes are within-slice"); - result.axes = {0, 1}; - } else { - auto dir = Axes::id2dir(slice_encoding_it->second); - for (size_t axis = 0; axis != 3; ++axis) { - if (!dir[axis]) - result.axes.push_back(axis); - } - INFO("For header SliceEncodingDirection=\"" + slice_encoding_it->second + - "\", " - "chose demodulation axes: " + - join(result.axes, ",")); - } - } else { - result.axes = parse_ints(opt_axes[0][0]); - for (auto axis : result.axes) { - if (axis > 2) - throw Exception("Phase demodulation implementation not yet robust to non-spatial axes"); - } - } - return result; -} - -} // namespace MR::Denoise diff --git a/src/denoise/demodulate.h b/src/denoise/demodulate.h deleted file mode 100644 index b0d9d04153..0000000000 --- a/src/denoise/demodulate.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2008-2024 the MRtrix3 contributors. - * - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this - * file, You can obtain one at http://mozilla.org/MPL/2.0/. - * - * Covered Software is provided under this License on an "as is" - * basis, without warranty of any kind, either expressed, implied, or - * statutory, including, without limitation, warranties that the - * Covered Software is free of defects, merchantable, fit for a - * particular purpose or non-infringing. - * See the Mozilla Public License v. 2.0 for more details. - * - * For more details, see http://www.mrtrix.org/. - */ - -#pragma once - -#include -#include - -#include "app.h" -#include "header.h" - -namespace MR::Denoise { - -extern const char *const demodulation_description; - -const std::vector demodulation_choices({"none", "linear", "nonlinear"}); -enum class demodulation_t { NONE, LINEAR, NONLINEAR }; - -extern const App::OptionGroup demodulation_options; - -class Demodulation { -public: - Demodulation(demodulation_t mode) : mode(mode) {} - Demodulation() : mode(demodulation_t::NONE) {} - explicit operator bool() const { return mode != demodulation_t::NONE; } - bool operator!() const { return mode == demodulation_t::NONE; } - demodulation_t mode; - std::vector axes; -}; - -Demodulation get_demodulation(const Header &); - -} // namespace MR::Denoise diff --git a/src/denoise/denoise.h b/src/denoise/denoise.h index ee17664d34..32cd596735 100644 --- a/src/denoise/denoise.h +++ b/src/denoise/denoise.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include "app.h" diff --git a/src/denoise/estimate.cpp b/src/denoise/estimate.cpp index 3fb8da2af7..fdc5e1617e 100644 --- a/src/denoise/estimate.cpp +++ b/src/denoise/estimate.cpp @@ -27,15 +27,12 @@ template Estimate::Estimate(const Header &header, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, Exports &exports) : m(header.size(3)), subsample(subsample), kernel(kernel), estimator(estimator), - transform(std::make_shared(header)), - vst_noise_image(vst_noise_image), X(m, kernel->estimated_size()), XtX(std::min(m, kernel->estimated_size()), std::min(m, kernel->estimated_size())), eig(std::min(m, kernel->estimated_size())), @@ -123,42 +120,10 @@ template void Estimate::operator()(Image &dwi) { exports.patchcount.value() = exports.patchcount.value() + 1; } } - if (exports.noise_cov.valid()) { - double variance(double(0)); - for (auto v : patch.voxels) - variance += Math::pow2(v.noise_level - patch.centre_noise); - variance /= (patch.voxels.size() - 1); - assign_pos_of(ss_index).to(exports.noise_cov); - exports.noise_cov.value() = std::sqrt(variance) / patch.centre_noise; - } } template void Estimate::load_data(Image &image) { const Kernel::Voxel::index_type pos({image.index(0), image.index(1), image.index(2)}); - if (vst_noise_image.valid()) { - assert(patch.centre_realspace.allFinite()); - Interp::Cubic> interp(vst_noise_image); - interp.scanner(patch.centre_realspace); - assert(!(!interp)); - patch.centre_noise = interp.value(); - if (patch.centre_noise > 0.0) { - for (ssize_t i = 0; i != patch.voxels.size(); ++i) { - interp.scanner(transform->voxel2scanner * patch.voxels[i].index.cast()); - // TODO Trying to pull intensity information from voxels beyond the extremities of the subsampled image - // may cause problems - assert(!(!interp)); - const double voxel_noise = interp.value(); - patch.voxels[i].noise_level = voxel_noise; - const double scaling_factor = voxel_noise > 0.0 ? (patch.centre_noise / voxel_noise) : 1.0; - assert(std::isfinite(scaling_factor)); - assign_pos_of(patch.voxels[i].index, 0, 3).to(image); - X.col(i) = image.row(3); - X.col(i) *= scaling_factor; - } - assign_pos_of(pos, 0, 3).to(image); - return; - } - } for (ssize_t i = 0; i != patch.voxels.size(); ++i) { assign_pos_of(patch.voxels[i].index, 0, 3).to(image); X.col(i) = image.row(3); diff --git a/src/denoise/estimate.h b/src/denoise/estimate.h index 622acdd7f7..7aa9f46589 100644 --- a/src/denoise/estimate.h +++ b/src/denoise/estimate.h @@ -43,7 +43,6 @@ template class Estimate { Estimate(const Header &header, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, Exports &exports); @@ -57,12 +56,8 @@ template class Estimate { std::shared_ptr kernel; std::shared_ptr estimator; - // Necessary for transform from input voxel locations to nonstationarity image - std::shared_ptr transform; - // Reusable memory Kernel::Data patch; - Image vst_noise_image; MatrixType X; MatrixType XtX; Eigen::SelfAdjointEigenSolver eig; diff --git a/src/denoise/estimator/estimator.cpp b/src/denoise/estimator/estimator.cpp index 41ffee1406..bf66aaffa5 100644 --- a/src/denoise/estimator/estimator.cpp +++ b/src/denoise/estimator/estimator.cpp @@ -53,7 +53,7 @@ const OptionGroup estimator_denoise_options = "set a fixed input signal rank rather than estimating the noise level from the data") + Argument("value").type_integer(1); -std::shared_ptr make_estimator(const bool permit_bypass) { +std::shared_ptr make_estimator(Image &vst_noise_in, const bool permit_bypass) { auto opt = get_options("estimator"); if (permit_bypass) { auto noise_in = get_options("noise_in"); @@ -63,7 +63,7 @@ std::shared_ptr make_estimator(const bool permit_bypass) { throw Exception("Cannot both provide an input noise level image and specify a noise level estimator"); if (!fixed_rank.empty()) throw Exception("Cannot both provide an input noise level image and request a fixed signal rank"); - return std::make_shared(noise_in[0][0]); + return std::make_shared(noise_in[0][0], vst_noise_in); } if (!fixed_rank.empty()) { if (!opt.empty()) diff --git a/src/denoise/estimator/estimator.h b/src/denoise/estimator/estimator.h index 322d65ca96..1384b157ee 100644 --- a/src/denoise/estimator/estimator.h +++ b/src/denoise/estimator/estimator.h @@ -21,6 +21,7 @@ #include #include "app.h" +#include "image.h" namespace MR::Denoise::Estimator { @@ -30,6 +31,6 @@ extern const App::Option estimator_option; extern const App::OptionGroup estimator_denoise_options; const std::vector estimators = {"exp1", "exp2", "med", "mrm2022"}; enum class estimator_type { EXP1, EXP2, MED, MRM2022 }; -std::shared_ptr make_estimator(const bool permit_bypass); +std::shared_ptr make_estimator(Image &vst_noise_in, const bool permit_bypass); } // namespace MR::Denoise::Estimator diff --git a/src/denoise/estimator/import.h b/src/denoise/estimator/import.h index 6705c11155..fde0801bba 100644 --- a/src/denoise/estimator/import.h +++ b/src/denoise/estimator/import.h @@ -28,7 +28,9 @@ namespace MR::Denoise::Estimator { class Import : public Base { public: - Import(const std::string &path) : noise_image(Image::open(path)) {} + Import(const std::string &path, Image &vst_noise_in) // + : noise_image(Image::open(path)), // + vst_noise_image(vst_noise_in) {} // Result operator()(const eigenvalues_type &s, // const ssize_t m, // const ssize_t n, // @@ -46,7 +48,17 @@ class Import : public Base { // where the patch centre is too close to the image edge for cubic interpolation if (!interp.scanner(pos)) return result; - result.sigma2 = Math::pow2(interp.value()); + // If the data have been preconditioned at input based on a pre-estimated noise level, + // then we need to rescale the threshold that we load from this image + // based on knowledge of that rescaling + if (vst_noise_image.valid()) { + Interp::Cubic> vst_interp(vst_noise_image); + if (!vst_interp.scanner(pos)) + return result; + result.sigma2 = Math::pow2(interp.value() / vst_interp.value()); + } else { + result.sigma2 = Math::pow2(interp.value()); + } } // From this noise level, // estimate the upper bound of the MP distribution and rank of signal @@ -69,6 +81,7 @@ class Import : public Base { private: Image noise_image; + Image vst_noise_image; }; } // namespace MR::Denoise::Estimator diff --git a/src/denoise/exports.h b/src/denoise/exports.h index 0cd1e8a715..d93008025c 100644 --- a/src/denoise/exports.h +++ b/src/denoise/exports.h @@ -59,7 +59,6 @@ class Exports { else sum_aggregation = Image::create(path, H_in); } - void set_noise_cov(const std::string &path) { noise_cov = Image::create(path, H_ss); } Image noise_out; Image rank_input; @@ -69,7 +68,6 @@ class Exports { Image voxelcount; Image patchcount; Image sum_aggregation; - Image noise_cov; protected: Header H_in; diff --git a/src/denoise/precondition.cpp b/src/denoise/precondition.cpp new file mode 100644 index 0000000000..6c6c9afada --- /dev/null +++ b/src/denoise/precondition.cpp @@ -0,0 +1,369 @@ +/* Copyright (c) 2008-2024 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "denoise/precondition.h" + +#include + +#include "algo/copy.h" +#include "app.h" +#include "axes.h" +#include "dwi/gradient.h" +#include "dwi/shells.h" +#include "transform.h" + +using namespace MR::App; + +namespace MR::Denoise { + +const char *const demodulation_description = + "If the input data are of complex type, " + "then a smooth non-linear phase will be demodulated removed from each k-space prior to PCA. " + "In the absence of metadata indicating otherwise, " + "it is inferred that the first two axes correspond to acquired slices, " + "and different slices / volumes will be demodulated individually; " + "this behaviour can be modified using the -demod_axes option. " + "A strictly linear phase term can instead be regressed from each k-space, " + "similarly to performed in Cordero-Grande et al. 2019, " + "by specifying -demodulate linear."; + +// clang-format off +const OptionGroup precondition_options = OptionGroup("Options for preconditioning data prior to PCA") + + Option("demodulate", + "select form of phase demodulation; " + "options are: " + join(demodulation_choices, ",") + " " + "(default: nonlinear)") + + Argument("mode").type_choice(demodulation_choices) + + Option("demod_axes", + "comma-separated list of axis indices along which FFT can be applied for phase demodulation") + + Argument("axes").type_sequence_int() + + Option("demean", + "select method of demeaning prior to PCA; " + "options are: " + join(demean_choices, ",") + " " + "(default: 'shells' if DWI gradient table available, 'all' otherwise)") + + Argument("mode").type_choice(demean_choices) + + Option("vst", + "apply a within-patch variance-stabilising transformation based on a pre-estimated noise level map") + + Argument("image").type_image_in() + + Option("preconditioned", + "export the preconditioned version of the input image that is the input to PCA") + + Argument("image").type_image_out(); +// clang-format on + +Demodulation select_demodulation(const Header &H) { + const bool complex = H.datatype().is_complex(); + auto opt_mode = get_options("demodulate"); + auto opt_axes = get_options("demod_axes"); + Demodulation result; + if (opt_mode.empty()) { + if (complex) { + result.mode = demodulation_t::NONLINEAR; + } else { + if (!opt_axes.empty()) { + throw Exception("Option -demod_axes cannot be specified: " + "no phase demodulation of magnitude data"); + } + } + } else { + result.mode = demodulation_t(int(opt_mode[0][0])); + if (!complex) { + switch (result.mode) { + case demodulation_t::NONE: + WARN("Specifying -demodulate none is redundant: " + "never any phase demodulation for magnitude input data"); + break; + default: + throw Exception("Phase modulation cannot be utilised for magnitude-only input data"); + } + } + } + if (!complex) + return result; + if (opt_axes.empty()) { + auto slice_encoding_it = H.keyval().find("SliceEncodingDirection"); + if (slice_encoding_it == H.keyval().end()) { + // TODO Ideally this would be the first two axes *on disk*, + // not following transform realignment + INFO("No header information on slice encoding; " + "assuming first two axes are within-slice"); + result.axes = {0, 1}; + } else { + auto dir = Axes::id2dir(slice_encoding_it->second); + for (size_t axis = 0; axis != 3; ++axis) { + if (!dir[axis]) + result.axes.push_back(axis); + } + INFO("For header SliceEncodingDirection=\"" + slice_encoding_it->second + + "\", " + "chose demodulation axes: " + + join(result.axes, ",")); + } + } else { + result.axes = parse_ints(opt_axes[0][0]); + for (auto axis : result.axes) { + if (axis > 2) + throw Exception("Phase demodulation implementation not yet robust to non-spatial axes"); + } + } + return result; +} + +demean_type select_demean(const Header &H) { + auto opt = get_options("demean"); + if (opt.empty()) { + try { + auto grad = DWI::get_DW_scheme(H); + auto shells = DWI::Shells(grad); + INFO("Choosing to demean per b-value shell based on input gradient table"); + return demean_type::SHELLS; + } catch (Exception &) { + INFO("Choosing to demean across all volumes based on absent / non-shelled gradient table"); + return demean_type::ALL; + } + } + return demean_type(int(opt[0][0])); +} + +template +Precondition::Precondition(Image &image, + const Demodulation &demodulation, + const demean_type demean, + Image &vst_image) + : H(image), // + vst_image(vst_image) { // + + // Step 1: Phase demodulation + Image dephased; + if (demodulation.mode == demodulation_t::NONE) { + dephased = image; + } else { + typename DemodulatorSelector::type demodulator(image, // + demodulation.axes, // + demodulation.mode == demodulation_t::LINEAR); // + phase_image = demodulator(); + // Only actually perform the dephasing of the input image + // if that result needs to be utilised in calculation of the mean + if (demean != demean_type::NONE) { + dephased = Image::scratch(H, "Scratch dephased version of \"" + image.name() + "\" for mean calculation"); + demodulator(image, dephased, false); + } + } + + // Step 2: Demeaning + Header H_mean(H); + switch (demean) { + case demean_type::NONE: + break; + case demean_type::SHELLS: { + Eigen::Matrix grad; + try { + grad = DWI::get_DW_scheme(H_mean); + } catch (Exception &e) { + throw Exception(e, "Cannot demean by shells as unable to obtain valid gradient table"); + } + try { + DWI::Shells shells(grad); + vol2shellidx.resize(image.size(3), -1); + for (ssize_t shell_idx = 0; shell_idx != shells.count(); ++shell_idx) { + for (auto v : shells[shell_idx].get_volumes()) + vol2shellidx[v] = shell_idx; + } + assert(*std::min_element(vol2shellidx.begin(), vol2shellidx.end()) == 0); + H_mean.size(3) = shells.count(); + DWI::stash_DW_scheme(H_mean, grad); + mean_image = Image::scratch(H_mean, "Scratch image for per-shell mean intensity"); + for (auto l_voxel = Loop("Computing mean intensities across shells", H_mean, 0, 3)(dephased, mean_image); // + l_voxel; // + ++l_voxel) { // + for (ssize_t volume_idx = 0; volume_idx != image.size(3); ++volume_idx) { + dephased.index(3) = volume_idx; + mean_image.index(3) = vol2shellidx[volume_idx]; + mean_image.value() += dephased.value(); + } + for (ssize_t shell_idx = 0; shell_idx != shells.count(); ++shell_idx) { + mean_image.index(3) = shell_idx; + mean_image.value() /= T(shells[shell_idx].count()); + } + } + } catch (Exception &e) { + throw Exception(e, "Cannot demean by shells as unable to establish b-value shell structure"); + } + } break; + case demean_type::ALL: { + H_mean.ndim() = 3; + DWI::clear_DW_scheme(H_mean); + mean_image = Image::scratch(H_mean, "Scratch image for mean intensity across all volumes"); + for (auto l_voxel = Loop("Computing mean intensity across all volumes", H_mean)(dephased, mean_image); // + l_voxel; // + ++l_voxel) { // + T mean(T(0)); + for (auto l_volume = Loop(3)(dephased); l_volume; ++l_volume) + mean += T(dephased.value()); + mean_image.value() = mean / T(image.size(3)); + } + } break; + } + + // Step 3: Variance-stabilising transform + // Image vst is already set within constructor definition; + // nothing to do here +} + +namespace { +// Private functions to prevent compiler attempting to create complex functions for real types +template +typename std::enable_if::value, T>::type demodulate(const cfloat in, const cfloat phase) { + return in * std::conj(phase); +} +template +typename std::enable_if::value, T>::type demodulate(const cdouble in, const cfloat phase) { + return in * std::conj(cdouble(phase)); +} +template +typename std::enable_if::value, T>::type demodulate(const T in, const cfloat phase) { + assert(false); + return in; +} +template +typename std::enable_if::value, T>::type modulate(const cfloat in, const cfloat phase) { + return in * phase; +} +template +typename std::enable_if::value, T>::type modulate(const cdouble in, const cfloat phase) { + return in * cdouble(phase); +} +template typename std::enable_if::value, T>::type modulate(const T in, const cfloat phase) { + assert(false); + return in; +} +} // namespace + +template void Precondition::operator()(Image input, Image output, const bool inverse) const { + + // For thread-safety / const-ness + const Transform transform(input); + Image phase(phase_image); + Image mean(mean_image); + std::unique_ptr>> vst; + if (vst_image.valid()) + vst.reset(new Interp::Cubic>(vst_image)); + + Eigen::Array data(input.size(3)); + if (inverse) { + for (auto l_voxel = Loop("Reversing data preconditioning", H, 0, 3)(input, output); l_voxel; ++l_voxel) { + + // Step 3: Reverse variance-stabilising transform + if (vst) { + vst->scanner(transform.voxel2scanner * // + Eigen::Vector3d({default_type(input.index(0)), // + default_type(input.index(1)), // + default_type(input.index(2))})); // + const T multiplier = T(vst->value()); + for (ssize_t v = 0; v != input.size(3); ++v) { + input.index(3) = v; + data[v] = T(input.value()) * multiplier; + } + } else { + for (ssize_t v = 0; v != input.size(3); ++v) { + input.index(3) = v; + data[v] = input.value(); + } + } + + // Step 2: Reverse demeaning + if (mean.valid()) { + assign_pos_of(input, 0, 3).to(mean); + if (mean.ndim() == 3) { + const T mean_value = mean.value(); + data += mean_value; + } else { + for (ssize_t v = 0; v != input.size(3); ++v) { + mean.index(3) = vol2shellidx[v]; + data[v] += T(mean.value()); + } + } + } + + // Step 1: Reverse phase demodulation + if (phase.valid()) { + assign_pos_of(input, 0, 3).to(phase); + for (ssize_t v = 0; v != input.size(3); ++v) { + phase.index(3) = v; + data[v] = modulate(data[v], phase.value()); + } + } + + // Write to output + for (ssize_t v = 0; v != input.size(3); ++v) { + output.index(3) = v; + output.value() = data[v]; + } + } + return; + } + + // Applying forward preconditioning + for (auto l_voxel = Loop("Applying data preconditioning", H, 0, 3)(input, output); l_voxel; ++l_voxel) { + + // Step 1: Phase demodulation + if (phase.valid()) { + assign_pos_of(input, 0, 3).to(phase); + for (ssize_t v = 0; v != input.size(3); ++v) { + input.index(3) = v; + phase.index(3) = v; + data[v] = demodulate(input.value(), phase.value()); + } + } else { + for (ssize_t v = 0; v != input.size(3); ++v) { + input.index(3) = v; + data[v] = input.value(); + } + } + + // Step 2: Demeaning + if (mean.valid()) { + assign_pos_of(input, 0, 3).to(mean); + if (mean.ndim() == 3) { + const T mean_value = mean.value(); + for (ssize_t v = 0; v != input.size(3); ++v) + data[v] -= mean_value; + } else { + for (ssize_t v = 0; v != input.size(3); ++v) { + mean.index(3) = vol2shellidx[v]; + data[v] -= T(mean.value()); + } + } + } + + // Step 3: Variance-stabilising transform + if (vst) { + vst->scanner(transform.voxel2scanner // + * Eigen::Vector3d({default_type(input.index(0)), // + default_type(input.index(1)), // + default_type(input.index(2))})); // + const default_type multiplier = 1.0 / vst->value(); + data *= multiplier; + } + + // Write to output + for (ssize_t v = 0; v != input.size(3); ++v) { + output.index(3) = v; + output.value() = data[v]; + } + } +} + +} // namespace MR::Denoise diff --git a/src/denoise/precondition.h b/src/denoise/precondition.h new file mode 100644 index 0000000000..7a1ff9c74c --- /dev/null +++ b/src/denoise/precondition.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2008-2024 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include +#include + +#include "app.h" +#include "denoise/kernel/voxel.h" +#include "filter/demodulate.h" +#include "header.h" +#include "image.h" +#include "interp/cubic.h" +#include "types.h" + +namespace MR::Denoise { + +extern const char *const demodulation_description; + +const std::vector demodulation_choices({"none", "linear", "nonlinear"}); +enum class demodulation_t { NONE, LINEAR, NONLINEAR }; + +const std::vector demean_choices = {"none", "shells", "all"}; +enum class demean_type { NONE, SHELLS, ALL }; + +extern const App::OptionGroup precondition_options; + +class Demodulation { +public: + Demodulation(demodulation_t mode) : mode(mode) {} + Demodulation() : mode(demodulation_t::NONE) {} + explicit operator bool() const { return mode != demodulation_t::NONE; } + bool operator!() const { return mode == demodulation_t::NONE; } + demodulation_t mode; + std::vector axes; +}; +Demodulation select_demodulation(const Header &); + +demean_type select_demean(const Header &); + +// Need to SFINAE define the demodulator type, +// so that it does not attempt to compile the demodulation filter for non-complex types +class DummyDemodulator { +public: + template DummyDemodulator(ImageType &, const std::vector &, const bool) {} + template + void operator()(InputImageType &, OutputImageType &, const bool) { + assert(false); + } + Image operator()() { return Image(); } +}; +template struct DemodulatorSelector { + using type = DummyDemodulator; +}; +template struct DemodulatorSelector> { + using type = Filter::Demodulate; +}; + +template class Precondition { +public: + Precondition(Image &image, const Demodulation &demodulation, const demean_type demean, Image &vst); + Precondition(Precondition &) = default; + void operator()(Image input, Image output, const bool inverse = false) const; + ssize_t rank() const { return phase_image.valid() || mean_image.valid() ? 1 : 0; } + +private: + const Header H; + // First step: Phase demodulation + Image phase_image; + // Second step: Demeaning + std::vector vol2shellidx; + Image mean_image; + // Third step: Variance-stabilising transform + Image vst_image; +}; +template class Precondition; +template class Precondition; +template class Precondition; +template class Precondition; + +} // namespace MR::Denoise diff --git a/src/denoise/recon.cpp b/src/denoise/recon.cpp index 9cfc177667..2573e83153 100644 --- a/src/denoise/recon.cpp +++ b/src/denoise/recon.cpp @@ -24,12 +24,11 @@ template Recon::Recon(const Header &header, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, filter_type filter, aggregator_type aggregator, Exports &exports) - : Estimate(header, subsample, kernel, vst_noise_image, estimator, exports), + : Estimate(header, subsample, kernel, estimator, exports), filter(filter), aggregator(aggregator), // FWHM = 2 x cube root of spacings between kernels @@ -75,12 +74,6 @@ template void Recon::operator()(Image &dwi, Image &out) { const double transition = 1.0 + std::sqrt(beta); for (ssize_t i = 0; i != r; ++i) { const double lam = std::max(Estimate::s[i], 0.0) / q; - // TODO Should this be based on the noise level, - // or on the estimated upper bound of the MP distribution? - // If based on upper bound, - // there will be an issue with importing this information from a pre-estimated noise map - // TODO Unexpected absence of sqrt() here - // const double y = lam / std::sqrt(Estimate::threshold.sigma2); const double y = lam / Estimate::threshold.sigma2; double nu = 0.0; if (y > transition) { diff --git a/src/denoise/recon.h b/src/denoise/recon.h index 3d96dedc3d..6db9e40d59 100644 --- a/src/denoise/recon.h +++ b/src/denoise/recon.h @@ -37,7 +37,6 @@ template class Recon : public Estimate { Recon(const Header &header, std::shared_ptr subsample, std::shared_ptr kernel, - Image &vst_noise_image, std::shared_ptr estimator, filter_type filter, aggregator_type aggregator,