diff --git a/modules/cudaimgproc/include/opencv2/cudaimgproc.hpp b/modules/cudaimgproc/include/opencv2/cudaimgproc.hpp index 9ee50c7305..4c9ee0f48e 100644 --- a/modules/cudaimgproc/include/opencv2/cudaimgproc.hpp +++ b/modules/cudaimgproc/include/opencv2/cudaimgproc.hpp @@ -57,6 +57,7 @@ @{ @defgroup cudaimgproc_color Color space processing @defgroup cudaimgproc_hist Histogram Calculation + @defgroup cudaimgproc_shape Structural Analysis and Shape Descriptors @defgroup cudaimgproc_hough Hough Transform @defgroup cudaimgproc_feature Feature Detection @} @@ -779,9 +780,84 @@ CV_EXPORTS_AS(connectedComponentsWithAlgorithm) void connectedComponents(InputAr CV_EXPORTS_W void connectedComponents(InputArray image, OutputArray labels, int connectivity = 8, int ltype = CV_32S); - //! @} +//! @addtogroup cudaimgproc_shape +//! @{ + + /** @brief Order of image moments. + * @param FIRST_ORDER_MOMENTS First order moments + * @param SECOND_ORDER_MOMENTS Second order moments. + * @param THIRD_ORDER_MOMENTS Third order moments. + * */ +enum MomentsOrder { + FIRST_ORDER_MOMENTS = 1, + SECOND_ORDER_MOMENTS = 2, + THIRD_ORDER_MOMENTS = 3 +}; + +/** @brief Returns the number of image moments less than or equal to the largest image moments \a order. +@param order Order of largest moments to calculate with lower order moments requiring less computation. +@returns number of image moments. + +@sa cuda::moments, cuda::spatialMoments, cuda::MomentsOrder + */ +CV_EXPORTS_W int numMoments(const MomentsOrder order); + +/** @brief Calculates all of the spatial moments up to the 3rd order of a rasterized shape. + +Asynchronous version of cuda::moments() which only calculates the spatial (not centralized or normalized) moments, up to the 3rd order, of a rasterized shape. +Each moment is returned as a column entry in the 1D \a moments array. + +@param src Raster image (single-channel 2D array). +@param [out] moments 1D array with each column entry containing a spatial image moment. +@param binaryImage If it is true, all non-zero image pixels are treated as 1's. +@param order Order of largest moments to calculate with lower order moments requiring less computation. +@param momentsType Precision to use when calculating moments. Available types are `CV_32F` and `CV_64F` with the performance of `CV_32F` an order of magnitude greater than `CV_64F`. If the image is small the accuracy from `CV_32F` can be equal or very close to `CV_64F`. +@param stream Stream for the asynchronous version. + +@note For maximum performance pre-allocate a 1D GpuMat for \a moments of the correct type and size large enough to store the all the image moments of up to the desired \a order. e.g. With \a order === MomentsOrder::SECOND_ORDER_MOMENTS and \a momentsType == `CV_32F` \a moments can be allocated as +``` +GpuMat momentsDevice(1,numMoments(MomentsOrder::SECOND_ORDER_MOMENTS),CV_32F) +``` +The central and normalized moments can easily be calculated on the host by downloading the \a moments array and using the cv::Moments constructor. e.g. +``` +HostMem momentsHostMem(1, numMoments(MomentsOrder::SECOND_ORDER_MOMENTS), CV_32F); +momentsDevice.download(momentsHostMem, stream); +stream.waitForCompletion(); +Mat momentsMat = momentsHostMem.createMatHeader(); +cv::Moments cvMoments(momentsMat.at(0), momentsMat.at(1), momentsMat.at(2), momentsMat.at(3), momentsMat.at(4), momentsMat.at(5), momentsMat.at(6), momentsMat.at(7), momentsMat.at(8), momentsMat.at(9)); +``` +see the \a CUDA_TEST_P(Moments, Async) test inside opencv_contrib_source_code/modules/cudaimgproc/test/test_moments.cpp for an example. +@returns cv::Moments. +@sa cuda::moments +*/ +CV_EXPORTS_W void spatialMoments(InputArray src, OutputArray moments, const bool binaryImage = false, const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS, const int momentsType = CV_64F, Stream& stream = Stream::Null()); + +/** @brief Calculates all of the moments up to the 3rd order of a rasterized shape. + +The function computes moments, up to the 3rd order, of a rasterized shape. The +results are returned in the structure cv::Moments. + +@param src Raster image (single-channel 2D array). +@param binaryImage If it is true, all non-zero image pixels are treated as 1's. +@param order Order of largest moments to calculate with lower order moments requiring less computation. + @param momentsType Precision to use when calculating moments. Available types are `CV_32F` and `CV_64F` with the performance of `CV_32F` an order of magnitude greater than `CV_64F`. If the image is small the accuracy from `CV_32F` can be equal or very close to `CV_64F`. + +@note For maximum performance use the asynchronous version cuda::spatialMoments() as this version interally allocates and deallocates both GpuMat and HostMem to respectively perform the calculation on the device and download the result to the host. +The costly HostMem allocation cannot be avoided however the GpuMat device allocation can be by using BufferPool, e.g. +``` + setBufferPoolUsage(true); + setBufferPoolConfig(getDevice(), numMoments(order) * ((momentsType == CV_64F) ? sizeof(double) : sizeof(float)), 1); +``` +see the \a CUDA_TEST_P(Moments, Accuracy) test inside opencv_contrib_source_code/modules/cudaimgproc/test/test_moments.cpp for an example. +@returns cv::Moments. +@sa cuda::spatialMoments + */ +CV_EXPORTS_W Moments moments(InputArray src, const bool binaryImage = false, const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS, const int momentsType = CV_64F); + +//! @} cudaimgproc_shape + }} // namespace cv { namespace cuda { #endif /* OPENCV_CUDAIMGPROC_HPP */ diff --git a/modules/cudaimgproc/misc/python/test/test_cudaimgproc.py b/modules/cudaimgproc/misc/python/test/test_cudaimgproc.py index 0548cbcd8b..f07617f53e 100644 --- a/modules/cudaimgproc/misc/python/test/test_cudaimgproc.py +++ b/modules/cudaimgproc/misc/python/test/test_cudaimgproc.py @@ -89,5 +89,30 @@ def test_cvtColor(self): self.assertTrue(np.allclose(cv.cuda.cvtColor(cuMat, cv.COLOR_BGR2HSV).download(), cv.cvtColor(npMat, cv.COLOR_BGR2HSV))) + def test_moments(self): + # setup + src_host = (np.ones([10,10])).astype(np.uint8)*255 + cpu_moments = cv.moments(src_host, True) + moments_order = cv.cuda.THIRD_ORDER_MOMENTS + n_moments = cv.cuda.numMoments(cv.cuda.THIRD_ORDER_MOMENTS) + src_device = cv.cuda.GpuMat(src_host) + + # synchronous + cv.cuda.setBufferPoolUsage(True) + cv.cuda.setBufferPoolConfig(cv.cuda.getDevice(), n_moments * np.dtype(float).itemsize, 1); + gpu_moments = cv.cuda.moments(src_device, True, moments_order, cv.CV_64F) + self.assertTrue(len([1 for moment_type in cpu_moments if moment_type in gpu_moments and cpu_moments[moment_type] == gpu_moments[moment_type]]) == 24) + + # asynchronous + stream = cv.cuda.Stream() + moments_array_host = np.empty([1, n_moments], np.float64) + cv.cuda.registerPageLocked(moments_array_host) + moments_array_device = cv.cuda.GpuMat(1, n_moments, cv.CV_64F) + cv.cuda.spatialMoments(src_device, moments_array_device, True, moments_order, cv.CV_64F, stream) + moments_array_device.download(stream, moments_array_host); + stream.waitForCompletion() + cv.cuda.unregisterPageLocked(moments_array_host) + self.assertTrue(len([ 1 for moment_type,gpu_moment in zip(cpu_moments,moments_array_host[0]) if cpu_moments[moment_type] == gpu_moment]) == 10) + if __name__ == '__main__': NewOpenCVTests.bootstrap() \ No newline at end of file diff --git a/modules/cudaimgproc/perf/perf_moments.cpp b/modules/cudaimgproc/perf/perf_moments.cpp new file mode 100644 index 0000000000..ba91afbacf --- /dev/null +++ b/modules/cudaimgproc/perf/perf_moments.cpp @@ -0,0 +1,61 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "perf_precomp.hpp" + +namespace opencv_test { namespace { +static void drawCircle(cv::Mat& dst, const cv::Vec3i& circle, bool fill) +{ + dst.setTo(Scalar::all(0)); + cv::circle(dst, Point2i(circle[0], circle[1]), circle[2], Scalar::all(255), fill ? -1 : 1, cv::LINE_AA); +} + +DEF_PARAM_TEST(Sz_Depth, Size, MatDepth); +PERF_TEST_P(Sz_Depth, SpatialMoments, Combine(CUDA_TYPICAL_MAT_SIZES, Values(MatDepth(CV_32F), MatDepth((CV_64F))))) +{ + const cv::Size size = GET_PARAM(0); + const int momentsType = GET_PARAM(1); + Mat imgHost(size, CV_8U); + const Vec3i circle(size.width / 2, size.height / 2, static_cast(static_cast(size.width / 2) * 0.9)); + drawCircle(imgHost, circle, true); + if (PERF_RUN_CUDA()) { + const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS; + const int nMoments = numMoments(order); + GpuMat momentsDevice(1, nMoments, momentsType); + const GpuMat imgDevice(imgHost); + TEST_CYCLE() cuda::spatialMoments(imgDevice, momentsDevice, false, order, momentsType); + SANITY_CHECK_NOTHING(); + } + else { + cv::Moments momentsHost; + TEST_CYCLE() momentsHost = cv::moments(imgHost, false); + SANITY_CHECK_NOTHING(); + } +} + +PERF_TEST_P(Sz_Depth, Moments, Combine(CUDA_TYPICAL_MAT_SIZES, Values(MatDepth(CV_32F), MatDepth(CV_64F)))) +{ + const cv::Size size = GET_PARAM(0); + const int momentsType = GET_PARAM(1); + Mat imgHost(size, CV_8U); + const Vec3i circle(size.width / 2, size.height / 2, static_cast(static_cast(size.width / 2) * 0.9)); + drawCircle(imgHost, circle, true); + if (PERF_RUN_CUDA()) { + const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS; + const int nMoments = numMoments(order); + setBufferPoolUsage(true); + setBufferPoolConfig(getDevice(), nMoments * ((momentsType == CV_64F) ? sizeof(double) : sizeof(float)), 1); + const GpuMat imgDevice(imgHost); + cv::Moments momentsHost; + TEST_CYCLE() momentsHost = cuda::moments(imgDevice, false, order, momentsType); + SANITY_CHECK_NOTHING(); + } + else { + cv::Moments momentsHost; + TEST_CYCLE() momentsHost = cv::moments(imgHost, false); + SANITY_CHECK_NOTHING(); + } +} + +}} diff --git a/modules/cudaimgproc/src/cuda/moments.cu b/modules/cudaimgproc/src/cuda/moments.cu new file mode 100644 index 0000000000..9828c5614b --- /dev/null +++ b/modules/cudaimgproc/src/cuda/moments.cu @@ -0,0 +1,186 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#if !defined CUDA_DISABLER + +#include +#include +#include "moments.cuh" + +namespace cv { namespace cuda { namespace device { namespace imgproc { + +constexpr int blockSizeX = 32; +constexpr int blockSizeY = 16; + +template +__device__ T butterflyWarpReduction(T value) { + for (int i = 16; i >= 1; i /= 2) + value += __shfl_xor_sync(0xffffffff, value, i, 32); + return value; +} + +template +__device__ T butterflyHalfWarpReduction(T value) { + for (int i = 8; i >= 1; i /= 2) + value += __shfl_xor_sync(0xffff, value, i, 32); + return value; +} + +template +__device__ void updateSums(const T val, const unsigned int x, T r[4]) { + const T x2 = x * x; + const T x3 = static_cast(x) * x2; + r[0] += val; + r[1] += val * x; + if (nMoments >= n12) r[2] += val * x2; + if (nMoments >= n123) r[3] += val * x3; +} + +template +__device__ void rowReductions(const PtrStepSz img, const bool binary, const unsigned int y, TMoments r[4], TMoments smem[][nMoments + 1]) { + for (int x = threadIdx.x; x < img.cols; x += blockDim.x) { + const TMoments val = (!binary || img(y, x) == 0) ? img(y, x) : 1; + updateSums(val, x, r); + } +} + +template +__device__ void rowReductionsCoalesced(const PtrStepSz img, const bool binary, const unsigned int y, TMoments r[4], const int offsetX, TMoments smem[][nMoments + 1]) { + const int alignedOffset = fourByteAligned ? 0 : 4 - offsetX; + // load uncoalesced head + if (!fourByteAligned && threadIdx.x == 0) { + for (int x = 0; x < ::min(alignedOffset, static_cast(img.cols)); x++) { + const TMoments val = (!binary || img(y, x) == 0) ? img(y, x) : 1; + updateSums(val, x, r); + } + } + + // coalesced loads + const unsigned int* rowPtrIntAligned = (const unsigned int*)(fourByteAligned ? img.ptr(y) : img.ptr(y) + alignedOffset); + const int cols4 = fourByteAligned ? img.cols / 4 : (img.cols - alignedOffset) / 4; + for (int x = threadIdx.x; x < cols4; x += blockDim.x) { + const unsigned int data = rowPtrIntAligned[x]; +#pragma unroll 4 + for (int i = 0; i < 4; i++) { + const int iX = alignedOffset + 4 * x + i; + const uchar ucharVal = ((data >> i * 8) & 0xFFU); + const TMoments val = (!binary || ucharVal == 0) ? ucharVal : 1; + updateSums(val, iX, r); + } + } + + // load uncoalesced tail + if (threadIdx.x == 0) { + const int iTailStart = fourByteAligned ? cols4 * 4 : cols4 * 4 + alignedOffset; + for (int x = iTailStart; x < img.cols; x++) { + const TMoments val = (!binary || img(y, x) == 0) ? img(y, x) : 1; + updateSums(val, x, r); + } + } +} + +template +__global__ void spatialMoments(const PtrStepSz img, const bool binary, TMoments* moments, const int offsetX = 0) { + const unsigned int y = blockIdx.x * blockDim.y + threadIdx.y; + __shared__ TMoments smem[blockSizeY][nMoments + 1]; + if (threadIdx.y < nMoments && threadIdx.x < blockSizeY) + smem[threadIdx.x][threadIdx.y] = 0; + __syncthreads(); + + TMoments r[4] = { 0 }; + if (y < img.rows) { + if (coalesced) + rowReductionsCoalesced(img, binary, y, r, offsetX, smem); + else + rowReductions(img, binary, y, r, smem); + } + + const unsigned long y2 = y * y; + const TMoments y3 = static_cast(y2) * y; + const TMoments res = butterflyWarpReduction(r[0]); + if (res) { + smem[threadIdx.y][0] = res; //0th + smem[threadIdx.y][1] = butterflyWarpReduction(r[1]); //1st + smem[threadIdx.y][2] = y * res; //1st + if (nMoments >= n12) { + smem[threadIdx.y][3] = butterflyWarpReduction(r[2]); //2nd + smem[threadIdx.y][4] = smem[threadIdx.y][1] * y; //2nd + smem[threadIdx.y][5] = y2 * res; //2nd + } + if (nMoments >= n123) { + smem[threadIdx.y][6] = butterflyWarpReduction(r[3]); //3rd + smem[threadIdx.y][7] = smem[threadIdx.y][3] * y; //3rd + smem[threadIdx.y][8] = smem[threadIdx.y][1] * y2; //3rd + smem[threadIdx.y][9] = y3 * res; //3rd + } + } + __syncthreads(); + + if (threadIdx.x < blockSizeY && threadIdx.y < nMoments) + smem[threadIdx.y][nMoments] = butterflyHalfWarpReduction(smem[threadIdx.x][threadIdx.y]); + __syncthreads(); + + if (threadIdx.y == 0 && threadIdx.x < nMoments) { + if (smem[threadIdx.x][nMoments]) + cudev::atomicAdd(&moments[threadIdx.x], smem[threadIdx.x][nMoments]); + } +} + +template struct momentsDispatcherNonChar { + static void call(const PtrStepSz src, PtrStepSz moments, const bool binary, const int offsetX, const cudaStream_t stream) { + dim3 blockSize(blockSizeX, blockSizeY); + dim3 gridSize = dim3(divUp(src.rows, blockSizeY)); + spatialMoments << > > (src, binary, moments.ptr()); + if (stream == 0) + cudaSafeCall(cudaStreamSynchronize(stream)); + }; +}; + +template struct momentsDispatcherChar { + static void call(const PtrStepSz src, PtrStepSz moments, const bool binary, const int offsetX, const cudaStream_t stream) { + dim3 blockSize(blockSizeX, blockSizeY); + dim3 gridSize = dim3(divUp(src.rows, blockSizeY)); + if (offsetX) + spatialMoments << > > (src, binary, moments.ptr(), offsetX); + else + spatialMoments << > > (src, binary, moments.ptr()); + + if (stream == 0) + cudaSafeCall(cudaStreamSynchronize(stream)); + }; +}; + +template struct momentsDispatcher : momentsDispatcherNonChar {}; +template struct momentsDispatcher : momentsDispatcherChar {}; +template struct momentsDispatcher : momentsDispatcherChar {}; + +template +void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream) { + if (order == 1) + momentsDispatcher::call(static_cast>(src), static_cast>(moments), binary, offsetX, stream); + else if (order == 2) + momentsDispatcher::call(static_cast>(src), static_cast>(moments), binary, offsetX, stream); + else if (order == 3) + momentsDispatcher::call(static_cast>(src), static_cast>(moments), binary, offsetX, stream); +}; + +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); + +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +template void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); + +}}}} + +#endif /* CUDA_DISABLER */ diff --git a/modules/cudaimgproc/src/cuda/moments.cuh b/modules/cudaimgproc/src/cuda/moments.cuh new file mode 100644 index 0000000000..0041882b64 --- /dev/null +++ b/modules/cudaimgproc/src/cuda/moments.cuh @@ -0,0 +1,6 @@ +#pragma once +namespace cv { namespace cuda { namespace device { namespace imgproc { + constexpr int n1 = 3; + constexpr int n12 = 6; + constexpr int n123 = 10; +}}}} diff --git a/modules/cudaimgproc/src/moments.cpp b/modules/cudaimgproc/src/moments.cpp new file mode 100644 index 0000000000..ced5b2f8c6 --- /dev/null +++ b/modules/cudaimgproc/src/moments.cpp @@ -0,0 +1,67 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "precomp.hpp" +#include "cuda/moments.cuh" + +using namespace cv; +using namespace cv::cuda; + +int cv::cuda::numMoments(const MomentsOrder order) { + return order == MomentsOrder::FIRST_ORDER_MOMENTS ? device::imgproc::n1 : order == MomentsOrder::SECOND_ORDER_MOMENTS ? device::imgproc::n12 : device::imgproc::n123; +} + +#if !defined (HAVE_CUDA) || defined (CUDA_DISABLER) + Moments cv::cuda::moments(InputArray src, const bool binary, const MomentsOrder order, const int momentsType) { throw_no_cuda(); } + void spatialMoments(InputArray src, OutputArray moments, const bool binary, const MomentsOrder order, const int momentsType, Stream& stream) { throw_no_cuda(); } +#else /* !defined (HAVE_CUDA) */ + +namespace cv { namespace cuda { namespace device { namespace imgproc { + template + void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); +}}}} + +void cv::cuda::spatialMoments(InputArray src, OutputArray moments, const bool binary, const MomentsOrder order, const int momentsType, Stream& stream) { + CV_Assert(src.depth() <= CV_64F); + const GpuMat srcDevice = getInputMat(src, stream); + + CV_Assert(momentsType == CV_32F || momentsType == CV_64F); + const int nMoments = numMoments(order); + const int momentsCols = nMoments < moments.cols() ? moments.cols() : nMoments; + GpuMat momentsDevice = getOutputMat(moments, 1, momentsCols, momentsType, stream); + momentsDevice.setTo(0); + + Point ofs; Size wholeSize; + srcDevice.locateROI(wholeSize, ofs); + + typedef void (*func_t)(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream); + static const func_t funcs[7][2] = + { + {device::imgproc::moments, device::imgproc::moments }, + {device::imgproc::moments, device::imgproc::moments }, + {device::imgproc::moments, device::imgproc::moments}, + {device::imgproc::moments, device::imgproc::moments }, + {device::imgproc::moments, device::imgproc::moments }, + {device::imgproc::moments, device::imgproc::moments }, + {device::imgproc::moments, device::imgproc::moments } + }; + + const func_t func = funcs[srcDevice.depth()][momentsType == CV_64F]; + func(srcDevice, momentsDevice, binary, static_cast(order), ofs.x, StreamAccessor::getStream(stream)); + syncOutput(momentsDevice, moments, stream); +} + +Moments cv::cuda::moments(InputArray src, const bool binary, const MomentsOrder order, const int momentsType) { + Stream& stream = Stream::Null(); + HostMem dst; + spatialMoments(src, dst, binary, order, momentsType, stream); + stream.waitForCompletion(); + Mat moments = dst.createMatHeader(); + if(momentsType == CV_32F) + return Moments(moments.at(0), moments.at(1), moments.at(2), moments.at(3), moments.at(4), moments.at(5), moments.at(6), moments.at(7), moments.at(8), moments.at(9)); + else + return Moments(moments.at(0), moments.at(1), moments.at(2), moments.at(3), moments.at(4), moments.at(5), moments.at(6), moments.at(7), moments.at(8), moments.at(9)); +} + +#endif /* !defined (HAVE_CUDA) */ diff --git a/modules/cudaimgproc/test/test_moments.cpp b/modules/cudaimgproc/test/test_moments.cpp new file mode 100644 index 0000000000..c5dd889f09 --- /dev/null +++ b/modules/cudaimgproc/test/test_moments.cpp @@ -0,0 +1,124 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "test_precomp.hpp" + +#ifdef HAVE_CUDA + +namespace opencv_test { namespace { + +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// Moments + +CV_ENUM(MaxMomentsOrder, MomentsOrder::FIRST_ORDER_MOMENTS, MomentsOrder::SECOND_ORDER_MOMENTS, MomentsOrder::THIRD_ORDER_MOMENTS) + +PARAM_TEST_CASE(Moments, cv::cuda::DeviceInfo, cv::Size, bool, MatDepth, MatDepth, UseRoi, MaxMomentsOrder) +{ + DeviceInfo devInfo; + Size size; + bool isBinary; + float pcWidth = 0.6f; + int momentsType; + int imgType; + bool useRoi; + MomentsOrder order; + + virtual void SetUp() + { + devInfo = GET_PARAM(0); + size = GET_PARAM(1); + isBinary = GET_PARAM(2); + momentsType = GET_PARAM(3); + imgType = GET_PARAM(4); + useRoi = GET_PARAM(5); + order = static_cast(static_cast(GET_PARAM(6))); + cv::cuda::setDevice(devInfo.deviceID()); + } + + static void drawCircle(cv::Mat& dst, const cv::Vec3i& circle, bool fill) + { + dst.setTo(Scalar::all(0)); + cv::circle(dst, Point2i(circle[0], circle[1]), circle[2], Scalar::all(255), fill ? -1 : 1, cv::LINE_AA); + } +}; + +bool Equal(const double m0, const double m1, const double absPcErr) { + if (absPcErr == 0) return m0 == m1; + if (m0 == 0) { + if (m1 < absPcErr) return true; + else return false; + } + const double pcDiff = abs(m0 - m1) / m1; + return pcDiff < absPcErr; +} + +void CheckMoments(const cv::Moments m0, const cv::Moments m1, const MomentsOrder order, const int momentsType) { + double absPcErr = momentsType == CV_64F ? 0 : 5e-7; + ASSERT_TRUE(Equal(m0.m00, m1.m00, absPcErr)) << "m0.m00: " << m0.m00 << ", m1.m00: " << m1.m00 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m10, m1.m10, absPcErr)) << "m0.m10: " << m0.m10 << ", m1.m10: " << m1.m10 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m01, m1.m01, absPcErr)) << "m0.m01: " << m0.m01 << ", m1.m01: " << m1.m01 << ", absPcErr: " << absPcErr; + if (static_cast(order) >= static_cast(MomentsOrder::SECOND_ORDER_MOMENTS)) { + ASSERT_TRUE(Equal(m0.m20, m1.m20, absPcErr)) << "m0.m20: " << m0.m20 << ", m1.m20: " << m1.m20 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m11, m1.m11, absPcErr)) << "m0.m11: " << m0.m11 << ", m1.m11: " << m1.m11 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m02, m1.m02, absPcErr)) << "m0.m02: " << m0.m02 << ", m1.m02: " << m1.m02 << ", absPcErr: " << absPcErr; + } + if (static_cast(order) >= static_cast(MomentsOrder::THIRD_ORDER_MOMENTS)) { + ASSERT_TRUE(Equal(m0.m30, m1.m30, absPcErr)) << "m0.m30: " << m0.m30 << ", m1.m30: " << m1.m30 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m21, m1.m21, absPcErr)) << "m0.m21: " << m0.m21 << ", m1.m21: " << m1.m21 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m12, m1.m12, absPcErr)) << "m0.m12: " << m0.m12 << ", m1.m12: " << m1.m12 << ", absPcErr: " << absPcErr; + ASSERT_TRUE(Equal(m0.m03, m1.m03, absPcErr)) << "m0.m03: " << m0.m03 << ", m1.m03: " << m1.m03 << ", absPcErr: " << absPcErr; + } +} + +CUDA_TEST_P(Moments, Accuracy) +{ + Mat imgHost(size, imgType); + const Rect roi = useRoi ? Rect(1, 0, imgHost.cols - 2, imgHost.rows) : Rect(0, 0, imgHost.cols, imgHost.rows); + const Vec3i circle(size.width / 2, size.height / 2, static_cast(static_cast(size.width/2) * pcWidth)); + drawCircle(imgHost, circle, true); + const GpuMat imgDevice(imgHost); + const int nMoments = numMoments(order); + setBufferPoolUsage(true); + setBufferPoolConfig(getDevice(), nMoments * ((momentsType == CV_64F) ? sizeof(double) : sizeof(float)), 1); + const cv::Moments moments = cuda::moments(imgDevice(roi), isBinary, order, momentsType); + Mat imgHostFloat; imgHost(roi).convertTo(imgHostFloat, CV_32F); + const cv::Moments momentsGs = cv::moments(imgHostFloat, isBinary); + CheckMoments(momentsGs, moments, order, momentsType); +} + +CUDA_TEST_P(Moments, Async) +{ + Stream stream; + const int nMoments = numMoments(order); + GpuMat momentsDevice(1, nMoments, momentsType); + Mat imgHost(size, imgType); + const Rect roi = useRoi ? Rect(1, 0, imgHost.cols - 2, imgHost.rows) : Rect(0, 0, imgHost.cols, imgHost.rows); + const Vec3i circle(size.width / 2, size.height / 2, static_cast(static_cast(size.width/2) * pcWidth)); + drawCircle(imgHost, circle, true); + const GpuMat imgDevice(imgHost); + cuda::spatialMoments(imgDevice(roi), momentsDevice, isBinary, order, momentsType, stream); + HostMem momentsHost(1, nMoments, momentsType); + momentsDevice.download(momentsHost, stream); + stream.waitForCompletion(); + Mat momentsHost64F = momentsHost.createMatHeader(); + if (momentsType == CV_32F) + momentsHost.createMatHeader().convertTo(momentsHost64F, CV_64F); + const cv::Moments moments = cv::Moments(momentsHost64F.at(0), momentsHost64F.at(1), momentsHost64F.at(2), momentsHost64F.at(3), momentsHost64F.at(4), momentsHost64F.at(5), momentsHost64F.at(6), momentsHost64F.at(7), momentsHost64F.at(8), momentsHost64F.at(9)); + Mat imgHostAdjustedType = imgHost(roi); + if (imgType != CV_8U && imgType != CV_32F) + imgHost(roi).convertTo(imgHostAdjustedType, CV_32F); + const cv::Moments momentsGs = cv::moments(imgHostAdjustedType, isBinary); + CheckMoments(momentsGs, moments, order, momentsType); +} + +#define SIZES DIFFERENT_SIZES +#define GRAYSCALE_BINARY testing::Bool() +#define MOMENTS_TYPE testing::Values(MatDepth(CV_32F), MatDepth(CV_64F)) +#define IMG_TYPE ALL_DEPTH +#define USE_ROI WHOLE_SUBMAT +#define MOMENTS_ORDER testing::Values(MaxMomentsOrder(MomentsOrder::FIRST_ORDER_MOMENTS), MaxMomentsOrder(MomentsOrder::SECOND_ORDER_MOMENTS), MaxMomentsOrder(MomentsOrder::THIRD_ORDER_MOMENTS)) +INSTANTIATE_TEST_CASE_P(CUDA_ImgProc, Moments, testing::Combine(ALL_DEVICES, SIZES, GRAYSCALE_BINARY, MOMENTS_TYPE, IMG_TYPE, USE_ROI, MOMENTS_ORDER)); +}} // namespace + +#endif // HAVE_CUDA diff --git a/modules/cudev/include/opencv2/cudev/util/atomic.hpp b/modules/cudev/include/opencv2/cudev/util/atomic.hpp index 190e8ee48b..600f836749 100644 --- a/modules/cudev/include/opencv2/cudev/util/atomic.hpp +++ b/modules/cudev/include/opencv2/cudev/util/atomic.hpp @@ -83,7 +83,7 @@ __device__ __forceinline__ float atomicAdd(float* address, float val) __device__ static double atomicAdd(double* address, double val) { -#if CV_CUDEV_ARCH >= 130 +#if CV_CUDEV_ARCH < 600 unsigned long long int* address_as_ull = (unsigned long long int*) address; unsigned long long int old = *address_as_ull, assumed; do { @@ -93,9 +93,7 @@ __device__ static double atomicAdd(double* address, double val) } while (assumed != old); return __longlong_as_double(old); #else - CV_UNUSED(address); - CV_UNUSED(val); - return 0.0; + return ::atomicAdd(address, val); #endif }