@@ -753,7 +753,14 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
753753
754754class GenerateHeatmap (Transform ):
755755 """
756- Generate per-landmark gaussian response maps for 2D or 3D coordinates.
756+ Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.
757+
758+ Notes:
759+ - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
760+ - Output shape:
761+ - Non-batched points (N, D): (N, H, W[, D])
762+ - Batched points (B, N, D): (B, N, H, W[, D])
763+ - Each channel corresponds to one landmark.
757764
758765 Args:
759766 sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
@@ -829,11 +836,13 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
829836 continue
830837 region = heatmap [b_idx , idx ][window_slices ]
831838 gaussian = self ._evaluate_gaussian (coord_shifts , sigma )
832- torch .maximum (region , gaussian , out = region )
839+ updated = torch .maximum (region , gaussian )
840+ # write back
841+ region .copy_ (updated )
833842 if self .normalize :
834- max_val = heatmap [ b_idx , idx ] .max ()
835- if max_val .item () > 0 :
836- heatmap [b_idx , idx ] /= max_val
843+ peak = updated .max ()
844+ if peak .item () > 0 :
845+ heatmap [b_idx , idx ] /= peak
837846
838847 if not is_batched :
839848 heatmap = heatmap .squeeze (0 )
@@ -851,7 +860,9 @@ def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims:
851860 if len (shape_tuple ) == 1 :
852861 shape_tuple = shape_tuple * spatial_dims # type: ignore
853862 else :
854- raise ValueError ("spatial_shape length must match spatial dimension of the landmarks." )
863+ raise ValueError (
864+ "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
865+ )
855866 return tuple (int (s ) for s in shape_tuple )
856867
857868 def _resolve_sigma (self , spatial_dims : int ) -> tuple [float , ...]:
@@ -879,7 +890,7 @@ def _make_window(
879890 if start >= stop :
880891 return None , ()
881892 slices .append (slice (start , stop ))
882- coord_shifts .append (torch .arange (start , stop , device = device , dtype = self . torch_dtype ) - float (c ))
893+ coord_shifts .append (torch .arange (start , stop , device = device , dtype = torch . float32 ) - float (c ))
883894 return tuple (slices ), tuple (coord_shifts )
884895
885896 def _evaluate_gaussian (self , coord_shifts : tuple [torch .Tensor , ...], sigma : tuple [float , ...]) -> torch .Tensor :
@@ -897,13 +908,15 @@ def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tupl
897908 shape = tuple (len (axis ) for axis in coord_shifts )
898909 if 0 in shape :
899910 return torch .zeros (shape , dtype = self .torch_dtype , device = device )
900- exponent = torch .zeros (shape , dtype = self . torch_dtype , device = device )
911+ exponent = torch .zeros (shape , dtype = torch . float32 , device = device )
901912 for dim , (shift , sig ) in enumerate (zip (coord_shifts , sigma )):
902- scaled = (shift / float (sig )) ** 2
913+ shift32 = shift .to (torch .float32 )
914+ scaled = (shift32 / float (sig )) ** 2
903915 reshape_shape = [1 ] * len (coord_shifts )
904916 reshape_shape [dim ] = shift .numel ()
905917 exponent += scaled .reshape (reshape_shape )
906- return torch .exp (- 0.5 * exponent )
918+ gauss = torch .exp (- 0.5 * exponent )
919+ return gauss .to (dtype = self .torch_dtype )
907920
908921
909922class ProbNMS (Transform ):
0 commit comments