Skip to content

Commit

Permalink
Add symmetric scheme support for reg_nmi_gpu #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Jul 20, 2023
1 parent 4006362 commit 37c3370
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 199 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
283
284
21 changes: 15 additions & 6 deletions reg-lib/cuda/CudaMeasure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,28 @@ void CudaMeasure::Initialise(reg_measure& measure, F3dContent& con, F3dContent *
// TODO Implement symmetric scheme for CUDA measure types
reg_measure_gpu& measureGpu = dynamic_cast<reg_measure_gpu&>(measure);
CudaF3dContent& cudaCon = dynamic_cast<CudaF3dContent&>(con);
CudaF3dContent *cudaConBw = dynamic_cast<CudaF3dContent*>(conBw);
measureGpu.InitialiseMeasure(cudaCon.Content::GetReference(),
cudaCon.GetReferenceCuda(),
cudaCon.Content::GetFloating(),
cudaCon.GetFloatingCuda(),
cudaCon.Content::GetReferenceMask(),
cudaCon.GetReferenceMaskCuda(),
cudaCon.GetActiveVoxelNumber(),
cudaCon.Content::GetWarped(),
cudaCon.GetWarpedCuda(),
cudaCon.F3dContent::GetWarpedGradient(),
cudaCon.GetWarpedGradientCuda(),
cudaCon.F3dContent::GetVoxelBasedMeasureGradient(),
cudaCon.GetVoxelBasedMeasureGradientCuda(),
cudaCon.F3dContent::GetLocalWeightSim(),
cudaCon.GetReferenceCuda(),
cudaCon.GetFloatingCuda(),
cudaCon.GetReferenceMaskCuda(),
cudaCon.GetWarpedCuda(),
cudaCon.GetWarpedGradientCuda(),
cudaCon.GetVoxelBasedMeasureGradientCuda());
cudaConBw ? cudaConBw->Content::GetReferenceMask() : nullptr,
cudaConBw ? cudaConBw->GetReferenceMaskCuda() : nullptr,
cudaConBw ? cudaConBw->Content::GetWarped() : nullptr,
cudaConBw ? cudaConBw->GetWarpedCuda() : nullptr,
cudaConBw ? cudaConBw->F3dContent::GetWarpedGradient() : nullptr,
cudaConBw ? cudaConBw->GetWarpedGradientCuda() : nullptr,
cudaConBw ? cudaConBw->F3dContent::GetVoxelBasedMeasureGradient() : nullptr,
cudaConBw ? cudaConBw->GetVoxelBasedMeasureGradientCuda() : nullptr);
}
/* *************************************************************** */
135 changes: 105 additions & 30 deletions reg-lib/cuda/_reg_measure_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,63 @@ class reg_measure_gpu {
virtual ~reg_measure_gpu() {}

virtual void InitialiseMeasure(nifti_image *refImg,
cudaArray *refImgCuda,
nifti_image *floImg,
cudaArray *floImgCuda,
int *refMask,
int *refMaskCuda,
size_t activeVoxNum,
nifti_image *warpedImg,
nifti_image *warpedGrad,
nifti_image *voxelBasedGrad,
nifti_image *localWeightSim,
cudaArray *refImgCuda,
cudaArray *floImgCuda,
int *refMaskCuda,
float *warpedImgCuda,
nifti_image *warpedGrad,
float4 *warpedGradCuda,
float4 *voxelBasedGradCuda) = 0;
nifti_image *voxelBasedGrad,
float4 *voxelBasedGradCuda,
nifti_image *localWeightSim = nullptr,
int *floMask = nullptr,
int *floMaskCuda = nullptr,
nifti_image *warpedImgBw = nullptr,
float *warpedImgBwCuda = nullptr,
nifti_image *warpedGradBw = nullptr,
float4 *warpedGradBwCuda = nullptr,
nifti_image *voxelBasedGradBw = nullptr,
float4 *voxelBasedGradBwCuda = nullptr) {
// Check that the input image are of type float
if (refImg->datatype != NIFTI_TYPE_FLOAT32 || warpedImg->datatype != NIFTI_TYPE_FLOAT32) {
reg_print_fct_error("reg_measure_gpu::InitialiseMeasure");
reg_print_msg_error("Only single precision is supported on the GPU");
reg_exit();
}
// Bind the required pointers
this->referenceImageCuda = refImgCuda;
this->floatingImageCuda = floImgCuda;
this->referenceMaskCuda = refMaskCuda;
this->activeVoxelNumber = activeVoxNum;
this->warpedImageCuda = warpedImgCuda;
this->warpedGradientCuda = warpedGradCuda;
this->voxelBasedGradientCuda = voxelBasedGradCuda;
// Check if the symmetric mode is used
if (floMask != nullptr && warpedImgBw != nullptr && warpedGradBw != nullptr && voxelBasedGradBw != nullptr &&
floMaskCuda != nullptr && warpedImgBwCuda != nullptr && warpedGradBwCuda != nullptr && voxelBasedGradBwCuda != nullptr) {
if (floImg->datatype != NIFTI_TYPE_FLOAT32 || warpedImgBw->datatype != NIFTI_TYPE_FLOAT32) {
reg_print_fct_error("reg_measure_gpu::InitialiseMeasure");
reg_print_msg_error("Only single precision is supported on the GPU");
reg_exit();
}
this->floatingMaskCuda = floMaskCuda;
this->warpedImageBwCuda = warpedImgBwCuda;
this->warpedGradientBwCuda = warpedGradBwCuda;
this->voxelBasedGradientBwCuda = voxelBasedGradBwCuda;
} else {
this->floatingMaskCuda = nullptr;
this->warpedImageBwCuda = nullptr;
this->warpedGradientBwCuda = nullptr;
this->voxelBasedGradientBwCuda = nullptr;
}
#ifndef NDEBUG
reg_print_msg_debug("reg_measure_gpu::InitialiseMeasure() called");
#endif
}

protected:
cudaArray *referenceImageCuda;
Expand All @@ -44,6 +88,11 @@ class reg_measure_gpu {
float *warpedImageCuda;
float4 *warpedGradientCuda;
float4 *voxelBasedGradientCuda;

int *floatingMaskCuda;
float *warpedImageBwCuda;
float4 *warpedGradientBwCuda;
float4 *voxelBasedGradientBwCuda;
};
/* *************************************************************** */
class reg_lncc_gpu: public reg_lncc, public reg_measure_gpu {
Expand All @@ -57,19 +106,27 @@ class reg_lncc_gpu: public reg_lncc, public reg_measure_gpu {
virtual ~reg_lncc_gpu() {}

virtual void InitialiseMeasure(nifti_image *refImg,
cudaArray *refImgCuda,
nifti_image *floImg,
cudaArray *floImgCuda,
int *refMask,
int *refMaskCuda,
size_t activeVoxNum,
nifti_image *warpedImg,
nifti_image *warpedGrad,
nifti_image *voxelBasedGrad,
nifti_image *localWeightSim,
cudaArray *refImgCuda,
cudaArray *floImgCuda,
int *refMaskCuda,
float *warpedImgCuda,
nifti_image *warpedGrad,
float4 *warpedGradCuda,
float4 *voxelBasedGradCuda) override {}
nifti_image *voxelBasedGrad,
float4 *voxelBasedGradCuda,
nifti_image *localWeightSim = nullptr,
int *floMask = nullptr,
int *floMaskCuda = nullptr,
nifti_image *warpedImgBw = nullptr,
float *warpedImgBwCuda = nullptr,
nifti_image *warpedGradBw = nullptr,
float4 *warpedGradBwCuda = nullptr,
nifti_image *voxelBasedGradBw = nullptr,
float4 *voxelBasedGradBwCuda = nullptr) override {}
/// @brief Returns the lncc value
virtual double GetSimilarityMeasureValue() override { return 0; }
/// @brief Compute the voxel based lncc gradient
Expand All @@ -80,26 +137,35 @@ class reg_kld_gpu: public reg_kld, public reg_measure_gpu {
public:
/// @brief reg_kld_gpu class constructor
reg_kld_gpu() {
fprintf(stderr, "[ERROR] CUDA CANNOT BE USED WITH KLD YET\n");
reg_print_fct_error("reg_kld_gpu::reg_kld_gpu");
reg_print_msg_error("CUDA CANNOT BE USED WITH KLD YET");
reg_exit();
}
/// @brief reg_kld_gpu class destructor
virtual ~reg_kld_gpu() {}

virtual void InitialiseMeasure(nifti_image *refImg,
cudaArray *refImgCuda,
nifti_image *floImg,
cudaArray *floImgCuda,
int *refMask,
int *refMaskCuda,
size_t activeVoxNum,
nifti_image *warpedImg,
nifti_image *warpedGrad,
nifti_image *voxelBasedGrad,
nifti_image *localWeightSim,
cudaArray *refImgCuda,
cudaArray *floImgCuda,
int *refMaskCuda,
float *warpedImgCuda,
nifti_image *warpedGrad,
float4 *warpedGradCuda,
float4 *voxelBasedGradCuda) override {}
nifti_image *voxelBasedGrad,
float4 *voxelBasedGradCuda,
nifti_image *localWeightSim = nullptr,
int *floMask = nullptr,
int *floMaskCuda = nullptr,
nifti_image *warpedImgBw = nullptr,
float *warpedImgBwCuda = nullptr,
nifti_image *warpedGradBw = nullptr,
float4 *warpedGradBwCuda = nullptr,
nifti_image *voxelBasedGradBw = nullptr,
float4 *voxelBasedGradBwCuda = nullptr) override {}
/// @brief Returns the kld value
virtual double GetSimilarityMeasureValue() override { return 0; }
/// @brief Compute the voxel based kld gradient
Expand All @@ -110,26 +176,35 @@ class reg_dti_gpu: public reg_dti, public reg_measure_gpu {
public:
/// @brief reg_dti_gpu class constructor
reg_dti_gpu() {
fprintf(stderr, "[ERROR] CUDA CANNOT BE USED WITH DTI YET\n");
reg_print_fct_error("reg_dti_gpu::reg_dti_gpu");
reg_print_msg_error("CUDA CANNOT BE USED WITH DTI YET");
reg_exit();
}
/// @brief reg_dti_gpu class destructor
virtual ~reg_dti_gpu() {}

virtual void InitialiseMeasure(nifti_image *refImg,
cudaArray *refImgCuda,
nifti_image *floImg,
cudaArray *floImgCuda,
int *refMask,
int *refMaskCuda,
size_t activeVoxNum,
nifti_image *warpedImg,
nifti_image *warpedGrad,
nifti_image *voxelBasedGrad,
nifti_image *localWeightSim,
cudaArray *refImgCuda,
cudaArray *floImgCuda,
int *refMaskCuda,
float *warpedImgCuda,
nifti_image *warpedGrad,
float4 *warpedGradCuda,
float4 *voxelBasedGradCuda) override {}
nifti_image *voxelBasedGrad,
float4 *voxelBasedGradCuda,
nifti_image *localWeightSim = nullptr,
int *floMask = nullptr,
int *floMaskCuda = nullptr,
nifti_image *warpedImgBw = nullptr,
float *warpedImgBwCuda = nullptr,
nifti_image *warpedGradBw = nullptr,
float4 *warpedGradBwCuda = nullptr,
nifti_image *voxelBasedGradBw = nullptr,
float4 *voxelBasedGradBwCuda = nullptr) override {}
/// @brief Returns the dti value
virtual double GetSimilarityMeasureValue() override { return 0; }
/// @brief Compute the voxel based dti gradient
Expand Down
Loading

0 comments on commit 37c3370

Please sign in to comment.