Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting multiple elements in the same ax seems to work only when show() is not called. #71

Open
LucaMarconato opened this issue May 14, 2023 · 2 comments
Labels
bug Something isn't working enhancement New feature or request priority: low

Comments

@LucaMarconato
Copy link
Member

I refer to the code mentioned in this other issue: #68

This code here:

    ax = plt.gca()
    sdata.pl.render_shapes(element='s', na_color=(0.5, 0.5, 0.5, 0.5)).pl.render_points().pl.show(ax=ax)
    sdata.pl.render_shapes(element='c', na_color=(0.7, 0.7, 0.7, 0.5)).pl.show(ax=ax)
    plt.show()

doesn't work if I run the code as a script, but it works in interactive mode (where because of a bug the plots are not shown until I call plt.show()). I suggest to do like scanpy and having a parameter show: bool. I suggest also that if the parameter ax is not None, then show is set to False. I don't remember if this one is also a behavior of scanpy, but I think it's reasonable.

@josenimo
Copy link

josenimo commented Sep 4, 2024

Hello devs,

I have a really cool function on my hands, and I saving a summary plot is proving to be quite difficult. So I am kinda restarting this issue.

My function would take an image as an input, perform segmentation of the image using Cellpose via SOPA, and produce a PNG file with a hyperparameter search, to decide what is the best segmentation.

Currently I am running this code for plotting each ax object, in a fig that has many axes.

sdata.pl.render_images(
    element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
    element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))

When this line is reached in the CLI a matplotlib popup comes up with the entire figure, but a single filled ax object.
I have to manually close this first figure, and then the other axes are plotted, and then the entire figure saved (I think overwriting itself).

I have looked into matplotlib docs but I found no clear answer.

Any tips, ideas, or comments, very welcome. For the plotting or the function in general.

Best,
Jose

Entire script Function
#system
from loguru import logger
import argparse
import sys
import os
import time

import spatialdata
import spatialdata_plot

#imports
import skimage.segmentation as segmentation
import skimage.io as io
import numpy as np

#yaml 
import yaml
import math
import matplotlib.pyplot as plt
import re
import os
import matplotlib.gridspec as gridspec

#sopa
import sopa.segmentation
import sopa.io

def get_args():
    """ Get arguments from command line """
    description = """Expand labeled masks by a certain number of pixels."""
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
    inputs = parser.add_argument_group(title="Required Input", description="Path to required input file")
    inputs.add_argument("-i", "--input",    dest="input",   action="store", required=True, help="File path to input mask or folders with many masks")
    inputs.add_argument("-c", "--config",   dest="config",  action="store", required=True, help="Path to config.yaml for cellpose parameters")
    inputs.add_argument("-o", "--output",   dest="output",  action="store", required=True, help="Path to output mask, or folder where to save the output masks")
    inputs.add_argument("-l", "--log-level",dest="loglevel", default='INFO', choices=["DEBUG", "INFO"], help='Set the log level (default: INFO)')
    arg = parser.parse_args()
    arg.input = os.path.abspath(arg.input)
    arg.config = os.path.abspath(arg.config)
    arg.output = os.path.abspath(arg.output)
    return arg

def check_input_outputs(args):
    """ Check if input and output files exist """
    #input
    assert os.path.isfile(args.input), "Input must be a file"
    assert args.input.endswith((".tif", ".tiff")), "Input file must be a .tif or .tiff file"
    #config
    assert os.path.isfile(args.config), "Config must exist"
    assert args.config.endswith(".yaml"), "Config file must be a .yaml file"
    #output
    if not os.path.exists(args.output):
        os.makedirs(args.output)
    assert os.path.isdir(args.output), "Output must be a folder"
    #create output folders
    os.makedirs(os.path.join(args.output, "pngs"), exist_ok=True)
    ### os.makedirs(os.path.join(args.output, "zarrs"), exist_ok=True)
    args.filename = os.path.basename(args.input).split(".")[0]
    args.zarr_path = os.path.join(args.output, f"{args.filename}.zarr")
    
    logger.info(f"Input, output and config files exist and checked.")

def create_sdata(args):
    """ Create sdata object """
    logger.info(f"Creating spatialdata object.")
    time_start = time.time()

    sdata = sopa.io.ome_tif(args.input)
    args.image_key = list(sdata.images.keys())[0]

    time_end = time.time()
    logger.info(f"Creating spatialdata object took {time_end - time_start} seconds.")
    return sdata

def prepare_for_segmentation_search(sdata, args):
    """ Search for segments in sopa data """
    logger.info(f"Preparing for segmentation search.")
    time_start = time.time()

    patches = sopa.segmentation.Patches2D(sdata, element_name=args.image_key, patch_width=1000, patch_overlap=100)
    patches.write()

    #reset channel names to their indexes, metadata to inconsistent
    new_c = list(range(len(sdata.images[args.image_key]['scale0'].coords['c'].values)))
    sdata.images[args.image_key] = sdata.images[args.image_key].assign_coords(c=new_c)

    time_end = time.time()
    logger.info(f"Preparation for segmentation took {time_end - time_start} seconds.")
    return sdata

def read_yaml(file_path):
    """ Read yaml file """
    logger.info(f"Reading yaml file.")
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    return data

def segmentation_loop(sdata, args, config):
    """ Loop through different cellpose parameters """
    logger.info(f"Starting segmentation loop.")

    for ft in config['flow_thresholds']:
        for cpt in config['cellprob_thresholds']:

            logger.info(f"Segmenting with FT: {ft} and CT: {cpt}")
            FT_str = str(ft).replace(".", "")
            #create method for segmenting
            method = sopa.segmentation.methods.cellpose_patch(
                diameter=config['cell_pixel_diameter'], 
                channels=config['channels'], 
                flow_threshold=ft, 
                cellprob_threshold=cpt, 
                model_type=config['model_type']
            )
            segmentation = sopa.segmentation.StainingSegmentation(sdata, method, channels=config['channels'], min_area=config['min_area'])
            #create temp dir to store segmentation of each tile
            cellpose_temp_dir = os.path.join(args.output, ".sopa_cache", "cellpose", f"run_FT{FT_str}_CPT{cpt}")
            #segment
            segmentation.write_patches_cells(cellpose_temp_dir)
            #read and solve conflicts
            cells = sopa.segmentation.StainingSegmentation.read_patches_cells(cellpose_temp_dir)
            cells = sopa.segmentation.shapes.solve_conflicts(cells)
            #save segmentation of entire image as shapes
            sopa.segmentation.StainingSegmentation.add_shapes(
                sdata, cells, image_key=args.image_key, shapes_key=f"cellpose_boundaries_FT{FT_str}_CT{cpt}")
    
    logger.info(f"Saving zarr to {args.zarr_path}")
    sdata.write(args.zarr_path, overwrite=True)
    logger.info(f"Segmentation loop finished.")

def extract_ft_values(shape_titles):
    """Extract all unique ft values from a list of shape titles."""
    ft_values = set()
    cpt_values = set()
    for title in shape_titles:
        match = re.search(r'_FT(\d+)_CT(\d+)', title)
        if match:
            ft_values.add(match.group(1))
            cpt_values.add(match.group(2))
        else:
            print(f"Warning: {title} does not match the expected pattern.")
    return sorted(ft_values), sorted(cpt_values)

def plot(sdata, args, config):

    shape_titles = list(sdata.shapes.keys())
    shape_titles.remove("sopa_patches")
    logger.info(f"Plotting {shape_titles} segmentations")

    logger.info

    unique_ft_values, unique_cpt_values = extract_ft_values(shape_titles)
    num_cols = len(unique_ft_values)
    num_rows = len(unique_cpt_values)
    logger.info(f"Unique FT values: {unique_ft_values} and Unique CT values: {unique_cpt_values}")
    ft_to_index = {ft: i for i, ft in enumerate(unique_ft_values)}
    cpt_to_index = {cpt: i for i, cpt in enumerate(unique_cpt_values)}
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*6, num_rows*6), facecolor='black')
    gs = gridspec.GridSpec(num_rows, num_cols, wspace=0.1, hspace=0.1)

    for i, title in enumerate(shape_titles):
        #print number of title out of all titles
        logger.info(f"Rendering {i+1}/{len(shape_titles)} ||| {title}")
        ft, cpt = re.search(r'FT(\d+)_CT(\d+)', title).groups()
        row = cpt_to_index[cpt]
        col = ft_to_index[ft]

        ax = fig.add_subplot(gs[row, col])
        ax.set_facecolor('black')
        ax.title.set_color('white')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

        try:
            logger.info(f"  Rendering image")
            sdata.pl.render_images(
                element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
            ).pl.render_shapes(
                element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
            ).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))
            logger.info(f"Saving plot to {os.path.join(args.output, 'pngs', 'segment_search.png')}")
            # plt.savefig(os.path.join(args.output, "pngs", "segment_search.png"))
        except:
            print(f"Could not render shapes of {title}")

def main():
    args = get_args()
    logger.remove()
    logger.add(sys.stdout, format="<green>{time:HH:mm:ss.SS}</green> | <level>{level}</level> | {message}", level=args.loglevel.upper())
    check_input_outputs(args)
    sdata = create_sdata(args)
    sdata = prepare_for_segmentation_search(sdata, args)
    segmentation_loop(sdata, args, config=read_yaml(args.config))
    plot(sdata, args, config=read_yaml(args.config))

if __name__ == "__main__":
    main()

"""
Example:

python ./scripts/segment_search.py \
--input ./data/input/Exemplar001.ome.tif \
--config ./data/configs/config.yaml \
--output ./data/output/
"""

@timtreis timtreis added the enhancement New feature or request label Sep 4, 2024
@timtreis timtreis removed their assignment Sep 30, 2024
@LucaMarconato
Copy link
Member Author

Note for us. This bug and the newly reported bug #362 are related.

Thanks @josenimo for the bug report. We will try to address this bug soon. Meanwhile, I would suggest checking if setting the matplotlib backend to Agg could work for you. Or maybe using plt.ion()/plt.ioff() as described in this other issue #68.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request priority: low
Projects
None yet
Development

No branches or pull requests

3 participants