@@ -528,7 +528,7 @@ def __init__(
528528 heatmap_keys : KeysCollection | None = None ,
529529 ref_image_keys : KeysCollection | None = None ,
530530 spatial_shape : Sequence [int ] | Sequence [Sequence [int ]] | None = None ,
531- truncate : float = 3 .0 ,
531+ truncated : float = 4 .0 ,
532532 normalize : bool = True ,
533533 dtype : np .dtype | type = np .float32 ,
534534 allow_missing_keys : bool = False ,
@@ -540,7 +540,7 @@ def __init__(
540540 self .generator = GenerateHeatmap (
541541 sigma = sigma ,
542542 spatial_shape = None ,
543- truncate = truncate ,
543+ truncated = truncated ,
544544 normalize = normalize ,
545545 dtype = dtype ,
546546 )
@@ -632,11 +632,25 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
632632
633633 def _prepare_output (self , heatmap : NdarrayOrTensor , reference : Any ) -> Any :
634634 if isinstance (reference , MetaTensor ):
635- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = reference .dtype , device = reference .device )
636- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
635+ # Use heatmap's dtype (from generator), not reference's dtype
636+ converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
637+ # For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial)
638+ if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
639+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
640+ elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641+ # Need to check if this is batched 2D or non-batched 3D
642+ if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
643+ # Non-batched 3D
644+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
645+ else :
646+ # Batched 2D
647+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
648+ else : # 2D non-batched: (C, H, W)
649+ converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
637650 return converted
638651 if isinstance (reference , torch .Tensor ):
639- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = reference .dtype , device = reference .device )
652+ # Use heatmap's dtype (from generator), not reference's dtype
653+ converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
640654 return converted
641655 return heatmap
642656
0 commit comments