-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plotting multiple elements in the same ax
seems to work only when show()
is not called.
#71
Comments
Hello devs, I have a really cool function on my hands, and I saving a summary plot is proving to be quite difficult. So I am kinda restarting this issue. My function would take an image as an input, perform segmentation of the image using Cellpose via SOPA, and produce a PNG file with a hyperparameter search, to decide what is the best segmentation. Currently I am running this code for plotting each ax object, in a fig that has many axes. sdata.pl.render_images(
element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png")) When this line is reached in the CLI a matplotlib popup comes up with the entire figure, but a single filled ax object. I have looked into matplotlib docs but I found no clear answer. Any tips, ideas, or comments, very welcome. For the plotting or the function in general. Best, Entire script Function#system
from loguru import logger
import argparse
import sys
import os
import time
import spatialdata
import spatialdata_plot
#imports
import skimage.segmentation as segmentation
import skimage.io as io
import numpy as np
#yaml
import yaml
import math
import matplotlib.pyplot as plt
import re
import os
import matplotlib.gridspec as gridspec
#sopa
import sopa.segmentation
import sopa.io
def get_args():
""" Get arguments from command line """
description = """Expand labeled masks by a certain number of pixels."""
parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
inputs = parser.add_argument_group(title="Required Input", description="Path to required input file")
inputs.add_argument("-i", "--input", dest="input", action="store", required=True, help="File path to input mask or folders with many masks")
inputs.add_argument("-c", "--config", dest="config", action="store", required=True, help="Path to config.yaml for cellpose parameters")
inputs.add_argument("-o", "--output", dest="output", action="store", required=True, help="Path to output mask, or folder where to save the output masks")
inputs.add_argument("-l", "--log-level",dest="loglevel", default='INFO', choices=["DEBUG", "INFO"], help='Set the log level (default: INFO)')
arg = parser.parse_args()
arg.input = os.path.abspath(arg.input)
arg.config = os.path.abspath(arg.config)
arg.output = os.path.abspath(arg.output)
return arg
def check_input_outputs(args):
""" Check if input and output files exist """
#input
assert os.path.isfile(args.input), "Input must be a file"
assert args.input.endswith((".tif", ".tiff")), "Input file must be a .tif or .tiff file"
#config
assert os.path.isfile(args.config), "Config must exist"
assert args.config.endswith(".yaml"), "Config file must be a .yaml file"
#output
if not os.path.exists(args.output):
os.makedirs(args.output)
assert os.path.isdir(args.output), "Output must be a folder"
#create output folders
os.makedirs(os.path.join(args.output, "pngs"), exist_ok=True)
### os.makedirs(os.path.join(args.output, "zarrs"), exist_ok=True)
args.filename = os.path.basename(args.input).split(".")[0]
args.zarr_path = os.path.join(args.output, f"{args.filename}.zarr")
logger.info(f"Input, output and config files exist and checked.")
def create_sdata(args):
""" Create sdata object """
logger.info(f"Creating spatialdata object.")
time_start = time.time()
sdata = sopa.io.ome_tif(args.input)
args.image_key = list(sdata.images.keys())[0]
time_end = time.time()
logger.info(f"Creating spatialdata object took {time_end - time_start} seconds.")
return sdata
def prepare_for_segmentation_search(sdata, args):
""" Search for segments in sopa data """
logger.info(f"Preparing for segmentation search.")
time_start = time.time()
patches = sopa.segmentation.Patches2D(sdata, element_name=args.image_key, patch_width=1000, patch_overlap=100)
patches.write()
#reset channel names to their indexes, metadata to inconsistent
new_c = list(range(len(sdata.images[args.image_key]['scale0'].coords['c'].values)))
sdata.images[args.image_key] = sdata.images[args.image_key].assign_coords(c=new_c)
time_end = time.time()
logger.info(f"Preparation for segmentation took {time_end - time_start} seconds.")
return sdata
def read_yaml(file_path):
""" Read yaml file """
logger.info(f"Reading yaml file.")
with open(file_path, 'r') as file:
data = yaml.safe_load(file)
return data
def segmentation_loop(sdata, args, config):
""" Loop through different cellpose parameters """
logger.info(f"Starting segmentation loop.")
for ft in config['flow_thresholds']:
for cpt in config['cellprob_thresholds']:
logger.info(f"Segmenting with FT: {ft} and CT: {cpt}")
FT_str = str(ft).replace(".", "")
#create method for segmenting
method = sopa.segmentation.methods.cellpose_patch(
diameter=config['cell_pixel_diameter'],
channels=config['channels'],
flow_threshold=ft,
cellprob_threshold=cpt,
model_type=config['model_type']
)
segmentation = sopa.segmentation.StainingSegmentation(sdata, method, channels=config['channels'], min_area=config['min_area'])
#create temp dir to store segmentation of each tile
cellpose_temp_dir = os.path.join(args.output, ".sopa_cache", "cellpose", f"run_FT{FT_str}_CPT{cpt}")
#segment
segmentation.write_patches_cells(cellpose_temp_dir)
#read and solve conflicts
cells = sopa.segmentation.StainingSegmentation.read_patches_cells(cellpose_temp_dir)
cells = sopa.segmentation.shapes.solve_conflicts(cells)
#save segmentation of entire image as shapes
sopa.segmentation.StainingSegmentation.add_shapes(
sdata, cells, image_key=args.image_key, shapes_key=f"cellpose_boundaries_FT{FT_str}_CT{cpt}")
logger.info(f"Saving zarr to {args.zarr_path}")
sdata.write(args.zarr_path, overwrite=True)
logger.info(f"Segmentation loop finished.")
def extract_ft_values(shape_titles):
"""Extract all unique ft values from a list of shape titles."""
ft_values = set()
cpt_values = set()
for title in shape_titles:
match = re.search(r'_FT(\d+)_CT(\d+)', title)
if match:
ft_values.add(match.group(1))
cpt_values.add(match.group(2))
else:
print(f"Warning: {title} does not match the expected pattern.")
return sorted(ft_values), sorted(cpt_values)
def plot(sdata, args, config):
shape_titles = list(sdata.shapes.keys())
shape_titles.remove("sopa_patches")
logger.info(f"Plotting {shape_titles} segmentations")
logger.info
unique_ft_values, unique_cpt_values = extract_ft_values(shape_titles)
num_cols = len(unique_ft_values)
num_rows = len(unique_cpt_values)
logger.info(f"Unique FT values: {unique_ft_values} and Unique CT values: {unique_cpt_values}")
ft_to_index = {ft: i for i, ft in enumerate(unique_ft_values)}
cpt_to_index = {cpt: i for i, cpt in enumerate(unique_cpt_values)}
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*6, num_rows*6), facecolor='black')
gs = gridspec.GridSpec(num_rows, num_cols, wspace=0.1, hspace=0.1)
for i, title in enumerate(shape_titles):
#print number of title out of all titles
logger.info(f"Rendering {i+1}/{len(shape_titles)} ||| {title}")
ft, cpt = re.search(r'FT(\d+)_CT(\d+)', title).groups()
row = cpt_to_index[cpt]
col = ft_to_index[ft]
ax = fig.add_subplot(gs[row, col])
ax.set_facecolor('black')
ax.title.set_color('white')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
try:
logger.info(f" Rendering image")
sdata.pl.render_images(
element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))
logger.info(f"Saving plot to {os.path.join(args.output, 'pngs', 'segment_search.png')}")
# plt.savefig(os.path.join(args.output, "pngs", "segment_search.png"))
except:
print(f"Could not render shapes of {title}")
def main():
args = get_args()
logger.remove()
logger.add(sys.stdout, format="<green>{time:HH:mm:ss.SS}</green> | <level>{level}</level> | {message}", level=args.loglevel.upper())
check_input_outputs(args)
sdata = create_sdata(args)
sdata = prepare_for_segmentation_search(sdata, args)
segmentation_loop(sdata, args, config=read_yaml(args.config))
plot(sdata, args, config=read_yaml(args.config))
if __name__ == "__main__":
main()
"""
Example:
python ./scripts/segment_search.py \
--input ./data/input/Exemplar001.ome.tif \
--config ./data/configs/config.yaml \
--output ./data/output/
"""
|
Note for us. This bug and the newly reported bug #362 are related. Thanks @josenimo for the bug report. We will try to address this bug soon. Meanwhile, I would suggest checking if setting the |
I refer to the code mentioned in this other issue: #68
This code here:
doesn't work if I run the code as a script, but it works in interactive mode (where because of a bug the plots are not shown until I call
plt.show()
). I suggest to do like scanpy and having a parametershow: bool
. I suggest also that if the parameterax
is notNone
, thenshow
is set toFalse
. I don't remember if this one is also a behavior of scanpy, but I think it's reasonable.The text was updated successfully, but these errors were encountered: