-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Config-based prediction with Xarray-based output format (#132)
* use callback to write prediction embeddings * moving over the script to compute infection score from contrastive_update * delete unused stem module * organize scripts and CLIs for contrastive phenotyping * add dependencies for prediction * export embedding dataset reader function * add more plots to script * use real paths in predict config * do not require seaborn and umap-learn for base install * use relative path in example job script * add docstrings for embedding writer and reader * don't assign unused grid object * show time and id as hover data in interactive plot * fix typo * fix script to test data i/o * ignore accidental lightning_logs * add plotly and nbformat to visual dependencies * tweak predict cli example * add another plot type - raw features of random samples * comment on speed of clustermap * add prediction config example to specify log path * simplify env var in job script and match cpu count with config * vectorize string concatenation --------- Co-authored-by: Shalin Mehta <[email protected]>
- Loading branch information
1 parent
fb2ec0f
commit 308392c
Showing
19 changed files
with
590 additions
and
227 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
145 changes: 145 additions & 0 deletions
145
applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# %% | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import plotly.express as px | ||
import seaborn as sns | ||
from sklearn.preprocessing import StandardScaler | ||
from umap import UMAP | ||
|
||
from viscy.light.embedding_writer import read_embedding_dataset | ||
|
||
# %% | ||
dataset = read_embedding_dataset( | ||
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04-tokenized-drop_path_0_0.zarr" | ||
) | ||
dataset | ||
|
||
# %% | ||
# load all unprojected features: | ||
features = dataset["features"] | ||
# or select a well: | ||
# features = features[features["fov_name"].str.contains("B/4")] | ||
features | ||
|
||
# %% | ||
# examine raw features | ||
random_samples = np.random.randint(0, dataset.sizes["sample"], 700) | ||
# concatenate fov_name, track_id, and t to create a unique sample identifier | ||
sample_id = ( | ||
features["fov_name"][random_samples] | ||
+ "-" | ||
+ features["track_id"][random_samples].astype(str) | ||
+ "-" | ||
+ features["t"][random_samples].astype(str) | ||
) | ||
px.imshow( | ||
features.values[random_samples], | ||
labels={ | ||
"x": "feature", | ||
"y": "sample", | ||
"color": "value", | ||
}, # change labels to match our metadata | ||
y=sample_id, | ||
# show fov_name as y-axis | ||
) | ||
|
||
# %% | ||
scaled_features = StandardScaler().fit_transform(features.values) | ||
|
||
umap = UMAP() | ||
|
||
embedding = umap.fit_transform(scaled_features) | ||
features = ( | ||
features.assign_coords(UMAP1=("sample", embedding[:, 0])) | ||
.assign_coords(UMAP2=("sample", embedding[:, 1])) | ||
.set_index(sample=["UMAP1", "UMAP2"], append=True) | ||
) | ||
features | ||
|
||
# %% | ||
sns.scatterplot( | ||
x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 | ||
) | ||
|
||
|
||
# %% | ||
def load_annotation(da, path, name, categories: dict | None = None): | ||
annotation = pd.read_csv(path) | ||
annotation["fov_name"] = "/" + annotation["fov ID"] | ||
annotation = annotation.set_index(["fov_name", "id"]) | ||
mi = pd.MultiIndex.from_arrays( | ||
[da["fov_name"].values, da["id"].values], names=["fov_name", "id"] | ||
) | ||
selected = annotation.loc[mi][name] | ||
if categories: | ||
selected = selected.astype("category").cat.rename_categories(categories) | ||
return selected | ||
|
||
|
||
# %% | ||
ann_root = Path( | ||
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track" | ||
) | ||
|
||
infection = load_annotation( | ||
features, | ||
ann_root / "tracking_v1_infection.csv", | ||
"infection class", | ||
{0.0: "background", 1.0: "uninfected", 2.0: "infected"}, | ||
) | ||
division = load_annotation( | ||
features, | ||
ann_root / "cell_division_state.csv", | ||
"division", | ||
{0: "non-dividing", 2: "dividing"}, | ||
) | ||
|
||
|
||
# %% | ||
sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=division, s=7, alpha=0.8) | ||
|
||
# %% | ||
sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) | ||
|
||
# %% | ||
ax = sns.histplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, bins=64) | ||
sns.move_legend(ax, loc="lower left") | ||
|
||
# %% | ||
sns.displot( | ||
x=features["UMAP1"], | ||
y=features["UMAP2"], | ||
kind="hist", | ||
col=infection, | ||
bins=64, | ||
cmap="inferno", | ||
) | ||
|
||
# %% | ||
# interactive scatter plot to associate clusters with specific cells | ||
|
||
px.scatter( | ||
data_frame=pd.DataFrame( | ||
{k: v for k, v in features.coords.items() if k != "features"} | ||
), | ||
x="UMAP1", | ||
y="UMAP2", | ||
color=(infection.astype(str) + " " + division.astype(str)).rename("annotation"), | ||
hover_name="fov_name", | ||
hover_data=["id", "t"], | ||
) | ||
|
||
# %% | ||
# cluster features in heatmap directly | ||
# this is very slow for large datasets even with fastcluster installed | ||
inf_codes = pd.Series(infection.values.codes, name="infection") | ||
lut = dict(zip(inf_codes.unique(), "brw")) | ||
row_colors = inf_codes.map(lut) | ||
|
||
g = sns.clustermap( | ||
scaled_features, row_colors=row_colors.to_numpy(), col_cluster=False, cbar_pos=None | ||
) | ||
g.yaxis.set_ticks([]) | ||
# %% |
50 changes: 50 additions & 0 deletions
50
applications/contrastive_phenotyping/contrastive_cli/predict.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
seed_everything: 42 | ||
trainer: | ||
accelerator: gpu | ||
strategy: auto | ||
devices: auto | ||
num_nodes: 1 | ||
precision: 32-true | ||
callbacks: | ||
- class_path: viscy.light.embedding_writer.EmbeddingWriter | ||
init_args: | ||
output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr" | ||
# edit the following lines to specify logging path | ||
# - class_path: lightning.pytorch.loggers.TensorBoardLogger | ||
# init_args: | ||
# save_dir: /path/to/save_dir | ||
# version: name-of-experiment | ||
# log_graph: True | ||
inference_mode: true | ||
model: | ||
backbone: convnext_tiny | ||
in_channels: 2 | ||
in_stack_depth: 15 | ||
stem_kernel_size: [5, 4, 4] | ||
data: | ||
data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr | ||
tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr | ||
source_channel: | ||
- Phase3D | ||
- RFP | ||
z_range: [28, 43] | ||
batch_size: 32 | ||
num_workers: 16 | ||
initial_yx_patch_size: [192, 192] | ||
final_yx_patch_size: [192, 192] | ||
normalizations: | ||
- class_path: viscy.transforms.NormalizeSampled | ||
init_args: | ||
keys: [Phase3D] | ||
level: fov_statistics | ||
subtrahend: mean | ||
divisor: std | ||
- class_path: viscy.transforms.ScaleIntensityRangePercentilesd | ||
init_args: | ||
keys: [RFP] | ||
lower: 50 | ||
upper: 99 | ||
b_min: 0.0 | ||
b_max: 1.0 | ||
return_predictions: false | ||
ckpt_path: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/lightning_logs/tokenized-drop-path-0.0/checkpoints/epoch=96-step=23377.ckpt |
21 changes: 21 additions & 0 deletions
21
applications/contrastive_phenotyping/contrastive_cli/predict_slurm.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH --job-name=contrastive_predict | ||
#SBATCH --nodes=1 | ||
#SBATCH --ntasks-per-node=1 | ||
#SBATCH --gres=gpu:1 | ||
#SBATCH --partition=gpu | ||
#SBATCH --cpus-per-task=16 | ||
#SBATCH --mem-per-cpu=7G | ||
#SBATCH --time=0-01:00:00 | ||
|
||
module load anaconda/2022.05 | ||
# Update to use the actual prefix | ||
conda activate $MYDATA/envs/viscy | ||
|
||
scontrol show job $SLURM_JOB_ID | ||
|
||
# use absolute path in production | ||
config=./predict.yml | ||
cat $config | ||
srun python -m viscy.cli.contrastive_triplet predict -c $config |
File renamed without changes.
File renamed without changes.
File renamed without changes.
166 changes: 166 additions & 0 deletions
166
...cations/contrastive_phenotyping/contrastive_scripts/predict_infection_score_supervised.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
import numpy as np | ||
import os | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
from viscy.data.triplet import TripletDataModule, TripletDataset | ||
import pandas as pd | ||
import warnings | ||
|
||
warnings.filterwarnings( | ||
"ignore", | ||
category=UserWarning, | ||
message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", | ||
) | ||
|
||
# %% Paths and constants | ||
save_dir = ( | ||
"/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4" | ||
) | ||
|
||
# rechunked data | ||
data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr" | ||
|
||
# updated tracking data | ||
tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" | ||
|
||
source_channel = ["background_mask", "uninfected_mask", "infected_mask"] | ||
z_range = (0, 1) | ||
batch_size = 1 # match the number of fovs being processed such that no data is left | ||
# set to 15 for full, 12 for infected, and 8 for uninfected | ||
|
||
# non-rechunked data | ||
data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" | ||
|
||
# updated tracking data | ||
tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" | ||
|
||
source_channel_1 = ["Nuclei_prediction_labels"] | ||
|
||
|
||
# %% Define the main function for training | ||
def main(hparams): | ||
# Initialize the data module for prediction, re-do embeddings but with size 224 by 224 | ||
data_module = TripletDataModule( | ||
data_path=data_path, | ||
tracks_path=tracks_path, | ||
source_channel=source_channel, | ||
z_range=z_range, | ||
initial_yx_patch_size=(224, 224), | ||
final_yx_patch_size=(224, 224), | ||
batch_size=batch_size, | ||
num_workers=hparams.num_workers, | ||
) | ||
|
||
data_module.setup(stage="predict") | ||
|
||
print(f"Total prediction dataset size: {len(data_module.predict_dataset)}") | ||
|
||
dataloader = DataLoader( | ||
data_module.predict_dataset, | ||
batch_size=batch_size, | ||
num_workers=hparams.num_workers, | ||
) | ||
|
||
# Initialize the second data module for segmentation masks | ||
seg_data_module = TripletDataModule( | ||
data_path=data_path_1, | ||
tracks_path=tracks_path_1, | ||
source_channel=source_channel_1, | ||
z_range=z_range, | ||
initial_yx_patch_size=(224, 224), | ||
final_yx_patch_size=(224, 224), | ||
batch_size=batch_size, | ||
num_workers=hparams.num_workers, | ||
) | ||
|
||
seg_data_module.setup(stage="predict") | ||
|
||
seg_dataloader = DataLoader( | ||
seg_data_module.predict_dataset, | ||
batch_size=batch_size, | ||
num_workers=hparams.num_workers, | ||
) | ||
|
||
# Initialize lists to store average values | ||
background_avg = [] | ||
uninfected_avg = [] | ||
infected_avg = [] | ||
|
||
for batch, seg_batch in tqdm( | ||
zip(dataloader, seg_dataloader), | ||
desc="Processing batches", | ||
total=len(data_module.predict_dataset), | ||
): | ||
anchor = batch["anchor"] | ||
seg_anchor = seg_batch["anchor"].int() | ||
|
||
# Extract the fov_name and id from the batch | ||
fov_name = batch["index"]["fov_name"][0] | ||
cell_id = batch["index"]["id"].item() | ||
|
||
fov_dirs = fov_name.split("/") | ||
# Construct the path to the CSV file | ||
csv_path = os.path.join( | ||
tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv" | ||
) | ||
|
||
# Read the CSV file | ||
df = pd.read_csv(csv_path) | ||
|
||
# Find the row with the specified id and extract the track_id | ||
track_id = df.loc[df["id"] == cell_id, "track_id"].values[0] | ||
|
||
# Create a boolean mask where segmentation values are equal to the track_id | ||
mask = seg_anchor == track_id | ||
# mask = (seg_anchor > 0) | ||
|
||
# Find the most frequent non-zero value in seg_anchor | ||
# unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True) | ||
# most_frequent_value = unique[np.argmax(counts)] | ||
|
||
# # Create a boolean mask where segmentation values are equal to the most frequent value | ||
# mask = (seg_anchor == most_frequent_value) | ||
|
||
# Expand the mask to match the anchor tensor shape | ||
mask = mask.expand(1, 3, 1, 224, 224) | ||
|
||
# Calculate average values for each channel (background, uninfected, infected) using the mask | ||
background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item()) | ||
uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item()) | ||
infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item()) | ||
|
||
# Convert lists to numpy arrays | ||
background_avg = np.array(background_avg) | ||
uninfected_avg = np.array(uninfected_avg) | ||
infected_avg = np.array(infected_avg) | ||
|
||
print("Average values per cell for each mask calculated.") | ||
print("Background average shape:", background_avg.shape) | ||
print("Uninfected average shape:", uninfected_avg.shape) | ||
print("Infected average shape:", infected_avg.shape) | ||
|
||
# Save the averages as .npy files | ||
np.save(os.path.join(save_dir, "background_avg.npy"), background_avg) | ||
np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg) | ||
np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--backbone", type=str, default="resnet50") | ||
parser.add_argument("--margin", type=float, default=0.5) | ||
parser.add_argument("--lr", type=float, default=1e-3) | ||
parser.add_argument("--schedule", type=str, default="Constant") | ||
parser.add_argument("--log_steps_per_epoch", type=int, default=10) | ||
parser.add_argument("--embedding_len", type=int, default=256) | ||
parser.add_argument("--max_epochs", type=int, default=100) | ||
parser.add_argument("--accelerator", type=str, default="gpu") | ||
parser.add_argument("--devices", type=int, default=1) | ||
parser.add_argument("--num_nodes", type=int, default=1) | ||
parser.add_argument("--log_every_n_steps", type=int, default=1) | ||
parser.add_argument("--num_workers", type=int, default=8) | ||
args = parser.parse_args() | ||
main(args) |
Oops, something went wrong.