Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to WSI datasets to save patch coordinates #674

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ data:
height: 224
target_mpp: 0.25
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ data:
height: 224
target_mpp: 0.25
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
1 change: 1 addition & 0 deletions configs/vision/pathology/offline/classification/panda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ data:
height: 224
target_mpp: 0.5
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ data:
height: 224
target_mpp: 0.5
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
7 changes: 3 additions & 4 deletions src/eva/core/callbacks/writers/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,14 @@ def _get_item_metadata(

def _check_if_exists(self) -> None:
"""Checks if the output directory already exists and if it should be overwritten."""
try:
os.makedirs(self._output_dir, exist_ok=self._overwrite)
except FileExistsError as e:
os.makedirs(self._output_dir, exist_ok=True)
if os.path.exists(os.path.join(self._output_dir, "manifest.csv")) and not self._overwrite:
raise FileExistsError(
f"The embeddings output directory already exists: {self._output_dir}. This "
"either means that they have been computed before or that a wrong output "
"directory is being used. Consider using `eva fit` instead, selecting a "
"different output directory or setting overwrite=True."
) from e
)
os.makedirs(self._output_dir, exist_ok=True)


Expand Down
5 changes: 4 additions & 1 deletion src/eva/vision/data/datasets/classification/camelyon16.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
target_mpp: float = 0.5,
backend: str = "openslide",
image_transforms: Callable | None = None,
coords_path: str | None = None,
seed: int = 42,
) -> None:
"""Initializes the dataset.
Expand All @@ -100,6 +101,7 @@ def __init__(
target_mpp: Target microns per pixel (mpp) for the patches.
backend: The backend to use for reading the whole-slide images.
image_transforms: Transforms to apply to the extracted image patches.
coords_path: File path to save the patch coordinates as .csv.
seed: Random seed for reproducibility.
"""
self._split = split
Expand All @@ -119,6 +121,7 @@ def __init__(
target_mpp=target_mpp,
backend=backend,
image_transforms=image_transforms,
coords_path=coords_path,
)

@property
Expand Down Expand Up @@ -207,7 +210,7 @@ def load_target(self, index: int) -> torch.Tensor:

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
return {"wsi_id": self.filename(index).split(".")[0]}
return wsi.MultiWsiDataset.load_metadata(self, index)

def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
"""Loads the file paths of the corresponding dataset split."""
Expand Down
5 changes: 4 additions & 1 deletion src/eva/vision/data/datasets/classification/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
target_mpp: float = 0.5,
backend: str = "openslide",
image_transforms: Callable | None = None,
coords_path: str | None = None,
seed: int = 42,
) -> None:
"""Initializes the dataset.
Expand All @@ -62,6 +63,7 @@ def __init__(
target_mpp: Target microns per pixel (mpp) for the patches.
backend: The backend to use for reading the whole-slide images.
image_transforms: Transforms to apply to the extracted image patches.
coords_path: File path to save the patch coordinates as .csv.
seed: Random seed for reproducibility.
"""
self._split = split
Expand All @@ -80,6 +82,7 @@ def __init__(
target_mpp=target_mpp,
backend=backend,
image_transforms=image_transforms,
coords_path=coords_path,
)

@property
Expand Down Expand Up @@ -132,7 +135,7 @@ def load_target(self, index: int) -> torch.Tensor:

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
return {"wsi_id": self.filename(index).split(".")[0]}
return wsi.MultiWsiDataset.load_metadata(self, index)

def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
"""Loads the file paths of the corresponding dataset split."""
Expand Down
5 changes: 4 additions & 1 deletion src/eva/vision/data/datasets/classification/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
split: Literal["train", "val", "test"] | None = None,
image_transforms: Callable | None = None,
column_mapping: Dict[str, str] = default_column_mapping,
coords_path: str | None = None,
):
"""Initializes the dataset.

Expand All @@ -51,6 +52,7 @@ def __init__(
split: The split of the dataset to load.
image_transforms: Transforms to apply to the extracted image patches.
column_mapping: Mapping of the columns in the manifest file.
coords_path: File path to save the patch coordinates as .csv.
"""
self._split = split
self._column_mapping = self.default_column_mapping | column_mapping
Expand All @@ -66,6 +68,7 @@ def __init__(
target_mpp=target_mpp,
backend=backend,
image_transforms=image_transforms,
coords_path=coords_path,
)

@override
Expand All @@ -88,7 +91,7 @@ def load_target(self, index: int) -> np.ndarray:

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
return {"wsi_id": self.filename(index).split(".")[0]}
return wsi.MultiWsiDataset.load_metadata(self, index)

def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
df = pd.read_csv(manifest_path)
Expand Down
38 changes: 37 additions & 1 deletion src/eva/vision/data/datasets/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import bisect
import os
from typing import Callable, List
from typing import Any, Callable, Dict, List

import pandas as pd
from loguru import logger
from torch.utils.data import dataset as torch_datasets
from torchvision import tv_tensors
Expand Down Expand Up @@ -85,6 +86,17 @@ def __getitem__(self, index: int) -> tv_tensors.Image:
patch = self._apply_transforms(patch)
return patch

def load_metadata(self, index: int) -> Dict[str, Any]:
"""Loads the metadata for the patch at the specified index."""
x, y = self._coords.x_y[index]
return {
"x": x,
"y": y,
"width": self._coords.width,
"height": self._coords.height,
"level_idx": self._coords.level_idx,
}

def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
if self._image_transforms is not None:
image = self._image_transforms(image)
Expand All @@ -105,6 +117,7 @@ def __init__(
overwrite_mpp: float | None = None,
backend: str = "openslide",
image_transforms: Callable | None = None,
coords_path: str | None = None,
):
"""Initializes a new dataset instance.

Expand All @@ -118,6 +131,7 @@ def __init__(
sampler: The sampler to use for sampling patch coordinates.
backend: The backend to use for reading the whole-slide images.
image_transforms: Transforms to apply to the extracted image patches.
coords_path: File path to save the patch coordinates as .csv.
"""
super().__init__()

Expand All @@ -130,6 +144,7 @@ def __init__(
self._sampler = sampler
self._backend = backend
self._image_transforms = image_transforms
self._coords_path = coords_path

self._concat_dataset: torch_datasets.ConcatDataset

Expand All @@ -146,6 +161,7 @@ def cumulative_sizes(self) -> List[int]:
@override
def configure(self) -> None:
self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
self._save_coords_to_file()

@override
def __len__(self) -> int:
Expand All @@ -159,6 +175,12 @@ def __getitem__(self, index: int) -> tv_tensors.Image:
def filename(self, index: int) -> str:
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])

def load_metadata(self, index: int) -> Dict[str, Any]:
"""Loads the metadata for the patch at the specified index."""
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata

def _load_datasets(self) -> list[WsiDataset]:
logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
wsi_datasets = []
Expand All @@ -185,3 +207,17 @@ def _load_datasets(self) -> list[WsiDataset]:

def _get_dataset_idx(self, index: int) -> int:
return bisect.bisect_right(self.cumulative_sizes, index)

def _get_sample_idx(self, index: int) -> int:
dataset_idx = self._get_dataset_idx(index)
return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1]

def _save_coords_to_file(self):
if self._coords_path is not None:
coords = [
{"file": self._file_paths[i]} | dataset._coords.to_dict()
for i, dataset in enumerate(self.datasets)
]
os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True)
pd.DataFrame(coords).to_csv(self._coords_path, index=False)
logger.info(f"Saved patch coordinates to: {self._coords_path}")
10 changes: 9 additions & 1 deletion src/eva/vision/data/wsi/patching/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import functools
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

from eva.vision.data.wsi import backends
from eva.vision.data.wsi.patching import samplers
Expand Down Expand Up @@ -75,6 +75,14 @@ def from_file(

return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))

def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]:
"""Convert the coordinates to a dictionary."""
include_keys = include_keys or ["x_y", "width", "height", "level_idx"]
coord_dict = dataclasses.asdict(self)
if include_keys:
coord_dict = {key: coord_dict[key] for key in include_keys}
return coord_dict


@functools.lru_cache(LRU_CACHE_SIZE)
def get_cached_coords(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def _check_batch_shape(batch: Any):
assert isinstance(target, torch.Tensor)
assert isinstance(metadata, dict)
assert "wsi_id" in metadata
assert "x" in metadata
assert "y" in metadata
assert "width" in metadata
assert "height" in metadata
assert "level_idx" in metadata


@pytest.fixture
Expand Down
5 changes: 5 additions & 0 deletions tests/eva/vision/data/datasets/classification/test_panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def _check_batch_shape(batch: Any):
assert isinstance(target, torch.Tensor)
assert isinstance(metadata, dict)
assert "wsi_id" in metadata
assert "x" in metadata
assert "y" in metadata
assert "width" in metadata
assert "height" in metadata
assert "level_idx" in metadata


@pytest.fixture
Expand Down
5 changes: 5 additions & 0 deletions tests/eva/vision/data/datasets/classification/test_wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def _check_batch_shape(batch: Any):

assert isinstance(metadata, dict)
assert "wsi_id" in metadata
assert "x" in metadata
assert "y" in metadata
assert "width" in metadata
assert "height" in metadata
assert "level_idx" in metadata


@pytest.fixture
Expand Down
12 changes: 10 additions & 2 deletions tests/eva/vision/data/datasets/test_wsi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""WsiDataset & MultiWsiDataset tests."""

import os
import pathlib
from typing import Tuple

import pandas as pd
import pytest

from eva.vision.data import datasets
Expand Down Expand Up @@ -69,14 +71,14 @@ def test_patch_shape(width: int, height: int, target_mpp: float, root: str, back
assert dataset[0].shape == (3, scaled_width, scaled_height)


def test_multi_dataset(root: str):
def test_multi_dataset(root: str, tmp_path: pathlib.Path):
"""Test MultiWsiDataset with multiple whole-slide image paths."""
coords_path = (tmp_path / "coords.csv").as_posix()
file_paths = [
os.path.join(root, "0/a.tiff"),
os.path.join(root, "0/b.tiff"),
os.path.join(root, "1/a.tiff"),
]

width, height = 32, 32
dataset = datasets.MultiWsiDataset(
root=root,
Expand All @@ -86,6 +88,7 @@ def test_multi_dataset(root: str):
target_mpp=0.25,
sampler=samplers.GridSampler(max_samples=None),
backend="openslide",
coords_path=coords_path,
)
dataset.setup()

Expand All @@ -94,6 +97,11 @@ def test_multi_dataset(root: str):
assert len(dataset) == _expected_n_patches(layer_shape, width, height, (0, 0)) * len(file_paths)
assert dataset.cumulative_sizes == [64, 128, 192]

assert os.path.exists(coords_path)
df_coords = pd.read_csv(coords_path)
assert "file" in df_coords.columns
assert "x_y" in df_coords.columns


def _expected_n_patches(layer_shape, width, height, overlap):
"""Calculate the expected number of patches."""
Expand Down