Skip to content

Commit 3097baf

Browse files
committed
Enables batched input for heatmap generation
Extends the `GenerateHeatmap` transform to support batched inputs, allowing for more efficient processing of multiple landmark sets. This change modifies the transform to handle inputs with a batch dimension (B, N, spatial_dims) in addition to single-point inputs (N, spatial_dims). It also includes a demonstration of 3D heatmap generation using PyVista for visualization.
1 parent 226bf90 commit 3097baf

File tree

5 files changed

+347
-30
lines changed

5 files changed

+347
-30
lines changed

2d_mdtest.ipynb

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
"# Heatmap helper using GenerateHeatmap\n",
9292
"# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n",
9393
"# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n",
94+
"# It now supports batched inputs.\n",
9495
"\n",
9596
"sigma = 3.0\n",
9697
"\n",
@@ -99,8 +100,12 @@
99100
" s = float(sigma_override) if sigma_override is not None else float(sigma)\n",
100101
" tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n",
101102
" # Reorder (x,y) -> (y,x) for the transform\n",
102-
" pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n",
103-
" return tr(pts_yx) # (N,H,W) where pts interpreted as (row, col)"
103+
" # Support batched and non-batched inputs\n",
104+
" pts = np.array(list(zip(y, x)), dtype=np.float32)\n",
105+
" if pts.ndim == 2:\n",
106+
" pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n",
107+
" pts_yx = pts[..., [1, 0]]\n",
108+
" return tr(pts_yx) # (B, N, H, W)\n"
104109
]
105110
},
106111
{
@@ -125,10 +130,14 @@
125130
" spatial_shape=None if use_ref else (H, W),\n",
126131
" sigma=s,\n",
127132
" )\n",
128-
" pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n",
133+
" # Support batched and non-batched inputs\n",
134+
" pts = np.array(list(zip(y, x)), dtype=np.float32)\n",
135+
" if pts.ndim == 2:\n",
136+
" pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n",
137+
" pts_yx = pts[..., [1, 0]]\n",
129138
" data = {\"points\": pts_yx, \"ref\": ref_img}\n",
130139
" out = tr(data)\n",
131-
" return out[\"heatmap\"]"
140+
" return out[\"heatmap\"]\n"
132141
]
133142
},
134143
{
@@ -179,6 +188,7 @@
179188
"num_points = 3 # number of random landmarks\n",
180189
"sigma_demo = 3.0 # Gaussian sigma\n",
181190
"combine_mode = \"max\" # or 'sum'\n",
191+
"batched_input = True # Set to True to test batched input\n",
182192
"\n",
183193
"# Sample random (x,y) points within image bounds (user-friendly)\n",
184194
"points_xy = np.array(\n",
@@ -189,10 +199,15 @@
189199
"\n",
190200
"# Convert to (y,x) for the transform\n",
191201
"yx_points = points_xy[:, [1, 0]].copy()\n",
202+
"if batched_input:\n",
203+
" yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n",
192204
"\n",
193205
"array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n",
194206
"heatmaps = array_tr(yx_points) # now correct orientation\n",
195207
"\n",
208+
"if batched_input:\n",
209+
" heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n",
210+
"\n",
196211
"if combine_mode == \"max\":\n",
197212
" combined = heatmaps.max(axis=0)\n",
198213
"elif combine_mode == \"sum\":\n",
@@ -230,7 +245,7 @@
230245
" ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n",
231246
" ax.set_axis_off()\n",
232247
"plt.tight_layout()\n",
233-
"plt.show()"
248+
"plt.show()\n"
234249
]
235250
}
236251
],

0 commit comments

Comments
 (0)