|
91 | 91 | "# Heatmap helper using GenerateHeatmap\n",
|
92 | 92 | "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n",
|
93 | 93 | "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n",
|
| 94 | + "# It now supports batched inputs.\n", |
94 | 95 | "\n",
|
95 | 96 | "sigma = 3.0\n",
|
96 | 97 | "\n",
|
|
99 | 100 | " s = float(sigma_override) if sigma_override is not None else float(sigma)\n",
|
100 | 101 | " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n",
|
101 | 102 | " # Reorder (x,y) -> (y,x) for the transform\n",
|
102 |
| - " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", |
103 |
| - " return tr(pts_yx) # (N,H,W) where pts interpreted as (row, col)" |
| 103 | + " # Support batched and non-batched inputs\n", |
| 104 | + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", |
| 105 | + " if pts.ndim == 2:\n", |
| 106 | + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", |
| 107 | + " pts_yx = pts[..., [1, 0]]\n", |
| 108 | + " return tr(pts_yx) # (B, N, H, W)\n" |
104 | 109 | ]
|
105 | 110 | },
|
106 | 111 | {
|
|
125 | 130 | " spatial_shape=None if use_ref else (H, W),\n",
|
126 | 131 | " sigma=s,\n",
|
127 | 132 | " )\n",
|
128 |
| - " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", |
| 133 | + " # Support batched and non-batched inputs\n", |
| 134 | + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", |
| 135 | + " if pts.ndim == 2:\n", |
| 136 | + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", |
| 137 | + " pts_yx = pts[..., [1, 0]]\n", |
129 | 138 | " data = {\"points\": pts_yx, \"ref\": ref_img}\n",
|
130 | 139 | " out = tr(data)\n",
|
131 |
| - " return out[\"heatmap\"]" |
| 140 | + " return out[\"heatmap\"]\n" |
132 | 141 | ]
|
133 | 142 | },
|
134 | 143 | {
|
|
179 | 188 | "num_points = 3 # number of random landmarks\n",
|
180 | 189 | "sigma_demo = 3.0 # Gaussian sigma\n",
|
181 | 190 | "combine_mode = \"max\" # or 'sum'\n",
|
| 191 | + "batched_input = True # Set to True to test batched input\n", |
182 | 192 | "\n",
|
183 | 193 | "# Sample random (x,y) points within image bounds (user-friendly)\n",
|
184 | 194 | "points_xy = np.array(\n",
|
|
189 | 199 | "\n",
|
190 | 200 | "# Convert to (y,x) for the transform\n",
|
191 | 201 | "yx_points = points_xy[:, [1, 0]].copy()\n",
|
| 202 | + "if batched_input:\n", |
| 203 | + " yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n", |
192 | 204 | "\n",
|
193 | 205 | "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n",
|
194 | 206 | "heatmaps = array_tr(yx_points) # now correct orientation\n",
|
195 | 207 | "\n",
|
| 208 | + "if batched_input:\n", |
| 209 | + " heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n", |
| 210 | + "\n", |
196 | 211 | "if combine_mode == \"max\":\n",
|
197 | 212 | " combined = heatmaps.max(axis=0)\n",
|
198 | 213 | "elif combine_mode == \"sum\":\n",
|
|
230 | 245 | " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n",
|
231 | 246 | " ax.set_axis_off()\n",
|
232 | 247 | "plt.tight_layout()\n",
|
233 |
| - "plt.show()" |
| 248 | + "plt.show()\n" |
234 | 249 | ]
|
235 | 250 | }
|
236 | 251 | ],
|
|
0 commit comments