Skip to content

Commit

Permalink
Update compose deformation field test to include CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Oct 12, 2023
1 parent b6d5097 commit a8f1232
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 251 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
341
342
4 changes: 4 additions & 0 deletions reg-lib/Compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,7 @@ void Compute::SymmetriseVelocityFields(Content& conBwIn) {
nifti_image_free(warpedTransBw);
}
/* *************************************************************** */
void Compute::DefFieldCompose(const nifti_image *defField) {
reg_defField_compose(defField, con.GetDeformationField(), nullptr);
}
/* *************************************************************** */
1 change: 1 addition & 0 deletions reg-lib/Compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Compute {
#ifdef NR_TESTING
public:
#endif
virtual void DefFieldCompose(const nifti_image *defField);
virtual void VoxelCentricToNodeCentric(float weight);

private:
Expand Down
104 changes: 42 additions & 62 deletions reg-lib/cpu/_reg_localTrans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2258,25 +2258,25 @@ void reg_spline_refineControlPointGrid(nifti_image *controlPointGrid,
}
/* *************************************************************** */
template <class DataType>
void reg_defField_compose2D(nifti_image *deformationField,
void reg_defField_compose2D(const nifti_image *deformationField,
nifti_image *dfToUpdate,
int *mask) {
const size_t DFVoxelNumber = NiftiImage::calcVoxelNumber(deformationField, 2);
const int *mask) {
const size_t dfVoxelNumber = NiftiImage::calcVoxelNumber(deformationField, 2);
#ifdef _WIN32
long i;
const long warVoxelNumber = (long)NiftiImage::calcVoxelNumber(dfToUpdate, 2);
#else
size_t i;
const size_t warVoxelNumber = NiftiImage::calcVoxelNumber(dfToUpdate, 2);
#endif
DataType *defPtrX = static_cast<DataType*>(deformationField->data);
DataType *defPtrY = &defPtrX[DFVoxelNumber];
const DataType *defPtrX = static_cast<DataType*>(deformationField->data);
const DataType *defPtrY = &defPtrX[dfVoxelNumber];

DataType *resPtrX = static_cast<DataType*>(dfToUpdate->data);
DataType *resPtrY = &resPtrX[warVoxelNumber];

const mat44 *df_real2Voxel;
mat44 *df_voxel2Real;
const mat44 *df_voxel2Real;
if (deformationField->sform_code > 0) {
df_real2Voxel = &dfToUpdate->sto_ijk;
df_voxel2Real = &deformationField->sto_xyz;
Expand All @@ -2302,12 +2302,14 @@ void reg_defField_compose2D(nifti_image *deformationField,
realDefY = resPtrY[i];

// Conversion from real to voxel in the deformation field
voxelX = realDefX * df_real2Voxel->m[0][0]
+ realDefY * df_real2Voxel->m[0][1]
+ df_real2Voxel->m[0][3];
voxelY = realDefX * df_real2Voxel->m[1][0]
+ realDefY * df_real2Voxel->m[1][1]
+ df_real2Voxel->m[1][3];
voxelX =
realDefX * df_real2Voxel->m[0][0] +
realDefY * df_real2Voxel->m[0][1] +
df_real2Voxel->m[0][3];
voxelY =
realDefX * df_real2Voxel->m[1][0] +
realDefY * df_real2Voxel->m[1][1] +
df_real2Voxel->m[1][3];

// Linear interpolation to compute the new deformation
pre[0] = Floor(voxelX);
Expand All @@ -2316,12 +2318,12 @@ void reg_defField_compose2D(nifti_image *deformationField,
relX[0] = 1.f - relX[1];
relY[1] = voxelY - static_cast<DataType>(pre[1]);
relY[0] = 1.f - relY[1];
realDefX = realDefY = 0.f;
realDefX = realDefY = 0;
for (b = 0; b < 2; ++b) {
for (a = 0; a < 2; ++a) {
basis = relX[a] * relY[b];
if (pre[0] + a > -1 && pre[0] + a<deformationField->nx &&
pre[1] + b>-1 && pre[1] + b < deformationField->ny) {
if (pre[0] + a > -1 && pre[0] + a < deformationField->nx &&
pre[1] + b > -1 && pre[1] + b < deformationField->ny) {
// Uses the deformation field if voxel is in its space
index = (pre[1] + b) * deformationField->nx + pre[0] + a;
defX = defPtrX[index];
Expand Down Expand Up @@ -2349,22 +2351,20 @@ void reg_defField_compose2D(nifti_image *deformationField,
}
/* *************************************************************** */
template <class DataType>
void reg_defField_compose3D(nifti_image *deformationField,
void reg_defField_compose3D(const nifti_image *deformationField,
nifti_image *dfToUpdate,
int *mask) {
const int DefFieldDim[3] = { deformationField->nx, deformationField->ny, deformationField->nz };
const size_t DFVoxelNumber = (size_t)DefFieldDim[0] * DefFieldDim[1] * DefFieldDim[2];
const int *mask) {
const size_t dfVoxelNumber = NiftiImage::calcVoxelNumber(deformationField, 3);
#ifdef _WIN32
long i;
const long warVoxelNumber = (long)NiftiImage::calcVoxelNumber(dfToUpdate, 3);
#else
size_t i;
const size_t warVoxelNumber = NiftiImage::calcVoxelNumber(dfToUpdate, 3);
#endif

DataType *defPtrX = static_cast<DataType*>(deformationField->data);
DataType *defPtrY = &defPtrX[DFVoxelNumber];
DataType *defPtrZ = &defPtrY[DFVoxelNumber];
const DataType *defPtrX = static_cast<DataType*>(deformationField->data);
const DataType *defPtrY = &defPtrX[dfVoxelNumber];
const DataType *defPtrZ = &defPtrY[dfVoxelNumber];

DataType *resPtrX = static_cast<DataType*>(dfToUpdate->data);
DataType *resPtrY = &resPtrX[warVoxelNumber];
Expand All @@ -2375,7 +2375,7 @@ void reg_defField_compose3D(nifti_image *deformationField,
#else
mat44 df_real2Voxel __attribute__((aligned(16)));
#endif
mat44 *df_voxel2Real;
const mat44 *df_voxel2Real;
if (deformationField->sform_code > 0) {
df_real2Voxel = deformationField->sto_ijk;
df_voxel2Real = &deformationField->sto_xyz;
Expand All @@ -2391,7 +2391,7 @@ void reg_defField_compose3D(nifti_image *deformationField,
bool inY, inZ;
#ifdef _OPENMP
#pragma omp parallel for default(none) \
shared(warVoxelNumber, mask, df_real2Voxel, df_voxel2Real, DefFieldDim, \
shared(warVoxelNumber, mask, df_real2Voxel, df_voxel2Real, \
defPtrX, defPtrY, defPtrZ, resPtrX, resPtrY, resPtrZ, deformationField) \
private(a, b, c, currentX, currentY, currentZ, index, tempIndex, pre, \
realDef, voxel, tempBasis, defX, defY, defZ, relX, relY, relZ, basis, inY, inZ)
Expand Down Expand Up @@ -2429,21 +2429,21 @@ void reg_defField_compose3D(nifti_image *deformationField,
relY[0] = 1.f - relY[1];
relZ[1] = voxel[2] - static_cast<DataType>(pre[2]);
relZ[0] = 1.f - relZ[1];
realDef[0] = realDef[1] = realDef[2] = 0.;
realDef[0] = realDef[1] = realDef[2] = 0;
for (c = 0; c < 2; ++c) {
currentZ = pre[2] + c;
tempIndex = currentZ * DefFieldDim[0] * DefFieldDim[1];
if (currentZ > -1 && currentZ < DefFieldDim[2]) inZ = true;
tempIndex = currentZ * deformationField->nx * deformationField->ny;
if (currentZ > -1 && currentZ < deformationField->nz) inZ = true;
else inZ = false;
for (b = 0; b < 2; ++b) {
currentY = pre[1] + b;
index = tempIndex + currentY * DefFieldDim[0] + pre[0];
index = tempIndex + currentY * deformationField->nx + pre[0];
tempBasis = relY[b] * relZ[c];
if (currentY > -1 && currentY < DefFieldDim[1]) inY = true;
if (currentY > -1 && currentY < deformationField->ny) inY = true;
else inY = false;
for (a = 0; a < 2; ++a) {
currentX = pre[0] + a;
if (currentX > -1 && currentX < DefFieldDim[0] && inY && inZ) {
if (currentX > -1 && currentX < deformationField->nx && inY && inZ) {
// Uses the deformation field if voxel is in its space
defX = defPtrX[index];
defY = defPtrY[index];
Expand Down Expand Up @@ -2478,43 +2478,23 @@ void reg_defField_compose3D(nifti_image *deformationField,
}// loop over every voxel
}
/* *************************************************************** */
void reg_defField_compose(nifti_image *deformationField,
void reg_defField_compose(const nifti_image *deformationField,
nifti_image *dfToUpdate,
int *mask) {
const int *mask) {
if (deformationField->datatype != dfToUpdate->datatype)
NR_FATAL_ERROR("Both deformation fields are expected to have the same type");

bool freeMask = false;
if (mask == nullptr) {
mask = (int*)calloc(NiftiImage::calcVoxelNumber(dfToUpdate, 3), sizeof(int));
freeMask = true;
}

if (dfToUpdate->nu == 2) {
switch (deformationField->datatype) {
case NIFTI_TYPE_FLOAT32:
reg_defField_compose2D<float>(deformationField, dfToUpdate, mask);
break;
case NIFTI_TYPE_FLOAT64:
reg_defField_compose2D<double>(deformationField, dfToUpdate, mask);
break;
default:
NR_FATAL_ERROR("Deformation field pixel type is unsupported");
}
} else {
switch (deformationField->datatype) {
case NIFTI_TYPE_FLOAT32:
reg_defField_compose3D<float>(deformationField, dfToUpdate, mask);
break;
case NIFTI_TYPE_FLOAT64:
reg_defField_compose3D<double>(deformationField, dfToUpdate, mask);
break;
default:
NR_FATAL_ERROR("Deformation field pixel type is unsupported");
}
unique_ptr<int[]> currentMask;
if (!mask) {
currentMask.reset(new int[NiftiImage::calcVoxelNumber(dfToUpdate, 3)]());
mask = currentMask.get();
}

if (freeMask) free(mask);
std::visit([&](auto&& defFieldDataType) {
using DefFieldDataType = std::decay_t<decltype(defFieldDataType)>;
auto defFieldCompose = dfToUpdate->nu == 2 ? reg_defField_compose2D<DefFieldDataType> : reg_defField_compose3D<DefFieldDataType>;
defFieldCompose(deformationField, dfToUpdate, mask);
}, NiftiImage::getFloatingDataType(deformationField));
}
/* *************************************************************** */
/// @brief Internal data structure to pass user data into optimizer that get passed to cost_function
Expand Down
6 changes: 3 additions & 3 deletions reg-lib/cpu/_reg_localTrans.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ int reg_spline_cppComposition(nifti_image *grid1,
* @param dfToUpdate Image that contains the deformation field that
* is being updated
* @param mask Mask overlaid on the dfToUpdate field where only voxel
* within the mask will be updated. All positive values in the maks
* within the mask will be updated. All positive values in the mask
* are considered as belonging to the mask.
*/
void reg_defField_compose(nifti_image *deformationField,
void reg_defField_compose(const nifti_image *deformationField,
nifti_image *dfToUpdate,
int *mask);
const int *mask);
/* *************************************************************** */
/** @brief Compute the inverse of a deformation field
* @author Marcel van Herk (CMIC / NKI / AVL)
Expand Down
40 changes: 20 additions & 20 deletions reg-lib/cpu/_reg_splineBasis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,13 +460,13 @@ template void set_second_order_bspline_basis_values<double>(double*, double*, do
template <class DataType>
void get_SlidedValues(DataType& defX,
DataType& defY,
int X,
int Y,
DataType *defPtrX,
DataType *defPtrY,
mat44 *df_voxel2Real,
int *dim,
bool displacement) {
const int X,
const int Y,
const DataType *defPtrX,
const DataType *defPtrY,
const mat44 *df_voxel2Real,
const int *dim,
const bool displacement) {
int newX = X;
int newY = Y;
if (X < 0) {
Expand All @@ -493,22 +493,22 @@ void get_SlidedValues(DataType& defX,
defX = defPtrX[index] + shiftValueX;
defY = defPtrY[index] + shiftValueY;
}
template void get_SlidedValues<float>(float&, float&, int, int, float*, float*, mat44*, int*, bool);
template void get_SlidedValues<double>(double&, double&, int, int, double*, double*, mat44*, int*, bool);
template void get_SlidedValues<float>(float&, float&, const int, const int, const float*, const float*, const mat44*, const int*, const bool);
template void get_SlidedValues<double>(double&, double&, const int, const int, const double*, const double*, const mat44*, const int*, const bool);
/* *************************************************************** */
template <class DataType>
void get_SlidedValues(DataType& defX,
DataType& defY,
DataType& defZ,
int X,
int Y,
int Z,
DataType *defPtrX,
DataType *defPtrY,
DataType *defPtrZ,
mat44 *df_voxel2Real,
int *dim,
bool displacement) {
const int X,
const int Y,
const int Z,
const DataType *defPtrX,
const DataType *defPtrY,
const DataType *defPtrZ,
const mat44 *df_voxel2Real,
const int *dim,
const bool displacement) {
int newX = X;
int newY = Y;
int newZ = Z;
Expand Down Expand Up @@ -552,8 +552,8 @@ void get_SlidedValues(DataType& defX,
defY = defPtrY[index] + shiftValueY;
defZ = defPtrZ[index] + shiftValueZ;
}
template void get_SlidedValues<float>(float&, float&, float&, int, int, int, float*, float*, float*, mat44*, int*, bool);
template void get_SlidedValues<double>(double&, double&, double&, int, int, int, double*, double*, double*, mat44*, int*, bool);
template void get_SlidedValues<float>(float&, float&, float&, const int, const int, const int, const float*, const float*, const float*, const mat44*, const int*, const bool);
template void get_SlidedValues<double>(double&, double&, double&, const int, const int, const int, const double*, const double*, const double*, const mat44*, const int*, const bool);
/* *************************************************************** */
template <class DataType>
void get_GridValues(int startX,
Expand Down
32 changes: 16 additions & 16 deletions reg-lib/cpu/_reg_splineBasis.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,26 @@ void get_SplineBasisValues(DataType basis,
template <class DataType>
void get_SlidedValues(DataType &defX,
DataType &defY,
int X,
int Y,
DataType *defPtrX,
DataType *defPtrY,
mat44 *df_voxel2Real,
int *dim,
bool displacement);
const int X,
const int Y,
const DataType *defPtrX,
const DataType *defPtrY,
const mat44 *df_voxel2Real,
const int *dim,
const bool displacement);
template <class DataType>
void get_SlidedValues(DataType &defX,
DataType &defY,
DataType &defZ,
int X,
int Y,
int Z,
DataType *defPtrX,
DataType *defPtrY,
DataType *defPtrZ,
mat44 *df_voxel2Real,
int *dim,
bool displacement);
const int X,
const int Y,
const int Z,
const DataType *defPtrX,
const DataType *defPtrY,
const DataType *defPtrZ,
const mat44 *df_voxel2Real,
const int *dim,
const bool displacement);


template <class DataType>
Expand Down
2 changes: 1 addition & 1 deletion reg-lib/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ set(NIFTYREG_LIBRARIES "${NIFTYREG_LIBRARIES};${NAME}")
set(NAME _reg_cuda_kernels)
cuda_add_library(${NAME} ${NIFTYREG_LIBRARY_TYPE}
CudaAladinContent.cpp
CudaCompute.cpp
CudaCompute.cu
CudaContent.cpp
CudaContext.cpp
CudaDefContent.cpp
Expand Down
8 changes: 8 additions & 0 deletions reg-lib/cuda/CudaCompute.cpp → reg-lib/cuda/CudaCompute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,11 @@ void CudaCompute::SymmetriseVelocityFields(Content& conBwIn) {
dynamic_cast<CudaF3dContent&>(conBwIn).UpdateControlPointGrid();
}
/* *************************************************************** */
void CudaCompute::DefFieldCompose(const nifti_image *defField) {
CudaContent& con = dynamic_cast<CudaContent&>(this->con);
const size_t& voxelNumber = NiftiImage::calcVoxelNumber(defField, 3);
thrust::device_vector<float4> defFieldCuda(voxelNumber);
Cuda::TransferNiftiToDevice(defFieldCuda.data().get(), defField);
reg_defField_compose_gpu(defField, defFieldCuda.data().get(), con.GetDeformationFieldCuda());
}
/* *************************************************************** */
1 change: 1 addition & 0 deletions reg-lib/cuda/CudaCompute.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CudaCompute: public Compute {
#ifndef NR_TESTING
protected:
#endif
virtual void DefFieldCompose(const nifti_image *defField) override;
virtual void VoxelCentricToNodeCentric(float weight) override;

private:
Expand Down
7 changes: 3 additions & 4 deletions reg-lib/cuda/_reg_localTransformation_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,14 @@ void reg_spline_getFlowFieldFromVelocityGrid_gpu(nifti_image *velocityFieldGrid,
/* *************************************************************** */
void reg_defField_compose_gpu(const nifti_image *deformationField,
const float4 *deformationFieldCuda,
float4 *deformationFieldCudaOut,
const size_t activeVoxelNumber) {
float4 *deformationFieldCudaOut) {
auto blockSize = CudaContext::GetBlockSize();
const size_t voxelNumber = NiftiImage::calcVoxelNumber(deformationField, 3);
const int3 referenceImageDim{ deformationField->nx, deformationField->ny, deformationField->nz };
const mat44& affineMatrixB = deformationField->sform_code > 0 ? deformationField->sto_ijk : deformationField->qto_ijk;
const mat44& affineMatrixC = deformationField->sform_code > 0 ? deformationField->sto_xyz : deformationField->qto_xyz;
auto deformationFieldTexture = Cuda::CreateTextureObject(deformationFieldCuda, cudaResourceTypeLinear,
activeVoxelNumber * sizeof(float4), cudaChannelFormatKindFloat, 4);
voxelNumber * sizeof(float4), cudaChannelFormatKindFloat, 4);

if (deformationField->nz > 1) {
const unsigned blocks = blockSize->reg_defField_compose3D;
Expand Down Expand Up @@ -634,7 +633,7 @@ void reg_defField_getDeformationFieldFromFlowField_gpu(nifti_image *flowField,
// The deformation field is squared
for (int i = 0; i < squaringNumber; ++i) {
// The deformation field is applied to itself
reg_defField_compose_gpu(deformationField, deformationFieldCuda, flowFieldCuda, voxelNumber);
reg_defField_compose_gpu(deformationField, deformationFieldCuda, flowFieldCuda);
// The computed scaled deformation field is copied over
thrust::copy(thrust::device, flowFieldCuda, flowFieldCuda + voxelNumber, deformationFieldCuda);
NR_DEBUG("Squaring (composition) step " << i + 1 << "/" << squaringNumber);
Expand Down
Loading

0 comments on commit a8f1232

Please sign in to comment.