diff --git a/README.md b/README.md index 2cef750..1417c0d 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ The input clip is processed in 3-step stages, For each reference block:
This final estimate can be realized as a refinement. It can significantly improve the denoising quality, keeping more details and fine structures that were removed in basic estimate. ```python -bm3d.Final(clip input, clip ref[, string profile="fast", float[] sigma=[10,10,10], int block_size, int block_step, int group_size, int bm_range, int bm_step, float th_mse, int matrix=2]) +bm3d.Final(clip input, clip ref[, clip wref=ref, string profile="fast", float[] sigma=[10,10,10], int block_size, int block_step, int group_size, int bm_range, int bm_step, float th_mse, int matrix=2]) ``` - input:
@@ -195,6 +195,21 @@ bm3d.Final(clip input, clip ref[, string profile="fast", float[] sigma=[10,10,10 It must be specified. In original BM3D algorithm, it is the basic estimate.
Alternatively, you can choose any other decent denoising filter as basic estimate, and take this final estimate as a refinement. +- wref:
+ The reference clip for empirical Wiener filtering. If specified, wref will replace ref as the empirical estimate for Wiener filtering.
+ You should specify this parameter if input and ref are sampled from different domains, e.g. input might be a difference image while ref could be a natural image.
+ In such case you should assign something sampled from the input’s domain to wref.
+ A common use case of wref would be to refine the denoising result of an alternative denoiser using BM3D. + +```python +flt = some_filter(src) +dif = core.std.MakeDiff(src, flt) +ref_dif = core.bm3d.Basic(dif, flt) +ref = core.std.MergeDiff(flt, ref_dif) +dif = core.bm3d.Final(dif, ref, wref=ref_dif) +flt = core.std.MergeDiff(flt, dif) +``` + - profile, sigma, block_size, block_step, group_size, bm_range, bm_step, th_mse, matrix:
Same as those in bm3d.Basic. @@ -252,10 +267,10 @@ bm3d.VBasic(clip input[, clip ref=input, string profile="fast", float[] sigma=[1 #### final estimate of V-BM3D denoising filter ```python -bm3d.VFinal(clip input, clip ref[, string profile="fast", float[] sigma=[10,10,10], int radius, int block_size, int block_step, int group_size, int bm_range, int bm_step, int ps_num, int ps_range, int ps_step, float th_mse, int matrix=2]) +bm3d.VFinal(clip input, clip ref[, clip wref=ref, string profile="fast", float[] sigma=[10,10,10], int radius, int block_size, int block_step, int group_size, int bm_range, int bm_step, int ps_num, int ps_range, int ps_step, float th_mse, int matrix=2]) ``` -- input, ref:
+- input, ref, wref:
Same as those in bm3d.Final. - profile, sigma, block_size, block_step, group_size, bm_range, bm_step, th_mse, matrix:
diff --git a/include/BM3D_Base.h b/include/BM3D_Base.h index 219d1c2..54738ed 100644 --- a/include/BM3D_Base.h +++ b/include/BM3D_Base.h @@ -47,6 +47,10 @@ class BM3D_Data_Base VSNodeRef *rnode = nullptr; const VSVideoInfo *rvi = nullptr; + bool wdef = false; + VSNodeRef* wnode = nullptr; + const VSVideoInfo* wvi = nullptr; + bool wiener; ColorMatrix matrix; @@ -126,11 +130,19 @@ class BM3D_Process_Base const VSFrameRef *ref = nullptr; const VSFormat *rfi = nullptr; + const VSFrameRef* wref = nullptr; + const VSFormat* wfi = nullptr; + PCType ref_height[VSMaxPlaneCount]; PCType ref_width[VSMaxPlaneCount]; PCType ref_stride[VSMaxPlaneCount]; PCType ref_pcount[VSMaxPlaneCount]; + PCType wref_height[VSMaxPlaneCount]; + PCType wref_width[VSMaxPlaneCount]; + PCType wref_stride[VSMaxPlaneCount]; + PCType wref_pcount[VSMaxPlaneCount]; + bool full = true; private: @@ -220,11 +232,12 @@ class BM3D_Process_Base _NewFrame(width, height, dfi == fi); } - void Kernel(FLType *dst, const FLType *src, const FLType *ref) const; + void Kernel(FLType *dst, const FLType *src, const FLType *ref, const FLType *wref) const; void Kernel(FLType *dstY, FLType *dstU, FLType *dstV, const FLType *srcY, const FLType *srcU, const FLType *srcV, - const FLType *refY, const FLType *refU, const FLType *refV) const; + const FLType *refY, + const FLType *wrefY, const FLType *wrefU, const FLType *wrefV) const; PosPairCode BlockMatching(const FLType *ref, PCType j, PCType i) const; diff --git a/include/BM3D_Final.h b/include/BM3D_Final.h index 2debbf2..f11de35 100644 --- a/include/BM3D_Final.h +++ b/include/BM3D_Final.h @@ -49,7 +49,9 @@ class BM3D_Final_Data _Myt &operator=(const _Myt &right) = delete; _Myt &operator=(_Myt &&right) = delete; - virtual ~BM3D_Final_Data() override {} + virtual ~BM3D_Final_Data() override { + if (wdef && wnode) vsapi->freeNode(wnode); + } virtual int arguments_process(const VSMap *in, VSMap *out) override; }; @@ -70,11 +72,29 @@ class BM3D_Final_Process const _Mydata &d; public: - BM3D_Final_Process(_Mydata &_d, int _n, VSFrameContext *_frameCtx, VSCore *_core, const VSAPI *_vsapi) - : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d) - {} - - virtual ~BM3D_Final_Process() override {} + BM3D_Final_Process(_Mydata& _d, int _n, VSFrameContext* _frameCtx, VSCore* _core, const VSAPI* _vsapi) + : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d) { + if (d.wdef) { + wref = vsapi->getFrameFilter(n, d.wnode, frameCtx); + wfi = vsapi->getFrameFormat(wref); + } + else { + wref = ref; + wfi = rfi; + } + + if (!skip) + for (int i = 0; i < PlaneCount; ++i) { + wref_height[i] = vsapi->getFrameHeight(wref, i); + wref_width[i] = vsapi->getFrameWidth(wref, i); + wref_stride[i] = vsapi->getStride(wref, i) / wfi->bytesPerSample; + wref_pcount[i] = wref_height[i] * wref_stride[i]; + } + } + + virtual ~BM3D_Final_Process() override { + if (d.wdef) vsapi->freeFrame(wref); + } protected: virtual void CollaborativeFilter(int plane, diff --git a/include/VBM3D_Base.h b/include/VBM3D_Base.h index 831c3f4..000bee0 100644 --- a/include/VBM3D_Base.h +++ b/include/VBM3D_Base.h @@ -63,6 +63,10 @@ class VBM3D_Data_Base VSNodeRef *rnode = nullptr; const VSVideoInfo *rvi = nullptr; + bool wdef = false; + VSNodeRef* wnode = nullptr; + const VSVideoInfo* wvi = nullptr; + bool wiener; ColorMatrix matrix; @@ -135,8 +139,10 @@ class VBM3D_Process_Base std::vector v_src; std::vector v_ref; - + std::vector v_wref; + const VSFormat *rfi = nullptr; + const VSFormat* wfi = nullptr; PCType dst_height[VSMaxPlaneCount]; PCType dst_pcount[VSMaxPlaneCount]; @@ -146,6 +152,11 @@ class VBM3D_Process_Base PCType ref_stride[VSMaxPlaneCount]; PCType ref_pcount[VSMaxPlaneCount]; + PCType wref_height[VSMaxPlaneCount]; + PCType wref_width[VSMaxPlaneCount]; + PCType wref_stride[VSMaxPlaneCount]; + PCType wref_pcount[VSMaxPlaneCount]; + bool full = true; private: @@ -310,11 +321,12 @@ class VBM3D_Process_Base vsapi->propSetIntArray(dst_map, "BM3D_V_process", process, VSMaxPlaneCount); } - void Kernel(const std::vector &dst, const std::vector &src, const std::vector &ref) const; + void Kernel(const std::vector &dst, const std::vector &src, const std::vector &ref, const std::vector &wref) const; void Kernel(const std::vector &dstY, const std::vector &dstU, const std::vector &dstV, const std::vector &srcY, const std::vector &srcU, const std::vector &srcV, - const std::vector &refY, const std::vector &refU, const std::vector &refV) const; + const std::vector &refY, + const std::vector &wrefY, const std::vector &wrefU, const std::vector &wrefV) const; Pos3PairCode BlockMatching(const std::vector &ref, PCType j, PCType i) const; diff --git a/include/VBM3D_Final.h b/include/VBM3D_Final.h index 9d61424..74b5c56 100644 --- a/include/VBM3D_Final.h +++ b/include/VBM3D_Final.h @@ -49,7 +49,9 @@ class VBM3D_Final_Data _Myt &operator=(const _Myt &right) = delete; _Myt &operator=(_Myt &&right) = delete; - virtual ~VBM3D_Final_Data() override {} + virtual ~VBM3D_Final_Data() override { + if (wdef && wnode) vsapi->freeNode(wnode); + } virtual int arguments_process(const VSMap *in, VSMap *out) override; }; @@ -71,10 +73,32 @@ class VBM3D_Final_Process public: VBM3D_Final_Process(const _Mydata &_d, int _n, VSFrameContext *_frameCtx, VSCore *_core, const VSAPI *_vsapi) - : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d) - {} - - virtual ~VBM3D_Final_Process() override {} + : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d) { + if (d.wdef) { + for (int o = b_offset; o <= f_offset; ++o) + v_wref.push_back(vsapi->getFrameFilter(n + o, d.wnode, frameCtx)); + + wfi = vsapi->getFrameFormat(v_wref[cur]); + } + else { + v_wref = v_ref; + wfi = rfi; + } + + if (!skip) + for (int i = 0; i < PlaneCount; ++i) { + wref_height[i] = vsapi->getFrameHeight(v_wref[cur], i); + wref_width[i] = vsapi->getFrameWidth(v_wref[cur], i); + wref_stride[i] = vsapi->getStride(v_wref[cur], i) / wfi->bytesPerSample; + wref_pcount[i] = wref_height[i] * wref_stride[i]; + } + } + + virtual ~VBM3D_Final_Process() override { + if (d.wdef) + for (int i = 0; i < frames; ++i) + vsapi->freeFrame(v_wref[i]); + } protected: virtual void CollaborativeFilter(int plane, diff --git a/msvc/BM3D.vcxproj b/msvc/BM3D.vcxproj index 6110b4c..6ff7b76 100644 --- a/msvc/BM3D.vcxproj +++ b/msvc/BM3D.vcxproj @@ -24,32 +24,32 @@ {E93BF469-E2BA-4C9A-9E5A-C08D085FE469} Win32Proj BM3D - 10.0.17763.0 + 10.0 DynamicLibrary true - v141 + v142 Unicode DynamicLibrary false - v141 + v142 true Unicode DynamicLibrary true - v141 + v142 Unicode DynamicLibrary false - v141 + v142 true Unicode @@ -125,6 +125,7 @@ false true true + stdcpplatest true diff --git a/source/BM3D_Base.cpp b/source/BM3D_Base.cpp index 1513bdd..b3faa80 100644 --- a/source/BM3D_Base.cpp +++ b/source/BM3D_Base.cpp @@ -291,10 +291,10 @@ void BM3D_Data_Base::init_filter_data() // Functions of class BM3D_Process_Base -void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref) const +void BM3D_Process_Base::Kernel(FLType* dst, const FLType* src, const FLType* ref, const FLType* wref) const { std::thread::id threadId = std::this_thread::get_id(); - FLType *ResNum = dst, *ResDen = nullptr; + FLType* ResNum = dst, * ResDen = nullptr; if (!d.buffer0.count(threadId)) { @@ -340,26 +340,27 @@ void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref PosPairCode matchCode = BlockMatching(ref, j, i); // Get the filtered result through collaborative filtering and aggregation of matched blocks - CollaborativeFilter(0, ResNum, ResDen, src, ref, matchCode); + CollaborativeFilter(0, ResNum, ResDen, src, wref, matchCode); } } // The filtered blocks are sumed and averaged to form the final filtered image LOOP_VH(dst_height[0], dst_width[0], dst_stride[0], [&](PCType i) - { - dst[i] = ResNum[i] / ResDen[i]; - }); + { + dst[i] = ResNum[i] / ResDen[i]; + }); } -void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV, - const FLType *srcY, const FLType *srcU, const FLType *srcV, - const FLType *refY, const FLType *refU, const FLType *refV) const +void BM3D_Process_Base::Kernel(FLType* dstY, FLType* dstU, FLType* dstV, + const FLType* srcY, const FLType* srcU, const FLType* srcV, + const FLType* refY, + const FLType* wrefY, const FLType* wrefU, const FLType* wrefV) const { std::thread::id threadId = std::this_thread::get_id(); - FLType *ResNumY = dstY, *ResDenY = nullptr; - FLType *ResNumU = dstU, *ResDenU = nullptr; - FLType *ResNumV = dstV, *ResDenV = nullptr; + FLType* ResNumY = dstY, * ResDenY = nullptr; + FLType* ResNumU = dstU, * ResDenU = nullptr; + FLType* ResNumV = dstV, * ResDenV = nullptr; if (d.process[0]) { @@ -440,27 +441,27 @@ void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV, PosPairCode matchCode = BlockMatching(refY, j, i); // Get the filtered result through collaborative filtering and aggregation of matched blocks - if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, refY, matchCode); - if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, refU, matchCode); - if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, refV, matchCode); + if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, wrefY, matchCode); + if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, wrefU, matchCode); + if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, wrefV, matchCode); } } // The filtered blocks are sumed and averaged to form the final filtered image if (d.process[0]) LOOP_VH(dst_height[0], dst_width[0], dst_stride[0], [&](PCType i) - { - dstY[i] = ResNumY[i] / ResDenY[i]; - }); + { + dstY[i] = ResNumY[i] / ResDenY[i]; + }); if (d.process[1]) LOOP_VH(dst_height[1], dst_width[1], dst_stride[1], [&](PCType i) - { - dstU[i] = ResNumU[i] / ResDenU[i]; - }); + { + dstU[i] = ResNumU[i] / ResDenU[i]; + }); if (d.process[2]) LOOP_VH(dst_height[2], dst_width[2], dst_stride[2], [&](PCType i) - { - dstV[i] = ResNumV[i] / ResDenV[i]; - }); + { + dstV[i] = ResNumV[i] / ResDenV[i]; + }); } @@ -512,25 +513,32 @@ void BM3D_Process_Base::process_core() template < typename _Ty > void BM3D_Process_Base::process_core_gray() { - FLType *dstYd = nullptr, *srcYd = nullptr, *refYd = nullptr; + FLType *dstYd = nullptr, *srcYd = nullptr, *refYd = nullptr, *wrefYd = nullptr; // Get write/read pointer auto dstY = reinterpret_cast<_Ty *>(vsapi->getWritePtr(dst, 0)); auto srcY = reinterpret_cast(vsapi->getReadPtr(src, 0)); auto refY = reinterpret_cast(vsapi->getReadPtr(ref, 0)); + auto wrefY = static_cast(nullptr); + if (wref != nullptr) + wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0)); + // Allocate memory for floating point Y data AlignedMalloc(dstYd, dst_pcount[0]); AlignedMalloc(srcYd, src_pcount[0]); if (d.rdef) AlignedMalloc(refYd, ref_pcount[0]); else refYd = srcYd; + if (d.wdef) AlignedMalloc(wrefYd, wref_pcount[0]); + else wrefYd = refYd; // Convert src and ref from integer Y data to floating point Y data Int2Float(srcYd, srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false); if (d.rdef) Int2Float(refYd, refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false); + if (d.wdef) Int2Float(wrefYd, wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false); // Execute kernel - Kernel(dstYd, srcYd, refYd); + Kernel(dstYd, srcYd, refYd, wrefYd); // Convert dst from floating point Y data to integer Y data Float2Int(dstY, dstYd, dst_height[0], dst_width[0], dst_stride[0], dst_stride[0], false, full, !isFloat(_Ty)); @@ -539,6 +547,7 @@ void BM3D_Process_Base::process_core_gray() AlignedFree(dstYd); AlignedFree(srcYd); if (d.rdef) AlignedFree(refYd); + if (d.wdef) AlignedFree(wrefYd); } template <> @@ -549,8 +558,12 @@ void BM3D_Process_Base::process_core_gray() auto srcY = reinterpret_cast(vsapi->getReadPtr(src, 0)); auto refY = reinterpret_cast(vsapi->getReadPtr(ref, 0)); + auto wrefY = static_cast(nullptr); + if (wref != nullptr) + wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0)); + // Execute kernel - Kernel(dstY, srcY, refY); + Kernel(dstY, srcY, refY, wrefY); } @@ -559,7 +572,8 @@ void BM3D_Process_Base::process_core_yuv() { FLType *dstYd = nullptr, *dstUd = nullptr, *dstVd = nullptr; FLType *srcYd = nullptr, *srcUd = nullptr, *srcVd = nullptr; - FLType *refYd = nullptr, *refUd = nullptr, *refVd = nullptr; + FLType *refYd = nullptr; + FLType *wrefYd = nullptr, *wrefUd = nullptr, *wrefVd = nullptr; // Get write/read pointer auto dstY = reinterpret_cast<_Ty *>(vsapi->getWritePtr(dst, 0)); @@ -574,6 +588,16 @@ void BM3D_Process_Base::process_core_yuv() auto refU = reinterpret_cast(vsapi->getReadPtr(ref, 1)); auto refV = reinterpret_cast(vsapi->getReadPtr(ref, 2)); + auto wrefY = static_cast(nullptr); + auto wrefU = static_cast(nullptr); + auto wrefV = static_cast(nullptr); + + if (wref != nullptr) { + wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0)); + wrefU = reinterpret_cast(vsapi->getReadPtr(wref, 1)); + wrefV = reinterpret_cast(vsapi->getReadPtr(wref, 2)); + } + // Allocate memory for floating point YUV data if (d.process[0]) AlignedMalloc(dstYd, dst_pcount[0]); if (d.process[1]) AlignedMalloc(dstUd, dst_pcount[1]); @@ -584,17 +608,16 @@ void BM3D_Process_Base::process_core_yuv() if (d.process[2]) AlignedMalloc(srcVd, src_pcount[2]); if (d.rdef) - { AlignedMalloc(refYd, ref_pcount[0]); - if (d.wiener && d.process[1]) AlignedMalloc(refUd, ref_pcount[1]); - if (d.wiener && d.process[2]) AlignedMalloc(refVd, ref_pcount[2]); - } else - { refYd = srcYd; - refUd = srcUd; - refVd = srcVd; - } + + if (d.wdef && d.process[0]) + AlignedMalloc(wrefYd, wref_pcount[0]); + else + wrefYd = refYd; + if (d.wiener && d.process[1]) AlignedMalloc(wrefUd, wref_pcount[1]); + if (d.wiener && d.process[2]) AlignedMalloc(wrefVd, wref_pcount[2]); // Convert src and ref from integer YUV data to floating point YUV data if (d.process[0] || !d.rdef) Int2Float(srcYd, srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false); @@ -602,14 +625,20 @@ void BM3D_Process_Base::process_core_yuv() if (d.process[2]) Int2Float(srcVd, srcV, src_height[2], src_width[2], src_stride[2], src_stride[2], true, full, false); if (d.rdef) - { Int2Float(refYd, refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false); - if (d.wiener && d.process[1]) Int2Float(refUd, refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false); - if (d.wiener && d.process[2]) Int2Float(refVd, refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false); + + if (d.wdef) { + if (d.process[0]) Int2Float(wrefYd, wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false); + if (d.process[1]) Int2Float(wrefUd, wrefU, wref_height[1], wref_width[1], wref_stride[1], wref_stride[1], true, full, false); + if (d.process[2]) Int2Float(wrefVd, wrefV, wref_height[2], wref_width[2], wref_stride[2], wref_stride[2], true, full, false); + } + else if (d.wiener) { + if (d.process[1]) Int2Float(wrefUd, refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false); + if (d.process[2]) Int2Float(wrefVd, refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false); } // Execute kernel - Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, refUd, refVd); + Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, wrefYd, wrefUd, wrefVd); // Convert dst from floating point YUV data to integer YUV data if (d.process[0]) Float2Int(dstY, dstYd, dst_height[0], dst_width[0], dst_stride[0], dst_stride[0], false, full, !isFloat(_Ty)); @@ -625,12 +654,11 @@ void BM3D_Process_Base::process_core_yuv() if (d.process[1]) AlignedFree(srcUd); if (d.process[2]) AlignedFree(srcVd); - if (d.rdef) - { - AlignedFree(refYd); - if (d.wiener && d.process[1]) AlignedFree(refUd); - if (d.wiener && d.process[2]) AlignedFree(refVd); - } + if (d.rdef) AlignedFree(refYd); + + if (d.wdef && d.process[0]) AlignedFree(wrefYd); + if (d.wiener && d.process[1]) AlignedFree(wrefUd); + if (d.wiener && d.process[2]) AlignedFree(wrefVd); } template <> @@ -646,11 +674,19 @@ void BM3D_Process_Base::process_core_yuv() auto srcV = reinterpret_cast(vsapi->getReadPtr(src, 2)); auto refY = reinterpret_cast(vsapi->getReadPtr(ref, 0)); - auto refU = reinterpret_cast(vsapi->getReadPtr(ref, 1)); - auto refV = reinterpret_cast(vsapi->getReadPtr(ref, 2)); + + auto wrefY = static_cast(nullptr); + auto wrefU = static_cast(nullptr); + auto wrefV = static_cast(nullptr); + + if (wref != nullptr) { + wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0)); + wrefU = reinterpret_cast(vsapi->getReadPtr(wref, 1)); + wrefV = reinterpret_cast(vsapi->getReadPtr(wref, 2)); + } // Execute kernel - Kernel(dstY, dstU, dstV, srcY, srcU, srcV, refY, refU, refV); + Kernel(dstY, dstU, dstV, srcY, srcU, srcV, refY, wrefY, wrefU, wrefV); } @@ -659,7 +695,8 @@ void BM3D_Process_Base::process_core_rgb() { FLType *dstYd = nullptr, *dstUd = nullptr, *dstVd = nullptr; FLType *srcYd = nullptr, *srcUd = nullptr, *srcVd = nullptr; - FLType *refYd = nullptr, *refUd = nullptr, *refVd = nullptr; + FLType *refYd = nullptr; + FLType *wrefYd = nullptr, *wrefUd = nullptr, *wrefVd = nullptr; // Get write/read pointer auto dstR = reinterpret_cast<_Ty *>(vsapi->getWritePtr(dst, 0)); @@ -674,6 +711,16 @@ void BM3D_Process_Base::process_core_rgb() auto refG = reinterpret_cast(vsapi->getReadPtr(ref, 1)); auto refB = reinterpret_cast(vsapi->getReadPtr(ref, 2)); + auto wrefR = static_cast(nullptr); + auto wrefG = static_cast(nullptr); + auto wrefB = static_cast(nullptr); + + if (wref != nullptr) { + wrefR = reinterpret_cast(vsapi->getReadPtr(wref, 0)); + wrefG = reinterpret_cast(vsapi->getReadPtr(wref, 1)); + wrefB = reinterpret_cast(vsapi->getReadPtr(wref, 2)); + } + // Allocate memory for floating point YUV data AlignedMalloc(dstYd, dst_pcount[0]); AlignedMalloc(dstUd, dst_pcount[1]); @@ -684,17 +731,16 @@ void BM3D_Process_Base::process_core_rgb() AlignedMalloc(srcVd, src_pcount[2]); if (d.rdef) - { AlignedMalloc(refYd, ref_pcount[0]); - if (d.wiener) AlignedMalloc(refUd, ref_pcount[1]); - if (d.wiener) AlignedMalloc(refVd, ref_pcount[2]); - } else - { refYd = srcYd; - refUd = srcUd; - refVd = srcVd; - } + + if (d.wdef) + AlignedMalloc(wrefYd, wref_pcount[0]); + else + wrefYd = refYd; + if (d.wiener) AlignedMalloc(wrefUd, wref_pcount[1]); + if (d.wiener) AlignedMalloc(wrefVd, wref_pcount[2]); // Convert src and ref from RGB data to floating point YUV data RGB2FloatYUV(srcYd, srcUd, srcVd, srcR, srcG, srcB, @@ -702,23 +748,21 @@ void BM3D_Process_Base::process_core_rgb() ColorMatrix::OPP, true, false); if (d.rdef) - { - if (d.wiener) - { - RGB2FloatYUV(refYd, refUd, refVd, refR, refG, refB, - ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], - ColorMatrix::OPP, true, false); - } - else - { - RGB2FloatY(refYd, refR, refG, refB, - ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], - ColorMatrix::OPP, true, false); - } - } + RGB2FloatY(refYd, refR, refG, refB, + ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], + ColorMatrix::OPP, true, false); + + if (d.wdef) + RGB2FloatYUV(wrefYd, wrefUd, wrefVd, wrefR, wrefG, wrefB, + wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], + ColorMatrix::OPP, true, false); + else if (d.wiener) + RGB2FloatYUV(wrefYd, wrefUd, wrefVd, refR, refG, refB, + ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], + ColorMatrix::OPP, true, false); // Execute kernel - Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, refUd, refVd); + Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, wrefYd, wrefUd, wrefVd); // Convert dst from floating point YUV data to RGB data FloatYUV2RGB(dstR, dstG, dstB, dstYd, dstUd, dstVd, @@ -735,11 +779,12 @@ void BM3D_Process_Base::process_core_rgb() AlignedFree(srcVd); if (d.rdef) - { AlignedFree(refYd); - if (d.wiener) AlignedFree(refUd); - if (d.wiener) AlignedFree(refVd); - } + + if (d.wdef) + AlignedFree(wrefYd); + if (d.wiener) AlignedFree(wrefUd); + if (d.wiener) AlignedFree(wrefVd); } diff --git a/source/BM3D_Final.cpp b/source/BM3D_Final.cpp index f39f641..f3a2220 100644 --- a/source/BM3D_Final.cpp +++ b/source/BM3D_Final.cpp @@ -36,6 +36,36 @@ int BM3D_Final_Data::arguments_process(const VSMap *in, VSMap *out) return 1; } + auto error = 0; + wnode = vsapi->propGetNode(in, "wref", 0, &error); + + if (error) { + wdef = false; + wnode = rnode; + wvi = rvi; + } + else { + wdef = true; + wvi = vsapi->getVideoInfo(wnode); + + if (!isConstantFormat(wvi)) { + setError(out, "Invalid clip \"wref\", only constant format input supported"); + return 1; + } + if (wvi->format != vi->format) { + setError(out, "input clip and clip \"wref\" must be of the same format"); + return 1; + } + if (wvi->width != vi->width || wvi->height != vi->height) { + setError(out, "input clip and clip \"wref\" must be of the same width and height"); + return 1; + } + if (wvi->numFrames != vi->numFrames) { + setError(out, "input clip and clip \"wref\" must have the same number of frames"); + return 1; + } + } + // Initialize filter data for empirical Wiener filtering init_filter_data(); diff --git a/source/VBM3D_Base.cpp b/source/VBM3D_Base.cpp index 83ddc5d..800bab0 100644 --- a/source/VBM3D_Base.cpp +++ b/source/VBM3D_Base.cpp @@ -412,7 +412,7 @@ void VBM3D_Data_Base::init_filter_data() void VBM3D_Process_Base::Kernel(const std::vector &dst, - const std::vector &src, const std::vector &ref) const + const std::vector &src, const std::vector &ref, const std::vector &wref) const { std::vector ResNum(frames), ResDen(frames); @@ -455,7 +455,7 @@ void VBM3D_Process_Base::Kernel(const std::vector &dst, Pos3PairCode matchCode = BlockMatching(ref, j, i); // Get the filtered result through collaborative filtering and aggregation of matched blocks - CollaborativeFilter(0, ResNum, ResDen, src, ref, matchCode); + CollaborativeFilter(0, ResNum, ResDen, src, wref, matchCode); } } } @@ -463,7 +463,8 @@ void VBM3D_Process_Base::Kernel(const std::vector &dst, void VBM3D_Process_Base::Kernel(const std::vector &dstY, const std::vector &dstU, const std::vector &dstV, const std::vector &srcY, const std::vector &srcU, const std::vector &srcV, - const std::vector &refY, const std::vector &refU, const std::vector &refV) const + const std::vector &refY, + const std::vector &wrefY, const std::vector &wrefU, const std::vector &wrefV) const { std::vector ResNumY(frames), ResDenY(frames); std::vector ResNumU(frames), ResDenU(frames); @@ -533,9 +534,9 @@ void VBM3D_Process_Base::Kernel(const std::vector &dstY, const std::ve Pos3PairCode matchCode = BlockMatching(refY, j, i); // Get the filtered result through collaborative filtering and aggregation of matched blocks - if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, refY, matchCode); - if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, refU, matchCode); - if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, refV, matchCode); + if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, wrefY, matchCode); + if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, wrefU, matchCode); + if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, wrefV, matchCode); } } } @@ -694,8 +695,9 @@ void VBM3D_Process_Base::process_core_gray() std::vector dstYv; std::vector srcYv; std::vector refYv; + std::vector wrefYv; - std::vector srcYd(frames, nullptr), refYd(frames, nullptr); + std::vector srcYd(frames, nullptr), refYd(frames, nullptr), wrefYd(frames, nullptr); // Get write pointer auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0)) @@ -707,30 +709,39 @@ void VBM3D_Process_Base::process_core_gray() auto srcY = reinterpret_cast(vsapi->getReadPtr(v_src[i], 0)); auto refY = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 0)); + auto wrefY = static_cast(nullptr); + if (v_wref.size() > i) + wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0)); + // Allocate memory for floating point Y data AlignedMalloc(srcYd[i], src_pcount[0]); if (d.rdef) AlignedMalloc(refYd[i], ref_pcount[0]); else refYd[i] = srcYd[i]; + if (d.wdef) AlignedMalloc(wrefYd[i], wref_pcount[0]); + else wrefYd[i] = refYd[i]; // Convert src and ref from integer Y data to floating point Y data Int2Float(srcYd[i], srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false); if (d.rdef) Int2Float(refYd[i], refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false); + if (d.wdef) Int2Float(wrefYd[i], wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false); // Store pointer to floating point Y data into corresponding frame of the vector dstYv.push_back(dstY + dst_pcount[0] * (i * 2)); dstYv.push_back(dstY + dst_pcount[0] * (i * 2 + 1)); srcYv.push_back(srcYd[i]); refYv.push_back(refYd[i]); + wrefYv.push_back(wrefYd[i]); } // Execute kernel - Kernel(dstYv, srcYv, refYv); + Kernel(dstYv, srcYv, refYv, wrefYv); // Free memory for floating point Y data for (int i = 0; i < frames; ++i) { AlignedFree(srcYd[i]); if (d.rdef) AlignedFree(refYd[i]); + if (d.wdef) AlignedFree(wrefYd[i]); } } @@ -740,6 +751,7 @@ void VBM3D_Process_Base::process_core_gray() std::vector dstYv; std::vector srcYv; std::vector refYv; + std::vector wrefYv; // Get write pointer auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0)) @@ -751,15 +763,20 @@ void VBM3D_Process_Base::process_core_gray() auto srcY = reinterpret_cast(vsapi->getReadPtr(v_src[i], 0)); auto refY = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 0)); + auto wrefY = static_cast(nullptr); + if (v_wref.size() > i) + wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0)); + // Store pointer to floating point Y data into corresponding frame of the vector dstYv.push_back(dstY + dst_pcount[0] * (i * 2)); dstYv.push_back(dstY + dst_pcount[0] * (i * 2 + 1)); srcYv.push_back(srcY); refYv.push_back(refY); + wrefYv.push_back(wrefY); } // Execute kernel - Kernel(dstYv, srcYv, refYv); + Kernel(dstYv, srcYv, refYv, wrefYv); } @@ -775,11 +792,14 @@ void VBM3D_Process_Base::process_core_yuv() std::vector srcVv; std::vector refYv; - std::vector refUv; - std::vector refVv; + + std::vector wrefYv; + std::vector wrefUv; + std::vector wrefVv; std::vector srcYd(frames, nullptr), srcUd(frames, nullptr), srcVd(frames, nullptr); - std::vector refYd(frames, nullptr), refUd(frames, nullptr), refVd(frames, nullptr); + std::vector refYd(frames, nullptr); + std::vector wrefYd(frames, nullptr), wrefUd(frames, nullptr), wrefVd(frames, nullptr); // Get write pointer auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0)) @@ -800,23 +820,32 @@ void VBM3D_Process_Base::process_core_yuv() auto refU = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 1)); auto refV = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 2)); + auto wrefY = static_cast(nullptr); + auto wrefU = static_cast(nullptr); + auto wrefV = static_cast(nullptr); + + if (v_wref.size() > i) { + wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0)); + wrefU = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 1)); + wrefV = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 2)); + } + // Allocate memory for floating point YUV data if (d.process[0] || !d.rdef) AlignedMalloc(srcYd[i], src_pcount[0]); if (d.process[1]) AlignedMalloc(srcUd[i], src_pcount[1]); if (d.process[2]) AlignedMalloc(srcVd[i], src_pcount[2]); if (d.rdef) - { AlignedMalloc(refYd[i], ref_pcount[0]); - if (d.wiener && d.process[1]) AlignedMalloc(refUd[i], ref_pcount[1]); - if (d.wiener && d.process[2]) AlignedMalloc(refVd[i], ref_pcount[2]); - } else - { refYd[i] = srcYd[i]; - refUd[i] = srcUd[i]; - refVd[i] = srcVd[i]; - } + + if (d.wdef && d.process[0]) + AlignedMalloc(wrefYd[i], wref_pcount[0]); + else + wrefYd[i] = refYd[i]; + if (d.wiener && d.process[1]) AlignedMalloc(wrefUd[i], wref_pcount[1]); + if (d.wiener && d.process[2]) AlignedMalloc(wrefVd[i], wref_pcount[2]); // Convert src and ref from integer YUV data to floating point YUV data if (d.process[0] || !d.rdef) Int2Float(srcYd[i], srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false); @@ -824,10 +853,16 @@ void VBM3D_Process_Base::process_core_yuv() if (d.process[2]) Int2Float(srcVd[i], srcV, src_height[2], src_width[2], src_stride[2], src_stride[2], true, full, false); if (d.rdef) - { Int2Float(refYd[i], refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false); - if (d.wiener && d.process[1]) Int2Float(refUd[i], refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false); - if (d.wiener && d.process[2]) Int2Float(refVd[i], refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false); + + if (d.wdef) { + if (d.process[0]) Int2Float(wrefYd[i], wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false); + if (d.process[1]) Int2Float(wrefUd[i], wrefU, wref_height[1], wref_width[1], wref_stride[1], wref_stride[1], true, full, false); + if (d.process[2]) Int2Float(wrefVd[i], wrefV, wref_height[2], wref_width[2], wref_stride[2], wref_stride[2], true, full, false); + } + else if (d.wiener) { + if (d.process[1]) Int2Float(wrefUd[i], refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false); + if (d.process[2]) Int2Float(wrefVd[i], refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false); } // Store pointer to floating point YUV data into corresponding frame in the vector @@ -844,12 +879,14 @@ void VBM3D_Process_Base::process_core_yuv() srcVv.push_back(srcVd[i]); refYv.push_back(refYd[i]); - refUv.push_back(refUd[i]); - refVv.push_back(refVd[i]); + + wrefYv.push_back(wrefYd[i]); + wrefUv.push_back(wrefUd[i]); + wrefVv.push_back(wrefVd[i]); } // Execute kernel - Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, refUv, refVv); + Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, wrefYv, wrefUv, wrefVv); // Free memory for floating point YUV data for (int i = 0; i < frames; ++i) @@ -858,12 +895,11 @@ void VBM3D_Process_Base::process_core_yuv() if (d.process[1]) AlignedFree(srcUd[i]); if (d.process[2]) AlignedFree(srcVd[i]); - if (d.rdef) - { - AlignedFree(refYd[i]); - if (d.wiener && d.process[1]) AlignedFree(refUd[i]); - if (d.wiener && d.process[2]) AlignedFree(refVd[i]); - } + if (d.rdef) AlignedFree(refYd[i]); + + if (d.wdef && d.process[0]) AlignedFree(wrefYd[i]); + if (d.wiener && d.process[1]) AlignedFree(wrefUd[i]); + if (d.wiener && d.process[2]) AlignedFree(wrefVd[i]); } } @@ -879,8 +915,10 @@ void VBM3D_Process_Base::process_core_yuv() std::vector srcVv; std::vector refYv; - std::vector refUv; - std::vector refVv; + + std::vector wrefYv; + std::vector wrefUv; + std::vector wrefVv; // Get write/read pointer auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0)) @@ -898,8 +936,16 @@ void VBM3D_Process_Base::process_core_yuv() auto srcV = reinterpret_cast(vsapi->getReadPtr(v_src[i], 2)); auto refY = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 0)); - auto refU = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 1)); - auto refV = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 2)); + + auto wrefY = static_cast(nullptr); + auto wrefU = static_cast(nullptr); + auto wrefV = static_cast(nullptr); + + if (v_wref.size() > i) { + wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0)); + wrefU = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 1)); + wrefV = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 2)); + } // Store pointer to floating point YUV data into corresponding frame in the vector dstYv.push_back(dstY + dst_pcount[0] * (i * 2)); @@ -915,12 +961,14 @@ void VBM3D_Process_Base::process_core_yuv() srcVv.push_back(srcV); refYv.push_back(refY); - refUv.push_back(refU); - refVv.push_back(refV); + + wrefYv.push_back(wrefY); + wrefUv.push_back(wrefU); + wrefVv.push_back(wrefV); } // Execute kernel - Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, refUv, refVv); + Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, wrefYv, wrefUv, wrefVv); } @@ -936,11 +984,14 @@ void VBM3D_Process_Base::process_core_rgb() std::vector srcVv; std::vector refYv; - std::vector refUv; - std::vector refVv; + + std::vector wrefYv; + std::vector wrefUv; + std::vector wrefVv; std::vector srcYd(frames, nullptr), srcUd(frames, nullptr), srcVd(frames, nullptr); - std::vector refYd(frames, nullptr), refUd(frames, nullptr), refVd(frames, nullptr); + std::vector refYd(frames, nullptr); + std::vector wrefYd(frames, nullptr), wrefUd(frames, nullptr), wrefVd(frames, nullptr); // Get write pointer auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0)) @@ -961,23 +1012,32 @@ void VBM3D_Process_Base::process_core_rgb() auto refG = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 1)); auto refB = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 2)); + auto wrefR = static_cast(nullptr); + auto wrefG = static_cast(nullptr); + auto wrefB = static_cast(nullptr); + + if (v_wref.size() > i) { + wrefR = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0)); + wrefG = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 1)); + wrefB = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 2)); + } + // Allocate memory for floating point YUV data AlignedMalloc(srcYd[i], src_pcount[0]); AlignedMalloc(srcUd[i], src_pcount[1]); AlignedMalloc(srcVd[i], src_pcount[2]); if (d.rdef) - { AlignedMalloc(refYd[i], ref_pcount[0]); - if (d.wiener) AlignedMalloc(refUd[i], ref_pcount[1]); - if (d.wiener) AlignedMalloc(refVd[i], ref_pcount[2]); - } else - { refYd[i] = srcYd[i]; - refUd[i] = srcUd[i]; - refVd[i] = srcVd[i]; - } + + if (d.wdef) + AlignedMalloc(wrefYd[i], wref_pcount[0]); + else + wrefYd[i] = refYd[i]; + if (d.wiener) AlignedMalloc(wrefUd[i], wref_pcount[1]); + if (d.wiener) AlignedMalloc(wrefVd[i], wref_pcount[2]); // Convert src and ref from RGB data to floating point YUV data RGB2FloatYUV(srcYd[i], srcUd[i], srcVd[i], srcR, srcG, srcB, @@ -985,20 +1045,18 @@ void VBM3D_Process_Base::process_core_rgb() ColorMatrix::OPP, true, false); if (d.rdef) - { - if (d.wiener) - { - RGB2FloatYUV(refYd[i], refUd[i], refVd[i], refR, refG, refB, - ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], - ColorMatrix::OPP, true, false); - } - else - { - RGB2FloatY(refYd[i], refR, refG, refB, - ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], - ColorMatrix::OPP, true, false); - } - } + RGB2FloatY(refYd[i], refR, refG, refB, + ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], + ColorMatrix::OPP, true, false); + + if (d.wdef) + RGB2FloatYUV(wrefYd[i], wrefUd[i], wrefVd[i], wrefR, wrefG, wrefB, + wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], + ColorMatrix::OPP, true, false); + else if (d.wiener) + RGB2FloatYUV(wrefYd[i], wrefUd[i], wrefVd[i], refR, refG, refB, + ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], + ColorMatrix::OPP, true, false); // Store pointer to floating point YUV data into corresponding frame in the vector dstYv.push_back(dstY + dst_pcount[0] * (i * 2)); @@ -1014,12 +1072,14 @@ void VBM3D_Process_Base::process_core_rgb() srcVv.push_back(srcVd[i]); refYv.push_back(refYd[i]); - refUv.push_back(refUd[i]); - refVv.push_back(refVd[i]); + + wrefYv.push_back(wrefYd[i]); + wrefUv.push_back(wrefUd[i]); + wrefVv.push_back(wrefVd[i]); } // Execute kernel - Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, refUv, refVv); + Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, wrefYv, wrefUv, wrefVv); // Free memory for floating point YUV data for (int i = 0; i < frames; ++i) @@ -1029,11 +1089,12 @@ void VBM3D_Process_Base::process_core_rgb() AlignedFree(srcVd[i]); if (d.rdef) - { AlignedFree(refYd[i]); - if (d.wiener) AlignedFree(refUd[i]); - if (d.wiener) AlignedFree(refVd[i]); - } + + if (d.wdef) + AlignedFree(wrefYd[i]); + if (d.wiener) AlignedFree(wrefUd[i]); + if (d.wiener) AlignedFree(wrefVd[i]); } } diff --git a/source/VBM3D_Final.cpp b/source/VBM3D_Final.cpp index ce0b3e3..566c2c1 100644 --- a/source/VBM3D_Final.cpp +++ b/source/VBM3D_Final.cpp @@ -36,6 +36,36 @@ int VBM3D_Final_Data::arguments_process(const VSMap *in, VSMap *out) return 1; } + auto error = 0; + wnode = vsapi->propGetNode(in, "wref", 0, &error); + + if (error) { + wdef = false; + wnode = rnode; + wvi = rvi; + } + else { + wdef = true; + wvi = vsapi->getVideoInfo(wnode); + + if (!isConstantFormat(wvi)) { + setError(out, "Invalid clip \"wref\", only constant format input supported"); + return 1; + } + if (wvi->format != vi->format) { + setError(out, "input clip and clip \"wref\" must be of the same format"); + return 1; + } + if (wvi->width != vi->width || wvi->height != vi->height) { + setError(out, "input clip and clip \"wref\" must be of the same width and height"); + return 1; + } + if (wvi->numFrames != vi->numFrames) { + setError(out, "input clip and clip \"wref\" must have the same number of frames"); + return 1; + } + } + // Initialize filter data for empirical Wiener filtering init_filter_data(); diff --git a/source/VSPlugin.cpp b/source/VSPlugin.cpp index 8183023..30d5c7d 100644 --- a/source/VSPlugin.cpp +++ b/source/VSPlugin.cpp @@ -213,6 +213,7 @@ static const VSFrameRef *VS_CC BM3D_Final_GetFrame(int n, int activationReason, { vsapi->requestFrameFilter(n, d->node, frameCtx); if (d->rdef) vsapi->requestFrameFilter(n, d->rnode, frameCtx); + if (d->wdef) vsapi->requestFrameFilter(n, d->wnode, frameCtx); } else if (activationReason == arAllFramesReady) { @@ -342,6 +343,7 @@ static const VSFrameRef *VS_CC VBM3D_Final_GetFrame(int n, int activationReason, { vsapi->requestFrameFilter(n + o, d->node, frameCtx); if (d->rdef) vsapi->requestFrameFilter(n + o, d->rnode, frameCtx); + if (d->wdef) vsapi->requestFrameFilter(n + o, d->wnode, frameCtx); } } else if (activationReason == arAllFramesReady) @@ -479,6 +481,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegiste registerFunc("Final", "input:clip;" "ref:clip;" + "wref:clip:opt;" "profile:data:opt;" "sigma:float[]:opt;" "block_size:int:opt;" @@ -512,6 +515,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegiste registerFunc("VFinal", "input:clip;" "ref:clip;" + "wref:clip:opt;" "profile:data:opt;" "sigma:float[]:opt;" "radius:int:opt;"