Skip to content

Commit

Permalink
Fix MIL heatmaps with stride > 1 [#331]
Browse files Browse the repository at this point in the history
- Fix a bug where heatmaps cannot be generated when training/evaluating an MIL model with a non-standard stride (any value other than 1).
  • Loading branch information
jamesdolezal committed Jan 11, 2024
1 parent df51255 commit 632fbf7
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 6 deletions.
1 change: 1 addition & 0 deletions slideflow/mil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
ModelConfigCLAM,
ModelConfigFastAI
)
from .utils import load_model_weights
115 changes: 109 additions & 6 deletions slideflow/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,24 @@ def map_values_to_slide_grid(
background: str = 'min',
*,
interpolation: Optional[str] = 'bicubic',
):
"""Map heatmap values to a slide grid, using tile location information."""
) -> np.ndarray:
"""Map heatmap values to a slide grid, using tile location information.
Args:
locations (np.ndarray): Array of shape ``(n_tiles, 2)`` containing x, y
coordinates for all image tiles. Coordinates represent the center
for an associated tile, and must be in a grid.
values (np.ndarray): Array of shape ``(n_tiles,)`` containing heatmap
values for each tile.
wsi (slideflow.wsi.WSI): WSI object.
Keyword args:
background (str, optional): Background strategy for heatmap. Can be
'min', 'mean', 'median', 'max', or 'mask'. Defaults to 'min'.
interpolation (str, optional): Interpolation strategy for smoothing
heatmap. Defaults to 'bicubic'.
"""
no_interpolation = (interpolation is None or interpolation == 'nearest')

# Slide coordinate information
Expand Down Expand Up @@ -1159,6 +1174,74 @@ def map_values_to_slide_grid(
return masked_grid


def bin_values_to_slide_grid(
locations: np.ndarray,
values: np.ndarray,
wsi: "sf.WSI",
background: str = 'min',
) -> np.ndarray:
"""Bin heatmap values to a slide grid, using tile location information.
Args:
locations (np.ndarray): Array of shape ``(n_tiles, 2)`` containing x, y
coordinates for all image tiles. Coordinates represent the center
for an associated tile, and must be in a grid.
values (np.ndarray): Array of shape ``(n_tiles,)`` containing heatmap
values for each tile.
wsi (slideflow.wsi.WSI): WSI object.
Keyword args:
background (str, optional): Background strategy for heatmap. Can be
'min', 'mean', 'median', 'max', or 'mask'. Defaults to 'min'.
"""
from scipy.stats import binned_statistic_2d
masked_grid, *_ = binned_statistic_2d(
locations[:, 0],
locations[:, 1],
values,
bins=wsi.grid.shape,
range=[[0, wsi.dimensions[0]], [0, wsi.dimensions[1]]]
)
masked_grid = masked_grid.T
nan_idx = np.where(np.isnan(masked_grid))

if background == 'mask':
# No action needed
pass
elif background == 'min':
masked_grid[nan_idx] = np.min(values)
elif background == 'mean':
masked_grid[nan_idx] = np.mean(values)
elif background == 'median':
masked_grid[nan_idx] = np.median(values)
elif background == 'max':
masked_grid[nan_idx] = np.max(values)
else:
raise ValueError(f"Unrecognized value for background: {background}")

return masked_grid


def infer_stride(locations, wsi):
"""Infer the stride of a grid of locations from a set of locations.
Args:
locations (np.ndarray): Nx2 array of locations
wsi (slideflow.wsi.WSI): WSI object
Returns:
float: inferred stride divisor in pixels
"""
sort_unique_x = np.sort(np.unique(locations[:, 0]))
sort_unique_y = np.sort(np.unique(locations[:, 1]))
min_stride_x = (sort_unique_x[1:] - sort_unique_x[:-1]).min()
min_stride_y = (sort_unique_y[1:] - sort_unique_y[:-1]).min()
inferred_stride_px = min(min_stride_x, min_stride_y)
return wsi.full_extract_px / inferred_stride_px


def location_heatmap(
locations: np.ndarray,
values: np.ndarray,
Expand Down Expand Up @@ -1203,7 +1286,19 @@ def location_heatmap(
slide_name = sf.util.path_to_name(slide)
log.info(f'Generating heatmap for [green]{slide}[/]...')
log.debug(f"Plotting {len(values)} values")
wsi = sf.slide.WSI(slide, tile_px, tile_um, verbose=False)
wsi = sf.WSI(slide, tile_px, tile_um, verbose=False)
stride = infer_stride(locations, wsi)
if stride > 32:
# Large inferred strides are likely due to unaligned grid.
# Rather than attempting to build a coordinate grid for verifying
# grid alignment, we will assume that the grid is unaligned and
# use the default stride (1). This will cause map_values_to_slide_grid
# to recognize that the grid is unaligned, and the heatmap will be built
# using histogram2d.
log.debug(f"Failed sanity check for inferred stride ({stride})")
elif stride != 1:
log.debug(f"Inferred stride: {stride}")
wsi = sf.WSI(slide, tile_px, tile_um, stride_div=stride, verbose=False)

stats = {
slide_name: {
Expand All @@ -1212,9 +1307,17 @@ def location_heatmap(
}
}

masked_grid = map_values_to_slide_grid(
locations, values, wsi, background=background, interpolation=interpolation
)
try:
masked_grid = map_values_to_slide_grid(
locations, values, wsi, background=background, interpolation=interpolation
)
except errors.CoordinateAlignmentError as e:
log.debug("Coordinate alignment error: {}".format(e))
log.info("Unable to align grid for plotting heatmap. Heatmap will be "
"binned with a stride of 1.")
masked_grid = bin_values_to_slide_grid(
locations, values, wsi, background=background
)

fig = plt.figure(figsize=(18, 16))
ax = fig.add_subplot(111)
Expand Down

0 comments on commit 632fbf7

Please sign in to comment.