Skip to content

Commit

Permalink
Optimise Optimiser #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Nov 24, 2023
1 parent 8182839 commit 592d01d
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 231 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
369
370
12 changes: 0 additions & 12 deletions reg-lib/cuda/BlockSize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ struct BlockSize {
unsigned reg_defField_compose2D;
unsigned reg_defField_compose3D;
unsigned reg_defField_getJacobianMatrix;
unsigned reg_initialiseConjugateGradient;
unsigned reg_getConjugateGradient1;
unsigned reg_getConjugateGradient2;
unsigned reg_updateControlPointPosition;
unsigned reg_voxelCentricToNodeCentric;
unsigned reg_convertNmiGradientFromVoxelToRealSpace;
unsigned reg_ApplyConvolutionWindowAlongX;
Expand Down Expand Up @@ -68,10 +64,6 @@ struct BlockSize100: public BlockSize {
reg_defField_compose2D = 512; // 15 reg - 24 smem - 08 cmem - 16 lmem
reg_defField_compose3D = 384; // 21 reg - 24 smem - 08 cmem - 24 lmem
reg_defField_getJacobianMatrix = 512; // 16 reg - 24 smem - 04 cmem
reg_initialiseConjugateGradient = 384; // 09 reg - 24 smem
reg_getConjugateGradient1 = 320; // 12 reg - 24 smem
reg_getConjugateGradient2 = 384; // 10 reg - 40 smem
reg_updateControlPointPosition = 384; // 08 reg - 24 smem
reg_voxelCentricToNodeCentric = 320; // 11 reg - 24 smem - 16 cmem
reg_convertNmiGradientFromVoxelToRealSpace = 512; // 16 reg - 24 smem
reg_ApplyConvolutionWindowAlongX = 512; // 14 reg - 28 smem - 08 cmem
Expand Down Expand Up @@ -106,10 +98,6 @@ struct BlockSize300: public BlockSize {
reg_defField_compose2D = 1024; // 23 reg
reg_defField_compose3D = 1024; // 24 reg
reg_defField_getJacobianMatrix = 768; // 34 reg
reg_initialiseConjugateGradient = 1024; // 20 reg
reg_getConjugateGradient1 = 1024; // 22 reg
reg_getConjugateGradient2 = 1024; // 25 reg
reg_updateControlPointPosition = 1024; // 22 reg
reg_voxelCentricToNodeCentric = 1024; // 23 reg
reg_convertNmiGradientFromVoxelToRealSpace = 1024; // 23 reg
reg_ApplyConvolutionWindowAlongX = 1024; // 25 reg
Expand Down
63 changes: 55 additions & 8 deletions reg-lib/cuda/CudaCompute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,68 @@ void CudaCompute::GetDeformationField(bool composition, bool bspline) {
bspline);
}
/* *************************************************************** */
template<bool optimiseX, bool optimiseY, bool optimiseZ>
inline void UpdateControlPointPosition(float4 *currentDofCuda,
cudaTextureObject_t bestDofTexture,
cudaTextureObject_t gradientTexture,
const size_t nVoxels,
const float scale) {
thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), nVoxels, [=]__device__(const int index) {
float4 dofValue = currentDofCuda[index]; scale; // To capture scale
const float4 bestValue = tex1Dfetch<float4>(bestDofTexture, index);
const float4 gradValue = tex1Dfetch<float4>(gradientTexture, index);
if constexpr (optimiseX)
dofValue.x = bestValue.x + scale * gradValue.x;
if constexpr (optimiseY)
dofValue.y = bestValue.y + scale * gradValue.y;
if constexpr (optimiseZ)
dofValue.z = bestValue.z + scale * gradValue.z;
currentDofCuda[index] = dofValue;
});
}
/* *************************************************************** */
template<bool optimiseX, bool optimiseY>
static inline void UpdateControlPointPosition(float4 *currentDofCuda,
cudaTextureObject_t bestDofTexture,
cudaTextureObject_t gradientTexture,
const size_t nVoxels,
const float scale,
const bool optimiseZ) {
auto updateControlPointPosition = UpdateControlPointPosition<optimiseX, optimiseY, true>;
if (!optimiseZ) updateControlPointPosition = UpdateControlPointPosition<optimiseX, optimiseY, false>;
updateControlPointPosition(currentDofCuda, bestDofTexture, gradientTexture, nVoxels, scale);
}
/* *************************************************************** */
template<bool optimiseX>
static inline void UpdateControlPointPosition(float4 *currentDofCuda,
cudaTextureObject_t bestDofTexture,
cudaTextureObject_t gradientTexture,
const size_t nVoxels,
const float scale,
const bool optimiseY,
const bool optimiseZ) {
auto updateControlPointPosition = UpdateControlPointPosition<optimiseX, true>;
if (!optimiseY) updateControlPointPosition = UpdateControlPointPosition<optimiseX, false>;
updateControlPointPosition(currentDofCuda, bestDofTexture, gradientTexture, nVoxels, scale, optimiseZ);
}
/* *************************************************************** */
void CudaCompute::UpdateControlPointPosition(float *currentDof,
const float *bestDof,
const float *gradient,
const float scale,
const bool optimiseX,
const bool optimiseY,
const bool optimiseZ) {
Cuda::UpdateControlPointPosition(NiftiImage::calcVoxelNumber(dynamic_cast<CudaF3dContent&>(con).F3dContent::GetControlPointGrid(), 3),
reinterpret_cast<float4*>(currentDof),
reinterpret_cast<const float4*>(bestDof),
reinterpret_cast<const float4*>(gradient),
scale,
optimiseX,
optimiseY,
optimiseZ);
const nifti_image *controlPointGrid = dynamic_cast<CudaF3dContent&>(con).F3dContent::GetControlPointGrid();
const bool is3d = controlPointGrid->nz > 1;
const size_t nVoxels = NiftiImage::calcVoxelNumber(controlPointGrid, 3);
auto bestDofTexturePtr = Cuda::CreateTextureObject(reinterpret_cast<const float4*>(bestDof), nVoxels, cudaChannelFormatKindFloat, 4);
auto gradientTexturePtr = Cuda::CreateTextureObject(reinterpret_cast<const float4*>(gradient), nVoxels, cudaChannelFormatKindFloat, 4);

auto updateControlPointPosition = ::UpdateControlPointPosition<true>;
if (!optimiseX) updateControlPointPosition = ::UpdateControlPointPosition<false>;
updateControlPointPosition(reinterpret_cast<float4*>(currentDof), *bestDofTexturePtr, *gradientTexturePtr,
nVoxels, scale, optimiseY, is3d ? optimiseZ : false);
}
/* *************************************************************** */
void CudaCompute::GetImageGradient(int interpolation, float paddingValue, int activeTimePoint) {
Expand Down
Loading

0 comments on commit 592d01d

Please sign in to comment.