diff --git a/README.md b/README.md
index 2cef750..1417c0d 100644
--- a/README.md
+++ b/README.md
@@ -183,7 +183,7 @@ The input clip is processed in 3-step stages, For each reference block:
This final estimate can be realized as a refinement. It can significantly improve the denoising quality, keeping more details and fine structures that were removed in basic estimate.
```python
-bm3d.Final(clip input, clip ref[, string profile="fast", float[] sigma=[10,10,10], int block_size, int block_step, int group_size, int bm_range, int bm_step, float th_mse, int matrix=2])
+bm3d.Final(clip input, clip ref[, clip wref=ref, string profile="fast", float[] sigma=[10,10,10], int block_size, int block_step, int group_size, int bm_range, int bm_step, float th_mse, int matrix=2])
```
- input:
@@ -195,6 +195,21 @@ bm3d.Final(clip input, clip ref[, string profile="fast", float[] sigma=[10,10,10
It must be specified. In original BM3D algorithm, it is the basic estimate.
Alternatively, you can choose any other decent denoising filter as basic estimate, and take this final estimate as a refinement.
+- wref:
+ The reference clip for empirical Wiener filtering. If specified, wref will replace ref as the empirical estimate for Wiener filtering.
+ You should specify this parameter if input and ref are sampled from different domains, e.g. input might be a difference image while ref could be a natural image.
+ In such case you should assign something sampled from the input’s domain to wref.
+ A common use case of wref would be to refine the denoising result of an alternative denoiser using BM3D.
+
+```python
+flt = some_filter(src)
+dif = core.std.MakeDiff(src, flt)
+ref_dif = core.bm3d.Basic(dif, flt)
+ref = core.std.MergeDiff(flt, ref_dif)
+dif = core.bm3d.Final(dif, ref, wref=ref_dif)
+flt = core.std.MergeDiff(flt, dif)
+```
+
- profile, sigma, block_size, block_step, group_size, bm_range, bm_step, th_mse, matrix:
Same as those in bm3d.Basic.
@@ -252,10 +267,10 @@ bm3d.VBasic(clip input[, clip ref=input, string profile="fast", float[] sigma=[1
#### final estimate of V-BM3D denoising filter
```python
-bm3d.VFinal(clip input, clip ref[, string profile="fast", float[] sigma=[10,10,10], int radius, int block_size, int block_step, int group_size, int bm_range, int bm_step, int ps_num, int ps_range, int ps_step, float th_mse, int matrix=2])
+bm3d.VFinal(clip input, clip ref[, clip wref=ref, string profile="fast", float[] sigma=[10,10,10], int radius, int block_size, int block_step, int group_size, int bm_range, int bm_step, int ps_num, int ps_range, int ps_step, float th_mse, int matrix=2])
```
-- input, ref:
+- input, ref, wref:
Same as those in bm3d.Final.
- profile, sigma, block_size, block_step, group_size, bm_range, bm_step, th_mse, matrix:
diff --git a/include/BM3D_Base.h b/include/BM3D_Base.h
index 219d1c2..54738ed 100644
--- a/include/BM3D_Base.h
+++ b/include/BM3D_Base.h
@@ -47,6 +47,10 @@ class BM3D_Data_Base
VSNodeRef *rnode = nullptr;
const VSVideoInfo *rvi = nullptr;
+ bool wdef = false;
+ VSNodeRef* wnode = nullptr;
+ const VSVideoInfo* wvi = nullptr;
+
bool wiener;
ColorMatrix matrix;
@@ -126,11 +130,19 @@ class BM3D_Process_Base
const VSFrameRef *ref = nullptr;
const VSFormat *rfi = nullptr;
+ const VSFrameRef* wref = nullptr;
+ const VSFormat* wfi = nullptr;
+
PCType ref_height[VSMaxPlaneCount];
PCType ref_width[VSMaxPlaneCount];
PCType ref_stride[VSMaxPlaneCount];
PCType ref_pcount[VSMaxPlaneCount];
+ PCType wref_height[VSMaxPlaneCount];
+ PCType wref_width[VSMaxPlaneCount];
+ PCType wref_stride[VSMaxPlaneCount];
+ PCType wref_pcount[VSMaxPlaneCount];
+
bool full = true;
private:
@@ -220,11 +232,12 @@ class BM3D_Process_Base
_NewFrame(width, height, dfi == fi);
}
- void Kernel(FLType *dst, const FLType *src, const FLType *ref) const;
+ void Kernel(FLType *dst, const FLType *src, const FLType *ref, const FLType *wref) const;
void Kernel(FLType *dstY, FLType *dstU, FLType *dstV,
const FLType *srcY, const FLType *srcU, const FLType *srcV,
- const FLType *refY, const FLType *refU, const FLType *refV) const;
+ const FLType *refY,
+ const FLType *wrefY, const FLType *wrefU, const FLType *wrefV) const;
PosPairCode BlockMatching(const FLType *ref, PCType j, PCType i) const;
diff --git a/include/BM3D_Final.h b/include/BM3D_Final.h
index 2debbf2..f11de35 100644
--- a/include/BM3D_Final.h
+++ b/include/BM3D_Final.h
@@ -49,7 +49,9 @@ class BM3D_Final_Data
_Myt &operator=(const _Myt &right) = delete;
_Myt &operator=(_Myt &&right) = delete;
- virtual ~BM3D_Final_Data() override {}
+ virtual ~BM3D_Final_Data() override {
+ if (wdef && wnode) vsapi->freeNode(wnode);
+ }
virtual int arguments_process(const VSMap *in, VSMap *out) override;
};
@@ -70,11 +72,29 @@ class BM3D_Final_Process
const _Mydata &d;
public:
- BM3D_Final_Process(_Mydata &_d, int _n, VSFrameContext *_frameCtx, VSCore *_core, const VSAPI *_vsapi)
- : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d)
- {}
-
- virtual ~BM3D_Final_Process() override {}
+ BM3D_Final_Process(_Mydata& _d, int _n, VSFrameContext* _frameCtx, VSCore* _core, const VSAPI* _vsapi)
+ : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d) {
+ if (d.wdef) {
+ wref = vsapi->getFrameFilter(n, d.wnode, frameCtx);
+ wfi = vsapi->getFrameFormat(wref);
+ }
+ else {
+ wref = ref;
+ wfi = rfi;
+ }
+
+ if (!skip)
+ for (int i = 0; i < PlaneCount; ++i) {
+ wref_height[i] = vsapi->getFrameHeight(wref, i);
+ wref_width[i] = vsapi->getFrameWidth(wref, i);
+ wref_stride[i] = vsapi->getStride(wref, i) / wfi->bytesPerSample;
+ wref_pcount[i] = wref_height[i] * wref_stride[i];
+ }
+ }
+
+ virtual ~BM3D_Final_Process() override {
+ if (d.wdef) vsapi->freeFrame(wref);
+ }
protected:
virtual void CollaborativeFilter(int plane,
diff --git a/include/VBM3D_Base.h b/include/VBM3D_Base.h
index 831c3f4..000bee0 100644
--- a/include/VBM3D_Base.h
+++ b/include/VBM3D_Base.h
@@ -63,6 +63,10 @@ class VBM3D_Data_Base
VSNodeRef *rnode = nullptr;
const VSVideoInfo *rvi = nullptr;
+ bool wdef = false;
+ VSNodeRef* wnode = nullptr;
+ const VSVideoInfo* wvi = nullptr;
+
bool wiener;
ColorMatrix matrix;
@@ -135,8 +139,10 @@ class VBM3D_Process_Base
std::vector v_src;
std::vector v_ref;
-
+ std::vector v_wref;
+
const VSFormat *rfi = nullptr;
+ const VSFormat* wfi = nullptr;
PCType dst_height[VSMaxPlaneCount];
PCType dst_pcount[VSMaxPlaneCount];
@@ -146,6 +152,11 @@ class VBM3D_Process_Base
PCType ref_stride[VSMaxPlaneCount];
PCType ref_pcount[VSMaxPlaneCount];
+ PCType wref_height[VSMaxPlaneCount];
+ PCType wref_width[VSMaxPlaneCount];
+ PCType wref_stride[VSMaxPlaneCount];
+ PCType wref_pcount[VSMaxPlaneCount];
+
bool full = true;
private:
@@ -310,11 +321,12 @@ class VBM3D_Process_Base
vsapi->propSetIntArray(dst_map, "BM3D_V_process", process, VSMaxPlaneCount);
}
- void Kernel(const std::vector &dst, const std::vector &src, const std::vector &ref) const;
+ void Kernel(const std::vector &dst, const std::vector &src, const std::vector &ref, const std::vector &wref) const;
void Kernel(const std::vector &dstY, const std::vector &dstU, const std::vector &dstV,
const std::vector &srcY, const std::vector &srcU, const std::vector &srcV,
- const std::vector &refY, const std::vector &refU, const std::vector &refV) const;
+ const std::vector &refY,
+ const std::vector &wrefY, const std::vector &wrefU, const std::vector &wrefV) const;
Pos3PairCode BlockMatching(const std::vector &ref, PCType j, PCType i) const;
diff --git a/include/VBM3D_Final.h b/include/VBM3D_Final.h
index 9d61424..74b5c56 100644
--- a/include/VBM3D_Final.h
+++ b/include/VBM3D_Final.h
@@ -49,7 +49,9 @@ class VBM3D_Final_Data
_Myt &operator=(const _Myt &right) = delete;
_Myt &operator=(_Myt &&right) = delete;
- virtual ~VBM3D_Final_Data() override {}
+ virtual ~VBM3D_Final_Data() override {
+ if (wdef && wnode) vsapi->freeNode(wnode);
+ }
virtual int arguments_process(const VSMap *in, VSMap *out) override;
};
@@ -71,10 +73,32 @@ class VBM3D_Final_Process
public:
VBM3D_Final_Process(const _Mydata &_d, int _n, VSFrameContext *_frameCtx, VSCore *_core, const VSAPI *_vsapi)
- : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d)
- {}
-
- virtual ~VBM3D_Final_Process() override {}
+ : _Mybase(_d, _n, _frameCtx, _core, _vsapi), d(_d) {
+ if (d.wdef) {
+ for (int o = b_offset; o <= f_offset; ++o)
+ v_wref.push_back(vsapi->getFrameFilter(n + o, d.wnode, frameCtx));
+
+ wfi = vsapi->getFrameFormat(v_wref[cur]);
+ }
+ else {
+ v_wref = v_ref;
+ wfi = rfi;
+ }
+
+ if (!skip)
+ for (int i = 0; i < PlaneCount; ++i) {
+ wref_height[i] = vsapi->getFrameHeight(v_wref[cur], i);
+ wref_width[i] = vsapi->getFrameWidth(v_wref[cur], i);
+ wref_stride[i] = vsapi->getStride(v_wref[cur], i) / wfi->bytesPerSample;
+ wref_pcount[i] = wref_height[i] * wref_stride[i];
+ }
+ }
+
+ virtual ~VBM3D_Final_Process() override {
+ if (d.wdef)
+ for (int i = 0; i < frames; ++i)
+ vsapi->freeFrame(v_wref[i]);
+ }
protected:
virtual void CollaborativeFilter(int plane,
diff --git a/msvc/BM3D.vcxproj b/msvc/BM3D.vcxproj
index 6110b4c..6ff7b76 100644
--- a/msvc/BM3D.vcxproj
+++ b/msvc/BM3D.vcxproj
@@ -24,32 +24,32 @@
{E93BF469-E2BA-4C9A-9E5A-C08D085FE469}
Win32Proj
BM3D
- 10.0.17763.0
+ 10.0
DynamicLibrary
true
- v141
+ v142
Unicode
DynamicLibrary
false
- v141
+ v142
true
Unicode
DynamicLibrary
true
- v141
+ v142
Unicode
DynamicLibrary
false
- v141
+ v142
true
Unicode
@@ -125,6 +125,7 @@
false
true
true
+ stdcpplatest
true
diff --git a/source/BM3D_Base.cpp b/source/BM3D_Base.cpp
index 1513bdd..b3faa80 100644
--- a/source/BM3D_Base.cpp
+++ b/source/BM3D_Base.cpp
@@ -291,10 +291,10 @@ void BM3D_Data_Base::init_filter_data()
// Functions of class BM3D_Process_Base
-void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref) const
+void BM3D_Process_Base::Kernel(FLType* dst, const FLType* src, const FLType* ref, const FLType* wref) const
{
std::thread::id threadId = std::this_thread::get_id();
- FLType *ResNum = dst, *ResDen = nullptr;
+ FLType* ResNum = dst, * ResDen = nullptr;
if (!d.buffer0.count(threadId))
{
@@ -340,26 +340,27 @@ void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref
PosPairCode matchCode = BlockMatching(ref, j, i);
// Get the filtered result through collaborative filtering and aggregation of matched blocks
- CollaborativeFilter(0, ResNum, ResDen, src, ref, matchCode);
+ CollaborativeFilter(0, ResNum, ResDen, src, wref, matchCode);
}
}
// The filtered blocks are sumed and averaged to form the final filtered image
LOOP_VH(dst_height[0], dst_width[0], dst_stride[0], [&](PCType i)
- {
- dst[i] = ResNum[i] / ResDen[i];
- });
+ {
+ dst[i] = ResNum[i] / ResDen[i];
+ });
}
-void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV,
- const FLType *srcY, const FLType *srcU, const FLType *srcV,
- const FLType *refY, const FLType *refU, const FLType *refV) const
+void BM3D_Process_Base::Kernel(FLType* dstY, FLType* dstU, FLType* dstV,
+ const FLType* srcY, const FLType* srcU, const FLType* srcV,
+ const FLType* refY,
+ const FLType* wrefY, const FLType* wrefU, const FLType* wrefV) const
{
std::thread::id threadId = std::this_thread::get_id();
- FLType *ResNumY = dstY, *ResDenY = nullptr;
- FLType *ResNumU = dstU, *ResDenU = nullptr;
- FLType *ResNumV = dstV, *ResDenV = nullptr;
+ FLType* ResNumY = dstY, * ResDenY = nullptr;
+ FLType* ResNumU = dstU, * ResDenU = nullptr;
+ FLType* ResNumV = dstV, * ResDenV = nullptr;
if (d.process[0])
{
@@ -440,27 +441,27 @@ void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV,
PosPairCode matchCode = BlockMatching(refY, j, i);
// Get the filtered result through collaborative filtering and aggregation of matched blocks
- if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, refY, matchCode);
- if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, refU, matchCode);
- if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, refV, matchCode);
+ if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, wrefY, matchCode);
+ if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, wrefU, matchCode);
+ if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, wrefV, matchCode);
}
}
// The filtered blocks are sumed and averaged to form the final filtered image
if (d.process[0]) LOOP_VH(dst_height[0], dst_width[0], dst_stride[0], [&](PCType i)
- {
- dstY[i] = ResNumY[i] / ResDenY[i];
- });
+ {
+ dstY[i] = ResNumY[i] / ResDenY[i];
+ });
if (d.process[1]) LOOP_VH(dst_height[1], dst_width[1], dst_stride[1], [&](PCType i)
- {
- dstU[i] = ResNumU[i] / ResDenU[i];
- });
+ {
+ dstU[i] = ResNumU[i] / ResDenU[i];
+ });
if (d.process[2]) LOOP_VH(dst_height[2], dst_width[2], dst_stride[2], [&](PCType i)
- {
- dstV[i] = ResNumV[i] / ResDenV[i];
- });
+ {
+ dstV[i] = ResNumV[i] / ResDenV[i];
+ });
}
@@ -512,25 +513,32 @@ void BM3D_Process_Base::process_core()
template < typename _Ty >
void BM3D_Process_Base::process_core_gray()
{
- FLType *dstYd = nullptr, *srcYd = nullptr, *refYd = nullptr;
+ FLType *dstYd = nullptr, *srcYd = nullptr, *refYd = nullptr, *wrefYd = nullptr;
// Get write/read pointer
auto dstY = reinterpret_cast<_Ty *>(vsapi->getWritePtr(dst, 0));
auto srcY = reinterpret_cast(vsapi->getReadPtr(src, 0));
auto refY = reinterpret_cast(vsapi->getReadPtr(ref, 0));
+ auto wrefY = static_cast(nullptr);
+ if (wref != nullptr)
+ wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0));
+
// Allocate memory for floating point Y data
AlignedMalloc(dstYd, dst_pcount[0]);
AlignedMalloc(srcYd, src_pcount[0]);
if (d.rdef) AlignedMalloc(refYd, ref_pcount[0]);
else refYd = srcYd;
+ if (d.wdef) AlignedMalloc(wrefYd, wref_pcount[0]);
+ else wrefYd = refYd;
// Convert src and ref from integer Y data to floating point Y data
Int2Float(srcYd, srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false);
if (d.rdef) Int2Float(refYd, refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false);
+ if (d.wdef) Int2Float(wrefYd, wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false);
// Execute kernel
- Kernel(dstYd, srcYd, refYd);
+ Kernel(dstYd, srcYd, refYd, wrefYd);
// Convert dst from floating point Y data to integer Y data
Float2Int(dstY, dstYd, dst_height[0], dst_width[0], dst_stride[0], dst_stride[0], false, full, !isFloat(_Ty));
@@ -539,6 +547,7 @@ void BM3D_Process_Base::process_core_gray()
AlignedFree(dstYd);
AlignedFree(srcYd);
if (d.rdef) AlignedFree(refYd);
+ if (d.wdef) AlignedFree(wrefYd);
}
template <>
@@ -549,8 +558,12 @@ void BM3D_Process_Base::process_core_gray()
auto srcY = reinterpret_cast(vsapi->getReadPtr(src, 0));
auto refY = reinterpret_cast(vsapi->getReadPtr(ref, 0));
+ auto wrefY = static_cast(nullptr);
+ if (wref != nullptr)
+ wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0));
+
// Execute kernel
- Kernel(dstY, srcY, refY);
+ Kernel(dstY, srcY, refY, wrefY);
}
@@ -559,7 +572,8 @@ void BM3D_Process_Base::process_core_yuv()
{
FLType *dstYd = nullptr, *dstUd = nullptr, *dstVd = nullptr;
FLType *srcYd = nullptr, *srcUd = nullptr, *srcVd = nullptr;
- FLType *refYd = nullptr, *refUd = nullptr, *refVd = nullptr;
+ FLType *refYd = nullptr;
+ FLType *wrefYd = nullptr, *wrefUd = nullptr, *wrefVd = nullptr;
// Get write/read pointer
auto dstY = reinterpret_cast<_Ty *>(vsapi->getWritePtr(dst, 0));
@@ -574,6 +588,16 @@ void BM3D_Process_Base::process_core_yuv()
auto refU = reinterpret_cast(vsapi->getReadPtr(ref, 1));
auto refV = reinterpret_cast(vsapi->getReadPtr(ref, 2));
+ auto wrefY = static_cast(nullptr);
+ auto wrefU = static_cast(nullptr);
+ auto wrefV = static_cast(nullptr);
+
+ if (wref != nullptr) {
+ wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0));
+ wrefU = reinterpret_cast(vsapi->getReadPtr(wref, 1));
+ wrefV = reinterpret_cast(vsapi->getReadPtr(wref, 2));
+ }
+
// Allocate memory for floating point YUV data
if (d.process[0]) AlignedMalloc(dstYd, dst_pcount[0]);
if (d.process[1]) AlignedMalloc(dstUd, dst_pcount[1]);
@@ -584,17 +608,16 @@ void BM3D_Process_Base::process_core_yuv()
if (d.process[2]) AlignedMalloc(srcVd, src_pcount[2]);
if (d.rdef)
- {
AlignedMalloc(refYd, ref_pcount[0]);
- if (d.wiener && d.process[1]) AlignedMalloc(refUd, ref_pcount[1]);
- if (d.wiener && d.process[2]) AlignedMalloc(refVd, ref_pcount[2]);
- }
else
- {
refYd = srcYd;
- refUd = srcUd;
- refVd = srcVd;
- }
+
+ if (d.wdef && d.process[0])
+ AlignedMalloc(wrefYd, wref_pcount[0]);
+ else
+ wrefYd = refYd;
+ if (d.wiener && d.process[1]) AlignedMalloc(wrefUd, wref_pcount[1]);
+ if (d.wiener && d.process[2]) AlignedMalloc(wrefVd, wref_pcount[2]);
// Convert src and ref from integer YUV data to floating point YUV data
if (d.process[0] || !d.rdef) Int2Float(srcYd, srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false);
@@ -602,14 +625,20 @@ void BM3D_Process_Base::process_core_yuv()
if (d.process[2]) Int2Float(srcVd, srcV, src_height[2], src_width[2], src_stride[2], src_stride[2], true, full, false);
if (d.rdef)
- {
Int2Float(refYd, refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false);
- if (d.wiener && d.process[1]) Int2Float(refUd, refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false);
- if (d.wiener && d.process[2]) Int2Float(refVd, refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false);
+
+ if (d.wdef) {
+ if (d.process[0]) Int2Float(wrefYd, wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false);
+ if (d.process[1]) Int2Float(wrefUd, wrefU, wref_height[1], wref_width[1], wref_stride[1], wref_stride[1], true, full, false);
+ if (d.process[2]) Int2Float(wrefVd, wrefV, wref_height[2], wref_width[2], wref_stride[2], wref_stride[2], true, full, false);
+ }
+ else if (d.wiener) {
+ if (d.process[1]) Int2Float(wrefUd, refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false);
+ if (d.process[2]) Int2Float(wrefVd, refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false);
}
// Execute kernel
- Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, refUd, refVd);
+ Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, wrefYd, wrefUd, wrefVd);
// Convert dst from floating point YUV data to integer YUV data
if (d.process[0]) Float2Int(dstY, dstYd, dst_height[0], dst_width[0], dst_stride[0], dst_stride[0], false, full, !isFloat(_Ty));
@@ -625,12 +654,11 @@ void BM3D_Process_Base::process_core_yuv()
if (d.process[1]) AlignedFree(srcUd);
if (d.process[2]) AlignedFree(srcVd);
- if (d.rdef)
- {
- AlignedFree(refYd);
- if (d.wiener && d.process[1]) AlignedFree(refUd);
- if (d.wiener && d.process[2]) AlignedFree(refVd);
- }
+ if (d.rdef) AlignedFree(refYd);
+
+ if (d.wdef && d.process[0]) AlignedFree(wrefYd);
+ if (d.wiener && d.process[1]) AlignedFree(wrefUd);
+ if (d.wiener && d.process[2]) AlignedFree(wrefVd);
}
template <>
@@ -646,11 +674,19 @@ void BM3D_Process_Base::process_core_yuv()
auto srcV = reinterpret_cast(vsapi->getReadPtr(src, 2));
auto refY = reinterpret_cast(vsapi->getReadPtr(ref, 0));
- auto refU = reinterpret_cast(vsapi->getReadPtr(ref, 1));
- auto refV = reinterpret_cast(vsapi->getReadPtr(ref, 2));
+
+ auto wrefY = static_cast(nullptr);
+ auto wrefU = static_cast(nullptr);
+ auto wrefV = static_cast(nullptr);
+
+ if (wref != nullptr) {
+ wrefY = reinterpret_cast(vsapi->getReadPtr(wref, 0));
+ wrefU = reinterpret_cast(vsapi->getReadPtr(wref, 1));
+ wrefV = reinterpret_cast(vsapi->getReadPtr(wref, 2));
+ }
// Execute kernel
- Kernel(dstY, dstU, dstV, srcY, srcU, srcV, refY, refU, refV);
+ Kernel(dstY, dstU, dstV, srcY, srcU, srcV, refY, wrefY, wrefU, wrefV);
}
@@ -659,7 +695,8 @@ void BM3D_Process_Base::process_core_rgb()
{
FLType *dstYd = nullptr, *dstUd = nullptr, *dstVd = nullptr;
FLType *srcYd = nullptr, *srcUd = nullptr, *srcVd = nullptr;
- FLType *refYd = nullptr, *refUd = nullptr, *refVd = nullptr;
+ FLType *refYd = nullptr;
+ FLType *wrefYd = nullptr, *wrefUd = nullptr, *wrefVd = nullptr;
// Get write/read pointer
auto dstR = reinterpret_cast<_Ty *>(vsapi->getWritePtr(dst, 0));
@@ -674,6 +711,16 @@ void BM3D_Process_Base::process_core_rgb()
auto refG = reinterpret_cast(vsapi->getReadPtr(ref, 1));
auto refB = reinterpret_cast(vsapi->getReadPtr(ref, 2));
+ auto wrefR = static_cast(nullptr);
+ auto wrefG = static_cast(nullptr);
+ auto wrefB = static_cast(nullptr);
+
+ if (wref != nullptr) {
+ wrefR = reinterpret_cast(vsapi->getReadPtr(wref, 0));
+ wrefG = reinterpret_cast(vsapi->getReadPtr(wref, 1));
+ wrefB = reinterpret_cast(vsapi->getReadPtr(wref, 2));
+ }
+
// Allocate memory for floating point YUV data
AlignedMalloc(dstYd, dst_pcount[0]);
AlignedMalloc(dstUd, dst_pcount[1]);
@@ -684,17 +731,16 @@ void BM3D_Process_Base::process_core_rgb()
AlignedMalloc(srcVd, src_pcount[2]);
if (d.rdef)
- {
AlignedMalloc(refYd, ref_pcount[0]);
- if (d.wiener) AlignedMalloc(refUd, ref_pcount[1]);
- if (d.wiener) AlignedMalloc(refVd, ref_pcount[2]);
- }
else
- {
refYd = srcYd;
- refUd = srcUd;
- refVd = srcVd;
- }
+
+ if (d.wdef)
+ AlignedMalloc(wrefYd, wref_pcount[0]);
+ else
+ wrefYd = refYd;
+ if (d.wiener) AlignedMalloc(wrefUd, wref_pcount[1]);
+ if (d.wiener) AlignedMalloc(wrefVd, wref_pcount[2]);
// Convert src and ref from RGB data to floating point YUV data
RGB2FloatYUV(srcYd, srcUd, srcVd, srcR, srcG, srcB,
@@ -702,23 +748,21 @@ void BM3D_Process_Base::process_core_rgb()
ColorMatrix::OPP, true, false);
if (d.rdef)
- {
- if (d.wiener)
- {
- RGB2FloatYUV(refYd, refUd, refVd, refR, refG, refB,
- ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
- ColorMatrix::OPP, true, false);
- }
- else
- {
- RGB2FloatY(refYd, refR, refG, refB,
- ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
- ColorMatrix::OPP, true, false);
- }
- }
+ RGB2FloatY(refYd, refR, refG, refB,
+ ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
+ ColorMatrix::OPP, true, false);
+
+ if (d.wdef)
+ RGB2FloatYUV(wrefYd, wrefUd, wrefVd, wrefR, wrefG, wrefB,
+ wref_height[0], wref_width[0], wref_stride[0], wref_stride[0],
+ ColorMatrix::OPP, true, false);
+ else if (d.wiener)
+ RGB2FloatYUV(wrefYd, wrefUd, wrefVd, refR, refG, refB,
+ ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
+ ColorMatrix::OPP, true, false);
// Execute kernel
- Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, refUd, refVd);
+ Kernel(dstYd, dstUd, dstVd, srcYd, srcUd, srcVd, refYd, wrefYd, wrefUd, wrefVd);
// Convert dst from floating point YUV data to RGB data
FloatYUV2RGB(dstR, dstG, dstB, dstYd, dstUd, dstVd,
@@ -735,11 +779,12 @@ void BM3D_Process_Base::process_core_rgb()
AlignedFree(srcVd);
if (d.rdef)
- {
AlignedFree(refYd);
- if (d.wiener) AlignedFree(refUd);
- if (d.wiener) AlignedFree(refVd);
- }
+
+ if (d.wdef)
+ AlignedFree(wrefYd);
+ if (d.wiener) AlignedFree(wrefUd);
+ if (d.wiener) AlignedFree(wrefVd);
}
diff --git a/source/BM3D_Final.cpp b/source/BM3D_Final.cpp
index f39f641..f3a2220 100644
--- a/source/BM3D_Final.cpp
+++ b/source/BM3D_Final.cpp
@@ -36,6 +36,36 @@ int BM3D_Final_Data::arguments_process(const VSMap *in, VSMap *out)
return 1;
}
+ auto error = 0;
+ wnode = vsapi->propGetNode(in, "wref", 0, &error);
+
+ if (error) {
+ wdef = false;
+ wnode = rnode;
+ wvi = rvi;
+ }
+ else {
+ wdef = true;
+ wvi = vsapi->getVideoInfo(wnode);
+
+ if (!isConstantFormat(wvi)) {
+ setError(out, "Invalid clip \"wref\", only constant format input supported");
+ return 1;
+ }
+ if (wvi->format != vi->format) {
+ setError(out, "input clip and clip \"wref\" must be of the same format");
+ return 1;
+ }
+ if (wvi->width != vi->width || wvi->height != vi->height) {
+ setError(out, "input clip and clip \"wref\" must be of the same width and height");
+ return 1;
+ }
+ if (wvi->numFrames != vi->numFrames) {
+ setError(out, "input clip and clip \"wref\" must have the same number of frames");
+ return 1;
+ }
+ }
+
// Initialize filter data for empirical Wiener filtering
init_filter_data();
diff --git a/source/VBM3D_Base.cpp b/source/VBM3D_Base.cpp
index 83ddc5d..800bab0 100644
--- a/source/VBM3D_Base.cpp
+++ b/source/VBM3D_Base.cpp
@@ -412,7 +412,7 @@ void VBM3D_Data_Base::init_filter_data()
void VBM3D_Process_Base::Kernel(const std::vector &dst,
- const std::vector &src, const std::vector &ref) const
+ const std::vector &src, const std::vector &ref, const std::vector &wref) const
{
std::vector ResNum(frames), ResDen(frames);
@@ -455,7 +455,7 @@ void VBM3D_Process_Base::Kernel(const std::vector &dst,
Pos3PairCode matchCode = BlockMatching(ref, j, i);
// Get the filtered result through collaborative filtering and aggregation of matched blocks
- CollaborativeFilter(0, ResNum, ResDen, src, ref, matchCode);
+ CollaborativeFilter(0, ResNum, ResDen, src, wref, matchCode);
}
}
}
@@ -463,7 +463,8 @@ void VBM3D_Process_Base::Kernel(const std::vector &dst,
void VBM3D_Process_Base::Kernel(const std::vector &dstY, const std::vector &dstU, const std::vector &dstV,
const std::vector &srcY, const std::vector &srcU, const std::vector &srcV,
- const std::vector &refY, const std::vector &refU, const std::vector &refV) const
+ const std::vector &refY,
+ const std::vector &wrefY, const std::vector &wrefU, const std::vector &wrefV) const
{
std::vector ResNumY(frames), ResDenY(frames);
std::vector ResNumU(frames), ResDenU(frames);
@@ -533,9 +534,9 @@ void VBM3D_Process_Base::Kernel(const std::vector &dstY, const std::ve
Pos3PairCode matchCode = BlockMatching(refY, j, i);
// Get the filtered result through collaborative filtering and aggregation of matched blocks
- if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, refY, matchCode);
- if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, refU, matchCode);
- if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, refV, matchCode);
+ if (d.process[0]) CollaborativeFilter(0, ResNumY, ResDenY, srcY, wrefY, matchCode);
+ if (d.process[1]) CollaborativeFilter(1, ResNumU, ResDenU, srcU, wrefU, matchCode);
+ if (d.process[2]) CollaborativeFilter(2, ResNumV, ResDenV, srcV, wrefV, matchCode);
}
}
}
@@ -694,8 +695,9 @@ void VBM3D_Process_Base::process_core_gray()
std::vector dstYv;
std::vector srcYv;
std::vector refYv;
+ std::vector wrefYv;
- std::vector srcYd(frames, nullptr), refYd(frames, nullptr);
+ std::vector srcYd(frames, nullptr), refYd(frames, nullptr), wrefYd(frames, nullptr);
// Get write pointer
auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0))
@@ -707,30 +709,39 @@ void VBM3D_Process_Base::process_core_gray()
auto srcY = reinterpret_cast(vsapi->getReadPtr(v_src[i], 0));
auto refY = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 0));
+ auto wrefY = static_cast(nullptr);
+ if (v_wref.size() > i)
+ wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0));
+
// Allocate memory for floating point Y data
AlignedMalloc(srcYd[i], src_pcount[0]);
if (d.rdef) AlignedMalloc(refYd[i], ref_pcount[0]);
else refYd[i] = srcYd[i];
+ if (d.wdef) AlignedMalloc(wrefYd[i], wref_pcount[0]);
+ else wrefYd[i] = refYd[i];
// Convert src and ref from integer Y data to floating point Y data
Int2Float(srcYd[i], srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false);
if (d.rdef) Int2Float(refYd[i], refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false);
+ if (d.wdef) Int2Float(wrefYd[i], wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false);
// Store pointer to floating point Y data into corresponding frame of the vector
dstYv.push_back(dstY + dst_pcount[0] * (i * 2));
dstYv.push_back(dstY + dst_pcount[0] * (i * 2 + 1));
srcYv.push_back(srcYd[i]);
refYv.push_back(refYd[i]);
+ wrefYv.push_back(wrefYd[i]);
}
// Execute kernel
- Kernel(dstYv, srcYv, refYv);
+ Kernel(dstYv, srcYv, refYv, wrefYv);
// Free memory for floating point Y data
for (int i = 0; i < frames; ++i)
{
AlignedFree(srcYd[i]);
if (d.rdef) AlignedFree(refYd[i]);
+ if (d.wdef) AlignedFree(wrefYd[i]);
}
}
@@ -740,6 +751,7 @@ void VBM3D_Process_Base::process_core_gray()
std::vector dstYv;
std::vector srcYv;
std::vector refYv;
+ std::vector wrefYv;
// Get write pointer
auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0))
@@ -751,15 +763,20 @@ void VBM3D_Process_Base::process_core_gray()
auto srcY = reinterpret_cast(vsapi->getReadPtr(v_src[i], 0));
auto refY = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 0));
+ auto wrefY = static_cast(nullptr);
+ if (v_wref.size() > i)
+ wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0));
+
// Store pointer to floating point Y data into corresponding frame of the vector
dstYv.push_back(dstY + dst_pcount[0] * (i * 2));
dstYv.push_back(dstY + dst_pcount[0] * (i * 2 + 1));
srcYv.push_back(srcY);
refYv.push_back(refY);
+ wrefYv.push_back(wrefY);
}
// Execute kernel
- Kernel(dstYv, srcYv, refYv);
+ Kernel(dstYv, srcYv, refYv, wrefYv);
}
@@ -775,11 +792,14 @@ void VBM3D_Process_Base::process_core_yuv()
std::vector srcVv;
std::vector refYv;
- std::vector refUv;
- std::vector refVv;
+
+ std::vector wrefYv;
+ std::vector wrefUv;
+ std::vector wrefVv;
std::vector srcYd(frames, nullptr), srcUd(frames, nullptr), srcVd(frames, nullptr);
- std::vector refYd(frames, nullptr), refUd(frames, nullptr), refVd(frames, nullptr);
+ std::vector refYd(frames, nullptr);
+ std::vector wrefYd(frames, nullptr), wrefUd(frames, nullptr), wrefVd(frames, nullptr);
// Get write pointer
auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0))
@@ -800,23 +820,32 @@ void VBM3D_Process_Base::process_core_yuv()
auto refU = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 1));
auto refV = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 2));
+ auto wrefY = static_cast(nullptr);
+ auto wrefU = static_cast(nullptr);
+ auto wrefV = static_cast(nullptr);
+
+ if (v_wref.size() > i) {
+ wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0));
+ wrefU = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 1));
+ wrefV = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 2));
+ }
+
// Allocate memory for floating point YUV data
if (d.process[0] || !d.rdef) AlignedMalloc(srcYd[i], src_pcount[0]);
if (d.process[1]) AlignedMalloc(srcUd[i], src_pcount[1]);
if (d.process[2]) AlignedMalloc(srcVd[i], src_pcount[2]);
if (d.rdef)
- {
AlignedMalloc(refYd[i], ref_pcount[0]);
- if (d.wiener && d.process[1]) AlignedMalloc(refUd[i], ref_pcount[1]);
- if (d.wiener && d.process[2]) AlignedMalloc(refVd[i], ref_pcount[2]);
- }
else
- {
refYd[i] = srcYd[i];
- refUd[i] = srcUd[i];
- refVd[i] = srcVd[i];
- }
+
+ if (d.wdef && d.process[0])
+ AlignedMalloc(wrefYd[i], wref_pcount[0]);
+ else
+ wrefYd[i] = refYd[i];
+ if (d.wiener && d.process[1]) AlignedMalloc(wrefUd[i], wref_pcount[1]);
+ if (d.wiener && d.process[2]) AlignedMalloc(wrefVd[i], wref_pcount[2]);
// Convert src and ref from integer YUV data to floating point YUV data
if (d.process[0] || !d.rdef) Int2Float(srcYd[i], srcY, src_height[0], src_width[0], src_stride[0], src_stride[0], false, full, false);
@@ -824,10 +853,16 @@ void VBM3D_Process_Base::process_core_yuv()
if (d.process[2]) Int2Float(srcVd[i], srcV, src_height[2], src_width[2], src_stride[2], src_stride[2], true, full, false);
if (d.rdef)
- {
Int2Float(refYd[i], refY, ref_height[0], ref_width[0], ref_stride[0], ref_stride[0], false, full, false);
- if (d.wiener && d.process[1]) Int2Float(refUd[i], refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false);
- if (d.wiener && d.process[2]) Int2Float(refVd[i], refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false);
+
+ if (d.wdef) {
+ if (d.process[0]) Int2Float(wrefYd[i], wrefY, wref_height[0], wref_width[0], wref_stride[0], wref_stride[0], false, full, false);
+ if (d.process[1]) Int2Float(wrefUd[i], wrefU, wref_height[1], wref_width[1], wref_stride[1], wref_stride[1], true, full, false);
+ if (d.process[2]) Int2Float(wrefVd[i], wrefV, wref_height[2], wref_width[2], wref_stride[2], wref_stride[2], true, full, false);
+ }
+ else if (d.wiener) {
+ if (d.process[1]) Int2Float(wrefUd[i], refU, ref_height[1], ref_width[1], ref_stride[1], ref_stride[1], true, full, false);
+ if (d.process[2]) Int2Float(wrefVd[i], refV, ref_height[2], ref_width[2], ref_stride[2], ref_stride[2], true, full, false);
}
// Store pointer to floating point YUV data into corresponding frame in the vector
@@ -844,12 +879,14 @@ void VBM3D_Process_Base::process_core_yuv()
srcVv.push_back(srcVd[i]);
refYv.push_back(refYd[i]);
- refUv.push_back(refUd[i]);
- refVv.push_back(refVd[i]);
+
+ wrefYv.push_back(wrefYd[i]);
+ wrefUv.push_back(wrefUd[i]);
+ wrefVv.push_back(wrefVd[i]);
}
// Execute kernel
- Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, refUv, refVv);
+ Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, wrefYv, wrefUv, wrefVv);
// Free memory for floating point YUV data
for (int i = 0; i < frames; ++i)
@@ -858,12 +895,11 @@ void VBM3D_Process_Base::process_core_yuv()
if (d.process[1]) AlignedFree(srcUd[i]);
if (d.process[2]) AlignedFree(srcVd[i]);
- if (d.rdef)
- {
- AlignedFree(refYd[i]);
- if (d.wiener && d.process[1]) AlignedFree(refUd[i]);
- if (d.wiener && d.process[2]) AlignedFree(refVd[i]);
- }
+ if (d.rdef) AlignedFree(refYd[i]);
+
+ if (d.wdef && d.process[0]) AlignedFree(wrefYd[i]);
+ if (d.wiener && d.process[1]) AlignedFree(wrefUd[i]);
+ if (d.wiener && d.process[2]) AlignedFree(wrefVd[i]);
}
}
@@ -879,8 +915,10 @@ void VBM3D_Process_Base::process_core_yuv()
std::vector srcVv;
std::vector refYv;
- std::vector refUv;
- std::vector refVv;
+
+ std::vector wrefYv;
+ std::vector wrefUv;
+ std::vector wrefVv;
// Get write/read pointer
auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0))
@@ -898,8 +936,16 @@ void VBM3D_Process_Base::process_core_yuv()
auto srcV = reinterpret_cast(vsapi->getReadPtr(v_src[i], 2));
auto refY = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 0));
- auto refU = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 1));
- auto refV = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 2));
+
+ auto wrefY = static_cast(nullptr);
+ auto wrefU = static_cast(nullptr);
+ auto wrefV = static_cast(nullptr);
+
+ if (v_wref.size() > i) {
+ wrefY = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0));
+ wrefU = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 1));
+ wrefV = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 2));
+ }
// Store pointer to floating point YUV data into corresponding frame in the vector
dstYv.push_back(dstY + dst_pcount[0] * (i * 2));
@@ -915,12 +961,14 @@ void VBM3D_Process_Base::process_core_yuv()
srcVv.push_back(srcV);
refYv.push_back(refY);
- refUv.push_back(refU);
- refVv.push_back(refV);
+
+ wrefYv.push_back(wrefY);
+ wrefUv.push_back(wrefU);
+ wrefVv.push_back(wrefV);
}
// Execute kernel
- Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, refUv, refVv);
+ Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, wrefYv, wrefUv, wrefVv);
}
@@ -936,11 +984,14 @@ void VBM3D_Process_Base::process_core_rgb()
std::vector srcVv;
std::vector refYv;
- std::vector refUv;
- std::vector refVv;
+
+ std::vector wrefYv;
+ std::vector wrefUv;
+ std::vector wrefVv;
std::vector srcYd(frames, nullptr), srcUd(frames, nullptr), srcVd(frames, nullptr);
- std::vector refYd(frames, nullptr), refUd(frames, nullptr), refVd(frames, nullptr);
+ std::vector refYd(frames, nullptr);
+ std::vector wrefYd(frames, nullptr), wrefUd(frames, nullptr), wrefVd(frames, nullptr);
// Get write pointer
auto dstY = reinterpret_cast(vsapi->getWritePtr(dst, 0))
@@ -961,23 +1012,32 @@ void VBM3D_Process_Base::process_core_rgb()
auto refG = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 1));
auto refB = reinterpret_cast(vsapi->getReadPtr(v_ref[i], 2));
+ auto wrefR = static_cast(nullptr);
+ auto wrefG = static_cast(nullptr);
+ auto wrefB = static_cast(nullptr);
+
+ if (v_wref.size() > i) {
+ wrefR = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 0));
+ wrefG = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 1));
+ wrefB = reinterpret_cast(vsapi->getReadPtr(v_wref[i], 2));
+ }
+
// Allocate memory for floating point YUV data
AlignedMalloc(srcYd[i], src_pcount[0]);
AlignedMalloc(srcUd[i], src_pcount[1]);
AlignedMalloc(srcVd[i], src_pcount[2]);
if (d.rdef)
- {
AlignedMalloc(refYd[i], ref_pcount[0]);
- if (d.wiener) AlignedMalloc(refUd[i], ref_pcount[1]);
- if (d.wiener) AlignedMalloc(refVd[i], ref_pcount[2]);
- }
else
- {
refYd[i] = srcYd[i];
- refUd[i] = srcUd[i];
- refVd[i] = srcVd[i];
- }
+
+ if (d.wdef)
+ AlignedMalloc(wrefYd[i], wref_pcount[0]);
+ else
+ wrefYd[i] = refYd[i];
+ if (d.wiener) AlignedMalloc(wrefUd[i], wref_pcount[1]);
+ if (d.wiener) AlignedMalloc(wrefVd[i], wref_pcount[2]);
// Convert src and ref from RGB data to floating point YUV data
RGB2FloatYUV(srcYd[i], srcUd[i], srcVd[i], srcR, srcG, srcB,
@@ -985,20 +1045,18 @@ void VBM3D_Process_Base::process_core_rgb()
ColorMatrix::OPP, true, false);
if (d.rdef)
- {
- if (d.wiener)
- {
- RGB2FloatYUV(refYd[i], refUd[i], refVd[i], refR, refG, refB,
- ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
- ColorMatrix::OPP, true, false);
- }
- else
- {
- RGB2FloatY(refYd[i], refR, refG, refB,
- ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
- ColorMatrix::OPP, true, false);
- }
- }
+ RGB2FloatY(refYd[i], refR, refG, refB,
+ ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
+ ColorMatrix::OPP, true, false);
+
+ if (d.wdef)
+ RGB2FloatYUV(wrefYd[i], wrefUd[i], wrefVd[i], wrefR, wrefG, wrefB,
+ wref_height[0], wref_width[0], wref_stride[0], wref_stride[0],
+ ColorMatrix::OPP, true, false);
+ else if (d.wiener)
+ RGB2FloatYUV(wrefYd[i], wrefUd[i], wrefVd[i], refR, refG, refB,
+ ref_height[0], ref_width[0], ref_stride[0], ref_stride[0],
+ ColorMatrix::OPP, true, false);
// Store pointer to floating point YUV data into corresponding frame in the vector
dstYv.push_back(dstY + dst_pcount[0] * (i * 2));
@@ -1014,12 +1072,14 @@ void VBM3D_Process_Base::process_core_rgb()
srcVv.push_back(srcVd[i]);
refYv.push_back(refYd[i]);
- refUv.push_back(refUd[i]);
- refVv.push_back(refVd[i]);
+
+ wrefYv.push_back(wrefYd[i]);
+ wrefUv.push_back(wrefUd[i]);
+ wrefVv.push_back(wrefVd[i]);
}
// Execute kernel
- Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, refUv, refVv);
+ Kernel(dstYv, dstUv, dstVv, srcYv, srcUv, srcVv, refYv, wrefYv, wrefUv, wrefVv);
// Free memory for floating point YUV data
for (int i = 0; i < frames; ++i)
@@ -1029,11 +1089,12 @@ void VBM3D_Process_Base::process_core_rgb()
AlignedFree(srcVd[i]);
if (d.rdef)
- {
AlignedFree(refYd[i]);
- if (d.wiener) AlignedFree(refUd[i]);
- if (d.wiener) AlignedFree(refVd[i]);
- }
+
+ if (d.wdef)
+ AlignedFree(wrefYd[i]);
+ if (d.wiener) AlignedFree(wrefUd[i]);
+ if (d.wiener) AlignedFree(wrefVd[i]);
}
}
diff --git a/source/VBM3D_Final.cpp b/source/VBM3D_Final.cpp
index ce0b3e3..566c2c1 100644
--- a/source/VBM3D_Final.cpp
+++ b/source/VBM3D_Final.cpp
@@ -36,6 +36,36 @@ int VBM3D_Final_Data::arguments_process(const VSMap *in, VSMap *out)
return 1;
}
+ auto error = 0;
+ wnode = vsapi->propGetNode(in, "wref", 0, &error);
+
+ if (error) {
+ wdef = false;
+ wnode = rnode;
+ wvi = rvi;
+ }
+ else {
+ wdef = true;
+ wvi = vsapi->getVideoInfo(wnode);
+
+ if (!isConstantFormat(wvi)) {
+ setError(out, "Invalid clip \"wref\", only constant format input supported");
+ return 1;
+ }
+ if (wvi->format != vi->format) {
+ setError(out, "input clip and clip \"wref\" must be of the same format");
+ return 1;
+ }
+ if (wvi->width != vi->width || wvi->height != vi->height) {
+ setError(out, "input clip and clip \"wref\" must be of the same width and height");
+ return 1;
+ }
+ if (wvi->numFrames != vi->numFrames) {
+ setError(out, "input clip and clip \"wref\" must have the same number of frames");
+ return 1;
+ }
+ }
+
// Initialize filter data for empirical Wiener filtering
init_filter_data();
diff --git a/source/VSPlugin.cpp b/source/VSPlugin.cpp
index 8183023..30d5c7d 100644
--- a/source/VSPlugin.cpp
+++ b/source/VSPlugin.cpp
@@ -213,6 +213,7 @@ static const VSFrameRef *VS_CC BM3D_Final_GetFrame(int n, int activationReason,
{
vsapi->requestFrameFilter(n, d->node, frameCtx);
if (d->rdef) vsapi->requestFrameFilter(n, d->rnode, frameCtx);
+ if (d->wdef) vsapi->requestFrameFilter(n, d->wnode, frameCtx);
}
else if (activationReason == arAllFramesReady)
{
@@ -342,6 +343,7 @@ static const VSFrameRef *VS_CC VBM3D_Final_GetFrame(int n, int activationReason,
{
vsapi->requestFrameFilter(n + o, d->node, frameCtx);
if (d->rdef) vsapi->requestFrameFilter(n + o, d->rnode, frameCtx);
+ if (d->wdef) vsapi->requestFrameFilter(n + o, d->wnode, frameCtx);
}
}
else if (activationReason == arAllFramesReady)
@@ -479,6 +481,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegiste
registerFunc("Final",
"input:clip;"
"ref:clip;"
+ "wref:clip:opt;"
"profile:data:opt;"
"sigma:float[]:opt;"
"block_size:int:opt;"
@@ -512,6 +515,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegiste
registerFunc("VFinal",
"input:clip;"
"ref:clip;"
+ "wref:clip:opt;"
"profile:data:opt;"
"sigma:float[]:opt;"
"radius:int:opt;"