Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dask suppport and valis alignment #69

Merged
merged 103 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
89bda5e
start implementing lazy shapes
pauldoucet Jul 24, 2024
6912985
refactoring of hest data shape mechanism
pauldoucet Jul 24, 2024
35a19cc
refactor tests and add seg_readers
pauldoucet Jul 24, 2024
af50700
cleanup
pauldoucet Jul 24, 2024
73d6552
minor changes 2
pauldoucet Jul 24, 2024
3d7bc12
speed up segmentation
pauldoucet Jul 25, 2024
9478c90
speed up tissue segmentation
pauldoucet Jul 25, 2024
5855798
fix segmentation and refactor wsi handling
pauldoucet Jul 26, 2024
03fa597
refactoring of tissue segmentation
pauldoucet Jul 29, 2024
bdbe308
fix spatial data cellvit conversion
pauldoucet Jul 29, 2024
1c5d69e
add spatialdata test
pauldoucet Jul 29, 2024
5324077
fix geojson contours and update tutorial 2
pauldoucet Jul 29, 2024
163db72
cleanup
pauldoucet Jul 29, 2024
94b9508
read .parquet cellvit and refactor seg readers
pauldoucet Jul 30, 2024
d5f51ee
optimize imports
pauldoucet Jul 30, 2024
19d6f5d
remove kwimage dependency
pauldoucet Jul 30, 2024
b2611d5
only warn once for cucim import
pauldoucet Jul 30, 2024
611a8fa
correct cucim circular import
pauldoucet Jul 30, 2024
952530f
warn_cucim not defined
pauldoucet Jul 30, 2024
b29b953
use singleton warn cucim
pauldoucet Jul 30, 2024
d18aba3
update tutorials
pauldoucet Jul 30, 2024
224d802
fix reader bug
pauldoucet Jul 30, 2024
4911fe1
add pipeline
pauldoucet Jul 30, 2024
a88319d
add registration
pauldoucet Jul 30, 2024
7a9f3f5
minor changes
pauldoucet Jul 30, 2024
b4767b7
add atlas
pauldoucet Jul 30, 2024
4705a22
add subtyping
pauldoucet Jul 30, 2024
7e8423d
optimize import
pauldoucet Jul 30, 2024
26fadbf
minor changes
pauldoucet Jul 30, 2024
76438e6
integrate valis into pipeline
pauldoucet Jul 31, 2024
aacd8b9
modify alignemnt for xenium>2.0
pauldoucet Jul 31, 2024
da702b3
refactoring
pauldoucet Aug 3, 2024
49ce521
comment
pauldoucet Aug 3, 2024
0255aac
voronoi
pauldoucet Aug 3, 2024
04ca9bb
finalize nuclei expansion
pauldoucet Aug 3, 2024
0d4965a
new changes
pauldoucet Aug 6, 2024
2bf271c
refactor WSIPatcher
pauldoucet Aug 7, 2024
aadfb15
put holes directly inside polygons
pauldoucet Aug 7, 2024
653cb73
correct typo in getitem
pauldoucet Aug 7, 2024
801655a
refactor patcher arguments
pauldoucet Aug 7, 2024
528ca8f
commit before erasing history
pauldoucet Aug 8, 2024
a5065d1
reimplement dunp_patches with the new wsipatcher
pauldoucet Aug 8, 2024
bf50ae4
delete hest bench
pauldoucet Aug 8, 2024
548ff17
Merge branch 'main' into wsi
pauldoucet Aug 8, 2024
9a6a82b
rm bench
pauldoucet Aug 8, 2024
55e8c20
Merge branch 'wsi' into develop
pauldoucet Aug 8, 2024
3dff001
changes
pauldoucet Aug 8, 2024
837574f
clean imports
pauldoucet Aug 8, 2024
4c3b41b
change gitignore
pauldoucet Aug 8, 2024
5897079
apt-get update before install
pauldoucet Aug 8, 2024
9500974
modify secret
pauldoucet Aug 8, 2024
b4c43ea
allow code exec
pauldoucet Aug 8, 2024
8c938f4
trust_remote_code=True
pauldoucet Aug 8, 2024
639ca00
reduce number of samples for test
pauldoucet Aug 8, 2024
bccbb3e
force custom_coords to be int
pauldoucet Aug 8, 2024
a951168
remove torch sync and throw exception on test fail
pauldoucet Aug 8, 2024
68303b1
include all tests
pauldoucet Aug 8, 2024
d8d33f7
wsi
pauldoucet Aug 8, 2024
a1b2efc
Merge branch 'wsi' into develop
pauldoucet Aug 8, 2024
6d33d70
get segmentation from wsi
pauldoucet Aug 8, 2024
9f15d6e
fix patching
pauldoucet Aug 9, 2024
f5c83cf
fix dump_patches
pauldoucet Aug 9, 2024
a6dabc5
fix wsi
pauldoucet Aug 9, 2024
17a6c59
changes
pauldoucet Aug 9, 2024
dc451be
final changes wsi
pauldoucet Aug 9, 2024
37ce7f3
improve pipeline
pauldoucet Aug 10, 2024
f87d595
small changes
pauldoucet Aug 10, 2024
911b313
update wsi
pauldoucet Aug 10, 2024
2ef0f9f
update dump patches
pauldoucet Aug 10, 2024
eb12634
Merge branch 'wsi' into develop
pauldoucet Aug 10, 2024
a734d08
download true
pauldoucet Aug 10, 2024
1c48fa0
download true
pauldoucet Aug 10, 2024
fd3549c
minor changes
pauldoucet Aug 12, 2024
7449e1a
clean dependencies
pauldoucet Aug 12, 2024
169d3a8
add hestcore as dependency
pauldoucet Aug 12, 2024
8420833
Merge branch 'wsi' into develop
pauldoucet Aug 12, 2024
7550d71
small changes
pauldoucet Aug 13, 2024
1ae7ff7
cleanup bench
pauldoucet Aug 14, 2024
908f9a7
add virchow1/2 hoptimus2
pauldoucet Aug 14, 2024
a6b96a4
hestcore 1.0.1
pauldoucet Aug 14, 2024
24b06c8
dont permute
pauldoucet Aug 14, 2024
5268919
fix permute
pauldoucet Aug 14, 2024
dcd8c72
small changes
pauldoucet Aug 14, 2024
e14da4f
fix indentation
pauldoucet Aug 15, 2024
434921d
Merge branch 'hest-bench-clean' into develop
pauldoucet Aug 15, 2024
88e1a77
small changes
pauldoucet Aug 15, 2024
e4f9a82
fix patching and add tests
pauldoucet Aug 15, 2024
b0fda59
Merge branch 'hest-bench-clean' into develop
pauldoucet Aug 15, 2024
aebbc0e
before commit
pauldoucet Aug 15, 2024
20853f3
small change
pauldoucet Aug 16, 2024
e5591b8
small changes
pauldoucet Aug 23, 2024
e66d4df
Merge remote-tracking branch 'origin/main' into develop
pauldoucet Aug 23, 2024
d1eb460
minor changes
pauldoucet Aug 26, 2024
ca43287
finish warping transcripts
pauldoucet Sep 20, 2024
7cb578a
minor changes
pauldoucet Sep 22, 2024
f746807
Merge branch 'main' into develop
pauldoucet Oct 29, 2024
4f1f02e
refactor xenium reader
pauldoucet Nov 1, 2024
eebf6c8
add DASK support for transcripts loading
pauldoucet Nov 5, 2024
2705e74
remove unused imports
pauldoucet Nov 5, 2024
a5aea93
Merge branch 'main' into develop
pauldoucet Nov 5, 2024
345ceea
optimize imports
pauldoucet Nov 5, 2024
8a4a12d
optimize imports further
pauldoucet Nov 5, 2024
c9d5018
remove incorrect import
pauldoucet Nov 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 23 additions & 58 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,58 +1,23 @@
__pycache__
data
align_coord_img.ipynb
src/hiSTloader/yolov8n.pt
.vscode/launch.json
dist
src/hest.egg-info
make.bat
Makefile
*.parquet
bench_data
cell_seg
filtered
src/hest/bench/timm_ctp
my_notebooks
.gitattributes
cells_xenium.geojson
nuclei_xenium.geojson
nuclei.tif
cells.tif
src_slides
data64.h5
tests/assets
config
tests/output_tests
tissue_seg

results
atlas
figures/test_paul
bench_data.zip
old_bench_data
tutorials/downloads
tutorials/processed
bench_config/my_bench_config.yaml
src/hest/bench/private
str.csv
bench_data_old
ST_data_emb/
ST_pred_results/
hest_data
fm_v1
cufile.log
int.csv
docs/build
docs/source/generated
local
hest_vis
hest_vis2
hest_vis
vis
vis2
models/deeplabv3*
htmlcov
models/CellViT-SAM-H-x40.pth
debug_seg
replace_seg
test_vis
data
.vscode/launch.json
dist
src/hest.egg-info
bench_data
.gitattributes
tests/assets
config
tests/output_tests
HEST/

results
atlas
ST_data_emb/
ST_pred_results/
hest_data
fm_v1
docs/build
docs/source/generated
local
models/deeplabv3*
htmlcov
models/CellViT-SAM-H-x40.pth
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ dependencies = [
"spatial_image >= 0.3.0",
"datasets",
"mygene",
"hestcore == 1.0.3"
"hestcore == 1.0.4"
]

requires-python = ">=3.9"
101 changes: 69 additions & 32 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
@@ -4,18 +4,17 @@
import os
import shutil
import warnings
from typing import Dict, Iterator, List, Union
from typing import Dict, List, Union

import cv2
import geopandas as gpd
import numpy as np
from loguru import logger
from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI,
contours_to_img, wsi_factory)
from loguru import logger



from hest.io.seg_readers import TissueContourReader
from hest.io.seg_readers import TissueContourReader, write_geojson
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask

@@ -31,7 +30,7 @@
from tqdm import tqdm

from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated,
find_first_file_endswith, get_path_from_meta_row,
find_first_file_endswith, get_k_genes_from_df, get_path_from_meta_row,
plot_verify_pixel_size, tiff_save, verify_paths)


@@ -100,7 +99,7 @@ class representing a single ST profile + its associated WSI image
else:
self._tissue_contours = tissue_contours

if 'total_counts' not in self.adata.var_names:
if 'total_counts' not in self.adata.var_names and len(self.adata) > 0:
sc.pp.calculate_qc_metrics(self.adata, inplace=True)


@@ -133,7 +132,7 @@ def load_wsi(self) -> None:
self.wsi = NumpyWSI(self.wsi.numpy())


def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False):
def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False, **kwargs):
"""Save a HESTData object to `path` as follows:
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
- metrics.json (contains useful metrics)
@@ -155,6 +154,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
self.adata.write(os.path.join(path, 'aligned_adata.h5ad'))
except:
# workaround from https://github.com/theislab/scvelo/issues/255
import traceback
traceback.print_exc()
self.adata.__dict__['_raw'].__dict__['_var'] = self.adata.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})
self.adata.write(os.path.join(path, 'aligned_adata.h5ad'))

@@ -172,7 +173,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
downscaled_img = self.adata.uns['spatial']['ST']['images']['downscaled_fullres']
down_fact = self.adata.uns['spatial']['ST']['scalefactors']['tissue_downscaled_fullres_scalef']
down_img = Image.fromarray(downscaled_img)
down_img.save(os.path.join(path, 'downscaled_fullres.jpeg'))
if len(downscaled_img) > 0:
down_img.save(os.path.join(path, 'downscaled_fullres.jpeg'))


if plot_pxl_size:
@@ -748,7 +750,9 @@ def __init__(
xenium_nuc_seg: pd.DataFrame=None,
xenium_cell_seg: pd.DataFrame=None,
cell_adata: sc.AnnData=None, # type: ignore
transcript_df: pd.DataFrame=None
transcript_df: pd.DataFrame=None,
dapi_path: str=None,
alignment_file_path: str=None
):
"""
class representing a single ST profile + its associated WSI image
@@ -765,16 +769,31 @@ class representing a single ST profile + its associated WSI image
xenium_cell_seg (pd.DataFrame): content of a xenium cell contour file as a dataframe (cell_boundaries.parquet)
cell_adata (sc.AnnData): ST cell data, each row in adata.obs is a cell, each row in obsm is the cell location on the H&E image in pixels
transcript_df (pd.DataFrame): dataframe of transcripts, each row is a transcript, he_x and he_y is the transcript location on the H&E image in pixels
dapi_path (str): path to a dapi focus image
alignment_file_path (np.ndarray): path to xenium alignment path
"""
super().__init__(adata=adata, img=img, pixel_size=pixel_size, meta=meta, tissue_seg=tissue_seg, tissue_contours=tissue_contours, shapes=shapes)

self.xenium_nuc_seg = xenium_nuc_seg
self.xenium_cell_seg = xenium_cell_seg
self.cell_adata = cell_adata
self.transcript_df = transcript_df


def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False):
self.dapi_path = dapi_path
self.alignment_file_path = alignment_file_path


def save(
self,
path: str,
save_img=True,
pyramidal=True,
bigtiff=False,
plot_pxl_size=False,
save_transcripts=False,
save_cell_seg=False,
save_nuclei_seg=False,
**kwargs
):
"""Save a HESTData object to `path` as follows:
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
- metrics.json (contains useful metrics)
@@ -795,21 +814,18 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
if self.cell_adata is not None:
self.cell_adata.write_h5ad(os.path.join(path, 'aligned_cells.h5ad'))

if self.transcript_df is not None:
if save_transcripts and self.transcript_df is not None:
self.transcript_df.to_parquet(os.path.join(path, 'aligned_transcripts.parquet'))

if save_cell_seg:
he_cells = self.get_shapes('tenx_cell', 'he').shapes
he_cells.to_parquet(os.path.join(path, 'he_cell_seg.parquet'))
write_geojson(he_cells, os.path.join(path, f'he_cell_seg.geojson'), '', chunk=True)

if self.xenium_nuc_seg is not None:
print('Saving Xenium nucleus boundaries... (can be slow)')
with open(os.path.join(path, 'nuclei_xenium.geojson'), 'w') as f:
json.dump(self.xenium_nuc_seg, f, indent=4)

if self.xenium_cell_seg is not None:
print('Saving Xenium cells boundaries... (can be slow)')
with open(os.path.join(path, 'cells_xenium.geojson'), 'w') as f:
json.dump(self.xenium_cell_seg, f, indent=4)


# TODO save segmentation
if save_nuclei_seg:
he_nuclei = self.get_shapes('tenx_nucleus', 'he').shapes
he_nuclei.to_parquet(os.path.join(path, 'he_nucleus_seg.parquet'))
write_geojson(he_nuclei, os.path.join(path, f'he_nucleus_seg.geojson'), '', chunk=True)


def read_HESTData(
@@ -936,19 +952,33 @@ def mask_and_patchify_bench(meta_df: pd.DataFrame, save_dir: str, use_mask=True,
i += 1


def create_benchmark_data(meta_df, save_dir:str, K, adata_folder, use_mask, keep_largest=None):
def create_benchmark_data(meta_df, save_dir:str, K):
os.makedirs(save_dir, exist_ok=True)
if K is not None:
splits = meta_df.groupby('patient')['id'].agg(list).to_dict()
create_splits(os.path.join(save_dir, 'splits'), splits, K=K)

meta_df['patient'] = meta_df['patient'].fillna('Patient 1')

get_k_genes_from_df(meta_df, 50, 'var', os.path.join(save_dir, 'var_50genes.json'))

splits = meta_df.groupby(['dataset_title', 'patient'])['id'].agg(list).to_dict()
create_splits(os.path.join(save_dir, 'splits'), splits, K=K)

os.makedirs(os.path.join(save_dir, 'patches'), exist_ok=True)
mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)
#mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)

os.makedirs(os.path.join(save_dir, 'patches_vis'), exist_ok=True)
os.makedirs(os.path.join(save_dir, 'adata'), exist_ok=True)
for index, row in meta_df.iterrows():
for _, row in meta_df.iterrows():
id = row['id']
src_adata = os.path.join(adata_folder, id + '.h5ad')
path = os.path.join(get_path_from_meta_row(row), 'processed')
src_patch = os.path.join(path, 'patches.h5')
dst_patch = os.path.join(save_dir, 'patches', id + '.h5')
shutil.copy(src_patch, dst_patch)

src_vis = os.path.join(path, 'patches_patch_vis.png')
dst_vis = os.path.join(save_dir, 'patches_vis', id + '.png')
shutil.copy(src_vis, dst_vis)

src_adata = os.path.join(path, 'aligned_adata.h5ad')
dst_adata = os.path.join(save_dir, 'adata', id + '.h5ad')
shutil.copy(src_adata, dst_adata)

@@ -1200,6 +1230,13 @@ def unify_gene_names(adata: sc.AnnData, species="human", drop=False) -> sc.AnnDa
mask = ~adata.var_names.duplicated(keep='first')
adata = adata[:, mask]

duplicated_genes_after = adata.var_names[adata.var_names.duplicated()]
if len(duplicated_genes_after) > len(duplicated_genes_before):
logger.warning(f"duplicated genes increased from {len(duplicated_genes_before)} to {len(duplicated_genes_after)} after resolving aliases")
logger.info('deduplicating...')
mask = ~adata.var_names.duplicated(keep='first')
adata = adata[:, mask]

if drop:
adata = adata[:, ~remaining]

12 changes: 9 additions & 3 deletions src/hest/LazyShapes.py
Original file line number Diff line number Diff line change
@@ -2,24 +2,30 @@
import pandas as pd
from shapely import Polygon

from hest.io.seg_readers import read_gdf
from hest.io.seg_readers import GDFReader, read_gdf
from hest.utils import verify_paths


class LazyShapes:

path: str = None

def __init__(self, path: str, name: str, coordinate_system: str):
def __init__(self, path: str, name: str, coordinate_system: str, reader: GDFReader=None, reader_kwargs = {}):
verify_paths([path])
self.path = path
self.name = name
self.coordinate_system = coordinate_system
self._shapes = None
self.reader_kwargs = reader_kwargs
self.reader = reader

def compute(self) -> None:
if self._shapes is None:
self._shapes = read_gdf(self.path)
if self.reader is None:
self._shapes = read_gdf(self.path, self.reader_kwargs)
else:
self._shapes = self.reader(**self.reader_kwargs).read_gdf(self.path)


@property
def shapes(self) -> gpd.GeoDataFrame:
46 changes: 46 additions & 0 deletions src/hest/SlideReaderAdapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Slide Adapter class for Valis compatibility
import os

import numpy as np
from valis import slide_tools
from valis.slide_io import PIXEL_UNIT, MetaData, SlideReader

from hestcore.wsi import wsi_factory


class SlideReaderAdapter(SlideReader):
def __init__(self, src_f, *args, **kwargs):
super().__init__(src_f, *args, **kwargs)
self.wsi = wsi_factory(src_f)
self.metadata = self.create_metadata()

def create_metadata(self):
meta_name = f"{os.path.split(self.src_f)[1]}_Series(0)".strip("_")
slide_meta = MetaData(meta_name, 'SlideReaderAdapter')

slide_meta.is_rgb = True
slide_meta.channel_names = self._get_channel_names('NO_NAME')
slide_meta.n_channels = 1
slide_meta.pixel_physical_size_xyu = [0.25, 0.25, PIXEL_UNIT]
level_dim = self.wsi.level_dimensions() #self._get_slide_dimensions()
slide_meta.slide_dimensions = np.array([list(item) for item in level_dim])

return slide_meta

def slide2vips(self, level, xywh=None, *args, **kwargs):
img = self.slide2image(level, xywh=xywh, *args, **kwargs)
vips_img = slide_tools.numpy2vips(img)

return vips_img

def slide2image(self, level, xywh=None, *args, **kwargs):
level_dim = self.wsi.level_dimensions()[level]
img = self.wsi.get_thumbnail(level_dim[0], level_dim[1])

if xywh is not None:
xywh = np.array(xywh)
start_c, start_r = xywh[0:2]
end_c, end_r = xywh[0:2] + xywh[2:]
img = img[start_r:end_r, start_c:end_c]

return img
2 changes: 1 addition & 1 deletion src/hest/bench/st_dataset.py
Original file line number Diff line number Diff line change
@@ -42,4 +42,4 @@ def load_adata(expr_path, genes = None, barcodes = None, normalize=False):
adata = adata[:, genes]
if normalize:
adata = normalize_adata(adata)
return adata.to_df()
return adata.to_df()
56 changes: 56 additions & 0 deletions src/hest/custom_readers.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,22 @@
write_10X_h5)


def colon_atlas_to_adata(path):
h5_path = find_first_file_endswith(path, 'filtered.h5ad')
custom_adata = sc.read_h5ad(h5_path)
#custom_adata.obs['pxl_col_in_fullres'] = custom_adata.obsm['spatial'][:, 0]
#custom_adata.obs['pxl_row_in_fullres'] = custom_adata.obsm['spatial'][:, 1]
custom_adata = custom_adata[custom_adata.obs['in_tissue'] == 1]
return custom_adata

def heart_atlas_to_adata(path):
h5_path = find_first_file_endswith(path, '.raw.h5ad')
custom_data = sc.read_h5ad(h5_path)
custom_data.obs['pxl_col_in_fullres'] = custom_data.obsm['spatial'][:, 0]
custom_data.obs['pxl_row_in_fullres'] = custom_data.obsm['spatial'][:, 1]
custom_data.obs.index = [idx.split('_')[1] for idx in custom_data.obs.index]
return custom_data

def GSE238145_to_adata(path):
counts_path = find_first_file_endswith(path, 'counts.txt')
coords_path = find_first_file_endswith(path, 'coords.txt')
@@ -426,4 +442,44 @@ def raw_count_to_adata(raw_count_path):

adata = sc.AnnData(matrix)

return adata


def GSE144239_to_adata(raw_counts_path, spot_coord_path):
import scanpy as sc

raw_counts = pd.read_csv(raw_counts_path, sep='\t', index_col=0)
spot_coord = pd.read_csv(spot_coord_path, sep='\t')
spot_coord.index = spot_coord['x'].astype(str) + ['x' for _ in range(len(spot_coord))] + spot_coord['y'].astype(str)
merged = pd.merge(spot_coord, raw_counts, left_index=True, right_index=True)
raw_counts = raw_counts.reindex(merged.index)
adata = sc.AnnData(raw_counts)
col1 = merged['pixel_x'].values
col2 = merged['pixel_y'].values
matrix = (np.vstack((col1, col2))).T
adata.obsm['spatial'] = matrix
return adata


def ADT_to_adata(img_path, raw_counts_path):
import scanpy as sc

basedir = os.path.dirname(img_path)
# combine spot coordinates into a single dataframe
pre_adt_path= find_first_file_endswith(basedir, 'pre-ADT.tsv')
post_adt_path = find_first_file_endswith(basedir, 'postADT.tsv')
if post_adt_path is None:
post_adt_path = find_first_file_endswith(basedir, 'post-ADT.tsv')
counts = pd.read_csv(raw_counts_path, index_col=0, sep='\t')
pre_adt = pd.read_csv(pre_adt_path, sep='\t')
post_adt = pd.read_csv(post_adt_path, sep='\t')
merged_coords = pd.concat([pre_adt, post_adt], ignore_index=True)
merged_coords.index = [str(x) + 'x' + str(y) for x, y in zip(merged_coords['x'], merged_coords['y'])]
merged = pd.merge(merged_coords, counts, left_index=True, right_index=True, how='inner')
counts = counts.reindex(merged.index)
adata = sc.AnnData(counts)
col1 = merged['pixel_x'].values
col2 = merged['pixel_y'].values
matrix = (np.vstack((col1, col2))).T
adata.obsm['spatial'] = matrix
return adata
102 changes: 86 additions & 16 deletions src/hest/io/seg_readers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
import json
import warnings
from abc import abstractmethod

import geopandas as gpd
from loguru import logger
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from shapely.geometry.polygon import Point, Polygon
from tqdm import tqdm

from hest.utils import align_xenium_df, get_n_threads


def _process(x, extra_props, index_key, class_name):
from shapely.geometry.polygon import Point, Polygon
@@ -60,26 +64,78 @@ class GDFReader:
@abstractmethod
def read_gdf(self, path) -> gpd.GeoDataFrame:
pass

def fn(block, i):
logger.debug(f'start fn block {i}')
groups = defaultdict(lambda: [])
[groups[row[0]].append(row[1]) for row in block]
g = np.array([Polygon(value) for _, value in groups.items()])
key = np.array([key for key, _ in groups.items()])
logger.debug(f'finish fn block {i}')
return np.column_stack((key, g))

def groupby_shape(df, col, n_threads, col_shape='xy'):
n_chunks = n_threads

if n_threads >= 1:
l = len(df) // n_chunks
start = 0
chunk_lens = []
while start < len(df):
end = min(start + l, len(df))
while end < len(df) and df.iloc[end][col] == df.iloc[end - 1][col]:
end += 1
chunk_lens.append((start, end))
start = end

dfs = []
with ProcessPoolExecutor(max_workers=n_threads) as executor:
future_results = [executor.submit(fn, df[[col, col_shape]].iloc[start:end].values, start) for start, end in chunk_lens]

for future in as_completed(future_results):
dfs.append(future.result())

concat = np.concatenate(dfs)
else:
concat = fn(df[[col, col_shape]].values, 0)

gdf = gpd.GeoDataFrame(geometry=concat[:, 1])
gdf.index = concat[:, 0]

return gdf

class XeniumParquetCellReader(GDFReader):

def read_gdf(self, path) -> gpd.GeoDataFrame:
def __init__(self, pixel_size_morph=None, alignment_matrix=None):
self.pixel_size_morph = pixel_size_morph
self.alignment_matrix = alignment_matrix

def read_gdf(self, path, n_workers=0) -> gpd.GeoDataFrame:

df = pd.read_parquet(path)



if self.alignment_matrix is not None:
df = align_xenium_df(
df,
self.alignment_matrix,
self.pixel_size_morph,
'vertex_x',
'vertex_y',
x_key_dist='vertex_x',
y_key_dist='vertex_y')
else:
df['vertex_x'], df['vertex_y'] = df['vertex_x'] / self.pixel_size_morph, df['vertex_y'] / self.pixel_size_morph

df['xy'] = list(zip(df['vertex_x'], df['vertex_y']))
df = df.drop(['vertex_x', 'vertex_y'], axis=1)
df = df.drop(['vertex_x', 'vertex_y'], axis=1)

df = df.groupby('cell_id').agg({
'xy': Polygon
}).reset_index()
n_threads = get_n_threads(n_workers)

gdf = gpd.GeoDataFrame(df, geometry=df['xy'])
gdf = gdf.drop(['xy'], axis=1)
gdf = groupby_shape(df, 'cell_id', n_threads)
return gdf


class GDFParquetCellReader(GDFReader):

def read_gdf(self, path) -> gpd.GeoDataFrame:
@@ -102,7 +158,7 @@ def read_gdf(self, path) -> gpd.GeoDataFrame:
return gdf


def write_geojson(gdf: gpd.GeoDataFrame, path: str, category_key: str, extra_prop=False, uniform_prop=True, index_key: str=None) -> None:
def write_geojson(gdf: gpd.GeoDataFrame, path: str, category_key: str, extra_prop=False, uniform_prop=True, index_key: str=None, chunk=False) -> None:

if isinstance(gdf.geometry.iloc[0], Point):
geometry = 'MultiPoint'
@@ -111,6 +167,19 @@ def write_geojson(gdf: gpd.GeoDataFrame, path: str, category_key: str, extra_pro
else:
raise ValueError(f"gdf.geometry[0] must be of type Point or Polygon, got {type(gdf.geometry.iloc[0])}")


if chunk:
n = 10
l = (len(gdf) // n) + 1
s = []
for i in range(n):
s.append(np.repeat(i, l))
cls = np.concatenate(s)

gdf['_chunked'] = cls[:len(gdf)]
category_key = '_chunked'


groups = np.unique(gdf[category_key])
colors = generate_colors(groups)
cells = []
@@ -169,6 +238,7 @@ def write_geojson(gdf: gpd.GeoDataFrame, path: str, category_key: str, extra_pro


def generate_colors(names):
from matplotlib import pyplot as plt
colors = plt.get_cmap('hsv', len(names))
color_dict = {}
for i in range(len(names)):
@@ -192,19 +262,19 @@ def read_parquet_schema_df(path: str) -> pd.DataFrame:
return schema


def cell_reader_factory(path) -> GDFReader:
def cell_reader_factory(path, reader_kwargs={}) -> GDFReader:
if path.endswith('.geojson'):
return GeojsonCellReader()
return GeojsonCellReader(**reader_kwargs)
elif path.endswith('.parquet'):
schema = read_parquet_schema_df(path)
if 'geometry' in schema['column'].values:
return GDFParquetCellReader()
return GDFParquetCellReader(**reader_kwargs)
else:
return XeniumParquetCellReader()
return XeniumParquetCellReader(**reader_kwargs)
else:
ext = path.split('.')[-1]
raise ValueError(f'Unknown file extension {ext} for a cell segmentation file, needs to be .geojson or .parquet')


def read_gdf(path) -> gpd.GeoDataFrame:
return cell_reader_factory(path).read_gdf(path)
def read_gdf(path, reader_kwargs={}) -> gpd.GeoDataFrame:
return cell_reader_factory(path, reader_kwargs).read_gdf(path)
451 changes: 451 additions & 0 deletions src/hest/pipeline.py

Large diffs are not rendered by default.

520 changes: 232 additions & 288 deletions src/hest/readers.py

Large diffs are not rendered by default.

162 changes: 162 additions & 0 deletions src/hest/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

import os
from typing import Union

import geopandas as gpd
import numpy as np
from loguru import logger
from shapely import Polygon

from hest.io.seg_readers import groupby_shape, read_gdf
from hest.utils import (get_name_datetime,
value_error_str, verify_paths)
from hestcore.wsi import WSI


def register_dapi_he(
he_path: Union[str, WSI, np.ndarray, openslide.OpenSlide, CuImage], # type: ignore
dapi_path: str,
registrar_dir: str = "results/registration",
name = None,
max_non_rigid_registration_dim_px=10000,
micro_rigid_registrar_cls=None,
micro_rigid_registrar_params={},
micro_reg=True,
check_for_reflections=False,
reuse_registrar=False
) -> str:
""" Register the DAPI WSI to HE with a fine-grained ridig + non-rigid transform with Valis
Args:
dapi_path (str): path to a dapi WSI
he_path (str): path to an H&E WSI
registrar_dir (str, optional): the output base registration directory. Defaults to "results/registration".
name (str, optional): name of current experiment, the path to the output registrar will be {registrar_dir}/name if name is not None,
or {registrar_dir}/{date} otherwise. Defaults to None.
max_non_rigid_registration_dim_px (int, optional): largest edge of both WSI will be downscaled to this dimension during non-rigid registration. Defaults to 10000.
Returns:
str: path to the resulting Valis registrar
"""

try:
from valis import (affine_optimizer, feature_detectors, preprocessing,
registration)
from valis.micro_rigid_registrar import MicroRigidRegistrar
from valis.slide_io import BioFormatsSlideReader

from .SlideReaderAdapter import SlideReaderAdapter
except Exception:
import traceback
traceback.print_exc()
raise Exception("Valis needs to be installed independently. Please install Valis with `pip install valis-wsi` or follow instruction on their website")

verify_paths([dapi_path, he_path])

if name is None:
date = get_name_datetime()
registrar_dir = os.path.join(registrar_dir, date)
else:
registrar_dir = os.path.join(registrar_dir, name)

img_list = [
he_path,
dapi_path
]

registrar_path = os.path.join(registrar_dir, 'data/_registrar.pickle')

if reuse_registrar:
registration.init_jvm()
return registrar_path
registrar = registration.Valis(
'',
registrar_dir,
reference_img_f=he_path,
align_to_reference=True,
img_list=img_list,
check_for_reflections=check_for_reflections,
micro_rigid_registrar_params=micro_rigid_registrar_params,
micro_rigid_registrar_cls=micro_rigid_registrar_cls
)

registrar.register(
brightfield_processing_cls=preprocessing.HEDeconvolution,
reader_dict= {
he_path: [SlideReaderAdapter],
dapi_path: [BioFormatsSlideReader]
}
)

if micro_reg:
# Perform micro-registration on higher resolution images, aligning *directly to* the reference image
registrar.register_micro(
max_non_rigid_registration_dim_px=max_non_rigid_registration_dim_px,
align_to_reference=True,
brightfield_processing_cls=preprocessing.HEDeconvolution,
reference_img_f=he_path
)

return registrar_path


def warp_gdf_valis(
shapes: Union[gpd.GeoDataFrame, str],
path_registrar: str,
curr_slide_name: str,
n_workers=-1
) -> gpd.GeoDataFrame:
""" Warp some shapes (points or polygons) from an existing Valis registration
Args:
shapes (Union[gpd.GeoDataFrame, str]): shapes to warp. A `str` will be interpreted as a path a nucleus shape file, can be .geojson, or xenium .parquet (ex: nucleus_boundaries.parquet)
path_registrar (str): path to the .pickle file of an existing Valis registrar
Returns:
gpd.GeoDataFrame: warped shapes
"""

try:
from valis import registration
except Exception:
import traceback
traceback.print_exc()
raise Exception("Valis needs to be installed independently. Please install Valis with `pip install valis-wsi` or follow instruction on their website")


if isinstance(shapes, str):
gdf = read_gdf(shapes)
elif isinstance(shapes, gpd.GeoDataFrame):
gdf = shapes.copy()
else:
raise ValueError(value_error_str(shapes, 'shapes'))

registrar = registration.load_registrar(path_registrar)
slide_obj = registrar.get_slide(registrar.reference_img_f)
if isinstance(shapes.iloc[0].geometry, Polygon):
coords = gdf.geometry.get_coordinates(index_parts=True)
points_gdf = coords
idx = coords.index.get_level_values(0)
points_gdf['_polygons'] = idx # keep track of polygons
points = list(zip(points_gdf['x'], points_gdf['y']))
else:
points_gdf = gdf
gdf['_polygons'] = np.arange(len(points_gdf))
points = list(zip(gdf.geometry.x, gdf.geometry.y))

morph = registrar.get_slide(curr_slide_name)
logger.debug('warp with valis...')
warped = morph.warp_xy_from_to(points, slide_obj)
logger.debug('finished warping with valis')

if isinstance(shapes.iloc[0].geometry, Polygon):
points_gdf['xy'] = list(zip(warped[:, 0], warped[:, 1]))
aggr_df = groupby_shape(points_gdf, '_polygons', n_threads=0)
gdf.geometry = aggr_df.geometry
else:
gdf.geometry = gpd.points_from_xy(warped[:, 0], warped[:, 1])

return gdf

3 changes: 2 additions & 1 deletion src/hest/segmentation/TissueMask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pickle
from typing import List

import cv2
import numpy as np
from PIL import Image

@@ -27,6 +26,8 @@ def load_tissue_mask(pkl_path: str, jpg_path: str, width: int, height: int) -> T
with Image.open(jpg_path) as img:
tissue_mask = np.array(img).copy()


import cv2
tissue_mask = cv2.resize(tissue_mask, (width, height))

mask = TissueMask(tissue_mask, contours_tissue, contours_holes)
310 changes: 143 additions & 167 deletions src/hest/segmentation/cell_segmenters.py

Large diffs are not rendered by default.

903 changes: 903 additions & 0 deletions src/hest/subtyping/atlas.py

Large diffs are not rendered by default.

375 changes: 375 additions & 0 deletions src/hest/subtyping/atlas_matchers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
from __future__ import annotations

import os
import pickle
from abc import abstractmethod

import numpy as np
import pandas as pd
from loguru import logger

from hest.HESTData import unify_gene_names


def reduce_dim(X, indices=None, harmony=False):
import umap
from sklearn.decomposition import PCA

print('perform PCA...')
pca = PCA(n_components=50)
comps = pca.fit_transform(X)

if harmony:
print('perform harmony...')
meta_df = pd.DataFrame(indices, columns=['dataset'])
meta_df['dataset'] = meta_df['dataset'].astype(str)

import harmonypy as hm
vars_use = ['dataset']

X = hm.run_harmony(comps, meta_df, vars_use, verbose=False).Z_corr.transpose()
else:
X = comps

print('perform UMAP...')
reducer = umap.UMAP(verbose=False)
embedding = reducer.fit_transform(X)
return embedding


def get_per_cluster_types(cells: sc.AnnData, cluster_key='Cluster', type_key='cell_type_pred', pct=False):
clusters = np.unique(cells.obs[cluster_key])
ls = []
for cluster in clusters:
cell_types = cells[cells.obs[cluster_key] == cluster].obs[type_key]
types, counts = np.unique(cell_types, return_counts=True)
freqs = list(zip(list(types), list(counts)))
freqs = sorted(freqs, key=lambda x: x[1], reverse=True)
if pct:
s = np.array([a[1] for a in freqs]).sum()
freqs = [(a[0], str(round(100 * a[1] / s)) + '%') for a in freqs if (100 * a[1] / s) > 1]
ls.append([cluster, freqs])
return ls


def plot(plot_name, embedding, cell_types, indices=None):
import seaborn as sns
from matplotlib import pyplot as plt

print('find uniques...')

cell_types = np.array([str(s) for s in cell_types])
names, inverse = np.unique(cell_types, return_inverse=True)
plt.figure(figsize=(10, 6))


if indices is None:
indices = np.array([0 for _ in range(len(cell_types))])

palettes = ['tab10', 'tab10']
markers = ['o', '^']
for i in np.unique(indices):
obs_names = cell_types

sub_emb = embedding[indices==i, :]
sub_obs_names = obs_names[indices==i]
n = len(sub_emb)
k = 1000
idx = np.random.choice(np.arange(n), size=k, replace=False)
sub_emb = sub_emb[idx, :]
sub_obs_names = sub_obs_names[idx]

sns.scatterplot(
x=sub_emb[:, 0],
y=sub_emb[:, 1],
hue=sub_obs_names,
palette=palettes[i], # Choose a color palette
legend='full',
alpha=1,
s=8,
marker=markers[i]
)

plt.legend(
title='Your Legend Title',
bbox_to_anchor=(1.05, 1), # Adjust position to fit outside the plot area if needed
loc='upper left',
borderaxespad=0.,
ncol=4 # Number of columns
)


#plt.legend(markerscale=2, labels=names[np.unique(inverse[indices==i])], ncol=6, loc='upper center', bbox_to_anchor=(0.5, -0.1))
plt.title('UMAP Visualization of Sparse Matrix')
plt.xlabel('UMAP Component 1')
plt.ylabel('UMAP Component 2')
plt.grid(True)
plt.tight_layout()
plt.savefig(plot_name, dpi=300)


def cache_or_read(plot_name, X, indices=None, harmony=None):
cached_emb_path = os.path.join('emb_cache', plot_name + 'embedding.pkl')
if not os.path.exists(cached_emb_path):
print('cache empty, perform dim reduction...')
embedding = reduce_dim(X, indices=indices, harmony=harmony)

with open(cached_emb_path, 'wb') as f:
pickle.dump(embedding, f)

with open(cached_emb_path, 'rb') as f:
embedding = pickle.load(f)
return embedding


class SCMatcher:

def _filter_na_cells(self, cells, atlas_cells, level):
nan_cells = cells.obs.isna().any(axis=1)
cells = cells[~nan_cells]

nan_cells = (atlas_cells.obs[level] == 'NA')
atlas_cells = atlas_cells[~nan_cells]

return cells, atlas_cells

def _filter_common_genes(self, cells, atlas_cells, unify_genes):
inter_genes = np.intersect1d(cells.var_names, atlas_cells.var_names)
missing = set(cells.var_names) - set(atlas_cells.var_names)
if len(missing) > 0:
missing_str = missing if len(missing) < 100 else str(missing)[:500] + '...'

warning_str = f"{len(missing)} out of {len(cells.var_names)} genes are missing in the Atlas: {missing_str}"
if not unify_genes:
warning_str += ". Consider passing unify_genes=True"
logger.warning(warning_str)

return cells[:, inter_genes], atlas_cells[:, inter_genes]


def unify_genes(self, cells, atlas_cells, species):
logger.info('unifying source gene names')
cells = unify_gene_names(cells, species)
logger.info('unifying atlas gene names')
atlas_cells = unify_gene_names(atlas_cells, species)

return cells, atlas_cells


def match_atlas(
self,
name,
cells: sc.AnnData,
atlas_cells: sc.AnnData,
unify_genes=False,
species='hsapiens',
level='cell_types',
plot_atlas=False,
plot_preds=False,
**kwargs
) -> sc.AnnData:

if unify_genes:
cells, atlas_cells = self.unify_genes(cells, atlas_cells, species)

cells, atlas_cells = self._filter_common_genes(cells, atlas_cells, unify_genes)

cells, atlas_cells = self._filter_na_cells(cells, atlas_cells, level)

atlas_cells.obs['cell_types'] = atlas_cells.obs[level]

if plot_atlas:
embeddings = reduce_dim(atlas_cells.X)
plot(f'{name}_atlas_plot.jpg', embeddings, atlas_cells.obs['cell_types'].values)

preds = self.match_atlas_imp(name, cells, atlas_cells, level, **kwargs)

if plot_preds:
embeddings = reduce_dim(cells.X)
plot(f'{name}_preds_plot.jpg', embeddings, preds.values)

cells.obs['cell_type_pred'] = preds

return cells


@abstractmethod
def match_atlas_imp(
self,
name,
cells,
atlas_cells,
level
):
""" Output prediction for each cell """
pass


class HarmonyMatcher(SCMatcher):
def _prepare_data(self, atlas_cells, cells, sub_atlas, sub_cells, cluster_key, k, random_state):
import scanpy as sc

sc.pp.subsample(atlas_cells, fraction=sub_atlas, random_state=random_state)
sc.pp.subsample(cells, fraction=sub_cells, random_state=random_state)
assert np.array_equal(cells.to_df().columns, atlas_cells.to_df().columns)

df_combined = pd.concat([cells.to_df(), atlas_cells.to_df()])
adata_comb = sc.AnnData(df_combined)

if cluster_key is not None:
logger.info(f"Using custom clusters in .obs[{cluster_key}]")
cells.obs = cells.obs[['cluster_key']]
else:
from sklearn.cluster import KMeans
logger.info(f"Couldn't find custom clusters, perform kmeans...")
kmeans = KMeans(n_clusters=k, random_state=random_state).fit(cells.X)
cells.obs['kmeans_clusters'] = kmeans.labels_.astype(str)
cells.obs = cells.obs.rename(columns={'kmeans_clusters': 'cell_types'})

adata_comb.obs = pd.concat([cells.obs, atlas_cells.obs])
adata_comb.obs = pd.DataFrame(adata_comb.obs['cell_types'].values, columns=['cell_types'])

indices = np.concatenate((np.ones(len(cells.obs), dtype=np.int32), np.zeros(len(atlas_cells.obs), dtype=np.int32)))
return adata_comb, indices


def _get_k_nearest(self, X, indices, cell_types, k=5, n=30):
from sklearn.neighbors import KNeighborsClassifier

xen_emb = X[indices == 1]
xen_types = cell_types[indices == 1]['cell_types']
atlas_emb = X[indices == 0]
atlas_types = cell_types[indices == 0]['cell_types']
xen_clust_names = np.unique(xen_types)
logits = []
model = KNeighborsClassifier(metric='sqeuclidean', n_neighbors=k)
model.fit(atlas_emb, atlas_types)
for xen_clust_name in xen_clust_names:
xen_emb_i = xen_emb[(xen_types == xen_clust_name).values.flatten(), :]
idx = np.random.choice(np.arange(len(xen_emb_i)), n)
xen_emb_i = xen_emb_i[idx]


clusters_pred = model.predict(xen_emb_i)
names, counts = np.unique(clusters_pred, return_counts=True)
counts = counts / counts.sum()
logits_clust = dict(zip(names, counts))
logits_clust = dict(sorted(logits_clust.items(), key=lambda item: item[1], reverse=True))
logits.append([xen_clust_name, logits_clust])

logits_df = pd.DataFrame(logits)
mean_highest = np.array([max(v[1].values()) for v in logits]).mean()
return logits_df, mean_highest


def match_atlas_imp(
self,
name,
cells,
atlas_cells,
sub_atlas=1,
sub_cells=0.00005,
mode="cells",
chunk_len=None,
device=None,
level='cell_types',
cluster_key=None,
k=10,
random_state=None
):

adata_comb, indices = self._prepare_data(atlas_cells, cells, sub_atlas, sub_cells, cluster_key, k, random_state)

embedding = cache_or_read(name, adata_comb.X, indices=indices, harmony=True)


#full_adata_comb, full_indices = self._prepare_data(atlas_cells, cells, sub_atlas=1, sub_cells=1)
k_nearest, mean_highest = self._get_k_nearest(embedding, indices, adata_comb.obs)
k_nearest.to_csv(os.path.join('clusters', name + '_k_nearest.csv'))
cluster_map = k_nearest
cluster_map['cell_types'] = cluster_map[1].apply(lambda row: max(row, key=row.get))
cluster_map['cluster'] = cluster_map[0]
cluster_map = cluster_map[['cluster', 'cell_types']]
cells.obs['cell_id'] = cells.obs.index
#plot_umap(name + 'match_', embedding, adata_comb.obs, indices=indices)
merged = cells.obs.merge(cluster_map, left_on='cell_types', right_on='cluster', how='left')
merged.index = merged['cell_id']
merged = merged['cell_types_y'].values
return merged

class TangramMatcher(SCMatcher):

def match_atlas_imp(
self,
name,
cells,
atlas_cells,
level,
mode="clusters",
chunk_len=None,
device=None,
random_state=None
):
import scanpy as sc
import tangram as tg
import torch

ad_sp = cells
ad_sc = atlas_cells

ad_sc.obs['subclass_label'] = ad_sc.obs[level]


if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if mode == 'cells':
if chunk_len is None:
raise ValueError('please provide `chunk_len` when in mode `cells`')

n_chunks = len(cells) // chunk_len
chunks = []
for i in range(0, n_chunks + 1):
start = i * chunk_len
end = min(i * chunk_len + chunk_len, len(cells))
chunk = cells[start:end]

tg.pp_adatas(ad_sc, chunk, genes=None)

ad_map = tg.map_cells_to_space(
ad_sc,
chunk,
num_epochs=200,
mode=mode,
cluster_label='subclass_label',
device=device,
random_state=random_state)
tg.project_cell_annotations(ad_map, chunk, annotation='subclass_label')
chunks.append(chunk)
ad_sp = sc.concat(chunks)
ad_sp.obs.index = cells.obs.index
else:
tg.pp_adatas(ad_sc, ad_sp, genes=None)

ad_map = tg.map_cells_to_space(
ad_sc,
ad_sp,
num_epochs=200,
mode=mode,
cluster_label='subclass_label',
device=device)

tg.project_cell_annotations(ad_map, ad_sp, annotation='subclass_label')

preds = pd.Series(ad_sp.obsm['tangram_ct_pred'].idxmax(axis=1), name=level)
preds.index.name = 'cell_id'
return preds


def matcher_factory(name):
if name == 'harmony':
return HarmonyMatcher()
elif name == 'tangram':
return TangramMatcher()
else:
raise ValueError(f"unknown cell matcher {name}")
201 changes: 201 additions & 0 deletions src/hest/subtyping/subtyping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from __future__ import annotations

import os
import shutil

import numpy as np
import pandas as pd
from tqdm import tqdm

from hest.subtyping.atlas import (get_atlas_from_name, get_cells_with_clusters,
sc_atlas_factory)
from hest.subtyping.atlas_matchers import matcher_factory
from hest.utils import get_path_from_meta_row

level = 'predicted.celltypel1'


def assign_cell_types(cell_adata, atlas_name, name, method='tangram', full_atlas=False, organ=None, **matcher_kwargs) -> sc.AnnData:
matcher = matcher_factory(method)


if organ is not None:
atlas_cells = sc_atlas_factory(organ)
else:
atlas_cells = get_atlas_from_name(atlas_name)()

if not full_atlas:
atlas_cells = atlas_cells.get_downsampled()
else:
atlas_cells = atlas_cells.get_full()

preds = matcher.match_atlas(
name,
cell_adata,
atlas_cells,
**matcher_kwargs)
return preds


def assign_cell_types_hest(meta_df, method='tangram'):
for _, row in meta_df.iterrows():
path = get_path_from_meta_row(row)
organ = row['Organ']

cell_adata = get_cells_with_clusters(path, cluster_path=os.path.join(path, 'analysis/clustering/gene_expression_kmeans_10_clusters/clusters.csv'), k=None)
assign_cell_types(cell_adata, organ, method=method)


def place_in_right_folder(path, rename=False):
paths = os.listdir(path)
ids = np.unique([path.split('.')[0] for path in paths])
for id in ids:
os.makedirs(os.path.join(path, id), exist_ok=True)
for f in paths:
src = os.path.join(path, f)

if not os.path.isfile(src):
continue

name = f.split('.')[0]
if rename:
if 'barcodes' in f:
f = 'barcodes.tsv.gz'
elif 'features' in f or 'genes' in f:
f = 'features.tsv.gz'
elif 'matrix' in f:
f = 'matrix.mtx.gz'
dst = os.path.join(path, name, f)
shutil.move(src, dst)


def join_MEX(dir):
import scanpy as sc
joined_adata = None
for f in tqdm(os.listdir(dir)):
path = os.path.join(dir, f)
if os.path.isdir(path):
adata = sc.read_10x_mtx(path)
if joined_adata is None:
joined_adata = adata
else:
joined_adata = joined_adata.concatenate(adata, join='outer')
return joined_adata



xenium_cell_types_map = {
'Adipocytes': 'Stromal',
'B Cells': 'B-cells',
'CD163+ Macrophage': 'Myeloid',
'CD83+ Macrophage': 'Myeloid',
'CTLA4+ T Cells': 'T-cells',
'DST+ Myoepithelial': 'Normal Epithelial',
'ESR1+ Epithelial': 'Normal Epithelial',
'Endothelial': 'Endothelial',
'ITGAX+ Macrophage': 'Myeloid',
'Mast Cells': 'Myeloid',
'Not Plotted': 'NA',
'OPRPN+ Epithelial': 'Normal Epithelial',
'PIGR+ Epithelial': 'Normal Epithelial',
'Plasma Cells': 'B-cells',
'Plasmacytoid Dendritic': 'B-cells',
'Stromal Normal': 'Stromal',
'TRAC+ Cells': 'T-cells',
'Transitional Cells': 'NA',
'Tumor': 'Cancer Epithelial',
'Tumor Associated Stromal': 'Stromal',
'B_Cells': 'B-cells',
'CD4+_T_Cells': 'T-cells',
'CD8+_T_Cells': 'T-cells',
'DCIS_1': 'Cancer Epithelial',
'DCIS_2': 'Cancer Epithelial',
'IRF7+_DCs': 'Myeloid',
'Invasive_Tumor': 'Cancer Epithelial',
'LAMP3+_DCs': 'Myeloid',
'Macrophages_1': 'Myeloid',
'Macrophages_2': 'Myeloid',
'Mast_Cells': 'Myeloid',
'Myoepi_ACTA2+': 'Normal Epithelial',
'Myoepi_KRT15+': 'Normal Epithelial',
'Perivascular-Like': 'Stromal',
'Prolif_Invasive_Tumor': 'Cancer Epithelial',
'Stromal': 'Stromal',
'Stromal_&_T_Cell_Hybrid': 'NA',
'T_Cell_&_Tumor_Hybrid': 'NA',
'Unlabeled': 'NA',
'NK Cells': 'T-cells',
'Macrophage 1': 'Myeloid',
'Macrophage 2': 'Myeloid',
'Mast Cells': 'Myeloid',
'Plasmablast': 'B-cells',
'Invasive Tumor': 'Cancer Epithelial',
'Undefined': 'NA',
'T Cells': 'T-cells',
'DCIS': 'Cancer Epithelial',
'ACTA2+ Myoepithelial': 'Normal Epithelial',
'KRT15+ Myoepithelial': 'Normal Epithelial'
}

breast3_cell_types_map = {
'basal na': 'Normal Epithelial',
'bcells na': 'B-cells',
'fibroblasts na': 'Stromal',
'lumhr na': 'Normal Epithelial',
'lumsec na': 'Normal Epithelial',
'lumsec proliferating': 'Cancer Epithelial',
'lymphatic na': 'Endothelial',
'myeloid na': 'Myeloid',
'myeloid proliferating': 'Myeloid',
'pericytes na': 'Stromal',
'tcells na': 'T-cells',
'tcells proliferating': 'T-cells',
'vascular na': 'Endothelial'
}

def eval_cell_type_assignment(pred, path_gt, map_pred, map_gt, name, key='cell_type_pred'):
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import balanced_accuracy_score, confusion_matrix

if isinstance(pred, str):
df1 = pd.read_csv(pred)
else:
df1 = pred
df2 = pd.read_csv(path_gt)
if not isinstance(df2['Barcode'].iloc[0], str):
df2['Barcode'] = df2['Barcode'].astype(int).astype(str)

merged = df1.merge(df2, left_index=True, right_on='Barcode', how='inner')

merged['Mapped'] = [map_gt[x] for x in merged['Cluster'].values]



mask_NA = merged['Mapped'] == 'NA'
merged = merged[~mask_NA]

mapped_pred = [map_pred[x] for x in merged[key].values]
mapped_gt = [map_gt[x] for x in merged['Cluster'].values]

labels = sorted(list(set(mapped_gt) | set(mapped_pred)))
cm = confusion_matrix(mapped_gt, mapped_pred, labels=labels)

balanced_acc = round(balanced_accuracy_score(mapped_gt, mapped_pred), 4)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title(f'Confusion matrix, cell type prediction (balanced_acc={balanced_acc})')

plt.tight_layout()

plt.savefig(name + 'confusion_matrix.jpg', dpi=150)


def eval_all_cell_type_assignments(meta_df):
for id in meta_df['id']:
for method in ['tangram', 'harmony']:
eval_cell_type_assignment(f'cell_type_preds/{method}_{id}k=10.csv', f'cell_type_preds/gt_{id}.csv', breast3_cell_types_map, xenium_cell_types_map, name=f'{id}_{method}_')

371 changes: 257 additions & 114 deletions src/hest/utils.py

Large diffs are not rendered by default.

6 changes: 0 additions & 6 deletions tests/hest_tests.py
Original file line number Diff line number Diff line change
@@ -19,12 +19,6 @@
from hest.readers import VisiumReader
from hest.utils import load_image

try:
from cucim import CuImage
except ImportError:
CuImage = None
CucimWarningSingleton.warn()


class TestHESTReader(unittest.TestCase):

2 changes: 1 addition & 1 deletion tutorials/README.md
Original file line number Diff line number Diff line change
@@ -23,6 +23,6 @@ This notebook is dedicated to visualizing batch effects within the HEST-1k datas

## Contributions

External contributions are welcome! If you have ideas for improving these tutorials or would like to contribute, please feel free to reach out to [gjaume@bwh.harvard.edu](mailto:gjaume@bwh.harvard.edu).
External contributions are welcome! If you have ideas for improving these tutorials or would like to contribute, please feel free to reach out to [gjaume@bwh.harvard.edu](mailto:gjaume@bwh.harvard.edu) (cc: [homedoucetpaul@gmail.com](mailto:homedoucetpaul@gmail.com)).

If you encounter any issues, please check the GitHub Issues section, as other users might have already faced similar challenges.