From 33ffe332e430f5017f06d12630683ba910b4233a Mon Sep 17 00:00:00 2001 From: mawen1250 Date: Tue, 26 May 2015 22:41:19 +0800 Subject: [PATCH] Fixed a stupid mistake that the order of applying DCT and IDCT is reversed Fixed a stupid mistake that the order of applying DCT and IDCT is reversed (it should be DCT->filter->IDCT), which leads to over-filtering in fine structures and produces artifacts such as blocking, ringing and desaturation. Now it acctually filters as expected. Adjust hard-thresholding table to fit the unnormalized DCT, to be equivalent to filtering in orthogonal transform Default value of hard_thr changed from 2.3 to 2.8 for profile="vn", from 2.2 to 2.7 for other profiles Some function arguments changed to const reference Some read-only references in function arguments changed to const Some read-only member functions changed to const Cosmetics --- README.md | 15 ++-- include/BM3D_Base.h | 8 +-- include/BM3D_Basic.h | 2 +- include/BM3D_Final.h | 2 +- include/Block.h | 74 +++++++++----------- include/Helper.h | 152 +++++++++++++++++++++++++++++++---------- include/Type.h | 44 ++++++------ include/VBM3D_Base.h | 16 ++--- include/VBM3D_Basic.h | 6 +- include/VBM3D_Final.h | 6 +- source/BM3D.cpp | 60 +++------------- source/BM3D_Base.cpp | 6 +- source/BM3D_Basic.cpp | 4 +- source/BM3D_Final.cpp | 2 +- source/VBM3D_Base.cpp | 11 +-- source/VBM3D_Basic.cpp | 8 +-- source/VBM3D_Final.cpp | 6 +- 17 files changed, 228 insertions(+), 194 deletions(-) diff --git a/README.md b/README.md index d158c6f..2febb34 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ bm3d.Basic(clip input[, clip ref=input, string profile="fast", float[] sigma=[10 - sigma:
The strength of denoising, valid range [0, +inf), default [10,10,10].
+ This denoising algorithm is very sensitive to the sigma, so adjust it carefully according to the source.
Technically, this is the standard deviation of i.i.d. zero mean additive white Gaussian noise in 8 bit scale. BM3D denoising filter is designed based on this noise model, and best fit for attenuating it.
An array up to 3 elements can be assigned to set different sigma for Y,U,V channels. If less than 3 elements assigned, the last element's value will be assigned to the undefined elements.
0 will disable the processing for corresponding channel. @@ -141,7 +142,7 @@ bm3d.Basic(clip input[, clip ref=input, string profile="fast", float[] sigma=[10 - hard_thr:
The threshold parameter for the hard-thresholding in 3D transformed domain, in 8 bit scale, valid range (0, +inf).
- Larger results in stronger hard-threshold filtering in frequency domain.
+ Larger value results in stronger hard-threshold filtering in frequency domain.
Usually, to tweak denoising strength, it's better to adjust "sigma" rather than "hard_thr". - matrix:
@@ -296,7 +297,7 @@ bm3d.Basic / bm3d.Final / bm3d.VBasic / bm3d.VFinal bm3d.VBasic / bm3d.VFinal --------------------------------------------------- | profile || radius | ps_num | ps_range | ps_step | --------------------------------------------------- +--------------------------------------------------- | "fast" || 1/1 | 2/2 | 4/5 | 1/1/1/1 | | "lc" || 2/2 | 2/2 | 4/5 | 1/1/1/1 | | "np" || 3/3 | 2/2 | 5/6 | 1/1/1/1 | @@ -310,11 +311,11 @@ bm3d.Basic & bm3d.VBasic / bm3d.Final & bm3d.VFinal -------------------------------------------------------------- | profile || th_mse | hard_thr | -------------------------------------------------------------- -| "fast" || sigma[0]*80+400 / sigma[0]*10+200 | 2.2 / NUL | -| "lc" || sigma[0]*80+400 / sigma[0]*10+200 | 2.2 / NUL | -| "np" || sigma[0]*80+400 / sigma[0]*10+200 | 2.2 / NUL | -| "high" || sigma[0]*80+400 / sigma[0]*10+200 | 2.2 / NUL | -| "vn" || sigma[0]*150+1000 / sigma[0]*40+400 | 2.3 / NUL | +| "fast" || sigma[0]*80+400 / sigma[0]*10+200 | 2.7 / NUL | +| "lc" || sigma[0]*80+400 / sigma[0]*10+200 | 2.7 / NUL | +| "np" || sigma[0]*80+400 / sigma[0]*10+200 | 2.7 / NUL | +| "high" || sigma[0]*80+400 / sigma[0]*10+200 | 2.7 / NUL | +| "vn" || sigma[0]*150+1000 / sigma[0]*40+400 | 2.8 / NUL | -------------------------------------------------------------- ``` diff --git a/include/BM3D_Base.h b/include/BM3D_Base.h index 18b318b..ae1e806 100644 --- a/include/BM3D_Base.h +++ b/include/BM3D_Base.h @@ -230,18 +230,18 @@ class BM3D_Process_Base _NewFrame(width, height, dfi == fi); } - void Kernel(FLType *dst, const FLType *src, const FLType *ref); + void Kernel(FLType *dst, const FLType *src, const FLType *ref) const; void Kernel(FLType *dstY, FLType *dstU, FLType *dstV, const FLType *srcY, const FLType *srcU, const FLType *srcV, - const FLType *refY, const FLType *refU, const FLType *refV); + const FLType *refY, const FLType *refU, const FLType *refV) const; - PosPairCode BlockMatching(const FLType *ref, PCType j, PCType i); + PosPairCode BlockMatching(const FLType *ref, PCType j, PCType i) const; virtual void CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, - const PosPairCode &code) = 0; + const PosPairCode &code) const = 0; }; diff --git a/include/BM3D_Basic.h b/include/BM3D_Basic.h index 0dc4d77..8f97839 100644 --- a/include/BM3D_Basic.h +++ b/include/BM3D_Basic.h @@ -85,7 +85,7 @@ class BM3D_Basic_Process virtual void CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, - const PosPairCode &code) override; + const PosPairCode &code) const override; }; diff --git a/include/BM3D_Final.h b/include/BM3D_Final.h index af38dac..5ed3bef 100644 --- a/include/BM3D_Final.h +++ b/include/BM3D_Final.h @@ -85,7 +85,7 @@ class BM3D_Final_Process virtual void CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, - const PosPairCode &code) override; + const PosPairCode &code) const override; }; diff --git a/include/Block.h b/include/Block.h index d2fb022..c081858 100644 --- a/include/Block.h +++ b/include/Block.h @@ -76,15 +76,15 @@ class Block } template < typename _St2, typename _Fn1 > - void for_each(_St2 &right, _Fn1 _Func) + void for_each(_St2 &data2, _Fn1 _Func) { - Block_For_each(*this, right, _Func); + Block_For_each(*this, data2, _Func); } template < typename _St2, typename _Fn1 > - void for_each(_St2 &right, _Fn1 _Func) const + void for_each(_St2 &data2, _Fn1 _Func) const { - Block_For_each(*this, right, _Func); + Block_For_each(*this, data2, _Func); } template < typename _Fn1 > @@ -103,7 +103,7 @@ class Block // Default constructor Block() {} - Block(PCType _Height, PCType _Width, PosType pos, bool Init = true, value_type Value = 0) + Block(PCType _Height, PCType _Width, const PosType &pos, bool Init = true, value_type Value = 0) : Height_(_Height), Width_(_Width), PixelCount_(Height_ * Width_), pos_(pos) { AlignedMalloc(Data_, size()); @@ -113,13 +113,13 @@ class Block // Constructor from plane pointer and PosType template < typename _St1 > - Block(const _St1 *src, PCType src_stride, PCType _Height, PCType _Width, PosType pos) + Block(const _St1 *src, PCType src_stride, PCType _Height, PCType _Width, const PosType &pos) : Block(_Height, _Width, pos, false) { From(src, src_stride); } - // Constructor from src + // Constructor from src Block Block(const _Myt &src, bool Init, value_type Value = 0) : Block(src.Height_, src.Width_, src.pos_, Init, Value) {} @@ -275,10 +275,9 @@ class Block } template < typename _St1 > - void From(const _St1 *src, PCType src_stride, PosType pos) + void From(const _St1 *src, PCType src_stride, const PosType &pos) { pos_ = pos; - From(src, src_stride); } @@ -303,7 +302,7 @@ class Block // Accumulate functions template < typename _St1 > - void AddFrom(const _St1 *src, PCType src_stride, PosType pos) + void AddFrom(const _St1 *src, PCType src_stride, const PosType &pos) { auto dstp = data(); auto srcp = src + pos.y * src_stride + pos.x; @@ -320,7 +319,7 @@ class Block } template < typename _St1, typename _Gt1 > - void AddFrom(const _St1 *src, PCType src_stride, PosType pos, _Gt1 gain) + void AddFrom(const _St1 *src, PCType src_stride, const PosType &pos, _Gt1 gain) { auto dstp = data(); auto srcp = src + pos.y * src_stride + pos.x; @@ -403,7 +402,7 @@ class Block } //////////////////////////////////////////////////////////////// - // Block matching functions + // Single block-matching functions template < typename _St1 > PosPair BlockMatching(const _St1 *src, PCType src_height, PCType src_width, PCType src_stride, _St1 src_range, @@ -474,6 +473,9 @@ class Block return PosPair(static_cast(distMin * distMul), pos); } + //////////////////////////////////////////////////////////////// + // Multiple block-matching functions + template < typename _St1 > void BlockMatchingMulti(PosPairCode &match_code, const _St1 *src, PCType src_stride, _St1 src_range, const PosCode &search_pos, double thMSE) const @@ -798,7 +800,7 @@ class BlockGroup // Constructor from plane pointer and Pos3PairCode template < typename _St1 > - BlockGroup(std::vector src, PCType src_stride, const Pos3PairCode &code, + BlockGroup(const std::vector &src, PCType src_stride, const Pos3PairCode &code, PCType _GroupSize = -1, PCType _Height = 16, PCType _Width = 16) : Height_(_Height), Width_(_Width) { @@ -984,7 +986,7 @@ class BlockGroup // Read/Store functions template < typename _St1 > - _Myt &From(const _St1 *src, PCType src_stride) + void From(const _St1 *src, PCType src_stride) { auto dstp = data(); @@ -1002,12 +1004,10 @@ class BlockGroup } } } - - return *this; } template < typename _St1 > - _Myt &From(std::vector src, PCType src_stride) + void From(const std::vector &src, PCType src_stride) { auto dstp = data(); @@ -1025,8 +1025,6 @@ class BlockGroup } } } - - return *this; } template < typename _Dt1 > @@ -1051,7 +1049,7 @@ class BlockGroup } template < typename _Dt1 > - void To(std::vector<_Dt1 *> dst, PCType dst_stride) const + void To(const std::vector<_Dt1 *> &dst, PCType dst_stride) const { auto srcp = data(); @@ -1117,7 +1115,7 @@ class BlockGroup } template < typename _Dt1 > - void AddTo(std::vector<_Dt1 *> dst, PCType dst_stride) const + void AddTo(const std::vector<_Dt1 *> &dst, PCType dst_stride) const { auto srcp = data(); @@ -1138,7 +1136,7 @@ class BlockGroup } template < typename _Dt1, typename _Gt1 > - void AddTo(std::vector<_Dt1 *> dst, PCType dst_stride, _Gt1 gain) const + void AddTo(const std::vector<_Dt1 *> &dst, PCType dst_stride, _Gt1 gain) const { auto srcp = data(); @@ -1197,7 +1195,7 @@ class BlockGroup } template < typename _Dt1 > - void CountTo(std::vector<_Dt1 *> dst, PCType dst_stride) const + void CountTo(const std::vector<_Dt1 *> &dst, PCType dst_stride) const { for (PCType z = 0; z < GroupSize(); ++z) { @@ -1216,7 +1214,7 @@ class BlockGroup } template < typename _Dt1 > - void CountTo(std::vector<_Dt1 *> dst, PCType dst_stride, _Dt1 value) const + void CountTo(const std::vector<_Dt1 *> &dst, PCType dst_stride, _Dt1 value) const { for (PCType z = 0; z < GroupSize(); ++z) { @@ -1239,7 +1237,7 @@ class BlockGroup //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -template < typename _St1, typename _Fn1 > inline +template < typename _St1, typename _Fn1 > void Block_For_each(_St1 &data, _Fn1 &&_Func) { auto datap = data.data(); @@ -1250,26 +1248,24 @@ void Block_For_each(_St1 &data, _Fn1 &&_Func) } } -template < typename _St1, typename _St2, typename _Fn1 > inline -void Block_For_each(_St1 &left, _St2 &right, _Fn1 &&_Func) +template < typename _St1, typename _St2, typename _Fn1 > +void Block_For_each(_St1 &data1, _St2 &data2, _Fn1 &&_Func) { - const char *FunctionName = "Block_For_each"; - if (left.size() != right.size()) + if (data1.size() != data2.size()) { - std::cerr << FunctionName << ": size() of left and right must be the same.\n"; - exit(EXIT_FAILURE); + DEBUG_FAIL("Block_For_each: size() of data1 and data2 must be the same."); } - auto leftp = left.data(); - auto rightp = right.data(); + auto data1p = data1.data(); + auto data2p = data2.data(); - for (auto upper = leftp + left.size(); leftp != upper; ++leftp, ++rightp) + for (auto upper = data1p + data1.size(); data1p != upper; ++data1p, ++data2p) { - _Func(*leftp, *rightp); + _Func(*data1p, *data2p); } } -template < typename _St1, typename _Fn1 > inline +template < typename _St1, typename _Fn1 > void Block_Transform(_St1 &data, _Fn1 &&_Func) { auto datap = data.data(); @@ -1280,14 +1276,12 @@ void Block_Transform(_St1 &data, _Fn1 &&_Func) } } -template < typename _Dt1, typename _St1, typename _Fn1 > inline +template < typename _Dt1, typename _St1, typename _Fn1 > void Block_Transform(_Dt1 &dst, const _St1 &src, _Fn1 &&_Func) { - const char *FunctionName = "Block_Transform"; if (dst.size() != src.size()) { - std::cerr << FunctionName << ": size() of dst and src must be the same.\n"; - exit(EXIT_FAILURE); + DEBUG_FAIL("Block_Transform: size() of dst and src must be the same."); } auto dstp = dst.data(); diff --git a/include/Helper.h b/include/Helper.h index 555a7d3..6665f3d 100644 --- a/include/Helper.h +++ b/include/Helper.h @@ -33,6 +33,20 @@ //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Exception handle + + +#ifdef _DEBUG +#define DEBUG_BREAK __debugbreak(); +#define DEBUG_FAIL(mesg) __debugbreak(); _STD _DEBUG_ERROR(mesg); +#else +#define DEBUG_BREAK exit(EXIT_FAILURE); +#define DEBUG_FAIL(mesg) std::cerr << mesg << std::endl; exit(EXIT_FAILURE); +#endif + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Convert to std::string template < typename _Ty > @@ -45,6 +59,70 @@ std::string GetStr(const _Ty &src) //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Memory allocation + + +const size_t MEMORY_ALIGNMENT = 32; + + +template < typename _Ty > +void AlignedMalloc(_Ty *&Memory, size_t Count, size_t Alignment = MEMORY_ALIGNMENT) +{ + Memory = vs_aligned_malloc<_Ty>(sizeof(_Ty) * Count, Alignment); +} + + +template < typename _Ty > +void AlignedFree(_Ty *&Memory) +{ + vs_aligned_free(Memory); + Memory = nullptr; +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// 2D array copy + + +template < typename _Dt1, typename _St1 > +void MatCopy(_Dt1 *dstp, const _St1 *srcp, PCType height, PCType width, PCType dst_stride, PCType src_stride) +{ + for (PCType j = 0; j < height; ++j) + { + for (PCType i = 0; i < width; ++i) + { + dstp[i] = static_cast<_Dt1>(srcp[i]); + } + + dstp += dst_stride; + srcp += src_stride; + } +} + +template < typename _Ty > +void MatCopy(_Ty *dstp, const _Ty *srcp, PCType height, PCType width, PCType dst_stride, PCType src_stride) +{ + if (height > 0) + { + if (src_stride == dst_stride && src_stride == width) + { + memcpy(dstp, srcp, sizeof(_Ty) * height * width); + } + else + { + for (PCType j = 0; j < height; ++j) + { + memcpy(dstp, srcp, sizeof(_Ty) * width); + dstp += dst_stride; + srcp += src_stride; + } + } + } +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Loop in 2D array #define LOOP_V _Loop_V @@ -123,6 +201,7 @@ void _Loop_VH(const PCType height, const PCType width, const PCType stride0, con //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Quantization parameters template < typename _Ty > @@ -223,76 +302,75 @@ bool isPCChroma(_Ty Floor, _Ty Neutral, _Ty Ceil) //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -template < typename T > -T Min(T a, T b) +template < typename _Ty > +_Ty Min(const _Ty &a, const _Ty &b) { return b < a ? b : a; } -template < typename T > -T Max(T a, T b) +template < typename _Ty > +_Ty Max(const _Ty &a, const _Ty &b) { return b > a ? b : a; } -template < typename T > -T Clip(T input, T Floor, T Ceil) +template < typename _Ty > +_Ty Clip(const _Ty &input, const _Ty &lower, const _Ty &upper) { - return input <= Floor ? Floor : input >= Ceil ? Ceil : input; + return input <= lower ? lower : input >= upper ? upper : input; } -template -inline T Abs(T input) + +template < typename _Ty > +_Ty Abs(const _Ty &input) { return input < 0 ? -input : input; } -template -inline T Round_Div(T dividend, T divisor) +template < typename _Ty > +_Ty AbsSub(const _Ty &a, const _Ty &b) { - return (dividend + divisor / 2) / divisor; + return b < a ? a - b : b - a; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - -const size_t MEMORY_ALIGNMENT = 32; - - template < typename _Ty > -void AlignedMalloc(_Ty *&Memory, size_t Count, size_t Alignment = MEMORY_ALIGNMENT) +_Ty _RoundDiv(_Ty dividend, _Ty divisor, const std::false_type &) { - Memory = vs_aligned_malloc<_Ty>(sizeof(_Ty) * Count, Alignment); + return (dividend + divisor / 2) / divisor; } +template < typename _Ty > +_Ty _RoundDiv(_Ty dividend, _Ty divisor, const std::true_type &) +{ + return dividend / divisor; +} template < typename _Ty > -void AlignedFree(_Ty *&Memory) +_Ty RoundDiv(_Ty dividend, _Ty divisor) { - vs_aligned_free(Memory); - Memory = nullptr; + return _RoundDiv(dividend, divisor, _IsFloat<_Ty>); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -template < typename T > +template < typename _Ty > int stride_cal(int width) { - size_t Alignment2 = MEMORY_ALIGNMENT / sizeof(T); + size_t Alignment2 = MEMORY_ALIGNMENT / sizeof(_Ty); return static_cast(width % Alignment2 == 0 ? width : (width / Alignment2 + 1) * Alignment2); } -template < typename T > -void data2buff(T *dst, const T *src, int xoffset, int yoffset, +template < typename _Ty > +void data2buff(_Ty *dst, const _Ty *src, int xoffset, int yoffset, int bufheight, int bufwidth, int bufstride, int height, int width, int stride) { int x, y; - T *dstp; - const T *srcp; + _Ty *dstp; + const _Ty *srcp; for (y = 0; y < height; ++y) { @@ -300,7 +378,7 @@ void data2buff(T *dst, const T *src, int xoffset, int yoffset, srcp = src + y * stride; for (x = 0; x < xoffset; ++x) dstp[x] = srcp[0]; - memcpy(dstp + xoffset, srcp, sizeof(T) * width); + memcpy(dstp + xoffset, srcp, sizeof(_Ty) * width); for (x = xoffset + width; x < bufwidth; ++x) dstp[x] = srcp[width - 1]; } @@ -309,29 +387,29 @@ void data2buff(T *dst, const T *src, int xoffset, int yoffset, for (y = 0; y < yoffset; ++y) { dstp = dst + y * bufstride; - memcpy(dstp, srcp, sizeof(T) * bufwidth); + memcpy(dstp, srcp, sizeof(_Ty) * bufwidth); } srcp = dst + (yoffset + height - 1) * bufstride; for (y = yoffset + height; y < bufheight; ++y) { dstp = dst + y * bufstride; - memcpy(dstp, srcp, sizeof(T) * bufwidth); + memcpy(dstp, srcp, sizeof(_Ty) * bufwidth); } } -template < typename T > -T *newbuff(const T *src, int xoffset, int yoffset, +template < typename _Ty > +_Ty *newbuff(const _Ty *src, int xoffset, int yoffset, int bufheight, int bufwidth, int bufstride, int height, int width, int stride) { - T *dst; + _Ty *dst; AlignedMalloc(dst, bufheight * bufstride); data2buff(dst, src, xoffset, yoffset, bufheight, bufwidth, bufstride, height, width, stride); return dst; } -template < typename T > -void freebuff(T *&buff) +template < typename _Ty > +void freebuff(_Ty *&buff) { AlignedFree(buff); } diff --git a/include/Type.h b/include/Type.h index 2b1d5bd..fe047c9 100644 --- a/include/Type.h +++ b/include/Type.h @@ -47,36 +47,38 @@ typedef float FLType; // determine whether _Ty satisfies Signed Int requirements -template -struct _IsSInt : std::integral_constant::value || - std::is_same<_Ty, short>::value || - std::is_same<_Ty, int>::value || - std::is_same<_Ty, long>::value || - std::is_same<_Ty, long long>::value> +template < typename _Ty > +struct _IsSInt + : std::integral_constant::value + || std::is_same<_Ty, short>::value + || std::is_same<_Ty, int>::value + || std::is_same<_Ty, long>::value + || std::is_same<_Ty, long long>::value> {}; // determine whether _Ty satisfies Unsigned Int requirements -template -struct _IsUInt : std::integral_constant::value || - std::is_same<_Ty, unsigned short>::value || - std::is_same<_Ty, unsigned int>::value || - std::is_same<_Ty, unsigned long>::value || - std::is_same<_Ty, unsigned long long>::value> +template < typename _Ty > +struct _IsUInt + : std::integral_constant::value + || std::is_same<_Ty, unsigned short>::value + || std::is_same<_Ty, unsigned int>::value + || std::is_same<_Ty, unsigned long>::value + || std::is_same<_Ty, unsigned long long>::value> {}; // determine whether _Ty satisfies Int requirements -template -struct _IsInt : std::integral_constant::value || - _IsUInt<_Ty>::value> +template < typename _Ty > +struct _IsInt + : std::integral_constant::value + || _IsUInt<_Ty>::value> {}; // determine whether _Ty satisfies Float requirements -template -struct _IsFloat : std::integral_constant::value> +template < typename _Ty > +struct _IsFloat + : std::integral_constant::value + || std::is_same<_Ty, double>::value + || std::is_same<_Ty, long double>::value> {}; diff --git a/include/VBM3D_Base.h b/include/VBM3D_Base.h index 8a6ff4b..23cbf87 100644 --- a/include/VBM3D_Base.h +++ b/include/VBM3D_Base.h @@ -337,18 +337,18 @@ class VBM3D_Process_Base vsapi->propSetIntArray(dst_map, "BM3D_V_process", process, VSMaxPlaneCount); } - void Kernel(std::vector &dst, std::vector &src, std::vector &ref); + void Kernel(const std::vector &dst, const std::vector &src, const std::vector &ref) const; - void Kernel(std::vector &dstY, std::vector &dstU, std::vector &dstV, - std::vector &srcY, std::vector &srcU, std::vector &srcV, - std::vector &refY, std::vector &refU, std::vector &refV); + void Kernel(const std::vector &dstY, const std::vector &dstU, const std::vector &dstV, + const std::vector &srcY, const std::vector &srcU, const std::vector &srcV, + const std::vector &refY, const std::vector &refU, const std::vector &refV) const; - Pos3PairCode BlockMatching(std::vector &ref, PCType j, PCType i); + Pos3PairCode BlockMatching(const std::vector &ref, PCType j, PCType i) const; virtual void CollaborativeFilter(int plane, - std::vector &ResNum, std::vector &ResDen, - std::vector &src, std::vector &ref, - const Pos3PairCode &code) = 0; + const std::vector &ResNum, const std::vector &ResDen, + const std::vector &src, const std::vector &ref, + const Pos3PairCode &code) const = 0; }; diff --git a/include/VBM3D_Basic.h b/include/VBM3D_Basic.h index 5c50c70..e2b70e3 100644 --- a/include/VBM3D_Basic.h +++ b/include/VBM3D_Basic.h @@ -83,9 +83,9 @@ class VBM3D_Basic_Process protected: virtual void CollaborativeFilter(int plane, - std::vector &ResNum, std::vector &ResDen, - std::vector &src, std::vector &ref, - const Pos3PairCode &code) override; + const std::vector &ResNum, const std::vector &ResDen, + const std::vector &src, const std::vector &ref, + const Pos3PairCode &code) const override; }; diff --git a/include/VBM3D_Final.h b/include/VBM3D_Final.h index a88f0ed..31c791e 100644 --- a/include/VBM3D_Final.h +++ b/include/VBM3D_Final.h @@ -83,9 +83,9 @@ class VBM3D_Final_Process protected: virtual void CollaborativeFilter(int plane, - std::vector &ResNum, std::vector &ResDen, - std::vector &src, std::vector &ref, - const Pos3PairCode &code) override; + const std::vector &ResNum, const std::vector &ResDen, + const std::vector &src, const std::vector &ref, + const Pos3PairCode &code) const override; }; diff --git a/source/BM3D.cpp b/source/BM3D.cpp index 387c1cf..a065464 100644 --- a/source/BM3D.cpp +++ b/source/BM3D.cpp @@ -35,7 +35,7 @@ BM3D_Para::BM3D_Para(bool _wiener, std::string _profile) { BlockStep = 4; GroupSize = 16; - lambda = 2.2; + lambda = 2.7; } else { @@ -88,7 +88,7 @@ BM3D_Para::BM3D_Para(bool _wiener, std::string _profile) { BlockStep = 4; GroupSize = 32; - lambda = 2.3; + lambda = 2.8; } else { @@ -133,8 +133,8 @@ BM3D_FilterData::BM3D_FilterData(bool wiener, double sigma, PCType GroupSize, PC wienerSigmaSqr(wiener ? GroupSize : 0) { const unsigned int flags = FFTW_PATIENT; - const fftw::r2r_kind fkind = FFTW_REDFT01; - const fftw::r2r_kind bkind = FFTW_REDFT10; + const fftw::r2r_kind fkind = FFTW_REDFT10; + const fftw::r2r_kind bkind = FFTW_REDFT01; FLType *temp = nullptr; @@ -156,10 +156,11 @@ BM3D_FilterData::BM3D_FilterData(bool wiener, double sigma, PCType GroupSize, PC { double thrBase = sigma * lambda * forwardAMP; std::vector thr(4); + thr[0] = thrBase; - thr[1] = thrBase; - thr[2] = thrBase / double(2); - thr[3] = 0; + thr[1] = thrBase * sqrt(double(2)); + thr[2] = thrBase * double(2); + thr[3] = thrBase * sqrt(double(8)); thrTable[i - 1] = std::vector(i * BlockSize * BlockSize); auto thr_d = thrTable[i - 1].data(); @@ -171,64 +172,21 @@ BM3D_FilterData::BM3D_FilterData(bool wiener, double sigma, PCType GroupSize, PC for (PCType x = 0; x < BlockSize; ++x, ++thr_d) { int flag = 0; - double scale = 1; if (x == 0) { ++flag; } - else if (x < BlockSize / 4) - { - scale *= 1.00; - } - else if (x < BlockSize / 2) - { - scale *= 1.01; - } - else - { - scale *= 1.07; - } - if (y == 0) { ++flag; } - else if (y < BlockSize / 4) - { - scale *= 1.00; - } - else if (y < BlockSize / 2) - { - scale *= 1.01; - } - else - { - scale *= 1.07; - } - if (z == 0) { ++flag; } - else if (z < i / 8) - { - scale *= 1.01; - } - else if (z < i / 4) - { - scale *= 1.07; - } - else if (z < i / 2) - { - scale *= 1.16; - } - else - { - scale *= 1.40; - } - *thr_d = static_cast(thr[flag] * scale); + *thr_d = static_cast(thr[flag]); } } } diff --git a/source/BM3D_Base.cpp b/source/BM3D_Base.cpp index 2f21b77..ce43b5e 100644 --- a/source/BM3D_Base.cpp +++ b/source/BM3D_Base.cpp @@ -295,7 +295,7 @@ void BM3D_Data_Base::init_filter_data() // Functions of class BM3D_Process_Base -void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref) +void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref) const { FLType *ResNum = dst, *ResDen = nullptr; @@ -354,7 +354,7 @@ void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV, const FLType *srcY, const FLType *srcU, const FLType *srcV, - const FLType *refY, const FLType *refU, const FLType *refV) + const FLType *refY, const FLType *refU, const FLType *refV) const { FLType *ResNumY = dstY, *ResDenY = nullptr; FLType *ResNumU = dstU, *ResDenU = nullptr; @@ -442,7 +442,7 @@ void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV, BM3D_Process_Base::PosPairCode BM3D_Process_Base::BlockMatching( - const FLType *ref, PCType j, PCType i) + const FLType *ref, PCType j, PCType i) const { // Skip block matching if GroupSize is 1 or thMSE is not positive, // and take the reference block as the only element in the group diff --git a/source/BM3D_Basic.cpp b/source/BM3D_Basic.cpp index 2fe2dc6..30df5c6 100644 --- a/source/BM3D_Basic.cpp +++ b/source/BM3D_Basic.cpp @@ -60,7 +60,7 @@ int BM3D_Basic_Data::arguments_process(const VSMap *in, VSMap *out) void BM3D_Basic_Process::CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, - const PosPairCode &code) + const PosPairCode &code) const { PCType GroupSize = static_cast(code.size()); // When para.GroupSize > 0, limit GroupSize up to para.GroupSize @@ -78,7 +78,7 @@ void BM3D_Basic_Process::CollaborativeFilter(int plane, // Apply forward 3D transform to the source group d.f[plane].fp[GroupSize - 1].execute_r2r(srcGroup.data(), srcGroup.data()); - // Apply hard threshold filtering to the source group + // Apply hard-thresholding to the source group Block_For_each(srcGroup, d.f[plane].thrTable[GroupSize - 1], [&](FLType &x, FLType y) { if (Abs(x) <= y) diff --git a/source/BM3D_Final.cpp b/source/BM3D_Final.cpp index 590b1c0..c00d803 100644 --- a/source/BM3D_Final.cpp +++ b/source/BM3D_Final.cpp @@ -45,7 +45,7 @@ int BM3D_Final_Data::arguments_process(const VSMap *in, VSMap *out) void BM3D_Final_Process::CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, - const PosPairCode &code) + const PosPairCode &code) const { PCType GroupSize = static_cast(code.size()); // When para.GroupSize > 0, limit GroupSize up to para.GroupSize diff --git a/source/VBM3D_Base.cpp b/source/VBM3D_Base.cpp index 88096be..0ad98cd 100644 --- a/source/VBM3D_Base.cpp +++ b/source/VBM3D_Base.cpp @@ -419,7 +419,8 @@ void VBM3D_Data_Base::init_filter_data() // Functions of class VBM3D_Process_Base -void VBM3D_Process_Base::Kernel(std::vector &dst, std::vector &src, std::vector &ref) +void VBM3D_Process_Base::Kernel(const std::vector &dst, + const std::vector &src, const std::vector &ref) const { std::vector ResNum(frames), ResDen(frames); @@ -472,9 +473,9 @@ void VBM3D_Process_Base::Kernel(std::vector &dst, std::vector &dstY, std::vector &dstU, std::vector &dstV, - std::vector &srcY, std::vector &srcU, std::vector &srcV, - std::vector &refY, std::vector &refU, std::vector &refV) +void VBM3D_Process_Base::Kernel(const std::vector &dstY, const std::vector &dstU, const std::vector &dstV, + const std::vector &srcY, const std::vector &srcU, const std::vector &srcV, + const std::vector &refY, const std::vector &refU, const std::vector &refV) const { std::vector ResNumY(frames), ResDenY(frames); std::vector ResNumU(frames), ResDenU(frames); @@ -556,7 +557,7 @@ void VBM3D_Process_Base::Kernel(std::vector &dstY, std::vector &ref, PCType j, PCType i) + const std::vector &ref, PCType j, PCType i) const { // Skip block matching if GroupSize is 1 or thMSE is not positive, // and take the reference block as the only element in the group diff --git a/source/VBM3D_Basic.cpp b/source/VBM3D_Basic.cpp index 6b9cc8b..20b14ab 100644 --- a/source/VBM3D_Basic.cpp +++ b/source/VBM3D_Basic.cpp @@ -58,9 +58,9 @@ int VBM3D_Basic_Data::arguments_process(const VSMap *in, VSMap *out) void VBM3D_Basic_Process::CollaborativeFilter(int plane, - std::vector &ResNum, std::vector &ResDen, - std::vector &src, std::vector &ref, - const Pos3PairCode &code) + const std::vector &ResNum, const std::vector &ResDen, + const std::vector &src, const std::vector &ref, + const Pos3PairCode &code) const { PCType GroupSize = static_cast(code.size()); // When para.GroupSize > 0, limit GroupSize up to para.GroupSize @@ -78,7 +78,7 @@ void VBM3D_Basic_Process::CollaborativeFilter(int plane, // Apply forward 3D transform to the source group d.f[plane].fp[GroupSize - 1].execute_r2r(srcGroup.data(), srcGroup.data()); - // Apply hard threshold filtering to the source group + // Apply hard-thresholding to the source group Block_For_each(srcGroup, d.f[plane].thrTable[GroupSize - 1], [&](FLType &x, FLType y) { if (Abs(x) <= y) diff --git a/source/VBM3D_Final.cpp b/source/VBM3D_Final.cpp index f498d3b..74524ea 100644 --- a/source/VBM3D_Final.cpp +++ b/source/VBM3D_Final.cpp @@ -43,9 +43,9 @@ int VBM3D_Final_Data::arguments_process(const VSMap *in, VSMap *out) void VBM3D_Final_Process::CollaborativeFilter(int plane, - std::vector &ResNum, std::vector &ResDen, - std::vector &src, std::vector &ref, - const Pos3PairCode &code) + const std::vector &ResNum, const std::vector &ResDen, + const std::vector &src, const std::vector &ref, + const Pos3PairCode &code) const { PCType GroupSize = static_cast(code.size()); // When para.GroupSize > 0, limit GroupSize up to para.GroupSize