Skip to content

Commit

Permalink
Make reg_getVoxelBasedNmiGradient_gpu() on a par with CPU #92
Browse files Browse the repository at this point in the history
- Optimise reg_getVoxelBasedNmiGradient_gpu()
- Get the function ready for multi-timepoint support
  • Loading branch information
onurulgen committed Nov 15, 2023
1 parent 52204d7 commit bc4c672
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 595 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
361
362
6 changes: 0 additions & 6 deletions reg-lib/cuda/BlockSize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
namespace NiftyReg {
/* *************************************************************** */
struct BlockSize {
unsigned reg_getVoxelBasedNmiGradientUsingPw2D;
unsigned reg_getVoxelBasedNmiGradientUsingPw3D;
unsigned reg_affine_getDeformationField;
unsigned reg_spline_getDeformationField2D;
unsigned reg_spline_getDeformationField3D;
Expand Down Expand Up @@ -54,8 +52,6 @@ struct BlockSize {
/* *************************************************************** */
struct BlockSize100: public BlockSize {
BlockSize100() {
reg_getVoxelBasedNmiGradientUsingPw2D = 384; // 21 reg - 24 smem - 32 cmem
reg_getVoxelBasedNmiGradientUsingPw3D = 320; // 25 reg - 24 smem - 32 cmem
reg_affine_getDeformationField = 512; // 16 reg - 24 smem
reg_spline_getDeformationField2D = 384; // 20 reg - 6168 smem - 28 cmem
reg_spline_getDeformationField3D = 192; // 37 reg - 6168 smem - 28 cmem
Expand Down Expand Up @@ -96,8 +92,6 @@ struct BlockSize100: public BlockSize {
/* *************************************************************** */
struct BlockSize300: public BlockSize {
BlockSize300() {
reg_getVoxelBasedNmiGradientUsingPw2D = 768; // 38 reg
reg_getVoxelBasedNmiGradientUsingPw3D = 640; // 45 reg
reg_affine_getDeformationField = 1024; // 23 reg
reg_spline_getDeformationField2D = 1024; // 34 reg
reg_spline_getDeformationField3D = 1024; // 34 reg
Expand Down
184 changes: 115 additions & 69 deletions reg-lib/cuda/_reg_nmi_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
*/

#include "_reg_nmi_gpu.h"
#include "_reg_nmi_kernels.cu"
#include "_reg_common_cuda_kernels.cu"

/* *************************************************************** */
reg_nmi_gpu::reg_nmi_gpu(): reg_nmi::reg_nmi() {
Expand Down Expand Up @@ -298,95 +298,141 @@ double reg_nmi_gpu::GetSimilarityMeasureValueBw() {
this->approximatePw);
}
/* *************************************************************** */
template<bool is3d> struct Derivative { using Type = double3; };
template<> struct Derivative<false> { using Type = double2; };
/* *************************************************************** */
/// Called when we only have one target and one source image
template<bool is3d>
void reg_getVoxelBasedNmiGradient_gpu(const nifti_image *referenceImage,
const cudaArray *referenceImageCuda,
const float *warpedImageCuda,
const float4 *warpedGradientCuda,
const float *logJointHistogramCuda,
const double *jointHistogramLogCuda,
float4 *voxelBasedGradientCuda,
const int *maskCuda,
const size_t activeVoxelNumber,
const double *entropies,
const int refBinning,
const int floBinning) {
auto blockSize = CudaContext::GetBlockSize();
const int refBinNumber,
const int floBinNumber,
const int totalBinNumber,
const double timePointWeight,
const int currentTimePoint) {
const size_t voxelNumber = NiftiImage::calcVoxelNumber(referenceImage, 3);
const int3 imageSize = make_int3(referenceImage->nx, referenceImage->ny, referenceImage->nz);
const int binNumber = refBinning * floBinning + refBinning + floBinning;
const float normalisedJE = (float)(entropies[2] * entropies[3]);
const float nmi = (float)((entropies[0] + entropies[1]) / entropies[2]);
const double normalisedJE = entropies[2] * entropies[3];
const double nmi = (entropies[0] + entropies[1]) / entropies[2];
const int referenceOffset = refBinNumber * floBinNumber;
const int floatingOffset = referenceOffset + refBinNumber;

auto referenceImageTexture = Cuda::CreateTextureObject(referenceImageCuda, cudaResourceTypeArray, 0,
cudaChannelFormatKindNone, 1, cudaFilterModePoint, true);
auto warpedImageTexture = Cuda::CreateTextureObject(warpedImageCuda, cudaResourceTypeLinear, voxelNumber * sizeof(float),
cudaChannelFormatKindFloat, 1);
auto warpedGradientTexture = Cuda::CreateTextureObject(warpedGradientCuda, cudaResourceTypeLinear, voxelNumber * sizeof(float4),
cudaChannelFormatKindFloat, 4);
auto histogramTexture = Cuda::CreateTextureObject(logJointHistogramCuda, cudaResourceTypeLinear, binNumber * sizeof(float),
cudaChannelFormatKindFloat, 1);
auto maskTexture = Cuda::CreateTextureObject(maskCuda, cudaResourceTypeLinear, activeVoxelNumber * sizeof(int),
cudaChannelFormatKindSigned, 1);
auto referenceImageTexturePtr = Cuda::CreateTextureObject(referenceImageCuda, cudaResourceTypeArray, 0,
cudaChannelFormatKindNone, 1, cudaFilterModePoint, true);
auto warpedImageTexturePtr = Cuda::CreateTextureObject(warpedImageCuda + currentTimePoint * voxelNumber, cudaResourceTypeLinear,
voxelNumber * sizeof(float), cudaChannelFormatKindFloat, 1);
auto warpedGradientTexturePtr = Cuda::CreateTextureObject(warpedGradientCuda, cudaResourceTypeLinear, voxelNumber * sizeof(float4),
cudaChannelFormatKindFloat, 4);
auto maskTexturePtr = Cuda::CreateTextureObject(maskCuda, cudaResourceTypeLinear, activeVoxelNumber * sizeof(int),
cudaChannelFormatKindSigned, 1);
auto referenceImageTexture = *referenceImageTexturePtr;
auto warpedImageTexture = *warpedImageTexturePtr;
auto warpedGradientTexture = *warpedGradientTexturePtr;
auto maskTexture = *maskTexturePtr;

if (referenceImage->nz > 1) {
const unsigned blocks = blockSize->reg_getVoxelBasedNmiGradientUsingPw3D;
const unsigned grids = (unsigned)Ceil(sqrtf((float)activeVoxelNumber / (float)blocks));
const dim3 gridDims(grids, grids, 1);
const dim3 blockDims(blocks, 1, 1);
reg_getVoxelBasedNmiGradientUsingPw3D_kernel<<<gridDims, blockDims>>>(voxelBasedGradientCuda, *referenceImageTexture, *warpedImageTexture,
*warpedGradientTexture, *histogramTexture, *maskTexture,
imageSize, refBinning, floBinning, normalisedJE, nmi,
(unsigned)activeVoxelNumber);
NR_CUDA_CHECK_KERNEL(gridDims, blockDims);
} else {
const unsigned blocks = blockSize->reg_getVoxelBasedNmiGradientUsingPw2D;
const unsigned grids = (unsigned)Ceil(sqrtf((float)activeVoxelNumber / (float)blocks));
const dim3 gridDims(grids, grids, 1);
const dim3 blockDims(blocks, 1, 1);
reg_getVoxelBasedNmiGradientUsingPw2D_kernel<<<gridDims, blockDims>>>(voxelBasedGradientCuda, *referenceImageTexture, *warpedImageTexture,
*warpedGradientTexture, *histogramTexture, *maskTexture,
imageSize, refBinning, floBinning, normalisedJE, nmi,
(unsigned)activeVoxelNumber);
NR_CUDA_CHECK_KERNEL(gridDims, blockDims);
}
thrust::for_each_n(thrust::device, thrust::make_counting_iterator<unsigned>(0), activeVoxelNumber, [=]__device__(const unsigned index) {
const int targetIndex = tex1Dfetch<int>(maskTexture, index);
const float warpedImageValue = tex1Dfetch<float>(warpedImageTexture, targetIndex);
if (warpedImageValue != warpedImageValue) return;
const auto&& [x, y, z] = reg_indexToDims_cuda<is3d>(targetIndex, imageSize);
const float referenceImageValue = tex3D<float>(referenceImageTexture,
(float(x) + 0.5f) / float(imageSize.x),
(float(y) + 0.5f) / float(imageSize.y),
is3d ? (float(z) + 0.5f) / float(imageSize.z) : 0.5f);
if (referenceImageValue != referenceImageValue) return;
const float4& warpedGradValue = tex1Dfetch<float4>(warpedGradientTexture, index);
float4 gradValue = voxelBasedGradientCuda[targetIndex];

// No computation is performed if any of the point is part of the background
// The two is added because the image is resample between 2 and bin+2
// if 64 bins are used the histogram will have 68 bins et the image will be between 2 and 65
typename Derivative<is3d>::Type jointDeriv{}, refDeriv{}, warDeriv{};
for (int r = (int)referenceImageValue - 1; r < (int)referenceImageValue + 3; ++r) {
if (-1 < r && r < refBinNumber) {
for (int w = (int)warpedImageValue - 1; w < (int)warpedImageValue + 3; ++w) {
if (-1 < w && w < floBinNumber) {
const double commonValue = (GetBasisSplineValue<double>(referenceImageValue - r) *
GetBasisSplineDerivativeValue<double>(warpedImageValue - w));
const double jointLog = jointHistogramLogCuda[r + w * refBinNumber];
const double refLog = jointHistogramLogCuda[r + referenceOffset];
const double warLog = jointHistogramLogCuda[w + floatingOffset];
if (warpedGradValue.x == warpedGradValue.x) {
const double commonMultGrad = commonValue * warpedGradValue.x;
jointDeriv.x += commonMultGrad * jointLog;
refDeriv.x += commonMultGrad * refLog;
warDeriv.x += commonMultGrad * warLog;
}
if (warpedGradValue.y == warpedGradValue.y) {
const double commonMultGrad = commonValue * warpedGradValue.y;
jointDeriv.y += commonMultGrad * jointLog;
refDeriv.y += commonMultGrad * refLog;
warDeriv.y += commonMultGrad * warLog;
}
if constexpr (is3d) {
if (warpedGradValue.z == warpedGradValue.z) {
const double commonMultGrad = commonValue * warpedGradValue.z;
jointDeriv.z += commonMultGrad * jointLog;
refDeriv.z += commonMultGrad * refLog;
warDeriv.z += commonMultGrad * warLog;
}
}
}
}
}
}

// (Marc) I removed the normalisation by the voxel number as each gradient has to be normalised in the same way
gradValue.x += static_cast<float>(timePointWeight * (refDeriv.x + warDeriv.x - nmi * jointDeriv.x) / normalisedJE);
gradValue.y += static_cast<float>(timePointWeight * (refDeriv.y + warDeriv.y - nmi * jointDeriv.y) / normalisedJE);
if constexpr (is3d)
gradValue.z += static_cast<float>(timePointWeight * (refDeriv.z + warDeriv.z - nmi * jointDeriv.z) / normalisedJE);
voxelBasedGradientCuda[targetIndex] = gradValue;
});
}
/* *************************************************************** */
void reg_nmi_gpu::GetVoxelBasedSimilarityMeasureGradientFw(int currentTimePoint) {
// Call compute similarity measure to calculate joint histogram
this->GetSimilarityMeasureValue();

// The latest joint histogram is transferred onto the GPU
thrust::device_vector<float> jointHistogramLogCuda(this->jointHistogramLog[0], this->jointHistogramLog[0] + this->totalBinNumber[0]);

// The gradient of the NMI is computed on the GPU
reg_getVoxelBasedNmiGradient_gpu(this->referenceImage,
this->referenceImageCuda,
this->warpedImageCuda,
this->warpedGradientCuda,
jointHistogramLogCuda.data().get(),
this->voxelBasedGradientCuda,
this->referenceMaskCuda,
this->activeVoxelNumber,
this->entropyValues[0],
this->referenceBinNumber[0],
this->floatingBinNumber[0]);
auto getVoxelBasedNmiGradient = this->referenceImage->nz > 1 ? reg_getVoxelBasedNmiGradient_gpu<true> : reg_getVoxelBasedNmiGradient_gpu<false>;
getVoxelBasedNmiGradient(this->referenceImage,
this->referenceImageCuda,
this->warpedImageCuda,
this->warpedGradientCuda,
this->jointHistogramLogCudaVecs[currentTimePoint].data().get(),
this->voxelBasedGradientCuda,
this->referenceMaskCuda,
this->activeVoxelNumber,
this->entropyValues[currentTimePoint],
this->referenceBinNumber[currentTimePoint],
this->floatingBinNumber[currentTimePoint],
this->totalBinNumber[currentTimePoint],
this->timePointWeights[currentTimePoint],
currentTimePoint);
}
/* *************************************************************** */
void reg_nmi_gpu::GetVoxelBasedSimilarityMeasureGradientBw(int currentTimePoint) {
// The latest joint histogram is transferred onto the GPU
thrust::device_vector<float> jointHistogramLogCudaBw(this->jointHistogramLogBw[0], this->jointHistogramLogBw[0] + this->totalBinNumber[0]);

// The gradient of the NMI is computed on the GPU
reg_getVoxelBasedNmiGradient_gpu(this->floatingImage,
this->floatingImageCuda,
this->warpedImageBwCuda,
this->warpedGradientBwCuda,
jointHistogramLogCudaBw.data().get(),
this->voxelBasedGradientBwCuda,
this->floatingMaskCuda,
this->activeVoxelNumber,
this->entropyValuesBw[0],
this->floatingBinNumber[0],
this->referenceBinNumber[0]);
auto getVoxelBasedNmiGradient = this->floatingImage->nz > 1 ? reg_getVoxelBasedNmiGradient_gpu<true> : reg_getVoxelBasedNmiGradient_gpu<false>;
getVoxelBasedNmiGradient(this->floatingImage,
this->floatingImageCuda,
this->warpedImageBwCuda,
this->warpedGradientBwCuda,
this->jointHistogramLogBwCudaVecs[currentTimePoint].data().get(),
this->voxelBasedGradientBwCuda,
this->floatingMaskCuda,
this->activeVoxelNumber,
this->entropyValuesBw[currentTimePoint],
this->floatingBinNumber[currentTimePoint],
this->referenceBinNumber[currentTimePoint],
this->totalBinNumber[currentTimePoint],
this->timePointWeights[currentTimePoint],
currentTimePoint);
}
/* *************************************************************** */
Loading

0 comments on commit bc4c672

Please sign in to comment.