Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversion script for GIF to 3D MHA #3

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions convert_attention_gif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
from PIL import Image
import SimpleITK as sitk
import time

def split_gif_to_frames(gif_path, output_folder):
if not os.path.exists(output_folder):
os.makedirs(output_folder)

with Image.open(gif_path) as gif:
frame_count = gif.n_frames
print(f"Total frames: {frame_count}")

for frame in range(frame_count):
gif.seek(frame)
frame_filename = os.path.join(output_folder, f"frame_{frame:03d}.png")
gif.save(frame_filename, format="PNG")
# print(f"Saved: {frame_filename}")

n_frames = len(os.listdir(output_folder))
print(f"Saved {n_frames} frames")

def stack_png_to_mha(input_folder, output_mha_path):
# Get list of PNG files in the folder, sorted by filename (assuming sequential naming)
png_files = sorted([f for f in os.listdir(input_folder) if f.endswith('.png')])

if not png_files:
raise ValueError("No PNG files found in the provided directory")

# Read the first PNG file to determine the size and spacing of the slices
first_png_path = os.path.join(input_folder, png_files[0])
first_image = sitk.ReadImage(first_png_path, outputPixelType=sitk.sitkVectorUInt8)

slice_size = first_image.GetSize() # (width, height, depth)
slice_components = first_image.GetNumberOfComponentsPerPixel() # Checking if they have 3 channels.
print("slice size:", slice_size, ", num_components:", slice_components)

# Create an empty list to hold the images
image_list = []

# Load each PNG file and append the images to the list
for i, png_file in enumerate(png_files):
png_path = os.path.join(input_folder, png_file)
img = sitk.ReadImage(png_path, sitk.sitkVectorUInt8)
if img.GetNumberOfComponentsPerPixel() == 3:
image_list.append(img)

# Stack the images along the third dimension (Z-axis)
print("number of frames in stack:", len(image_list))
stacked_image = sitk.JoinSeries(image_list)
stack_size = stacked_image.GetSize()
print("stack size:", stack_size, ", stack n_components:", stacked_image.GetNumberOfComponentsPerPixel())

# Save the stacked image as a single 3D MHA file
sitk.WriteImage(stacked_image, output_mha_path)
print(f"Stacked MHA saved as {output_mha_path}")

if __name__ == "__main__":
## Input directory (including subfolder for attention GIFs).
ATTENTION_PARENT_DIR = f"/data/bodyct/experiments/lung-malignancy-fairness-shaurya/nlst/sybil_attentions"
ATTENTION_GIF_DIR = f"{ATTENTION_PARENT_DIR}/attention_gifs"
subfolders = [f[:-4] for f in os.listdir(ATTENTION_GIF_DIR_DIR)]

## Make subdirectories for output filetypes.
ATTENTION_PNG_DIR = f"{ATTENTION_PARENT_DIR}/attention_pngs"
os.makedirs(ATTENTION_PNG_DIR, exist_ok=True)
ATTENTION_MHA_DIR = f"{ATTENTION_PARENT_DIR}/attention_mhas"
os.makedirs(ATTENTION_MHA_DIR, exist_ok=True)

print("Starting conversion!")
start_time_conv = time.time()

for i in range(0, len(subfolders)):
seriesid = subfolders[i]
print(f"\n{i+1} / {len(subfolders)}: converting {seriesid} ...")

gif_path = f"{ATTENTION_GIF_DIR}/{seriesid}.gif"
png_folder = f"{ATTENTION_PNG_DIR}/{seriesid}"
mha_path = f"{ATTENTION_MHA_DIR}/{seriesid}.mha"

split_gif_to_frames(gif_path, png_folder)
stack_png_to_mha(png_folder, output_mha_path)

end_time_conv = time.time()
print(f"Total time for conversion: {end_time_conv - start_time_conv} seconds")
201 changes: 201 additions & 0 deletions inference_shaurya.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os
from sybil import Serie, Sybil
import time
import json
import pandas as pd
import numpy as np
import csv
import traceback
from sybil import visualize_attentions

EXPERIMENT_DIR = r"/data/bodyct/experiments/lung-malignancy-fairness-shaurya"

# Name of CSV file with series instance UIDs to extract
csv_input = rf"{EXPERIMENT_DIR}/nlst/sybil_tp_train_top25.csv"
INFERENCE_DIR = rf"{EXPERIMENT_DIR}/nlst/sybil_attentions_tp_train"

os.makedirs(INFERENCE_DIR, exist_ok=True)
ParentDirectory = rf"{EXPERIMENT_DIR}/nlst/DICOM_files"

# Name of the output csv file
csvoutput = rf"{INFERENCE_DIR}/inference.csv"
csvoutput2 = rf"{INFERENCE_DIR}/inference2.csv"
csv_error = rf"{INFERENCE_DIR}/error_ids.csv"
csvattentionoutput = rf"{INFERENCE_DIR}/output_attention_scores.csv"
log_file = rf"{INFERENCE_DIR}/errorlog.txt"


start_time_model = time.time()
# Load a trained model
model = Sybil("sybil_ensemble")


def get_series_instance_uids(csv_input, n=None):
df = pd.read_csv(csv_input)
# df = df[(df["Thijmen_mean"].isna()) & (df["InSybilTrain"] == False)]
ids = pd.unique(df["SeriesInstanceUID"]).tolist()
if n is not None:
ids = ids[0:n]
return ids


def get_subfolder_paths(parent_folder, id_list=None):
"""Return a list of all subfolder paths in a parent folder.

Parameters:
parent_folder (str): The path to the parent folder.

Returns:
List[str]: A list of paths for all subfolders in the parent folder.
"""
subfolder_names = os.listdir(parent_folder)
if id_list is not None:
subfolder_names = list(set(list(subfolder_names)).intersection(set(id_list)))

subfolder_paths = [
os.path.join(parent_folder, name)
for name in subfolder_names
if os.path.isdir(os.path.join(parent_folder, name))
]
return subfolder_paths


def get_dcm_filepaths(folder_path):
"""Parse through a folder of .dcm files and return a list of all file paths.

Parameters:
folder_path (str): The path to the folder containing the .dcm files.

Returns:
List[str]: A list of file paths for all .dcm files in the folder.
"""
dcm_filepaths = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.endswith(".dcm"):
dcm_filepaths.append(os.path.join(root, file))
return dcm_filepaths


seriesids = get_series_instance_uids(csv_input)
subfolders = get_subfolder_paths(ParentDirectory, seriesids)
print(f"Examining {len(subfolders)} subfolders")

# Initialize an empty dictionary
data_dict = {}


def collectscores(seriesuid, scores):
global data_dict

keys = ["SeriesInstanceUID", "year1", "year2", "year3", "year4", "year5", "year6"]
values = [
os.path.basename(seriesuid),
scores[0][0][0],
scores[0][0][1],
scores[0][0][2],
scores[0][0][3],
scores[0][0][4],
scores[0][0][5],
]

# Use the seriesuid as the key for this entry in the data_dict
entry_key = os.path.basename(seriesuid)
data_dict[entry_key] = dict(zip(keys, values))
return data_dict


def save_data_as_csv(data_dict, output_filename):
# First, create an empty CSV file
with open(output_filename, "w", newline="") as csvfile:
pass

# Open a CSV file in write mode
with open(output_filename, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=headers_scores)

# Write the header row
writer.writeheader()

# Iterate over the data dictionary and write each entry to the CSV file
for series_key, series_value in data_dict.items():
writer.writerow(series_value)


def log_error(error_message, subfolder):
with open(log_file, "a") as log:
log.write("Error in subfolder '{}': {}\n\n".format(subfolder, error_message))
traceback.print_exc(file=log)


headers_scores = [
"SeriesInstanceUID",
"year1",
"year2",
"year3",
"year4",
"year5",
"year6",
]

with open(csvoutput, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=headers_scores)
writer.writeheader()

with open(csv_error, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["SeriesInstanceUID"])
writer.writeheader()

# New headers for attention scores CSV
headers_attention_scores = ["series instance uid", "attention_score"]

# Open the CSV file for attention scores
with open(csvattentionoutput, "w", newline="") as attention_csvfile:
attention_writer = csv.writer(attention_csvfile)
attention_writer.writerow(headers_attention_scores)

for i, subfolder in enumerate(subfolders):
try:
# print(f"{i+1} / {len(subfolders)}: examining {subfolder} ...")
dcm_filepaths = get_dcm_filepaths(subfolder)
serie = Serie(dcm_filepaths)
scores = model.predict([serie], return_attentions=True)
attentions = scores.attentions

# Save attention scores to the attention CSV file
with open(csvattentionoutput, "a", newline="") as attention_csvfile:
attention_writer = csv.writer(attention_csvfile)
for attention_score in attentions:
attention_writer.writerow(
[os.path.basename(subfolder), attention_score]
)

folder_type = os.path.basename(os.path.normpath(ParentDirectory))
data_dict = collectscores(subfolder, scores)

with open(csvoutput, "a", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=headers_scores)
writer.writerow(data_dict[os.path.basename(subfolder)])

# save_data_as_csv(data_dict, csvoutput)

series_with_attention = visualize_attentions(
serie,
attentions=attentions,
save_directory=rf"{INFERENCE_DIR}/attention_gifs/",
gain=3,
series_uids=str(os.path.basename(subfolder)),
)

except Exception as e:
log_error(str(e), subfolder)
seriesuid = os.path.basename(subfolder)
with open(csv_error, "a", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["SeriesInstanceUID"])
writer.writerow({"SeriesInstanceUID": seriesuid})
continue

end_time_model = time.time()
print(f"Time taken for inference: {end_time_model - start_time_model} seconds")

save_data_as_csv(data_dict, csvoutput2)
8 changes: 4 additions & 4 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from urllib.request import urlopen
from zipfile import ZipFile


import torch
import numpy as np

Expand All @@ -14,7 +15,6 @@
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info


# Leaving this here for a bit; these are IDs to download the models from Google Drive
NAME_TO_FILE = {
"sybil_base": {
Expand Down Expand Up @@ -354,9 +354,9 @@ def predict(
for i in range(len(series)):
att = {}
for key in attention_keys:
att[key] = np.stack([
attentions_[j][i][key] for j in range(len(self.ensemble))
])
att[key] = np.stack(
[attentions_[j][i][key] for j in range(len(self.ensemble))]
)
attentions.append(att)

return Prediction(scores=calib_scores, attentions=attentions)
Expand Down
16 changes: 14 additions & 2 deletions sybil/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def visualize_attentions(
attentions: List[Dict[str, np.ndarray]],
save_directory: str = None,
gain: int = 3,
series_uids: Union[str, List[str]] = None,
) -> List[List[np.ndarray]]:
"""
Args:
Expand All @@ -67,6 +68,9 @@ def visualize_attentions(
if isinstance(series, Serie):
series = [series]

if isinstance(series_uids, str):
series_uids = [series_uids]

series_overlays = []
for serie_idx, serie in enumerate(series):
images = serie.get_raw_images()
Expand All @@ -76,8 +80,16 @@ def visualize_attentions(
overlayed_images = build_overlayed_images(images, cur_attention, gain)

if save_directory is not None:
save_path = os.path.join(save_directory, f"serie_{serie_idx}")
save_images(overlayed_images, save_path, f"serie_{serie_idx}")
if series_uids is not None:
# save_path = os.path.join(
# save_directory, f"serie_{series_uids[serie_idx]}"
# )
save_images(
overlayed_images, save_directory, f"serie_{series_uids[serie_idx]}"
)
else:
# save_path = os.path.join(save_directory, f"serie_{serie_idx}")
save_images(overlayed_images, save_directory, f"serie_{serie_idx}")

series_overlays.append(overlayed_images)
return series_overlays
Expand Down