diff --git a/.hooks/startup.sh b/.hooks/startup.sh index bfcbdde..345066d 100644 --- a/.hooks/startup.sh +++ b/.hooks/startup.sh @@ -11,4 +11,4 @@ set -e pip install scikit-learn streamlit-extras squidpy split-file-reader st-pages dill pympler objsize mamba install -y -q natsort "foundry-transforms-lib-python>=0.578.0" -mamba install -y gcc_linux-64 gxx_linux-64 && mamba install -y hnswlib \ No newline at end of file +mamba install -y gcc_linux-64 gxx_linux-64 && mamba install -y python-annoy && mamba install -y hnswlib diff --git a/.streamlit/pages.toml b/.streamlit/pages.toml index badf4c1..43b44b1 100644 --- a/.streamlit/pages.toml +++ b/.streamlit/pages.toml @@ -34,33 +34,33 @@ name = "Main" icon = "" [[pages]] -name = "Phenotyping 🖊️" +name = "Phenotype Clustering Workflow ✨" icon = "" is_section = true [[pages]] -path = "pages/multiaxial_gating.py" -name = "Raw Intensities" +path = "pages/05a_Pheno_Cluster.py" +name = "Phenotype Clustering" icon = "" [[pages]] -path = "pages/02_phenotyping.py" -name = "Thresholded Intensities" +path = "pages/05b_Pheno_Cluster.py" +name = "Clusters Differential Expression" icon = "" [[pages]] -name = "Phenotype Clustering Workflow ✨" +name = "Phenotyping 🖊️" icon = "" is_section = true [[pages]] -path = "pages/05a_Pheno_Cluster.py" -name = "Phenotype Clustering" +path = "pages/multiaxial_gating.py" +name = "Raw Intensities" icon = "" [[pages]] -path = "pages/05b_Pheno_Cluster.py" -name = "Clusters Differential Expression" +path = "pages/02_phenotyping.py" +name = "Thresholded Intensities" icon = "" [[pages]] @@ -128,12 +128,17 @@ path = "pages/memory_analyzer.py" name = "Memory Analyzer" icon = "" -# [[pages]] -# name = "In Development" -# icon = "" -# is_section = true +[[pages]] +name = "Radial Profiles" +icon = "" +is_section = true + +[[pages]] +path = "pages/radial_profiles_app.py" +name = "Calculations" +icon = "" -# [[pages]] -# path = "pages/macro_radial_density.py" -# name = "Macro Radial Density" -# icon = "" +[[pages]] +path = "pages/radial_profiles_app_plotting_aggregated_results.py" +name = "Plots" +icon = "" diff --git a/Multiplex_Analysis_Web_Apps-orig.py b/Multiplex_Analysis_Web_Apps-orig.py new file mode 100644 index 0000000..7245eea --- /dev/null +++ b/Multiplex_Analysis_Web_Apps-orig.py @@ -0,0 +1,45 @@ +''' +Top level Streamlit Application for MAWA +''' + +# Import relevant libraries +import os +import streamlit as st + +# Import relevant libraries +import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP +import app_top_of_page as top +import streamlit_dataframe_editor as sde + +def main(): + ''' + Define the single, main function + ''' + + # Set a wide layout + st.set_page_config(layout="wide") + + input_path = './input' + if not os.path.exists(input_path): + os.makedirs(input_path) + + output_path = './output' + if not os.path.exists(output_path): + os.makedirs(output_path) + + # Run streamlit-dataframe-editor library initialization tasks at the top of the page + st.session_state = sde.initialize_session_state(st.session_state) + + # Run Top of Page (TOP) functions + st.session_state = top.top_of_page_reqs(st.session_state) + + # Markdown text + intro_markdown = ndl.read_markdown_file('markdown/MAWA_WelcomePage.md') + st.markdown(intro_markdown, unsafe_allow_html=True) + + # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page + st.session_state = sde.finalize_session_state(st.session_state) + +# Call the main function +if __name__ == '__main__': + main() diff --git a/Multiplex_Analysis_Web_Apps.py b/Multiplex_Analysis_Web_Apps.py index 7245eea..2f37a89 100644 --- a/Multiplex_Analysis_Web_Apps.py +++ b/Multiplex_Analysis_Web_Apps.py @@ -1,45 +1,194 @@ -''' -Top level Streamlit Application for MAWA -''' - -# Import relevant libraries -import os import streamlit as st - -# Import relevant libraries +import os +from streamlit_extras.app_logo import add_logo +import streamlit_session_state_management import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP -import app_top_of_page as top -import streamlit_dataframe_editor as sde +import streamlit_utils +import numpy as np +import subprocess +import platform_io +import install_missing_packages -def main(): +install_missing_packages.live_package_installation() + +# Note if any of the following imports having " # slow" are not commented out, there is a delay in running the forking test +from pages2 import data_import_and_export +from pages2 import datafile_format_unifier +from pages2 import open_file +from pages2 import robust_scatter_plotter +from pages2 import multiaxial_gating +from pages2 import thresholded_phenotyping # slow due to things ultimately importing umap +from pages2 import adaptive_phenotyping +from pages2 import Pheno_Cluster_a # "slow" for forking test initialization +from pages2 import Pheno_Cluster_b # "slow" for forking test initialization +from pages2 import Tool_parameter_selection +from pages2 import Run_workflow +from pages2 import Display_individual_ROI_heatmaps +from pages2 import Display_average_heatmaps +from pages2 import Display_average_heatmaps_per_annotation +from pages2 import Display_ROI_P_values_overlaid_on_slides +from pages2 import Neighborhood_Profiles # slow due to things ultimately importing umap +from pages2 import UMAP_Analyzer # slow due to things ultimately importing umap +from pages2 import Clusters_Analyzer # slow due to things ultimately importing umap +from pages2 import memory_analyzer +from pages2 import radial_bins_plots +from pages2 import radial_profiles_analysis +from pages2 import preprocessing +from pages2 import results_transfer +# from pages2 import forking_test + + +def welcome_page(): + # Markdown text + intro_markdown = ndl.read_markdown_file('markdown/MAWA_WelcomePage.md') + st.markdown(intro_markdown, unsafe_allow_html=True) + + +def platform_is_nidap(): + ''' + Check if the Streamlit application is operating on NIDAP + ''' + return np.any(['nidap.nih.gov' in x for x in subprocess.run('conda config --show channels', shell=True, capture_output=True).stdout.decode().split('\n')[1:-1]]) + + +def check_for_platform(session_state): ''' - Define the single, main function + Set the platform parameters based on the platform the Streamlit app is running on ''' + # Initialize the platform object + if 'platform' not in session_state: + session_state['platform'] = platform_io.Platform(platform=('nidap' if platform_is_nidap() else 'local')) + return session_state + + +def main(): - # Set a wide layout st.set_page_config(layout="wide") + # Use the new st.naviation()/st.Page() API to create a multi-page app + pg = st.navigation({ + 'Home 🏠': + [ + st.Page(welcome_page, title="Welcome", url_path='home') + ], + 'File Handling 🗄️': + [ + st.Page(data_import_and_export.main, title="Data Import and Export", url_path='data_import_and_export'), + st.Page(datafile_format_unifier.main, title="Datafile Unification", url_path='datafile_unification'), + st.Page(open_file.main, title="Open File", url_path='open_file') + ], + 'Coordinate Scatter Plotter 🌟': + [ + st.Page(robust_scatter_plotter.main, title="Coordinate Scatter Plotter", url_path='coordinate_scatter_plotter') + ], + 'Phenotyping 🧬': + [ + st.Page(multiaxial_gating.main, title="Using Raw Intensities", url_path='using_raw_intensities'), + st.Page(thresholded_phenotyping.main, title="Using Thresholded Intensities", url_path='using_thresholded_intensities'), + st.Page(adaptive_phenotyping.main, title="Adaptive Phenotyping", url_path='adaptive_phenotyping') + ], + 'Phenotype Clustering Workflow ✨': + [ + st.Page(Pheno_Cluster_a.main, title="Unsupervised Phenotype Clustering", url_path='unsupervised_phenotype_clustering'), + st.Page(Pheno_Cluster_b.main, title="Differential Intensity", url_path='differential_intensity') + ], + 'Spatial Interaction Tool 🗺️': + [ + st.Page(Tool_parameter_selection.main, title="Tool Parameter Selection", url_path='tool_parameter_selection'), + st.Page(Run_workflow.main, title="Run SIT Workflow", url_path='run_sit_workflow'), + st.Page(Display_individual_ROI_heatmaps.main, title="Display Individual ROI Heatmaps", url_path='display_individual_roi_heatmaps'), + st.Page(Display_average_heatmaps.main, title="Display Average Heatmaps", url_path='display_average_heatmaps'), + st.Page(Display_average_heatmaps_per_annotation.main, title="Display Average Heatmaps per Annotation", url_path='display_average_heatmaps_per_annotation'), + st.Page(Display_ROI_P_values_overlaid_on_slides.main, title="Display ROI P Values Overlaid on Slides", url_path='display_roi_p_values_overlaid_on_slides') + ], + 'Neighborhood Profiles Workflow 🌳': + [ + st.Page(Neighborhood_Profiles.main, title="Neighborhood Profiles", url_path='neighborhood_profiles'), + st.Page(UMAP_Analyzer.main, title="UMAP Differences Analyzer", url_path='umap_differences_analyzer'), + st.Page(Clusters_Analyzer.main, title="Clusters Analyzer", url_path='clusters_analyzer') + ], + 'Radial Profiles 🌀': + [ + st.Page(radial_bins_plots.main, title="Radial Bins Plots", url_path='radial_bins_plots'), + st.Page(radial_profiles_analysis.main, title="Radial Profiles Analysis", url_path='radial_profiles_analysis') + ], + 'Utilities 🛠️': + [ + st.Page(preprocessing.main, title="Preprocessing", url_path='preprocessing'), + st.Page(memory_analyzer.main, title="Memory Analyzer", url_path='memory_analyzer'), + st.Page(results_transfer.main, title="Results Transfer", url_path='results_transfer'), + # st.Page(forking_test.main, title="Forking Test", url_path='forking_test') + ] + }) + + # Ensure the input/output directories exist input_path = './input' if not os.path.exists(input_path): os.makedirs(input_path) - output_path = './output' if not os.path.exists(output_path): os.makedirs(output_path) - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) + # For widget persistence, we need always copy the session state to itself, being careful with widgets that cannot be persisted, like st.data_editor() (where we use the "__do_not_persist" suffix to avoid persisting it) + for key in st.session_state.keys(): + if (not key.endswith('__do_not_persist')) and (not key.startswith('FormSubmitter:')): + st.session_state[key] = st.session_state[key] - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) + # This is needed for the st.dataframe_editor() class (https://github.com/andrew-weisman/streamlit-dataframe-editor) but is also useful for seeing where we are and where we've been + st.session_state['current_page_name'] = pg.url_path if pg.url_path != '' else 'Home' + if 'previous_page_name' not in st.session_state: + st.session_state['previous_page_name'] = st.session_state['current_page_name'] - # Markdown text - intro_markdown = ndl.read_markdown_file('markdown/MAWA_WelcomePage.md') - st.markdown(intro_markdown, unsafe_allow_html=True) + # Add logo to sidebar + add_logo('app_images/mawa_logo-width315.png', height=250) + + # Determine whether this is the first time the app has been run + if 'app_has_been_run_at_least_once' not in st.session_state: + st.session_state['app_has_been_run_at_least_once'] = True + first_app_run = True + else: + first_app_run = False + + # Run session state management in the sidebar + streamlit_session_state_management.execute(first_app_run) + + # Initalize session_state values for streamlit processing + if 'init' not in st.session_state: + st.session_state = ndl.init_session_state(st.session_state) + + # Sidebar organization + with st.sidebar: + st.write('**:book: [Documentation](https://ncats.github.io/multiplex-analysis-web-apps/)**') + with st.expander('Advanced:'): + benchmark_button = True + if benchmark_button: + st.button('Record Benchmarking', on_click = st.session_state.bc.save_run_to_csv) + if st.button('Calculate memory used by Python session'): + streamlit_utils.write_python_session_memory_usage() + + # Check the platform + st.session_state = check_for_platform(st.session_state) + + # Format tooltips + tooltip_style = """ + + """ + st.markdown(tooltip_style,unsafe_allow_html=True) + + # On every page, display its title + st.title(pg.title) + + # Render the select page + pg.run() + + # Update the previous page location + st.session_state['previous_page_name'] = st.session_state['current_page_name'] - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) -# Call the main function +# Needed for rendering pages which use multiprocessing (https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods) if __name__ == '__main__': main() diff --git a/SpatialUMAP.py b/SpatialUMAP.py index c006777..f0e1981 100644 --- a/SpatialUMAP.py +++ b/SpatialUMAP.py @@ -220,9 +220,7 @@ def calculate_density_matrix_for_all_images(self, cpu_pool_size = 8): with mp.Pool(processes=cpu_pool_size) as pool: results = pool.starmap(utils.fast_neighbors_counts_for_block2, kwargs_list) - print('Finished calculating density matrix for all images. Concatenating results...') df_density_matrix = pd.concat(self.get_dataframes(results)) - print('Finished concatenating counts results.') full_array = None for ii, phenotype in enumerate(phenotypes): cols2Use = [f'{phenotype} in {x}' for x in range_strings] @@ -231,10 +229,9 @@ def calculate_density_matrix_for_all_images(self, cpu_pool_size = 8): full_array = array_set else: full_array = np.dstack((full_array, array_set)) - print('BBBBB') + full_array_nan = np.isnan(full_array) full_array[full_array_nan] = 0 - print('CCCCC') # Concatenate the results into a single dataframe return full_array @@ -274,10 +271,11 @@ def __init__(self, dist_bin_um, um_per_px, area_downsample): # Mean Densities self.dens_df = pd.DataFrame() self.prop_df = pd.DataFrame() - self.dens_df_mean = pd.DataFrame(data = {'clust_label': ['No Cluster'], + self.dens_df_mean = pd.DataFrame(data = {'clust_label': ['No Cluster'], 'phenotype': ['Other'], 'dist_bin': [25], - 'density_mean': [0]}) + 'density_mean': [0], + 'density_sem': [0]}) self.dens_df_se = pd.DataFrame() self.maxdens_df = pd.DataFrame() diff --git a/app_images/logo2c.png b/app_images/logo2c.png old mode 100755 new mode 100644 diff --git a/app_top_of_page.py b/app_top_of_page.py index 0d04ad6..9afeb7b 100644 --- a/app_top_of_page.py +++ b/app_top_of_page.py @@ -6,7 +6,7 @@ import numpy as np import platform_io import streamlit as st -from st_pages import show_pages_from_config, add_indentation +# from st_pages import show_pages_from_config, add_indentation from streamlit_extras.app_logo import add_logo import streamlit_session_state_management import streamlit_utils diff --git a/basic_phenotyper_lib.py b/basic_phenotyper_lib.py index 77b6f92..71c88c8 100644 --- a/basic_phenotyper_lib.py +++ b/basic_phenotyper_lib.py @@ -5,21 +5,20 @@ ''' import time +import math import numpy as np import pandas as pd -import umap +import umap # slow import warnings -import multiprocessing as mp warnings.simplefilter(action='ignore', category= FutureWarning) warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") pd.options.mode.chained_assignment = None # default='warn' import matplotlib.pyplot as plt import seaborn as sns from sklearn.cluster import KMeans # K-Means - -from benchmark_collector import benchmark_collector # Benchmark Collector Class from SpatialUMAP import SpatialUMAP import PlottingTools as umPT +import utils def preprocess_df(df_orig, marker_names, marker_col_prefix, bc): '''Perform some preprocessing on our dataset to apply tranforms @@ -75,6 +74,29 @@ def init_pheno_cols(df, marker_names, marker_col_prefix): # This was previously really slow. Code basically taken from new_phenotyping_lib.py marker_cols_first_row = df_markers.iloc[0, :].to_list() # get just the first row of marker values if (0 not in marker_cols_first_row) and (1 not in marker_cols_first_row): + + # Null values in df_markers will break the .map() step so check for and remove them here + ser_num_of_null_rows_in_each_column = df_markers.isnull().sum() + if ser_num_of_null_rows_in_each_column.sum() != 0: + + # For the time being, import Streamlit so warnings can be rendered. Otherwise, this file does not import streamlit and it should remain that way but this is a minimal fix for the time being + import streamlit as st + + st.warning('Null values have been detected in the phenotype columns. Next time, please check for and remove null rows in the datafile unification step (File Handling > Datafile Unification). We are removing them for you now but it would be *much* better to do this in the Datafile Unifier now! Otherwise, downstream functionality may not work. Here are the numbers of null rows found in each column containing them:') + ser_num_of_null_rows_in_each_column.name = 'Number of null rows' + st.write(ser_num_of_null_rows_in_each_column[ser_num_of_null_rows_in_each_column != 0]) + + # Perform the operation + row_count_before = len(df) + df = df.dropna(subset=marker_cols) + row_count_after = len(df) + + # Display a success message + st.write(f'{row_count_before - row_count_after} rows deleted') + + # Update df_markers + df_markers = df[marker_cols] + df_markers = df_markers.map(lambda x: {'+': '1', '-': '0'}[x[-1]]) df['mark_bits'] = df_markers.astype(str).apply(''.join, axis='columns') # efficiently create a series of strings that are the columns (in string format) concatenated together @@ -507,6 +529,7 @@ def setup_Spatial_UMAP(df, marker_names, pheno_order, smallest_image_size): df (Pandas dataframe): Dataframe containing the data marker_names (list): List of marker names pheno_order (list): List of phenotype order + smallest_image_size (int): The size of the smallest image in the dataset Returns: SpatialUMAP: SpatialUMAP object @@ -632,16 +655,14 @@ def perform_spatialUMAP(spatial_umap, bc, umap_subset_per_fit, umap_subset_toggl spatial_umap.set_train_test(n_fit=n_fit, n_tra = n_tra, groupby_label = 'TMA_core_id', seed=54321, umap_subset_toggle = umap_subset_toggle) # fit umap on training cells - bc.startTimer() print('Fitting Model') spatial_umap.umap_fit = umap.UMAP().fit(spatial_umap.density[spatial_umap.cells['umap_train'].values].reshape((spatial_umap.cells['umap_train'].sum(), -1))) - bc.printElapsedTime(f' Fitting {np.sum(spatial_umap.cells["umap_train"] == 1)} points to a model') + bc.printElapsedTime(f' Fitting {np.sum(spatial_umap.cells["umap_train"] == 1)} points to a model', split = True) # Transform test cells based on fitted model - bc.startTimer() print('Transforming Data') spatial_umap.umap_test = spatial_umap.umap_fit.transform(spatial_umap.density[spatial_umap.cells['umap_test'].values].reshape((spatial_umap.cells['umap_test'].sum(), -1))) - bc.printElapsedTime(f' Transforming {np.sum(spatial_umap.cells["umap_test"] == 1)} points with the model') + bc.printElapsedTime(f' Transforming {np.sum(spatial_umap.cells["umap_test"] == 1)} points with the model', split = True) spatial_umap.umap_completed = True @@ -701,9 +722,15 @@ def umap_clustering(spatial_umap, n_clusters, clust_minmax, cpu_pool_size = 8): ) ) - # Create a pool of worker processes - with mp.Pool(processes=cpu_pool_size) as pool: - results = pool.starmap(kmeans_calc, kwargs_list) + results = utils.execute_data_parallelism_potentially(kmeans_calc, + kwargs_list, + nworkers = cpu_pool_size, + task_description='KMeans Clustering', + use_starmap=True) + # mp_start_method = mp.get_start_method() + # # Create a pool of worker processes + # with mp.get_context(mp_start_method).Pool(processes=cpu_pool_size) as pool: + # results = pool.starmap(kmeans_calc, kwargs_list) wcss = [x.inertia_ for x in results] @@ -858,7 +885,7 @@ def neighProfileDraw(spatial_umap, ax, sel_clus, cmp_clus = None, cmp_style = No maxdens_df = 1.05*max(dens_df_mean_base['density_mean'] + dens_df_mean_base['density_sem']) dens_df_mean_sel = dens_df_mean_base.loc[dens_df_mean_base['clust_label'] == sel_clus, :].reset_index(drop=True) - ylim = [0, maxdens_df] + ylim = [1, maxdens_df] dens_df_mean = dens_df_mean_sel.copy() cluster_title = f'{sel_clus}' @@ -892,6 +919,9 @@ def neighProfileDraw(spatial_umap, ax, sel_clus, cmp_clus = None, cmp_style = No else: cmp_style = None + if not np.all([math.isfinite(x) for x in ylim]): + ylim = [1, 10] + umPT.plot_mean_neighborhood_profile(ax = ax, dist_bin = spatial_umap.dist_bin_um, pheno_order = spatial_umap.phenoLabel, diff --git a/benchmark_collector.py b/benchmark_collector.py index 5bc0e5f..d44a5f8 100644 --- a/benchmark_collector.py +++ b/benchmark_collector.py @@ -3,12 +3,11 @@ Specifically to mark the time it takes for functions to run and save their values in a spreadsheet (if wanted) ''' - +from datetime import datetime import os import time import numpy as np import pandas as pd -from datetime import datetime class benchmark_collector: ''' @@ -56,6 +55,7 @@ def __init__(self, fiol = None): self.benchmarkDF.loc[0, 'id'] = datetime.now() self.benchmarkDF.loc[0, 'on_NIDAP'] = self.on_nidap self.stTimer = None + self.stTimer_split = None self.spTimer = None def startTimer(self): @@ -63,6 +63,7 @@ def startTimer(self): Set the Start time to the current date-time ''' self.stTimer = time.time() + self.stTimer_split = self.stTimer def stopTimer(self): ''' @@ -70,22 +71,26 @@ def stopTimer(self): ''' self.spTimer = time.time() - def elapsedTime(self): + def elapsedTime(self, split = False): ''' Calculate the elapsed time from the spTimer and the stTimer ''' - if self.stTimer is not None: + if self.stTimer is not None and split is False: + self.stopTimer() + elapsed_time = np.round((self.spTimer - self.stTimer)/60, 2) + elif self.stTimer is not None and split is True: self.stopTimer() - elapsed_time = np.round(self.spTimer - self.stTimer, 3) + elapsed_time = np.round((self.spTimer - self.stTimer_split)/60, 2) + self.stTimer_split = self.spTimer else: elapsed_time = None return elapsed_time - def printElapsedTime(self, msg): + def printElapsedTime(self, msg, split = False): ''' - Print the current value of elapsed time + Print the current value of elapsed time ''' - print(f'{msg} took {self.elapsedTime()} s') + print(f'{msg} took {self.elapsedTime(split)} min') def check_df(self): ''' diff --git a/ci.yml b/ci.yml index 61bf33c..2b0b1b0 100644 --- a/ci.yml +++ b/ci.yml @@ -33,6 +33,10 @@ jobs: set -euo pipefail unset SUDO_UID SUDO_GID SUDO_USER + if [ -n "$EXTERNAL_CONNECTIONS_CA_PATH" ]; then + export CURL_CA_BUNDLE="${CURL_CA_BUNDLE:-$EXTERNAL_CONNECTIONS_CA_PATH}" + fi + export BUILD_PLUGIN_VERSION=latest export CI_PLUGIN_NAME="build.py" diff --git a/environment-2024-02-08.yml b/environment-2024-02-08.yml index 21baf76..e784675 100644 --- a/environment-2024-02-08.yml +++ b/environment-2024-02-08.yml @@ -34,4 +34,10 @@ dependencies: - parc - parmap - setuptools-scm + - annoy + - sklearn-ann + - pynndescent + - plotnine + + diff --git a/environment.yml b/environment.yml index ab1234f..a7bfdfc 100644 --- a/environment.yml +++ b/environment.yml @@ -15,4 +15,4 @@ dependencies: - parmap - setuptools-scm - scanpy - + - plotnine diff --git a/foundry_IO_lib.py b/foundry_IO_lib.py index 8fcdb3c..1305f8e 100644 --- a/foundry_IO_lib.py +++ b/foundry_IO_lib.py @@ -25,8 +25,8 @@ def __init__(self): token = os.environ.get('FOUNDRY_TOKEN', 'Not found') if (host_name == 'Not found') | (token == 'Not found'): # Import SDK handling library - from palantir.datasets import dataset - self.dataset = dataset + # from palantir.datasets import dataset + # self.dataset = dataset # Inform on working environment print('Not Operating on NIDAP') self.onNIDAP = False diff --git a/image_filter.py b/image_filter.py new file mode 100644 index 0000000..d3ada1d --- /dev/null +++ b/image_filter.py @@ -0,0 +1,156 @@ +# Import relevant libraries +import streamlit as st +import pandas as pd +import numpy as np + +# Global variable +st_key_prefix = 'imagefilter__' + + +class ImageFilter: + + def __init__(self, df, image_colname='Slide ID', st_key_prefix=st_key_prefix, possible_filtering_columns=None): + # self.df = df # no real need to save this + self.image_colname = image_colname + self.st_key_prefix = st_key_prefix + self.possible_filtering_columns = possible_filtering_columns + self.df_image_filter = get_filtering_dataframe(df, image_colname, st_key_prefix=self.st_key_prefix, possible_filtering_columns=possible_filtering_columns) + self.ready = False if self.df_image_filter is None else True + + + def select_images(self, key, color='red', return_df_masked=False): + return filter_images(self.df_image_filter, key=key, color=color, image_colname=self.image_colname, st_key_prefix=self.st_key_prefix, return_df_masked=return_df_masked) + + +def get_filtering_dataframe(df, image_colname='Slide ID', st_key_prefix=st_key_prefix, possible_filtering_columns=None): + + if possible_filtering_columns is None: + possible_columns_on_which_to_filter = df.columns + else: + possible_columns_on_which_to_filter = possible_filtering_columns + + # Allow the user to select the columns on which they want to filter + selected_cols_for_filtering = st.multiselect('Select columns on which to filter:', possible_columns_on_which_to_filter, key=st_key_prefix + 'selected_cols_for_filtering', on_change=reset_filtering_columns, kwargs={'st_key_prefix': st_key_prefix}) + + # Simplify the dataframe to presumably just the essentially categorical columns + if st.button('Prepare filtering data'): + st.session_state[st_key_prefix + 'df_deduped'] = df[[image_colname] + selected_cols_for_filtering].drop_duplicates().sort_values(selected_cols_for_filtering) + + # Ensure the deduplication based on the selected columns has been performed + if st_key_prefix + 'df_deduped' not in st.session_state: + st.warning('Please prepare the filtering data first') + return + + # Get a shortcut to the deduplicated dataframe + df_deduped = st.session_state[st_key_prefix + 'df_deduped'] + + # Return the resulting dataframe + return df_deduped + + +# This is an image filter that should behave somewhat like a Streamlit (macro) widget +def filter_images(df, key, color='red', image_colname='Slide ID', st_key_prefix=st_key_prefix, return_df_masked=False): + + selected_cols_for_filtering = df.columns[1:] + + with st.expander(f'Image filter for :{color}[{key}] group:', expanded=False): + + # Build a widget for all selected filtering columns + for col in selected_cols_for_filtering: + build_multiselect(df, col, key, st_key_prefix=st_key_prefix) + + # Create a mask for each filter + masks = [df[col].isin(st.session_state[st_key_prefix + 'filtering_multiselect_' + key + '_' + col]) for col in selected_cols_for_filtering if st.session_state[st_key_prefix + 'filtering_multiselect_' + key + '_' + col]] + + # Combine the masks + combined_mask = np.logical_and.reduce(masks) + + # Apply the combined mask to the DataFrame + if masks: + df_masked = df[combined_mask] + else: + df_masked = df.copy() + + # Output the number of images that passed the filter + st.write(f'Number of images filtered above: {len(df_masked)}') + + # Create an interactive dataframe to allow the user to customize the image selection + df_selection = st.dataframe(df_masked, on_select='rerun', hide_index=True, key=st_key_prefix + key + '_df_selection__do_not_persist') + + # Output the number of images that have been manually selected by the user + st.write(f'Number of images selected above: {len(df_selection["selection"]["rows"])}') + + # Output the filenames of the selected images + ser_selection = df_masked[image_colname].iloc[df_selection['selection']['rows']] + st.dataframe(ser_selection, hide_index=True) + + # Convert the list of selected images to a list + selected_images = ser_selection.tolist() + + # Save it to the session state + full_key = st_key_prefix + key + '_selected_images' + st.session_state[full_key] = selected_images + + # Also return it + if not return_df_masked: + return selected_images + else: + return selected_images, df_masked.set_index('Slide ID', drop=True) + + +# Reset the filtering columns +def reset_filtering_columns(st_key_prefix=st_key_prefix): + for key in st.session_state.keys(): + if key.startswith(st_key_prefix + 'filtering_multiselect_'): + st.session_state[key] = [] + if st_key_prefix + 'df_deduped' in st.session_state: + del st.session_state[st_key_prefix + 'df_deduped'] + + +# Build a multiselect widget for a given column +def build_multiselect(df, col, widget_key_prefix, st_key_prefix=st_key_prefix): + unique_vals = df[col].unique() + st.multiselect(f'Filter image on `{col}`:', unique_vals, key=st_key_prefix + 'filtering_multiselect_' + widget_key_prefix + '_' + col) + + +# Main function +def main(): + + # Load the full dataframe from disk + # Sample of how it's written to disk in the first place from preprocess_radial_profile_data.ipynb: + # import random + # input_filenames = df_transformed['input_filename'].unique() + # df_transformed[df_transformed['input_filename'].isin(random.sample(list(input_filenames), 10))].to_hdf(os.path.join('image_data.h5'), key='df_transformed_partial', mode='w', format='table', complevel=9) + if st.button('Load data from disk'): + st.session_state[st_key_prefix + 'df'] = pd.read_hdf('image_data.h5') + st.info(f'Data of shape {st.session_state[st_key_prefix + "df"].shape} loaded successfully') + + # Ensure the full dataset has been loaded from disk + if st_key_prefix + 'df' not in st.session_state: + st.warning('Please load the data first') + return + + # Get a shortcut to the full dataframe + df = st.session_state[st_key_prefix + 'df'] + + # Instantiate the object + image_filter = ImageFilter(df, image_colname='input_filename', st_key_prefix=st_key_prefix) + + # If the image filter is not ready (which means the filtering dataframe was not generated), return + if not image_filter.ready: + return + + # Create two image filters + selected_images_baseline = image_filter.select_images(key='baseline', color='blue') + selected_images_signal = image_filter.select_images(key='signal', color='red') + + # Output the selected images in each group + st.write('Selected images in the baseline group:') + st.write(selected_images_baseline) + st.write('Selected images in the signal group:') + st.write(selected_images_signal) # or could write, e.g., st.session_state[st_key_prefix + 'signal' + '_selected_images'] + + +# Run the main function +if __name__ == '__main__': + main() diff --git a/install_missing_packages.py b/install_missing_packages.py new file mode 100644 index 0000000..e7ec24c --- /dev/null +++ b/install_missing_packages.py @@ -0,0 +1,99 @@ +import subprocess +import importlib + +def is_mamba_installed(): + try: + # Run the 'mamba --version' command + result = subprocess.run(['mamba', '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check if the command was successful + if result.returncode == 0: + print("&&&& Mamba is installed.") + return True + else: + print("&&&& Mamba is not installed.") + return False + except FileNotFoundError: + # The command was not found + print("&&&& Mamba is not installed.") + return False + +def install_with_mamba(packages): + print(f"&&&& Attempting to install {', '.join(packages)} with mamba.") + try: + # Run the 'mamba install ' command + result = subprocess.run(['mamba', 'install', '-y'] + packages, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check if the command was successful + if result.returncode == 0: + print(f"&&&& {', '.join(packages)} have been installed successfully with mamba.") + print(result.stdout) + else: + print(f"&&&& Failed to install {', '.join(packages)} with mamba.") + print(result.stderr) + except Exception as e: + print(f"&&&& An error occurred while trying to install {', '.join(packages)} with mamba: {e}") + +def install_with_conda(packages): + print(f"&&&& Attempting to install {', '.join(packages)} with conda.") + try: + # Run the 'conda install ' command + result = subprocess.run(['conda', 'install', '-y'] + packages, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check if the command was successful + if result.returncode == 0: + print(f"&&&& {', '.join(packages)} have been installed successfully with conda.") + print(result.stdout) + else: + print(f"&&&& Failed to install {', '.join(packages)} with conda.") + print(result.stderr) + except Exception as e: + print(f"&&&& An error occurred while trying to install {', '.join(packages)} with conda: {e}") + +def install_with_pip(packages): + print(f"&&&& Attempting to install {', '.join(packages)} with pip.") + try: + # Run the 'pip install ' command + result = subprocess.run(['pip', 'install'] + packages, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check if the command was successful + if result.returncode == 0: + print(f"&&&& {', '.join(packages)} have been installed successfully with pip.") + print(result.stdout) + else: + print(f"&&&& Failed to install {', '.join(packages)} with pip.") + print(result.stderr) + except Exception as e: + print(f"&&&& An error occurred while trying to install {', '.join(packages)} with pip: {e}") + +def live_package_installation(): + packages_to_install = ['hnswlib', 'parc', 'sklearn_ann', 'annoy', 'pyNNDescent'] # last two probably only needed for published dashboards + installers_to_use = ['mamba', 'conda', 'pip'] + + for package in packages_to_install: + try: + importlib.import_module(package.lower()) + print(f"&&&& {package} is already installed.") + except ImportError: + print(f"&&&& {package} is not installed.") + for installer in installers_to_use: + print(f"&&&& Trying to install {package} using {installer}.") + if installer == 'mamba': + if is_mamba_installed(): + install_with_mamba([package]) + else: + print(f"&&&& mamba is not installed. Trying the next installer.") + continue + elif installer == 'conda': + install_with_conda([package]) + elif installer == 'pip': + install_with_pip([package]) + + try: + importlib.import_module(package.lower()) + print(f"&&&& {package} has been successfully installed using {installer}.") + break + except ImportError: + print(f"&&&& {package} was not successfully installed with {installer}.") + else: + print(f"&&&& {package} could not be installed after trying all installers.") diff --git a/markdown/MAWA_WelcomePage.md b/markdown/MAWA_WelcomePage.md index 44718f6..0bf2859 100644 --- a/markdown/MAWA_WelcomePage.md +++ b/markdown/MAWA_WelcomePage.md @@ -1,93 +1,135 @@ -# Multiplex Analysis Web Apps -## NCATS-NCI-DMAP +## Multiplex Analysis Web Apps (MAWA) + +### NCI CBIIT ## Welcome + Welcome to the Multiple Analysis Web Apps (MAWA) presented by NCATS-NCI-DMAP. This is your one stop resource for data exploration, algorithm tuning, and figure generation. The below is a guide to using this app with suggested workflow, step-by-step instructions, and FAQs ## Available Apps -1. Data Import and Export + +1. File Handling 1. Multiaxial Gating 1. Phenotyping 1. Spatial Interaction Tool 1. Neighborhood Profiles -### 1. Data Import and Export +### 1. File Handling + As with most projects, the first step in starting your analysis workflow is with importing the data you intend to use. This is argueabely the most important step and one where the most issues may arise. If you any questions at all about importing and exporting data to not hesitate in seeking help from a DMAP team member (Dante and Andrew). #### NIDAP Infrastructure -NIDAP manages files using a product called Compass. Compass is akin to other file management systems like Windows File Explorer and Apple Finder. Data (.csv, .txt, .png, etc files) are stored in a DATASET, which can be thought of as a file folder. It is the goal of this app that as data is made available to be processed, it is stored in a NIDAP DATASET. Similarly, as results, figures, and updated datatables are generated, these data objects will be placed in a NIDAP DATASET for later sharing, downloading, and storing. -#### Data Import -Inside your NIDAP Project folder is a DATASET named `Input`. This is where you should upload files that you wish to analyze in the app. You can view those files inside the app from the dropdown +NIDAP manages files using a product called Compass. Compass is akin to other file management systems like Windows File Explorer and Apple Finder. Data (.csv, .txt, .png, etc files) are stored in a DATASET, which can be thought of as a file folder. It is the goal of this app that as data is made available to be processed, it is stored in a NIDAP DATASET. Similarly, as results, figures, and updated datatables are generated, these data objects will be placed in a NIDAP DATASET for later sharing, downloading, and storing. + +#### 1a. Data Import and Export + +1. Select any and all files that you want to import into MAWA from the the left-hand side of the screen in the section titled: **Available input data on NIDAP**. +2. Once those files are selected, click on the button in the middle of the screen that reads: **Load selected (at left) Input Data**. +3. Once the files have finished loading, make sure all selected files are visible on the right-hand side of the screen in the section titled: **Input data available to the tool.** -#### Data Export -There is another DATASET named `Output` which will hold the files that you have chosen to save. YOu can select anyone of these files to download for later use or delete if it cluttering the space. +#### 1B. Datafile Unification + +1. On Step 1 of this page, select any and all files that you wish to combine in the table on the left titled: **Select Datafiles** +2. Click the button that says **Combine selected files into single dataframe** +3. Once the files are done combining, review the table at the very bottom of the screen to identify that your data was loaded correctly. +4. Skip Step 2 in this section +5. On Step 3 of Datafile Unification, select the column labeled as `ShortName` for your Image identification. Click the **Assign Images** button to continue +6. On Step 4 of DU, simply click the button that says Assign ROIs and move on to the next step +7. On Step 5 of DU, keep the toggle which reads: **Select number of columns that specify one coordinate axis** to One Column. For the x-coordinate drop down select 'CentroidX' and for the y-coordinate select 'CentroidY'. Keep the micron conversion factor = 1um/pixel. Finalize this step by clicking the button titled: **Assign Coordinates** +8. For the first part of step 6 of the DU, expand the collapsed container and in the first drop down box select the option titled 'One Column for all Phenotypes'. After the page updates, in the second drop down box, select the column titled: `Pheno` or `Pheno_InEx`. Finish this section by selecting the button titled **Select Phenotypes** +9. For the second part of the Step 6 of the DU, once the phenotypes have been extracted, you should see a green box that says, `# phenotypes extracted`. In a table below that, you will see the names of the phenotypes. You have the option now to rename them to something else, should you choose. I would recommend removing the '+' sign from the phenotype labels in the right-hand column of the renamed phenotypes (There are 5 phenotypes that require this). Once this is done, click on the button titled **Assign Phenotypes** +10. For Step 7 of the DU, save your version of the unified dataset with a memorable name. Keep in mind that it will be prefaced with 'mawa-unified_dataset_'. Click the button titled: 'Save dataframe to csv'. +11. Return to the Data Import and Export Tab and look for your titled dataframe in the dropdown select box in the center of the screen titled: 'Save MAWA-unified datafile to NIDAP'. Once your dataframe is selected, click the button titled: ‘Save Selected (above) MAWA unified datafile to NIDAP’. This may take some time (~1min) to complete. Once done, you will have a permeant version of the unified datafile to use in the future. Load it into the MAWA input space anytime you want to just like you would the base files. + +#### 1C. Open File + +1. If you have just completed the datafile unification process you will see the toggle at the top of the screen titled **Load Dataset from Datafile Unifier** is ON. Feel free to keep it ON.  +2. Go ahead and click the button below that reads: **Load the selected input dataset**. This may take a moment to complete (~1min). When it has you will see a sample of the dataset. Feel free to review it.  +3. If in the future, you have loaded a previously created mawa-unified dataset, once it is loaded into MAWA memory, you can move directly to this screen and load it using the drop down select box.  ### 2. Multiaxial Gating ### 3. Phenotyping + The second page to start your analysis pipeline is the phenotyping page. This is where you will load your data, view your different feature conditions, and select your phenotyping. There are two primary steps in performing phenotyping. -**Step 0**: Import Data. The phenotyper needs data to use perform phenotyping -**Step 1**: Select a Phenotyping Method: The app currently offers three different phenotyping methods. They are: +**Step 0**: Click the button at the top of the screen titled: Load Data. It may take a few minutes to complete (<3 min, working to improve this). +**Step 1**: Once the data is loaded, select a phenotyping method in the top right. The different phenotyping methods are described below. Once a methods has been selected, click the button titled **Apply Phenotyping Methond**. The app currently offers three different phenotyping methods. They are: `Species`: The phenotype is set as a direct copy of the species name (species name short to be more precise). The species name is the combination of markers that the cell type is positive for. `Markers`: The phenotype is set as one of the given Marker that the cell is positive for. If the cell is positive for more than one Marker, the cell entry in the dataset is duplicated and represented by each positive marker value being studied. - `Custom`: The phenotype is set as a value of your own choise. Once the Custom phenotyping is selected, the Phenotype Summary Table on the right side of the page becomes editable. + `Custom`: The phenotype is set as a value of your own choise. Once the Custom phenotyping is selected, the Phenotype Summary Table on the right side of the page becomes editable. ### 4. Spatial Interaction Tool -### 5. Neighborhood Profiles (UMAP) -Once data is loaded and phenotyped appropriately, Neighborhood Profiles analysis can begin. There are 4 main steps required: -**Step 0**: Make sure your data is imported and phenotyped -**Step 1**: Perform Cell Counts/Areas: Clicking on this button starts the process of calculating the density of the cells in a given neighborhood profile. -**Step 2**: Perform UMAP: Clicking on this button starts the process of performing a 2-dimensional spatial UMAP on the dataset. -**Step 3**: Perform Clustering: Clicking this button performs k-means clustering on the 2-D UMAP coordinates produced from the previous step. The value of k can be set by using the slider located next to the button +### 5. Neighborhood Profiles (UMAP) Workflow + +#### 5A. Neighborhood Profiles + +1. Expand the collapsed container labeled: *Neighborhood Profiles Settings*. In this container make sure that `Number of CPUs` is set to 7 and `Calculate unique areas` is set to OFF. For the middle number boxes, set Percentage of cells to Subset for Fitting Step equal to 50 and Percentage of cells to Subset for Transforming Step equal to 50. Make sure the toggle titled Subset data transformed by UMAP is ON, and the toggle titled Load pre-generated UMAP is set to OFF. + +2. Click the button titled **Perform Cell Density Analysis**. This should complete in under 5 min +3. Click the button titled **Perform UMAP Analysis**. This should complete in 10-15 min +4. Once the UMAP is complete, you will see options for performing the clustering. There are two ways to perform clustering. + 1. Perform Clustering on UMAP Density Difference = OFF: This will not perform any difference metrics, but instead will perform clustering on the whole UMAP distribution. No distinction is made based on a feature of the data. Try selecting a random cluster number between 1-10. Then click the button titled Perform Clustering Analysis. An elbow plot will appear to allow you to adjust the number of clusters to a number of your choosing. Each time you adjust the number of clusters, you will need to resubmit the Perform Clustering Analysis button. This should be completed in roughly 3 min (we are working to improve the timing). + 2. Perform Clustering on UMAP Density Difference = ON:  This will allow you to perform individual clustering steps on regions of the UMAP which include cells of a specific feature condition. For example, how do differences between large nuclei and small nuclei cell contribute to the distribution of the UMAP? Select a column from the dropdown select box that is a numeric value (like area) or has exactly two unique values (For example: TRUE/FALSE). If you attempt to choose a categorical or string feature that has only 1 unique value or more than 2 unique values, you cannot perform the difference UMAP clustering. Choose any number of clusters for the FALSE and TRUE clustering to start off. For now, ignore the box titled: `Cutoff Percentage`. Rerun the clustering by hitting the box titled Perform Clustering Analysis. Once it has completed, elbow plots will appear under the cluster values for your investigation. You will also see many figures appear as well. Anytime you want to adjust the column being observed, or the number of clusters to use, you will need to resubmit the Perform Clustering Analysis button. This should be completed in roughly 3 min (we are working to improve the timing). +5. Once clustering is complete peruse the figures, as well as moving on to sections of Neighborhood Profiles//UMAP Differences and Neighborhood Profiles//Clusters Analyzer. + +#### 5b. UMAP Differences Analyzer -#### 4. UMAP Differences Analyzer After completing the UMAP decomposition and clustering analysis, the user may now take a look at the down-stream figures generated as a result of these analyses. While there are not many levers and knobs to change the data implicitly here, the user can generate different figures. + 1. Before starting to view these Clustering Differences, you must complete at least the UMAP processing seen on the previous page. To experience the full offering of the Clustering Differences page, you must also complete the Clustering step on the previous page. There are warnings on the page to help you remember what needs to be completed in order to see each figure. 2. The Figures that are available for viewing: 1. Full 2D UMAP - 2. 2D UMAP filtered by lineage and features - 3. Different UMAP scaled by features + 2. 2D UMAP filtered by lineage and features + 3. Different UMAP scaled by features + +#### 5c. Clusters Analyzer -#### 5. Clusters Analyzer After completing the UMAP decomposition and clustering analysis, the user may now take a look at the down-stream figures generated as a result of these analyses. The Cluster Analyzer page contains two figures generated from the upstream data analysis: -1. `Cluster/Phenotype Heatmap` + +1. `Cluster/Phenotype Heatmap` 2. `Incidence Lineplot` These figures have been created to investigate the composition of phenotypes of cells in assigned clustera, and the feature expression of cells in assigned s. Once a given figure is generated, you can change the name of the output in the text below each and add it as a figure to be exported in the `Data Input and Output` Page. The following sections are some general information about each figure: #### Cluster/Phenotype Heatmap + The heatmap offers a view of the number of each phenotyped cell located within each cluster. It offers three nromalization options for viewing the heatmap: -1. `No Norm`: No normalization is applied to the heatmap. The relative colors for each cell is scaled for all cells in all phenotypes in all clusters. If you were to sum the numbers shown in the grid, they would sum to the total number of cells fit to the spatial-umap model. +1. `No Norm`: No normalization is applied to the heatmap. The relative colors for each cell is scaled for all cells in all phenotypes in all clusters. If you were to sum the numbers shown in the grid, they would sum to the total number of cells fit to the spatial-umap model. 2. `Norm within Clusters`: The grid values are decimal values of the number of cells within a cluster assigned to a given phenotype. In this schema, the relative color of the grid is based on the within- 3. `Norm within Phenotypes`: The grid values are decimal values #### Incidence Lineplot + The incidence lineplot details how the cells within each cluster differ in their expression of the data features recorded alongside the cell positions and marker values. These features range from boolean values (True/False), continuous values (-1, 0, 1), and string values('time0'). There are two selection boxes to augment the indicence line plot, and a radio button to select the type of comparison to perform. They are the following: `Feature Selectbox`: Features that can be considered for the Incidence lineplot. + - Cell Counts: The number of cells assigned to a given cluster -- HYPOXIC, NORMOXIC, NucArea, RelOrientation, etc: Any other feature that specified to Dante/Andrew as one that is worth showing. YOU MUST tell us which ones you want and we will set it up for you. +- HYPOXIC, NORMOXIC, NucArea, RelOrientation, etc: Any other feature that specified to Dante/Andrew as one that is worth showing. YOU MUST tell us which ones you want and we will set it up for you. `Phenotype Selectionbox`: The phenotype the cells being plotted. The options shown are: + - All Phenotypes: Shows all cells irrespective of phenotype - VIM+, ECAD+, VIM+ECAD+, Other, etc...: The other phenotypes that have been selected in the Phenotyping stage of the workflow. `DisplayAs Radio Button`: How the values of the Feature selectbox should be displayed. This radio button is disabled for the Cell Counts condition, but is enabled for any other Feature selection. The options to be displayed are: + - Count Differences: The value shown on the y-axis is the difference between the number of cells in a cluster in the Y>0 condition subtracted from the number of cells in that cluster in the Y<0 condition. - Percentages: The value shown on the y-axis is the percentage of cells that match a feature condition in that given cluster. If you were to sum all the values across the clusters, they would sum to 100%. - Ratios: The value shown on the y-axis is the ratio of r1/r0 where r1 is the precentage of cells that match the feature of condition shown on y>0 in that cluster, and r0 is the percentage of cells that match the feature of the condition show on y<0 in that cluster. ## FAQs + Q: How do I add more features to view in the filtering step? A: Ask a member of NCATS-NCI-DMAP(Dante or Andrew) to add that column name to the dropdown menus Q: What do if I cant find the data I want in the drop-down menus? A: The easiest way to add new data for access is to add it to the following NIDAP dataset: - + Q: Where is the data/figure I exported to NIDAP? Q: How do I know if my data is in the right format for use in this app? -Q: Can I load a previously generated phenotyping summary file? \ No newline at end of file +Q: Can I load a previously generated phenotyping summary file? \ No newline at end of file diff --git a/neighborhood-profiles.ipynb b/neighborhood-profiles.ipynb index 6116ee9..c1087d2 100644 --- a/neighborhood-profiles.ipynb +++ b/neighborhood-profiles.ipynb @@ -122,9 +122,18 @@ "\n", " session_state.bc.startTimer()\n", " session_state.spatial_umap = bpl.perform_clusteringUMAP(session_state.spatial_umap,\n", - " session_state.slider_clus_val)\n", + " session_state.slider_clus_val,\n", + " session_state.clust_minmax,\n", + " session_state.cpu_pool_size)\n", + " session_state.cluster_dict = session_state.spatial_umap.cluster_dict\n", + " session_state.palette_dict = session_state.spatial_umap.palette_dict\n", " session_state.selected_nClus = session_state.slider_clus_val\n", "\n", + " # Draw the 2D histogram UMAP colored by the clusters\n", + " session_state.udp_full.cluster_dict = session_state.cluster_dict\n", + " session_state.udp_full.palette_dict = session_state.palette_dict\n", + " session_state.diff_clust_Fig = session_state.udp_full.umap_draw_clusters()\n", + "\n", " # Record time elapsed\n", " session_state.bc.printElapsedTime(msg = 'Setting Clusters')\n", " session_state.bc.set_value_df('time_to_run_cluster', session_state.bc.elapsedTime())\n", @@ -136,15 +145,22 @@ "\n", "def filter_and_plot(session_state):\n", " '''\n", - " function to update the filtering and the figure plotting\n", + " callback function to update the filtering and the \n", + " figure plotting\n", " '''\n", + "\n", + " session_state.prog_left_disabeled = False\n", + " session_state.prog_right_disabeled = False\n", + "\n", + " if session_state['idxSlide ID'] == 0:\n", + " session_state.prog_left_disabeled = True\n", + "\n", + " if session_state['idxSlide ID'] == session_state['numSlide ID']-1:\n", + " session_state.prog_right_disabeled = True\n", + "\n", " if session_state.umapCompleted:\n", " session_state.spatial_umap.df_umap_filt = session_state.spatial_umap.df_umap.loc[session_state.spatial_umap.df_umap['Slide ID'] == session_state['selSlide ID'], :]\n", - " if session_state['toggle_clust_diff']:\n", - " palette = 'bwr'\n", - " else:\n", - " palette = 'tab20'\n", - " session_state = ndl.setFigureObjs_UMAP(session_state, palette = palette)\n", + " session_state = ndl.setFigureObjs_UMAP(session_state, palette = st.session_state.palette_dict)\n", "\n", " return session_state" ] @@ -187,7 +203,7 @@ "# Run Top of Page (TOP) functions\n", "session_state = top_of_page_reqs(session_state)\n", "\n", - "file_index = 0\n", + "file_index = 2\n", "selectProj_u = 'C:/DATA/neighborhood-profiles/'\n", "filename = files[file_index]\n", "dataset_path = selectProj_u + filename\n", @@ -231,12 +247,12 @@ "session_state.bc.startTimer()\n", "npf.setup_spatial_umap(df = session_state.df,\n", " marker_names = session_state.marker_multi_sel,\n", - " pheno_order = session_state.phenoOrder)\n", - "\n", - "npf.perform_density_calc(session_state.cpu_pool_size)\n", + " pheno_order = session_state.phenoOrder,\n", + " smallest_image_size = session_state.datafile_min_img_size)\n", "\n", - "# Record time elapsed\n", - "session_state.bc.set_value_df('time_to_run_counts', session_state.bc.elapsedTime())\n", + "npf.perform_density_calc(calc_areas = False,\n", + " cpu_pool_size = session_state.cpu_pool_size,\n", + " area_threshold = 0.0)\n", "\n", "npf.cell_counts_completed = True" ] @@ -249,7 +265,10 @@ "source": [ "## Perform UMAP\n", "session_state.bc.startTimer()\n", - "session_state = npf.perform_spatial_umap(session_state, umap_style = 'Densities')\n", + "session_state = npf.perform_spatial_umap(session_state,\n", + " umap_subset_per_fit= 20,\n", + " umap_subset_toggle = False,\n", + " umap_subset_per = 100)\n", "\n", "# Record time elapsed\n", "session_state.bc.printElapsedTime(msg = 'Performing UMAP')\n", @@ -262,7 +281,117 @@ "metadata": {}, "outputs": [], "source": [ - "npf.spatial_umap = bpl.perform_clusteringUMAP(npf.spatial_umap, 5)" + "import seaborn as sns\n", + "import multiprocessing as mp\n", + "spatial_umap = npf.spatial_umap\n", + "n_clusters = 5\n", + "clust_minmax = session_state.clust_minmax\n", + "cpu_pool_size = session_state.cpu_pool_size\n", + "\n", + "clust_range = range(clust_minmax[0], clust_minmax[1]+1)\n", + "\n", + "kwargs_list = []\n", + "for clust in clust_range:\n", + " kwargs_list.append(\n", + " (\n", + " spatial_umap.umap_test,\n", + " clust\n", + " )\n", + " )\n", + "\n", + "# Create a pool of worker processes\n", + "with mp.Pool(processes=cpu_pool_size) as pool:\n", + " results = pool.starmap(bpl.kmeans_calc, kwargs_list)\n", + "\n", + "wcss = [x.inertia_ for x in results]\n", + "\n", + "# Create WCSS Elbow Plot\n", + "spatial_umap.elbow_fig = bpl.draw_wcss_elbow_plot(clust_range, wcss, n_clusters)\n", + "\n", + "# Identify the kmeans obj that matches the selected cluster number\n", + "kmeans_obj_targ = results[n_clusters-1]\n", + "\n", + "spatial_umap.cluster_dict = dict()\n", + "for i in range(n_clusters):\n", + " spatial_umap.cluster_dict[i+1] = f'Cluster {i+1}'\n", + "spatial_umap.cluster_dict[0] = 'No Cluster'\n", + "\n", + "spatial_umap.palette_dict = dict()\n", + "for i in range(n_clusters):\n", + " spatial_umap.palette_dict[f'Cluster {i+1}'] = sns.color_palette('tab20')[i]\n", + "spatial_umap.palette_dict['No Cluster'] = 'white'\n", + "\n", + "# Assign values to cluster_label column in df_umap\n", + "spatial_umap.df_umap.loc[:, 'clust_label'] = [spatial_umap.cluster_dict[key] for key in (kmeans_obj_targ.labels_+1)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# After assigning cluster labels, perform mean calculations\n", + "spatial_umap.mean_measures()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import multiprocessing as mp\n", + "spatial_umap = npf.spatial_umap\n", + "n_clusters = 5\n", + "clust_minmax = session_state.clust_minmax\n", + "cpu_pool_size = session_state.cpu_pool_size\n", + "\n", + "clust_range = range(clust_minmax[0], clust_minmax[1]+1)\n", + "\n", + "kwargs_list = []\n", + "for clust in clust_range:\n", + " kwargs_list.append(\n", + " (\n", + " spatial_umap.umap_test,\n", + " clust\n", + " )\n", + " )\n", + "\n", + "# Create a pool of worker processes\n", + "with mp.Pool(processes=cpu_pool_size) as pool:\n", + " results = pool.starmap(bpl.kmeans_calc, kwargs_list)\n", + "\n", + "wcss = [x.inertia_ for x in results]\n", + "\n", + "# Create WCSS Elbow Plot\n", + "spatial_umap.elbow_fig = bpl.draw_wcss_elbow_plot(clust_range, wcss, n_clusters)\n", + "\n", + "# Identify the kmeans obj that matches the selected cluster number\n", + "kmeans_obj_targ = results[n_clusters-1]\n", + "\n", + "spatial_umap.cluster_dict = dict()\n", + "for i in range(n_clusters):\n", + " spatial_umap.cluster_dict[i+1] = f'Cluster {i+1}'\n", + "spatial_umap.cluster_dict[0] = 'No Cluster'\n", + "\n", + "spatial_umap.palette_dict = dict()\n", + "for i in range(n_clusters):\n", + " spatial_umap.palette_dict[f'Cluster {i+1}'] = sns.color_palette('tab20')[i]\n", + "spatial_umap.palette_dict['No Cluster'] = 'white'\n", + "\n", + "# Assign values to cluster_label column in df_umap\n", + "spatial_umap.df_umap.loc[:, 'clust_label'] = [spatial_umap.cluster_dict[key] for key in (kmeans_obj_targ.labels_+1)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dens_umap_test = spatial_umap.density[spatial_umap.cells['umap_test'], :, :]" ] }, { @@ -271,7 +400,30 @@ "metadata": {}, "outputs": [], "source": [ - "print(npf.spatial_umap.density.shape)" + "print(dens_umap_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# After assigning cluster labels, perform mean calculations\n", + "spatial_umap.mean_measures()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "session_state.slider_clus_val = 5\n", + "npf.spatial_umap = bpl.umap_clustering(npf.spatial_umap,\n", + " session_state.slider_clus_val,\n", + " session_state.clust_minmax,\n", + " session_state.cpu_pool_size)" ] }, { diff --git a/neighborhood_profiles.py b/neighborhood_profiles.py index 47d8446..2f98490 100644 --- a/neighborhood_profiles.py +++ b/neighborhood_profiles.py @@ -9,20 +9,20 @@ Individual processing of UMAP density matrices ''' -import multiprocessing as mp import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.cluster import KMeans # K-Means -import umap +import umap # slow from scipy import ndimage as ndi import basic_phenotyper_lib as bpl # Useful functions for cell phenotyping import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP from benchmark_collector import benchmark_collector # Benchmark Collector Class import PlottingTools as umPT - +import utils +from natsort import natsorted class NeighborhoodProfiles: ''' Organization of the methods and attributes that are required to run @@ -108,14 +108,14 @@ def reset_neigh_profile_settings(self): self.inciOutcomeSel = self.definciOutcomes self.Inci_Value_display = 'Count Differences' - def setup_spatial_umap(self, df, marker_names, pheno_order): + def setup_spatial_umap(self, df, marker_names, pheno_order, smallest_image_size): ''' Silly I know. I will fix it later ''' - self.spatial_umap = bpl.setup_Spatial_UMAP(df, marker_names, pheno_order) + self.spatial_umap = bpl.setup_Spatial_UMAP(df, marker_names, pheno_order, smallest_image_size) - def perform_density_calc(self, cpu_pool_size = 1): + def perform_density_calc(self, calc_areas, cpu_pool_size = 1, area_threshold = 0.001): ''' Calculate the cell counts, cell areas, perform the cell densities and cell proportions analyses. @@ -145,16 +145,19 @@ def perform_density_calc(self, cpu_pool_size = 1): self.bc.printElapsedTime(f'Calculating Counts for {len(self.spatial_umap.cells)} cells') # get the areas of cells and save to pickle file - area_threshold = 0.001 - print('\nStarting Cell Areas process') - self.spatial_umap.get_areas(area_threshold, pool_size=cpu_pool_size) + print(f'\nStarting Cell Areas process with area threshold of {area_threshold}') + self.bc.startTimer() + self.spatial_umap.get_areas(calc_areas, area_threshold, pool_size=cpu_pool_size) + self.bc.printElapsedTime(f'Calculating Areas for {len(self.spatial_umap.cells)} cells') # calculate density based on counts of cells / area of each arc examine self.spatial_umap.calc_densities(area_threshold) # calculate proportions based on species counts/# cells within an arc self.spatial_umap.calc_proportions(area_threshold) - def perform_spatial_umap(self, session_state, umap_style = 'density'): + self.spatial_umap.density_completed = True + + def perform_spatial_umap(self, session_state, umap_subset_per_fit, umap_subset_toggle, umap_subset_per): ''' Perform the spatial UMAP analysis @@ -167,9 +170,13 @@ def perform_spatial_umap(self, session_state, umap_style = 'density'): spatial_umap: spatial_umap object with the UMAP analysis performed ''' + min_image_size = self.spatial_umap.smallest_image_size + n_fit = int(min_image_size*umap_subset_per_fit/100) + n_tra = n_fit + int(min_image_size*umap_subset_per/100) + # set training and "test" cells for umap training and embedding, respectively print('Setting Train/Test Split') - self.spatial_umap.set_train_test(n=2500, groupby_label = 'TMA_core_id', seed=54321) + self.spatial_umap.set_train_test(n_fit=n_fit, n_tra = n_tra, groupby_label = 'TMA_core_id', seed=54321, umap_subset_toggle = umap_subset_toggle) # fit umap on training cells self.bc.startTimer() @@ -183,6 +190,8 @@ def perform_spatial_umap(self, session_state, umap_style = 'density'): self.spatial_umap.umap_test = self.spatial_umap.umap_fit.transform(self.spatial_umap.density[self.spatial_umap.cells['umap_test'].values].reshape((self.spatial_umap.cells['umap_test'].sum(), -1))) self.bc.printElapsedTime(f' Transforming {np.sum(self.spatial_umap.cells["umap_test"] == 1)} points with the model') + self.spatial_umap.umap_completed = True + # Identify all of the features in the dataframe self.outcomes = self.spatial_umap.cells.columns @@ -491,6 +500,38 @@ def umap_summary_stats(self): self.dens_max = np.max(self.dens_mat) self.minabs = np.min([np.abs(self.dens_min), np.abs(self.dens_max)]) + def check_feature_values(self, feature): + ''' + + Returns: + int: 0: Feature is inappropriate for splitting + int: 2: Feature is boolean and is easily split + int 3-15: Feature has a few different options but can be easily compared when values are selected + int: 100: Feature is a numerical range and can be split by finding the median + ''' + + col = self.df[feature] # Column in question + dtypes = col.dtype # Column Type + n_uni = col.nunique() # Number of unique values + + # If only 1 unique value, then the feature cannot be split + if n_uni <= 1: + return 0 + # If exactly 2 values, then the value can be easily split. + elif n_uni == 2: + return 2 + # If more than 2 values but less than 15, then the values + # can be easily split by two chosen values + elif n_uni > 2 and n_uni <= 15: + return n_uni + else: + if dtypes == 'category' or dtypes == 'object': + return 0 + else: + # If there are more than 15 unique values, and the values are numerical, + # then the Feature can be split by the median + return 100 + def filter_by_lineage(self, display_toggle, drop_val, default_val): ''' Function for filtering UMAP function based on Phenotypes or Markers @@ -509,7 +550,7 @@ def filter_by_lineage(self, display_toggle, drop_val, default_val): elif display_toggle == 'Markers': self.df = self.df.loc[self.df['species_name_short'].str.contains(drop_val), :] - def split_df_by_feature(self, feature): + def split_df_by_feature(self, feature, val_fals=None, val_true=None, val_code=None): ''' split_df_by_feature takes in a feature from a dataframe and first identifies if the feature is boolean, if it contains @@ -526,34 +567,73 @@ def split_df_by_feature(self, feature): Args: feature (str): Feature to split the dataframe by + val_fals (int): Value to use for the false condition + val_true (int): Value to use for the true condition + val_code (int): Code to use for the split Returns: split_dict (dict): Dictionary of the outcomes of splitting - the dataframe + the dataframe with the following parameters + appro_feat (bool): True if the feature is appropriate for splitting + df_umap_fals (Pandas dataframe): Dataframe of the false condition + df_umap_true (Pandas dataframe): Dataframe of the true condition + fals_msg (str): Message for the false condition + true_msg (str): Message for the true condition ''' + # Set up the dictionary for the split split_dict = dict() - # Idenfify the column type that is splitting the UMAP - col_type = ndl.identify_col_type(self.df[feature]) - if col_type == 'not_bool': - # Identify UMAP by Condition - median = np.round(self.df[feature].median(), 2) + # Check the feature values + if val_code is None: + val_code = self.check_feature_values(feature) + + # Set default values for the false and true conditions + if val_fals is None: + # Get the unique values of the feature + feat_vals_uniq = natsorted(self.df[feature].unique()) + + if val_code == 0: + val_fals = None + val_true = None + elif val_code == 100: + # Get the median value of the feature + median_val = np.round(self.df[feature].median(), decimals = 2) + + val_fals = median_val + val_true = median_val + elif val_code == 2: + val_fals = feat_vals_uniq[0] + val_true = feat_vals_uniq[1] + else: + # We can later make this more sophisticated + # but this is only ever reached if the feature values + # are not otherwise previously identified. + # I dont think think this will be too much of a problem. + # If we need more specificity on this in the future, it can + # be easily added. + val_fals = feat_vals_uniq[0] + val_true = feat_vals_uniq[1] + + if val_code == 0: + split_dict['appro_feat'] = False + split_dict['df_umap_fals'] = None + split_dict['df_umap_true'] = None + split_dict['fals_msg'] = 'Feature is inappropriate for splitting' + split_dict['true_msg'] = 'Feature is inappropriate for splitting' + elif val_code == 100: + median = val_fals + split_dict['appro_feat'] = True split_dict['df_umap_fals'] = self.df.loc[self.df[feature] <= median, :] split_dict['df_umap_true'] = self.df.loc[self.df[feature] > median, :] - split_dict['fals_msg'] = f'<= {median}' - split_dict['true_msg'] = f'> {median}' - split_dict['appro_feat'] = True - elif col_type == 'bool': - # Identify UMAP by Condition - values = self.df[feature].unique() - split_dict['df_umap_fals'] = self.df.loc[self.df[feature] == values[0], :] - split_dict['df_umap_true'] = self.df.loc[self.df[feature] == values[1], :] - split_dict['fals_msg'] = f'= {values[0]}' - split_dict['true_msg'] = f'= {values[1]}' - split_dict['appro_feat'] = True + split_dict['fals_msg'] = f'<= {median:.2f}' + split_dict['true_msg'] = f'> {median:.2f}' else: - split_dict['appro_feat'] = False + split_dict['appro_feat'] = True + split_dict['df_umap_fals'] = self.df.loc[self.df[feature] == val_fals, :] + split_dict['df_umap_true'] = self.df.loc[self.df[feature] == val_true, :] + split_dict['fals_msg'] = f'= {val_fals}' + split_dict['true_msg'] = f'= {val_true}' return split_dict @@ -658,13 +738,17 @@ def perform_clustering(self, dens_mat_cmp, num_clus_0, num_clus_1, clust_minmax, ) ) - # Create a pool of worker processes - with mp.Pool(processes=cpu_pool_size) as pool: - results_0 = pool.starmap(self.kmeans_calc, kwargs_list_0) - - # Create a pool of worker processes - with mp.Pool(processes=cpu_pool_size) as pool: - results_1 = pool.starmap(self.kmeans_calc, kwargs_list_1) + results_0 = utils.execute_data_parallelism_potentially(self.kmeans_calc, + kwargs_list_0, + nworkers = cpu_pool_size, + task_description='KMeans Clustering for False Condition', + use_starmap=True) + + results_1 = utils.execute_data_parallelism_potentially(self.kmeans_calc, + kwargs_list_1, + nworkers = cpu_pool_size, + task_description='KMeans Clustering for True Condition', + use_starmap=True) wcss_0 = [x.inertia_ for x in results_0] wcss_1 = [x.inertia_ for x in results_1] diff --git a/new_phenotyping_lib.py b/new_phenotyping_lib.py index 3e79a85..801d2db 100644 --- a/new_phenotyping_lib.py +++ b/new_phenotyping_lib.py @@ -73,6 +73,9 @@ def map_species_to_possibly_compound_phenotypes(df, phenotype_identification_fil # Get prefixed column names for the individual phenotypes in order to avoid possible duplication of columns phenotype_colnames = ['phenotype ' + x for x in full_phenotype_list] + # Determine all the species integers in the dataframe + species_int_not_in_id_file = df[species_int_colname].unique() + # For each row in the biologists' phenotype specification file... num_updated_rows = 0 for spec_row in df_phenotype_spec.iterrows(): @@ -99,6 +102,17 @@ def map_species_to_possibly_compound_phenotypes(df, phenotype_identification_fil df.loc[curr_df_indexes, 'phenotype_int'] = curr_phenotype_int num_updated_rows = num_updated_rows + len(curr_df_indexes) + # Remove the current species integer from the list of all species integers in the dataframe + species_int_not_in_id_file = species_int_not_in_id_file[species_int_not_in_id_file != curr_species_int] + + # Filter out species with integer IDs that do not appear to be present in the phenotype identification file + if len(species_int_not_in_id_file) > 0: + print(f'Filtering out species with integer IDs {species_int_not_in_id_file} because they do not appear to be present in the phenotype iD file {phenotype_identification_file}...') + num_rows_before = len(df) + df = df[~df[species_int_colname].isin(species_int_not_in_id_file)].copy() + num_rows_after = len(df) + print(f'Filtered out {num_rows_before - num_rows_after} rows') + # Ensure the total number of rows modified equals the size of the dataframe itself assert num_updated_rows == len(df), 'ERROR: Not a one-to-one mapping of the rows' diff --git a/nidap_dashboard_lib.py b/nidap_dashboard_lib.py index bba1ea4..49e5cd6 100644 --- a/nidap_dashboard_lib.py +++ b/nidap_dashboard_lib.py @@ -10,16 +10,14 @@ import numpy as np import pandas as pd import altair as alt +alt.data_transformers.disable_max_rows() from natsort import natsorted from pathlib import Path from datetime import datetime -alt.data_transformers.disable_max_rows() - -# Import relevant libraries import basic_phenotyper_lib as bpl # Useful functions for cell phenotyping from foundry_IO_lib import foundry_IO_lib # Foundry Input/Output Class from benchmark_collector import benchmark_collector # Benchmark Collector Class -from neighborhood_profiles import NeighborhoodProfiles, UMAPDensityProcessing +from neighborhood_profiles import NeighborhoodProfiles, UMAPDensityProcessing # slow because this imports umap import PlottingTools as umPT def identify_col_type(col): @@ -133,13 +131,18 @@ def init_session_state(session_state): # General Neighborhood Profile Page Settings session_state.cpu_pool_size = 7 - session_state.umap_subset_toggle = False + session_state.umap_subset_toggle = True session_state.umap_subset_per = 20 session_state.area_filter_per = 0.001 session_state.clust_minmax = [1, 10] session_state.toggle_clust_diff = False + session_state.cluster_completed_diff = False session_state.appro_feat = False + # UMAP Differences Page Settings + session_state.umap_ins_msg = None + session_state.umap_diff_msg = None + # Set data_loaded = False. # This needs to happen at the end to counteract the 'loadDataButton' action session_state.data_loaded = False @@ -243,7 +246,7 @@ def loadDataButton(session_state, df_import, projectName, fileName): # Analysis Setting Init session_state.loaded_marker_names = session_state.marker_names session_state.marker_multi_sel = session_state.marker_names - session_state.pointstSliderVal_Sel = 100 + session_state.point_slider_val = 100 session_state.calcSliderVal = 100 session_state.selected_nClus = 1 # Clustering (If applicable) session_state.NormHeatRadio = 'No Norm' # Heatmap Radio @@ -286,8 +289,7 @@ def loadDataButton(session_state, df_import, projectName, fileName): # Set Figure Objects session_state.bc.startTimer() - session_state = setFigureObjs(session_state, df_plot) - session_state.pointstSliderVal_Sel = session_state.calcSliderVal + session_state = set_figure_objs(session_state, df_plot) # session_state.bc.printElapsedTime(msg = 'Setting Figure Objects') session_state.bc.set_value_df('file', fileName) @@ -368,7 +370,7 @@ def updatePhenotyping(session_state): df_plot = perform_filtering(session_state) # Update and reset Figure Objects - session_state = setFigureObjs(session_state, df_plot) + session_state = set_figure_objs(session_state, df_plot, session_state.point_slider_val) return session_state @@ -504,9 +506,17 @@ def export_results_dataset(fiol, df, path, filename, saveCompass=False, type = ' """ fiol.export_results_dataset(df, path, filename, saveCompass, type) -def setFigureObjs(session_state, df_plot, InSliderVal = None): +def set_figure_objs(session_state, df_plot, slider_val = None): """ - Organize Figure Objects to be used in plotting + Organize Figure Objects to be used in phenotyping plotting + + Args: + session_state: Streamlit data structure + df_plot: Filtered dataset to be plotted + slider_val: Value of the slider + + Returns: + session_state: Streamlit data structure """ title = [f'DATASET: {session_state.datafile}', @@ -519,23 +529,28 @@ def setFigureObjs(session_state, df_plot, InSliderVal = None): targ_cell_count = 150000 num_points = df_plot.shape[0] - if (num_points > targ_cell_count) & (InSliderVal is None): + print(f'Full image contains {num_points} points') + if (num_points > targ_cell_count) & (slider_val is None): n = targ_cell_count calc_slider_val = int(np.ceil(100*n/num_points)) df_plot = df_plot.sample(n) session_state.plotPointsCustom = False - elif InSliderVal is not None: + print(f' No slider_val selected. Randomly sampled {n} points') + elif slider_val is not None: - calc_slider_val = InSliderVal + calc_slider_val = slider_val df_plot = df_plot.sample(frac = calc_slider_val/100) session_state.plotPointsCustom = True + print(f' Slider_val selected. Randomly sampled {slider_val} points') else: n = num_points calc_slider_val = 100 session_state.plotPointsCustom = False - session_state.calcSliderVal = calc_slider_val + print(f' Number of points below {targ_cell_count}. Sampling the full image') + + session_state.point_slider_val = calc_slider_val session_state.drawnPoints = df_plot.shape[0] # Seaborn @@ -622,7 +637,7 @@ def setFigureObjs_UMAPDifferences(session_state): udp_ins = udp_true else: udp_ins = udp_ins_raw - session_state.umap_ins_msg = 'Please choose a boolean or numerical feature' + session_state.umap_ins_msg = 'Please choose a feature that is either boolean or numerical' else: udp_ins = udp_ins_raw diff --git a/pages/05a_Pheno_Cluster.py b/pages/05a_Pheno_Cluster.py deleted file mode 100644 index f1c400d..0000000 --- a/pages/05a_Pheno_Cluster.py +++ /dev/null @@ -1,564 +0,0 @@ -# Import relevant libraries -import streamlit as st -import subprocess - -try: - import parc - print("Successfully imported parc!") -except ImportError: - print("Installing parc...") - subprocess.run("pip install parc", shell=True) - try: - import parc - except ImportError: - print("Failed to import parc.") - -from ast import arg -from pyparsing import col -import streamlit as st -import app_top_of_page as top -import streamlit_dataframe_editor as sde -import streamlit as st -import pandas as pd -import anndata as ad -import scanpy as sc -import seaborn as sns -import os -import matplotlib.pyplot as plt -import phenograph -#from utag import utag -import numpy as np -import scanpy.external as sce -import plotly.express as px -import time -from utag.segmentation import utag - -def phenocluster__make_adata(df, x_cols, meta_cols): - mat = df[x_cols] - meta = df[meta_cols] - adata = ad.AnnData(mat) - adata.obs = meta - adata.layers["counts"] = adata.X.copy() - #adata.write("input/clust_dat.h5ad") - return adata - -# scanpy clustering -def RunNeighbClust(adata, n_neighbors, metric, resolution, random_state, n_principal_components): - if n_principal_components > 0: - sc.pp.pca(adata, n_comps=n_principal_components) - sc.pp.neighbors(adata, n_neighbors=n_neighbors, metric=metric, n_pcs=n_principal_components) - sc.tl.leiden(adata,resolution=resolution, random_state=random_state, n_iterations=5, flavor="igraph") - adata.obs['Cluster'] = adata.obs['leiden'] - #sc.tl.umap(adata) - adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) - return adata - -# phenograph clustering -def RunPhenographClust(adata, n_neighbors, clustering_algo, min_cluster_size, - primary_metric, resolution_parameter, nn_method, random_seed, n_principal_components): - #sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=0) - if n_principal_components == 0: - communities, graph, Q = phenograph.cluster(adata.X, clustering_algo=clustering_algo, k=n_neighbors, - min_cluster_size=min_cluster_size, primary_metric=primary_metric, - resolution_parameter=resolution_parameter, nn_method=nn_method, - seed=random_seed, n_iterations=5) - else: - sc.pp.pca(adata, n_comps=n_principal_components) - communities, graph, Q = phenograph.cluster(adata.obsm['X_pca'], clustering_algo=clustering_algo, k=n_neighbors, - min_cluster_size=min_cluster_size, primary_metric=primary_metric, - resolution_parameter=resolution_parameter, nn_method=nn_method, - seed=random_seed, n_iterations=5) - adata.obs['Cluster'] = communities - adata.obs['Cluster'] = adata.obs['Cluster'].astype(str) - #sc.tl.umap(adata) - adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) - return adata - -# parc clustering -def run_parc_clust(adata, n_neighbors, dist_std_local, jac_std_global, small_pop, - random_seed, resolution_parameter, hnsw_param_ef_construction, n_principal_components): - #sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=0) - if n_principal_components == 0: - parc_results = parc.PARC(adata.X, dist_std_local=dist_std_local, jac_std_global=jac_std_global, - small_pop=small_pop, random_seed=random_seed, knn=n_neighbors, - resolution_parameter=resolution_parameter, - hnsw_param_ef_construction=hnsw_param_ef_construction, - partition_type="RBConfigurationVP", - n_iter_leiden=5) - else: - sc.pp.pca(adata, n_comps=n_principal_components) - parc_results = parc.PARC(adata.obsm['X_pca'], dist_std_local=dist_std_local, jac_std_global=jac_std_global, - small_pop=small_pop, random_seed=random_seed, knn=n_neighbors, - resolution_parameter=resolution_parameter, - hnsw_param_ef_construction=hnsw_param_ef_construction, - partition_type="RBConfigurationVP", - n_iter_leiden=5) - parc_results.run_PARC() - adata.obs['Cluster'] = parc_results.labels - adata.obs['Cluster'] = adata.obs['Cluster'].astype(str) - #sc.tl.umap(adata) - adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) - return adata - -# utag clustering -# need to make image selection based on the variable -def run_utag_clust(adata, n_neighbors, resolutions, clustering_method, max_dist, n_principal_components): - #sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=0) - #sc.tl.umap(adata) - adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) - utag_results = utag(adata, - slide_key="Image ID_(standardized)", - max_dist=max_dist, - normalization_mode='l1_norm', - apply_clustering=True, - clustering_method = clustering_method, - resolutions = resolutions, - leiden_kwargs={"n_iterations": 5, "random_state": 42}, - pca_kwargs = {"n_comps": n_principal_components} - ) - - - curClusterCol = 'UTAG Label_leiden_' + str(resolutions[0]) - utag_results.obs['Cluster'] = utag_results.obs[curClusterCol] - adata.obs['Cluster'] = utag_results.obs[curClusterCol] - - curClusterCol = 'UTAG Label_leiden_' + str(resolutions[0]) - utag_results.obs['Cluster'] = utag_results.obs[curClusterCol] - adata.obs['Cluster'] = utag_results.obs[curClusterCol] - return utag_results - -def phenocluster__scanpy_umap(adata, n_neighbors, metric, n_principal_components): - if n_principal_components > 0: - sc.pp.pca(adata, n_comps=n_principal_components) - sc.pp.neighbors(adata, n_neighbors=n_neighbors, metric=metric, n_pcs=n_principal_components) - sc.tl.umap(adata) - st.session_state['phenocluster__clustering_adata'] = adata - -# plot umaps -def phenocluster__plotly_umaps(adata, umap_cur_col, umap_cur_groups, umap_color_col): - with phenocluster__col2: - subcol1, subcol2 = st.columns(2) - for i, umap_cur_group in enumerate(umap_cur_groups): - if umap_cur_group == "All": - subDat = adata - else: - subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] - umap_coords = subDat.obsm['X_umap'] - df = pd.DataFrame(umap_coords, columns=['UMAP_1', 'UMAP_2']) - clustersList = list(subDat.obs[umap_color_col] ) - df[umap_color_col] = clustersList - df[umap_color_col] = df[umap_color_col].astype(str) - # Create the seaborn plot - fig = px.scatter(df, - x="UMAP_1", - y="UMAP_2", - color=umap_color_col, - title="UMAP " + umap_cur_group - #color_discrete_sequence=px.colors.sequential.Plasma - ) - fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots - fig.update_layout( - title=dict( - text="UMAP " + umap_cur_group, - x=0.5, # Center the title - xanchor='center', - yanchor='top' - ), - legend=dict( - orientation="h", - yanchor="top", - y=-0.2, - xanchor="right", - x=1 - ) - ) - if i % 2 == 0: - subcol1.plotly_chart(fig, use_container_width=True) - else: - subcol2.plotly_chart(fig, use_container_width=True) - -# plot spatial -def spatial_plots_cust_2(adata, umap_cur_col, umap_cur_groups, umap_color_col): - with phenocluster__col2: - subcol3, subcol4 = st.columns(2) - for i, umap_cur_group in enumerate(umap_cur_groups): - if umap_cur_group == "All": - subDat = adata - else: - subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] - umap_coords = subDat.obs[['Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)']] - df = pd.DataFrame(umap_coords).reset_index().drop('index', axis = 1) - clustersList = list(subDat.obs[umap_color_col] ) - df[umap_color_col] = clustersList - df[umap_color_col] = df[umap_color_col].astype(str) - fig = px.scatter(df, - x="Centroid X (µm)_(standardized)", - y="Centroid Y (µm)_(standardized)", - color=umap_color_col, - title="Spatial " + umap_cur_group - #color_discrete_sequence=px.colors.sequential.Plasma - ) - fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots - fig.update_layout( - title=dict( - text="Spatial " + umap_cur_group, - x=0.5, # Center the title - xanchor='center', - yanchor='top' - ), - legend=dict( - orientation="h", - yanchor="top", - y=-0.2, - xanchor="right", - x=1 - ) - ) - if i % 2 == 0: - subcol3.plotly_chart(fig, use_container_width=True) - else: - subcol4.plotly_chart(fig, use_container_width=True) - -# make Umaps and Spatial Plots -def make_all_plots(): - # make umaps - phenocluster__plotly_umaps(st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__umap_cur_col'], - st.session_state['phenocluster__umap_cur_groups'], - st.session_state['phenocluster__umap_color_col']) - # make spatial plots - spatial_plots_cust_2(st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__umap_cur_col'], - st.session_state['phenocluster__umap_cur_groups'], - st.session_state['phenocluster__umap_color_col']) - - - # default session state values -def phenocluster__default_session_state(): - - if 'phenocluster__subset_data' not in st.session_state: - st.session_state['phenocluster__subset_data'] = False - - if 'phenocluster__cluster_method' not in st.session_state: - st.session_state['phenocluster__cluster_method'] = "phenograph" - - if 'phenocluster__resolution' not in st.session_state: - st.session_state['phenocluster__resolution'] = 1.0 - - # phenograph options - if 'phenocluster__n_neighbors_state' not in st.session_state: - st.session_state['phenocluster__n_neighbors_state'] = 10 - - if 'phenocluster__phenograph_k' not in st.session_state: - st.session_state['phenocluster__phenograph_k'] = 30 - - if 'phenocluster__phenograph_clustering_algo' not in st.session_state: - st.session_state['phenocluster__phenograph_clustering_algo'] = 'louvain' - - if 'phenocluster__phenograph_min_cluster_size' not in st.session_state: - st.session_state['phenocluster__phenograph_min_cluster_size'] = 10 - - if 'phenocluster__metric' not in st.session_state: - st.session_state['phenocluster__metric'] = 'euclidean' - - if 'phenocluster__phenograph_nn_method' not in st.session_state: - st.session_state['phenocluster__phenograph_nn_method'] = 'kdtree' - - if 'phenocluster__n_principal_components' not in st.session_state: - st.session_state['phenocluster__n_principal_components'] = 10 - - # parc options - # dist_std_local, jac_std_global, small_pop, random_seed, resolution_parameter, hnsw_param_ef_construction - if 'phenocluster__parc_dist_std_local' not in st.session_state: - st.session_state['phenocluster__parc_dist_std_local'] = 3 - - if 'phenocluster__parc_jac_std_global' not in st.session_state: - st.session_state['phenocluster__parc_jac_std_global'] = 0.15 - - if 'phenocluster__parc_small_pop' not in st.session_state: - st.session_state['phenocluster__parc_small_pop'] = 50 - - if 'phenocluster__random_seed' not in st.session_state: - st.session_state['phenocluster__random_seed'] = 42 - - if 'phenocluster__hnsw_param_ef_construction' not in st.session_state: - st.session_state['phenocluster__hnsw_param_ef_construction'] = 150 - - # utag options - #clustering_method ["leiden", "parc"]; resolutions; max_dist = 20 - if 'phenocluster__utag_clustering_method' not in st.session_state: - st.session_state['phenocluster__utag_clustering_method'] = 'leiden' - - if 'phenocluster__utag_max_dist' not in st.session_state: - st.session_state['phenocluster__utag_max_dist'] = 20 - - # umap options - #if 'phenocluster__umap_cur_col' not in st.session_state: - #st.session_state['phenocluster__umap_cur_col'] = "Image" - - if 'phenocluster__umap_color_col' not in st.session_state: - st.session_state['phenocluster__umap_color_col'] = "Cluster" - - if 'phenocluster__umap_cur_groups' not in st.session_state: - st.session_state['phenocluster__umap_cur_groups'] = ["All"] - - # differential intensity options - if 'phenocluster__de_col' not in st.session_state: - st.session_state['phenocluster__de_col'] = "Cluster" - - if 'phenocluster__de_sel_group' not in st.session_state: - st.session_state['phenocluster__de_sel_groups'] = ["All"] - - if 'phenocluster__plot_diff_intensity_method' not in st.session_state: - st.session_state['phenocluster__plot_diff_intensity_method'] = "Rank Plot" - - if 'phenocluster__plot_diff_intensity_n_genes' not in st.session_state: - st.session_state['phenocluster__plot_diff_intensity_n_genes'] = 10 - - -# subset data set -def phenocluster__subset_data(adata, subset_col, subset_vals): - adata_subset = adata[adata.obs[subset_col].isin(subset_vals)] - st.session_state['phenocluster__clustering_adata'] = adata_subset - -# clusters differential expression -def phenocluster__diff_expr(adata, phenocluster__de_col, phenocluster__de_sel_groups): - sc.tl.rank_genes_groups(adata, groupby = phenocluster__de_col, method="wilcoxon", layer="counts") - - if "All" in phenocluster__de_sel_groups: - phenocluster__de_results = sc.get.rank_genes_groups_df(adata, group=None) - else: - phenocluster__de_results = sc.get.rank_genes_groups_df(adata, group=phenocluster__de_sel_groups) - with phenocluster__col2: - st.dataframe(phenocluster__de_results, use_container_width=True) - - -# main -def main(): - """ - Main function for the page. - """ - #st.write(st.session_state['unifier__df'].head()) - - # set default values - phenocluster__default_session_state() - - - - # make layout with columns - # options - - with phenocluster__col_0[0]: - - numeric_cols = st.multiselect('Select numeric columns for clustering:', options = st.session_state['input_dataset'].data.columns, - key='phenocluster__X_cols') - meta_columns = st.multiselect('Select columns for metadata:', options = st.session_state['input_dataset'].data.columns, - key='phenocluster__meta_cols') - - items_to_add = ['Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)'] - - #meta_columns = st.session_state['phenocluster__meta_cols'] - # Add the new items if they don't already exist in the list - for item in items_to_add: - if item not in meta_columns: - meta_columns.append(item) - - if st.button('Submuit columns'): - st.session_state['phenocluster__clustering_adata'] = phenocluster__make_adata(st.session_state['input_dataset'].data, - numeric_cols, - meta_columns) - - if 'phenocluster__clustering_adata' in st.session_state: - - with phenocluster__col1: - - # subset data - st.checkbox('Subset Data:', key='phenocluster__subset_data') - if st.session_state['phenocluster__subset_data'] == True: - st.session_state['phenocluster__subset_options'] = list(st.session_state['phenocluster__clustering_adata'].obs.columns) - phenocluster__subset_col = st.selectbox('Select column for subsetting:', st.session_state['phenocluster__subset_options']) - st.session_state["phenocluster__subset_col"] = phenocluster__subset_col - st.session_state['phenocluster__subset_values_options'] = list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state["phenocluster__subset_col"]])) - phenocluster__subset_vals = st.multiselect('Select value for subsetting:', options = st.session_state['phenocluster__subset_values_options'], key='phenocluster__subset_vals_1') - st.session_state["phenocluster__subset_vals"] = phenocluster__subset_vals - if st.button('Subset Data'): - phenocluster__subset_data(st.session_state['phenocluster__clustering_adata'], - st.session_state["phenocluster__subset_col"], - st.session_state["phenocluster__subset_vals"]) - - # st.button('Subset Data' , on_click=phenocluster__subset_data, args = [st.session_state['phenocluster__clustering_adata'], - # st.session_state["phenocluster__subset_col"], - # st.session_state["phenocluster__subset_vals"] - - # ] - # ) - - clusteringMethods = ['phenograph', 'scanpy', 'parc', 'utag'] - selected_clusteringMethod = st.selectbox('Select Clustering method:', clusteringMethods, - key='clusteringMethods_dropdown') - - # Update session state on every change - st.session_state['phenocluster__cluster_method'] = selected_clusteringMethod - - # default widgets - - st.number_input(label = "Number of Principal Components", key='phenocluster__n_principal_components', step = 1) - - st.session_state['phenocluster__n_neighbors_state'] = st.number_input(label = "K Nearest Neighbors", - value=st.session_state['phenocluster__n_neighbors_state']) - - st.number_input(label = "Clustering resolution", key='phenocluster__resolution') - - if st.session_state['phenocluster__cluster_method'] == "phenograph": - # st.session_state['phenocluster__phenograph_k'] = st.number_input(label = "Phenograph k", - # value=st.session_state['phenocluster__phenograph_k']) - st.selectbox('Phenograph clustering algorithm:', ['louvain', 'leiden'], key='phenocluster__phenograph_clustering_algo') - st.number_input(label = "Phenograph min cluster size", key='phenocluster__phenograph_min_cluster_size', step = 1) - st.selectbox('Distance metric:', ['euclidean', 'manhattan', 'correlation', 'cosine'], key='phenocluster__metric') - st.selectbox('Phenograph nn method:', ['kdtree', 'brute'], key='phenocluster__phenograph_nn_method') - - elif st.session_state['phenocluster__cluster_method'] == "scanpy": - st.selectbox('Distance metric:', ['euclidean', 'manhattan', 'correlation', 'cosine'], key='phenocluster__metric') - - elif st.session_state['phenocluster__cluster_method'] == "parc": - # make parc specific widgets - st.number_input(label = "Parc dist std local", key='phenocluster__parc_dist_std_local', step = 1) - st.number_input(label = "Parc jac std global", key='phenocluster__parc_jac_std_global', step = 0.01) - st.number_input(label = "Minimum cluster size to be considered a separate population", - key='phenocluster__parc_small_pop', step = 1) - st.number_input(label = "Random seed", key='phenocluster__random_seed', step = 1) - st.number_input(label = "HNSW exploration factor for construction", - key='phenocluster__hnsw_param_ef_construction', step = 1) - elif st.session_state['phenocluster__cluster_method'] == "utag": - # make utag specific widgets - #st.selectbox('UTAG clustering method:', ['leiden', 'parc'], key='phenocluster__utag_clustering_method') - st.number_input(label = "UTAG max dist", key='phenocluster__utag_max_dist', step = 1) - - # add options if clustering has been run - if st.button('Run Clustering'): - start_time = time.time() - if st.session_state['phenocluster__cluster_method'] == "phenograph": - with st.spinner('Wait for it...'): - st.session_state['phenocluster__clustering_adata'] = RunPhenographClust(adata=st.session_state['phenocluster__clustering_adata'], n_neighbors=st.session_state['phenocluster__n_neighbors_state'], - clustering_algo=st.session_state['phenocluster__phenograph_clustering_algo'], - min_cluster_size=st.session_state['phenocluster__phenograph_min_cluster_size'], - primary_metric=st.session_state['phenocluster__metric'], - resolution_parameter=st.session_state['phenocluster__resolution'], - nn_method=st.session_state['phenocluster__phenograph_nn_method'], - random_seed=st.session_state['phenocluster__random_seed'], - n_principal_components=st.session_state['phenocluster__n_principal_components'] - ) - elif st.session_state['phenocluster__cluster_method'] == "scanpy": - with st.spinner('Wait for it...'): - st.session_state['phenocluster__clustering_adata'] = RunNeighbClust(adata=st.session_state['phenocluster__clustering_adata'], - n_neighbors=st.session_state['phenocluster__n_neighbors_state'], - metric=st.session_state['phenocluster__metric'], - resolution=st.session_state['phenocluster__resolution'], - random_state=st.session_state['phenocluster__random_seed'], - n_principal_components=st.session_state['phenocluster__n_principal_components'] - ) - #st.session_state['phenocluster__clustering_adata'] = adata - elif st.session_state['phenocluster__cluster_method'] == "parc": - with st.spinner('Wait for it...'): - st.session_state['phenocluster__clustering_adata'] = run_parc_clust(adata=st.session_state['phenocluster__clustering_adata'], - n_neighbors=st.session_state['phenocluster__n_neighbors_state'], - dist_std_local=st.session_state['phenocluster__parc_dist_std_local'], - jac_std_global= st.session_state['phenocluster__parc_jac_std_global'], - small_pop=st.session_state['phenocluster__parc_small_pop'], - random_seed=st.session_state['phenocluster__random_seed'], - resolution_parameter=st.session_state['phenocluster__resolution'], - hnsw_param_ef_construction=st.session_state['phenocluster__hnsw_param_ef_construction'], - n_principal_components=st.session_state['phenocluster__n_principal_components'] - ) - elif st.session_state['phenocluster__cluster_method'] == "utag": - phenocluster__utag_resolutions = [st.session_state['phenocluster__resolution']] - with st.spinner('Wait for it...'): - st.session_state['phenocluster__clustering_adata'] = run_utag_clust(adata=st.session_state['phenocluster__clustering_adata'], - n_neighbors=st.session_state['phenocluster__n_neighbors_state'], - resolutions=phenocluster__utag_resolutions, - clustering_method=st.session_state['phenocluster__utag_clustering_method'], - max_dist=st.session_state['phenocluster__utag_max_dist'], - n_principal_components=st.session_state['phenocluster__n_principal_components'] - ) - # save clustering result - #st.session_state['phenocluster__clustering_adata'].write("input/clust_dat.h5ad") - end_time = time.time() - execution_time = end_time - start_time - rounded_time = round(execution_time, 2) - st.write('Execution time: ', rounded_time, 'seconds') - - - # umap - if 'Cluster' in st.session_state['phenocluster__clustering_adata'].obs.columns: - st.session_state['input_dataset'].data["Phenotype_Cluster"] = st.session_state['phenocluster__clustering_adata'].obs["Cluster"] - #st.write(pd.unique(st.session_state['phenocluster__clustering_adata'].obs["Cluster"])) - - st.session_state['phenocluster__umeta_columns'] = list(st.session_state['phenocluster__clustering_adata'].obs.columns) - st.session_state['phenocluster__umap_color_col_index'] = st.session_state['phenocluster__umeta_columns'].index(st.session_state['phenocluster__umap_color_col']) - #st.write(st.session_state['phenocluster__umap_color_col_index']) - - # select column for umap coloring - st.session_state['phenocluster__umap_color_col'] = st.selectbox('Select column for groups coloring:', - st.session_state['phenocluster__umeta_columns'], - index=st.session_state['phenocluster__umap_color_col_index'] - ) - - # select column for umap subsetting - st.session_state['phenocluster__umap_cur_col'] = st.selectbox('Select column to subset plots:', - st.session_state['phenocluster__umeta_columns'], key='phenocluster__umap_col_dropdown_subset' - ) - - # list of available subsetting options - umap_cur_groups= ["All"] + list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state['phenocluster__umap_cur_col']])) - umap_sel_groups = st.multiselect('Select groups to be plotted', - options = umap_cur_groups) - st.session_state['phenocluster__umap_cur_groups'] = umap_sel_groups - - st.button('Make Spatial Plots' , on_click=spatial_plots_cust_2, args = [st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__umap_cur_col'], - st.session_state['phenocluster__umap_cur_groups'], - st.session_state['phenocluster__umap_color_col'] - ] - ) - - st.button("Compute UMAPs", on_click=phenocluster__scanpy_umap, args = [st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__n_neighbors_state'], - st.session_state['phenocluster__metric'], - st.session_state['phenocluster__n_principal_components'] - ] - ) - - st.button('Plot UMAPs' , on_click=phenocluster__plotly_umaps, args = [st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__umap_cur_col'], - st.session_state['phenocluster__umap_cur_groups'], - st.session_state['phenocluster__umap_color_col'], - ] - ) - - - - - -# Run the main function -if __name__ == '__main__': - - # Set page settings - page_name = 'Unsupervised Phenotype Clustering' - st.set_page_config(layout='wide', page_title=page_name) - st.title(page_name) - phenocluster__col_0 = st.columns(1) - phenocluster__col1, phenocluster__col2 = st.columns([1, 6]) - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Call the main function - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) - -# need to make differential expression on another page diff --git a/pages/05b_Pheno_Cluster.py b/pages/05b_Pheno_Cluster.py deleted file mode 100644 index 4c5e031..0000000 --- a/pages/05b_Pheno_Cluster.py +++ /dev/null @@ -1,291 +0,0 @@ -# Import relevant libraries -from ast import arg -from pyparsing import col -import streamlit as st -import app_top_of_page as top -import streamlit_dataframe_editor as sde -import streamlit as st -import pandas as pd -import anndata as ad -import scanpy as sc -import seaborn as sns -import os -import matplotlib.pyplot as plt -import phenograph -import parc -from utag import utag -import numpy as np -import scanpy.external as sce -import plotly.express as px -import streamlit_dataframe_editor as sde -import basic_phenotyper_lib as bpl -import nidap_dashboard_lib as ndl - - - -# Functions - -# clusters differential expression -def phenocluster__diff_expr(adata, phenocluster__de_col, phenocluster__de_sel_groups, plot_column): - sc.tl.rank_genes_groups(adata, groupby = phenocluster__de_col, method="wilcoxon", layer="counts") - - if "All" in phenocluster__de_sel_groups: - phenocluster__de_results = sc.get.rank_genes_groups_df(adata, group=None) - else: - phenocluster__de_results = sc.get.rank_genes_groups_df(adata, group=phenocluster__de_sel_groups) - with plot_column: - phenocluster__de_results[['pvals', 'pvals_adj']] = phenocluster__de_results[['pvals', 'pvals_adj']].applymap('{:.1e}'.format) - st.dataframe(phenocluster__de_results, use_container_width=True) - - -# change cluster names -def phenocluster__edit_cluster_names(adata, edit_names_result): - adata.obs['Edit_Cluster'] = adata.obs['Cluster'].map(edit_names_result.set_index('Cluster')['New_Name']) - st.session_state['phenocluster__clustering_adata'] = adata - -def phenocluster__edit_cluster_names_2(adata, edit_names_result): - edit_names_result_2 = edit_names_result.reconstruct_edited_dataframe() - adata.obs['Edit_Cluster'] = adata.obs['Cluster'].map(dict(zip(edit_names_result_2['Cluster'].to_list(), edit_names_result_2['New_Name'].to_list()))) - st.session_state['phenocluster__clustering_adata'] = adata - -# make differential intensity plots -def phenocluster__plot_diff_intensity(adata, groups, method, n_genes, plot_column): - if "All" in groups: - cur_groups = None - else: - cur_groups = groups - - if method == "Rank Plot": - cur_fig = sc.pl.rank_genes_groups(adata, n_genes=n_genes, - groups=cur_groups, sharey=False) - elif method == "Dot Plot": - cur_fig = sc.pl.rank_genes_groups_dotplot(adata, n_genes=n_genes, - groups=cur_groups) - elif method == "Heat Map": - cur_fig = sc.pl.rank_genes_groups_heatmap(adata, n_genes=n_genes, - groups=cur_groups) - elif method == "Violin Plot": - cur_fig = sc.pl.rank_genes_groups_stacked_violin(adata, n_genes=n_genes, - groups=cur_groups, split = False) - - with plot_column: - st.pyplot(fig = cur_fig, clear_figure=None, use_container_width=True) - -def data_editor_change_callback(): - ''' - data_editor_change_callback is a callback function for the streamlit data_editor widget - which updates the saved value of the user-created changes after every instance of the - data_editor on_change method. This ensures the dashboard can remake the edited data_editor - when the user navigates to a different page. - ''' - - st.session_state.df = bpl.assign_phenotype_custom(st.session_state.df, st.session_state['phenocluster__edit_names_result_2a'].reconstruct_edited_dataframe()) - - # Create Phenotypes Summary Table based on 'phenotype' column in df - st.session_state.pheno_summ = bpl.init_pheno_summ(st.session_state.df) - - # Perform filtering - st.session_state.df_filt = ndl.perform_filtering(st.session_state) - - # Set Figure Objects based on updated df - st.session_state = ndl.setFigureObjs(st.session_state, st.session_state.pointstSliderVal_Sel) - -def phenocluster__plotly_umaps_b(adata, umap_cur_col, umap_cur_groups, umap_color_col, plot_column): - with plot_column: - subcol1, subcol2 = st.columns(2) - for i, umap_cur_group in enumerate(umap_cur_groups): - if umap_cur_group == "All": - subDat = adata - else: - subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] - umap_coords = subDat.obsm['X_umap'] - df = pd.DataFrame(umap_coords, columns=['UMAP_1', 'UMAP_2']) - clustersList = list(subDat.obs[umap_color_col] ) - df[umap_color_col] = clustersList - df[umap_color_col] = df[umap_color_col].astype(str) - # Create the seaborn plot - fig = px.scatter(df, - x="UMAP_1", - y="UMAP_2", - color=umap_color_col, - title="UMAP " + umap_cur_group - #color_discrete_sequence=px.colors.sequential.Plasma - ) - fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots - fig.update_layout( - title=dict( - text="UMAP " + umap_cur_group, - x=0.5, # Center the title - xanchor='center', - yanchor='top' - ), - legend=dict( - orientation="h", - yanchor="top", - y=-0.2, - xanchor="right", - x=1 - ) - ) - if i % 2 == 0: - subcol1.plotly_chart(fig, use_container_width=True) - else: - subcol2.plotly_chart(fig, use_container_width=True) - -def spatial_plots_cust_2b(adata, umap_cur_col, umap_cur_groups, umap_color_col, plot_column): - with plot_column: - subcol3, subcol4 = st.columns(2) - for i, umap_cur_group in enumerate(umap_cur_groups): - if umap_cur_group == "All": - subDat = adata - else: - subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] - umap_coords = subDat.obs[['Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)']] - df = pd.DataFrame(umap_coords).reset_index().drop('index', axis = 1) - clustersList = list(subDat.obs[umap_color_col] ) - df[umap_color_col] = clustersList - df[umap_color_col] = df[umap_color_col].astype(str) - fig = px.scatter(df, - x="Centroid X (µm)_(standardized)", - y="Centroid Y (µm)_(standardized)", - color=umap_color_col, - title="Spatial " + umap_cur_group - #color_discrete_sequence=px.colors.sequential.Plasma - ) - fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots - fig.update_layout( - title=dict( - text="Spatial " + umap_cur_group, - x=0.5, # Center the title - xanchor='center', - yanchor='top' - ), - legend=dict( - orientation="h", - yanchor="top", - y=-0.2, - xanchor="right", - x=1 - ) - ) - if i % 2 == 0: - subcol3.plotly_chart(fig, use_container_width=True) - else: - subcol4.plotly_chart(fig, use_container_width=True) - - -def main(): - phenocluster__col1b, phenocluster__col2b = st.columns([1, 6]) - with phenocluster__col1b: - # differential expression - phenocluster__de_col_options = list(st.session_state['phenocluster__clustering_adata'].obs.columns) - st.selectbox('Select column for differential expression:', phenocluster__de_col_options, key='phenocluster__de_col') - phenocluster__de_groups = ["All"] + list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state['phenocluster__de_col']])) - phenocluster__de_selected_groups = st.multiselect('Select group for differential expression table:', options = phenocluster__de_groups) - st.session_state['phenocluster__de_sel_groups'] = phenocluster__de_selected_groups - # Differential expression - - st.button('Run Differential Expression', on_click=phenocluster__diff_expr, args = [st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__de_col'], - st.session_state['phenocluster__de_sel_groups'], - phenocluster__col2b - ]) - - phenocluster__col3b, phenocluster__col4b = st.columns([1, 6]) - phenocluster__col5b, phenocluster__col6b = st.columns([1, 6]) - cur_clusters = list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs["Cluster"])) - edit_names_df = pd.DataFrame({"Cluster": cur_clusters, "New_Name": cur_clusters}) - st.session_state['phenocluster__edit_names_df'] = edit_names_df - - with phenocluster__col3b: - # Plot differential intensity - phenocluster__dif_int_plot_methods = ["Rank Plot", "Dot Plot", "Heat Map", "Violin Plot"] - st.selectbox('Select Plot Type:', phenocluster__dif_int_plot_methods, key='phenocluster__plot_diff_intensity_method') - st.number_input(label = "Number of genes to plot", - key = 'phenocluster__plot_diff_intensity_n_genes', - step = 1) - - - st.button('Plot Markers', on_click=phenocluster__plot_diff_intensity, args = [st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__de_sel_groups'], - st.session_state['phenocluster__plot_diff_intensity_method'], - st.session_state['phenocluster__plot_diff_intensity_n_genes'], - phenocluster__col4b - ]) - - with phenocluster__col6b: - #st.table(st.session_state['phenocluster__edit_names_df']) - #edit_clustering_names = st.data_editor(edit_names_df) - #st.session_state['phenocluster__edit_names_result'] = edit_clustering_names - if 'phenocluster__edit_names_result_2' not in st.session_state: - st.session_state['phenocluster__edit_names_result_2'] = sde.DataframeEditor(df_name='phenocluster__edit_names_result_2a', default_df_contents=st.session_state['phenocluster__edit_names_df']) - - #st.session_state['phenocluster__edit_names_result_2'].dataframe_editor(on_change=data_editor_change_callback, reset_data_editor_button_text='Reset New Clusters Names') - st.session_state['phenocluster__edit_names_result_2'].dataframe_editor(reset_data_editor_button_text='Reset New Clusters Names') - - with phenocluster__col5b: - #Edit cluster names - st.button('Edit Clusters Names', on_click=phenocluster__edit_cluster_names_2, args = [st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__edit_names_result_2'] - ]) - phenocluster__col7b, phenocluster__col8b, phenocluster__col9b = st.columns([1, 3, 3]) - def make_all_plots_2(): - spatial_plots_cust_2b(st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__umap_cur_col'], - st.session_state['phenocluster__umap_cur_groups'], - st.session_state['phenocluster__umap_color_col_2'], - phenocluster__col9b - ) - # make umaps plots - if 'X_umap' in st.session_state['phenocluster__clustering_adata'].obsm.keys(): - phenocluster__plotly_umaps_b(st.session_state['phenocluster__clustering_adata'], - st.session_state['phenocluster__umap_cur_col'], - st.session_state['phenocluster__umap_cur_groups'], - st.session_state['phenocluster__umap_color_col_2'], - phenocluster__col8b - ) - with phenocluster__col7b: - # select column for umap coloring - st.session_state['phenocluster__umeta_columns'] = list(st.session_state['phenocluster__clustering_adata'].obs.columns) - if 'Edit_Cluster' in st.session_state['phenocluster__umeta_columns']: - st.session_state['phenocluster__umap_color_col_index'] = st.session_state['phenocluster__umeta_columns'].index('Edit_Cluster') - else: - st.session_state['phenocluster__umap_color_col_index'] = st.session_state['phenocluster__umeta_columns'].index(st.session_state['phenocluster__umap_color_col']) - - st.session_state['phenocluster__umap_color_col_2'] = st.selectbox('Select column to color groups:', - st.session_state['phenocluster__umeta_columns'], - index=st.session_state['phenocluster__umap_color_col_index'] - ) - # select column for umap subsetting - st.session_state['phenocluster__umap_cur_col'] = st.selectbox('Select column to subset plots:', - st.session_state['phenocluster__umeta_columns'], key='phenocluster__umap_col_dropdown_subset' - ) - # list of available subsetting options - umap_cur_groups= ["All"] + list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state['phenocluster__umap_cur_col']])) - umap_sel_groups = st.multiselect('Select groups to be plotted', - options = umap_cur_groups) - st.session_state['phenocluster__umap_cur_groups'] = umap_sel_groups - - st.button('Make Plots' , on_click=make_all_plots_2) - - -# Run the main function -if __name__ == '__main__': - - # Set page settings - page_name = 'Differential Intensity' - st.set_page_config(layout='wide', page_title=page_name) - st.title(page_name) - st.set_option('deprecation.showPyplotGlobalUse', False) - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Call the main function - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) \ No newline at end of file diff --git a/pages/macro_radial_density.py b/pages/macro_radial_density.py deleted file mode 100644 index 576a4ee..0000000 --- a/pages/macro_radial_density.py +++ /dev/null @@ -1,33 +0,0 @@ -# Import relevant libraries -import streamlit as st -import app_top_of_page as top -import streamlit_dataframe_editor as sde - - -def main(): - """ - Main function for the page. - """ - - st.write('Insert your code here.') - - -# Run the main function -if __name__ == '__main__': - - # Set page settings - page_name = 'Macro radial density' - st.set_page_config(layout='wide', page_title=page_name) - st.title(page_name) - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Call the main function - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/04c_Clusters Analyzer.py b/pages2/Clusters_Analyzer.py similarity index 79% rename from pages/04c_Clusters Analyzer.py rename to pages2/Clusters_Analyzer.py index 5826926..c72f25b 100644 --- a/pages/04c_Clusters Analyzer.py +++ b/pages2/Clusters_Analyzer.py @@ -3,11 +3,7 @@ ''' import streamlit as st - -# Import relevant libraries import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP -import app_top_of_page as top -import streamlit_dataframe_editor as sde def reset_phenotype_selection(): ''' @@ -73,19 +69,4 @@ def main(): st.pyplot(st.session_state.inciFig) if __name__ == '__main__': - - # Set a wide layout - st.set_page_config(page_title="Clusters Analyzer", - layout="wide") - st.title('Clusters Analyzer') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/03f_Display_ROI_P_values_overlaid_on_slides.py b/pages2/Display_ROI_P_values_overlaid_on_slides.py similarity index 91% rename from pages/03f_Display_ROI_P_values_overlaid_on_slides.py rename to pages2/Display_ROI_P_values_overlaid_on_slides.py index caea337..cd72ac4 100644 --- a/pages/03f_Display_ROI_P_values_overlaid_on_slides.py +++ b/pages2/Display_ROI_P_values_overlaid_on_slides.py @@ -3,9 +3,6 @@ import utils as utils import os -import app_top_of_page as top -import streamlit_dataframe_editor as sde - save_image_ext = 'jpg' def main(): @@ -24,18 +21,6 @@ def update_neighbor_species_index(neighbor_species_names): def update_neighbor_species_name(neighbor_species_names): st.session_state['neighbor_species_name_to_visualize'] = neighbor_species_names[st.session_state['neighbor_species_index_to_visualize']] - # Set a wide layout - st.set_page_config(layout="wide") - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Display page heading - st.title('ROI P values overlaid on slides') - if os.path.exists(os.path.join('.', 'output', 'images', 'density_pvals_over_slide_spatial_plot')): # Create an expander to hide some optional widgets @@ -103,8 +88,6 @@ def update_neighbor_species_name(neighbor_species_names): else: st.warning('The component "Plot density P values for each ROI over slide spatial plot" of the workflow does not appear to have been run; please select it on the "Run workflow" page', icon='⚠️') - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) if __name__ == '__main__': main() diff --git a/pages/03d_Display_average_heatmaps.py b/pages2/Display_average_heatmaps.py similarity index 83% rename from pages/03d_Display_average_heatmaps.py rename to pages2/Display_average_heatmaps.py index e894903..9397f04 100644 --- a/pages/03d_Display_average_heatmaps.py +++ b/pages2/Display_average_heatmaps.py @@ -3,8 +3,6 @@ import utils as utils import os -import app_top_of_page as top -import streamlit_dataframe_editor as sde def main(): @@ -14,17 +12,6 @@ def update_slide_index(slide_names): def update_slide_name(slide_names): st.session_state['slide_name_to_visualize'] = slide_names[st.session_state['slide_index_to_visualize']] - # Set a wide layout - st.set_page_config(layout="wide") - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Display page heading - st.title('Average heatmaps per slide') if os.path.exists(os.path.join('.', 'output', 'images', 'whole_slide_patches')) and os.path.exists(os.path.join('.', 'output', 'images', 'dens_pvals_per_slide')): @@ -66,7 +53,11 @@ def update_slide_name(slide_names): # Display the three images for the currently selected slide with display_col1: - st.image(df_paths_per_slide.loc[st.session_state['slide_name_to_visualize'], 'heatmap']) + image_path_entry = df_paths_per_slide.loc[st.session_state['slide_name_to_visualize'], 'heatmap'] + if not isinstance(image_path_entry, float): + st.image(image_path_entry) + else: + st.info('No heatmap data are available for this slide') st.radio('Display slide patching at right?', ['not patched', 'patched'], key='display_slide_patching') with display_col2: slide_suffix = ('' if st.session_state['display_slide_patching'] == 'not patched' else '_patched') @@ -75,8 +66,6 @@ def update_slide_name(slide_names): else: st.warning('At least one of the two sets of per-slide plots does not exist; please run all per-slide components of the workflow on the "Run workflow" page', icon='⚠️') - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) if __name__ == '__main__': main() diff --git a/pages/03e_Display_average_heatmaps_per_annotation.py b/pages2/Display_average_heatmaps_per_annotation.py similarity index 94% rename from pages/03e_Display_average_heatmaps_per_annotation.py rename to pages2/Display_average_heatmaps_per_annotation.py index eb16247..cbc7817 100644 --- a/pages/03e_Display_average_heatmaps_per_annotation.py +++ b/pages2/Display_average_heatmaps_per_annotation.py @@ -4,25 +4,10 @@ import os import streamlit_utils -import app_top_of_page as top -import streamlit_dataframe_editor as sde - save_image_ext = 'jpg' def main(): - # Set a wide layout - st.set_page_config(layout="wide") - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Display page heading - st.title('Annotation plots') - if os.path.exists(os.path.join('.', 'output', 'images', 'raw_weights_check')) and os.path.exists(os.path.join('.', 'output', 'images', f'all_annotation_data.{save_image_ext}')) and os.path.exists(os.path.join('.', 'output', 'images', 'weight_heatmaps_on_annot')) and os.path.exists(os.path.join('.', 'output', 'images', 'pixel_plot')) and os.path.exists(os.path.join('.', 'output', 'images', 'analysis_overlaid_on_annotation')) and os.path.exists(os.path.join('.', 'output', 'images', 'dens_pvals_per_annotation')): # Constant: annotation-dependent plot types (there are two more that are annotation-independent; see the "Optional plots" section) @@ -165,8 +150,6 @@ def main(): else: st.warning('The component "Average density P values over ROIs for each annotation region type" of the workflow does not appear to have been run; please select it on the "Run workflow" page', icon='⚠️') - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) if __name__ == '__main__': main() diff --git a/pages/03c_Display_individual_ROI_heatmaps.py b/pages2/Display_individual_ROI_heatmaps.py similarity index 83% rename from pages/03c_Display_individual_ROI_heatmaps.py rename to pages2/Display_individual_ROI_heatmaps.py index c7150e8..a0a14c4 100644 --- a/pages/03c_Display_individual_ROI_heatmaps.py +++ b/pages2/Display_individual_ROI_heatmaps.py @@ -3,8 +3,6 @@ import streamlit as st import utils as utils -import app_top_of_page as top -import streamlit_dataframe_editor as sde def main(): @@ -14,18 +12,6 @@ def update_roi_index(roi_names): def update_roi_name(roi_names): st.session_state['roi_name_to_visualize'] = roi_names[st.session_state['roi_index_to_visualize']] - # Set a wide layout - st.set_page_config(layout="wide") - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Display page heading - st.title('Individual ROI heatmaps') - if os.path.exists(os.path.join('.', 'output', 'images', 'single_roi_outlines_on_whole_slides')) and os.path.exists(os.path.join('.', 'output', 'images', 'roi_plots')) and os.path.exists(os.path.join('.', 'output', 'images', 'dens_pvals_per_roi')): # Create an expander to hide some optional widgets @@ -70,19 +56,16 @@ def update_roi_name(roi_names): with display_col1: st.image(df_paths_per_roi.loc[st.session_state['roi_name_to_visualize'], 'roi']) image_path_entry = df_paths_per_roi.loc[st.session_state['roi_name_to_visualize'], 'heatmap'] - # if isinstance(image_path_entry, str): if image_path_entry != '': st.image(image_path_entry) else: - st.write('No heatmap data available') + st.info('No heatmap data are available for this ROI') with display_col2: st.image(df_paths_per_roi.loc[st.session_state['roi_name_to_visualize'], 'outline']) else: st.warning('At least one of the three sets of per-ROI plots does not exist; please run all per-ROI components of the workflow on the "Run workflow" page', icon='⚠️') - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) if __name__ == '__main__': main() diff --git a/pages/04a_Neighborhood Profiles.py b/pages2/Neighborhood_Profiles.py similarity index 65% rename from pages/04a_Neighborhood Profiles.py rename to pages2/Neighborhood_Profiles.py index c03e05e..50b8ae4 100644 --- a/pages/04a_Neighborhood Profiles.py +++ b/pages2/Neighborhood_Profiles.py @@ -9,12 +9,9 @@ from streamlit_extras.add_vertical_space import add_vertical_space import matplotlib.pyplot as plt import pandas as pd - -# Import relevant libraries +from natsort import natsorted import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP import basic_phenotyper_lib as bpl # Useful functions for phenotyping collections of cells -import app_top_of_page as top -import streamlit_dataframe_editor as sde from neighborhood_profiles import NeighborhoodProfiles, UMAPDensityProcessing def get_spatialUMAP(spatial_umap, bc, umap_subset_per_fit, umap_subset_toggle, umap_subset_per): @@ -90,7 +87,7 @@ def init_spatial_umap(): st.session_state.density_completed = True # Save checkpoint for Neighborhood Profile structure - save_neipro_struct() + # save_neipro_struct() def apply_umap(umap_style): ''' @@ -98,13 +95,13 @@ def apply_umap(umap_style): ''' st.session_state.bc.startTimer() - # if togle for loading pre-generated UMAP is selected extract UMAP from file, works only with a specific dataset + # if toggle for loading pre-generated UMAP is selected extract UMAP from file, works only with a specific dataset if st.session_state['load_generated_umap_toggle']: st.session_state.spatial_umap = get_spatialUMAP(st.session_state.spatial_umap, - st.session_state.bc, - st.session_state.umap_subset_per_fit, - st.session_state.umap_subset_toggle, - st.session_state.umap_subset_per) + st.session_state.bc, + st.session_state.umap_subset_per_fit, + st.session_state.umap_subset_toggle, + st.session_state.umap_subset_per) else: with st.spinner('Calculating UMAP'): st.session_state.spatial_umap = bpl.perform_spatialUMAP(st.session_state.spatial_umap, @@ -128,6 +125,8 @@ def apply_umap(umap_style): st.session_state.outcomes = st.session_state.spatial_umap.cells.columns st.session_state.spatial_umap.outcomes = st.session_state.spatial_umap.cells.columns + st.session_state.dens_diff_feat_sel = st.session_state.outcomes[0] + # List of possible outcome variables as defined by the config yaml files st.session_state.umapOutcomes = [st.session_state.defumapOutcomes] st.session_state.umapOutcomes.extend(st.session_state.outcomes) @@ -143,7 +142,7 @@ def apply_umap(umap_style): # Create Neighborhood Profiles Object st.session_state.npf = NeighborhoodProfiles(bc = st.session_state.bc) - + # Create Full UMAP example st.session_state.udp_full = UMAPDensityProcessing(st.session_state.npf, st.session_state.spatial_umap.df_umap) st.session_state.UMAPFig = st.session_state.udp_full.UMAPdraw_density() @@ -152,7 +151,7 @@ def apply_umap(umap_style): filter_and_plot() # Save checkpoint for Neighborhood Profile structure - save_neipro_struct() + # save_neipro_struct() def set_clusters(): ''' @@ -163,98 +162,101 @@ def set_clusters(): with st.spinner('Calculating Clusters'): # If clustering is to be performed on the UMAP density difference - if st.session_state['toggle_clust_diff']: - - split_dict_full = st.session_state.udp_full.split_df_by_feature(st.session_state.dens_diff_feat_sel) - st.session_state.appro_feat = split_dict_full['appro_feat'] - - # If the feature is appropriate, perform the density difference split/clustering - if st.session_state.appro_feat: - - # Perform Density Calculations for each Condition - udp_fals = UMAPDensityProcessing(st.session_state.npf, split_dict_full['df_umap_fals'], xx=st.session_state.udp_full.xx, yy=st.session_state.udp_full.yy) - udp_true = UMAPDensityProcessing(st.session_state.npf, split_dict_full['df_umap_true'], xx=st.session_state.udp_full.xx, yy=st.session_state.udp_full.yy) - - ## Copy over - udp_diff = copy(udp_fals) - ## Perform difference calculation - udp_diff.dens_mat = np.log10(udp_fals.dens_mat) - np.log10(udp_true.dens_mat) - ## Rerun the min/max calcs - udp_diff.umap_summary_stats() - ## Set Feature Labels - udp_fals.set_feature_label(st.session_state.dens_diff_feat_sel, split_dict_full['fals_msg']) - udp_true.set_feature_label(st.session_state.dens_diff_feat_sel, split_dict_full['true_msg']) - udp_diff.set_feature_label(st.session_state.dens_diff_feat_sel, 'Difference') - - # Draw UMAPS - st.session_state.UMAPFig_fals = udp_fals.UMAPdraw_density() - st.session_state.UMAPFig_true = udp_true.UMAPdraw_density() - st.session_state.UMAPFig_diff = udp_diff.UMAPdraw_density(diff=True) - - # Assign Masking and plot - udp_mask = copy(udp_diff) - udp_mask.filter_density_matrix(st.session_state.dens_diff_cutoff, st.session_state.udp_full.empty_bin_ind) - udp_mask.set_feature_label(st.session_state.dens_diff_feat_sel, f'Difference- Masked, \ncutoff = {st.session_state.dens_diff_cutoff}') - st.session_state.UMAPFig_mask = udp_mask.UMAPdraw_density(diff=True) - - # Perform Clustering - udp_clus = copy(udp_mask) - udp_clus.perform_clustering(dens_mat_cmp=udp_mask.dens_mat, - num_clus_0=st.session_state.num_clus_0, - num_clus_1=st.session_state.num_clus_1, - clust_minmax=st.session_state.clust_minmax, - cpu_pool_size=st.session_state.cpu_pool_size) - udp_clus.set_feature_label(st.session_state.dens_diff_feat_sel, f'Clusters, False-{st.session_state.num_clus_0}, True-{st.session_state.num_clus_1}') - st.session_state.UMAPFig_clus = udp_clus.UMAPdraw_density(diff=True, legendtype='legend') - st.session_state.cluster_dict = udp_clus.cluster_dict - st.session_state.palette_dict = udp_clus.palette_dict - st.session_state.elbow_fig_0 = udp_clus.elbow_fig_0 - st.session_state.elbow_fig_1 = udp_clus.elbow_fig_1 - - # Add cluster label column to cells dataframe - st.session_state.spatial_umap.df_umap.loc[:, 'clust_label'] = 'No Cluster' - st.session_state.spatial_umap.df_umap.loc[:, 'cluster'] = 'No Cluster' - st.session_state.spatial_umap.df_umap.loc[:, 'Cluster'] = 'No Cluster' - - st.session_state.bc.startTimer() - for key, val in st.session_state.cluster_dict.items(): - if key != 0: - bin_clust = np.argwhere(udp_clus.dens_mat == key) - bin_clust = bin_clust[:, [1, 0]] # Swapping columns to by y, x - bin_clust = [tuple(x) for x in bin_clust] - - significant_groups = st.session_state.udp_full.bin_indices_df_group[st.session_state.udp_full.bin_indices_df_group.set_index(['indx', 'indy']).index.isin(bin_clust)] - - umap_ind = significant_groups.index.values - st.session_state.spatial_umap.df_umap.loc[umap_ind, 'clust_label'] = val - st.session_state.spatial_umap.df_umap.loc[umap_ind, 'cluster'] = val - st.session_state.spatial_umap.df_umap.loc[umap_ind, 'Cluster'] = val - - st.session_state.bc.printElapsedTime('Untangling bin indicies with UMAP indicies') - - # After assigning cluster labels, perform mean calculations - st.session_state.bc.startTimer() - st.session_state.spatial_umap.mean_measures() - st.session_state.bc.printElapsedTime('Performing Mean Measures') - - # Average False condition and Average True Condition - dens_df_fals = st.session_state.spatial_umap.dens_df_mean.loc[st.session_state.spatial_umap.dens_df_mean['clust_label'].str.contains('False'), :] - dens_df_true = st.session_state.spatial_umap.dens_df_mean.loc[st.session_state.spatial_umap.dens_df_mean['clust_label'].str.contains('True'), :] - - dens_df_fals['clust_label'] = 'Average False' - dens_df_mean_fals = dens_df_fals.groupby(['clust_label', 'phenotype', 'dist_bin'], as_index=False).mean() - - dens_df_true['clust_label'] = 'Average True' - dens_df_mean_true = dens_df_true.groupby(['clust_label', 'phenotype', 'dist_bin'], as_index=False).mean() - - st.session_state.spatial_umap.dens_df_mean = pd.concat([st.session_state.spatial_umap.dens_df_mean, dens_df_mean_fals, dens_df_mean_true], axis=0) + if st.session_state['toggle_clust_diff'] and st.session_state['appro_feat']: + + # Split the UMAP by the selected values of the feature + split_dict_full = st.session_state.udp_full.split_df_by_feature(st.session_state.dens_diff_feat_sel, + st.session_state.feature_value_fals, + st.session_state.feature_value_true, + st.session_state.clust_diff_vals_code) + + # Perform Density Calculations for each Condition + udp_fals = UMAPDensityProcessing(st.session_state.npf, split_dict_full['df_umap_fals'], xx=st.session_state.udp_full.xx, yy=st.session_state.udp_full.yy) + udp_true = UMAPDensityProcessing(st.session_state.npf, split_dict_full['df_umap_true'], xx=st.session_state.udp_full.xx, yy=st.session_state.udp_full.yy) + + ## Copy over + udp_diff = copy(udp_fals) + ## Perform difference calculation + udp_diff.dens_mat = np.log10(udp_fals.dens_mat) - np.log10(udp_true.dens_mat) + ## Rerun the min/max calcs + udp_diff.umap_summary_stats() + ## Set Feature Labels + udp_fals.set_feature_label(st.session_state.dens_diff_feat_sel, split_dict_full['fals_msg']) + udp_true.set_feature_label(st.session_state.dens_diff_feat_sel, split_dict_full['true_msg']) + udp_diff.set_feature_label(st.session_state.dens_diff_feat_sel, 'Difference') + + # Draw UMAPS + st.session_state.UMAPFig_fals = udp_fals.UMAPdraw_density() + st.session_state.UMAPFig_true = udp_true.UMAPdraw_density() + st.session_state.UMAPFig_diff = udp_diff.UMAPdraw_density(diff=True) + + # Assign Masking and plot + udp_mask = copy(udp_diff) + udp_mask.filter_density_matrix(st.session_state.dens_diff_cutoff, st.session_state.udp_full.empty_bin_ind) + udp_mask.set_feature_label(st.session_state.dens_diff_feat_sel, f'Difference- Masked, \ncutoff = {st.session_state.dens_diff_cutoff}') + st.session_state.UMAPFig_mask = udp_mask.UMAPdraw_density(diff=True) + + # Perform Clustering + udp_clus = copy(udp_mask) + udp_clus.perform_clustering(dens_mat_cmp=udp_mask.dens_mat, + num_clus_0=st.session_state.num_clus_0, + num_clus_1=st.session_state.num_clus_1, + clust_minmax=st.session_state.clust_minmax, + cpu_pool_size=3) + udp_clus.set_feature_label(st.session_state.dens_diff_feat_sel, f'Clusters, False-{st.session_state.num_clus_0}, True-{st.session_state.num_clus_1}') + st.session_state.UMAPFig_clus = udp_clus.UMAPdraw_density(diff=True, legendtype='legend') + st.session_state.cluster_dict = udp_clus.cluster_dict + st.session_state.palette_dict = udp_clus.palette_dict + st.session_state.elbow_fig_0 = udp_clus.elbow_fig_0 + st.session_state.elbow_fig_1 = udp_clus.elbow_fig_1 + + # Add cluster label column to cells dataframe + st.session_state.spatial_umap.df_umap.loc[:, 'clust_label'] = 'No Cluster' + st.session_state.spatial_umap.df_umap.loc[:, 'cluster'] = 'No Cluster' + st.session_state.spatial_umap.df_umap.loc[:, 'Cluster'] = 'No Cluster' + + for key, val in st.session_state.cluster_dict.items(): + if key != 0: + bin_clust = np.argwhere(udp_clus.dens_mat == key) + bin_clust = bin_clust[:, [1, 0]] # Swapping columns to by y, x + bin_clust = [tuple(x) for x in bin_clust] + + significant_groups = st.session_state.udp_full.bin_indices_df_group[st.session_state.udp_full.bin_indices_df_group.set_index(['indx', 'indy']).index.isin(bin_clust)] + + umap_ind = significant_groups.index.values + st.session_state.spatial_umap.df_umap.loc[umap_ind, 'clust_label'] = val + st.session_state.spatial_umap.df_umap.loc[umap_ind, 'cluster'] = val + st.session_state.spatial_umap.df_umap.loc[umap_ind, 'Cluster'] = val + + # Benchmark how long it took to untangle indicies + st.session_state.bc.printElapsedTime('Untangling bin indicies with UMAP indicies', split = True) + + # After assigning cluster labels, perform mean calculations + st.session_state.spatial_umap.mean_measures() + st.session_state.bc.printElapsedTime('Performing Mean Measures', split = True) + + # Average False condition and Average True Condition + dens_df_fals = st.session_state.spatial_umap.dens_df_mean.loc[st.session_state.spatial_umap.dens_df_mean['clust_label'].str.contains('False'), :] + dens_df_true = st.session_state.spatial_umap.dens_df_mean.loc[st.session_state.spatial_umap.dens_df_mean['clust_label'].str.contains('True'), :] + + dens_df_fals['clust_label'] = 'Average False' + dens_df_mean_fals = dens_df_fals.groupby(['clust_label', 'phenotype', 'dist_bin'], as_index=False).mean() + + dens_df_true['clust_label'] = 'Average True' + dens_df_mean_true = dens_df_true.groupby(['clust_label', 'phenotype', 'dist_bin'], as_index=False).mean() + + st.session_state.spatial_umap.dens_df_mean = pd.concat([st.session_state.spatial_umap.dens_df_mean, dens_df_mean_fals, dens_df_mean_true], axis=0) + + st.session_state.cluster_completed_diff = True else: - st.session_state.spatial_umap = bpl.umap_clustering(st.session_state.spatial_umap, - st.session_state.slider_clus_val, - st.session_state.clust_minmax, - st.session_state.cpu_pool_size) + st.session_state.spatial_umap = bpl.umap_clustering(spatial_umap = st.session_state.spatial_umap, + n_clusters = st.session_state.slider_clus_val, + clust_minmax = st.session_state.clust_minmax, + cpu_pool_size = 3) + st.session_state.appro_feat = True + st.session_state.cluster_completed_diff = False st.session_state.cluster_dict = st.session_state.spatial_umap.cluster_dict st.session_state.palette_dict = st.session_state.spatial_umap.palette_dict st.session_state.selected_nClus = st.session_state.slider_clus_val @@ -271,6 +273,45 @@ def set_clusters(): filter_and_plot() +def check_feature_approval_callback(): + ''' + Simple callback to test the current value of + st.session_state.dens_diff_feat_sel + ''' + + if not st.session_state['toggle_clust_diff']: + st.session_state.appro_feat = True + else: + + # Check feature values + st.session_state.clust_diff_vals_code = st.session_state.udp_full.check_feature_values(st.session_state.dens_diff_feat_sel) + + if st.session_state.clust_diff_vals_code == 0: + st.session_state.appro_feat = False + else: + st.session_state.appro_feat = True + + feature_vals = natsorted(st.session_state.udp_full.df[st.session_state.dens_diff_feat_sel].unique()) + if st.session_state.clust_diff_vals_code == 2: + options_fals = [feature_vals[0]] + options_true = [feature_vals[1]] + elif st.session_state.clust_diff_vals_code > 2 and st.session_state.clust_diff_vals_code <= 15: + options_fals = feature_vals + options_true = feature_vals + elif st.session_state.clust_diff_vals_code == 100: + median = np.round(st.session_state.udp_full.df[st.session_state.dens_diff_feat_sel].median(), decimals = 2) + + options_fals = [median] + options_true = [median] + else: + options_fals = ['None'] + options_true = ['None'] + + st.session_state.clus_diff_vals_fals = options_fals + st.session_state.clus_diff_vals_true = options_true + st.session_state.feature_value_fals = options_fals[0] + st.session_state.feature_value_true = options_true[0] + def slide_id_prog_left_callback(): ''' callback function when the left Cell_ID progression button is clicked @@ -304,8 +345,10 @@ def slide_id_callback(): def filter_and_plot(): ''' - function to update the filtering and the figure plotting + callback function to update the filtering and the + figure plotting ''' + st.session_state.prog_left_disabeled = False st.session_state.prog_right_disabeled = False @@ -538,10 +581,10 @@ def main(): with neipro_settings[2]: st.toggle('Subset data transformed by UMAP', - value = False, key = 'umap_subset_toggle', + key = 'umap_subset_toggle', help = '''The UMAP model is always trained on a percentage of data included - in the smallest image.vYou can choose to transform the entire dataset using - this trained model, or only transformva percentage of the data. This can be + in the smallest image. You can choose to transform the entire dataset using + this trained model, or only transform a percentage of the data. This can be useful for large datasets. If a percentage is chosen for transformation, it is always a different sample than what the model was trained on.''') add_vertical_space(2) @@ -595,14 +638,24 @@ def main(): if st.session_state.umap_completed: with st.expander('Clustering Settings', expanded = True): - st.toggle('Perform Clustering on UMAP Density Difference', value = False, key = 'toggle_clust_diff') + st.toggle('Perform Clustering on UMAP Density Difference', + value = False, key = 'toggle_clust_diff', + help = '''Perform clustering on the density difference between + two levels of a dataset feature.''', + on_change=check_feature_approval_callback) clust_exp_col = st.columns(2) with clust_exp_col[0]: # Run Clustering Normally if st.session_state['toggle_clust_diff'] is True: - st.selectbox('Feature', options = st.session_state.spatial_umap.outcomes, key = 'dens_diff_feat_sel') + st.selectbox('Feature', options = st.session_state.spatial_umap.outcomes, + key = 'dens_diff_feat_sel', + help = '''Select the feature to split the UMAP by.''', + on_change=check_feature_approval_callback) + + st.selectbox('Values for False Condition', key = 'feature_value_fals', + options = st.session_state.clus_diff_vals_fals) st.number_input('Number of Clusters for False Condition', min_value = 1, max_value = 10, value = 3, step = 1, key = 'num_clus_0') if st.session_state.elbow_fig_0 is not None: st.pyplot(st.session_state.elbow_fig_0) @@ -617,6 +670,8 @@ def main(): with clust_exp_col[1]: if st.session_state['toggle_clust_diff'] is True: st.number_input('Cutoff Percentage', min_value = 0.01, max_value = 0.99, value = 0.01, step = 0.01, key = 'dens_diff_cutoff') + st.selectbox('Values for True Condition', key = 'feature_value_true', + options = st.session_state.clus_diff_vals_true) st.number_input('Number of Clusters for True Condition', min_value = 1, max_value = 10, value = 3, step = 1, key = 'num_clus_1') if st.session_state.elbow_fig_1 is not None: st.pyplot(st.session_state.elbow_fig_1) @@ -638,9 +693,9 @@ def main(): with st.columns(3)[1]: st.pyplot(fig=st.session_state.UMAPFig) - if st.session_state.cluster_completed and st.session_state['toggle_clust_diff']: + if st.session_state.cluster_completed_diff and st.session_state['toggle_clust_diff']: if st.session_state['appro_feat']: - + diff_cols = st.columns(3) with diff_cols[0]: st.pyplot(fig=st.session_state.UMAPFig_fals) @@ -654,13 +709,14 @@ def main(): st.pyplot(fig=st.session_state.UMAPFig_mask) with mor_cols[1]: st.pyplot(fig=st.session_state.diff_clust_Fig) - else: - st.write('Feature must be boolean or numeric to perform density difference analysis') + if st.session_state['appro_feat'] is False: + st.write('Feature must be boolean or numeric to perform density difference analysis') # Tab for Loading Previous UMAP Results with nei_pro_tabs[1]: - st.write('Checkpoint file: neighborhood_profiles_checkpoint.pkl') - st.button('Load checkpointed UMAP results', on_click=load_neipro_struct) + st.write('Feature coming soon!') + # st.write('Checkpoint file: neighborhood_profiles_checkpoint.pkl') + # st.button('Load checkpointed UMAP results', on_click=load_neipro_struct) add_vertical_space(19) if not st.session_state.phenotyping_completed: @@ -722,12 +778,20 @@ def main(): with viz_cols[1]: st.header('Neighborhood Profiles') with st.expander('Neighborhood Profile Options'): - nei_sett_col = st.columns([1, 2, 1]) + nei_sett_col = st.columns([1, 1, 1, 1]) with nei_sett_col[0]: + st.toggle('Manual Y-axis scaling', + key = 'toggle_manual_y_axis_scaling_main') st.toggle('Hide "Other" Phenotype', value = False, key = 'toggle_hide_other') - st.toggle('Plot on Log Scale', value = True, key = 'nei_pro_toggle_log_scale') with nei_sett_col[1]: + st.number_input('Y-axis Min', key = 'y_axis_min_main', + value = 0.1, step = 0.01,) st.toggle('Hide "No Cluster" Neighborhood Profile', value = False, key = 'toggle_hide_no_cluster') + with nei_sett_col[2]: + st.number_input('Y-axis Max', key = 'y_axis_max_main', + value = 10000, step = 10,) + with nei_sett_col[3]: + st.checkbox('Log Scale', key = 'nei_pro_toggle_log_scale', value = True) # If the spatial-umap is completed... if 'spatial_umap' in st.session_state: @@ -742,7 +806,7 @@ def main(): add_vertical_space(2) st.toggle('Compare Cluster Neighborhoods', value = False, key = 'toggle_compare_clusters') if st.session_state['toggle_compare_clusters']: - st.radio('Compare as:', ('Difference', 'Ratio'), index = 0, key = 'compare_clusters_as', horizontal=True) + st.radio('Compare as:', ('Ratio', 'Difference'), index = 0, key = 'compare_clusters_as', horizontal=True) # Cluster Select Widgets with cluster_sel_col[0]: @@ -751,9 +815,9 @@ def main(): if st.session_state['toggle_compare_clusters']: sel_npf_fig2 = st.selectbox('Select a cluster to compare', list_clusters) - if st.session_state.cluster_completed: - # Draw the Neighborhood Profile + if st.session_state.cluster_completed and st.session_state.appro_feat: + # Draw the Neighborhood Profile npf_fig, ax = bpl.draw_scatter_fig(figsize=(14, 16)) bpl.neighProfileDraw(st.session_state.spatial_umap, @@ -765,87 +829,148 @@ def main(): hide_no_cluster = st.session_state['toggle_hide_no_cluster']) if st.session_state['nei_pro_toggle_log_scale']: - ax.set_ylim([0.1, 10000]) ax.set_yscale('log') - st.pyplot(fig=npf_fig) - # Create widgets for exporting the Neighborhood Profile images - neigh_prof_col = st.columns([2, 1]) - with neigh_prof_col[0]: - st.text_input('.png file suffix (Optional)', key = 'neigh_prof_line_suffix') - with neigh_prof_col[1]: - add_vertical_space(2) - if st.button('Append Export List', key = 'appendexportbutton_neighproline__do_not_persist'): + if st.session_state['toggle_manual_y_axis_scaling_main']: + ax.set_ylim(st.session_state['y_axis_min_main'], st.session_state['y_axis_max_main']) - ndl.save_png(npf_fig, 'Neighborhood Profiles', st.session_state.neigh_prof_line_suffix) - st.toast(f'Added {st.session_state.neigh_prof_line_suffix} to export list') - - if st.session_state.umap_completed and st.session_state['toggle_clust_diff']: + if sel_npf_fig == sel_npf_fig2: + st.markdown('## Please choose two different clusters to compare') + else: + # Display the Neighborhood Profile + st.pyplot(fig=npf_fig) + + # Create widgets for exporting the Neighborhood Profile images + neigh_prof_col = st.columns([2, 1]) + with neigh_prof_col[0]: + st.text_input('.png file suffix (Optional)', key = 'neigh_prof_line_suffix') + with neigh_prof_col[1]: + add_vertical_space(2) + if st.button('Append Export List', key = 'appendexportbutton_neighproline__do_not_persist'): + + ndl.save_png(npf_fig, 'Neighborhood Profiles', st.session_state.neigh_prof_line_suffix) + st.toast(f'Added {st.session_state.neigh_prof_line_suffix} to export list') + + # Drawing the subplots of Neighborhood Profiles per cluster combinations + if st.session_state['appro_feat'] and st.session_state.cluster_completed_diff: + + supp_neipro_col = st.columns([4, 2]) + with supp_neipro_col[0]: + with st.expander('Neighborhood Profile Subplots Settings', expanded = False): + st.toggle('Manual Y-axis scaling', + key = 'toggle_manual_y_axis_scaling_supplemental') + + st.write('Settings for Individual Cluster Plots') + neipro_exp_full = st.columns([3, 3, 1]) + with neipro_exp_full[0]: + st.number_input('Y-axis Min', key = 'y_axis_min_supplemental', + value = 0.1, step = 0.001,) + with neipro_exp_full[1]: + st.number_input('Y-axis Max', key = 'y_axis_max_supplemental', + value = 10000, step = 10,) + with neipro_exp_full[2]: + st.checkbox('Log Scale', key = 'log_scale_supplemental', value = True) + + st.write('Settings for Individual Cluster Ratios') + neipro_exp_full = st.columns([3, 3, 1]) + with neipro_exp_full[0]: + st.number_input('Y-axis Min', key = 'y_axis_min_indiratio_supplemental', + value = 0.1, step = 0.001,) + with neipro_exp_full[1]: + st.number_input('Y-axis Max', key = 'y_axis_max_indiratio_supplemental', + value = 10, step = 1,) + with neipro_exp_full[2]: + st.checkbox('Log Scale', key = 'log_scale_indiratio_supplemental', value = True) + + st.write('Settings for Aggregated Cluster Ratios') + neipro_exp_full = st.columns([3, 3, 1]) + with neipro_exp_full[0]: + st.number_input('Y-axis Min', key = 'y_axis_min_aggratio_supplemental', + value = 0.1, step = 0.001,) + with neipro_exp_full[1]: + st.number_input('Y-axis Max', key = 'y_axis_max_aggratio_supplemental', + value = 10, step = 1,) + with neipro_exp_full[2]: + st.checkbox('Log Scale', key = 'log_scale_aggratio_supplemental', value = True) npf_fig_big = plt.figure(figsize=(16, 45), facecolor = '#0E1117') - list_figures = [['Average False', None, 'log', [1, 10000]], - ['Average True', None, 'log', [1, 10000]], - ['Average False', 'Average True', 'linear', [0, 4]], - ['False Cluster 1', None, 'log', [0.1, 10000]], - ['False Cluster 2', None, 'log', [0.1, 10000]], - ['False Cluster 3', None, 'log', [0.1, 10000]], - ['True Cluster 1', None, 'log', [0.1, 10000]], - ['True Cluster 2', None, 'log', [0.1, 10000]], - ['False Cluster 3', None, 'linear', [0, 2000]], - ['False Cluster 1', 'True Cluster 1', 'log', [0.01, 100]], - ['False Cluster 2', 'True Cluster 1', 'log', [0.01, 100]], - ['False Cluster 3', 'True Cluster 1', 'log', [0.01, 100]], - ['False Cluster 1', 'True Cluster 2', 'log', [0.01, 100]], - ['False Cluster 2', 'True Cluster 2', 'log', [0.01, 100]], - ['False Cluster 3', 'True Cluster 2', 'log', [0.01, 100]], - ['False Cluster 1', 'Average True', 'linear', [0, 15]], - ['False Cluster 2', 'Average True', 'linear', [0, 15]], - ['False Cluster 3', 'Average True', 'linear', [0, 15]], + title_supp = [f'DATASET: {st.session_state.datafile}', + f'FEATURE: {st.session_state.dens_diff_feat_sel}', + f'FALSE Val: {st.session_state.feature_value_fals}, TRUE Val: {st.session_state.feature_value_true}',] + + list_figures = [['Average False', None, 'Individual Cluster Plots'], + ['Average True', None, 'Individual Cluster Plots'], + ['Average False', 'Average True', 'Aggregate Cluster Ratios'], + ['False Cluster 1', None, 'Individual Cluster Plots'], + ['False Cluster 2', None, 'Individual Cluster Plots'], + ['False Cluster 3', None, 'Individual Cluster Plots'], + ['True Cluster 1', None, 'Individual Cluster Plots'], + ['True Cluster 2', None, 'Individual Cluster Plots'], + ['True Cluster 3', None, 'Individual Cluster Plots'], + ['False Cluster 1', 'True Cluster 1', 'Individual Cluster Ratios'], + ['False Cluster 2', 'True Cluster 1', 'Individual Cluster Ratios'], + ['False Cluster 3', 'True Cluster 1', 'Individual Cluster Ratios'], + ['False Cluster 1', 'True Cluster 2', 'Individual Cluster Ratios'], + ['False Cluster 2', 'True Cluster 2', 'Individual Cluster Ratios'], + ['False Cluster 3', 'True Cluster 2', 'Individual Cluster Ratios'], + ['False Cluster 1', 'Average True', 'Aggregate Cluster Ratios'], + ['False Cluster 2', 'Average True', 'Aggregate Cluster Ratios'], + ['False Cluster 3', 'Average True', 'Aggregate Cluster Ratios'], + ['True Cluster 1', 'Average False', 'Aggregate Cluster Ratios'], + ['True Cluster 2', 'Average False', 'Aggregate Cluster Ratios'], + ['True Cluster 3', 'Average False', 'Aggregate Cluster Ratios'], ] - if st.session_state.cluster_completed: - num_figs = len(list_figures) - num_cols = 3 - num_rows = np.ceil(num_figs/3).astype(int) - for ii, cluster in enumerate(list_figures): - axii = npf_fig_big.add_subplot(num_rows, 3, ii+1, facecolor = '#0E1117') + num_figs = len(list_figures) + num_cols = 3 + num_rows = np.ceil(num_figs/3).astype(int) + for ii, cluster in enumerate(list_figures): + axii = npf_fig_big.add_subplot(num_rows, 3, ii+1, facecolor = '#0E1117') - if ii == ((num_rows*num_cols)-3): - legend_flag = True - else: - legend_flag = False - - bpl.neighProfileDraw(st.session_state.spatial_umap, - ax = axii, - sel_clus = cluster[0], - cmp_clus = cluster[1], - cmp_style = 'Ratio', - hide_other = st.session_state['toggle_hide_other'], - hide_no_cluster = st.session_state['toggle_hide_no_cluster'], - legend_flag = legend_flag) + if ii == ((num_rows*num_cols)-3): + legend_flag = True + else: + legend_flag = False + + bpl.neighProfileDraw(st.session_state.spatial_umap, + ax = axii, + sel_clus = cluster[0], + cmp_clus = cluster[1], + cmp_style = 'Ratio', + hide_other = st.session_state['toggle_hide_other'], + hide_no_cluster = st.session_state['toggle_hide_no_cluster'], + legend_flag = legend_flag) + + if st.session_state['toggle_manual_y_axis_scaling_supplemental']: + if cluster[2] == 'Individual Cluster Plots': + axii.set_ylim([st.session_state['y_axis_min_supplemental'], + st.session_state['y_axis_max_supplemental']]) + if st.session_state['log_scale_supplemental']: + axii.set_yscale('log') + elif cluster[2] == 'Individual Cluster Ratios': + axii.set_ylim([st.session_state['y_axis_min_indiratio_supplemental'], + st.session_state['y_axis_max_indiratio_supplemental']]) + if st.session_state['log_scale_indiratio_supplemental']: + axii.set_yscale('log') + elif cluster[2] == 'Aggregate Cluster Ratios': + axii.set_ylim([st.session_state['y_axis_min_aggratio_supplemental'], + st.session_state['y_axis_max_aggratio_supplemental']]) + if st.session_state['log_scale_aggratio_supplemental']: + axii.set_yscale('log') + else: + axii.set_yscale('log') - if cluster[2] == 'log': - axii.set_yscale('log') + plot_title = '' + for i in title_supp: + plot_title = plot_title + i + '\n' - axii.set_ylim(cluster[3]) + # Super title with more information + npf_fig_big.suptitle(t = plot_title, color = '#FAFAFA', x = 0.1, y = 0.9, + horizontalalignment = 'left', + verticalalignment = 'top',) - st.pyplot(fig=npf_fig_big) + st.pyplot(fig=npf_fig_big) if __name__ == '__main__': - - # Set a wide layout - st.set_page_config(page_title="Neighborhood Profiles", - layout="wide") - st.title('Neighborhood Profiles') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages2/Pheno_Cluster_a.py b/pages2/Pheno_Cluster_a.py new file mode 100644 index 0000000..b459a3d --- /dev/null +++ b/pages2/Pheno_Cluster_a.py @@ -0,0 +1,1785 @@ +# Import relevant libraries +import streamlit as st +import hnswlib +import parc +from parc import PARC +import annoy +import sklearn_ann +from ast import arg +from pyparsing import col +import pandas as pd +import anndata as ad +import scanpy as sc +import seaborn as sns +import os +import matplotlib.pyplot as plt +import phenograph +import numpy as np +import scanpy.external as sce +import plotly.express as px +import time +from pynndescent import PyNNDescentTransformer +from scipy.sparse import csr_matrix +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn_ann.utils import TransformerChecksMixin +import typing as tp +import anndata +from tqdm import tqdm +import parmap +import typing as tp +import scipy +import squidpy as sq +import leidenalg +import igraph as ig +from scipy import stats +from igraph.community import _community_leiden +community_leiden = _community_leiden + +def z_score(x): + """ + Scale (divide by standard deviation) and center (subtract mean) array-like objects. + """ + return (x - x.min()) / (x.max() - x.min()) + +def sparse_matrix_dstack( + matrices: tp.Sequence[scipy.sparse.csr_matrix], +) -> scipy.sparse.csr_matrix: + """ + Diagonally stack sparse matrices. + """ + import scipy + from tqdm import tqdm + + n = sum([x.shape[0] for x in matrices]) + _res = list() + i = 0 + for x in tqdm(matrices): + v = scipy.sparse.csr_matrix((x.shape[0], n)) + v[:, i : i + x.shape[0]] = x + _res.append(v) + i += x.shape[0] + return scipy.sparse.vstack(_res) + +def utag( + adata, + channels_to_use = None, + slide_key = "Slide", + save_key: str = "UTAG Label", + filter_by_variance: bool = False, + max_dist: float = 20.0, + normalization_mode: str = "l1_norm", + keep_spatial_connectivity: bool = False, + pca_kwargs: tp.Dict[str, tp.Any] = dict(n_comps=10), + apply_umap: bool = False, + umap_kwargs: tp.Dict[str, tp.Any] = dict(), + apply_clustering: bool = True, + clustering_method: tp.Sequence[str] = ["leiden", "parc", "kmeans"], + resolutions: tp.Sequence[float] = [0.05, 0.1, 0.3, 1.0], + leiden_kwargs: tp.Dict[str, tp.Any] = None, + parc_kwargs: tp.Dict[str, tp.Any] = None, + parallel: bool = True, + processes: int = None, +): + """ + Discover tissue architechture in single-cell imaging data + by combining phenotypes and positional information of cells. + + Parameters + ---------- + adata: AnnData + AnnData object with spatial positioning of cells in obsm 'spatial' slot. + channels_to_use: Optional[Sequence[str]] + An optional sequence of strings used to subset variables to use. + Default (None) is to use all variables. + max_dist: float + Maximum distance to cut edges within a graph. + Should be adjusted depending on resolution of images. + For imaging mass cytometry, where resolution is 1um, 20 often gives good results. + Default is 20. + slide_key: {str, None} + Key of adata.obs containing information on the batch structure of the data. + In general, for image data this will often be a variable indicating the image + so image-specific effects are removed from data. + Default is "Slide". + save_key: str + Key to be added to adata object holding the UTAG clusters. + Depending on the values of `clustering_method` and `resolutions`, + the final keys will be of the form: {save_key}_{method}_{resolution}". + Default is "UTAG Label". + filter_by_variance: bool + Whether to filter vairiables by variance. + Default is False, which keeps all variables. + max_dist: float + Recommended values are between 20 to 50 depending on magnification. + Default is 20. + normalization_mode: str + Method to normalize adjacency matrix. + Default is "l1_norm", any other value will not use normalization. + keep_spatial_connectivity: bool + Whether to keep sparse matrices of spatial connectivity and distance in the obsp attribute of the + resulting anndata object. This could be useful in downstream applications. + Default is not to (False). + pca_kwargs: Dict[str, Any] + Keyword arguments to be passed to scanpy.pp.pca for dimensionality reduction after message passing. + Default is to pass n_comps=10, which uses 10 Principal Components. + apply_umap: bool + Whether to build a UMAP representation after message passing. + Default is False. + umap_kwargs: Dict[str, Any] + Keyword arguments to be passed to scanpy.tl.umap for dimensionality reduction after message passing. + Default is 10.0. + apply_clustering: bool + Whether to cluster the message passed matrix. + Default is True. + clustering_method: Sequence[str] + Which clustering method(s) to use for clustering of the message passed matrix. + Default is ["leiden", "parc"]. + resolutions: Sequence[float] + What resolutions should the methods in `clustering_method` be run at. + Default is [0.05, 0.1, 0.3, 1.0]. + leiden_kwargs: dict[str, Any] + Keyword arguments to pass to scanpy.tl.leiden. + parc_kwargs: dict[str, Any] + Keyword arguments to pass to parc.PARC. + parallel: bool + Whether to run message passing part of algorithm in parallel. + Will accelerate the process but consume more memory. + Default is True. + processes: int + Number of processes to use in parallel. + Default is to use all available (-1). + + Returns + ------- + adata: AnnData + AnnData object with UTAG domain predictions for each cell in adata.obs, column `save_key`. + """ + ad = adata.copy() + + if channels_to_use: + ad = ad[:, channels_to_use] + + if filter_by_variance: + ad = low_variance_filter(ad) + + if isinstance(clustering_method, list): + clustering_method = [m.upper() for m in clustering_method] + elif isinstance(clustering_method, str): + clustering_method = [clustering_method.upper()] + else: + print( + "Invalid Clustering Method. Clustering Method Should Either be a string or a list" + ) + return + assert all(m in ["LEIDEN", "PARC", "KMEANS"] for m in clustering_method) + + if "PARC" in clustering_method: + from parc import PARC # early fail if not available + if "KMEANS" in clustering_method: + from sklearn.cluster import KMeans + + print("Applying UTAG Algorithm...") + if slide_key: + ads = [ + ad[ad.obs[slide_key] == slide].copy() for slide in ad.obs[slide_key].unique() + ] + ad_list = parmap.map( + _parallel_message_pass, + ads, + radius=max_dist, + coord_type="generic", + set_diag=True, + mode=normalization_mode, + pm_pbar=True, + pm_parallel=parallel, + pm_processes=processes, + ) + ad_result = anndata.concat(ad_list) + if keep_spatial_connectivity: + ad_result.obsp["spatial_connectivities"] = sparse_matrix_dstack( + [x.obsp["spatial_connectivities"] for x in ad_list] + ) + ad_result.obsp["spatial_distances"] = sparse_matrix_dstack( + [x.obsp["spatial_distances"] for x in ad_list] + ) + else: + sq.gr.spatial_neighbors(ad, radius=max_dist, coord_type="generic", set_diag=True) + ad_result = custom_message_passing(ad, mode=normalization_mode) + + if apply_clustering: + if "n_comps" in pca_kwargs: + if pca_kwargs["n_comps"] > ad_result.shape[1]: + pca_kwargs["n_comps"] = ad_result.shape[1] - 1 + print( + f"Overwriding provided number of PCA dimensions to match number of features: {pca_kwargs['n_comps']}" + ) + sc.tl.pca(ad_result, **pca_kwargs) + sc.pp.neighbors(ad_result) + + if apply_umap: + print("Running UMAP on Input Dataset...") + sc.tl.umap(ad_result, **umap_kwargs) + + for resolution in tqdm(resolutions): + + res_key1 = save_key + "_leiden_" + str(resolution) + res_key2 = save_key + "_parc_" + str(resolution) + res_key3 = save_key + "_kmeans_" + str(resolution) + if "LEIDEN" in clustering_method: + print(f"Applying Leiden Clustering at Resolution: {resolution}...") + kwargs = dict() + kwargs.update(leiden_kwargs or {}) + sc.tl.leiden( + ad_result, resolution=resolution, key_added=res_key1, **kwargs + ) + add_probabilities_to_centroid(ad_result, res_key1) + + if "PARC" in clustering_method: + from parc import PARC + + print(f"Applying PARC Clustering at Resolution: {resolution}...") + + kwargs = dict(random_seed=1, small_pop=1000) + kwargs.update(parc_kwargs or {}) + model = PARC( + ad_result.obsm["X_pca"], + neighbor_graph=ad_result.obsp["connectivities"], + resolution_parameter=resolution, + **kwargs, + ) + model.run_PARC() + ad_result.obs[res_key2] = pd.Categorical(model.labels) + ad_result.obs[res_key2] = ad_result.obs[res_key2].astype("category") + add_probabilities_to_centroid(ad_result, res_key2) + + if "KMEANS" in clustering_method: + print(f"Applying K-means Clustering at Resolution: {resolution}...") + k = int(np.ceil(resolution * 10)) + kmeans = KMeans(n_clusters=k, random_state=1).fit(ad_result.obsm["X_pca"]) + ad_result.obs[res_key3] = pd.Categorical(kmeans.labels_.astype(str)) + add_probabilities_to_centroid(ad_result, res_key3) + + return ad_result + + +def _parallel_message_pass( + ad, + radius: int, + coord_type: str, + set_diag: bool, + mode: str, +): + sq.gr.spatial_neighbors(ad, radius=radius, coord_type=coord_type, set_diag=set_diag) + ad = custom_message_passing(ad, mode=mode) + return ad + + +def custom_message_passing(adata, mode: str = "l1_norm"): + # from scipy.linalg import sqrtm + # import logging + if mode == "l1_norm": + A = adata.obsp["spatial_connectivities"] + from sklearn.preprocessing import normalize + affinity = normalize(A, axis=1, norm="l1") + else: + # Plain A_mod multiplication + A = adata.obsp["spatial_connectivities"] + affinity = A + # logging.info(type(affinity)) + adata.X = affinity @ adata.X + return adata + + +def low_variance_filter(adata): + return adata[:, adata.var["std"] > adata.var["std"].median()] + + +def add_probabilities_to_centroid( + adata, col: str, name_to_output: str = None +): + from scipy.special import softmax + + if name_to_output is None: + name_to_output = col + "_probabilities" + + mean = z_score(adata.to_df()).groupby(adata.obs[col]).mean() + probs = softmax(adata.to_df() @ mean.T, axis=1) + adata.obsm[name_to_output] = probs + return adata + + +class AnnoyTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator): + """Wrapper for using annoy.AnnoyIndex as sklearn's KNeighborsTransformer""" + + def __init__(self, n_neighbors=5, *, metric="euclidean", + n_trees=10, search_k=-1, n_jobs=-1): + self.n_neighbors = n_neighbors + self.n_trees = n_trees + self.search_k = search_k + self.metric = metric + self.n_jobs = n_jobs + + def fit(self, X, y=None): + X = self._validate_data(X) + self.n_samples_fit_ = X.shape[0] + metric = self.metric if self.metric != "sqeuclidean" else "euclidean" + self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric) + for i, x in enumerate(X): + self.annoy_.add_item(i, x.tolist()) + self.annoy_.build(self.n_trees, n_jobs = self.n_jobs) + return self + + def transform(self, X): + X = self._transform_checks(X, "annoy_") + return self._transform(X) + + def fit_transform(self, X, y=None): + return self.fit(X)._transform(X=None) + + def _transform(self, X): + """As `transform`, but handles X is None for faster `fit_transform`.""" + + n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0] + + # For compatibility reasons, as each sample is considered as its own + # neighbor, one extra neighbor will be computed. + n_neighbors = self.n_neighbors + 1 + + indices = np.empty((n_samples_transform, n_neighbors), dtype=int) + distances = np.empty((n_samples_transform, n_neighbors)) + + if X is None: + for i in range(self.annoy_.get_n_items()): + ind, dist = self.annoy_.get_nns_by_item( + i, n_neighbors, self.search_k, include_distances=True + ) + + indices[i], distances[i] = ind, dist + else: + for i, x in enumerate(X): + indices[i], distances[i] = self.annoy_.get_nns_by_vector( + x.tolist(), n_neighbors, self.search_k, include_distances=True + ) + + if self.metric == "sqeuclidean": + distances **= 2 + + indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors) + kneighbors_graph = csr_matrix( + (distances.ravel(), indices.ravel(), indptr), + shape=(n_samples_transform, self.n_samples_fit_), + ) + + return kneighbors_graph + + def _more_tags(self): + return { + "_xfail_checks": {"check_estimators_pickle": "Cannot pickle AnnoyIndex"}, + "requires_y": False, + } + +class PARC_2: + def __init__(self, data, true_label=None, dist_std_local=3, jac_std_global='median', keep_all_local_dist='auto', + too_big_factor=0.4, small_pop=10, jac_weighted_edges=True, knn=30, n_iter_leiden=5, random_seed=42, + num_threads=-1, distance='l2', time_smallpop=15, partition_type="ModularityVP", + resolution_parameter=1.0, + knn_struct=None, neighbor_graph=None, hnsw_param_ef_construction=150): + # higher dist_std_local means more edges are kept + # highter jac_std_global means more edges are kept + if keep_all_local_dist == 'auto': + if data.shape[0] > 300000: + keep_all_local_dist = True # skips local pruning to increase speed + else: + keep_all_local_dist = False + if resolution_parameter != 1: + partition_type = "RBVP" # Reichardt and Bornholdt’s Potts model. Note that this is the same as ModularityVertexPartition when setting 𝛾 = 1 and normalising by 2m + self.data = data + self.true_label = true_label + self.dist_std_local = dist_std_local # similar to the jac_std_global parameter. avoid setting local and global pruning to both be below 0.5 as this is very aggresive pruning. + self.jac_std_global = jac_std_global # 0.15 is also a recommended value performing empirically similar to 'median'. Generally values between 0-1.5 are reasonable. + self.keep_all_local_dist = keep_all_local_dist # decides whether or not to do local pruning. default is 'auto' which omits LOCAL pruning for samples >300,000 cells. + self.too_big_factor = too_big_factor # if a cluster exceeds this share of the entire cell population, then the PARC will be run on the large cluster. at 0.4 it does not come into play + self.small_pop = small_pop # smallest cluster population to be considered a community + self.jac_weighted_edges = jac_weighted_edges # boolean. whether to partition using weighted graph + self.knn = knn + self.n_iter_leiden = n_iter_leiden # the default is 5 in PARC + self.random_seed = random_seed # enable reproducible Leiden clustering + self.num_threads = num_threads # number of threads used in KNN search/construction + self.distance = distance # Euclidean distance 'l2' by default; other options 'ip' and 'cosine' + self.time_smallpop = time_smallpop # number of seconds trying to check an outlier + self.partition_type = partition_type # default is the simple ModularityVertexPartition where resolution_parameter =1. In order to change resolution_parameter, we switch to RBConfigurationVP + self.resolution_parameter = resolution_parameter # defaults to 1. expose this parameter in leidenalg + self.knn_struct = knn_struct # the hnsw index of the KNN graph on which we perform queries + self.neighbor_graph = neighbor_graph # CSR affinity matrix for pre-computed nearest neighbors + self.hnsw_param_ef_construction = hnsw_param_ef_construction # set at 150. higher value increases accuracy of index construction. Even for several 100,000s of cells 150-200 is adequate + + def make_knn_struct(self, too_big=False, big_cluster=None): + if self.knn > 190: print('consider using a lower K_in for KNN graph construction') + ef_query = max(100, self.knn + 1) # ef always should be >K. higher ef, more accurate query + if too_big == False: + num_dims = self.data.shape[1] + n_elements = self.data.shape[0] + p = hnswlib.Index(space=self.distance, dim=num_dims) # default to Euclidean distance + p.set_num_threads(self.num_threads) # allow user to set threads used in KNN construction + if n_elements < 10000: + ef_query = min(n_elements - 10, 500) + ef_construction = ef_query + else: + ef_construction = self.hnsw_param_ef_construction + if (num_dims > 30) & (n_elements <= 50000): + p.init_index(max_elements=n_elements, ef_construction=ef_construction, + M=48) ## good for scRNA seq where dimensionality is high + else: + p.init_index(max_elements=n_elements, ef_construction=ef_construction, M=24) # 30 + p.add_items(self.data) + if too_big == True: + num_dims = big_cluster.shape[1] + n_elements = big_cluster.shape[0] + p = hnswlib.Index(space='l2', dim=num_dims) + p.init_index(max_elements=n_elements, ef_construction=200, M=30) + p.add_items(big_cluster) + p.set_ef(ef_query) # ef should always be > k + + return p + + def knngraph_full(self): # , neighbor_array, distance_array): + k_umap = 15 + t0 = time.time() + # neighbors in array are not listed in in any order of proximity + self.knn_struct.set_ef(k_umap + 1) + neighbor_array, distance_array = self.knn_struct.knn_query(self.data, k=k_umap) + + row_list = [] + n_neighbors = neighbor_array.shape[1] + n_cells = neighbor_array.shape[0] + + row_list.extend(list(np.transpose(np.ones((n_neighbors, n_cells)) * range(0, n_cells)).flatten())) + + row_min = np.min(distance_array, axis=1) + row_sigma = np.std(distance_array, axis=1) + + distance_array = (distance_array - row_min[:, np.newaxis]) / row_sigma[:, np.newaxis] + + col_list = neighbor_array.flatten().tolist() + distance_array = distance_array.flatten() + distance_array = np.sqrt(distance_array) + distance_array = distance_array * -1 + + weight_list = np.exp(distance_array) + + threshold = np.mean(weight_list) + 2 * np.std(weight_list) + + weight_list[weight_list >= threshold] = threshold + + weight_list = weight_list.tolist() + + graph = csr_matrix((np.array(weight_list), (np.array(row_list), np.array(col_list))), + shape=(n_cells, n_cells)) + + graph_transpose = graph.T + prod_matrix = graph.multiply(graph_transpose) + + graph = graph_transpose + graph - prod_matrix + return graph + + def make_csrmatrix_noselfloop(self, neighbor_array, distance_array): + # neighbor array not listed in in any order of proximity + row_list = [] + col_list = [] + weight_list = [] + + n_neighbors = neighbor_array.shape[1] + n_cells = neighbor_array.shape[0] + rowi = 0 + discard_count = 0 + if self.keep_all_local_dist == False: # locally prune based on (squared) l2 distance + + print('commencing local pruning based on Euclidean distance metric at', + self.dist_std_local, 's.dev above mean') + distance_array = distance_array + 0.1 + for row in neighbor_array: + distlist = distance_array[rowi, :] + to_keep = np.where(distlist < np.mean(distlist) + self.dist_std_local * np.std(distlist))[0] # 0*std + updated_nn_ind = row[np.ix_(to_keep)] + updated_nn_weights = distlist[np.ix_(to_keep)] + discard_count = discard_count + (n_neighbors - len(to_keep)) + + for ik in range(len(updated_nn_ind)): + if rowi != row[ik]: # remove self-loops + row_list.append(rowi) + col_list.append(updated_nn_ind[ik]) + dist = np.sqrt(updated_nn_weights[ik]) + weight_list.append(1 / (dist + 0.1)) + + rowi = rowi + 1 + + if self.keep_all_local_dist == True: # dont prune based on distance + row_list.extend(list(np.transpose(np.ones((n_neighbors, n_cells)) * range(0, n_cells)).flatten())) + col_list = neighbor_array.flatten().tolist() + weight_list = (1. / (distance_array.flatten() + 0.1)).tolist() + + csr_graph = csr_matrix((np.array(weight_list), (np.array(row_list), np.array(col_list))), + shape=(n_cells, n_cells)) + return csr_graph + + def func_mode(self, ll): # return MODE of list + # If multiple items are maximal, the function returns the first one encountered. + return max(set(ll), key=ll.count) + + def run_toobig_subPARC(self, X_data, jac_std_toobig=0.3, + jac_weighted_edges=True): + n_elements = X_data.shape[0] + hnsw = self.make_knn_struct(too_big=True, big_cluster=X_data) + if n_elements <= 10: print('consider increasing the too_big_factor') + if n_elements > self.knn: + knnbig = self.knn + else: + knnbig = int(max(5, 0.2 * n_elements)) + + neighbor_array, distance_array = hnsw.knn_query(X_data, k=knnbig) + # print('shapes of neigh and dist array', neighbor_array.shape, distance_array.shape) + csr_array = self.make_csrmatrix_noselfloop(neighbor_array, distance_array) + sources, targets = csr_array.nonzero() + + #mask = np.zeros(len(sources), dtype=bool) + + #mask |= (csr_array.data > ( np.mean(csr_array.data) + np.std(csr_array.data) * 5)) # weights are set as 1/dist. so smaller distance means stronger edge + + #csr_array.data[mask] = 0 + #csr_array.eliminate_zeros() + #sources, targets = csr_array.nonzero() + edgelist = list(zip(sources.tolist(), targets.tolist())) + edgelist_copy = edgelist.copy() + G = ig.Graph(edgelist, edge_attrs={'weight': csr_array.data.tolist()}) + sim_list = G.similarity_jaccard(pairs=edgelist_copy) # list of jaccard weights + new_edgelist = [] + sim_list_array = np.asarray(sim_list) + if jac_std_toobig == 'median': + threshold = np.median(sim_list) + else: + threshold = np.mean(sim_list) - jac_std_toobig * np.std(sim_list) + print('jac threshold %.3f' % threshold) + print('jac std %.3f' % np.std(sim_list)) + print('jac mean %.3f' % np.mean(sim_list)) + strong_locs = np.where(sim_list_array > threshold)[0] + for ii in strong_locs: new_edgelist.append(edgelist_copy[ii]) + sim_list_new = list(sim_list_array[strong_locs]) + + if jac_weighted_edges == True: + G_sim = ig.Graph(n=n_elements, edges=list(new_edgelist), edge_attrs={'weight': sim_list_new}) + else: + G_sim = ig.Graph(n=n_elements, edges=list(new_edgelist)) + G_sim.simplify(combine_edges='sum') + import random + random.seed(self.random_seed) + if jac_weighted_edges == True: + if self.partition_type == 'ModularityVP': + partition = leidenalg.find_partition(G_sim, leidenalg.ModularityVertexPartition, weights='weight', + n_iterations=self.n_iter_leiden, seed=self.random_seed) + print('partition type MVP') + else: + # partition = leidenalg.find_partition(G_sim, leidenalg.RBConfigurationVertexPartition, weights='weight', + # n_iterations=self.n_iter_leiden, seed=self.random_seed, + # resolution_parameter=self.resolution_parameter) + print("custom partition") + print(self.resolution_parameter) + print(self.n_iter_leiden) + partition = G_sim.community_leiden(objective_function='modularity', weights='weight', n_iterations=self.n_iter_leiden, + resolution=self.resolution_parameter) + print('partition type RBC') + else: + if self.partition_type == 'ModularityVP': + print('partition type MVP') + partition = leidenalg.find_partition(G_sim, leidenalg.ModularityVertexPartition, + n_iterations=self.n_iter_leiden, seed=self.random_seed) + else: + # print('partition type RBC') + # partition = leidenalg.find_partition(G_sim, leidenalg.RBConfigurationVertexPartition, + # n_iterations=self.n_iter_leiden, seed=self.random_seed, + # resolution_parameter=self.resolution_parameter) + print("custom partition") + print(self.resolution_parameter) + print(self.n_iter_leiden) + partition = G_sim.community_leiden(objective_function='modularity', n_iterations=self.n_iter_leiden, + resolution=self.resolution_parameter) + # print('Q= %.2f' % partition.quality()) + PARC_labels_leiden = np.asarray(partition.membership) + PARC_labels_leiden = np.reshape(PARC_labels_leiden, (n_elements, 1)) + small_pop_list = [] + small_cluster_list = [] + small_pop_exist = False + dummy, PARC_labels_leiden = np.unique(list(PARC_labels_leiden.flatten()), return_inverse=True) + for cluster in set(PARC_labels_leiden): + population = len(np.where(PARC_labels_leiden == cluster)[0]) + if population < 10: + small_pop_exist = True + small_pop_list.append(list(np.where(PARC_labels_leiden == cluster)[0])) + small_cluster_list.append(cluster) + + for small_cluster in small_pop_list: + for single_cell in small_cluster: + old_neighbors = neighbor_array[single_cell] + group_of_old_neighbors = PARC_labels_leiden[old_neighbors] + group_of_old_neighbors = list(group_of_old_neighbors.flatten()) + available_neighbours = set(group_of_old_neighbors) - set(small_cluster_list) + if len(available_neighbours) > 0: + available_neighbours_list = [value for value in group_of_old_neighbors if + value in list(available_neighbours)] + best_group = max(available_neighbours_list, key=available_neighbours_list.count) + PARC_labels_leiden[single_cell] = best_group + + time_smallpop_start = time.time() + print('handling fragments') + while (small_pop_exist) == True & (time.time() - time_smallpop_start < self.time_smallpop): + small_pop_list = [] + small_pop_exist = False + for cluster in set(list(PARC_labels_leiden.flatten())): + population = len(np.where(PARC_labels_leiden == cluster)[0]) + if population < 10: + small_pop_exist = True + + small_pop_list.append(np.where(PARC_labels_leiden == cluster)[0]) + for small_cluster in small_pop_list: + for single_cell in small_cluster: + old_neighbors = neighbor_array[single_cell, :] + group_of_old_neighbors = PARC_labels_leiden[old_neighbors] + group_of_old_neighbors = list(group_of_old_neighbors.flatten()) + best_group = max(set(group_of_old_neighbors), key=group_of_old_neighbors.count) + PARC_labels_leiden[single_cell] = best_group + + dummy, PARC_labels_leiden = np.unique(list(PARC_labels_leiden.flatten()), return_inverse=True) + + return PARC_labels_leiden + + def run_subPARC(self): + + X_data = self.data + too_big_factor = self.too_big_factor + small_pop = self.small_pop + jac_std_global = self.jac_std_global + jac_weighted_edges = self.jac_weighted_edges + knn = self.knn + n_elements = X_data.shape[0] + + if self.neighbor_graph is not None: + csr_array = self.neighbor_graph + neighbor_array = np.split(csr_array.indices, csr_array.indptr)[1:-1] + else: + if self.knn_struct is None: + print('knn struct was not available, so making one') + self.knn_struct = self.make_knn_struct() + else: + print('knn struct already exists') + neighbor_array, distance_array = self.knn_struct.knn_query(X_data, k=knn) + csr_array = self.make_csrmatrix_noselfloop(neighbor_array, distance_array) + + sources, targets = csr_array.nonzero() + + edgelist = list(zip(sources, targets)) + + edgelist_copy = edgelist.copy() + + G = ig.Graph(edgelist, edge_attrs={'weight': csr_array.data.tolist()}) + # print('average degree of prejacard graph is %.1f'% (np.mean(G.degree()))) + # print('computing Jaccard metric') + sim_list = G.similarity_jaccard(pairs=edgelist_copy) + + print('commencing global pruning') + + sim_list_array = np.asarray(sim_list) + edge_list_copy_array = np.asarray(edgelist_copy) + + if jac_std_global == 'median': + threshold = np.median(sim_list) + else: + threshold = np.mean(sim_list) - jac_std_global * np.std(sim_list) + strong_locs = np.where(sim_list_array > threshold)[0] + # print('Share of edges kept after Global Pruning %.2f' % (len(strong_locs) / len(sim_list)), '%') + new_edgelist = list(edge_list_copy_array[strong_locs]) + sim_list_new = list(sim_list_array[strong_locs]) + + G_sim = ig.Graph(n=n_elements, edges=list(new_edgelist), edge_attrs={'weight': sim_list_new}) + # print('average degree of graph is %.1f' % (np.mean(G_sim.degree()))) + G_sim.simplify(combine_edges='sum') # "first" + # print('average degree of SIMPLE graph is %.1f' % (np.mean(G_sim.degree()))) + print('commencing community detection') + import random + random.seed(self.random_seed) + if jac_weighted_edges == True: + start_leiden = time.time() + if self.partition_type == 'ModularityVP': + print('partition type MVP') + partition = leidenalg.find_partition(G_sim, leidenalg.ModularityVertexPartition, weights='weight', + n_iterations=self.n_iter_leiden, seed=self.random_seed) + else: + print('partition type RBC') + # partition = leidenalg.find_partition(G_sim, leidenalg.RBConfigurationVertexPartition, weights='weight', + # n_iterations=self.n_iter_leiden, seed=self.random_seed, + # resolution_parameter=self.resolution_parameter) + print("custom partition") + print(self.resolution_parameter) + print(self.n_iter_leiden) + partition = G_sim.community_leiden(objective_function='modularity',weights='weight', n_iterations=self.n_iter_leiden, + resolution=self.resolution_parameter) + # print(time.time() - start_leiden) + else: + start_leiden = time.time() + if self.partition_type == 'ModularityVP': + partition = leidenalg.find_partition(G_sim, leidenalg.ModularityVertexPartition, + n_iterations=self.n_iter_leiden, seed=self.random_seed) + print('partition type MVP') + else: + # partition = leidenalg.find_partition(G_sim, leidenalg.RBConfigurationVertexPartition, + # n_iterations=self.n_iter_leiden, seed=self.random_seed, + # resolution_parameter=self.resolution_parameter) + print("custom partition") + print(self.resolution_parameter) + print(self.n_iter_leiden) + partition = G_sim.community_leiden(objective_function='modularity', n_iterations=self.n_iter_leiden, + resolution=self.resolution_parameter) + + #print('partition type RBC') + # print(time.time() - start_leiden) + time_end_PARC = time.time() + # print('Q= %.1f' % (partition.quality())) + PARC_labels_leiden = np.asarray(partition.membership) + PARC_labels_leiden = np.reshape(PARC_labels_leiden, (n_elements, 1)) + + too_big = False + + # print('labels found after Leiden', set(list(PARC_labels_leiden.T)[0])) will have some outlier clusters that need to be added to a cluster if a cluster has members that are KNN + + cluster_i_loc = np.where(PARC_labels_leiden == 0)[ + 0] # the 0th cluster is the largest one. so if cluster 0 is not too big, then the others wont be too big either + pop_i = len(cluster_i_loc) + if pop_i > too_big_factor * n_elements: # 0.4 + too_big = True + cluster_big_loc = cluster_i_loc + list_pop_too_bigs = [pop_i] + cluster_too_big = 0 + + while too_big == True: + + X_data_big = X_data[cluster_big_loc, :] + PARC_labels_leiden_big = self.run_toobig_subPARC(X_data_big) + # print('set of new big labels ', set(PARC_labels_leiden_big.flatten())) + PARC_labels_leiden_big = PARC_labels_leiden_big + 100000 + # print('set of new big labels +1000 ', set(list(PARC_labels_leiden_big.flatten()))) + pop_list = [] + + for item in set(list(PARC_labels_leiden_big.flatten())): + pop_list.append([item, list(PARC_labels_leiden_big.flatten()).count(item)]) + print('pop of big clusters', pop_list) + jj = 0 + print('shape PARC_labels_leiden', PARC_labels_leiden.shape) + for j in cluster_big_loc: + PARC_labels_leiden[j] = PARC_labels_leiden_big[jj] + jj = jj + 1 + dummy, PARC_labels_leiden = np.unique(list(PARC_labels_leiden.flatten()), return_inverse=True) + print('new set of labels ', set(PARC_labels_leiden)) + too_big = False + set_PARC_labels_leiden = set(PARC_labels_leiden) + + PARC_labels_leiden = np.asarray(PARC_labels_leiden) + for cluster_ii in set_PARC_labels_leiden: + cluster_ii_loc = np.where(PARC_labels_leiden == cluster_ii)[0] + pop_ii = len(cluster_ii_loc) + not_yet_expanded = pop_ii not in list_pop_too_bigs + if pop_ii > too_big_factor * n_elements and not_yet_expanded == True: + too_big = True + print('cluster', cluster_ii, 'is too big and has population', pop_ii) + cluster_big_loc = cluster_ii_loc + cluster_big = cluster_ii + big_pop = pop_ii + if too_big == True: + list_pop_too_bigs.append(big_pop) + print('cluster', cluster_big, 'is too big with population', big_pop, '. It will be expanded') + dummy, PARC_labels_leiden = np.unique(list(PARC_labels_leiden.flatten()), return_inverse=True) + small_pop_list = [] + small_cluster_list = [] + small_pop_exist = False + + for cluster in set(PARC_labels_leiden): + population = len(np.where(PARC_labels_leiden == cluster)[0]) + + if population < small_pop: # 10 + small_pop_exist = True + + small_pop_list.append(list(np.where(PARC_labels_leiden == cluster)[0])) + small_cluster_list.append(cluster) + + for small_cluster in small_pop_list: + + for single_cell in small_cluster: + old_neighbors = neighbor_array[single_cell] + group_of_old_neighbors = PARC_labels_leiden[old_neighbors] + group_of_old_neighbors = list(group_of_old_neighbors.flatten()) + available_neighbours = set(group_of_old_neighbors) - set(small_cluster_list) + if len(available_neighbours) > 0: + available_neighbours_list = [value for value in group_of_old_neighbors if + value in list(available_neighbours)] + best_group = max(available_neighbours_list, key=available_neighbours_list.count) + PARC_labels_leiden[single_cell] = best_group + time_smallpop_start = time.time() + while (small_pop_exist == True) & ((time.time() - time_smallpop_start) < self.time_smallpop): + small_pop_list = [] + small_pop_exist = False + for cluster in set(list(PARC_labels_leiden.flatten())): + population = len(np.where(PARC_labels_leiden == cluster)[0]) + if population < small_pop: + small_pop_exist = True + print(cluster, ' has small population of', population, ) + small_pop_list.append(np.where(PARC_labels_leiden == cluster)[0]) + for small_cluster in small_pop_list: + for single_cell in small_cluster: + old_neighbors = neighbor_array[single_cell, :] + group_of_old_neighbors = PARC_labels_leiden[old_neighbors] + group_of_old_neighbors = list(group_of_old_neighbors.flatten()) + best_group = max(set(group_of_old_neighbors), key=group_of_old_neighbors.count) + PARC_labels_leiden[single_cell] = best_group + + dummy, PARC_labels_leiden = np.unique(list(PARC_labels_leiden.flatten()), return_inverse=True) + PARC_labels_leiden = list(PARC_labels_leiden.flatten()) + # print('final labels allocation', set(PARC_labels_leiden)) + pop_list = [] + for item in set(PARC_labels_leiden): + pop_list.append((item, PARC_labels_leiden.count(item))) + print('list of cluster labels and populations', len(pop_list), pop_list) + + self.labels = PARC_labels_leiden # list + return + + def accuracy(self, onevsall=1): + + true_labels = self.true_label + Index_dict = {} + PARC_labels = self.labels + N = len(PARC_labels) + n_cancer = list(true_labels).count(onevsall) + n_pbmc = N - n_cancer + + for k in range(N): + Index_dict.setdefault(PARC_labels[k], []).append(true_labels[k]) + num_groups = len(Index_dict) + sorted_keys = list(sorted(Index_dict.keys())) + error_count = [] + pbmc_labels = [] + thp1_labels = [] + fp, fn, tp, tn, precision, recall, f1_score = 0, 0, 0, 0, 0, 0, 0 + + for kk in sorted_keys: + vals = [t for t in Index_dict[kk]] + majority_val = self.func_mode(vals) + if majority_val == onevsall: print('cluster', kk, ' has majority', onevsall, 'with population', len(vals)) + if kk == -1: + len_unknown = len(vals) + print('len unknown', len_unknown) + if (majority_val == onevsall) and (kk != -1): + thp1_labels.append(kk) + fp = fp + len([e for e in vals if e != onevsall]) + tp = tp + len([e for e in vals if e == onevsall]) + list_error = [e for e in vals if e != majority_val] + e_count = len(list_error) + error_count.append(e_count) + elif (majority_val != onevsall) and (kk != -1): + pbmc_labels.append(kk) + tn = tn + len([e for e in vals if e != onevsall]) + fn = fn + len([e for e in vals if e == onevsall]) + error_count.append(len([e for e in vals if e != majority_val])) + + predict_class_array = np.array(PARC_labels) + PARC_labels_array = np.array(PARC_labels) + number_clusters_for_target = len(thp1_labels) + for cancer_class in thp1_labels: + predict_class_array[PARC_labels_array == cancer_class] = 1 + for benign_class in pbmc_labels: + predict_class_array[PARC_labels_array == benign_class] = 0 + predict_class_array.reshape((predict_class_array.shape[0], -1)) + error_rate = sum(error_count) / N + n_target = tp + fn + tnr = tn / n_pbmc + fnr = fn / n_cancer + tpr = tp / n_cancer + fpr = fp / n_pbmc + + if tp != 0 or fn != 0: recall = tp / (tp + fn) # ability to find all positives + if tp != 0 or fp != 0: precision = tp / (tp + fp) # ability to not misclassify negatives as positives + if precision != 0 or recall != 0: + f1_score = precision * recall * 2 / (precision + recall) + majority_truth_labels = np.empty((len(true_labels), 1), dtype=object) + + for cluster_i in set(PARC_labels): + cluster_i_loc = np.where(np.asarray(PARC_labels) == cluster_i)[0] + true_labels = np.asarray(true_labels) + majority_truth = self.func_mode(list(true_labels[cluster_i_loc])) + majority_truth_labels[cluster_i_loc] = majority_truth + + majority_truth_labels = list(majority_truth_labels.flatten()) + accuracy_val = [error_rate, f1_score, tnr, fnr, tpr, fpr, precision, + recall, num_groups, n_target] + + return accuracy_val, predict_class_array, majority_truth_labels, number_clusters_for_target + + def run_PARC(self): + print('input data has shape', self.data.shape[0], '(samples) x', self.data.shape[1], '(features)') + if self.true_label is None: + self.true_label = [1] * self.data.shape[0] + list_roc = [] + + time_start_total = time.time() + + time_start_knn = time.time() + + time_end_knn_struct = time.time() - time_start_knn + # Query dataset, k - number of closest elements (returns 2 numpy arrays) + self.run_subPARC() + run_time = time.time() - time_start_total + print('time elapsed {:.1f} seconds'.format(run_time)) + + targets = list(set(self.true_label)) + N = len(list(self.true_label)) + self.f1_accumulated = 0 + self.f1_mean = 0 + self.stats_df = pd.DataFrame({'jac_std_global': [self.jac_std_global], 'dist_std_local': [self.dist_std_local], + 'runtime(s)': [run_time]}) + self.majority_truth_labels = [] + if len(targets) > 1: + f1_accumulated = 0 + f1_acc_noweighting = 0 + for onevsall_val in targets: + print('target is', onevsall_val) + vals_roc, predict_class_array, majority_truth_labels, numclusters_targetval = self.accuracy( + onevsall=onevsall_val) + f1_current = vals_roc[1] + print('target', onevsall_val, 'has f1-score of %.2f' % (f1_current * 100)) + f1_accumulated = f1_accumulated + f1_current * (list(self.true_label).count(onevsall_val)) / N + f1_acc_noweighting = f1_acc_noweighting + f1_current + + list_roc.append( + [self.jac_std_global, self.dist_std_local, onevsall_val] + vals_roc + [numclusters_targetval] + [ + run_time]) + + f1_mean = f1_acc_noweighting / len(targets) + print("f1-score (unweighted) mean %.2f" % (f1_mean * 100), '%') + print('f1-score weighted (by population) %.2f' % (f1_accumulated * 100), '%') + + df_accuracy = pd.DataFrame(list_roc, + columns=['jac_std_global', 'dist_std_local', 'onevsall-target', 'error rate', + 'f1-score', 'tnr', 'fnr', + 'tpr', 'fpr', 'precision', 'recall', 'num_groups', + 'population of target', 'num clusters', 'clustering runtime']) + + self.f1_accumulated = f1_accumulated + self.f1_mean = f1_mean + self.stats_df = df_accuracy + self.majority_truth_labels = majority_truth_labels + return + + def run_umap_hnsw(self, X_input, graph, n_components=2, alpha: float = 1.0, negative_sample_rate: int = 5, + gamma: float = 1.0, spread=1.0, min_dist=0.1, init_pos='spectral', random_state=1, ): + + from umap.umap_ import find_ab_params, simplicial_set_embedding + import matplotlib.pyplot as plt + + a, b = find_ab_params(spread, min_dist) + print('a,b, spread, dist', a, b, spread, min_dist) + t0 = time.time() + X_umap = simplicial_set_embedding(data=X_input, graph=graph, n_components=n_components, initial_alpha=alpha, + a=a, b=b, n_epochs=0, metric_kwds={}, gamma=gamma, + negative_sample_rate=negative_sample_rate, init=init_pos, + random_state=np.random.RandomState(random_state), metric='euclidean', + verbose=1) + return X_umap + +def knngraph_full_2(self): # , neighbor_array, distance_array): + k_umap = 15 + t0 = time.time() + # neighbors in array are not listed in in any order of proximity + self.knn_struct.set_ef(k_umap + 1) + neighbor_array, distance_array = self.knn_struct.knn_query(self.data, k=k_umap) + + row_list = [] + n_neighbors = neighbor_array.shape[1] + n_cells = neighbor_array.shape[0] + + row_list.extend(list(np.transpose(np.ones((n_neighbors, n_cells)) * range(0, n_cells)).flatten())) + + row_min = np.min(distance_array, axis=1) + row_sigma = np.std(distance_array, axis=1) + + distance_array = (distance_array - row_min[:, np.newaxis]) / row_sigma[:, np.newaxis] + + col_list = neighbor_array.flatten().tolist() + distance_array = distance_array.flatten() + distance_array = np.sqrt(distance_array) + distance_array = distance_array * -1 + + weight_list = np.exp(distance_array) + + threshold = np.mean(weight_list) + 2 * np.std(weight_list) + + weight_list[weight_list >= threshold] = threshold + + weight_list = weight_list.tolist() + + graph = csr_matrix((np.array(weight_list), (np.array(row_list), np.array(col_list))), + shape=(n_cells, n_cells)) + + graph_transpose = graph.T + prod_matrix = graph.multiply(graph_transpose) + + #graph = graph_transpose + graph - prod_matrix + return graph , prod_matrix + +def phenocluster__make_adata(df, x_cols, meta_cols, + z_normalize, normalize_total, + log_normalize, select_high_var_features, n_features): + + print(select_high_var_features) + print(n_features) + + mat = df[x_cols] + meta = df[meta_cols] + adata = ad.AnnData(mat) + adata.obs = meta + adata.layers["counts"] = adata.X.copy() + #adata.write("input/clust_dat.h5ad") + if normalize_total: + sc.pp.normalize_total(adata) + if log_normalize: + sc.pp.log1p(adata) + if z_normalize: + sc.pp.scale(adata) + + if select_high_var_features: + sc.pp.highly_variable_genes(adata, n_top_genes=n_features, flavor='cell_ranger') + adata = adata[:, adata.var.highly_variable].copy() + print(adata.shape) + return adata + +# scanpy clustering +def RunNeighbClust(adata, n_neighbors, metric, resolution, random_state, n_principal_components, + n_jobs, n_iterations, fast, transformer): + + if fast == True: + sc.pp.pca(adata, n_comps=n_principal_components) + if transformer == "Annoy": + sc.pp.neighbors(adata, transformer=AnnoyTransformer(n_neighbors=n_neighbors, metric=metric, n_jobs=n_jobs), + n_pcs=n_principal_components, random_state=random_state) + elif transformer == "PNNDescent": + transformer = PyNNDescentTransformer(n_neighbors=n_neighbors, metric=metric, n_jobs=n_jobs, random_state=random_state) + sc.pp.neighbors(adata, transformer=transformer, n_pcs=n_principal_components, random_state=random_state) + else: + sc.pp.neighbors(adata, n_neighbors=n_neighbors, metric=metric, n_pcs=n_principal_components, random_state=random_state) + + else: + if n_principal_components > 0: + sc.pp.pca(adata, n_comps=n_principal_components) + sc.pp.neighbors(adata, n_neighbors=n_neighbors, metric=metric, n_pcs=n_principal_components, random_state=random_state) + else: + sc.pp.neighbors(adata, n_neighbors=n_neighbors, metric=metric, n_pcs=0, random_state=random_state) + + sc.tl.leiden(adata,resolution=resolution, random_state=random_state, n_iterations=n_iterations, flavor="igraph") + adata.obs['Cluster'] = adata.obs['leiden'] + #sc.tl.umap(adata) + + adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) + return adata + +# phenograph clustering +def RunPhenographClust(adata, n_neighbors, clustering_algo, min_cluster_size, + primary_metric, resolution_parameter, nn_method, random_seed, n_principal_components, + n_jobs, n_iterations, fast): + print(adata.shape) + if fast == True: + print("fast phenograph selected") + sc.pp.pca(adata, n_comps=n_principal_components) + print("PCA done") + p1 = PARC(adata.obsm["X_pca"], keep_all_local_dist=True, num_threads=n_jobs) # without labels + print("Parc object created") + p1.knn_struct = p1.make_knn_struct() + print("Parc knn struct created") + graph_parc_2, mat_parc_2 = knngraph_full_2(p1) + print("Parc graph prepared") + communities, graph_phen, Q = phenograph.cluster(graph_parc_2, clustering_algo=None, k=n_neighbors, + min_cluster_size=min_cluster_size, primary_metric=primary_metric, + resolution_parameter=resolution_parameter, nn_method=nn_method, + seed=random_seed, n_iterations=n_iterations, n_jobs=n_jobs) + adata.obsp["pheno_jaccard_ig"] = graph_phen.tocsr() + print("start_leiden") + sc.tl.leiden(adata,resolution=resolution_parameter, random_state=random_seed, + n_iterations=n_iterations, flavor="igraph", obsp="pheno_jaccard_ig") + adata.obs['Cluster'] = adata.obs['leiden'].astype(str) + print("Fast phenograph clustering done") + else: + if n_principal_components == 0: + communities, graph, Q = phenograph.cluster(adata.X, clustering_algo=clustering_algo, k=n_neighbors, + min_cluster_size=min_cluster_size, primary_metric=primary_metric, + resolution_parameter=resolution_parameter, nn_method=nn_method, + seed=random_seed, n_iterations=n_iterations, n_jobs=n_jobs) + else: + sc.pp.pca(adata, n_comps=n_principal_components) + communities, graph, Q = phenograph.cluster(adata.obsm['X_pca'], clustering_algo=clustering_algo, k=n_neighbors, + min_cluster_size=min_cluster_size, primary_metric=primary_metric, + resolution_parameter=resolution_parameter, nn_method=nn_method, + seed=random_seed, n_iterations=n_iterations, n_jobs=n_jobs) + adata.obs['Cluster'] = communities + adata.obs['Cluster'] = adata.obs['Cluster'].astype(str) + print("Regular phenograph clustering done") + #sc.tl.umap(adata) + adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) + return adata + +# parc clustering +def run_parc_clust(adata, n_neighbors, dist_std_local, jac_std_global, small_pop, + random_seed, resolution_parameter, hnsw_param_ef_construction, + n_principal_components, n_iterations, n_jobs, fast): + if fast == True: + sc.pp.pca(adata, n_comps=n_principal_components) + parc_results = PARC_2(adata.obsm["X_pca"], dist_std_local=dist_std_local, jac_std_global=jac_std_global, + small_pop=small_pop, random_seed=random_seed, knn=n_neighbors, + resolution_parameter=resolution_parameter, + hnsw_param_ef_construction=hnsw_param_ef_construction, + partition_type="RBConfigurationVP", + n_iter_leiden=n_iterations, num_threads=n_jobs) # without labels + parc_results.run_PARC() + adata.obs['Cluster'] = parc_results.labels + adata.obs['Cluster'] = adata.obs['Cluster'].astype(str) + else: + if n_principal_components == 0: + parc_results = parc.PARC(adata.X, dist_std_local=dist_std_local, jac_std_global=jac_std_global, + small_pop=small_pop, random_seed=random_seed, knn=n_neighbors, + resolution_parameter=resolution_parameter, + hnsw_param_ef_construction=hnsw_param_ef_construction, + partition_type="RBConfigurationVP", + n_iter_leiden=n_iterations, num_threads=n_jobs) + else: + sc.pp.pca(adata, n_comps=n_principal_components) + parc_results = parc.PARC(adata.obsm['X_pca'], dist_std_local=dist_std_local, jac_std_global=jac_std_global, + small_pop=small_pop, random_seed=random_seed, knn=n_neighbors, + resolution_parameter=resolution_parameter, + hnsw_param_ef_construction=hnsw_param_ef_construction, + partition_type="RBConfigurationVP", + n_iter_leiden=n_iterations, num_threads=n_jobs) + parc_results.run_PARC() + adata.obs['Cluster'] = parc_results.labels + adata.obs['Cluster'] = adata.obs['Cluster'].astype(str) + #sc.tl.umap(adata) + adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) + return adata + +# utag clustering +# need to make image selection based on the variable +def run_utag_clust(adata, n_neighbors, resolution, clustering_method, max_dist, n_principal_components, + random_state, n_jobs, n_iterations, fast, transformer): + #sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=0) + #sc.tl.umap(adata) + + resolutions = [resolution] + print(resolutions) + adata.obsm['spatial'] = np.array(adata.obs[["Centroid X (µm)_(standardized)", "Centroid Y (µm)_(standardized)"]]) + + if fast == True: + utag_results = utag(adata, + slide_key="Image ID_(standardized)", + max_dist=max_dist, + normalization_mode='l1_norm', + apply_clustering=False, + parallel = True, + processes = n_jobs) + + sc.pp.pca(utag_results, n_comps=n_principal_components) + print("start k graph") + if transformer == "Annoy": + sc.pp.neighbors(utag_results, transformer=AnnoyTransformer(n_neighbors=n_neighbors, n_jobs=n_jobs), + n_pcs=n_principal_components, random_state=random_state) + elif transformer == "PNNDescent": + transformer = PyNNDescentTransformer(n_neighbors=n_neighbors, n_jobs=n_jobs, random_state=random_state) + sc.pp.neighbors(utag_results, transformer=transformer, n_pcs=n_principal_components, random_state=random_state) + else: + sc.pp.neighbors(utag_results, n_neighbors=n_neighbors,n_pcs=n_principal_components, random_state=random_state) + + resolution_parameter = resolution + sc.tl.leiden(utag_results,resolution=resolution_parameter, random_state=random_state, + n_iterations=n_iterations, flavor="igraph") + utag_results.obs['Cluster'] = utag_results.obs['leiden'].copy() + adata.uns["leiden"] = utag_results.uns["leiden"].copy() + + else: + utag_results = utag(adata, + slide_key="Image ID_(standardized)", + max_dist=max_dist, + normalization_mode='l1_norm', + apply_clustering=True, + clustering_method = "leiden", + resolutions = resolutions, + leiden_kwargs={"n_iterations": n_iterations, "random_state": random_state}, + pca_kwargs = {"n_comps": n_principal_components}, + parallel = True, + processes = n_jobs) + + curClusterCol = 'UTAG Label_leiden_' + str(resolution) + utag_results.obs['Cluster'] = utag_results.obs[curClusterCol].copy() + + cluster_list = list(utag_results.obs['Cluster']) + print(pd.unique(cluster_list)) + adata.obsp["distances"] = utag_results.obsp["distances"].copy() + adata.obsp["connectivities"] = utag_results.obsp["connectivities"].copy() + adata.obsm["X_pca"] = utag_results.obsm["X_pca"].copy() + adata.uns["neighbors"] = utag_results.uns["neighbors"].copy() + adata.varm["PCs"] = utag_results.varm["PCs"].copy() + adata.obs["Cluster"] = cluster_list + #utag_results.X = adata.X + + return adata + +def phenocluster__scanpy_umap(adata, n_neighbors, metric, n_principal_components): + if "X_pca" not in adata.obsm: + if n_principal_components > 0: + sc.pp.pca(adata, n_comps=n_principal_components) + if "neighbors" not in adata.uns: + print("Finding nearest neigbours") + sc.pp.neighbors(adata, n_neighbors=n_neighbors, metric=metric, n_pcs=n_principal_components) + sc.tl.umap(adata) + st.session_state['phenocluster__clustering_adata'] = adata + +# plot umaps +def phenocluster__plotly_umaps(adata, umap_cur_col, umap_cur_groups, umap_color_col, plot_col): + with plot_col: + subcol1, subcol2 = st.columns(2) + for i, umap_cur_group in enumerate(umap_cur_groups): + if umap_cur_group == "All": + subDat = adata + else: + subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] + umap_coords = subDat.obsm['X_umap'] + df = pd.DataFrame(umap_coords, columns=['UMAP_1', 'UMAP_2']) + clustersList = list(subDat.obs[umap_color_col] ) + df[umap_color_col] = clustersList + df[umap_color_col] = df[umap_color_col].astype(str) + # Create the seaborn plot + fig = px.scatter(df, + x="UMAP_1", + y="UMAP_2", + color=umap_color_col, + title="UMAP " + umap_cur_group + #color_discrete_sequence=px.colors.sequential.Plasma + ) + fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots + fig.update_layout( + title=dict( + text="UMAP " + umap_cur_group, + x=0.5, # Center the title + xanchor='center', + yanchor='top' + ), + legend=dict( + orientation="h", + yanchor="top", + y=-0.2, + xanchor="right", + x=1 + ) + ) + if i % 2 == 0: + subcol1.plotly_chart(fig, use_container_width=True) + else: + subcol2.plotly_chart(fig, use_container_width=True) + +# plot spatial +def spatial_plots_cust_2(adata, umap_cur_col, umap_cur_groups, umap_color_col, plot_col): + with plot_col: + subcol3, subcol4 = st.columns(2) + for i, umap_cur_group in enumerate(umap_cur_groups): + if umap_cur_group == "All": + subDat = adata + else: + subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] + umap_coords = subDat.obs[['Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)']] + df = pd.DataFrame(umap_coords).reset_index().drop('index', axis = 1) + clustersList = list(subDat.obs[umap_color_col] ) + df[umap_color_col] = clustersList + df[umap_color_col] = df[umap_color_col].astype(str) + fig = px.scatter(df, + x="Centroid X (µm)_(standardized)", + y="Centroid Y (µm)_(standardized)", + color=umap_color_col, + title="Spatial " + umap_cur_group + #color_discrete_sequence=px.colors.sequential.Plasma + ) + fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots + fig.update_layout( + title=dict( + text="Spatial " + umap_cur_group, + x=0.5, # Center the title + xanchor='center', + yanchor='top', + ), + legend=dict( + orientation="h", + yanchor="top", + y=-0.2, + xanchor="right", + x=1 + ), + xaxis=dict( + scaleanchor="y", + scaleratio=1) + ) + if i % 2 == 0: + subcol3.plotly_chart(fig, use_container_width=True) + else: + subcol4.plotly_chart(fig, use_container_width=True) + +# make Umaps and Spatial Plots +def make_all_plots(): + # make umaps + phenocluster__plotly_umaps(st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col']) + # make spatial plots + spatial_plots_cust_2(st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col']) + + +# make Umaps and Spatial Plots +def make_all_plots(): + # make umaps + phenocluster__plotly_umaps(st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col']) + # make spatial plots + spatial_plots_cust_2(st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col']) + + +# default session state values +def phenocluster__default_session_state(): + + if 'phenocluster__subset_data' not in st.session_state: + st.session_state['phenocluster__subset_data'] = False + + if 'phenocluster__cluster_method' not in st.session_state: + st.session_state['phenocluster__cluster_method'] = "phenograph" + + if 'phenocluster__resolution' not in st.session_state: + st.session_state['phenocluster__resolution'] = 1.0 + + if 'phenocluster__n_jobs' not in st.session_state: + st.session_state['phenocluster__n_jobs'] = 7 + + if 'phenocluster__n_iterations' not in st.session_state: + st.session_state['phenocluster__n_iterations'] = 5 + + if 'phenocluster__n_features' not in st.session_state: + st.session_state['phenocluster__n_features'] = 0 + + # phenograph options + if 'phenocluster__n_neighbors_state' not in st.session_state: + st.session_state['phenocluster__n_neighbors_state'] = 10 + + if 'phenocluster__phenograph_clustering_algo' not in st.session_state: + st.session_state['phenocluster__phenograph_clustering_algo'] = 'louvain' + + if 'phenocluster__phenograph_min_cluster_size' not in st.session_state: + st.session_state['phenocluster__phenograph_min_cluster_size'] = 10 + + if 'phenocluster__metric' not in st.session_state: + st.session_state['phenocluster__metric'] = 'euclidean' + + if 'phenocluster__phenograph_nn_method' not in st.session_state: + st.session_state['phenocluster__phenograph_nn_method'] = 'kdtree' + + if 'phenocluster__n_principal_components' not in st.session_state: + st.session_state['phenocluster__n_principal_components'] = 10 + + # parc options + # dist_std_local, jac_std_global, small_pop, random_seed, resolution_parameter, hnsw_param_ef_construction + if 'phenocluster__parc_dist_std_local' not in st.session_state: + st.session_state['phenocluster__parc_dist_std_local'] = 3 + + if 'phenocluster__parc_jac_std_global' not in st.session_state: + st.session_state['phenocluster__parc_jac_std_global'] = 0.15 + + if 'phenocluster__parc_small_pop' not in st.session_state: + st.session_state['phenocluster__parc_small_pop'] = 50 + + if 'phenocluster__random_seed' not in st.session_state: + st.session_state['phenocluster__random_seed'] = 42 + + if 'phenocluster__hnsw_param_ef_construction' not in st.session_state: + st.session_state['phenocluster__hnsw_param_ef_construction'] = 150 + + # utag options + #clustering_method ["leiden", "parc"]; resolutions; max_dist = 20 + if 'phenocluster__utag_clustering_method' not in st.session_state: + st.session_state['phenocluster__utag_clustering_method'] = 'leiden' + + if 'phenocluster__utag_max_dist' not in st.session_state: + st.session_state['phenocluster__utag_max_dist'] = 20 + + # umap options + #if 'phenocluster__umap_cur_col' not in st.session_state: + #st.session_state['phenocluster__umap_cur_col'] = "Image" + + if 'phenocluster__umap_color_col' not in st.session_state: + st.session_state['phenocluster__umap_color_col'] = "Cluster" + + if 'phenocluster__umap_cur_groups' not in st.session_state: + st.session_state['phenocluster__umap_cur_groups'] = ["All"] + + # differential intensity options + if 'phenocluster__de_col' not in st.session_state: + st.session_state['phenocluster__de_col'] = "Cluster" + + if 'phenocluster__de_sel_group' not in st.session_state: + st.session_state['phenocluster__de_sel_groups'] = ["All"] + + if 'phenocluster__plot_diff_intensity_method' not in st.session_state: + st.session_state['phenocluster__plot_diff_intensity_method'] = "Rank Plot" + + if 'phenocluster__plot_diff_intensity_n_genes' not in st.session_state: + st.session_state['phenocluster__plot_diff_intensity_n_genes'] = 10 + + +# subset data set +def phenocluster__subset_data(adata, subset_col, subset_vals): + adata_subset = adata[adata.obs[subset_col].isin(subset_vals)] + st.session_state['phenocluster__clustering_adata'] = adata_subset + +def phenocluster_select_high_var_features(adata, n_features): + sc.pp.highly_variable_features(adata, n_top_features=n_features) + adata = adata[:, adata.var.highly_variable] + print(adata) + return adata + +def phenocluster__add_clusters_to_input_df(): + if "phenocluster__phenotype_cluster_cols" in st.session_state: + cur_df = st.session_state['input_dataset'].data + cur_df = cur_df.drop(columns=st.session_state["phenocluster__phenotype_cluster_cols"]) + st.session_state['input_dataset'].data = cur_df + print(pd.unique(st.session_state['phenocluster__clustering_adata'].obs['Cluster'])) + st.session_state['input_dataset'].data["Phenotype_Cluster"] = 'Phenotype ' + str(st.session_state['phenocluster__clustering_adata'].obs["Cluster"]) + print(st.session_state['input_dataset'].data["Phenotype_Cluster"]) + dummies = pd.get_dummies(st.session_state['phenocluster__clustering_adata'].obs["Cluster"], prefix='Phenotype Cluster').astype(int) + #dummies = dummies.replace({1: '+', 0: '-'}) + cur_df = pd.concat([st.session_state['input_dataset'].data, dummies], axis=1) + st.session_state['input_dataset'].data = cur_df + new_cluster_cols = list(dummies.columns) + st.session_state["phenocluster__phenotype_cluster_cols"] = new_cluster_cols + print(st.session_state["phenocluster__phenotype_cluster_cols"]) + +# check that only numeric columns are included in the adata.X +def phenocluster__check_input_dat(input_dat, numeric_cols): + for cur_col in numeric_cols: + if pd.api.types.is_numeric_dtype(input_dat[cur_col]): + pass + else: + st.error("Column " + cur_col + " is not numeric. Only numeric columns can be included in the matrix", + icon="🚨") + + +# main +def main(): + """ + Main function for the page. + """ + #st.write(st.session_state['unifier__df'].head()) + phenocluster__col_0, phenocluster__col_0a = st.columns([10,1]) + + phenocluster__col1, phenocluster__col2 = st.columns([2, 6]) + # set default values + phenocluster__default_session_state() + + + # make layout with columns + # options + + try: + with phenocluster__col_0: + st.multiselect('Select numeric columns for clustering:', options = st.session_state['input_dataset'].data.columns, + key='phenocluster__X_cols') + + numeric_cols = st.session_state['phenocluster__X_cols'] + phenocluster__check_input_dat(input_dat=st.session_state['input_dataset'].data, numeric_cols=numeric_cols) + + st.multiselect('Select columns for metadata:', options = st.session_state['input_dataset'].data.columns, + key='phenocluster__meta_cols') + + + meta_columns = st.session_state['phenocluster__meta_cols'] + #Add the new items if they don't already exist in the list + items_to_add = ['Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)'] + for item in items_to_add: + if item not in st.session_state['phenocluster__meta_cols']: + st.session_state['phenocluster__meta_cols'].append(item) + + st.toggle("Z-score normalize columns", key='phenocluster__zscore_normalize') + st.toggle("Normalize total intensity", key='phenocluster__normalize_total_intensity') + st.toggle("Log normalize", key='phenocluster__log_normalize') + st.toggle("Select high variance features", key='phenocluster__select_high_var_features') + if st.session_state['phenocluster__select_high_var_features'] == True: + st.number_input(label = "Number of features", key='phenocluster__n_features', step = 1) + + + if st.button('Submit columns'): + st.session_state['phenocluster__clustering_adata'] = phenocluster__make_adata(st.session_state['input_dataset'].data, + numeric_cols, + meta_columns, + z_normalize = st.session_state['phenocluster__zscore_normalize'], + normalize_total = st.session_state['phenocluster__normalize_total_intensity'], + log_normalize = st.session_state['phenocluster__log_normalize'], + select_high_var_features = st.session_state['phenocluster__select_high_var_features'], + n_features = st.session_state['phenocluster__n_features'] + ) + except: + st.warning("container issue") + + try: + + if 'phenocluster__clustering_adata' in st.session_state: + + with phenocluster__col1: + + # subset data + st.checkbox('Subset Data', key='phenocluster__subset_data', help = '''Subset data based on a variable''') + if st.session_state['phenocluster__subset_data'] == True: + st.session_state['phenocluster__subset_options'] = list(st.session_state['phenocluster__clustering_adata'].obs.columns) + phenocluster__subset_col = st.selectbox('Select column for subsetting:', st.session_state['phenocluster__subset_options']) + st.session_state["phenocluster__subset_col"] = phenocluster__subset_col + st.session_state['phenocluster__subset_values_options'] = list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state["phenocluster__subset_col"]])) + phenocluster__subset_vals = st.multiselect('Select a group for subsetting:', options = st.session_state['phenocluster__subset_values_options'], key='phenocluster__subset_vals_1') + st.session_state["phenocluster__subset_vals"] = phenocluster__subset_vals + if st.button('Subset Data'): + phenocluster__subset_data(st.session_state['phenocluster__clustering_adata'], + st.session_state["phenocluster__subset_col"], + st.session_state["phenocluster__subset_vals"]) + + + clusteringMethods = ['phenograph', 'scanpy', 'parc', 'utag'] + selected_clusteringMethod = st.selectbox('Select Clustering method:', clusteringMethods, + key='clusteringMethods_dropdown') + + # Update session state on every change + st.session_state['phenocluster__cluster_method'] = selected_clusteringMethod + + # default widgets + + st.number_input(label = "Number of Principal Components", key='phenocluster__n_principal_components', step = 1, + help='''Number of principal components to use for clustering. + If 0, Clustering will be performed on a numeric matrx (0 cannot be used for UTAG clustering)''') + + #st.session_state['phenocluster__n_neighbors_state'] = st.number_input(label = "K Nearest Neighbors", + # value=st.session_state['phenocluster__n_neighbors_state']) + st.number_input(label = "K Nearest Neighbors", + key='phenocluster__n_neighbors_state', step = 1, + help = '''The size of local neighborhood (in terms of number of neighboring data points) used for manifold approximation. + Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. + In general values should be in the range 2 to 100''') + + st.number_input(label = "Clustering resolution", key='phenocluster__resolution', step = 0.1,format="%.1f", + help = '''A parameter value controlling the coarseness of the clustering. + Higher values lead to more clusters''') + + st.number_input(label = "n_jobs", key='phenocluster__n_jobs', step=1, + help = '''N threads to use''') + + st.number_input(label = "n_iterations", key='phenocluster__n_iterations', step=1, + help = '''N iterations to use for leiden clustering''') + + if st.session_state['phenocluster__cluster_method'] == "phenograph": + st.selectbox('Phenograph clustering algorithm:', ['louvain', 'leiden'], key='phenocluster__phenograph_clustering_algo') + st.number_input(label = "Phenograph min cluster size", key='phenocluster__phenograph_min_cluster_size', step = 1, + help = ''' + Cells that end up in a cluster smaller than min_cluster_size are considered + outliers and are assigned to -1 in the cluster labels + ''') + st.selectbox('Distance metric:', ['euclidean', 'manhattan', 'correlation', 'cosine'], key='phenocluster__metric', + help='''Distance metric to define nearest neighbors.''') + st.selectbox('Phenograph nn method:', ['kdtree', 'brute'], key='phenocluster__phenograph_nn_method', + help = '''Whether to use brute force or kdtree for nearest neighbor search.''') + st.checkbox('Fast', key='phenocluster__fast', help = '''Use aproximate nearest neigbour search''') + + elif st.session_state['phenocluster__cluster_method'] == "scanpy": + st.selectbox('Distance metric:', ['euclidean', 'manhattan', 'correlation', 'cosine'], key='phenocluster__metric', + help='''Distance metric to define nearest neighbors.''') + st.checkbox('Fast', key='phenocluster__scanpy_fast', help = '''Use aproximate nearest neigbour search''') + if st.session_state['phenocluster__scanpy_fast'] == True: + st.selectbox('Transformer:', ['Annoy', 'PNNDescent'], key='phenocluster__scanpy_transformer', + help = '''Transformer for the approximate nearest neigbours search''') + else: + st.session_state["phenocluster__scanpy_transformer"] = None + + elif st.session_state['phenocluster__cluster_method'] == "parc": + # make parc specific widgets + st.number_input(label = "Parc dist std local", key='phenocluster__parc_dist_std_local', step = 1, + help = '''local pruning threshold: the number of standard deviations above the mean minkowski + distance between neighbors of a given node. + The higher the parameter, the more edges are retained.''') + st.number_input(label = "Parc jac std global", key='phenocluster__parc_jac_std_global', step = 0.01, + help = '''Global level graph pruning. This threshold can also be set as the number of standard deviations below the network's + mean-jaccard-weighted edges. 0.1-1 provide reasonable pruning. higher value means less pruning. + e.g. a value of 0.15 means all edges that are above mean(edgeweight)-0.15*std(edge-weights) are retained.''') + st.number_input(label = "Minimum cluster size to be considered a separate population", + key='phenocluster__parc_small_pop', step = 1, + help = '''Smallest cluster population to be considered a community.''') + st.number_input(label = "Random seed", key='phenocluster__random_seed', step = 1, + help = '''enable reproducible Leiden clustering''') + st.number_input(label = "HNSW exploration factor for construction", + key='phenocluster__hnsw_param_ef_construction', step = 1, + help = '''Higher value increases accuracy of index construction. + Even for several 100,000s of cells 150-200 is adequate''') + st.checkbox('Fast', key='phenocluster__fast', help = '''Use aproximate nearest neigbour search''') + elif st.session_state['phenocluster__cluster_method'] == "utag": + # make utag specific widgets + #st.selectbox('UTAG clustering method:', ['leiden', 'parc'], key='phenocluster__utag_clustering_method') + st.number_input(label = "UTAG max dist", key='phenocluster__utag_max_dist', step = 1, + help = '''Threshold euclidean distance to determine whether a pair of cell is adjacent in graph structure. + Recommended values are between 10 to 100 depending on magnification.''') + st.checkbox('Fast', key='phenocluster__utag_fast', help = '''Use aproximate nearest neigbour search''') + if st.session_state['phenocluster__utag_fast'] == True: + st.selectbox('Transformer:', ['Annoy', 'PNNDescent'], key='phenocluster__utag_transformer', + help = '''Transformer for the approximate nearest neigbours search''') + else: + st.session_state["phenocluster__utag_transformer"] = None + + # add options if clustering has been run + # add options if clustering has been run + if st.button('Run Clustering'): + start_time = time.time() + if st.session_state['phenocluster__cluster_method'] == "phenograph": + with st.spinner('Wait for it...'): + st.session_state['phenocluster__clustering_adata'] = RunPhenographClust(adata=st.session_state['phenocluster__clustering_adata'], + n_neighbors=st.session_state['phenocluster__n_neighbors_state'], + clustering_algo=st.session_state['phenocluster__phenograph_clustering_algo'], + min_cluster_size=st.session_state['phenocluster__phenograph_min_cluster_size'], + primary_metric=st.session_state['phenocluster__metric'], + resolution_parameter=st.session_state['phenocluster__resolution'], + nn_method=st.session_state['phenocluster__phenograph_nn_method'], + random_seed=st.session_state['phenocluster__random_seed'], + n_principal_components=st.session_state['phenocluster__n_principal_components'], + n_jobs=st.session_state['phenocluster__n_jobs'], + n_iterations= st.session_state['phenocluster__n_iterations'], + fast=st.session_state["phenocluster__fast"] + ) + elif st.session_state['phenocluster__cluster_method'] == "scanpy": + with st.spinner('Wait for it...'): + st.session_state['phenocluster__clustering_adata'] = RunNeighbClust(adata=st.session_state['phenocluster__clustering_adata'], + n_neighbors=st.session_state['phenocluster__n_neighbors_state'], + metric=st.session_state['phenocluster__metric'], + resolution=st.session_state['phenocluster__resolution'], + random_state=st.session_state['phenocluster__random_seed'], + n_principal_components=st.session_state['phenocluster__n_principal_components'], + n_jobs=st.session_state['phenocluster__n_jobs'], + n_iterations= st.session_state['phenocluster__n_iterations'], + fast=st.session_state["phenocluster__scanpy_fast"], + transformer = st.session_state["phenocluster__scanpy_transformer"] + ) + #st.session_state['phenocluster__clustering_adata'] = adata + elif st.session_state['phenocluster__cluster_method'] == "parc": + with st.spinner('Wait for it...'): + st.session_state['phenocluster__clustering_adata'] = run_parc_clust(adata=st.session_state['phenocluster__clustering_adata'], + n_neighbors=st.session_state['phenocluster__n_neighbors_state'], + dist_std_local=st.session_state['phenocluster__parc_dist_std_local'], + jac_std_global= st.session_state['phenocluster__parc_jac_std_global'], + small_pop=st.session_state['phenocluster__parc_small_pop'], + random_seed=st.session_state['phenocluster__random_seed'], + resolution_parameter=st.session_state['phenocluster__resolution'], + hnsw_param_ef_construction=st.session_state['phenocluster__hnsw_param_ef_construction'], + n_principal_components=st.session_state['phenocluster__n_principal_components'], + n_jobs=st.session_state['phenocluster__n_jobs'], + n_iterations= st.session_state['phenocluster__n_iterations'], + fast=st.session_state["phenocluster__fast"] + ) + elif st.session_state['phenocluster__cluster_method'] == "utag": + #phenocluster__utag_resolutions = [st.session_state['phenocluster__resolution']] + with st.spinner('Wait for it...'): + st.session_state['phenocluster__clustering_adata'] = run_utag_clust(adata=st.session_state['phenocluster__clustering_adata'], + n_neighbors=st.session_state['phenocluster__n_neighbors_state'], + resolution=st.session_state['phenocluster__resolution'], + clustering_method=st.session_state['phenocluster__utag_clustering_method'], + max_dist=st.session_state['phenocluster__utag_max_dist'], + n_principal_components=st.session_state['phenocluster__n_principal_components'], + random_state=st.session_state['phenocluster__random_seed'], + n_jobs=st.session_state['phenocluster__n_jobs'], + n_iterations= st.session_state['phenocluster__n_iterations'], + fast=st.session_state["phenocluster__utag_fast"], + transformer = st.session_state["phenocluster__utag_transformer"] + ) + # save clustering result + #st.session_state['phenocluster__clustering_adata'].write("input/clust_dat.h5ad") + end_time = time.time() + execution_time = end_time - start_time + rounded_time = round(execution_time, 2) + st.write('Execution time: ', rounded_time, 'seconds') + + + # umap + if 'Cluster' in st.session_state['phenocluster__clustering_adata'].obs.columns: + + st.session_state['phenocluster__umeta_columns'] = list(st.session_state['phenocluster__clustering_adata'].obs.columns) + st.session_state['phenocluster__umap_color_col_index'] = st.session_state['phenocluster__umeta_columns'].index(st.session_state['phenocluster__umap_color_col']) + #st.write(st.session_state['phenocluster__umap_color_col_index']) + + # select column for umap coloring + st.session_state['phenocluster__umap_color_col'] = st.selectbox('Select column for groups coloring:', + st.session_state['phenocluster__umeta_columns'], + index=st.session_state['phenocluster__umap_color_col_index'] + ) + + # select column for umap subsetting + st.session_state['phenocluster__umap_cur_col'] = st.selectbox('Select column to subset plots:', + st.session_state['phenocluster__umeta_columns'], key='phenocluster__umap_col_dropdown_subset' + ) + + # list of available subsetting options + umap_cur_groups= ["All"] + list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state['phenocluster__umap_cur_col']])) + umap_sel_groups = st.multiselect('Select groups to be plotted', + options = umap_cur_groups) + st.session_state['phenocluster__umap_cur_groups'] = umap_sel_groups + + st.button('Make Spatial Plots' , on_click=spatial_plots_cust_2, args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col'], + phenocluster__col2 + ] + ) + + st.button("Compute UMAP", on_click=phenocluster__scanpy_umap, args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__n_neighbors_state'], + st.session_state['phenocluster__metric'], + st.session_state['phenocluster__n_principal_components'] + ] + ) + if 'X_umap' in st.session_state['phenocluster__clustering_adata'].obsm.keys(): + st.button('Plot UMAPs' , on_click=phenocluster__plotly_umaps, + args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col'], + phenocluster__col2 + ] + ) + + st.button('Add Clusters to Input Data' , on_click=phenocluster__add_clusters_to_input_df) + + + except: + st.warning("container issue") + + + + + +# Run the main function +if __name__ == '__main__': + main() + +# need to make differential expression on another page diff --git a/pages2/Pheno_Cluster_b.py b/pages2/Pheno_Cluster_b.py new file mode 100644 index 0000000..f04f0bb --- /dev/null +++ b/pages2/Pheno_Cluster_b.py @@ -0,0 +1,471 @@ +# Import relevant libraries +from ast import arg +from pyparsing import col +import streamlit as st +import streamlit_dataframe_editor as sde +import streamlit as st +import pandas as pd +import anndata as ad +import scanpy as sc +import seaborn as sns +import os +import matplotlib.pyplot as plt +import phenograph +#import parc +import numpy as np +import scanpy.external as sce +import plotly.express as px +import streamlit_dataframe_editor as sde +import basic_phenotyper_lib as bpl +import nidap_dashboard_lib as ndl +import plotnine +from plotnine import * +import math + + +# Functions + +# clusters differential expression +def phenocluster__diff_expr(adata, phenocluster__de_col, phenocluster__de_sel_groups, only_positive): + sc.tl.rank_genes_groups(adata, groupby = phenocluster__de_col, method="wilcoxon", layer="counts") + + if "All" in phenocluster__de_sel_groups: + phenocluster__de_results = sc.get.rank_genes_groups_df(adata, group=None) + else: + phenocluster__de_results = sc.get.rank_genes_groups_df(adata, group=phenocluster__de_sel_groups) + + phenocluster__de_results[['pvals', 'pvals_adj']] = phenocluster__de_results[['pvals', 'pvals_adj']].applymap('{:.1e}'.format) + #st.dataframe(phenocluster__de_results, use_container_width=True) + if only_positive: + phenocluster__de_results_filt = phenocluster__de_results[phenocluster__de_results['logfoldchanges'] > 0].reset_index(drop=True) + st.session_state['phenocluster__de_results'] = phenocluster__de_results_filt + else: + st.session_state['phenocluster__de_results'] = phenocluster__de_results + + #st.session_state['phenocluster__de_markers'] = pd.unique(st.session_state['phenocluster__de_results']["names"]) + + +# change cluster names +def phenocluster__edit_cluster_names(adata, edit_names_result): + adata.obs['Edit_Cluster'] = adata.obs['Cluster'].map(edit_names_result.set_index('Cluster')['New_Name']) + st.session_state['phenocluster__clustering_adata'] = adata + +def phenocluster__edit_cluster_names_2(adata, edit_names_result): + edit_names_result_2 = edit_names_result.reconstruct_edited_dataframe() + adata.obs['Edit_Cluster'] = adata.obs['Cluster'].map(dict(zip(edit_names_result_2['Cluster'].to_list(), edit_names_result_2['New_Name'].to_list()))) + st.session_state['phenocluster__clustering_adata'] = adata + +# make differential intensity plots +def phenocluster__plot_diff_intensity(adata, groups, method, n_genes, cur_col): + if "All" in groups: + cur_groups = None + else: + cur_groups = groups + + if method == "Rank Plot": + cur_fig = sc.pl.rank_genes_groups(adata, n_genes=n_genes, + groups=cur_groups, sharey=False, fontsize=20, show=False) + elif method == "Dot Plot": + cur_fig = sc.pl.rank_genes_groups_dotplot(adata, n_genes=n_genes, + groups=cur_groups) + elif method == "Heat Map": + # cur_fig = sc.pl.rank_genes_groups_heatmap(adata, n_genes=n_genes, + # groups=cur_groups) + #sc.pp.normalize_total(adata) + #sc.pp.log1p(adata) + #sc.pp.scale(adata) + + if "Edit_Cluster" in adata.obs.columns: + adata_sub = adata[adata.obs['Edit_Cluster'].isin(cur_groups)] + top_names = pd.unique(st.session_state['phenocluster__de_results'].groupby('group')['names'].apply(lambda x: x.head(n_genes))) + cluster_group = "Edit_Cluster" + cur_fig = sc.pl.heatmap(adata_sub, top_names, groupby=cluster_group, swap_axes=False) + else: + adata_sub = adata[adata.obs['Cluster'].isin(cur_groups)] + top_names = pd.unique(st.session_state['phenocluster__de_results'].groupby('group')['names'].apply(lambda x: x.head(n_genes))) + cluster_group = "Cluster" + cur_fig = sc.pl.heatmap(adata_sub, top_names, groupby=cluster_group, swap_axes=False) + elif method == "Violin Plot": + cur_fig = sc.pl.rank_genes_groups_stacked_violin(adata, n_genes=n_genes, + groups=cur_groups, split = False) + # elif method == "UMAP" and "X_umap" in adata.obsm.keys(): + # adata_sub = adata[adata.obs['Cluster'].isin(cur_groups)] + # top_names = pd.unique(st.session_state['phenocluster__de_results'].groupby('group')['names'].apply(lambda x: x.head(n_genes))) + # with cur_col: + # cur_fig = st.pyplot(sc.pl.umap(adata, color=[*top_names], legend_loc="on data",frameon=True, + # ncols=3, show=True, + # wspace = 0.2 ,save = False), use_container_width = True , clear_figure = True) + +def phenocluster__plot_diff_intensity_2(adata, groups, method, n_genes, plot_column): + if "All" in groups: + cur_groups = list(pd.unique(st.session_state['phenocluster__de_results']["group"])) + else: + cur_groups = groups + adata_norm = adata.copy() + sc.pp.normalize_total(adata_norm) + sc.pp.log1p(adata_norm) + sc.pp.scale(adata_norm) + if "Edit_Cluster" in adata.obs.columns: + cluster_group = "Edit_Cluster" + else: + cluster_group = "Cluster" + + adata_sub = adata_norm[adata_norm.obs[cluster_group].isin(cur_groups)] + top_names = pd.unique(st.session_state['phenocluster__de_results'].groupby('group')['names'].apply(lambda x: x.head(n_genes))) + + with plot_column: + if method == "Heat Map": + obs_df = adata_sub.to_df().fillna(0) + obs_df[cluster_group] = adata_sub.obs[cluster_group] + mean_per_group = obs_df.groupby(cluster_group).mean() + matrix_avg = mean_per_group.stack().reset_index() + matrix_avg.columns = [cluster_group, "Marker", "Intensity"] + plot_mat = matrix_avg[matrix_avg["Marker"].isin(top_names)].reset_index(drop=True) + plot_mat["Marker"] = pd.Categorical(plot_mat["Marker"], categories=top_names[::-1], ordered=True) + plotnine.options.figure_size = (10, 10) + plot = ( + ggplot(plot_mat, aes(cluster_group, "Marker")) + + geom_tile(mapping = aes(fill = "Intensity")) + + scale_fill_distiller(type = 'div', palette = 'RdYlBu') + + theme(axis_text_x=element_text(rotation=0, hjust=0.5, size=28)) + + theme(axis_text_y=element_text(rotation=0, hjust=1, size=16)) + + theme(axis_title_x = element_blank(), axis_title_y = element_text(angle=90)) + + theme(text=element_text(size=16)) + ) + + st.pyplot(ggplot.draw(plot), use_container_width=True) + + elif method == "UMAP": + obs_df = adata_norm[:, top_names].to_df().reset_index(drop=True) + umap_coords = adata_norm.obsm['X_umap'] + umap_df = pd.DataFrame(umap_coords, columns=['UMAP_1', 'UMAP_2']).reset_index(drop=True) + umap_df["Cells_Id"] = umap_df.index + obs_df = pd.concat([obs_df, umap_df], axis=1) + # Get the column names excluding 'UMAP_1' and 'UMAP_2' + columns = top_names + # Calculate the number of rows needed for the subplots + n_rows = int(np.ceil(len(columns) / 4)) + # Create a figure with subplots + fig, axs = plt.subplots(n_rows, 4, figsize=(20, 5*n_rows)) + # Flatten the axes array to make it easier to iterate over + axs = axs.flatten() + for ax, col in zip(axs, columns): + plot = sns.scatterplot(data=obs_df, x='UMAP_1', y='UMAP_2', hue=col, + palette='viridis', ax=ax, s=8) + ax.set_title(col) + plot.legend(loc='upper left', bbox_to_anchor=(1, 1)) + # Remove any unused subplots + for ax in axs[len(columns):]: + ax.remove() + plt.tight_layout() + st.pyplot(fig, use_container_width=True) + elif method == "Rank Plot": + #obs_df = st.session_state['phenocluster__de_results'][st.session_state['phenocluster__de_results']['group'].isin(cur_groups)].reset_index(drop=True) + top_n_df = st.session_state['phenocluster__de_results'].groupby('group').head(n_genes).reset_index(drop=True) + top_n_df["ranking"]= top_n_df.groupby('group')['names'].cumcount() + + # Get the unique groups + groups = top_n_df['group'].unique() + + # Calculate the number of rows needed for the subplots + rows = math.ceil(len(groups) / 2) + cols = 2 + + # Create a figure with subplots + fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows)) + + # Flatten the axes array for easier indexing + axes = axes.flatten() + + # For each group, create a scatter plot + for i, group in enumerate(groups): + df_group = top_n_df[top_n_df['group'] == group] + + # Create scatter plot + sns.scatterplot(x='ranking', y='scores', data=df_group, ax=axes[i], alpha=0) + + # Add text labels + for _, row in df_group.iterrows(): + axes[i].text(row['ranking'], row['scores'], row['names'], + ha='center', va='bottom', rotation='vertical', fontsize=12) # Increase fontsize here + + # Set the x-axis range and ticks + axes[i].set_xlim(df_group['ranking'].min()-2, df_group['ranking'].max()+2) + axes[i].set_xticks(range(int(df_group['ranking'].min()-2), int(df_group['ranking'].max()+2), 2)) + axes[i].set_ylim(df_group['scores'].min()-2, df_group['scores'].max()+10) + + # Set titles and labels + axes[i].set_title(group, fontsize=16) # Increase fontsize here + axes[i].set_xlabel('ranking', fontsize=14) # Increase fontsize here + axes[i].set_ylabel('scores', fontsize=14) # Increase fontsize here + axes[i].tick_params(axis='x', labelsize=14) # Increase fontsize here + axes[i].tick_params(axis='y', labelsize=14) + + # Remove any unused subplots + for j in range(i+1, len(axes)): + fig.delaxes(axes[j]) + + # Adjust the layout and show the plot + plt.tight_layout() + # Show the plot + st.pyplot(fig, use_container_width=True) + +def data_editor_change_callback(): + ''' + data_editor_change_callback is a callback function for the streamlit data_editor widget + which updates the saved value of the user-created changes after every instance of the + data_editor on_change method. This ensures the dashboard can remake the edited data_editor + when the user navigates to a different page. + ''' + + st.session_state.df = bpl.assign_phenotype_custom(st.session_state.df, st.session_state['phenocluster__edit_names_result_2a'].reconstruct_edited_dataframe()) + + # Create Phenotypes Summary Table based on 'phenotype' column in df + st.session_state.pheno_summ = bpl.init_pheno_summ(st.session_state.df) + + # Perform filtering + st.session_state.df_filt = ndl.perform_filtering(st.session_state) + + # Set Figure Objects based on updated df + st.session_state = ndl.setFigureObjs(st.session_state, st.session_state.pointstSliderVal_Sel) + +def phenocluster__plotly_umaps_b(adata, umap_cur_col, umap_cur_groups, umap_color_col, plot_column): + with plot_column: + subcol1, subcol2 = st.columns(2) + for i, umap_cur_group in enumerate(umap_cur_groups): + if umap_cur_group == "All": + subDat = adata + else: + subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] + umap_coords = subDat.obsm['X_umap'] + df = pd.DataFrame(umap_coords, columns=['UMAP_1', 'UMAP_2']) + clustersList = list(subDat.obs[umap_color_col] ) + df[umap_color_col] = clustersList + df[umap_color_col] = df[umap_color_col].astype(str) + # Create the seaborn plot + fig = px.scatter(df, + x="UMAP_1", + y="UMAP_2", + color=umap_color_col, + title="UMAP " + umap_cur_group + #color_discrete_sequence=px.colors.sequential.Plasma + ) + fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots + fig.update_layout( + title=dict( + text="UMAP " + umap_cur_group, + x=0.5, # Center the title + xanchor='center', + yanchor='top' + ), + legend=dict( + orientation="h", + yanchor="top", + y=-0.2, + xanchor="right", + x=1 + ) + ) + if i % 2 == 0: + subcol1.plotly_chart(fig, use_container_width=True) + else: + subcol2.plotly_chart(fig, use_container_width=True) + +def spatial_plots_cust_2b(adata, umap_cur_col, umap_cur_groups, umap_color_col, plot_column): + with plot_column: + subcol3, subcol4 = st.columns(2) + for i, umap_cur_group in enumerate(umap_cur_groups): + if umap_cur_group == "All": + subDat = adata + else: + subDat = adata[adata.obs[umap_cur_col] == umap_cur_group] + umap_coords = subDat.obs[['Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)']] + df = pd.DataFrame(umap_coords).reset_index().drop('index', axis = 1) + clustersList = list(subDat.obs[umap_color_col] ) + df[umap_color_col] = clustersList + df[umap_color_col] = df[umap_color_col].astype(str) + fig = px.scatter(df, + x="Centroid X (µm)_(standardized)", + y="Centroid Y (µm)_(standardized)", + color=umap_color_col, + title="Spatial " + umap_cur_group + #color_discrete_sequence=px.colors.sequential.Plasma + ) + fig.update_traces(marker=dict(size=3)) # Adjust the size of the dots + fig.update_layout( + title=dict( + text="Spatial " + umap_cur_group, + x=0.5, # Center the title + xanchor='center', + yanchor='top' + ), + legend=dict( + orientation="h", + yanchor="top", + y=-0.2, + xanchor="right", + x=1 + ) + ) + if i % 2 == 0: + subcol3.plotly_chart(fig, use_container_width=True) + else: + subcol4.plotly_chart(fig, use_container_width=True) + +def phenocluster__add_edit_clusters_to_input_df(): + if "phenocluster__phenotype_cluster_cols" in st.session_state: + cur_df = st.session_state['input_dataset'].data + cur_df = cur_df.drop(columns=st.session_state["phenocluster__phenotype_cluster_cols"]) + st.session_state['input_dataset'].data = cur_df + + st.session_state['input_dataset'].data["Phenotype_Cluster"] = 'Phenotype ' + st.session_state['phenocluster__clustering_adata'].obs['Edit_Cluster'].astype(str) + dummies = pd.get_dummies(st.session_state['phenocluster__clustering_adata'].obs['Edit_Cluster'], prefix='Phenotype Cluster').astype(int) + #dummies = dummies.replace({1: '+', 0: '-'}) + cur_df = pd.concat([st.session_state['input_dataset'].data, dummies], axis=1) + st.session_state['input_dataset'].data = cur_df + new_cluster_cols = list(dummies.columns) + st.session_state["phenocluster__phenotype_cluster_cols"] = new_cluster_cols + +def main(): + if 'phenocluster__clustering_adata' not in st.session_state: + st.error("Please run the clustering step first", + icon="🚨") + return + + if "Cluster" not in st.session_state['phenocluster__clustering_adata'].obs.columns: + st.error("Please run the clustering step first", + icon="🚨") + return + + phenocluster__col1b, phenocluster__col2b = st.columns([2, 6]) + phenocluster__col3b, phenocluster__col4b = st.columns([2, 6]) + sc.set_figure_params(figsize=(10, 10), fontsize = 16) + if 'phenocluster__dif_int_plot_methods' not in st.session_state: + st.session_state['phenocluster__dif_int_plot_methods'] = ["Rank Plot", "Heat Map", "UMAP"] + + with phenocluster__col1b: + # differential expression + phenocluster__de_col_options = list(st.session_state['phenocluster__clustering_adata'].obs.columns) + st.selectbox('Select column for differential intensity:', phenocluster__de_col_options, key='phenocluster__de_col') + phenocluster__de_groups = ["All"] + list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state['phenocluster__de_col']])) + phenocluster__de_selected_groups = st.multiselect('Select group for differential intensity table:', options = phenocluster__de_groups) + st.session_state['phenocluster__de_sel_groups'] = phenocluster__de_selected_groups + st.checkbox('Only Positive Markers', key='phenocluster__de_only_positive') + # Differential expression + + st.button('Run Differential Intensity', on_click=phenocluster__diff_expr, args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__de_col'], + st.session_state['phenocluster__de_sel_groups'], + st.session_state['phenocluster__de_only_positive'] + ]) + if 'phenocluster__de_results' in st.session_state: + with phenocluster__col2b: + st.dataframe(st.session_state['phenocluster__de_results'], use_container_width=True) + + with phenocluster__col3b: + # Plot differential intensity + st.selectbox('Select Plot Type:', st.session_state['phenocluster__dif_int_plot_methods'], key='phenocluster__plot_diff_intensity_method') + st.number_input(label = "Number of markers to plot", + key = 'phenocluster__plot_diff_intensity_n_genes', + step = 1) + + # phenocluster__plot_diff_intensity(st.session_state['phenocluster__clustering_adata'], + # st.session_state['phenocluster__de_sel_groups'], + # st.session_state['phenocluster__plot_diff_intensity_method'], + # st.session_state['phenocluster__plot_diff_intensity_n_genes'], + # phenocluster__col4b) + st.button('Plot Markers', on_click=phenocluster__plot_diff_intensity_2, args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__de_sel_groups'], + st.session_state['phenocluster__plot_diff_intensity_method'], + st.session_state['phenocluster__plot_diff_intensity_n_genes'], + phenocluster__col4b + ]) + + # make plots for differential intensity markers + # if 'phenocluster__diff_intensity_plot' in st.session_state: + # with phenocluster__col4b: + # cur_fig = st.session_state['phenocluster__diff_intensity_plot'] + #st.pyplot(fig = cur_fig, use_container_width = True, clear_figure = False) + + phenocluster__col5b, phenocluster__col6b = st.columns([2, 6]) + + cur_clusters = list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs["Cluster"])) + edit_names_df = pd.DataFrame({"Cluster": cur_clusters, "New_Name": cur_clusters}) + st.session_state['phenocluster__edit_names_df'] = edit_names_df + + with phenocluster__col6b: + #st.table(st.session_state['phenocluster__edit_names_df']) + #edit_clustering_names = st.data_editor(edit_names_df) + #st.session_state['phenocluster__edit_names_result'] = edit_clustering_names + if 'phenocluster__edit_names_result_2' not in st.session_state: + st.session_state['phenocluster__edit_names_result_2'] = sde.DataframeEditor(df_name='phenocluster__edit_names_result_2a', default_df_contents=st.session_state['phenocluster__edit_names_df']) + + #st.session_state['phenocluster__edit_names_result_2'].dataframe_editor(on_change=data_editor_change_callback, reset_data_editor_button_text='Reset New Clusters Names') + st.session_state['phenocluster__edit_names_result_2'].dataframe_editor(reset_data_editor_button_text='Reset New Clusters Names') + + with phenocluster__col5b: + #Edit cluster names + st.button('Edit Clusters Names', on_click=phenocluster__edit_cluster_names_2, args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__edit_names_result_2'] + ]) + phenocluster__col7b, phenocluster__col8b= st.columns([2, 6]) + def make_all_plots_2(): + spatial_plots_cust_2b(st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col_2'], + phenocluster__col8b + ) + # make umaps plots + if 'X_umap' in st.session_state['phenocluster__clustering_adata'].obsm.keys(): + phenocluster__plotly_umaps_b(st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col_2'], + phenocluster__col8b + ) + with phenocluster__col7b: + # select column for umap coloring + st.session_state['phenocluster__umeta_columns'] = list(st.session_state['phenocluster__clustering_adata'].obs.columns) + if 'Edit_Cluster' in st.session_state['phenocluster__umeta_columns']: + st.session_state['phenocluster__umap_color_col_index'] = st.session_state['phenocluster__umeta_columns'].index('Edit_Cluster') + else: + st.session_state['phenocluster__umap_color_col_index'] = st.session_state['phenocluster__umeta_columns'].index(st.session_state['phenocluster__umap_color_col']) + + st.session_state['phenocluster__umap_color_col_2'] = st.selectbox('Select column to color groups:', + st.session_state['phenocluster__umeta_columns'], + index=st.session_state['phenocluster__umap_color_col_index'] + ) + # select column for umap subsetting + st.session_state['phenocluster__umap_cur_col'] = st.selectbox('Select column to subset plots:', + st.session_state['phenocluster__umeta_columns'], key='phenocluster__umap_col_dropdown_subset' + ) + # list of available subsetting options + umap_cur_groups= ["All"] + list(pd.unique(st.session_state['phenocluster__clustering_adata'].obs[st.session_state['phenocluster__umap_cur_col']])) + umap_sel_groups = st.multiselect('Select groups to be plotted', + options = umap_cur_groups) + st.session_state['phenocluster__umap_cur_groups'] = umap_sel_groups + + st.button('Make Spatial Plots' , on_click=spatial_plots_cust_2b, + args = [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col_2'], + phenocluster__col8b] + ) + + if 'X_umap' in st.session_state['phenocluster__clustering_adata'].obsm.keys(): + st.button("Make UMAPS", on_click=phenocluster__plotly_umaps_b, + args= [st.session_state['phenocluster__clustering_adata'], + st.session_state['phenocluster__umap_cur_col'], + st.session_state['phenocluster__umap_cur_groups'], + st.session_state['phenocluster__umap_color_col_2'], + phenocluster__col8b] + ) + + st.button('Add Edited Clusters to Input Data', on_click=phenocluster__add_edit_clusters_to_input_df) + + +# Run the main function +if __name__ == '__main__': + main() diff --git a/pages/03b_Run_workflow.py b/pages2/Run_workflow.py similarity index 95% rename from pages/03b_Run_workflow.py rename to pages2/Run_workflow.py index 7d43e6a..9705373 100644 --- a/pages/03b_Run_workflow.py +++ b/pages2/Run_workflow.py @@ -5,8 +5,7 @@ import time_cell_interaction_lib as tci # import the TIME library stored in time_cell_interaction_lib.py import time import streamlit_utils -import app_top_of_page as top -import streamlit_dataframe_editor as sde + def main(): ''' @@ -54,6 +53,8 @@ def main(): # Determine whether we should employ threading use_multiprocessing = st.checkbox('Should we use multiple logical CPUs to speed up the calculations?', key='use_multiprocessing') + # This isn't actually a good fix because it's only the Squidpy enrichment that shouldn't have multiprocessing, not the entire workflow, but we need to implement that in the future + # use_multiprocessing = st.checkbox('Should we use multiple logical CPUs to speed up the calculations?', key='use_multiprocessing', disabled=(st.session_state['settings__analysis__significance_calculation_method'] != 'Poisson (radius)')) # Get the number of threads to use for the calculations num_workers = st.number_input('Select number of threads for calculations:', min_value=1, max_value=os.cpu_count(), step=1, key='num_workers', disabled=(not use_multiprocessing)) @@ -213,18 +214,4 @@ def main(): # Call the main function if __name__ == '__main__': - - # Set a wide layout and display the page heading - st.set_page_config(layout="wide") - st.title('Run workflow') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/03a_Tool_parameter_selection.py b/pages2/Tool_parameter_selection.py similarity index 98% rename from pages/03a_Tool_parameter_selection.py rename to pages2/Tool_parameter_selection.py index b37d352..f37c173 100644 --- a/pages/03a_Tool_parameter_selection.py +++ b/pages2/Tool_parameter_selection.py @@ -6,10 +6,7 @@ import streamlit_utils import pprint import platform_io -import time_cell_interaction_lib as tci # import the TIME library stored in time_cell_interaction_lib.py import utils -import app_top_of_page as top -import streamlit_dataframe_editor as sde import dataset_formats import copy @@ -160,6 +157,9 @@ def update_dependencies_of_analysis_significance_calculation_method(): else: st.session_state['analysis_neighbor_radius_is_disabled'] = False st.session_state['analysis_n_neighs_is_disabled'] = True + # This isn't actually a good fix because it's only the Squidpy enrichment that shouldn't have multiprocessing, not the entire workflow, but we need to implement that in the future + # if st.session_state['settings__analysis__significance_calculation_method'] != 'Poisson (radius)': + # st.session_state['use_multiprocessing'] = False def set_session_state_key(settings, str1, str2): @@ -556,18 +556,4 @@ def main(): # Call the main function if __name__ == '__main__': - - # Set page settings - st.set_page_config(layout="wide", page_title='Tool parameter selection') - st.title('Tool parameter selection') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/04b_UMAP Analyzer.py b/pages2/UMAP_Analyzer.py similarity index 80% rename from pages/04b_UMAP Analyzer.py rename to pages2/UMAP_Analyzer.py index ef9c25e..d87f612 100644 --- a/pages/04b_UMAP Analyzer.py +++ b/pages2/UMAP_Analyzer.py @@ -3,11 +3,7 @@ ''' import streamlit as st from streamlit_extras.add_vertical_space import add_vertical_space - -# Import relevant libraries import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP -import app_top_of_page as top -import streamlit_dataframe_editor as sde def reset_phenotype_selection(): ''' @@ -21,6 +17,9 @@ def main(): Main function for running the page ''' + # Make a generic check to avoid neeeding to hunt down individual checks + rdy_to_plot = st.session_state.cluster_completed + # Toggles for different figures fig_toggle = st.columns([1, 1, 2]) with fig_toggle[0]: @@ -41,7 +40,7 @@ def main(): elif st.session_state.lineageDisplayToggle == 'Markers': st.session_state.umaplineages = st.session_state.umapMarks - if st.session_state.umap_completed: + if rdy_to_plot: st.session_state = ndl.setFigureObjs_UMAPDifferences(st.session_state) else: st.warning('No spatial UMAP analysis detected. Please complete Neighborhood Profiles') @@ -71,12 +70,12 @@ def main(): # FULL UMAP with umap_viz[0]: - if st.session_state.umap_completed: + if rdy_to_plot: st.pyplot(st.session_state.UMAPFig) # Inspection UMAP with umap_viz[1]: - if st.session_state.umap_completed: + if rdy_to_plot: st.pyplot(st.session_state.UMAPFigInsp) # Difference Measures @@ -100,12 +99,12 @@ def main(): diff_umap_col = st.columns(3) with diff_umap_col[0]: st.header('UMAP A') - if st.session_state.umap_completed: + if rdy_to_plot: st.pyplot(st.session_state.UMAPFigDiff0_Dens) st.pyplot(st.session_state.UMAPFigDiff0_Clus) with diff_umap_col[1]: st.header('UMAP B') - if st.session_state.umap_completed: + if rdy_to_plot: st.pyplot(st.session_state.UMAPFigDiff1_Dens) st.pyplot(st.session_state.UMAPFigDiff1_Clus) with diff_umap_col[2]: @@ -113,23 +112,8 @@ def main(): st.write('###') st.write('###') st.header('UMAP A - UMAP B') - if st.session_state.umap_completed: + if rdy_to_plot: st.pyplot(st.session_state.UMAPFigDiff2_Dens) if __name__ == '__main__': - - #Set a wide layout - st.set_page_config(page_title="UMAP Differences Analyzer", - layout="wide") - st.title('UMAP Differences Analyzer') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/__init__.py b/pages2/__init__.py similarity index 100% rename from pages/__init__.py rename to pages2/__init__.py diff --git a/pages2/adaptive_phenotyping.py b/pages2/adaptive_phenotyping.py new file mode 100644 index 0000000..1e34692 --- /dev/null +++ b/pages2/adaptive_phenotyping.py @@ -0,0 +1,462 @@ +# Import relevant libraries +import streamlit as st +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from pages2 import multiaxial_gating +import utils + +# Global variable +st_key_prefix = 'adaptive_phenotyping__' + + +def plotly_mean_and_sem(dfs, df_names): + + # Create a Plotly figure + fig = go.Figure() + + # For each dataframe and its name... + for df, df_name in zip(dfs, df_names): + + # Calculate mean and SEM over the groups + mean_values = df.mean(axis='columns') + sem_values = df.sem(axis='columns') + + # Add a scatter plot + fig.add_trace(go.Scatter( + x=mean_values.index, + y=mean_values, + error_y=dict( + type='data', # or 'percent' for percentage-based error bars + array=sem_values, + visible=True + ), + mode='markers+lines', # Use 'markers' or 'lines' or 'markers+lines' + name=df_name + )) + + # Customize layout + fig.update_layout( + title=f'Average Positive Percentage Mean with SEM as Error Bars', + xaxis_title='Z score', + yaxis_title='Average positive percentage mean', + showlegend=True + ) + + # Return the plotly figure + return fig + + +def get_box_and_whisker_data(df_grouped, df_thresholds, apply_thresh_to_selected_group, df, channel_for_phenotyping, column_identifying_baseline_signal, value_identifying_baseline, value_identifying_signal, row_selection, return_figure_and_summary=True): + + # If apply_thresh_to_selected_group (and not average_over_all_groups), the input into the function generate_box_and_whisker() corresponds to the selected group. The mean and std used are the defaults (None) for the function, i.e., those corresponding to the baseline group of the selected group. + # If not apply_thresh_to_selected_group (and not average_over_all_groups), the input into the function generate_box_and_whisker() corresponds to the entire dataset. The mean and std (which, when corresponding to a selection from df_thresholds, correspond to the baseline group) used are those from the selected group. --> this is the original phenotyping method (like a single T=0 threshold)! + # The thresholds (as in df_thresholds) are calculated from just baseline groups, whereas the groups of dataframes correspond to both the baseline and signal groups. + # If average_over_all_groups, then it's as if we're selecting the groups in turn via a for loop. We don't want to transform the entire dataset every time; we only want to transform each part of the full dataset once. Thus, it generally doesn't make sense to average_over_all_groups and (not apply_thresh_to_selected_group), but it does make sense to average_over_all_groups and apply_thresh_to_selected_group, which essentially thresholds each group using its own baseline group, i.e., corresponds to the new phenotyping method. + # Thus, here are the settings for the old and new phenotyping methods: + # Old: apply_thresh_to_selected_group=False, average_over_all_groups=False, DO have a particular group selected --> "Selected threshold applied to entire dataset" + # New: apply_thresh_to_selected_group=True, average_over_all_groups=True, DO NOT have a particular group selected --> "Group-specific threshold applied to each group" + # Note that there are no other combinations of these settings that will transform each part of the full dataset once, so the above is complete! + + # Obtain the index and dataframe of the group identified by row_selection + if isinstance(df_grouped, list): + current_index = None + df_selected = df_grouped[0][1] + else: + current_index = df_thresholds.iloc[row_selection].name + df_selected = df_grouped.get_group(current_index) + + # Obtain the dataframe to actually transform as well as the mean and std to use for the transformation + if apply_thresh_to_selected_group: + df_transform = df_selected + mean_for_zscore_calc = None + std_for_zscore_calc = None + else: + df_transform = df + ser_selected = df_thresholds.iloc[row_selection] + mean_for_zscore_calc = ser_selected.loc['z score = 0'] + std_for_zscore_calc = ser_selected.loc['z score = 1'] - mean_for_zscore_calc + + # Obtain the data that one would get by generating a box and whisker plot + return_values = multiaxial_gating.generate_box_and_whisker( + df=df_transform, + column_for_filtering=channel_for_phenotyping, + apply_another_filter=False, + another_filter_column=None, + values_on_which_to_filter=None, + images_in_plotting_group_1=df_transform.loc[df_transform[column_identifying_baseline_signal] == value_identifying_baseline, 'Slide ID'].unique(), + images_in_plotting_group_2=df_transform.loc[df_transform[column_identifying_baseline_signal] == value_identifying_signal, 'Slide ID'].unique(), + all_cells=False, + mean_for_zscore_calc=mean_for_zscore_calc, + std_for_zscore_calc=std_for_zscore_calc, + return_figure_and_summary=return_figure_and_summary + ) + + # Return the desired values + if return_figure_and_summary: + return return_values + else: + return return_values, current_index + + +def main(): + """ + Main function for the page. + """ + + # Ensure a dataset has been opened in the first place + if 'input_dataset' not in st.session_state: + st.warning('Please open a dataset from the Open File page at left.') + return + + # Get some necessary variables from the session state + df = st.session_state['input_dataset'].data + + # Store columns of certain types + if st_key_prefix + 'categorical_columns' not in st.session_state: + st.session_state[st_key_prefix + 'categorical_columns'] = utils.get_categorical_columns_including_numeric(df, max_num_unique_values=1000) + categorical_columns = st.session_state[st_key_prefix + 'categorical_columns'] + + # Initialize three columns + columns = st.columns(3) + + # In the first column... + with columns[0]: + + st.subheader(':one: Threshold calculation') + + # Select columns to use for grouping the threshold calculations + key = st_key_prefix + 'columns_for_phenotype_grouping' + if key not in st.session_state: + st.session_state[key] = [] + columns_for_phenotype_grouping = st.multiselect('Columns for phenotype grouping:', categorical_columns, key=key) + + # Optionally force-update the list of categorical columns + st.button('Update phenotype grouping columns 💡', help='If you don\'t see the column you want to group, click this button to update the list of potential phenotype grouping columns.', on_click=lambda: st.session_state.pop(st_key_prefix + 'categorical_columns', None)) + + # Set the column name that describes the baseline field such as cell type + key = st_key_prefix + 'column_identifying_baseline_signal' + if key not in st.session_state: + st.session_state[key] = categorical_columns[0] + column_identifying_baseline_signal = st.selectbox('Column identifying baseline/signal:', categorical_columns, key=key, on_change=lambda: st.session_state.pop(st_key_prefix + 'value_identifying_baseline', None)) + + # Extract the baseline field value (such as a specific cell type) to be used for determining the thresholds + available_baseline_signal_values = df[column_identifying_baseline_signal].unique() + key = st_key_prefix + 'value_identifying_baseline' + if key not in st.session_state: + st.session_state[key] = available_baseline_signal_values[0] + value_identifying_baseline = st.selectbox('Value identifying baseline:', available_baseline_signal_values, key=key) + + # Extract the available channels for performing phenotyping + key = st_key_prefix + 'channel_for_phenotyping' + if key not in st.session_state: + st.session_state[key] = df.columns[0] + channel_for_phenotyping = st.selectbox('Channel for phenotyping:', df.columns, key=key) + + # If adaptive thresholding is desired... + if st.button('Calculate thresholds for phenotyping'): + + # Group the relevant subset of the dataframe by the selected variables + if len(columns_for_phenotype_grouping) > 0: + df_grouped = df[columns_for_phenotype_grouping + [column_identifying_baseline_signal, channel_for_phenotyping, 'Slide ID']].groupby(by=columns_for_phenotype_grouping) + else: + df_grouped = [(None, df)] + + # Define various z scores of interest to use for determining the thresholds + z_scores = np.arange(-1, 11) + + # For every group of variables for which you want a different threshold... + thresholds_outer = [] + for _, df_group in df_grouped: + + # Get the selected intensity data for just the baseline field value + ser_baseline = df_group.loc[df_group[column_identifying_baseline_signal] == value_identifying_baseline, channel_for_phenotyping] + + # From that, calculate the mean and std + mean_to_use = ser_baseline.mean() + std_to_use = ser_baseline.std() + + # Determine the corresponding threshold for each z score + thresholds_inner = [] + for z_score in z_scores: + thresholds_inner.append(mean_to_use + z_score * std_to_use) + + # Add to the main thresholds holder + thresholds_outer.append(thresholds_inner) + + # Determine the multi-index for the dataframe + if len(columns_for_phenotype_grouping) == 0: + index = pd.Index([-1]) + elif len(columns_for_phenotype_grouping) == 1: + index = pd.Index(list(df_grouped.groups.keys()), name=columns_for_phenotype_grouping[0]) + elif len(columns_for_phenotype_grouping) > 1: + index = pd.MultiIndex.from_tuples(list(df_grouped.groups.keys()), names=columns_for_phenotype_grouping) + index.name = 'Grouping' + + # Determine the columns index for the dataframe + columns_index = pd.Index([f'z score = {z_score}' for z_score in z_scores]) + columns_index.name = 'Thresholds' + + # Set the dataframe of thresholds + df_thresholds = pd.DataFrame(thresholds_outer, columns=columns_index, index=index).sort_index() + st.session_state[st_key_prefix + 'df_thresholds'] = df_thresholds + + # Save the grouped data as well for plotting and phenotyping + st.session_state[st_key_prefix + 'df_grouped'] = df_grouped + + # Make sure the phenotyping thresholds have been calculated + key = st_key_prefix + 'df_thresholds' + if key not in st.session_state: + st.warning('Please calculate thresholds for phenotyping') + return + + # Set a shortcut to the relevant data + df_thresholds = st.session_state[key] + df_grouped = st.session_state[st_key_prefix + 'df_grouped'] + + # Display the thresholds, allowing the user to select a single row + st.write('Calculated thresholds:') + group_selection = st.dataframe(df_thresholds, on_select='rerun', selection_mode='single-row') + + # Create a new figure + fig = go.Figure() + + # Loop through each column in df_thresholds to add a trace for each one + for column in df_thresholds.columns: + fig.add_trace(go.Scatter( + # x=df_thresholds.index, # Use the DataFrame index for the x-axis + x=np.array(range(len(df_thresholds))), + y=df_thresholds[column], # Column values for the y-axis + mode='markers+lines', # Line plot + name=column # Use the column name as the trace name + )) + + # Update the layout to add titles and adjust other aesthetics if needed + fig.update_layout( + title='Line Plot of Thresholds Dataframe Above', + xaxis_title='Index', + yaxis_title='Phenotyping channel', + legend_title='Column' + ) + + # Plot the line plots + st.plotly_chart(fig) + + # In the second column... + with columns[1]: + + st.subheader(':two: Percent positives plotting') + + # Extract the value identifying the signal (such as a specific cell type) to be used for testing the thresholds + key = st_key_prefix + 'value_identifying_signal' + if key not in st.session_state: + st.session_state[key] = available_baseline_signal_values[0] + value_identifying_signal = st.selectbox('Value identifying signal:', available_baseline_signal_values, key=key) + + # Whether to apply the threshold to just the selected group + # In the future, this should be made a radio button probably, individual group vs. entire dataset + key = st_key_prefix + 'apply_thresh_to_selected_group' + if key not in st.session_state: + st.session_state[key] = True + apply_thresh_to_selected_group = st.checkbox('Apply threshold to each individual group (instead of to the entire dataset)', key=key) + + # Whether to average over all groups + key = st_key_prefix + 'average_over_all_groups' + if key not in st.session_state: + st.session_state[key] = False + average_over_all_groups = st.checkbox('Average over all groups', key=key) + + # If we want to generate a plot based on the currently selected group... + if not average_over_all_groups: + + # Obtain the current selection + row_selection_list = group_selection['selection']['rows'] + + # If something is actually selected... + if row_selection_list: + + # Get the box and whisker data + fig, _ = get_box_and_whisker_data(df_grouped, df_thresholds, apply_thresh_to_selected_group, df, channel_for_phenotyping, column_identifying_baseline_signal, value_identifying_baseline, value_identifying_signal, row_selection_list[0], return_figure_and_summary=True) + + # Plot the box and whisker chart + st.plotly_chart(fig) + + # If we want to generate a plot from averaging over all the groups... + else: + + # Since this takes a non-trivial amount of time, hide the calculation behind a button + if st.button('Calculate average positive percentages over all groups'): + + # Initialize the holders of the group indices and the average positive percentages + index_holder = [] + ser_holder_baseline = [] + ser_holder_signal = [] + + # For every group... + for curr_row in range(len(df_thresholds)): + + # Get the box and whisker data + df_summary, curr_index = get_box_and_whisker_data(df_grouped, df_thresholds, apply_thresh_to_selected_group, df, channel_for_phenotyping, column_identifying_baseline_signal, value_identifying_baseline, value_identifying_signal, curr_row, return_figure_and_summary=False) + + # Append the desired data to the holders + df_summary = df_summary.drop('Threshold', axis='columns').set_index('Z score') # the drop isn't necessary but it may make the operations marginally faster + index_holder.append(curr_index) + ser_holder_baseline.append(df_summary['Positive % (avg. over images) for baseline group']) + ser_holder_signal.append(df_summary['Positive % (avg. over images) for signal group']) + + # Calculate and save the average positive percentages for the baseline group + df_baseline = pd.concat(ser_holder_baseline, axis='columns', keys=index_holder) + df_baseline.columns.names = columns_for_phenotype_grouping if len(columns_for_phenotype_grouping) > 0 else ['No group'] + st.session_state[st_key_prefix + 'df_baseline'] = df_baseline + + # Calculate and save the average positive percentages for the signal group + df_signal = pd.concat(ser_holder_signal, axis='columns', keys=index_holder) + df_signal.columns.names = columns_for_phenotype_grouping if len(columns_for_phenotype_grouping) > 0 else ['No group'] + st.session_state[st_key_prefix + 'df_signal'] = df_signal + + # Ensure the average positive percentages have been calculated + if st_key_prefix + 'df_baseline' not in st.session_state: + st.warning('Please calculate average positive percentages over all groups') + return + + # Get the difference in the average positive percentages between the baseline and signal groups + df_baseline = st.session_state[st_key_prefix + 'df_baseline'] + df_signal = st.session_state[st_key_prefix + 'df_signal'] + df_diff = df_signal - df_baseline + + # Render some charts in Streamlit + st.plotly_chart(plotly_mean_and_sem([df_baseline, df_signal], ['Baseline', 'Signal'])) + st.plotly_chart(plotly_mean_and_sem([df_diff], ['signal - baseline'])) + + # In the third column... + with columns[2]: + + st.subheader(':three: Phenotype generation') + + # Set the phenotype name + key = st_key_prefix + 'phenotype_name' + if key not in st.session_state: + st.session_state[key] = '' + phenotype_name = st.text_input('Phenotype name:', key=key) + + # Set the desired Z score + key = st_key_prefix + 'desired_z_score' + if key not in st.session_state: + st.session_state[key] = 2.0 + desired_z_score = st.number_input('Desired Z score:', key=key) + + # Set the phenotyping method + key = st_key_prefix + 'phenotyping_method' + phenotyping_method_options = ["Selected threshold applied to entire dataset", "Group-specific threshold applied to each group"] + if key not in st.session_state: + st.session_state[key] = phenotyping_method_options[1] + phenotyping_method = st.selectbox('Phenotyping method:', phenotyping_method_options, key=key) + + # Obtain the current selection + row_selection_list = group_selection['selection']['rows'] + + # Determine whether the phenotyping button should be disabled + phenotyping_button_disabled = False + if (phenotyping_method == phenotyping_method_options[0]) and (not row_selection_list): + st.warning(f'Phenotyping method "{phenotyping_method}" requires a group to be selected in the first column') + phenotyping_button_disabled = True + if (phenotyping_method == phenotyping_method_options[1]) and (row_selection_list): + st.warning(f'Phenotyping method "{phenotyping_method}" does not actually require that no group be selected in the first column, but for clarity (since that method performs calculations for every group), we are enacting that requirement; please de-select the group selection in the first column') + phenotyping_button_disabled = True + + # Render the button to perform the phenotyping + if st.button('Perform phenotyping', disabled=phenotyping_button_disabled): + + # Perform the phenotyping method that applies the selected threshold to the entire dataset + # Here: apply_thresh_to_selected_group=False, average_over_all_groups=False, DO have a particular group selected --> "Selected threshold applied to entire dataset" + if phenotyping_method == phenotyping_method_options[0]: + + # Get the selected row of df_thresholds, which is a series + ser_selected = df_thresholds.iloc[row_selection_list[0]] + + # Calculate the threshold from the mean/std from the selected group + mean_for_zscore_calc = ser_selected.loc['z score = 0'] + std_for_zscore_calc = ser_selected.loc['z score = 1'] - mean_for_zscore_calc + threshold = mean_for_zscore_calc + desired_z_score * std_for_zscore_calc + + # If the threshold for our desired Z score has already been calculated, enforce agreement + if desired_z_score in np.arange(-1, 11, 1): + assert np.abs(threshold - ser_selected.loc[f'z score = {int(desired_z_score)}']) < 1e-8, 'The threshold calculated for the selected group does not match the threshold in the thresholds DataFrame' + + # Get the locations where the selected filtering column is at least the current threshold value and assign the integer version of this to a new series + positive_loc = df[channel_for_phenotyping] >= threshold + ser_phenotype = positive_loc.astype(int) + + # Output the acutally used threshold + st.write(f'Threshold used for entire dataset: {threshold}') + + # Perform the phenotyping method that applies a group-specific threshold to each group + # Here: apply_thresh_to_selected_group=True, average_over_all_groups=True, DO NOT have a particular group selected --> "Group-specific threshold applied to each group" + elif phenotyping_method == phenotyping_method_options[1]: + + # Initialize the phenotype column to all-negative + ser_phenotype = pd.Series(-1, index=df.index) + + # For every group... + thresholds = [] + for curr_row in range(len(df_thresholds)): + + # Obtain the index and dataframe of the group identified by curr_row + if isinstance(df_grouped, list): + curr_df = df_grouped[0][1] + curr_integer_indices_into_df = np.array(range(len(curr_df))) + else: + curr_index = df_thresholds.iloc[curr_row].name + curr_df = df_grouped.get_group(curr_index) + curr_integer_indices_into_df = df_grouped.indices[curr_index] + + # Get the locations of the images in the baseline group + images_in_plotting_group_1 = curr_df.loc[curr_df[column_identifying_baseline_signal] == value_identifying_baseline, 'Slide ID'].unique() + image_loc_group_1 = curr_df['Slide ID'].isin(images_in_plotting_group_1) + + # Get only the data for the column of interest for the baseline images + ser_for_z_score = curr_df.loc[image_loc_group_1, channel_for_phenotyping] + + # Use the mean and std from those data to calculate the desired threshold + mean_for_zscore_calc = ser_for_z_score.mean() + std_for_zscore_calc = ser_for_z_score.std() + threshold = mean_for_zscore_calc + desired_z_score * std_for_zscore_calc + + # If the threshold for our desired Z score has already been calculated, enforce agreement + if desired_z_score in np.arange(-1, 11, 1): + assert np.abs(threshold - df_thresholds.iloc[curr_row].loc[f'z score = {int(desired_z_score)}']) < 1e-8, 'The threshold calculated for the current group does not match the threshold in the thresholds DataFrame' + + # Get the locations where the selected filtering column is at least the current threshold value + positive_loc = curr_df[channel_for_phenotyping] >= threshold # boolean series fitting the current group + + # Update the phenotype assignments for the current group + ser_phenotype.iloc[curr_integer_indices_into_df] = positive_loc.astype(int) # the LHS is the slice of ser_phenotype corresponding to the current group + + # Store the threshold for the current group + thresholds.append(threshold) + + # Check that there are no values of -1 remaining in ser_phenotype + assert -1 not in ser_phenotype, 'There are still cells that have not been assigned positivity, which shouldn\'t have happened' + + # Output the actually used thresholds + thresholds = pd.Series(thresholds, index=df_thresholds.index) + thresholds.name = 'Threshold' + st.write('Thresholds used for each group in turn:') + st.write(thresholds) + + # Add the phenotype column to the dataframe + pheno_colname = f'Phenotype {phenotype_name}' + df[pheno_colname] = utils.downcast_series_dtype(ser_phenotype) + st.success(f'Phenotype column "{pheno_colname}" has been appended to (or modified in) the dataset') + st.write(f'Number of cells in each phenotype group (0 = negative, 1 = positive):') + st.write(df[pheno_colname].value_counts().reset_index(drop=True)) + + # Ensure the main dataframe is updated per the operations above + st.session_state['input_dataset'].data = df + + +# Run the main function +if __name__ == '__main__': + + # Call the main function + main() diff --git a/pages/child_process_killer.py b/pages2/child_process_killer.py similarity index 100% rename from pages/child_process_killer.py rename to pages2/child_process_killer.py diff --git a/pages/01_data_import_and_export.py b/pages2/data_import_and_export.py similarity index 78% rename from pages/01_data_import_and_export.py rename to pages2/data_import_and_export.py index e61ab61..74beb96 100644 --- a/pages/01_data_import_and_export.py +++ b/pages2/data_import_and_export.py @@ -2,29 +2,15 @@ This is the python script which produces the PHENOTYPING PAGE ''' import streamlit as st - -# Import relevant libraries import streamlit_utils -import app_top_of_page as top -import streamlit_dataframe_editor as sde +from streamlit_extras.add_vertical_space import add_vertical_space + def main(): ''' Main function for running the page ''' - # Use the whole page width - st.set_page_config(page_title="Data Import and Export", - layout="wide") - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - st.title('Data Import and Export') - # Store a copy (not a link) of the platform object for clarity below platform = st.session_state['platform'] @@ -45,6 +31,7 @@ def main(): # In the second column... with cols[1]: + add_vertical_space(5) platform.load_selected_inputs() # st.divider() platform.save_selected_input() @@ -92,8 +79,6 @@ def main(): # and it may have been modified, save it back to Streamlit st.session_state['platform'] = platform - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) if __name__ == '__main__': main() diff --git a/pages/datafile_format_unifier.py b/pages2/datafile_format_unifier.py similarity index 88% rename from pages/datafile_format_unifier.py rename to pages2/datafile_format_unifier.py index 0722da1..71bd5b7 100644 --- a/pages/datafile_format_unifier.py +++ b/pages2/datafile_format_unifier.py @@ -2,12 +2,22 @@ import os import streamlit as st import pandas as pd -import app_top_of_page as top import streamlit_dataframe_editor as sde import re import utils +def callback_for_combining_datafiles(filenames): + + # Clear all keys in the session state starting with "unifier__" and not applicable to the selections above the callback button + keys_to_delete = [key for key in st.session_state.keys() if (key.startswith("unifier__")) and (key not in ['unifier__input_files', 'unifier__de_datafile_selection', 'unifier__df_datafile_selection', 'unifier__df_datafile_selection_changes_dict', 'unifier__df_datafile_selection_key'])] + for key in keys_to_delete: + if key in st.session_state: + del st.session_state[key] + + generate_guess_for_basename_of_mawa_unified_file(filenames) + + def generate_guess_for_basename_of_mawa_unified_file(filenames): # generate_guess_for_basename_of_mawa_unified_file(df_reconstructed.loc[selected_rows, 'Filename']) @@ -52,19 +62,9 @@ def main(): Main function for the Datafile Unifier page. """ - # Set page settings - st.set_page_config(layout='wide', page_title='Datafile Unifier') - st.title('Datafile Unifier') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - # Constants directory = os.path.join('.', 'input') - extensions = ('.csv', '.tsv') + valid_extensions = ('.csv', '.tsv', '.txt') # Initialization show_dataframe_updates = False @@ -91,20 +91,20 @@ def main(): st.header(':one: Select datafile(s)') # Retrieve list of files with the given extensions in the requested directory - files = list_files(directory, extensions) + files = list_files(directory, valid_extensions) # If no files are found, write a message to the user if len(files) == 0: - st.warning('No ".csv" or ".tsv" files found in the `input` directory.') + st.warning(f'No files with extensions {valid_extensions} found in the `input` directory.') # If files are found, display them in a dataframe editor else: # Write messages to the user num_files = len(files) if num_files == 1: - st.write('Detected 1 ".csv" or ".tsv" file in the `input` directory.') + st.write(f'Detected 1 file with any of the extensions {valid_extensions} in the `input` directory.') else: - st.write(f'Detected {num_files} ".csv" and ".tsv" files in the `input` directory.') + st.write(f'Detected {num_files} files with any of the extensions {valid_extensions} in the `input` directory.') st.write('Select 1 or more files to load into the MAWA Datafile Unifier.') st.write('Note: Double-click any cell to see the full filename.') @@ -153,42 +153,54 @@ def main(): else: load_msg = 'Loading and Combining Files...' # Create a button to concatenate the selected files - if st.button(button_text, help=button_help_message, disabled=load_button_disabled, on_click=generate_guess_for_basename_of_mawa_unified_file, args=(df_reconstructed.loc[selected_rows, 'Filename'],)): + if st.button(button_text, help=button_help_message, disabled=load_button_disabled, on_click=callback_for_combining_datafiles, args=(df_reconstructed.loc[selected_rows, 'Filename'],)): # Render a progress spinner while the files are being combined with st.spinner(load_msg): # Efficiently check if the columns are equal for all input files - columns_equal = True if len(input_files) > 1: - sep = (',' if input_files[0].split('.')[-1] == 'csv' else '\t') - first_file_columns = pd.read_csv(os.path.join(directory, input_files[0]), nrows=0, sep=sep).columns - for input_file in input_files[1:]: + columns_holder = [] + for input_file in input_files: sep = (',' if input_file.split('.')[-1] == 'csv' else '\t') current_file_columns = pd.read_csv(os.path.join(directory, input_file), nrows=0, sep=sep).columns - if not first_file_columns.equals(current_file_columns): - st.error('Columns are not equal for files: {} and {}'.format(input_files[0], input_file)) - columns_equal = False - break + columns_holder.append(current_file_columns) + + # Check if the columns are equal for all input files + columns_equal = all([columns_holder[0].equals(columns) for columns in columns_holder]) - # If the columns are equal for all input files, concatenate all files into a single dataframe - if columns_equal: + else: + columns_equal = True + + # If the columns are not equal for all input files, display a warning and obtain the common columns + if not columns_equal: + st.warning('The selected input files have different columns. We will take the intersection of the columns for all files.') + common_columns = list(set.intersection(*[set(columns) for columns in columns_holder])) + unique_columns = list(set.union(*[set(columns) for columns in columns_holder]) - set(common_columns)) + st.write('Columns excluded from the file combination:', unique_columns) + else: sep = (',' if input_files[0].split('.')[-1] == 'csv' else '\t') - # st.session_state['unifier__df'] = utils.downcast_dataframe_dtypes(pd.concat([pd.read_csv(os.path.join(directory, input_file), sep=sep) for input_file in input_files], ignore_index=True)) - df_holder = [] - for input_file in input_files: - curr_df = pd.read_csv(os.path.join(directory, input_file), sep=sep) - assert 'input_filename' not in curr_df.columns, 'ERROR: "input_filename" is one of the columns but we want to overwrite it' + common_columns = pd.read_csv(os.path.join(directory, input_files[0]), nrows=0, sep=sep).columns + + if len(common_columns) == 0: + st.warning('No common columns found. Please select files with common columns.') + return + + # Concatenate all files into a single dataframe using the common set of columns + sep = (',' if input_files[0].split('.')[-1] == 'csv' else '\t') + df_holder = [] + for input_file in input_files: + curr_df = pd.read_csv(os.path.join(directory, input_file), sep=sep, usecols=common_columns) + if 'input_filename' not in curr_df.columns: curr_df['input_filename'] = input_file - df_holder.append(curr_df) - st.session_state['unifier__df'] = utils.downcast_dataframe_dtypes(pd.concat(df_holder, ignore_index=True)) + df_holder.append(curr_df) + st.session_state['unifier__df'] = utils.downcast_dataframe_dtypes(pd.concat(df_holder, ignore_index=True)) - # Save the setting used for this operation - st.session_state['unifier__input_files_actual'] = input_files + # Save the setting used for this operation + st.session_state['unifier__input_files_actual'] = input_files # Display a success message - if columns_equal: - st.success(f'{len(input_files)} files combined') + st.success(f'{len(input_files)} files combined') # Set a flag to update the dataframe sample at the bottom of the page show_dataframe_updates = True @@ -206,10 +218,28 @@ def main(): # In the first column... with main_columns[0]: - # ---- 2. (Optional) Drop null rows from the dataset -------------------------------------------------------------------------------------------------------------------------------- + # ---- 2. Drop null rows from the dataset -------------------------------------------------------------------------------------------------------------------------------- # Display a header for the null row deletion section - st.markdown('## :two: Delete null rows (optional) ') + st.markdown('## :two: Delete null rows ') + + st.write('**Instructions:** You must press the button below to make sure there are no null data in columns you will want to use downstream! If you see null data in the columns you want to use, you must delete the rows with null data in those columns by expanding the "Click to expand:" dropdown below and following the directions therein. Once you\'ve done this, feel free to press this button again to make sure you\'ve deleted all null rows in the columns you care about.') + + # Allow user to detect null rows + if 'unifier__null_detection_button_has_been_pressed' not in st.session_state: + st.session_state['unifier__null_detection_button_has_been_pressed'] = False + if st.button('Detect null rows in each column'): + ser_num_of_null_rows_in_each_column = df.isnull().sum() + if ser_num_of_null_rows_in_each_column.sum() == 0: + st.success('No null rows detected in the dataset.') + else: + st.write('Null values have been detected. Here are the numbers of null rows found in the columns containing them. Note they may not matter depending on the column. See instructions above:') + ser_num_of_null_rows_in_each_column.name = 'Number of null rows' + st.write(ser_num_of_null_rows_in_each_column[ser_num_of_null_rows_in_each_column != 0]) + st.session_state['unifier__null_detection_button_has_been_pressed'] = True + + if not st.session_state['unifier__null_detection_button_has_been_pressed']: + st.warning('You must press the "Detect null rows in each column" button above (and delete any relevant null data; see instructions above the button) before proceeding to the next steps!') # Create an expander for the null row deletion section with st.expander('Click to expand:', expanded=False): @@ -365,6 +395,9 @@ def main(): # Set a flag to update the dataframe sample at the bottom of the page show_dataframe_updates = True + if 'unifier__num_roi_columns_actual' not in st.session_state: + st.warning('You must press the "Assign ROIs" button even if ROIs are not explicitly defined in the dataset.') + # If the selected columns to define ROIs have changed since the last time ROIs were defined, display a warning if st.session_state['unifier__roi_explicitly_defined']: if ('unifier__roi_explicitly_defined_actual' in st.session_state) and ((st.session_state['unifier__roi_explicitly_defined_actual'] != st.session_state['unifier__roi_explicitly_defined']) or (st.session_state['unifier__roi_column_actual'] != st.session_state['unifier__roi_column'])): @@ -672,14 +705,12 @@ def main(): st.header('Sample of unified dataframe') resample_dataframe = st.button('Refresh dataframe sample') if ('sampled_df' not in st.session_state) or resample_dataframe or show_dataframe_updates: - sampled_df = df.sample(100).sort_index() + sampled_df = utils.sample_df_without_replacement_by_number(df=df, n=100).sort_index() st.session_state['sampled_df'] = sampled_df sampled_df = st.session_state['sampled_df'] st.write(sampled_df) st.dataframe(pd.DataFrame(st.session_state['unifier__input_files_actual'], columns=["Input files included in the combined dataset"]), hide_index=True) - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) # Run the main function if __name__ == '__main__': diff --git a/pages/dummy_editor.py b/pages2/dummy_editor.py similarity index 100% rename from pages/dummy_editor.py rename to pages2/dummy_editor.py diff --git a/pages2/forking_test.py b/pages2/forking_test.py new file mode 100644 index 0000000..1adc3a7 --- /dev/null +++ b/pages2/forking_test.py @@ -0,0 +1,35 @@ +import streamlit as st +import utils +import time + + +def sleep_task(tuple_of_args): + sleep_time_sec = tuple_of_args[0] + print(f'Single task running, waiting for {sleep_time_sec} seconds') + time.sleep(sleep_time_sec) + + +def main(): + + if 'nworkers' not in st.session_state: + st.session_state['nworkers'] = 1 + st.number_input('Number of workers', min_value=1, max_value=10, key='nworkers') + + if 'num_tasks' not in st.session_state: + st.session_state['num_tasks'] = 10 + st.number_input('Number of tasks', min_value=1, max_value=100, key='num_tasks') + + start_time = time.time() + if st.button('Run forking test'): + + st.write('Forking test button pressed') + + utils.execute_data_parallelism_potentially(sleep_task, [(1,)] * st.session_state['num_tasks'], nworkers=st.session_state['nworkers'], task_description='Sleeping tasks') + + st.write('Done') + + st.write(f'Total time: {time.time() - start_time:.2f} seconds') + + +if __name__ == '__main__': + main() diff --git a/pages2/macro_radial_density.py b/pages2/macro_radial_density.py new file mode 100644 index 0000000..858decf --- /dev/null +++ b/pages2/macro_radial_density.py @@ -0,0 +1,374 @@ +# Import relevant libraries +import streamlit as st +import app_top_of_page as top +import streamlit_dataframe_editor as sde +import plotly.graph_objects as go +import plotly.express as px +import pandas as pd +from itertools import cycle, islice +import numpy as np + + +def update_color_for_value(value_to_change_color): + st.session_state['mrd__color_dict'][value_to_change_color] = st.session_state['mrd__new_picked_color'] + + +def reset_color_dict(ser_to_plot): + # Create a color sequence based on the frequency of the values to plot in the entire dataset + values_to_plot = ser_to_plot.value_counts().index + colors = list(islice(cycle(px.colors.qualitative.Plotly), len(values_to_plot))) + st.session_state['mrd__values_to_plot'] = values_to_plot + st.session_state['mrd__color_dict'] = dict(zip(values_to_plot, colors)) # map values to colors + + +def go_to_previous_image(unique_images): + """ + Go to the previous image in the numpy array. + + Parameters: + unique_images (numpy.ndarray): The unique images. + + Returns: + None + """ + + # Get the current index in the unique images + current_index = list(unique_images).index(st.session_state['mrd__image_to_view']) + + # If we're not already at the first image, go to the previous image + if current_index > 0: + current_index -= 1 + st.session_state['mrd__image_to_view'] = unique_images[current_index] + + +def go_to_next_image(unique_images): + """ + Go to the next image in the numpy array. + + Parameters: + unique_images (numpy.ndarray): The unique images. + + Returns: + None + """ + + # Get the current index in the unique images + current_index = list(unique_images).index(st.session_state['mrd__image_to_view']) + + # If we're not already at the last image, go to the next image + if current_index < len(unique_images) - 1: + current_index += 1 + st.session_state['mrd__image_to_view'] = unique_images[current_index] + + +def main(): + """ + Main function for the page. + """ + + # Define the main settings columns + settings_columns_main = st.columns(3) + + # In the first column... + with settings_columns_main[0]: + + # Allow user to select the dataframe containing the data to plot + if 'mrd__data_to_plot' not in st.session_state: + st.session_state['mrd__data_to_plot'] = 'Input data' + data_to_plot = st.selectbox('Dataset containing plotting data:', ['Input data', 'Phenotyped data'], key='mrd__data_to_plot') + input_dataset_has_changed = ('mrd__data_to_plot_prev' not in st.session_state) or (st.session_state['mrd__data_to_plot_prev'] != data_to_plot) + st.session_state['mrd__data_to_plot_prev'] = data_to_plot + + # If they want to plot phenotyped data, ensure they've performed phenotyping + if (data_to_plot == 'Phenotyped data') and (len(st.session_state['df']) == 1): + st.warning('If you\'d like to plot the phenotyped data, please perform phenotyping first.') + return + + # Set the shortcut to the dataframe of interest + if data_to_plot == 'Input data': + df = st.session_state['input_dataset'].data + else: + df = st.session_state['df'] + + # Store columns of certain types + if ('mrd__categorical_columns' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__categorical_columns'] = df.select_dtypes(include=('category', 'object')).columns + if ('mrd__numeric_columns' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__numeric_columns'] = df.select_dtypes(include='number').columns + if ('mrd__all_columns' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__all_columns'] = df.columns + categorical_columns = st.session_state['mrd__categorical_columns'] + numeric_columns = st.session_state['mrd__numeric_columns'] + all_columns = st.session_state['mrd__all_columns'] + + # Choose a column to plot + if ('mrd__column_to_plot' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__column_to_plot'] = categorical_columns[0] + column_to_plot = st.selectbox('Select a column by which to color the points:', categorical_columns, key='mrd__column_to_plot') + column_to_plot_has_changed = ('mrd__column_to_plot_prev' not in st.session_state) or (st.session_state['mrd__column_to_plot_prev'] != column_to_plot) or input_dataset_has_changed + st.session_state['mrd__column_to_plot_prev'] = column_to_plot + + # Get some information about the images in the input dataset + if input_dataset_has_changed: + st.session_state['mrd__unique_images'] = df['Slide ID'].unique() # get the unique images in the dataset + st.session_state['mrd__ser_size_of_each_image'] = df['Slide ID'].value_counts() # calculate the number of objects in each image + unique_images = st.session_state['mrd__unique_images'] + ser_size_of_each_image = st.session_state['mrd__ser_size_of_each_image'] + + # Create an image selection selectbox + if 'mrd__image_to_view' not in st.session_state: + st.session_state['mrd__image_to_view'] = unique_images[0] + image_to_view = st.selectbox('Select image to view:', unique_images, key='mrd__image_to_view') + + # Display the number of cells in the selected image + st.write(f'Number of cells in image: {ser_size_of_each_image.loc[image_to_view]}') + + # Optionally navigate through the images using Previous and Next buttons + cols = st.columns(2) + with cols[0]: + st.button('Previous image', on_click=go_to_previous_image, args=(unique_images,), disabled=(image_to_view == unique_images[0]), use_container_width=True) + with cols[1]: + st.button('Next image', on_click=go_to_next_image, args=(unique_images, ), disabled=(image_to_view == unique_images[-1]), use_container_width=True) + + # In the second column... + with settings_columns_main[1]: + + # Optionally plot minimum and maximum coordinate fields + if 'mrd__use_coordinate_mins_and_maxs' not in st.session_state: + st.session_state['mrd__use_coordinate_mins_and_maxs'] = False + use_coordinate_mins_and_maxs = st.checkbox('Use coordinate mins and maxs', key='mrd__use_coordinate_mins_and_maxs') + settings_columns_refined = st.columns(2) + if 'mrd__x_min_coordinate_column' not in st.session_state: + st.session_state['mrd__x_min_coordinate_column'] = numeric_columns[0] + if 'mrd__y_min_coordinate_column' not in st.session_state: + st.session_state['mrd__y_min_coordinate_column'] = numeric_columns[0] + if 'mrd__x_max_coordinate_column' not in st.session_state: + st.session_state['mrd__x_max_coordinate_column'] = numeric_columns[0] + if 'mrd__y_max_coordinate_column' not in st.session_state: + st.session_state['mrd__y_max_coordinate_column'] = numeric_columns[0] + with settings_columns_refined[0]: + xmin_col = st.selectbox('Select a column for the minimum x-coordinate:', numeric_columns, key='mrd__x_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + with settings_columns_refined[1]: + xmax_col = st.selectbox('Select a column for the maximum x-coordinate:', numeric_columns, key='mrd__x_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + with settings_columns_refined[0]: + ymin_col = st.selectbox('Select a column for the minimum y-coordinate:', numeric_columns, key='mrd__y_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + with settings_columns_refined[1]: + ymax_col = st.selectbox('Select a column for the maximum y-coordinate:', numeric_columns, key='mrd__y_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + units = ('coordinate units' if use_coordinate_mins_and_maxs else 'microns') + + # Optionally add another filter + if ('mrd__add_another_filter' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__add_another_filter'] = False + if ('mrd__column_to_filter_by' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__column_to_filter_by'] = categorical_columns[0] + if ('mrd__values_to_filter_by' not in st.session_state) or input_dataset_has_changed: + st.session_state['mrd__values_to_filter_by'] = [] + st.checkbox('Add filter', key='mrd__add_another_filter') + st.selectbox('Select a column to filter by:', categorical_columns, key='mrd__column_to_filter_by', disabled=(not st.session_state['mrd__add_another_filter'])) + st.multiselect('Select values to filter by:', df[st.session_state['mrd__column_to_filter_by']].unique(), key='mrd__values_to_filter_by', disabled=(not st.session_state['mrd__add_another_filter'])) + add_another_filter = st.session_state['mrd__add_another_filter'] + column_to_filter_by = st.session_state['mrd__column_to_filter_by'] + values_to_filter_by = st.session_state['mrd__values_to_filter_by'] + + # In the third column... + with settings_columns_main[2]: + + # Add an option to invert the y-axis + if 'mrd__invert_y_axis' not in st.session_state: + st.session_state['mrd__invert_y_axis'] = False + invert_y_axis = st.checkbox('Invert y-axis', key='mrd__invert_y_axis') + + # Choose the opacity of objects + if 'mrd__opacity' not in st.session_state: + st.session_state['mrd__opacity'] = 0.7 + opacity = st.number_input('Opacity:', min_value=0.0, max_value=1.0, step=0.1, key='mrd__opacity') + + # Define the colors for the values to plot + if ('mrd__color_dict' not in st.session_state) or column_to_plot_has_changed: + reset_color_dict(df[column_to_plot]) + values_to_plot = st.session_state['mrd__values_to_plot'] + color_dict = st.session_state['mrd__color_dict'] + + # Select a value whose color we want to modify + if ('mrd__value_to_change_color' not in st.session_state) or column_to_plot_has_changed: + st.session_state['mrd__value_to_change_color'] = values_to_plot[0] + value_to_change_color = st.selectbox('Value whose color to change:', values_to_plot, key='mrd__value_to_change_color') + + # Create a color picker widget for the selected value + st.session_state['mrd__new_picked_color'] = color_dict[value_to_change_color] + st.color_picker('Pick a new color:', key='mrd__new_picked_color', on_change=update_color_for_value, args=(value_to_change_color,)) + + # Add a button to reset the colors to their default values + st.button('Reset plotting colors to defaults', on_click=reset_color_dict, args=(df[column_to_plot],)) + color_dict = st.session_state['mrd__color_dict'] + + # Draw a divider + st.divider() + + # Get columns in the dataframe specifying the coordinates of the circle centers + if 'mrd__reec_center_x' not in st.session_state: + st.session_state['mrd__reec_center_x'] = all_columns[0] + st.selectbox('Select a column identifying the center x-coordinate of the chamber:', all_columns, key='mrd__reec_center_x') + if 'mrd__reec_center_y' not in st.session_state: + st.session_state['mrd__reec_center_y'] = all_columns[1] + st.selectbox('Select a column identifying the center y-coordinate of the chamber:', all_columns, key='mrd__reec_center_y') + + # Allow the user to set a conversion factor for these coordinates + if 'mrd__conversion_factor' not in st.session_state: + st.session_state['mrd__conversion_factor'] = 1.0 + conversion_factor = st.number_input('Conversion factor for chamber center coordinates:', min_value=0.0, step=0.1, key='mrd__conversion_factor') + + # Get the coordinate column names + if not use_coordinate_mins_and_maxs: + xcolname = 'Cell X Position' + ycolname = 'Cell Y Position' + else: + xcolname = 'mrd__centroid_x' + ycolname = 'mrd__centroid_y' + df[xcolname] = (df[xmin_col] + df[xmax_col]) / 2 + df[ycolname] = (df[ymin_col] + df[ymax_col]) / 2 + + # Draw a divider + st.divider() + + # If the user wants to display the scatter plot, indicated by a toggle... + if 'mrd__show_scatter_plot' not in st.session_state: + st.session_state['mrd__show_scatter_plot'] = False + if st.toggle('Show scatter plot', key='mrd__show_scatter_plot'): + + # Optionally set up another filter + if add_another_filter: + filter_loc = df[column_to_filter_by].isin(values_to_filter_by) + else: + filter_loc = pd.Series(True, index=df.index) + + # Filter the DataFrame to include only the selected image and filter + df_selected_image_and_filter = df[(df['Slide ID'] == image_to_view) & filter_loc] + + # Group the DataFrame for the selected image by unique value of the column to plot + selected_image_grouped_by_value = df_selected_image_and_filter.groupby(column_to_plot) + + # Create the scatter plot + fig = go.Figure() + + # Loop over the unique values in the column whose values to plot, in order of their frequency + for value_to_plot in values_to_plot: + + # If the value exists in the selected image... + if (value_to_plot in selected_image_grouped_by_value.groups) and (len(selected_image_grouped_by_value.groups[value_to_plot]) > 0): + + # Store the dataframe for the current value for the selected image + df_group = selected_image_grouped_by_value.get_group(value_to_plot) + + # If value is a string, replace '(plus)' with '+' and '(dash)' with '-', since it could likely be a phenotype with those substitutions + if isinstance(value_to_plot, str): + value_str_cleaned = value_to_plot.replace('(plus)', '+').replace('(dash)', '-') + else: + value_str_cleaned = value_to_plot + + # Add the object index to the label + ser_hover_label = 'Index: ' + df_group.index.astype(str) + + # Works but doesn't scale the shapes + if not use_coordinate_mins_and_maxs: + fig.add_trace(go.Scatter(x=df_group['Cell X Position'], y=df_group['Cell Y Position'], mode='markers', name=value_str_cleaned, marker_color=color_dict[value_to_plot], hovertemplate=ser_hover_label)) + + # Works really well + else: + fig.add_trace(go.Bar( + x=((df_group[xmin_col] + df_group[xmax_col]) / 2), + y=df_group[ymax_col] - df_group[ymin_col], + width=df_group[xmax_col] - df_group[xmin_col], + base=df_group[ymin_col], + name=value_str_cleaned, + marker=dict( + color=color_dict[value_to_plot], + opacity=opacity, + ), + hovertemplate=ser_hover_label + )) + + # Update the layout + fig.update_layout( + xaxis=dict( + scaleanchor="y", + scaleratio=1, + ), + yaxis=dict( + autorange=('reversed' if invert_y_axis else True), + ), + title=f'Scatter plot for {image_to_view}', + xaxis_title=f'Cell X Position ({units})', + yaxis_title=f'Cell Y Position ({units})', + legend_title=column_to_plot, + height=800, # Set the height of the figure + width=800, # Set the width of the figure + ) + + # Get the center coordinates of the chamber + reec_center_x = df_selected_image_and_filter[st.session_state['mrd__reec_center_x']] + reec_center_y = df_selected_image_and_filter[st.session_state['mrd__reec_center_y']] + assert (reec_center_x.nunique() == 1) and (reec_center_y.nunique() == 1), f'There should be only one unique value for the center coordinates of the chamber (unique x values: {reec_center_x.unique()}, unique y values: {reec_center_y.unique()})' + reec_center_coords = np.array([reec_center_x.iloc[0], reec_center_y.iloc[0]]) * conversion_factor + + # Get the maximum possible radius of a circle in the chamber + extreme_radii = [ + df_selected_image_and_filter[xcolname].max() - reec_center_coords[0], + df_selected_image_and_filter[ycolname].max() - reec_center_coords[1], + reec_center_coords[0] - df_selected_image_and_filter[xcolname].min(), + reec_center_coords[1] - df_selected_image_and_filter[ycolname].min() + ] + st.write(extreme_radii) + maximum_radius = min(extreme_radii) + + tol = 1e-8 + step = 250 + max_value = maximum_radius + + if max_value % step < tol: + radii = np.arange(step, max_value + step, step) + else: + radii = np.arange(step, max_value, step) + + for radius in radii: + fig.add_shape( + type="circle", + xref="x", + yref="y", + x0=reec_center_coords[0] - radius, + y0=reec_center_coords[1] - radius, + x1=reec_center_coords[0] + radius, + y1=reec_center_coords[1] + radius, + line_color="green", + line_width=3, # Increase line_width for thicker circles + opacity=1, + ) + # Plot the plotly chart in Streamlit + st.plotly_chart(fig, use_container_width=True) + + if 'mrd__get_percent_frequencies' not in st.session_state: + st.session_state['mrd__get_percent_frequencies'] = False + if st.toggle('Get percent frequencies of coloring column for entire dataset', key='mrd__get_percent_frequencies'): + vc = df[column_to_plot].value_counts() + st.dataframe((df[column_to_plot].value_counts() / vc.sum() * 100).astype(int).reset_index()) + + +# Run the main function +if __name__ == '__main__': + + # Set page settings + page_name = 'Macro Radial Density' + st.set_page_config(layout='wide', page_title=page_name) + st.title(page_name) + + # Run streamlit-dataframe-editor library initialization tasks at the top of the page + st.session_state = sde.initialize_session_state(st.session_state) + + # Run Top of Page (TOP) functions + st.session_state = top.top_of_page_reqs(st.session_state) + + # Call the main function + main() + + # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page + st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/memory_analyzer.py b/pages2/memory_analyzer.py similarity index 97% rename from pages/memory_analyzer.py rename to pages2/memory_analyzer.py index 15e1e9a..4056fe5 100644 --- a/pages/memory_analyzer.py +++ b/pages2/memory_analyzer.py @@ -1,9 +1,6 @@ # Import relevant libraries import streamlit as st -import app_top_of_page as top -import streamlit_dataframe_editor as sde import pandas as pd -# from pympler.asizeof import asizeof as deep_mem_usage_in_bytes from objsize import get_deep_size as deep_mem_usage_in_bytes import numpy as np import pickle @@ -486,20 +483,4 @@ def main(): # Run the main function if __name__ == '__main__': - - # Set page settings - page_name = 'Memory Analyzer' - st.set_page_config(layout='wide', page_title=page_name) - st.title(page_name) - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Call the main function main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/multiaxial_gating.py b/pages2/multiaxial_gating.py similarity index 85% rename from pages/multiaxial_gating.py rename to pages2/multiaxial_gating.py index 3cf22fd..81af68d 100644 --- a/pages/multiaxial_gating.py +++ b/pages2/multiaxial_gating.py @@ -7,11 +7,62 @@ import numpy as np import utils from scipy.stats import gaussian_kde -import app_top_of_page as top import streamlit_dataframe_editor as sde +import random +import string +import image_filter -def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, another_filter_column, values_on_which_to_filter, images_in_plotting_group_1, images_in_plotting_group_2, all_cells=True): +def generate_random_string(length=10): + # Define the characters that will be used + characters = string.ascii_letters + string.digits + # Generate a random string of the specified length + random_string = ''.join(random.choice(characters) for i in range(length)) + return random_string + + +def reset_x_axis_range(use_groups_for_plotting, kde_or_hist_to_plot_full): + if not use_groups_for_plotting: + st.session_state['mg__histogram_x_range'] = [kde_or_hist_to_plot_full['Value'].min(), kde_or_hist_to_plot_full['Value'].max()] + else: + st.session_state['mg__histogram_x_range'] = [min(kde_or_hist_to_plot_full[0]['Value'].min(), kde_or_hist_to_plot_full[1]['Value'].min()), max(kde_or_hist_to_plot_full[0]['Value'].max(), kde_or_hist_to_plot_full[1]['Value'].max())] + st.session_state['mg__random_string'] = generate_random_string() + + +def plotly_chart_histogram_callback(): + plotly_chart_key = ('mg__plotly_chart_histogram_' + st.session_state['mg__random_string'] + '__do_not_persist') + if plotly_chart_key in st.session_state: + if st.session_state[plotly_chart_key]['selection']['box']: + x_range = sorted(st.session_state[plotly_chart_key]['selection']['box'][0]['x']) + if st.session_state['mg__histogram_box_selection_function'] == 'Positivity identification': + st.session_state['mg__selected_value_range'] = tuple(x_range) + st.session_state['mg__min_selection_value'] = x_range[0] + else: + st.session_state['mg__histogram_x_range'] = x_range + st.session_state['mg__random_string'] = generate_random_string() + + +def plotly_chart_summary_callback(): + if 'mg__plotly_chart_summary__do_not_persist' in st.session_state: + if st.session_state['mg__plotly_chart_summary__do_not_persist']['selection']['points']: + selected_z_score = st.session_state['mg__plotly_chart_summary__do_not_persist']['selection']['points'][0]['x'] + df_summary_contents = st.session_state['mg__df_summary_contents'].set_index('Z score') + df_selected_threshold = df_summary_contents.loc[selected_z_score, 'Threshold'] + selected_value_range = st.session_state['mg__selected_value_range'] + st.session_state['mg__selected_value_range'] = (df_selected_threshold, selected_value_range[1]) + st.session_state['mg__min_selection_value'] = df_selected_threshold + + +def df_summary_callback(): + if 'mg__df_summary__do_not_persist' in st.session_state: + if st.session_state['mg__df_summary__do_not_persist']['selection']['rows']: + df_selected_threshold = st.session_state['mg__df_summary_contents'].iloc[st.session_state['mg__df_summary__do_not_persist']['selection']['rows'][0]]['Threshold'] + selected_value_range = st.session_state['mg__selected_value_range'] + st.session_state['mg__selected_value_range'] = (df_selected_threshold, selected_value_range[1]) + st.session_state['mg__min_selection_value'] = df_selected_threshold + + +def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, another_filter_column, values_on_which_to_filter, images_in_plotting_group_1, images_in_plotting_group_2, all_cells=True, mean_for_zscore_calc=None, std_for_zscore_calc=None, return_figure_and_summary=True): # If we're ready to apply a filter, then create it if not apply_another_filter: @@ -28,12 +79,17 @@ def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, ano # From those data, get the values corresponding to -1 to 10 standard deviations above the mean z_scores = np.arange(-1, 11) - thresholds = ser_for_z_score.mean() + z_scores * ser_for_z_score.std() + if mean_for_zscore_calc is None: + mean_for_zscore_calc = ser_for_z_score.mean() + if std_for_zscore_calc is None: + std_for_zscore_calc = ser_for_z_score.std() + thresholds = mean_for_zscore_calc + z_scores * std_for_zscore_calc # Initialize the positive percentages holders group_1_holder = [] group_2_holder = [] - data_for_box_plot_holder = [] + if return_figure_and_summary: + data_for_box_plot_holder = [] # For each threshold... for threshold, z_score in zip(thresholds, z_scores): @@ -68,8 +124,9 @@ def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, ano df_group_2_pos_perc['Z score'] = z_score df_group_1_pos_perc['Group'] = 'Baseline' df_group_2_pos_perc['Group'] = 'Signal' - data_for_box_plot_holder.append(df_group_1_pos_perc) - data_for_box_plot_holder.append(df_group_2_pos_perc) + if return_figure_and_summary: + data_for_box_plot_holder.append(df_group_1_pos_perc) + data_for_box_plot_holder.append(df_group_2_pos_perc) # Name each series the value of the current threshold ser_group_1_pos_perc.name = threshold @@ -81,15 +138,17 @@ def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, ano # If we want the positive percentage of all the cells in each group... if all_cells: + + if return_figure_and_summary: - # Create a plotly figure - fig = go.Figure() + # Create a plotly figure + fig = go.Figure() - # Plot positive percentage in whole dataset vs. threshold for group 1 - fig.add_trace(go.Scatter(x=z_scores, y=group_1_holder, mode='lines+markers', name='Baseline')) + # Plot positive percentage in whole dataset vs. threshold for group 1 + fig.add_trace(go.Scatter(x=z_scores, y=group_1_holder, mode='lines+markers', name='Baseline')) - # Plot positive percentage in whole dataset vs. threshold for group 2 - fig.add_trace(go.Scatter(x=z_scores, y=group_2_holder, mode='lines+markers', name='Signal')) + # Plot positive percentage in whole dataset vs. threshold for group 2 + fig.add_trace(go.Scatter(x=z_scores, y=group_2_holder, mode='lines+markers', name='Signal')) # Create the summary dataframe df_summary = pd.DataFrame({'Z score': z_scores, 'Threshold': thresholds, 'Positive percentage (all cells) for baseline group': group_1_holder, 'Positive percentage (all cells) for signal group': group_2_holder}) @@ -98,7 +157,8 @@ def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, ano else: # Create the dataframe holding all the data for the desired box plot - df_box_plot = pd.concat(data_for_box_plot_holder, axis='rows') + if return_figure_and_summary: + df_box_plot = pd.concat(data_for_box_plot_holder, axis='rows') # Create a dataframe from the positive percentages for each image in each group df_group_1_pos_perc = pd.concat(group_1_holder, axis='columns') @@ -108,25 +168,25 @@ def generate_box_and_whisker(apply_another_filter, df, column_for_filtering, ano avg_group_1 = df_group_1_pos_perc.mean() avg_group_2 = df_group_2_pos_perc.mean() - # Plot the positive percentage in each image vs. threshold for both groups - # fig = go.Figure() - # fig.add_trace(go.Scatter(x=avg_group_1.index, y=avg_group_1.values, mode='lines+markers', name='Baseline')) - # fig.add_trace(go.Scatter(x=avg_group_2.index, y=avg_group_2.values, mode='lines+markers', name='Signal')) - # Create the desired box plot - fig = px.box(df_box_plot, x='Z score', y='Positive %', color='Group', points='all') + if return_figure_and_summary: + fig = px.box(df_box_plot, x='Z score', y='Positive %', color='Group', points='all') # Create the summary dataframe df_summary = pd.DataFrame({'Z score': z_scores, 'Threshold': thresholds, 'Positive % (avg. over images) for baseline group': avg_group_1.values, 'Positive % (avg. over images) for signal group': avg_group_2.values}) # Update the layout of the plot - fig.update_layout(title='Positive percentage vs. baseline Z score', - xaxis_title='Baseline Z score', - yaxis_title='Positive percentage', - legend_title='Group') + if return_figure_and_summary: + fig.update_layout(title='Positive percentage vs. baseline Z score', + xaxis_title='Baseline Z score', + yaxis_title='Positive percentage', + legend_title='Group') # Return the plot, the Z scores, and the thresholds - return fig, df_summary + if return_figure_and_summary: + return fig, df_summary + else: + return df_summary def reset_values_on_which_to_filter_another_column(): @@ -202,6 +262,8 @@ def update_dependencies_of_filtering_widgets(): st.session_state['mg__curr_column_range'] = (curr_series.min(), curr_series.max()) st.session_state['mg__selected_value_range'] = st.session_state['mg__curr_column_range'] # initialize the selected range to the entire range st.session_state['mg__min_selection_value'] = st.session_state['mg__curr_column_range'][0] # initialize the minimum selection value to the minimum of the range + st.session_state['mg__histogram_x_range'] = list(st.session_state['mg__curr_column_range']) + st.session_state['mg__random_string'] = generate_random_string() else: st.session_state['mg__selected_column_type'] = 'categorical' st.session_state['mg__curr_column_unique_values'] = curr_series.unique() @@ -451,6 +513,9 @@ def main(): Main function for running the page ''' + # Global variable + st_key_prefix = 'mg__' + # If 'input_dataset' isn't in the session state, print an error message and return if 'input_dataset' not in st.session_state: st.error('An input dataset has not yet been opened. Please do so using the "Open File" page in the sidebar.') @@ -601,15 +666,26 @@ def main(): # If extra settings are to be displayed, create widgets for selecting particular images to define two groups if st.session_state['mg__extra_settings']: - st.multiselect('Images in baseline group:', options=df['Slide ID'].unique().tolist(), key='mg__images_in_plotting_group_1') - st.multiselect('Images in signal group:', options=df['Slide ID'].unique().tolist(), key='mg__images_in_plotting_group_2') + + # Instantiate the object + image_selector = image_filter.ImageFilter(df, image_colname='Slide ID', st_key_prefix=st_key_prefix) + + # If the image filter is not ready (which means the filtering dataframe was not generated), return + if not image_selector.ready: + return + + # Create two image filters + st.session_state['mg__images_in_plotting_group_1'] = image_selector.select_images(key='baseline', color='blue') + st.session_state['mg__images_in_plotting_group_2'] = image_selector.select_images(key='signal', color='red') + + # Define other keys which are no longer relevant but we are leaving in so the rest of the code does not break if 'mg__filter_on_another_column' not in st.session_state: st.session_state['mg__filter_on_another_column'] = False - st.checkbox('Filter on another column', key='mg__filter_on_another_column') - st.selectbox('Select another column on which to filter the plots:', df.select_dtypes('category').columns, key='mg__another_filter_column', disabled=(not st.session_state['mg__filter_on_another_column']), on_change=reset_values_on_which_to_filter_another_column) + if 'mg__another_filter_column' not in st.session_state: + st.session_state['mg__another_filter_column'] = None # otherwise, df.select_dtypes('category').columns[0] if 'mg__values_on_which_to_filter' not in st.session_state: reset_values_on_which_to_filter_another_column() - st.multiselect('Values on which to filter:', df[st.session_state['mg__another_filter_column']].unique(), key='mg__values_on_which_to_filter', disabled=(not st.session_state['mg__filter_on_another_column'])) + else: st.session_state['mg__images_in_plotting_group_1'] = [] st.session_state['mg__images_in_plotting_group_2'] = [] @@ -749,27 +825,39 @@ def main(): # Get the lowest-intensity "positive" intensity/marker intensity_cutoff = srs_marker_column_values[positive_loc].index[0] + if 'mg__histogram_box_selection_function' not in st.session_state: + st.session_state['mg__histogram_box_selection_function'] = 'Positivity identification' + st.radio('Histogram box selection function:', ['Positivity identification', 'Zoom'], key='mg__histogram_box_selection_function') + + st.button('Reset x-axis zoom', on_click=reset_x_axis_range, args=(use_groups_for_plotting, kde_or_hist_to_plot_full)) + # Plot the Plotly figure in Streamlit fig = go.Figure() + if not use_groups_for_plotting: - fig.add_trace(go.Scatter(x=kde_or_hist_to_plot_full['Value'], y=kde_or_hist_to_plot_full['Density'], fill='tozeroy', mode='none', fillcolor='rgba(255, 0, 0, 0.25)', name='All selected images', hovertemplate=' ')) + fig.add_trace(go.Scatter(x=kde_or_hist_to_plot_full['Value'], y=kde_or_hist_to_plot_full['Density'], fill='tozeroy', mode='markers', marker=dict(color='rgba(255, 0, 0, 0.25)', size=1), fillcolor='rgba(255, 0, 0, 0.25)', name='All selected images', hovertemplate=' ')) fig.add_trace(go.Scatter(x=df_to_plot_selected['Value'], y=df_to_plot_selected['Density'], fill='tozeroy', mode='none', fillcolor='rgba(255, 0, 0, 0.5)', name='Selection', hoverinfo='skip')) if intensity_cutoff is not None: fig.add_vline(x=intensity_cutoff, line_color='green', line_width=3, line_dash="dash", annotation_text="Previous threshold: ~{}".format((intensity_cutoff)), annotation_font_size=18, annotation_font_color="green") fig.update_layout(hovermode='x unified', xaxis_title='Column value', yaxis_title='Density') fig.update_layout(legend=dict(yanchor="top", y=1.2, xanchor="left", x=0.01, orientation="h")) else: - fig.add_trace(go.Scatter(x=kde_or_hist_to_plot_full[0]['Value'], y=kde_or_hist_to_plot_full[0]['Density'], fill='tozeroy', mode='none', fillcolor='rgba(0, 255, 255, 0.25)', name='Baseline group', hovertemplate=' ')) + fig.add_trace(go.Scatter(x=kde_or_hist_to_plot_full[0]['Value'], y=kde_or_hist_to_plot_full[0]['Density'], fill='tozeroy', mode='markers', marker=dict(color='rgba(0, 255, 255, 0.25)', size=1), fillcolor='rgba(0, 255, 255, 0.25)', name='Baseline group', hovertemplate=' ')) fig.add_trace(go.Scatter(x=df_to_plot_selected[0]['Value'], y=df_to_plot_selected[0]['Density'], fill='tozeroy', mode='none', fillcolor='rgba(0, 255, 255, 0.5)', name='Baseline selection', hoverinfo='skip')) - fig.add_trace(go.Scatter(x=kde_or_hist_to_plot_full[1]['Value'], y=kde_or_hist_to_plot_full[1]['Density'], fill='tozeroy', mode='none', fillcolor='rgba(255, 0, 0, 0.25)', name='Signal group', hovertemplate=' ')) + fig.add_trace(go.Scatter(x=kde_or_hist_to_plot_full[1]['Value'], y=kde_or_hist_to_plot_full[1]['Density'], fill='tozeroy', mode='markers', marker=dict(color='rgba(255, 0, 0, 0.25)', size=1), fillcolor='rgba(255, 0, 0, 0.25)', name='Signal group', hovertemplate=' ')) fig.add_trace(go.Scatter(x=df_to_plot_selected[1]['Value'], y=df_to_plot_selected[1]['Density'], fill='tozeroy', mode='none', fillcolor='rgba(255, 0, 0, 0.5)', name='Signal selection', hoverinfo='skip')) if intensity_cutoff is not None: fig.add_vline(x=intensity_cutoff, line_color='green', line_width=3, line_dash="dash", annotation_text="Previous threshold: ~{}".format((intensity_cutoff)), annotation_font_size=18, annotation_font_color="green") fig.update_layout(hovermode='x unified', xaxis_title='Column value', yaxis_title='Density') fig.update_layout(legend=dict(yanchor="top", y=1.2, xanchor="left", x=0.01, orientation="h")) + if 'mg__histogram_x_range' not in st.session_state: + reset_x_axis_range(use_groups_for_plotting, kde_or_hist_to_plot_full) + + fig.update_xaxes(range=st.session_state['mg__histogram_x_range']) + # Set Plotly chart in streamlit - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, on_select=plotly_chart_histogram_callback, key=('mg__plotly_chart_histogram_' + st.session_state['mg__random_string'] + '__do_not_persist'), selection_mode='box') # Set the selection dictionary for the current filter to pass on to the current phenotype definition selection_dict = {'column_for_filtering': column_for_filtering, 'selected_min_val': selected_min_val, 'selected_max_val': selected_max_val, 'selected_column_values': None} @@ -822,9 +910,28 @@ def main(): if 'mg__positive_percentage_per_image' not in st.session_state: st.session_state['mg__positive_percentage_per_image'] = True st.checkbox('Calculate positive percentages separately for each image', key='mg__positive_percentage_per_image') - fig, df_summary = generate_box_and_whisker(apply_another_filter, df_batch_normalized, column_for_filtering, st.session_state['mg__another_filter_column'], st.session_state['mg__values_on_which_to_filter'], st.session_state['mg__images_in_plotting_group_1'], st.session_state['mg__images_in_plotting_group_2'], all_cells=(not st.session_state['mg__positive_percentage_per_image'])) - st.plotly_chart(fig) - st.dataframe(df_summary, hide_index=True) + + if 'mg__specify_mean_for_zscore_calc' not in st.session_state: + st.session_state['mg__specify_mean_for_zscore_calc'] = False + if st.checkbox('Specify mean for Z-score calculation', key='mg__specify_mean_for_zscore_calc'): + if 'mg__mean_for_zscore_calc' not in st.session_state: + st.session_state['mg__mean_for_zscore_calc'] = df_batch_normalized[column_for_filtering].mean() + st.number_input('Mean for Z-score calculation:', key='mg__mean_for_zscore_calc') + if 'mg__specify_std_for_zscore_calc' not in st.session_state: + st.session_state['mg__specify_std_for_zscore_calc'] = False + if st.checkbox('Specify standard deviation for Z-score calculation', key='mg__specify_std_for_zscore_calc'): + if 'mg__std_for_zscore_calc' not in st.session_state: + st.session_state['mg__std_for_zscore_calc'] = df_batch_normalized[column_for_filtering].std() + st.number_input('Standard deviation for Z-score calculation:', min_value=0.0, key='mg__std_for_zscore_calc') + if not st.session_state['mg__specify_mean_for_zscore_calc']: + st.session_state['mg__mean_for_zscore_calc'] = None + if not st.session_state['mg__specify_std_for_zscore_calc']: + st.session_state['mg__std_for_zscore_calc'] = None + + fig, df_summary = generate_box_and_whisker(apply_another_filter, df_batch_normalized, column_for_filtering, st.session_state['mg__another_filter_column'], st.session_state['mg__values_on_which_to_filter'], st.session_state['mg__images_in_plotting_group_1'], st.session_state['mg__images_in_plotting_group_2'], all_cells=(not st.session_state['mg__positive_percentage_per_image']), mean_for_zscore_calc=st.session_state['mg__mean_for_zscore_calc'], std_for_zscore_calc=st.session_state['mg__std_for_zscore_calc']) + st.session_state['mg__df_summary_contents'] = df_summary + st.plotly_chart(fig, on_select=plotly_chart_summary_callback, key='mg__plotly_chart_summary__do_not_persist') + st.dataframe(df_summary, hide_index=True, key="mg__df_summary__do_not_persist", on_select=df_summary_callback, selection_mode=["single-row"]) # Add the current column filter to the current phenotype assignment st.button(':star2: Add column filter to current phenotype :star2:', use_container_width=True, on_click=update_dependencies_of_button_for_adding_column_filter_to_current_phenotype, kwargs=selection_dict, disabled=add_column_button_disabled) @@ -873,7 +980,7 @@ def main(): # Print out a five-row sample of the main dataframe st.write('Augmented dataset sample:') - st.dataframe(st.session_state['mg__df'].sample(5), hide_index=True) + st.dataframe(utils.sample_df_without_replacement_by_number(df=st.session_state['mg__df'], n=5), hide_index=True) # if st.button('Save dataset to `output` folder'): # st.session_state['mg__df'].to_csv('./output/saved_dataset.csv') @@ -909,18 +1016,4 @@ def main(): # Call the main function if __name__ == '__main__': - - # Set page settings - st.set_page_config(layout='wide', page_title='Manual Phenotyping on Raw Intensities') - st.title('Manual Phenotyping on Raw Intensities') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/open_file.py b/pages2/open_file.py similarity index 81% rename from pages/open_file.py rename to pages2/open_file.py index c90a08c..6e43ecb 100644 --- a/pages/open_file.py +++ b/pages2/open_file.py @@ -5,10 +5,6 @@ import os import streamlit as st import streamlit_utils - -# Import relevant libraries -import app_top_of_page as top -import streamlit_dataframe_editor as sde import utils def clear_session_state(): @@ -75,14 +71,9 @@ def main(): st.session_state['opener__selected_input_file'] = None st.selectbox('Select an available input file to load:', options=available_input_files, key='opener__selected_input_file', disabled=st.session_state['opener__load_from_datafile_unifier']) - # Create a number input for the number of microns per coordinate unit + # Set the number of microns per coordinate unit if 'opener__microns_per_coordinate_unit' not in st.session_state: st.session_state['opener__microns_per_coordinate_unit'] = 1.0 - if st.session_state['opener__load_from_datafile_unifier']: - help_message = 'Remember that the dataset coordinates were converted to microns in the Datafile Unifier.' - else: - help_message = None - st.number_input('Enter the number of microns per coordinate unit in the input file:', min_value=0.0, key='opener__microns_per_coordinate_unit', format='%.4f', step=0.0001, disabled=st.session_state['opener__load_from_datafile_unifier'], help=help_message) # Determine the input datafile or input dataframe if st.session_state['opener__load_from_datafile_unifier']: @@ -117,9 +108,10 @@ def main(): st.session_state['opener__load_input_dataset'] = False with st.spinner('Loading the input dataset...'): streamlit_utils.load_input_dataset(input_file_or_df, st.session_state['opener__microns_per_coordinate_unit']) # this assigns the input dataset to st.session_state['input_dataset'] and the metadata to st.session_state['input_metadata'] + if st.session_state['input_dataset'] is not None: + st.session_state['input_dataset'].data, st.session_state['input_dataframe_memory_usage_bytes'] = utils.downcast_dataframe_dtypes(st.session_state['input_dataset'].data, also_return_final_size=True) + # st.session_state['adata'] = utils.create_anndata_from_dataframe(st.session_state['input_dataset'].data, columns_for_data_matrix='float') if st.session_state['input_dataset'] is not None: - st.session_state['input_dataset'].data, st.session_state['input_dataframe_memory_usage_bytes'] = utils.downcast_dataframe_dtypes(st.session_state['input_dataset'].data, also_return_final_size=True) - # st.session_state['adata'] = utils.create_anndata_from_dataframe(st.session_state['input_dataset'].data, columns_for_data_matrix='float') st.info('The input data have been successfully loaded and validated.') show_dataframe_updates = True else: @@ -158,33 +150,22 @@ def main(): :small_orange_diamond: Number of rows: `{df.shape[0]}` :small_orange_diamond: Number of columns: `{df.shape[1]}` :small_orange_diamond: Minimum coordinate spacing: `{dataset_obj.min_coord_spacing_:.4f} microns` - :small_orange_diamond: Loaded memory usage: `{st.session_state['input_dataframe_memory_usage_bytes'] / 1024 ** 2:.2f} MB` + :small_orange_diamond: Loaded memory usage: `{st.session_state['input_dataframe_memory_usage_bytes'] / 1024 ** 2:.2f} MB` ''' + # Display more information if the dataset has been preprocessed inside MAWA + if metadata['preprocessing'] is not None: + information += f':small_orange_diamond: Preprocessing: `{metadata["preprocessing"]}`' + # Display the information and the sampled dataframe st.markdown(information) st.header('Dataframe sample') resample_dataframe = st.button('Refresh dataframe sample') if ('opener__sampled_df' not in st.session_state) or resample_dataframe or show_dataframe_updates: - st.session_state['opener__sampled_df'] = df.sample(min(num_rows_to_sample, len(df))).sort_index() + st.session_state['opener__sampled_df'] = utils.sample_df_without_replacement_by_number(df=df, n=num_rows_to_sample).sort_index() st.write(st.session_state['opener__sampled_df']) # Run the main function if __name__ == '__main__': - - # Set page settings - page_name = 'Open File' - st.set_page_config(layout='wide', page_title=page_name) - st.title(page_name) - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages2/preprocessing.py b/pages2/preprocessing.py new file mode 100644 index 0000000..380a00a --- /dev/null +++ b/pages2/preprocessing.py @@ -0,0 +1,122 @@ +# Import relevant libraries +import streamlit as st +import radial_profiles +import time +import numpy as np + +# Global variable +st_key_prefix = 'preprocessing__' + + +# Function to initialize the preprocessing section +def initialize_radial_profiles_preprocessing(df): + + # Checkbox for whether to run checks + key = st_key_prefix + 'run_checks' + if key not in st.session_state: + st.session_state[key] = False + run_checks = st.checkbox('Run checks', key=key) + + # Number input for the threshold for the RawIntNorm check + key = st_key_prefix + 'perc_thresh_rawintnorm_column_check' + if key not in st.session_state: + st.session_state[key] = 0.01 + if run_checks: + st.number_input('Threshold for the RawIntNorm column check (%):', min_value=0.0, max_value=100.0, key=key) + perc_thresh_rawintnorm_column_check = st.session_state[key] + + # Number input to select the nuclear intensity channel + key = st_key_prefix + 'nuclear_channel' + if key not in st.session_state: + st.session_state[key] = 1 + nuclear_channel = st.number_input('Nuclear channel:', min_value=1, key=key) + + # Checkbox for whether to apply the z-score filter + key = st_key_prefix + 'do_z_score_filter' + if key not in st.session_state: + st.session_state[key] = True + do_z_score_filter = st.checkbox('Do z-score filter', key=key) + + # Number input for the z-score filter threshold + key = st_key_prefix + 'z_score_filter_threshold' + if key not in st.session_state: + st.session_state[key] = 3 + if do_z_score_filter: + st.number_input('z-score filter threshold:', min_value=0.0, key=key) + z_score_filter_threshold = st.session_state[key] + + # If dataset preprocessing is desired... + if st.button('Preprocess dataset'): + + # Record the start time + start_time = time.time() + + # Preprocess the dataset + df = radial_profiles.preprocess_dataset( + df, + perc_thresh_rawintnorm_column_check=perc_thresh_rawintnorm_column_check, + image_col='Slide ID', + nuclear_channel=nuclear_channel, + do_z_score_filter=do_z_score_filter, + z_score_filter_threshold=z_score_filter_threshold, + run_checks=run_checks + ) + + # If the dataset is None, it's likely preprocessing has already been performed, so display a warning and return + if df is None: + st.warning('It appears that the dataset has already been preprocessed because there is no "Label" column. If you would like to re-preprocess the dataset, please reload it from the Open File page.') + return + + # Output the time taken + st.write(f'Preprocessing took {int(np.round(time.time() - start_time))} seconds') + + # Calculate the memory usage of the transformed dataframe + st.session_state['input_dataframe_memory_usage_bytes'] = df.memory_usage(deep=True).sum() + + # Update the preprocessing parameters + st.session_state['input_metadata']['preprocessing'] = { + 'location': 'Radial Profiles app', + 'nuclear_channel': nuclear_channel, + 'do_z_score_filter': do_z_score_filter, + } + if do_z_score_filter: + st.session_state['input_metadata']['preprocessing']['z_score_filter_threshold'] = z_score_filter_threshold + + # Display information about the new dataframe + df.info() + + # In case df has been modified not-in-place in any way, reassign the input dataset as the modified df + st.session_state['input_dataset'].data = df + + # Return the modified dataframe + return df + + +def main(): + """ + Main function for the page. + """ + + st.header('ImageJ Output for Radial Profiles Analysis') + + # Ensure a dataset has been opened in the first place + if 'input_dataset' not in st.session_state: + st.warning('Please open a dataset from the Open File page at left.') + return + + # Save a shortcut to the dataframe + df = st.session_state['input_dataset'].data + + # Set up preprocessing + with st.columns(3)[0]: + df = initialize_radial_profiles_preprocessing(df) + + # Ensure the main dataframe is updated per the operations above + st.session_state['input_dataset'].data = df + + +# Run the main function +if __name__ == '__main__': + + # Call the main function + main() diff --git a/pages2/radial_bins_plots.py b/pages2/radial_bins_plots.py new file mode 100644 index 0000000..412649f --- /dev/null +++ b/pages2/radial_bins_plots.py @@ -0,0 +1,517 @@ +# Import relevant libraries +import streamlit as st +import time +import numpy as np +import plotly.express as px +from itertools import cycle, islice +import plotly.graph_objects as go +import pandas as pd +import scipy.spatial +import utils + +# Global variable +st_key_prefix = 'radial_bins_plots__' + + +def remove_keys(key_suffixes): + for key_suffix in key_suffixes: + st.session_state.pop(st_key_prefix + key_suffix, None) + + +def draw_single_image_scatter_plot(df, image_to_view, column_to_plot, values_to_plot, color_dict, xy_position_columns=['Cell X Position', 'Cell Y Position'], coordinate_scale_factor=1, annulus_spacing_um=250, use_coordinate_mins_and_maxs=False, xmin_col='Cell X Position', xmax_col='Cell X Position', ymin_col='Cell Y Position', ymax_col='Cell Y Position', units='microns', invert_y_axis=False, opacity=0.7): + + # Draw a header + st.header('Single image scatter plot') + + # If the user wants to display the scatter plot, indicated by a toggle... + if st_key_prefix + 'show_scatter_plot' not in st.session_state: + st.session_state[st_key_prefix + 'show_scatter_plot'] = False + if st.toggle('Show scatter plot', key=st_key_prefix + 'show_scatter_plot'): + + # Filter the DataFrame to include only the selected image + df_selected_image_and_filter = df.loc[(df['Slide ID'] == image_to_view), xy_position_columns + [column_to_plot]] + + # Optionally scale the coordinates (probably not; should have been done in Datafile Unifier + if coordinate_scale_factor != 1: + df_selected_image_and_filter[xy_position_columns] = df_selected_image_and_filter[xy_position_columns] * coordinate_scale_factor + + # Calculate the radius edges of the annuli + radius_edges, xy_mid = calculate_annuli_radius_edges(df_selected_image_and_filter, annulus_spacing_um=annulus_spacing_um, xy_position_columns=xy_position_columns) + + # Group the DataFrame for the selected image by unique value of the column to plot + selected_image_grouped_by_value = df_selected_image_and_filter.groupby(column_to_plot) + + # Create the scatter plot + fig = go.Figure() + + # Loop over the unique values in the column whose values to plot, in order of their frequency + for value_to_plot in values_to_plot: + + # If the value exists in the selected image... + if (value_to_plot in selected_image_grouped_by_value.groups) and (len(selected_image_grouped_by_value.groups[value_to_plot]) > 0): + + # Store the dataframe for the current value for the selected image + df_group = selected_image_grouped_by_value.get_group(value_to_plot) + + # If value is a string, replace '(plus)' with '+' and '(dash)' with '-', since it could likely be a phenotype with those substitutions + if isinstance(value_to_plot, str): + value_str_cleaned = value_to_plot.replace('(plus)', '+').replace('(dash)', '-') + else: + value_str_cleaned = value_to_plot + + # Add the object index to the label + df_group['hover_label'] = 'Index: ' + df_group.index.astype(str) + + # Works but doesn't scale the shapes + if not use_coordinate_mins_and_maxs: + fig.add_trace(go.Scatter(x=df_group[xy_position_columns[0]], y=df_group[xy_position_columns[1]], mode='markers', name=value_str_cleaned, marker_color=color_dict[value_to_plot], hovertemplate=df_group['hover_label'])) + + # Works really well + else: + fig.add_trace(go.Bar( + x=((df_group[xmin_col] + df_group[xmax_col]) / 2), + y=df_group[ymax_col] - df_group[ymin_col], + width=df_group[xmax_col] - df_group[xmin_col], + base=df_group[ymin_col], + name=value_str_cleaned, + marker=dict( + color=color_dict[value_to_plot], + opacity=opacity, + ), + hovertemplate=df_group['hover_label'] + )) + + # Plot circles of radii radius_edges[1:] centered at the midpoint of the coordinates + for radius in radius_edges[1:]: + fig.add_shape( + type='circle', + xref='x', + yref='y', + x0=xy_mid[xy_position_columns[0]] - radius, + y0=xy_mid[xy_position_columns[1]] - radius, + x1=xy_mid[xy_position_columns[0]] + radius, + y1=xy_mid[xy_position_columns[1]] + radius, + line=dict( + color='lime', + width=4, + ), + opacity=0.75, + ) + + # Update the layout + fig.update_layout( + xaxis=dict( + scaleanchor="y", + scaleratio=1, + ), + yaxis=dict( + autorange=('reversed' if invert_y_axis else True), + ), + title=f'Scatter plot for {image_to_view}', + xaxis_title=f'Cell X Position ({units})', + yaxis_title=f'Cell Y Position ({units})', + legend_title=column_to_plot, + height=800, # Set the height of the figure + width=800, # Set the width of the figure + ) + + # Plot the plotly chart in Streamlit + st.plotly_chart(fig, use_container_width=True) + + # Write the analysis results for the selected image + if st_key_prefix + 'df_analysis_results' in st.session_state: + st.write(st.session_state[st_key_prefix + 'df_analysis_results'].loc[image_to_view]) + + # We seem to need to render something on the page after rendering a plotly figure in order for the page to not automatically scroll back to the top when you go to the Previous or Next image... doesn't seem to always work at least in the scatter plotter + st.write(' ') + + +def initialize_main_settings(df, unique_images): + + # Main settings section + st.header('Main settings') + + # Define the main settings columns + settings_columns_main = st.columns(3) + + # In the first column... + with settings_columns_main[0]: + + # Store columns of certain types + if st_key_prefix + 'categorical_columns' not in st.session_state: + st.session_state[st_key_prefix + 'categorical_columns'] = utils.get_categorical_columns_including_numeric(df, max_num_unique_values=1000) + if st_key_prefix + 'numeric_columns' not in st.session_state: + st.session_state[st_key_prefix + 'numeric_columns'] = df.select_dtypes(include='number').columns + categorical_columns = st.session_state[st_key_prefix + 'categorical_columns'] + numeric_columns = st.session_state[st_key_prefix + 'numeric_columns'] + + # Choose a column to plot + if st_key_prefix + 'column_to_plot' not in st.session_state: + st.session_state[st_key_prefix + 'column_to_plot'] = categorical_columns[0] + column_to_plot = st.selectbox('Select a column by which to color the points:', categorical_columns, key=st_key_prefix + 'column_to_plot') + column_to_plot_has_changed = (st_key_prefix + 'column_to_plot_prev' not in st.session_state) or (st.session_state[st_key_prefix + 'column_to_plot_prev'] != column_to_plot) + st.session_state[st_key_prefix + 'column_to_plot_prev'] = column_to_plot + + # Optionally force-update the list of categorical columns + st.button('Update categorical columns', help='If you don\'t see the column you want to plot, click this button to update the list of categorical columns.', on_click=lambda: st.session_state.pop(st_key_prefix + 'categorical_columns', None)) + + # Get some information about the images in the input dataset + if st_key_prefix + 'ser_size_of_each_image' not in st.session_state: + st.session_state[st_key_prefix + 'ser_size_of_each_image'] = df['Slide ID'].value_counts() # calculate the number of objects in each image + ser_size_of_each_image = st.session_state[st_key_prefix + 'ser_size_of_each_image'] + + # Create an image selection selectbox + if st_key_prefix + 'image_to_view' not in st.session_state: + st.session_state[st_key_prefix + 'image_to_view'] = unique_images[0] + image_to_view = st.selectbox('Select image to view:', unique_images, key=st_key_prefix + 'image_to_view') + + # Display the number of cells in the selected image + st.write(f'Number of cells in image: {ser_size_of_each_image.loc[image_to_view]}') + + # Optionally navigate through the images using Previous and Next buttons + cols = st.columns(2) + with cols[0]: + st.button('Previous image', on_click=go_to_previous_image, args=(unique_images,), disabled=(image_to_view == unique_images[0]), use_container_width=True) + with cols[1]: + st.button('Next image', on_click=go_to_next_image, args=(unique_images, ), disabled=(image_to_view == unique_images[-1]), use_container_width=True) + + # In the second column... + with settings_columns_main[1]: + + # Optionally plot minimum and maximum coordinate fields + if st_key_prefix + 'use_coordinate_mins_and_maxs' not in st.session_state: + st.session_state[st_key_prefix + 'use_coordinate_mins_and_maxs'] = False + use_coordinate_mins_and_maxs = st.checkbox('Use coordinate mins and maxs', key=st_key_prefix + 'use_coordinate_mins_and_maxs') + settings_columns_refined = st.columns(2) + if st_key_prefix + 'x_min_coordinate_column' not in st.session_state: + st.session_state[st_key_prefix + 'x_min_coordinate_column'] = numeric_columns[0] + if st_key_prefix + 'y_min_coordinate_column' not in st.session_state: + st.session_state[st_key_prefix + 'y_min_coordinate_column'] = numeric_columns[0] + if st_key_prefix + 'x_max_coordinate_column' not in st.session_state: + st.session_state[st_key_prefix + 'x_max_coordinate_column'] = numeric_columns[0] + if st_key_prefix + 'y_max_coordinate_column' not in st.session_state: + st.session_state[st_key_prefix + 'y_max_coordinate_column'] = numeric_columns[0] + with settings_columns_refined[0]: + xmin_col = st.selectbox('Select a column for the minimum x-coordinate:', numeric_columns, key=st_key_prefix + 'x_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + with settings_columns_refined[1]: + xmax_col = st.selectbox('Select a column for the maximum x-coordinate:', numeric_columns, key=st_key_prefix + 'x_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + with settings_columns_refined[0]: + ymin_col = st.selectbox('Select a column for the minimum y-coordinate:', numeric_columns, key=st_key_prefix + 'y_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + with settings_columns_refined[1]: + ymax_col = st.selectbox('Select a column for the maximum y-coordinate:', numeric_columns, key=st_key_prefix + 'y_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + units = ('coordinate units' if use_coordinate_mins_and_maxs else 'microns') + + # In the third column... + with settings_columns_main[2]: + + # Add an option to invert the y-axis + if st_key_prefix + 'invert_y_axis' not in st.session_state: + st.session_state[st_key_prefix + 'invert_y_axis'] = False + invert_y_axis = st.checkbox('Invert y-axis', key=st_key_prefix + 'invert_y_axis') + + # Choose the opacity of objects + if st_key_prefix + 'opacity' not in st.session_state: + st.session_state[st_key_prefix + 'opacity'] = 0.7 + opacity = st.number_input('Opacity:', min_value=0.0, max_value=1.0, step=0.1, key=st_key_prefix + 'opacity') + + # Define the colors for the values to plot + if (st_key_prefix + 'color_dict' not in st.session_state) or column_to_plot_has_changed: + reset_color_dict(df[column_to_plot]) + values_to_plot = st.session_state[st_key_prefix + 'values_to_plot'] + color_dict = st.session_state[st_key_prefix + 'color_dict'] + + # Select a value whose color we want to modify + if (st_key_prefix + 'value_to_change_color' not in st.session_state) or column_to_plot_has_changed: + st.session_state[st_key_prefix + 'value_to_change_color'] = values_to_plot[0] + value_to_change_color = st.selectbox('Value whose color to change:', values_to_plot, key=st_key_prefix + 'value_to_change_color') + + # Create a color picker widget for the selected value + st.session_state[st_key_prefix + 'new_picked_color'] = color_dict[value_to_change_color] + st.color_picker('Pick a new color:', key=st_key_prefix + 'new_picked_color', on_change=update_color_for_value, args=(value_to_change_color,)) + + # Add a button to reset the colors to their default values + st.button('Reset plotting colors to defaults', on_click=reset_color_dict, args=(df[column_to_plot],)) + color_dict = st.session_state[st_key_prefix + 'color_dict'] + + # Return assigned variables + return column_to_plot, image_to_view, use_coordinate_mins_and_maxs, xmin_col, xmax_col, ymin_col, ymax_col, units, invert_y_axis, opacity, color_dict, values_to_plot, categorical_columns + + +def initialize_radial_bin_calculation(df): + + # Radial bins calculation section + st.header('Radial bins') + + # st.warning('Note, currently the image centroid is hardcoded!') + + # Get some information about the images in the input dataset + if st_key_prefix + 'unique_images' not in st.session_state: + st.session_state[st_key_prefix + 'unique_images'] = df['Slide ID'].unique() # get the unique images in the dataset + unique_images = st.session_state[st_key_prefix + 'unique_images'] + + # Number input for the coordinate scale factor + key = st_key_prefix + 'coordinate_scale_factor' + if key not in st.session_state: + st.session_state[key] = 1 + # coordinate_scale_factor = st.number_input('Coordinate scale factor:', min_value=0.0, key=key, help='This is only necessary if you have forgotten to scale the coordinates in the Datafile Unifier.') + coordinate_scale_factor = st.session_state[key] + + # Number input for the annulus spacing + key = st_key_prefix + 'annulus_spacing_um' + if key not in st.session_state: + st.session_state[key] = 250 + annulus_spacing_um = st.number_input('Annulus spacing (um):', min_value=0.0, key=key) + + # Multiselect for selection of coordinate columns + xy_position_columns = ['Cell X Position', 'Cell Y Position'] # this is set in Open File so should always be the same, no need for a widget + + # Calculate radial bins + if st.button('Calculate radial bins', on_click=remove_keys, args=(['color_dict', 'value_to_change_color'],)): + start_time = time.time() + df = add_radial_bin_to_dataset(df, unique_images, coordinate_scale_factor=coordinate_scale_factor, annulus_spacing_um=annulus_spacing_um, xy_position_columns=xy_position_columns) + st.session_state['input_dataset'].data = df + st.write(f'Calculation of radial bins took {int(np.round(time.time() - start_time))} seconds') + if st_key_prefix + 'categorical_columns' in st.session_state: + del st.session_state[st_key_prefix + 'categorical_columns'] # force the categorical columns to be recalculated since we just added one to the dataset + + # Return the necessary variables + return df, unique_images, coordinate_scale_factor, annulus_spacing_um, xy_position_columns + + +def calculate_annuli_radius_edges(df, annulus_spacing_um=250, xy_position_columns=['Cell X Position', 'Cell Y Position']): + + # Get the x-y midpoint of the coordinates in df[xy_position_columns] + xy_min = df[xy_position_columns].min() + xy_max = df[xy_position_columns].max() + xy_mid = (xy_min + xy_max) / 2 + + + # st.warning('Hardcoding image centers!') + # xy_mid.iloc[0] = 10130 / 2 * 0.32 + # xy_mid.iloc[1] = 10130 / 2 * 0.32 + + + # Get the radius edges that fit within the largest possible radius + largest_possible_radius = np.linalg.norm(xy_max - xy_mid) + annulus_spacing_um + num_intervals = largest_possible_radius // annulus_spacing_um # calculate the number of intervals that fit within the largest_possible_radius + end_value = (num_intervals + 1) * annulus_spacing_um # calculate the end value for np.arange to ensure it does not exceed largest_possible_radius + radius_edges = np.arange(0, end_value, annulus_spacing_um) # generate steps + + # Return the necessary variables + return radius_edges, xy_mid + + +# Function to determine the radial bin for every cell in the dataset +def add_radial_bin_to_dataset(df, unique_images, coordinate_scale_factor=1, annulus_spacing_um=250, xy_position_columns=['Cell X Position', 'Cell Y Position']): + + # For every image in the dataset... + for current_image in unique_images: + + # Filter the DataFrame to the current image + df_selected_image_and_filter = df.loc[(df['Slide ID'] == current_image), xy_position_columns] + + # Scale the coordinates. Generally unnecessary as this should have been done (converted to microns) in the unifier + if coordinate_scale_factor != 1: + df_selected_image_and_filter[xy_position_columns] = df_selected_image_and_filter[xy_position_columns] * coordinate_scale_factor + + # Calculate the radius edges of the annuli + radius_edges, xy_mid = calculate_annuli_radius_edges(df_selected_image_and_filter, annulus_spacing_um=annulus_spacing_um, xy_position_columns=xy_position_columns) + + # Construct a KDTree for the current image + kdtree = scipy.spatial.KDTree(df_selected_image_and_filter[xy_position_columns]) + + # For every outer radius... + prev_indices = None + for radius in radius_edges[1:]: + + # Get the indices of the points in the image within the current radius + curr_indices = kdtree.query_ball_point(xy_mid, radius) + + # Get the indices of the points in the current annulus (defined by the outer radius) + if prev_indices is not None: + annulus_indices = np.setdiff1d(curr_indices, prev_indices) + # annulus_indices = curr_indices[~np.isin(curr_indices, prev_indices)] # note copilot said this would be faster though may need to ensure curr_indices is a numpy array or else will get "TypeError: only integer scalar arrays can be converted to a scalar index" + else: + annulus_indices = curr_indices + + # Store the outer radius for all the cells in the current annulus in the current image + if len(annulus_indices) > 0: + df.loc[df_selected_image_and_filter.iloc[annulus_indices].index, 'Outer radius'] = radius + + # Store the current indices for the next iteration + prev_indices = curr_indices + + # Make sure there are no NaNs in the "Outer radius" column using an assertion + assert df['Outer radius'].isna().sum() == 0, 'There are NaNs in the "Outer radius" column but every cell should have been assigned a radial bin' + + # Return the dataframe with the new "Outer radius" column + return df + + +# Function to calculate the percent positives in each annulus in each image +def calculate_percent_positives_for_entire_dataset(df, column_to_plot, unique_images, coordinate_scale_factor=1, annulus_spacing_um=250, xy_position_columns=['Cell X Position', 'Cell Y Position']): + + # For every image in the dataset... + analysis_results_holder = [] + for current_image in unique_images: + + # Filter the DataFrame to the current image + df_selected_image_and_filter = df.loc[(df['Slide ID'] == current_image), xy_position_columns + [column_to_plot]] + + # Scale the coordinates. Generally unnecessary as this should have been done (converted to microns) in the unifier + if coordinate_scale_factor != 1: + df_selected_image_and_filter[xy_position_columns] = df_selected_image_and_filter[xy_position_columns] * coordinate_scale_factor + + # Get the x-y midpoint of the coordinates in df_selected_image_and_filter[xy_position_columns] + xy_min = df_selected_image_and_filter[xy_position_columns].min() + xy_max = df_selected_image_and_filter[xy_position_columns].max() + xy_mid = (xy_min + xy_max) / 2 + + # Get the radius edges that fit within the largest possible radius + largest_possible_radius = (xy_max - xy_mid).min() + spacing_um = annulus_spacing_um + num_intervals = largest_possible_radius // spacing_um # calculate the number of intervals that fit within the largest_possible_radius + end_value = (num_intervals + 1) * spacing_um # calculate the end value for np.arange to ensure it does not exceed largest_possible_radius + radius_edges = np.arange(0, end_value, spacing_um) # generate steps + + # Construct a KDTree for the current image + kdtree = scipy.spatial.KDTree(df_selected_image_and_filter[xy_position_columns]) + + # For every outer radius... + prev_indices = None + percent_positives = [] + annulus_radius_strings = [] + for radius in radius_edges[1:]: + + # Get the indices of the points in the image within the current radius + curr_indices = kdtree.query_ball_point(xy_mid, radius) + + # Get the indices of the points in the current annulus (defined by the outer radius) + if prev_indices is not None: + annulus_indices = np.setdiff1d(curr_indices, prev_indices) + # annulus_indices = curr_indices[~np.isin(curr_indices, prev_indices)] # note copilot said this would be faster though may need to ensure curr_indices is a numpy array or else will get "TypeError: only integer scalar arrays can be converted to a scalar index" + else: + annulus_indices = curr_indices + + # Get the series in the current image corresponding to the current annulus and positivity column + ser_positivity_annulus = df_selected_image_and_filter.iloc[annulus_indices][column_to_plot] + + # Calculate the percent positive, denoted by "+" + full_size = len(ser_positivity_annulus) + if full_size != 0: + percent_positives.append((ser_positivity_annulus == '+').sum() / full_size * 100) + else: + percent_positives.append(None) + + # Store the annulus radii as a string for the current annulus + annulus_radius_strings.append(f'Annulus from {radius - annulus_spacing_um} to {radius} um') + + # Store the current indices for the next iteration + prev_indices = curr_indices + + # Get a dictionary containing the calculation results for the current image + percent_positives_dict = dict(zip(annulus_radius_strings, percent_positives)) + percent_positives_dict[f'Number of annuli of width {annulus_spacing_um} um'] = len(percent_positives) + + # Store the dictionary in the analysis results holder + analysis_results_holder.append(percent_positives_dict) + + # Return a dataframe of the results + return pd.DataFrame(analysis_results_holder, index=unique_images) + + +# Function to update the color for a value +def update_color_for_value(value_to_change_color): + st.session_state[st_key_prefix + 'color_dict'][value_to_change_color] = st.session_state[st_key_prefix + 'new_picked_color'] + + +# Function to reset the color dictionary +def reset_color_dict(ser_to_plot): + # Create a color sequence based on the frequency of the values to plot in the entire dataset + values_to_plot = ser_to_plot.value_counts().index + colors = list(islice(cycle(px.colors.qualitative.Plotly), len(values_to_plot))) + st.session_state[st_key_prefix + 'values_to_plot'] = values_to_plot + st.session_state[st_key_prefix + 'color_dict'] = dict(zip(values_to_plot, colors)) # map values to colors + + +def go_to_previous_image(unique_images): + """ + Go to the previous image in the numpy array. + + Parameters: + unique_images (numpy.ndarray): The unique images. + + Returns: + None + """ + + # Get the current index in the unique images + key = st_key_prefix + 'image_to_view' + current_index = list(unique_images).index(st.session_state[key]) + + # If we're not already at the first image, go to the previous image + if current_index > 0: + current_index -= 1 + st.session_state[key] = unique_images[current_index] + + +def go_to_next_image(unique_images): + """ + Go to the next image in the numpy array. + + Parameters: + unique_images (numpy.ndarray): The unique images. + + Returns: + None + """ + + # Get the current index in the unique images + key = st_key_prefix + 'image_to_view' + current_index = list(unique_images).index(st.session_state[key]) + + # If we're not already at the last image, go to the next image + if current_index < len(unique_images) - 1: + current_index += 1 + st.session_state[key] = unique_images[current_index] + + +def main(): + """ + Main function for the page. + """ + + # Ensure a dataset has been opened in the first place + if 'input_dataset' not in st.session_state: + st.warning('Please open a dataset from the Open File page at left.') + return + + # Save a shortcut to the dataframe + df = st.session_state['input_dataset'].data + + # Set up some columns + columns = st.columns(3) + + # Set up calculation of radial bins + with columns[1]: + df, unique_images, coordinate_scale_factor, annulus_spacing_um, xy_position_columns = initialize_radial_bin_calculation(df) + + # Draw a divider + st.divider() + + # Main settings section + column_to_plot, image_to_view, use_coordinate_mins_and_maxs, xmin_col, xmax_col, ymin_col, ymax_col, units, invert_y_axis, opacity, color_dict, values_to_plot, _ = initialize_main_settings(df, unique_images) + + # Output/plutting section + st.divider() + + # Draw a single image scatter plot + draw_single_image_scatter_plot(df, image_to_view, column_to_plot, values_to_plot, color_dict, xy_position_columns=xy_position_columns, coordinate_scale_factor=coordinate_scale_factor, annulus_spacing_um=annulus_spacing_um, use_coordinate_mins_and_maxs=use_coordinate_mins_and_maxs, xmin_col=xmin_col, xmax_col=xmax_col, ymin_col=ymin_col, ymax_col=ymax_col, units=units, invert_y_axis=invert_y_axis, opacity=opacity) + + # Ensure the main dataframe is updated per the operations above + st.session_state['input_dataset'].data = df + + +# Run the main function +if __name__ == '__main__': + main() diff --git a/pages2/radial_profiles_analysis.py b/pages2/radial_profiles_analysis.py new file mode 100644 index 0000000..2ab4c58 --- /dev/null +++ b/pages2/radial_profiles_analysis.py @@ -0,0 +1,415 @@ +# Import relevant libraries +import streamlit as st +import pandas as pd +import scipy.stats +import numpy as np +import plotly.graph_objects as go +import plotly.express as px + +# Global variable +st_key_prefix = 'radial_profiles_plotting__' + + +def calculate_significant_differences_between_groups(percent_positives, well_id_indices1, well_id_indices2, unique_time_vals, unique_outer_radii, confidence_level=0.95): + + # Initialize the flags array + flags = np.ones((len(unique_time_vals), len(unique_outer_radii))) * np.nan + + # For every unique time and outer radius... + for itime in range(len(unique_time_vals)): + for iradius in range(len(unique_outer_radii)): + + # Get the percent positives for the two groups + percent_positives_group1 = percent_positives[well_id_indices1, itime, iradius] + percent_positives_group2 = percent_positives[well_id_indices2, itime, iradius] + + # Remove any NaNs + percent_positives_group1 = percent_positives_group1[~np.isnan(percent_positives_group1)] + percent_positives_group2 = percent_positives_group2[~np.isnan(percent_positives_group2)] + + # If there are fewer than two non-NaN values in either group, skip this iteration + if (len(percent_positives_group1) < 2) or (len(percent_positives_group2) < 2): + continue + + # Calculate the confidence intervals for the two groups + res = scipy.stats.bootstrap((percent_positives_group1, percent_positives_group2), lambda arr1, arr2: arr2.mean(axis=0) - arr1.mean(axis=0), confidence_level=confidence_level) + + # Determine the flag + if res.confidence_interval.low > 0: + flags[itime, iradius] = 1 + elif res.confidence_interval.high < 0: + flags[itime, iradius] = -1 + else: + flags[itime, iradius] = 0 + + # Create the heatmap + heatmap = go.Heatmap( + x=unique_outer_radii, + y=unique_time_vals, + z=flags, + colorscale='Jet', + zmin=-1, + zmax=1 + ) + + # Create a figure and add the heatmap + fig = go.Figure(data=[heatmap]) + + # Customize layout + fig.update_layout( + title='Flags Indicating Significant Differences (group2 - group1)', + xaxis_title='Outer radius (um)', + yaxis_title='Timepoint' + ) + + # Return the figure + return fig + + +def get_heatmap(percent_positives, well_id_indices, unique_time_vals, unique_outer_radii): + + # Create the heatmap + heatmap = go.Heatmap( + x=unique_outer_radii, + y=unique_time_vals, + z=np.nanmean(percent_positives[well_id_indices, :, :], axis=0), + colorscale='Jet', + zmin=0, + zmax=100 + ) + + # Create a figure and add the heatmap + fig = go.Figure(data=[heatmap]) + + # Customize layout + fig.update_layout( + title='Average Percent Positive Over Selected Wells', + xaxis_title='Outer radius (um)', + yaxis_title='Timepoint' + ) + + # Return the figure + return fig + + +def calculate_percent_positives(df, phenotype_column_for_analysis, unique_well_ids, unique_time_vals, unique_outer_radii): + + # Create a 3D array to hold the percent positives + percent_positives = np.ones((len(unique_well_ids), len(unique_time_vals), len(unique_outer_radii))) * np.nan + + # For each well... + for well_id, df_group in df.groupby('well_id'): + + # Get the location of the well_id in the unique_well_ids list + well_id_loc = unique_well_ids.index(well_id) + + # For each unique combination of time and outer radius... + for (time_val, outer_radius), df_group2 in df_group.groupby(by=['T', 'Outer radius']): + + # Get the location of the time_val and outer_radius in their respective lists + time_val_loc = unique_time_vals.index(time_val) + outer_radius_loc = unique_outer_radii.index(outer_radius) + + # Calculate the percent positive for the current group + ser = df_group2[phenotype_column_for_analysis] + percent_positives[well_id_loc, time_val_loc, outer_radius_loc] = (ser == 1).sum() / len(ser) * 100 + + # Return the percent positives + return percent_positives + + +def get_line_plots(percent_positives, well_id_indices, plot_confidence_intervals, unique_vals_for_series, position_in_percent_positives, series_name, unique_vals_for_x, xaxis_title, alpha=0.1, ci_type='bootstrap'): + + # Potentially permute the axes of percent_positives + if position_in_percent_positives == 1: + percent_positives_transposed = percent_positives + else: + percent_positives_transposed = np.transpose(percent_positives, axes=(0, 2, 1)) + + # Get the default colors to use for both the main lines and the shaded confidence intervals + colors_to_use = get_default_colors(len(unique_vals_for_series)) + colors_to_use_with_alpha = get_default_colors(len(unique_vals_for_series), alpha=alpha) + + # Initialize the plotly figure + fig = go.Figure() + + # For each series... + method_used_holder = [] + bootstrap_flag_holder = [] + for series_val_index, series_val in enumerate(unique_vals_for_series): + + # Get the percent positives for the current series + curr_percent_positives = percent_positives_transposed[well_id_indices, series_val_index, :] + + # Calculate the confidence intervals for the current series + confidence_intervals, bootstrap_flag, method_used = get_confidence_intervals(curr_percent_positives, ci_type=ci_type) + method_used_holder.append(method_used) + bootstrap_flag_holder.append(bootstrap_flag) + + # Optionally plot the confidence intervals + if plot_confidence_intervals: + + # Plot the lower bound of the confidence interval + fig.add_trace(go.Scatter(x=unique_vals_for_x, y=confidence_intervals[0, :], mode='lines', line=dict(width=0), showlegend=False)) + + # Plot the upper bound of the confidence interval with filling to the next Y (the lower bound) + fig.add_trace(go.Scatter(x=unique_vals_for_x, y=confidence_intervals[1, :], mode='lines', fill='tonexty', fillcolor=colors_to_use_with_alpha[series_val_index], line=dict(width=0), showlegend=True, name=f'{series_name} {series_val} CI')) # color format example: 'rgba(0,100,80,0.2)' + + # For each series... + for series_val_index, series_val in enumerate(unique_vals_for_series): + + # Get the percent positives for the current series + curr_percent_positives = percent_positives_transposed[well_id_indices, series_val_index, :] + + # Calculate the means for the current series + y = np.nanmean(curr_percent_positives, axis=0) + + # Plot the means + fig.add_trace(go.Scatter(x=unique_vals_for_x, y=y, mode='lines+markers', name=f'{series_name} {series_val}', line=dict(color=colors_to_use[series_val_index]), marker=dict(color=colors_to_use[series_val_index]))) + + # Update the layout of the figure + fig.update_layout(title=f'Percent positive averaged over all selected wells', xaxis_title=xaxis_title, yaxis_title='Percent positive (%)') + + # Display a warning if we couldn't calculate potentially desired bootstrap confidence intervals for some of the data + if np.any(bootstrap_flag_holder): + st.write('⚠️ Normal confidence intervals were calculated for some of the data instead of the selected bootstrap method.') + with st.expander('Expand to see which data used normal confidence due to there being fewer than two non-NaN well_ids:', expanded=False): + st.dataframe(pd.DataFrame(method_used_holder, index=unique_vals_for_series, columns=unique_vals_for_x)) + + # Return the plotly figure + return fig + + +def hex_to_rgb(hex_color): + # Remove the '#' character and convert the remaining string to an integer using base 16 + # Then extract each color component + hex_color = hex_color.lstrip('#') + r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) + return (r, g, b) + + +def get_default_colors(num_colors, alpha=None): + + # Get the color sequence + color_sequence = px.colors.qualitative.Plotly + + # Return a list of the first 15 colors, cycling through color_sequence if necessary + hex_colors = [color_sequence[i % len(color_sequence)] for i in range(num_colors)] + + # Optionally add an alpha value to each color and if so return the rbga values; otherwise return the hex values + if alpha is not None: + return [f'rgba{hex_to_rgb(hex_color) + (alpha,)}' for hex_color in hex_colors] + else: + return hex_colors + + +def get_confidence_intervals(array2d, ci_type='bootstrap'): + size_of_second_dim = array2d.shape[1] + ci = np.ones((2, size_of_second_dim)) * np.nan + method_used = [] + bootstrap_flag = False + for i in range(size_of_second_dim): + curr_set_of_data = array2d[:, i] + curr_set_of_data = curr_set_of_data[~np.isnan(curr_set_of_data)] # select out just the non-nan values in curr_set_of_data + if (ci_type == 'bootstrap') and (len(curr_set_of_data) < 2): + bootstrap_flag = True + ci_type_to_use = 'normal' + else: + ci_type_to_use = ci_type + method_used.append(ci_type_to_use) + curr_ci = get_confidence_interval(pd.Series(curr_set_of_data), ci_type=ci_type_to_use) + ci[:, i] = curr_ci + return ci, bootstrap_flag, method_used + + +# Calculate the 95% confidence interval of a series +def get_confidence_interval(ser, ci_type='bootstrap'): + assert ci_type in ['normal', 'bootstrap'], 'ci_type must be either "normal" or "bootstrap"' + if ci_type == 'normal': + # This is a common approach but works well primarily when the sample size is large (usually n > 30) and the data distribution is not heavily skewed. + mean = ser.mean() + margin_of_error = ser.sem() * 1.96 + return mean - margin_of_error, mean + margin_of_error + elif ci_type == 'bootstrap': + # Largely distribution-independent but sample sizes less than 10 should be interpreted with caution + res = scipy.stats.bootstrap((ser,), np.mean, confidence_level=0.95) # calculate the bootstrap confidence interval + confidence_interval = res.confidence_interval # extract the confidence interval + return confidence_interval.low, confidence_interval.high + + +def main(): + """ + Main function for the page. + """ + + # Ensure the user has loaded a dataset + if 'input_dataset' not in st.session_state: + st.warning('Please open a dataset from the Open File page at left.') + return + + # Save a shortcut to the dataframe + df = st.session_state['input_dataset'].data + + # Obtain the phenotype columns + phenotype_columns = [column for column in df.columns if column.startswith('Phenotype ')] + + # Ensure phenotype columns exist + if len(phenotype_columns) == 0: + st.warning('No phenotype columns found in the dataset. Please run the Adaptive Phenotyping page at left.') + return + + # Make sure all of the columns in ['T', 'REEC', 'well_id', 'Outer radius'] are present in df + if not all([col in df.columns for col in ['T', 'REEC', 'well_id', 'Outer radius']]): + st.warning('The columns "T", "REEC", "well_id", and "Outer radius" must be present in the dataset.') + return + + # Keep only the necessary columns + df = df[['Slide ID', 'T', 'REEC', 'well_id', 'Outer radius'] + phenotype_columns] + + # Get the unique values of particular columns of interest + unique_well_ids = sorted(df['well_id'].unique()) + unique_time_vals = sorted(df['T'].unique()) + unique_outer_radii = sorted(df['Outer radius'].unique()) + + with st.columns(3)[0]: + + # Select a phenotype column on which to perform the analysis + key = st_key_prefix + 'phenotype_column_for_analysis' + if key not in st.session_state: + st.session_state[key] = phenotype_columns[0] + phenotype_column_for_analysis = st.selectbox('Select a phenotype column on which to perform the analysis:', phenotype_columns, key=key) + + # Calculate the percent positives + if st.button('Calculate the percent positives'): + + # Save the percent_positives array to the session state + st.session_state[st_key_prefix + 'percent_positives'] = calculate_percent_positives(df, phenotype_column_for_analysis, unique_well_ids, unique_time_vals, unique_outer_radii) + + # Ensure the percent positives have been calculated + key = st_key_prefix + 'percent_positives' + if key not in st.session_state: + st.warning('Please calculate the percent positives.') + return + + # Save a shortcut to the percent positives + percent_positives = st.session_state[key] + + # Get the number of nans in percent_positives and if there are any, print out where they are + if np.isnan(percent_positives).sum() > 0: + st.write('⚠️ There are NaNs in the percent_positives array.') + with st.expander('Expand to see where the NaNs are located (well ID, time, outer radius):', expanded=False): + for i in range(len(unique_well_ids)): + for j in range(len(unique_time_vals)): + for k in range(len(unique_outer_radii)): + if np.isnan(percent_positives[i, j, k]): + st.write(unique_well_ids[i], unique_time_vals[j], unique_outer_radii[k]) + + # Obtain the wells and their properties (just the REEC for now). This takes 0.10 to 0.15 seconds + df_to_select = df[['well_id', 'REEC']].drop_duplicates().sort_values(['well_id', 'REEC']) + + # Allow the user to select more than one set of wells + key = st_key_prefix + 'select_two_groups_of_wells' + if key not in st.session_state: + st.session_state[key] = False + select_two_groups_of_wells = st.checkbox('Select two groups of wells', key=key) + + # Initialize columns appropriately + if select_two_groups_of_wells: + group1_column, group2_column = st.columns(2) + else: + group1_column = st.columns(1)[0] + + # Get the user selection from this dataframe + with group1_column: + if not select_two_groups_of_wells: + st.write('Select the well(s) whose average percent positives we will plot:') + else: + st.write('Select the first group of wells:') + selected_rows = st.dataframe(df_to_select, on_select='rerun', hide_index=True, key='group1_well_selection__do_not_persist')['selection']['rows'] + ser_selected_well_ids = df_to_select.iloc[selected_rows]['well_id'] + + # If the user wants to select two groups of wells, allow them to select the second group + if select_two_groups_of_wells: + with group2_column: + st.write('Select the second group of wells:') + selected_rows = st.dataframe(df_to_select, on_select='rerun', hide_index=True, key='group2_well_selection__do_not_persist')['selection']['rows'] + ser_selected_well_ids2 = df_to_select.iloc[selected_rows]['well_id'] + + # Ensure at least one well is selected + if len(ser_selected_well_ids) == 0: + if not select_two_groups_of_wells: + st.warning('No wells selected. Please select them from the left side of the table above.') + else: + st.warning('No wells selected for the first group. Please select them from the left side of the left table above.') + return + + # If the user wants to select two groups of wells, ensure at least one well is selected for the second group + if select_two_groups_of_wells and len(ser_selected_well_ids2) == 0: + st.warning('No wells selected for the second group. Please select them from the left side of the right table above.') + return + + # Obtain the indices of the selected wells in unique_well_ids + well_id_indices = [unique_well_ids.index(selected_well_id) for selected_well_id in ser_selected_well_ids] + + # If the user wants to select two groups of wells, obtain the indices of the selected wells in unique_well_ids + if select_two_groups_of_wells: + well_id_indices2 = [unique_well_ids.index(selected_well_id) for selected_well_id in ser_selected_well_ids2] + + # Checkbox for whether to plot confidence intervals + key = st_key_prefix + 'plot_confidence_intervals' + if key not in st.session_state: + st.session_state[key] = True + plot_confidence_intervals = st.checkbox('Plot confidence intervals (at least two wells must be selected)', key=key) + + # Set some widget defaults if they don't exist + key = st_key_prefix + 'ci_type' + if key not in st.session_state: + st.session_state[key] = 'bootstrap' + ci_type = st.session_state[key] + key = st_key_prefix + 'alpha' + if key not in st.session_state: + st.session_state[key] = 0.2 + alpha = st.session_state[key] + + # Allow the user to customize their values if they matter + if plot_confidence_intervals: + + # Radio button to select the type of confidence interval + ci_type = st.radio('Select the type of confidence interval:', ['bootstrap', 'normal'], key=st_key_prefix + 'ci_type') + + # Number input for the alpha value + alpha = st.number_input('Alpha value:', min_value=0.0, max_value=1.0, key=st_key_prefix + 'alpha') + + # Button to generate line plots + if st.button('Generate line plots'): + + # For a given list of well_id indices, create a plotly lineplot on the percent_positives array using the radii on the x-axis and the percent positives on the y-axis + st.plotly_chart(get_line_plots(percent_positives, well_id_indices, plot_confidence_intervals, unique_time_vals, 1, 'Time', unique_outer_radii, 'Outer radius (um)', alpha=alpha, ci_type=ci_type)) + + # For a given list of well_id indices, create a plotly lineplot on the percent_positives array using the time values on the x-axis and the percent positives on the y-axis + st.plotly_chart(get_line_plots(percent_positives, well_id_indices, plot_confidence_intervals, unique_outer_radii, 2, 'Bin', unique_time_vals, 'Time', alpha=alpha, ci_type=ci_type)) + + # Button to generate heatmaps + if st.button('Generate heatmaps'): + + # Display the heatmap + st.plotly_chart(get_heatmap(percent_positives, well_id_indices, unique_time_vals, unique_outer_radii)) + + # If the user wants to compare two groups of wells, allow them to set the confidence level + if select_two_groups_of_wells: + with st.columns(3)[0]: + key = st_key_prefix + 'confidence_level' + if key not in st.session_state: + st.session_state[key] = 0.95 + confidence_level = st.number_input('Confidence level:', min_value=0.0, max_value=1.0, key=key, format='%.2f') + + # Button to compare differences between the two groups of wells + if select_two_groups_of_wells and st.button('Assess whether the two group means are significantly different'): + + # Calculate the flags and display the heatmap + st.plotly_chart(calculate_significant_differences_between_groups(percent_positives, well_id_indices, well_id_indices2, unique_time_vals, unique_outer_radii, confidence_level=confidence_level)) + + +# Run the main function +if __name__ == '__main__': + main() diff --git a/pages2/results_transfer.py b/pages2/results_transfer.py new file mode 100644 index 0000000..04dc9a8 --- /dev/null +++ b/pages2/results_transfer.py @@ -0,0 +1,126 @@ +# Import relevant libraries +import streamlit as st +import os +import pandas as pd +from datetime import datetime +import pytz +import zipfile +import nidap_io + +# Global variable +st_key_prefix = 'results_transfer__' + + +def zip_files(file_paths, output_zip): + """ + Zips the given list of file paths into a single zip file. + + :param file_paths: List of file paths to include in the zip file. + :param output_zip: The path to the output zip file. + """ + with zipfile.ZipFile(output_zip, 'w') as zipf: + for file in file_paths: + # Add file to the zip file + zipf.write(file, os.path.basename(file)) + + +def main(): + """ + Main function for the page. + """ + + # Output directory + output_dir = 'output' + + # Get the list of files in the directory + dir_listing = os.listdir(output_dir) + + # Create a list to hold file information + file_info = [] + + # Define the EDT timezone + edt = pytz.timezone('US/Eastern') + + # For each file in the directory... + for file_name in dir_listing: + + # Get the full file path + file_path = os.path.join(output_dir, file_name) + + # Get the last modification time + mod_time = os.path.getmtime(file_path) + + # Convert the timestamp to a datetime object + mod_time_dt = datetime.fromtimestamp(mod_time) + + # Localize the datetime object to UTC and then convert to EDT + mod_time_edt = mod_time_dt.astimezone(edt) + + # Format the datetime object to a human-readable string + mod_time_readable = mod_time_edt.strftime('%Y-%m-%d %H:%M:%S %Z') + + # Append the file information to the list + file_info.append({'File Name': file_name, 'Last Modified': mod_time_readable}) + + # Create a DataFrame from the file information + st.write('Files in the output directory:') + df = pd.DataFrame(file_info) + + # Display the DataFrame in Streamlit + selected_files = st.dataframe(df, hide_index=True, on_select='rerun') + + # On the first of three columns... + with st.columns(3)[0]: + + # Add a text input for the zip file name + key = st_key_prefix + 'zipfile_name' + if key not in st.session_state: + st.session_state[key] = 'output.zip' + st.text_input('Zip file name:', key=key) + zipfile_path = os.path.join(output_dir, st.session_state[key]) + + # Check if we're on NIDAP + on_nidap = st.session_state['platform'].platform == 'nidap' + + # Add a button to zip (and transfer, if on NIDAP) the selected files + if st.button('Zip and transfer to NIDAP' if on_nidap else 'Zip'): + + # Ensure selection exists + if 'selection' in selected_files and 'rows' in selected_files['selection']: + + # Get the selected files + selected_files_list = df.iloc[selected_files['selection']['rows']]['File Name'].tolist() + + # Check if any files were selected + if selected_files_list: + + # Zip the selected files + zip_files([os.path.join(output_dir, file) for file in selected_files_list], zipfile_path) + + # Display a success message + st.success('Files zipped') + + # If we're on NIDAP... + if on_nidap: + + # Get the output dataset + dataset = nidap_io.get_foundry_dataset(alias='output') + + # Upload the zip file to the dataset + nidap_io.upload_file_to_dataset(dataset, zipfile_path) + + # Display a success message + st.success('Zip file transferred to NIDAP') + + # Otherwise, display an error message + else: + st.warning('No files selected to zip!') + else: + st.warning('No files selected to zip!') + + +# Run the main function +if __name__ == '__main__': + + # Call the main function + main() diff --git a/pages/robust_scatter_plotter.py b/pages2/robust_scatter_plotter.py similarity index 75% rename from pages/robust_scatter_plotter.py rename to pages2/robust_scatter_plotter.py index 20fdb1a..1f0959e 100644 --- a/pages/robust_scatter_plotter.py +++ b/pages2/robust_scatter_plotter.py @@ -1,13 +1,15 @@ # Import relevant libraries import streamlit as st -import app_top_of_page as top -import streamlit_dataframe_editor as sde import plotly.graph_objects as go import plotly.express as px import pandas as pd from itertools import cycle, islice +def turn_off_plotting(): + st.session_state['rsp__show_scatter_plot'] = False + + def update_color_for_value(value_to_change_color): st.session_state['rsp__color_dict'][value_to_change_color] = st.session_state['rsp__new_picked_color'] @@ -60,10 +62,7 @@ def go_to_next_image(unique_images): st.session_state['rsp__image_to_view'] = unique_images[current_index] -def main(): - """ - Main function for the page. - """ +def draw_scatter_plot_with_options(): # Define the main settings columns settings_columns_main = st.columns(3) @@ -78,6 +77,11 @@ def main(): input_dataset_has_changed = ('rsp__data_to_plot_prev' not in st.session_state) or (st.session_state['rsp__data_to_plot_prev'] != data_to_plot) st.session_state['rsp__data_to_plot_prev'] = data_to_plot + # If they want to plot phenotyped data, ensure they've performed phenotyping + if (data_to_plot == 'Input data') and ('input_dataset' not in st.session_state): + st.warning('If you\'d like to plot the input data, please open a file first.') + return + # If they want to plot phenotyped data, ensure they've performed phenotyping if (data_to_plot == 'Phenotyped data') and (len(st.session_state['df']) == 1): st.warning('If you\'d like to plot the phenotyped data, please perform phenotyping first.') @@ -85,13 +89,22 @@ def main(): # Set the shortcut to the dataframe of interest if data_to_plot == 'Input data': + if 'input_dataset' not in st.session_state: + st.warning('Please open a dataset first using the Open File page at left.') + return df = st.session_state['input_dataset'].data else: df = st.session_state['df'] # Store columns of certain types - if ('rsp__categorical_columns' not in st.session_state) or input_dataset_has_changed: - st.session_state['rsp__categorical_columns'] = df.select_dtypes(include=('category', 'object')).columns + if st.button('Re-extract columns from dataset') or ('rsp__categorical_columns' not in st.session_state) or input_dataset_has_changed: + max_num_unique_values = 1000 + categorical_columns = [] + for col in df.select_dtypes(include=('category', 'object')).columns: + if not isinstance(df[col].iloc[0], list): + if df[col].nunique() <= max_num_unique_values: + categorical_columns.append(col) + st.session_state['rsp__categorical_columns'] = categorical_columns if ('rsp__numeric_columns' not in st.session_state) or input_dataset_has_changed: st.session_state['rsp__numeric_columns'] = df.select_dtypes(include='number').columns categorical_columns = st.session_state['rsp__categorical_columns'] @@ -100,7 +113,7 @@ def main(): # Choose a column to plot if ('rsp__column_to_plot' not in st.session_state) or input_dataset_has_changed: st.session_state['rsp__column_to_plot'] = categorical_columns[0] - column_to_plot = st.selectbox('Select a column by which to color the points:', categorical_columns, key='rsp__column_to_plot') + column_to_plot = st.selectbox('Select a column by which to color the points:', categorical_columns, key='rsp__column_to_plot', help='If you don\'t see the column you want, you may need to re-extract the columns from the dataset using the button above.') column_to_plot_has_changed = ('rsp__column_to_plot_prev' not in st.session_state) or (st.session_state['rsp__column_to_plot_prev'] != column_to_plot) or input_dataset_has_changed st.session_state['rsp__column_to_plot_prev'] = column_to_plot @@ -119,6 +132,11 @@ def main(): # Display the number of cells in the selected image st.write(f'Number of cells in image: {ser_size_of_each_image.loc[image_to_view]}') + # Allow sampling of the scatter plot + if 'rsp__sample_percent' not in st.session_state: + st.session_state['rsp__sample_percent'] = 100 + sample_percent = st.number_input('Sample percent:', min_value=1, max_value=100, step=1, key='rsp__sample_percent') + # Optionally navigate through the images using Previous and Next buttons cols = st.columns(2) with cols[0]: @@ -132,7 +150,7 @@ def main(): # Optionally plot minimum and maximum coordinate fields if 'rsp__use_coordinate_mins_and_maxs' not in st.session_state: st.session_state['rsp__use_coordinate_mins_and_maxs'] = False - use_coordinate_mins_and_maxs = st.checkbox('Use coordinate mins and maxs', key='rsp__use_coordinate_mins_and_maxs') + use_coordinate_mins_and_maxs = st.checkbox('Use coordinate mins and maxs', key='rsp__use_coordinate_mins_and_maxs', on_change=turn_off_plotting) settings_columns_refined = st.columns(2) if 'rsp__x_min_coordinate_column' not in st.session_state: st.session_state['rsp__x_min_coordinate_column'] = numeric_columns[0] @@ -143,13 +161,13 @@ def main(): if 'rsp__y_max_coordinate_column' not in st.session_state: st.session_state['rsp__y_max_coordinate_column'] = numeric_columns[0] with settings_columns_refined[0]: - xmin_col = st.selectbox('Select a column for the minimum x-coordinate:', numeric_columns, key='rsp__x_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + xmin_col = st.selectbox('Select a column for the minimum x-coordinate:', numeric_columns, key='rsp__x_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs), on_change=turn_off_plotting) with settings_columns_refined[1]: - xmax_col = st.selectbox('Select a column for the maximum x-coordinate:', numeric_columns, key='rsp__x_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + xmax_col = st.selectbox('Select a column for the maximum x-coordinate:', numeric_columns, key='rsp__x_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs), on_change=turn_off_plotting) with settings_columns_refined[0]: - ymin_col = st.selectbox('Select a column for the minimum y-coordinate:', numeric_columns, key='rsp__y_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + ymin_col = st.selectbox('Select a column for the minimum y-coordinate:', numeric_columns, key='rsp__y_min_coordinate_column', disabled=(not use_coordinate_mins_and_maxs), on_change=turn_off_plotting) with settings_columns_refined[1]: - ymax_col = st.selectbox('Select a column for the maximum y-coordinate:', numeric_columns, key='rsp__y_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs)) + ymax_col = st.selectbox('Select a column for the maximum y-coordinate:', numeric_columns, key='rsp__y_max_coordinate_column', disabled=(not use_coordinate_mins_and_maxs), on_change=turn_off_plotting) units = ('coordinate units' if use_coordinate_mins_and_maxs else 'microns') # Optionally add another filter @@ -213,7 +231,7 @@ def main(): filter_loc = pd.Series(True, index=df.index) # Filter the DataFrame to include only the selected image and filter - df_selected_image_and_filter = df[(df['Slide ID'] == image_to_view) & filter_loc] + df_selected_image_and_filter = df[(df['Slide ID'] == image_to_view) & filter_loc].sample(frac=sample_percent / 100) # Group the DataFrame for the selected image by unique value of the column to plot selected_image_grouped_by_value = df_selected_image_and_filter.groupby(column_to_plot) @@ -233,6 +251,8 @@ def main(): # If value is a string, replace '(plus)' with '+' and '(dash)' with '-', since it could likely be a phenotype with those substitutions if isinstance(value_to_plot, str): value_str_cleaned = value_to_plot.replace('(plus)', '+').replace('(dash)', '-') + else: + value_str_cleaned = value_to_plot # Add the object index to the label df_group['hover_label'] = 'Index: ' + df_group.index.astype(str) @@ -276,29 +296,53 @@ def main(): # Plot the plotly chart in Streamlit st.plotly_chart(fig, use_container_width=True) + # Attempt to get page to not scroll up to the top after the plot is drawn... doesn't seem to work here, though note that toggling on the box and whisker plot does prevent this snapping to the top + st.write(' ') + + # Return necessary variables + return df, column_to_plot, values_to_plot, categorical_columns, unique_images + + +def main(): + """ + Main function for the page. + """ + + if (return_values := draw_scatter_plot_with_options()) is None: + return + + df, column_to_plot, values_to_plot, categorical_columns, unique_images = return_values + if 'rsp__get_percent_frequencies' not in st.session_state: st.session_state['rsp__get_percent_frequencies'] = False if st.toggle('Get percent frequencies of coloring column for entire dataset', key='rsp__get_percent_frequencies'): vc = df[column_to_plot].value_counts() - st.dataframe((df[column_to_plot].value_counts() / vc.sum() * 100).astype(int).reset_index()) + st.dataframe((vc / vc.sum() * 100).astype(int).reset_index()) + + if 'rsp__generate_box_and_whisker_plot' not in st.session_state: + st.session_state['rsp__generate_box_and_whisker_plot'] = False + if st.toggle('Generate box and whisker plot', key='rsp__generate_box_and_whisker_plot'): + with st.columns(3)[0]: + if 'rsp__box_and_whisker_plot_value' not in st.session_state: + st.session_state['rsp__box_and_whisker_plot_value'] = values_to_plot[0] + box_and_whisker_plot_value = st.selectbox('Select a value in the selected coloring column to analyze:', values_to_plot, key='rsp__box_and_whisker_plot_value') + if 'rsp__column_identifying_trace' not in st.session_state: + st.session_state['rsp__column_identifying_trace'] = categorical_columns[0] + column_identifying_trace = st.selectbox('Select a column to identify different values/traces to plot:', categorical_columns, key='rsp__column_identifying_trace') + match_loc = df[column_to_plot] == box_and_whisker_plot_value + percent_match_holder = [] + trace_value_holder = [] + for image in unique_images: + image_loc = df['Slide ID'] == image + set_trace_values = set(df.loc[image_loc, column_identifying_trace]) + assert len(set_trace_values) == 1, 'There should only be one value for the column identifying the trace' + trace_value_holder.append(set_trace_values.pop()) + percent_match_holder.append((image_loc & match_loc).sum() / image_loc.sum() * 100) + df_boxplot = pd.DataFrame({'Image': unique_images, 'Percent': percent_match_holder, 'Trace': trace_value_holder}) + fig = px.box(df_boxplot, x='Trace', y='Percent', title=f'Box and whisker plot for {box_and_whisker_plot_value}', points='all') + st.plotly_chart(fig, use_container_width=True) + - # Run the main function if __name__ == '__main__': - - # Set page settings - page_name = 'Coordinate Scatter Plotter' - st.set_page_config(layout='wide', page_title=page_name) - st.title(page_name) - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - - # Call the main function main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/pages/skeleton.py b/pages2/skeleton.py similarity index 100% rename from pages/skeleton.py rename to pages2/skeleton.py diff --git a/pages2/skeleton2.py b/pages2/skeleton2.py new file mode 100644 index 0000000..85d7c90 --- /dev/null +++ b/pages2/skeleton2.py @@ -0,0 +1,19 @@ +# Much simpler now (vs. skeleton.py) as the top and bottom matter is now located in Multiplex_Analysis_Web_Apps.py using the new Streamlit multipage functionality + +# Import relevant libraries +import streamlit as st + + +def main(): + """ + Main function for the page. + """ + + st.write('Insert your code here.') + + +# Run the main function +if __name__ == '__main__': + + # Call the main function + main() diff --git a/pages/spatial_umap_prediction_app.py b/pages2/spatial_umap_prediction_app.py similarity index 99% rename from pages/spatial_umap_prediction_app.py rename to pages2/spatial_umap_prediction_app.py index 5b99e4c..6c91cff 100644 --- a/pages/spatial_umap_prediction_app.py +++ b/pages2/spatial_umap_prediction_app.py @@ -7,7 +7,7 @@ import streamlit as st import streamlit_dataframe_editor as sde import app_top_of_page as top -from pages import sit_03a_Tool_parameter_selection as sit +from pages2 import sit_03a_Tool_parameter_selection as sit import new_phenotyping_lib import utils diff --git a/pages/02_phenotyping.py b/pages2/thresholded_phenotyping.py similarity index 90% rename from pages/02_phenotyping.py rename to pages2/thresholded_phenotyping.py index ef7b6f2..7c7bad6 100644 --- a/pages/02_phenotyping.py +++ b/pages2/thresholded_phenotyping.py @@ -5,11 +5,8 @@ import pandas as pd import streamlit as st from streamlit_extras.add_vertical_space import add_vertical_space - -# Import relevant libraries import nidap_dashboard_lib as ndl # Useful functions for dashboards connected to NIDAP import basic_phenotyper_lib as bpl # Useful functions for phenotyping collections of cells -import app_top_of_page as top import streamlit_dataframe_editor as sde def data_editor_change_callback(): @@ -25,11 +22,7 @@ def data_editor_change_callback(): # Create Phenotypes Summary Table based on 'phenotype' column in df st.session_state.pheno_summ = bpl.init_pheno_summ(st.session_state.df) - # Perform filtering - st.session_state.df_filt = ndl.perform_filtering(st.session_state) - - # Set Figure Objects based on updated df - st.session_state = ndl.setFigureObjs(st.session_state, st.session_state.pointstSliderVal_Sel) + filter_and_plot(plot_by_slider = True) def slide_id_prog_left_callback(): ''' @@ -59,7 +52,7 @@ def slide_id_callback(): st.session_state['selSlide ID'] = st.session_state['uniSlide ID'][st.session_state['idxSlide ID']] filter_and_plot() -def filter_and_plot(): +def filter_and_plot(plot_by_slider = False): ''' function to update the filtering and the figure plotting ''' @@ -72,11 +65,18 @@ def filter_and_plot(): if st.session_state['idxSlide ID'] == st.session_state['numSlide ID']-1: st.session_state.prog_right_disabeled = True + if plot_by_slider: + slider_val = st.session_state.point_slider_val + else: + slider_val = None + # Filtered dataset df_filt = ndl.perform_filtering(st.session_state) # Update and reset Figure Objects - st.session_state = ndl.setFigureObjs(st.session_state, df_filt) + st.session_state = ndl.set_figure_objs(session_state = st.session_state, + df_plot = df_filt, + slider_val = slider_val) def marker_multiselect_callback(): ''' @@ -263,12 +263,9 @@ def main(): plot_slide = st.columns(2) with plot_slide[0]: - with st.form('Plotting Num'): - st.slider('How many points to plot (%)', 0, 100, key = 'pointstSliderVal_Sel') - update_pixels_button = st.form_submit_button('Update Scatterplot') - if update_pixels_button: - st.session_state = ndl.setFigureObjs(st.session_state, - st.session_state.pointstSliderVal_Sel) + st.slider('How many points to plot (%)', 0, 100, key = 'point_slider_val', + on_change = filter_and_plot, kwargs = {"plot_by_slider": True}) + with plot_slide[1]: if st.session_state.calcSliderVal < 100: @@ -279,7 +276,8 @@ def main(): st.write(f'Drawing {st.session_state.drawnPoints} points') st.checkbox('Omit drawing cells with all negative markers', key = 'selhas_pos_mark', - on_change=filter_and_plot) + on_change=filter_and_plot, + kwargs = {"plot_by_slider": True}) image_prog_col = st.columns([3, 1, 1, 2]) with image_prog_col[0]: @@ -311,19 +309,4 @@ def main(): st.toast(f'Added {st.session_state.imgFileSuffixText} to export list ') if __name__ == '__main__': - - # Set a wide layout - st.set_page_config(page_title="Manual Phenotyping on Thresholded Intensities", - layout="wide") - st.title('Manual Phenotyping on Thresholded Intensities') - - # Run streamlit-dataframe-editor library initialization tasks at the top of the page - st.session_state = sde.initialize_session_state(st.session_state) - - # Run Top of Page (TOP) functions - st.session_state = top.top_of_page_reqs(st.session_state) - main() - - # Run streamlit-dataframe-editor library finalization tasks at the bottom of the page - st.session_state = sde.finalize_session_state(st.session_state) diff --git a/platform_io.py b/platform_io.py index 3d52160..ddb37e8 100644 --- a/platform_io.py +++ b/platform_io.py @@ -9,9 +9,8 @@ import pandas as pd import streamlit as st import streamlit_dataframe_editor as sde -import streamlit_session_state_management import utils -from pages import memory_analyzer +from pages2 import memory_analyzer # Constant local_input_dir = os.path.join('.', 'input') @@ -274,10 +273,10 @@ def load_selected_inputs(self): # If on NIDAP... elif self.platform == 'nidap': - st.subheader(':tractor: Load input data') + st.subheader(':tractor: Load input data into MAWA') # If a load button is clicked... - if st.button('Load selected (at left) input data :arrow_right:'): + if st.button('Load selected NIDAP input data :arrow_right:'): # Import relevant libraries import nidap_io @@ -313,9 +312,9 @@ def load_selected_inputs(self): if selected_input_filename.endswith('.zip'): splitted = selected_input_filename.split('.') # should be of length 2 or 3 (for, e.g., asdf.csv.zip) num_periods = len(splitted) - 1 # should be 1 or 2 - if (num_periods < 1) or (num_periods > 2): - st.error('Available .zip input filename {} has a bad number of periods ({}... it should have 1-2 periods); please fix this.'.format(selected_input_filename, num_periods)) - sys.exit() + # if (num_periods < 1) or (num_periods > 2): + # st.error('Available .zip input filename {} has a bad number of periods ({}... it should have 1-2 periods); please fix this.'.format(selected_input_filename, num_periods)) + # sys.exit() if num_periods == 1: # it's a zipped directory, by specification if '--' not in selected_input_filename: dirpath = os.path.join(local_input_dir, selected_input_filename.rstrip('.zip')) @@ -323,7 +322,7 @@ def load_selected_inputs(self): dirpath = os.path.join(local_input_dir, selected_input_filename.split('--')[0]) ensure_empty_directory(dirpath) shutil.unpack_archive(local_download_path, dirpath) - elif num_periods == 2: # it's a zipped datafile + else: # it's a zipped datafile shutil.unpack_archive(local_download_path, local_input_dir) else: shutil.copy(local_download_path, local_input_dir) @@ -374,7 +373,7 @@ def get_local_inputs_listing(self): # Write a dataframe of the local input files, which we don't want to be editable because we don't want to mess with the local inputs (for now), even though they're basically a local copy def display_local_inputs_df(self): - st.subheader(':open_file_folder: Input data available to the tool') + st.subheader(':open_file_folder: Input data in MAWA') local_inputs = self.get_local_inputs_listing() if self.platform == 'local': # not editable locally because deletion is disabled anyway so there'd be nothing to do with selected files make_complex_dataframe_from_file_listing(dirpath=local_input_dir, item_names=local_inputs, editable=False) @@ -501,7 +500,7 @@ def load_selected_archive(self): st.selectbox('Select available results archive to load:', self.available_archives, key='archive_to_load') # If the user wants to load the selected archive... - if st.button('Load selected (above) results archive :arrow_right:', help='WARNING: This will copy the contents of the selected archive to the results directory and will overwrite currently loaded results; please ensure they are backed up (you can just use the functions on this page)!'): + if st.button('Load selected results archive :arrow_right:', help='WARNING: This will copy the contents of the selected archive to the results directory and will overwrite currently loaded results; please ensure they are backed up (you can just use the functions on this page)!'): # First delete everything in currently in the output results directory (i.e., all currently loaded data) that's not an output archive delete_selected_files_and_dirs(local_output_dir, self.get_local_results_listing()) @@ -617,7 +616,7 @@ def get_local_results_listing(self): # Write a dataframe of the results in the local output directory, also obviously platform-independent def display_local_results_df(self): - st.subheader(':open_file_folder: Results loaded in the tool') + st.subheader(':open_file_folder: Results in MAWA') make_complex_dataframe_from_file_listing(local_output_dir, self.get_local_results_listing(), df_session_state_key_basename='local_results', editable=True) # Delete selected items from the output results directory @@ -641,7 +640,7 @@ def add_delete_local_results_button(self): # Write a YAML file of the current tool parameters to the loaded results directory def write_settings_to_local_results(self): st.subheader(':tractor: Write current tool parameters to loaded results') - if st.button(':pencil2: Write current tool settings to the results directory', help='Note you can subsequently load these parameters from the "Tool parameter selection" tab at left'): + if st.button(':pencil2: Write current tool settings to the results directory'): write_current_tool_parameters_to_disk(local_output_dir) st.rerun() # rerun since this potentially changes outputs diff --git a/radial_profiles.py b/radial_profiles.py new file mode 100644 index 0000000..3558d4b --- /dev/null +++ b/radial_profiles.py @@ -0,0 +1,345 @@ +# Import relevant libraries +import pandas as pd +import os +import utils +import time +import scipy.stats + + +# Define a function to determine the common suffix in a series +def common_suffix(ser): + strings = ser.tolist() + reversed_strings = [s[::-1] for s in strings] + reversed_lcs = os.path.commonprefix(reversed_strings) + suffix = reversed_lcs[::-1] + print(f'The common suffix is "{suffix}".') + return suffix + + +# Create a function to check for duplicate columns +def has_duplicate_columns(df, sample_size=1000): + print('NOTE: Sampling is being performed, so if this returns True, consider increasing the sample_size parameter (or manually check for duplicates). If it returns False, you know that there are no duplicate columns in the DataFrame and do not need to stress about the sampling.') + df = utils.sample_df_without_replacement_by_number(df=df, n=sample_size) + df_transposed = df.T + df_deduplicated = df_transposed.drop_duplicates().T + return len(df.columns) != len(df_deduplicated.columns) # if they're the same length (has_duplicate_columns() returns False), you know there are no duplicate columns + + +# Define a function to appropriately transform the dataframe +def transform_dataframe(df1, index, columns, values, repeated_columns, verbose=False): + + # Sample call + # df2.equals(transform_dataframe(df1, index='a', columns='c', values=['d', 'e'], repeated_columns=['b', 'f'], verbose=True)) + + # Pivot the specified values columns based on the specified index and columns + if verbose: + print(f'Pivoting the DataFrame based on the index "{index}", the columns "{columns}", and the values "{values}"...') + start_time = time.time() + pivot_df = df1.pivot_table(index=index, columns=columns, values=values, aggfunc='first', observed=False) + if verbose: + print(f'Pivoting the DataFrame took {time.time() - start_time:.2f} seconds.') + + # Flatten the multi-index columns + pivot_df.columns = [f'{x[0]}{x[1]}' for x in pivot_df.columns] + + # Reset the index to make the index column a column again + pivot_df.reset_index(inplace=True) + + # Merge with the repeated_columns from the original DataFrame + if verbose: + print(f'Merging the DataFrame with the repeated columns "{repeated_columns}"...') + start_time = time.time() + df2 = pd.merge(pivot_df, df1[[index] + repeated_columns].drop_duplicates(), on=index) + if verbose: + print(f'Merging the DataFrame took {time.time() - start_time:.2f} seconds.') + + # Return the transformed DataFrame + return df2 + + +# Transform the dataframes image-by-image to save both memory and time +def transform_dataframes_in_chunks(df, image_col, new_index_col, distinct_old_row_identifier_col, cols_with_unique_rows, cols_with_repeated_rows, verbose=False): + for unique_image in df[image_col].unique(): + df_image = df[df[image_col] == unique_image] + df_image2 = transform_dataframe(df_image, index=new_index_col, columns=distinct_old_row_identifier_col, values=cols_with_unique_rows, repeated_columns=cols_with_repeated_rows, verbose=verbose) + yield df_image2 + + +# Efficiently turn a series of strings (in the format of the ImageJ "Labe" column) into three series of the relevant parts +def process_label_column(ser_label): + + # Regular expression to capture the three parts + pattern = r'^(.*?) - (T=.*?:c:.*/.*? t:.*/.*?) - (.*)$' + + # Extract parts into separate columns + extracted_df = ser_label.str.extract(pattern) + + # Return the desired parts + return extracted_df[0], extracted_df[1], extracted_df[2] + + +def preprocess_dataset(df, perc_thresh_rawintnorm_column_check=0.01, image_col='Slide ID', nuclear_channel=2, do_z_score_filter=True, z_score_filter_threshold=3, run_checks=False): + + # Note the input dataframe df should likely be that loaded from the Open File page + # This function is essentially the same as the preprocess_radial_profile_data.ipynb Jupyter notebook + + # Constant + run_repeated_rows_check = False + + # Output the initial dataframe shape + print('Initial dataframe shape:', df.shape) + + # Determine whether preprocessing has already been performed + if 'Label' not in df.columns: + print('It appears that the dataset has already been preprocessed because there is no "Label" column. If you would like to re-preprocess the dataset, please reload it from the Open File page.') + return + + # Efficiently save some data from the "Label" column + tif_name, middle_data, last_name = process_label_column(df['Label']) + + # Show that the basename of the image name is always the same as the last name + if run_checks: + assert (tif_name.apply(lambda x: os.path.splitext(x)[0]) == last_name).all(), 'The basename of the image name is not always the same as the last name.' + + # Add the TIF image names to the dataframe + df['tif_name'] = tif_name.astype(str) + print('tif_name:', df['tif_name']) + + # Show that there are always five fields of the middle data when split using ":" + if run_checks: + assert list(middle_data.apply(lambda x: len(x.split(':'))).unique()) == [5], 'There are not always five fields of the middle data when split using ":".' + + # Add the time ("T") field to the observations dataframe + df['T'] = middle_data.apply(lambda x: x.split(':')[0].removeprefix('T=')).astype(int) + print('T:', df['T']) + + # Add the cell ID field to the observations dataframe + df['cell_id'] = middle_data.apply(lambda x: x.split(':')[1]).astype(str) + print('cell_id:', df['cell_id']) + + # Check that the "c" field is the exact same as the "Ch" column + if run_checks: + assert (df['Ch'] == middle_data.apply(lambda x: x.split(':')[3].split('/')[0]).astype(int)).all(), 'The "c" field is not the exact same as the "Ch" column.' + + # Check that the last field is exactly the same as one plus the "T" field + if run_checks: + assert (df['T'] == middle_data.apply(lambda x: x.split(':')[4].split('/')[0]).astype(int) - 1).all(), 'The last field is not exactly the same as one plus the "T" field.' + + # Check that the .tif basename is completely contained in the actual input filename + df_small = df[['input_filename', 'tif_name']].drop_duplicates() + if run_checks: + assert df_small.apply(lambda x: os.path.splitext(x['tif_name'])[0].replace('REEEC', 'REEC') in x['input_filename'], axis='columns').all(), 'The .tif basename is not completely contained in the actual input filename.' + + # Determine (and remove from the small df) the common suffix in the input_filename field + ser = df_small['input_filename'] + suffix1 = common_suffix(ser) + df_small['input_filename'] = ser.str.removesuffix(suffix1) + print('df_small:', df_small) + + # Ensure that the "T=X" part of the input filename is the same as the "T" field + if run_checks: + assert df['input_filename'].apply(lambda x: x.removesuffix(suffix1).split('=')[-1]).astype(int).equals(df['T']), 'The "T=X" part of the input filename is not the same as the "T" field.' + + # Determine the next common suffix aside from the T value + ser = df_small['input_filename'].apply(lambda x: x.removesuffix(suffix1).split('=')[0]) + suffix2 = common_suffix(ser) + df_small['input_filename'] = ser.str.removesuffix(suffix2) + print('df_small:', df_small) + + # Get a series of the data to process from the "input_filename" field + ser_remaining_data = df['input_filename'].apply(lambda x: x.removesuffix(suffix1).split('=')[0].removesuffix(suffix2)) + print('ser_remaining_data:', ser_remaining_data) + + # Add a column to identify if the cell was processed using the REEC + df['REEC'] = ser_remaining_data.str.endswith('_REEC').astype(bool) + print('REEC:', df['REEC']) + + # Remove the just-processed suffix + ser_remaining_data = ser_remaining_data.str.removesuffix('_REEC') + print('ser_remaining_data:', ser_remaining_data) + + # Add columns identifying the well and cell type + df['well_id'] = ser_remaining_data.apply(lambda x: x.split('_')[0]).astype(str) + df['cell_type'] = ser_remaining_data.apply(lambda x: x.split('_')[1]).astype(str) + print('well_id and cell_type:', df[['well_id', 'cell_type']]) + + # Normalize the total intensity using the cell areas + df['RawIntNorm'] = df['RawIntDen'] / df['Area'] + print('RawIntNorm:', df['RawIntNorm']) + + # Check that this yields the same result already included in the datafile + if run_checks: + min_true_val = df['Mean'].min() + thresh = min_true_val * perc_thresh_rawintnorm_column_check / 100 + mad = (df['RawIntNorm'] - df['Mean']).abs().max() + if mad < thresh: + print(f'The new and existing columns ARE equal to within {perc_thresh_rawintnorm_column_check}% of the minimum existing value (MAD: {mad}).') + else: + print(f'The new and existing columns are NOT equal to within {perc_thresh_rawintnorm_column_check}% of the minimum existing value (MAD: {mad}).') + + # Set a row ID as a combination of the input filename and the extracted cell ID + # Note then that for each row ID there should be the same number of rows as there are channels, unless some cell IDs are duplicated, which is what the following cells test. + # As may often be necessary, we are doing this on a per-image basis to save memory. + unique_images = df[image_col].unique() + for image in unique_images: + image_loc = df[image_col] == image + df.loc[image_loc, 'row_id'] = df.loc[image_loc, 'input_filename'].astype(str) + ' - ' + df.loc[image_loc, 'cell_id'] + + # Display the rows corresponding to duplicated cell IDs + num_channels = df['Ch'].nunique() + vc = df['row_id'].value_counts() + df_dupes = df[df['row_id'].isin(vc[vc > num_channels].index)] + print('df_dupes:', df_dupes) + + # Create a dataframe to aid in transforming the duplicated cell IDs to something unique + df_assist_deduping = df_dupes.groupby(by=['row_id', 'Ch'], observed=False)[' '].aggregate(sorted).to_frame().reset_index() + df_assist_deduping['index'] = df_assist_deduping[' '].apply(lambda x: list(range(len(x)))) + df_assist_deduping_expanded = df_assist_deduping.explode(' ') + df_assist_deduping_expanded['index'] = df_assist_deduping.explode('index')['index'].apply(lambda x: f'{x:04d}') + print('df_assist_deduping_expanded:', df_assist_deduping_expanded) + + # Check that the row ID in combination with the "blank index" uniquely identifies all rows (the duplicated cell IDs are not used here) + if run_checks: + assert (df['row_id'] + ' - ' + df[' '].astype(str)).nunique() == len(df), 'The row ID in combination with the "blank index" does not uniquely identify all rows.' + + # Append string indices to the cell IDs in order to de-duplicate them + for _, ser in df_assist_deduping_expanded.iterrows(): + row_id = ser['row_id'] + blank_index = ser[' '] + index_to_append = ser['index'] + loc = (df['row_id'] == row_id) & (df[' '] == blank_index) + assert loc.sum() == 1 + df.loc[loc, 'cell_id'] = df.loc[loc, 'cell_id'] + ':' + index_to_append + + # Output the de-duplicated versions of the previously duplicated rows (just `cell_id` is modified so far) + if run_checks: + print('df at the locations of the duplicates:', df.loc[df_dupes.index]) + + # Now confirm that there are no more duplicated cell IDs + for image in df[image_col].unique(): + image_loc = df[image_col] == image + df.loc[image_loc, 'row_id'] = df.loc[image_loc, 'input_filename'].astype(str) + ' - ' + df.loc[image_loc, 'cell_id'] + vc = df['row_id'].value_counts() + if run_checks: + df_dupes = df[df['row_id'].isin(vc[vc > num_channels].index)] + print('df_dupes:', df_dupes) + + # Really confirm there are no duplicated rows anymore + if run_checks: + assert list(vc.unique()) == [num_channels], 'There are still duplicated rows.' + + # Check for duplicate columns (ignoring MAWA-created columns) + if run_checks: + print('has_duplicate_columns:', has_duplicate_columns(df.drop(columns=['input_filename']))) + + # Check whether suspected columns are equal + if run_checks: + assert (df['IntDen'] == df['RawIntDen']).all(), 'The "IntDen" and "RawIntDen" columns are not equal.' + + # Drop the ostensibly more processed one + df.drop(columns=['IntDen'], inplace=True) + + # Check again for duplicate columns + # This should return False if just loading the data in a Jupyter notebook but not when dataset_formats.py is used (as in Open File) because it probably does add duplicate columns + if run_checks: + print('has_duplicate_columns:', has_duplicate_columns(df.drop(columns=['input_filename']))) + + # Show that in general, the cells are not strictly in sorted order + if run_checks: + is_sorted = df['cell_id'].iloc[::2].reset_index(drop=True).equals(df['cell_id'].iloc[1::2].reset_index(drop=True)) + print('is_sorted:', is_sorted) + + # Sort them appropriately and then check again + df.sort_values(by=['well_id', 'cell_type', 'REEC', 'T', 'cell_id'], inplace=True) + if run_checks: + is_sorted = df['cell_id'].iloc[::2].reset_index(drop=True).equals(df['cell_id'].iloc[1::2].reset_index(drop=True)) + print('is_sorted:', is_sorted) + + # Sample usage of the transformation function + if run_checks: + df1 = pd.DataFrame( + { + 'a': [1, 1, 2, 2, 3, 3], + 'c': [1, 2, 1, 2, 1, 2], + 'd': ['AA', 'BB', 'CC', 'DD', 'EE', 'FF'], + 'e': ['AAA', 'BBB', 'CCC', 'DDD', 'EEE', 'FFF'], + 'b': ['b1', 'b1', 'b2', 'b2', 'b3', 'b3'], + 'f': ['f1', 'f1', 'f2', 'f2', 'f3', 'f3'], + } + ) + df2 = pd.DataFrame( + { + 'a': [1, 2, 3], + 'd1': ['AA', 'CC', 'EE'], + 'd2': ['BB', 'DD', 'FF'], + 'e1': ['AAA', 'CCC', 'EEE'], + 'e2': ['BBB', 'DDD', 'FFF'], + 'b': ['b1', 'b2', 'b3'], + 'f': ['f1', 'f2', 'f3'], + } + ) + assert df2.equals(transform_dataframe(df1, 'a', 'c', ['d', 'e'], ['b', 'f'], verbose=True)), 'The transformation function does not work as expected.' + + # Compress the dataframe prior to transforming it + df = utils.downcast_dataframe_dtypes(df) + print('Shape of dataframe prior to transforming it:', df.shape) + + # Programmatically get the columns with duplicated data in sequential rows and manually confirm the result + distinct_old_row_identifier_col = 'Ch' + new_index_col = 'row_id' + cols_with_repeated_rows2 = [] + cols_with_unique_rows2 = [] + for column in df.columns: + ser = df[column] + if ser.iloc[::2].reset_index(drop=True).equals(ser.iloc[1::2].reset_index(drop=True)): + cols_with_repeated_rows2.append(column) + else: + cols_with_unique_rows2.append(column) + cols_with_repeated_rows2 = [col for col in cols_with_repeated_rows2 if col not in [distinct_old_row_identifier_col, new_index_col]] + cols_with_unique_rows2 = [col for col in cols_with_unique_rows2 if col not in [distinct_old_row_identifier_col, new_index_col]] + + # These two definitions are part of a manual check that, if passing consistently, can likely be removed in the future + if run_checks and run_repeated_rows_check: + # These sample rows are expected a unified dataframe but not for one processed by dataset_formats.py (which adds more columns), so in general we'll force this check to be skipped + cols_with_repeated_rows = ['Image ID_(standardized)', 'Centroid X (µm)_(standardized)', 'Centroid Y (µm)_(standardized)', 'Area', 'X', 'Y', 'Perim.', 'Circ.', 'AR', 'Round', 'Solidity', 'input_filename', 'tif_name', 'T', 'cell_id', 'REEC', 'well_id', 'cell_type'] + cols_with_unique_rows = [' ', 'Label', 'Mean', 'StdDev', 'RawIntDen', 'RawIntNorm'] + print('Repeated:', cols_with_repeated_rows2) + print('Unique:', cols_with_unique_rows2) + assert cols_with_repeated_rows2 == cols_with_repeated_rows, 'The columns with repeated rows are not as expected.' + assert cols_with_unique_rows2 == cols_with_unique_rows, 'The columns with unique rows are not as expected.' + + # Create a generator that transforms the dataframes image-by-image + df_generator = transform_dataframes_in_chunks(df, image_col, new_index_col=new_index_col, distinct_old_row_identifier_col=distinct_old_row_identifier_col, cols_with_unique_rows=cols_with_unique_rows2, cols_with_repeated_rows=cols_with_repeated_rows2, verbose=False) + + # Concatenate the DataFrames generated by the generator into a single DataFrame + df_transformed = pd.concat(df_generator, ignore_index=True) + + # Print the memory usage of the new dataframe + if run_checks: + df_transformed.info(memory_usage='deep') + print('Shape of df_transformed:', df_transformed.shape) + + # Get the signal intensity and signal Z score column names + signal_channels = set(df['Ch'].unique()) - {nuclear_channel} + signal_intensity_columns = [f'RawIntNorm{channel}' for channel in signal_channels] + signal_zscore_columns = [f'{col}_zscore' for col in signal_intensity_columns] + print('signal_intensity_columns:', signal_intensity_columns) + print('signal_zscore_columns:', signal_zscore_columns) + + # Calculate the Z scores on each area-normalized signal channel for each image + for curr_image in unique_images: + for curr_column in signal_intensity_columns: + curr_image_loc = df_transformed[image_col] == curr_image + df_transformed.loc[curr_image_loc, curr_column + '_zscore'] = scipy.stats.zscore(df_transformed.loc[curr_image_loc, curr_column]) + print('Columns of df_transformed that end with "_zscore":', df_transformed.loc[:, df_transformed.columns.str.endswith('_zscore')]) + + # Filter out cells for each image that are more than three standard deviations from the mean for any channel + if do_z_score_filter: + num_cells_before = df_transformed[image_col].value_counts() + df_transformed = df_transformed[(df_transformed[signal_zscore_columns].abs() < z_score_filter_threshold).all(axis='columns')] + num_cells_removed = num_cells_before - df_transformed[image_col].value_counts() + print('Number of cells removed by Z score threshold:') + print(num_cells_removed) + + # Return the transformed dataframe + return df_transformed diff --git a/requirements.txt b/requirements.txt index 79150dd..0a76416 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,28 +2,46 @@ altair==5.3.0 anndata==0.10.7 dill==0.3.5.1 foundry_ml==0.5.1 +holoviews==1.18.1 +imageio==2.31.1 +imantics==0.1.12 matplotlib==3.8.2 natsort==8.4.0 +networkx==3.1 numpy==1.24.4 -palantir==1.3.0 +objsize==0.7.0 +palantir==1.3.3 pandas==2.2.2 +parc==0.40 +parmap==1.7.0 +PhenoGraph==1.5.7 plotly==5.18.0 psutil==5.9.8 -Pympler==1.0.1 +pyparsing==3.1.2 pytz==2022.7 PyYAML==6.0.1 PyYAML==6.0.1 -Requests==2.31.0 -scikit_learn==1.3.0 -scipy==1.13.0 +Requests==2.32.3 +scanpy==1.10.1 +scikit_learn==1.5.0 +scipy==1.13.1 seaborn==0.13.2 +setuptools_scm==8.1.0 +Shapely==2.0.4 skimage==0.0 split_file_reader==0.1.4 -squidpy==1.4.1 +squidpy==1.5.0 st_pages==0.4.5 -streamlit==1.34.0 -streamlit_extras==0.4.2 +streamlit==1.35.0 +streamlit_extras==0.4.3 streamlit_javascript==0.1.5 -tqdm==4.66.1 +tensorflow==2.16.1 +tifffile==2023.8.25 +tqdm==4.66.4 umap==0.1.1 -umap_learn==0.5.4 +umap_learn==0.5.6 +setuptools-scm==8.1.0 +annoy==1.17.3 +sklearn-ann==0.1.2 +pynndescent==0.5.13 +plotnine==0.13.6 diff --git a/streamlit_dataframe_editor.py b/streamlit_dataframe_editor.py index 664a522..7ed2039 100644 --- a/streamlit_dataframe_editor.py +++ b/streamlit_dataframe_editor.py @@ -2,7 +2,7 @@ import streamlit as st import random import pandas as pd -from st_pages import get_pages, get_script_run_ctx +# from st_pages import get_pages, get_script_run_ctx def get_random_integer(stop=1000000): ''' diff --git a/streamlit_session_state_management.py b/streamlit_session_state_management.py index 9f06e38..76ef112 100644 --- a/streamlit_session_state_management.py +++ b/streamlit_session_state_management.py @@ -9,7 +9,7 @@ # from pympler.asizeof import asizeof as deep_mem_usage_in_bytes from objsize import get_deep_size as deep_mem_usage_in_bytes import time -from pages import memory_analyzer +from pages2 import memory_analyzer def load_session_state_preprocessing(saved_streamlit_session_states_dir, saved_streamlit_session_state_prefix='streamlit_session_state-', saved_streamlit_session_state_key='session_selection', selected_session=None): diff --git a/streamlit_utils.py b/streamlit_utils.py index 6ab3b88..57b3323 100644 --- a/streamlit_utils.py +++ b/streamlit_utils.py @@ -101,6 +101,7 @@ def load_input_dataset(datafile_path_or_df, coord_units_in_microns, input_datase 'datafile_path': None, 'coord_units_in_microns': coord_units_in_microns } + metadata['preprocessing'] = None # Load and standardize the input datafile into the session state dataset_obj = utils.load_and_standardize_input_datafile(datafile_path_or_df, coord_units_in_microns) diff --git a/tci_squidpy_supp_lib.py b/tci_squidpy_supp_lib.py index a39bf34..98049fa 100644 --- a/tci_squidpy_supp_lib.py +++ b/tci_squidpy_supp_lib.py @@ -109,7 +109,7 @@ def squidpy_enrichment(adata, radius=3.0, n_neighs=6, radius_instead_of_knn=True sq.gr.spatial_neighbors(adata, n_neighs=n_neighs, coord_type='generic') # Calculate the neighborhood enrichment - sq.gr.nhood_enrichment(adata, cluster_key=label_name, n_jobs=n_jobs) + sq.gr.nhood_enrichment(adata, cluster_key=label_name) # Generate the heatmap sq.pl.nhood_enrichment(adata, cluster_key=label_name, annotate=annotate) diff --git a/templateConfig.json b/templateConfig.json index baaf519..7789c54 100644 --- a/templateConfig.json +++ b/templateConfig.json @@ -1,6 +1,6 @@ { "parentTemplateId" : "jupyter-workspaces", "childTemplatesByPath" : { }, - "parentTemplateVersion" : "0.146.0", + "parentTemplateVersion" : "0.176.0", "childTemplateVersionsByPath" : { } } \ No newline at end of file diff --git a/time_cell_interaction_lib.py b/time_cell_interaction_lib.py index f43c9d8..a47fe2c 100644 --- a/time_cell_interaction_lib.py +++ b/time_cell_interaction_lib.py @@ -1824,7 +1824,7 @@ def average_dens_pvals_over_rois_for_each_slide(self, figsize=(10, 4), dpi=100, dpi=dpi, plots_dir=savedir, plot_real_data=plot_real_data, - entity_name=slide_name, + entity_name=slide_name.replace(' ', '_'), img_file_suffix=img_file_suffix, entity=entity, entity_index=-1, @@ -2242,7 +2242,7 @@ def average_over_rois_per_annotation_region(self, annotations_csv_files, phenoty dpi=settings__plotting__pval_dpi, plots_dir=plots_dir, plot_real_data=plot_real_data, - entity_name=entity_name, + entity_name=entity_name.replace(' ', '_'), img_file_suffix='', entity=entity, entity_index=-1, @@ -2814,7 +2814,7 @@ def roi_checks_and_output(x_roi, y_roi, do_printing=True, do_main_printing=True) return(x_range, y_range, min_coordinate_spacing) -def calculate_metrics_from_coords(min_coord_spacing, input_coords=None, neighbors_eq_centers=False, ncenters_roi=1300, nneighbors_roi=220, nbootstrap_resamplings=0, rad_range=(2.2, 5.1), use_theoretical_counts=False, roi_edge_buffer_mult=1, roi_x_range=(1.0, 100.0), roi_y_range=(0.5, 50.0), silent=False, log_file_data=None, keep_unnecessary_calculations=False, old_method=False): +def calculate_metrics_from_coords(min_coord_spacing, input_coords=None, neighbors_eq_centers=False, ncenters_roi=1300, nneighbors_roi=220, nbootstrap_resamplings=0, rad_range=(2.2, 5.1), use_theoretical_counts=False, roi_edge_buffer_mult=1, roi_x_range=(1.0, 100.0), roi_y_range=(0.5, 50.0), silent=False, log_file_data=None, keep_unnecessary_calculations=False, neighbor_counts_method='cdist avoiding oom'): ''' Given a set of coordinates (whether actual coordinates or ones to be simulated), calculate the P values and Z scores. @@ -2925,11 +2925,13 @@ def calculate_metrics_from_coords(min_coord_spacing, input_coords=None, neighbor print('NOTE: Using artificial distribution') nneighbors = scipy.stats.poisson.rvs(nexpected, size=(nvalid_centers,)) else: - if old_method: + if neighbor_counts_method == 'pure cdist': dist_mat = scipy.spatial.distance.cdist(coords_centers[valid_centers, :], coords_neighbors, 'euclidean') # calculate the distances between the valid centers and all the neighbors nneighbors = ((dist_mat >= rad_range[0]) & (dist_mat < rad_range[1])).sum(axis=1) # count the number of neighbors in the slice around every valid center - else: + elif neighbor_counts_method == 'cdist avoiding oom': # returns number of points in [0, radius) nneighbors = utils.calculate_neighbor_counts_with_possible_chunking(center_coords=coords_centers[valid_centers, :], neighbor_coords=coords_neighbors, radii=radii, single_dist_mat_cutoff_in_mb=200, verbose=False)[:, 0] # (num_centers,) + elif neighbor_counts_method == 'kdtree': # returns number of points in [0, radius) (with the "tol" correction within) + nneighbors = utils.calculate_neighbor_counts_with_kdtree(center_coords=coords_centers[valid_centers, :], neighbor_coords=coords_neighbors, radius=radii[1]) if (neighbors_eq_centers) and (rad_range[0] < tol): nneighbors = nneighbors - 1 # we're always going to count the center as a neighbor of itself in this case, so account for this; see also physical notebook notes on 1/7/21 @@ -4002,7 +4004,7 @@ def plot_just_rois(df, plotting_map, num_colors, webpage_dir, mapping_dict, coor boxes_to_plot = boxes_for_slide.iloc[iroi_box].to_frame().T # this is a one-row dataframe tag = boxes_to_plot.index[0] list_of_tuple_arguments.append((roi_figsize, spec2plot_roi, species_roi, x_roi, y_roi, plotting_map, colors, x_range, y_range, uroi, marker_size_step, default_marker_size, roi_dpi, mapping_dict, coord_units_in_microns, alpha, edgecolors, yaxis_dir, boxes_to_plot, pval_params, title_suffix, tag, filename_suffix, savedir)) - utils.execute_data_parallelism_potentially(function=plot_and_save_roi, list_of_tuple_arguments=list_of_tuple_arguments, nworkers=(0 if not use_multiprocessing else nworkers), task_description='plotting of the single ROI outlines on whole slide images') + utils.execute_data_parallelism_potentially(function=plot_and_save_roi, list_of_tuple_arguments=list_of_tuple_arguments, nworkers=(0 if not use_multiprocessing else nworkers), task_description='plotting of the single ROI outlines on whole slide images') # note this takes plenty of overhead... the individual plot_and_save_roi() function takes about 0.2s per image, but over six seconds seem to be taken for setting up the forkserver or other multiprocessing overhead. So it seems like the plotting of the single ROI outlines on whole slide images step takes long on simple datasets, but that's okay! This step in particular is all execute_data_parallelism_potentially() which is all (for small datasets) overhead! def get_integer_index(val_to_map, vmin=-200, vmax=0, N=2**8): @@ -4100,7 +4102,7 @@ def plot_single_density_pvals(argument_tuple): # Define the ROI-specific variables from the main df_density_pvals_arrays parameter log_dens_pvals_arr = df_density_pvals_arrays.loc[entity_index, 'log_dens_pvals_arr'] - entity_name = df_density_pvals_arrays.loc[entity_index, 'roi_name'] + entity_name = df_density_pvals_arrays.loc[entity_index, 'roi_name'].replace(' ', '_') # this is probably the place to replace spaces with underscores entity = 'roi' # Plot the heatmaps for the current set of data @@ -4141,7 +4143,7 @@ def plot_single_roi(args_as_single_tuple): # Assign the data for the current ROI to useful variables roi_name = df_data_by_roi.loc[roi_index, 'unique_roi'] # df_data_by_roi.loc[roi_index, 'roi_name'] - roi_fig_filename = 'roi_plot_{}_{}.{}'.format(roi_name, roi_index, save_image_ext) + roi_fig_filename = 'roi_plot_{}_{}.{}'.format(roi_name.replace(' ', '_'), roi_index, save_image_ext) roi_fig_pathname = os.path.join(savedir, roi_fig_filename) # If the image file doesn't already exist... diff --git a/utag/__init__.py b/utag/__init__.py deleted file mode 100644 index 08c26ca..0000000 --- a/utag/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -import warnings - -from scipy.sparse import SparseEfficiencyWarning - -from utag.segmentation import utag - - -try: - # Even though there is no "imc/_version" file, - # it should be generated by - # setuptools_scm when building the package - from utag._version import version - - __version__ = version -except ImportError: - from setuptools_scm import get_version as _get_version - - version = __version__ = _get_version(root="..", relative_to=__file__) - -warnings.simplefilter("ignore", FutureWarning) -warnings.simplefilter("ignore", SparseEfficiencyWarning) diff --git a/utag/segmentation.py b/utag/segmentation.py deleted file mode 100644 index 0a372cf..0000000 --- a/utag/segmentation.py +++ /dev/null @@ -1,285 +0,0 @@ -import typing as tp -import warnings -import os - -import scanpy as sc -import squidpy as sq -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -from tqdm import tqdm -import anndata -import parmap - -from utag.types import Path, Array, AnnData -from utag.utils import sparse_matrix_dstack - - -def utag( - adata: AnnData, - channels_to_use: tp.Sequence[str] = None, - slide_key: tp.Optional[str] = "Slide", - save_key: str = "UTAG Label", - filter_by_variance: bool = False, - max_dist: float = 20.0, - normalization_mode: str = "l1_norm", - keep_spatial_connectivity: bool = False, - pca_kwargs: tp.Dict[str, tp.Any] = dict(n_comps=10), - apply_umap: bool = False, - umap_kwargs: tp.Dict[str, tp.Any] = dict(), - apply_clustering: bool = True, - clustering_method: tp.Sequence[str] = ["leiden", "parc", "kmeans"], - resolutions: tp.Sequence[float] = [0.05, 0.1, 0.3, 1.0], - leiden_kwargs: tp.Dict[str, tp.Any] = None, - parc_kwargs: tp.Dict[str, tp.Any] = None, - parallel: bool = True, - processes: int = None, -) -> AnnData: - """ - Discover tissue architechture in single-cell imaging data - by combining phenotypes and positional information of cells. - - Parameters - ---------- - adata: AnnData - AnnData object with spatial positioning of cells in obsm 'spatial' slot. - channels_to_use: Optional[Sequence[str]] - An optional sequence of strings used to subset variables to use. - Default (None) is to use all variables. - max_dist: float - Maximum distance to cut edges within a graph. - Should be adjusted depending on resolution of images. - For imaging mass cytometry, where resolution is 1um, 20 often gives good results. - Default is 20. - slide_key: {str, None} - Key of adata.obs containing information on the batch structure of the data. - In general, for image data this will often be a variable indicating the image - so image-specific effects are removed from data. - Default is "Slide". - save_key: str - Key to be added to adata object holding the UTAG clusters. - Depending on the values of `clustering_method` and `resolutions`, - the final keys will be of the form: {save_key}_{method}_{resolution}". - Default is "UTAG Label". - filter_by_variance: bool - Whether to filter vairiables by variance. - Default is False, which keeps all variables. - max_dist: float - Recommended values are between 20 to 50 depending on magnification. - Default is 20. - normalization_mode: str - Method to normalize adjacency matrix. - Default is "l1_norm", any other value will not use normalization. - keep_spatial_connectivity: bool - Whether to keep sparse matrices of spatial connectivity and distance in the obsp attribute of the - resulting anndata object. This could be useful in downstream applications. - Default is not to (False). - pca_kwargs: Dict[str, Any] - Keyword arguments to be passed to scanpy.pp.pca for dimensionality reduction after message passing. - Default is to pass n_comps=10, which uses 10 Principal Components. - apply_umap: bool - Whether to build a UMAP representation after message passing. - Default is False. - umap_kwargs: Dict[str, Any] - Keyword arguments to be passed to scanpy.tl.umap for dimensionality reduction after message passing. - Default is 10.0. - apply_clustering: bool - Whether to cluster the message passed matrix. - Default is True. - clustering_method: Sequence[str] - Which clustering method(s) to use for clustering of the message passed matrix. - Default is ["leiden", "parc"]. - resolutions: Sequence[float] - What resolutions should the methods in `clustering_method` be run at. - Default is [0.05, 0.1, 0.3, 1.0]. - leiden_kwargs: dict[str, Any] - Keyword arguments to pass to scanpy.tl.leiden. - parc_kwargs: dict[str, Any] - Keyword arguments to pass to parc.PARC. - parallel: bool - Whether to run message passing part of algorithm in parallel. - Will accelerate the process but consume more memory. - Default is True. - processes: int - Number of processes to use in parallel. - Default is to use all available (-1). - - Returns - ------- - adata: AnnData - AnnData object with UTAG domain predictions for each cell in adata.obs, column `save_key`. - """ - ad = adata.copy() - - if channels_to_use: - ad = ad[:, channels_to_use] - - if filter_by_variance: - ad = low_variance_filter(ad) - - if isinstance(clustering_method, list): - clustering_method = [m.upper() for m in clustering_method] - elif isinstance(clustering_method, str): - clustering_method = [clustering_method.upper()] - else: - print( - "Invalid Clustering Method. Clustering Method Should Either be a string or a list" - ) - return - assert all(m in ["LEIDEN", "PARC", "KMEANS"] for m in clustering_method) - - if "PARC" in clustering_method: - from parc import PARC # early fail if not available - if "KMEANS" in clustering_method: - from sklearn.cluster import KMeans - - print("Applying UTAG Algorithm...") - if slide_key: - ads = [ - ad[ad.obs[slide_key] == slide].copy() for slide in ad.obs[slide_key].unique() - ] - ad_list = parmap.map( - _parallel_message_pass, - ads, - radius=max_dist, - coord_type="generic", - set_diag=True, - mode=normalization_mode, - pm_pbar=True, - pm_parallel=parallel, - pm_processes=processes, - ) - ad_result = anndata.concat(ad_list) - if keep_spatial_connectivity: - ad_result.obsp["spatial_connectivities"] = sparse_matrix_dstack( - [x.obsp["spatial_connectivities"] for x in ad_list] - ) - ad_result.obsp["spatial_distances"] = sparse_matrix_dstack( - [x.obsp["spatial_distances"] for x in ad_list] - ) - else: - sq.gr.spatial_neighbors(ad, radius=max_dist, coord_type="generic", set_diag=True) - ad_result = custom_message_passing(ad, mode=normalization_mode) - - if apply_clustering: - if "n_comps" in pca_kwargs: - if pca_kwargs["n_comps"] > ad_result.shape[1]: - pca_kwargs["n_comps"] = ad_result.shape[1] - 1 - print( - f"Overwriding provided number of PCA dimensions to match number of features: {pca_kwargs['n_comps']}" - ) - sc.tl.pca(ad_result, **pca_kwargs) - sc.pp.neighbors(ad_result) - - if apply_umap: - print("Running UMAP on Input Dataset...") - sc.tl.umap(ad_result, **umap_kwargs) - - for resolution in tqdm(resolutions): - - res_key1 = save_key + "_leiden_" + str(resolution) - res_key2 = save_key + "_parc_" + str(resolution) - res_key3 = save_key + "_kmeans_" + str(resolution) - if "LEIDEN" in clustering_method: - print(f"Applying Leiden Clustering at Resolution: {resolution}...") - kwargs = dict() - kwargs.update(leiden_kwargs or {}) - sc.tl.leiden( - ad_result, resolution=resolution, key_added=res_key1, **kwargs - ) - add_probabilities_to_centroid(ad_result, res_key1) - - if "PARC" in clustering_method: - from parc import PARC - - print(f"Applying PARC Clustering at Resolution: {resolution}...") - - kwargs = dict(random_seed=1, small_pop=1000) - kwargs.update(parc_kwargs or {}) - model = PARC( - ad_result.obsm["X_pca"], - neighbor_graph=ad_result.obsp["connectivities"], - resolution_parameter=resolution, - **kwargs, - ) - model.run_PARC() - ad_result.obs[res_key2] = pd.Categorical(model.labels) - ad_result.obs[res_key2] = ad_result.obs[res_key2].astype("category") - add_probabilities_to_centroid(ad_result, res_key2) - - if "KMEANS" in clustering_method: - print(f"Applying K-means Clustering at Resolution: {resolution}...") - k = int(np.ceil(resolution * 10)) - kmeans = KMeans(n_clusters=k, random_state=1).fit(ad_result.obsm["X_pca"]) - ad_result.obs[res_key3] = pd.Categorical(kmeans.labels_.astype(str)) - add_probabilities_to_centroid(ad_result, res_key3) - - return ad_result - - -def _parallel_message_pass( - ad: AnnData, - radius: int, - coord_type: str, - set_diag: bool, - mode: str, -): - sq.gr.spatial_neighbors(ad, radius=radius, coord_type=coord_type, set_diag=set_diag) - ad = custom_message_passing(ad, mode=mode) - return ad - - -def custom_message_passing(adata: AnnData, mode: str = "l1_norm") -> AnnData: - # from scipy.linalg import sqrtm - # import logging - if mode == "l1_norm": - A = adata.obsp["spatial_connectivities"] - from sklearn.preprocessing import normalize - affinity = normalize(A, axis=1, norm="l1") - else: - # Plain A_mod multiplication - A = adata.obsp["spatial_connectivities"] - affinity = A - # logging.info(type(affinity)) - adata.X = affinity @ adata.X - return adata - - -def low_variance_filter(adata: AnnData) -> AnnData: - return adata[:, adata.var["std"] > adata.var["std"].median()] - - -def add_probabilities_to_centroid( - adata: AnnData, col: str, name_to_output: str = None -) -> AnnData: - from utag.utils import z_score - from scipy.special import softmax - - if name_to_output is None: - name_to_output = col + "_probabilities" - - mean = z_score(adata.to_df()).groupby(adata.obs[col]).mean() - probs = softmax(adata.to_df() @ mean.T, axis=1) - adata.obsm[name_to_output] = probs - return adata - - -def evaluate_performance( - adata: AnnData, - batch_key: str = "Slide", - truth_key: str = "DOM_argmax", - pred_key: str = "cluster", - method: str = "rand", -) -> Array: - assert method in ["rand", "homogeneity"] - from sklearn.metrics import rand_score, homogeneity_score - - score_list = [] - for key in adata.obs[batch_key].unique(): - batch = adata[adata.obs[batch_key] == key] - if method == "rand": - score = rand_score(batch.obs[truth_key], batch.obs[pred_key]) - elif method == "homogeneity": - score = homogeneity_score(batch.obs[truth_key], batch.obs[pred_key]) - score_list.append(score) - return score_list diff --git a/utag/types.py b/utag/types.py deleted file mode 100644 index 9622226..0000000 --- a/utag/types.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python - -""" -Specific data types used for type annotations in the package. -""" - -from __future__ import annotations -import os -import typing as tp -import pathlib - - -import numpy -import pandas -import anndata -import networkx -import matplotlib -from matplotlib.figure import Figure as _Figure - - -__all__ = [ - "Array", - "Graph", - "DataFrame", - "Figure", - "Axis", - "Path", - "AnnData", -] - - -class Path(pathlib.Path): - """ - A pathlib.Path child class that allows concatenation with strings - by overloading the addition operator. - - In addition, it implements the ``startswith`` and ``endswith`` methods - just like in the base :obj:`str` type. - - The ``replace_`` implementation is meant to be an implementation closer - to the :obj:`str` type. - - Iterating over a directory with ``iterdir`` that does not exists - will return an empty iterator instead of throwing an error. - - Creating a directory with ``mkdir`` allows existing directory and - creates parents by default. - """ - - _flavour = ( - pathlib._windows_flavour # type: ignore[attr-defined] # pylint: disable=W0212 - if os.name == "nt" - else pathlib._posix_flavour # type: ignore[attr-defined] # pylint: disable=W0212 - ) - - def __add__(self, string: str) -> Path: - return Path(str(self) + string) - - def startswith(self, string: str) -> bool: - return str(self).startswith(string) - - def endswith(self, string: str) -> bool: - return str(self).endswith(string) - - def replace_(self, patt: str, repl: str) -> Path: - return Path(str(self).replace(patt, repl)) - - def iterdir(self) -> tp.Generator: - if self.exists(): - yield from [Path(x) for x in pathlib.Path(str(self)).iterdir()] - yield from [] - - def mkdir(self, mode=0o777, parents: bool = True, exist_ok: bool = True) -> Path: - super().mkdir(mode=mode, parents=parents, exist_ok=exist_ok) - return self - - -Array = tp.Union[numpy.ndarray] -Graph = tp.Union[networkx.Graph] - -DataFrame = tp.Union[pandas.DataFrame] -AnnData = tp.Union[anndata.AnnData] - -Figure = tp.Union[_Figure] -Axis = tp.Union[matplotlib.axis.Axis] diff --git a/utag/utils.py b/utag/utils.py deleted file mode 100644 index 9d53b48..0000000 --- a/utag/utils.py +++ /dev/null @@ -1,472 +0,0 @@ -#!/usr/bin/env python - -""" -Helper functions used throughout the package. -""" - -import typing as tp - -import numpy as np -import scipy -import pandas as pd -import networkx as nx - -from utag.types import Array, Graph, DataFrame, Path, AnnData - - -def domain_connectivity( - adata: AnnData, - slide_key: str = 'Slide', - domain_key: str = 'UTAG Label', -) -> AnnData: - import squidpy as sq - import numpy as np - from tqdm import tqdm - - order = sorted(adata.obs[domain_key].unique().tolist()) - - global_pairwise_connection = pd.DataFrame(np.zeros(shape = (len(order),len(order))), index = order, columns = order) - for slide in tqdm(adata.obs[slide_key].unique()): - adata_batch = adata[adata.obs[slide_key] == slide].copy() - - sq.gr.spatial_neighbors(adata_batch, radius = 40, coord_type = 'generic') - - pairwise_connection = pd.DataFrame(index = order, columns = order) - for label in adata_batch.obs[domain_key].unique(): - self_connection = adata_batch[adata_batch.obs[domain_key] == label].obsp['spatial_connectivities'].todense().sum()/2 - self_connection = self_connection.round() - - pairwise_connection.loc[label, label] = self_connection - - for label in adata_batch.obs[domain_key].unique(): - for label2 in adata_batch.obs[domain_key].unique(): - if label != label2: - pairwise = adata_batch[adata_batch.obs[domain_key].isin([label, label2])].obsp['spatial_connectivities'].todense().sum()/2 - pairwise = pairwise.round() - pairwise_connection.loc[label, label2] = pairwise - pairwise_connection.loc[label, label] - pairwise_connection.loc[label2, label2] - pairwise_connection.loc[label2, label] = pairwise_connection.loc[label, label2] - - pairwise_connection = pairwise_connection.fillna(0) - global_pairwise_connection = global_pairwise_connection + pairwise_connection - adata.uns[f'{domain_key}_domain_adjacency_matrix'] = global_pairwise_connection - return adata - -def celltype_connectivity( - adata: AnnData, - slide_key: str = 'Slide', - domain_key: str = 'UTAG Label', - celltype_key: str = 'cluster_0.5_label', -) -> AnnData: - import squidpy as sq - import numpy as np - from tqdm import tqdm - - global_pairwise_utag = dict() - for label in adata.obs[domain_key].unique(): - cell_types = adata.obs[celltype_key].unique().tolist() - global_pairwise_utag[label] = pd.DataFrame(np.zeros(shape = (len(cell_types),len(cell_types))), index = cell_types, columns = cell_types) - - for slide in tqdm(adata.obs[slide_key].unique()): - adata_batch = adata[adata.obs[slide_key] == slide].copy() - sq.gr.spatial_neighbors(adata_batch, radius = 40, coord_type = 'generic') - - for label in adata.obs[domain_key].unique(): - adata_batch2 = adata_batch[adata_batch.obs[domain_key] == label].copy() - pairwise_connection = pd.DataFrame(index = cell_types, columns = cell_types) - - for cell_type1 in adata_batch2.obs[celltype_key].unique(): - self_connection = adata_batch2[adata_batch2.obs[celltype_key] == cell_type1].obsp['spatial_connectivities'].todense().sum()/2 - self_connection = self_connection.round() - - pairwise_connection.loc[cell_type1, cell_type1] = self_connection - - for cell_type1 in adata_batch.obs[celltype_key].unique(): - for cell_type2 in adata_batch2.obs[celltype_key].unique(): - if cell_type1 != cell_type2: - pairwise = adata_batch2[adata_batch2.obs[celltype_key].isin([cell_type1, cell_type2])].obsp['spatial_connectivities'].todense().sum()/2 - pairwise = pairwise.round() - pairwise_connection.loc[cell_type1, cell_type2] = pairwise - pairwise_connection.loc[cell_type1, cell_type1] - pairwise_connection.loc[cell_type2, cell_type2] - pairwise_connection.loc[cell_type2, cell_type1] = pairwise_connection.loc[cell_type1, cell_type2] - - pairwise_connection = pairwise_connection.fillna(0) - global_pairwise_utag[label] = global_pairwise_utag[label] + pairwise_connection - - adata.uns[f'{domain_key}_celltype_adjacency_matrix'] = global_pairwise_utag - return adata - - -def slide_connectivity( - adata: AnnData, - slide_key: str = 'roi', - domain_key: str = 'UTAG Label', -) -> dict(): - import squidpy as sq - import numpy as np - from tqdm import tqdm - - order = sorted(adata.obs[domain_key].unique().tolist()) - slide_connection = dict() - - for slide in tqdm(adata.obs[slide_key].unique()): - adata_batch = adata[adata.obs[slide_key] == slide].copy() - - sq.gr.spatial_neighbors(adata_batch, radius = 40, coord_type = 'generic') - - pairwise_connection = pd.DataFrame(index = order, columns = order) - for label in adata_batch.obs[domain_key].unique(): - self_connection = adata_batch[adata_batch.obs[domain_key] == label].obsp['spatial_connectivities'].todense().sum()/2 - self_connection = self_connection.round() - - pairwise_connection.loc[label, label] = self_connection - - for label in adata_batch.obs[domain_key].unique(): - for label2 in adata_batch.obs[domain_key].unique(): - if label != label2: - pairwise = adata_batch[adata_batch.obs[domain_key].isin([label, label2])].obsp['spatial_connectivities'].todense().sum()/2 - pairwise = pairwise.round() - pairwise_connection.loc[label, label2] = pairwise - pairwise_connection.loc[label, label] - pairwise_connection.loc[label2, label2] - pairwise_connection.loc[label2, label] = pairwise_connection.loc[label, label2] - - pairwise_connection = pairwise_connection.fillna(0) - pairwise_connection = pairwise_connection.loc[(pairwise_connection!=0).any(1), (pairwise_connection!=0).any(0)] - #pairwise_connection = pairwise_connection.dropna(axis = 1) - slide_connection[slide] = pairwise_connection - - return slide_connection - -def measure_per_domain_cell_type_colocalization( - adata: AnnData, - utag_key: str = "UTAG Label", - max_dist: int = 40, - n_iterations: int = 100, -): - import squidpy as sq - a_ = adata.copy() - sq.gr.spatial_neighbors(a_, radius=max_dist, coord_type="generic") - - G = nx.from_scipy_sparse_matrix(a_.obsp["spatial_connectivities"]) - - utag_map = {i: x for i, x in enumerate(adata.obs[utag_key])} - nx.set_node_attributes(G, utag_map, name=utag_key) - - adj, order = nx.linalg.attrmatrix.attr_matrix(G, node_attr=utag_key) - order = pd.Series(order).astype(adata.obs[utag_key].dtype) - freqs = pd.DataFrame(adj, order, order).fillna(0) + 1 - - norm_freqs = correct_interaction_background_random(G, freqs, utag_key, n_iterations) - return norm_freqs - - -def correct_interaction_background_random( - graph: nx.Graph, freqs: pd.DataFrame, attribute: str, n_iterations: int = 100 -): - values = {x: graph.nodes[x][attribute] for x in graph.nodes} - shuffled_freqs = list() - for _ in range(n_iterations): - g2 = graph.copy() - shuffled_attr = pd.Series(values).sample(frac=1) - shuffled_attr.index = values - nx.set_node_attributes(g2, shuffled_attr.to_dict(), name=attribute) - rf, rl = nx.linalg.attrmatrix.attr_matrix(g2, node_attr=attribute) - rl = pd.Series(rl, dtype=freqs.index.dtype) - shuffled_freqs.append(pd.DataFrame(rf, index=rl, columns=rl))#.fillna(0) + 1) - shuffled_freq = pd.concat(shuffled_freqs) - shuffled_freq = shuffled_freq.groupby(level=0).sum() - shuffled_freq = shuffled_freq.fillna(0) + 1 - - fl = np.log((freqs / freqs.values.sum())) - sl = np.log((shuffled_freq / shuffled_freq.values.sum())) - # make sure both contain all edges/nodes - fl = fl.reindex(sl.index, axis=0).reindex(sl.index, axis=1) - sl = sl.reindex(fl.index, axis=0).reindex(fl.index, axis=1) - return fl - sl - - -def evaluate_clustering( - adata: AnnData, - cluster_keys: Array, - celltype_label: str = 'celltype', - slide_key: str = 'roi', - metrics: Array = ['entropy', 'cluster_number', 'silhouette_score', 'connectivity'] -) -> DataFrame: - - if type(cluster_keys) == str: - cluster_keys = [cluster_keys] - if type(metrics) == str: - metrics = [metrics] - - cluster_loss = pd.DataFrame(index = metrics, columns = cluster_keys) - from tqdm import tqdm - - for metric in metrics: - print(f'Evaluating Cluster {metric}') - for cluster in tqdm(cluster_keys): - assert(metric in ['entropy', 'cluster_number', 'silhouette_score', 'connectivity']) - - if metric == 'entropy': - from scipy.stats import entropy - distribution = adata.obs.groupby([celltype_label, cluster]).count()[slide_key].reset_index().pivot(index = cluster, columns = celltype_label, values = slide_key) - cluster_entropy = distribution.apply(entropy, axis = 1).sort_values().mean() - - cluster_loss.loc[metric, cluster] = cluster_entropy - elif metric == 'cluster_number': - - cluster_loss.loc[metric, cluster] = len(adata.obs[cluster].unique()) - elif metric == 'silhouette_score': - - from sklearn.metrics import silhouette_score - cluster_loss.loc[metric, cluster] = silhouette_score(adata.X, labels = adata.obs[cluster]) - elif metric == 'connectivity': - global_pairwise_connection = domain_connectivity(adata = adata, slide_key = slide_key, domain_key = cluster) - inter_spatial_connectivity = np.log(np.diag(global_pairwise_connection).sum() / (global_pairwise_connection.sum().sum() - np.diag(global_pairwise_connection).sum())) - - cluster_loss.loc[metric, cluster] = inter_spatial_connectivity - return cluster_loss - -def to_uint(x: Array, base: int = 8) -> Array: - return (x * (2 ** base - 1)).astype(f"uint{base}") - - -def to_float(x: Array, base: int = 32) -> Array: - return (x / x.max()).astype(f"float{base}") - - -def open_image_with_tf(filename: str, file_type="png"): - import tensorflow as tf - - img = tf.io.read_file(filename) - return tf.io.decode_image(img, file_type) - - -def filter_kwargs( - kwargs: tp.Dict[str, tp.Any], callabl: tp.Callable, exclude: bool = None -) -> tp.Dict[str, tp.Any]: - from inspect import signature - - args = signature(callabl).parameters.keys() - if "kwargs" in args: - return kwargs - return {k: v for k, v in kwargs.items() if (k in args) and k not in (exclude or [])} - - -def array_to_graph( - arr: Array, - max_dist: int = 5, - node_attrs: tp.Mapping[int, tp.Mapping[str, tp.Union[str, int, float]]] = None, -) -> Graph: - """ - Generate a Graph of object distance-based connectivity in euclidean space. - - Parameters - ---------- - arr: np.ndarray - Labeled array. - """ - mask = arr > 0 - idx = arr[mask] - xx, yy = np.mgrid[: arr.shape[0], : arr.shape[1]] - arri = np.stack([xx[mask], yy[mask]]).T - dists = pd.DataFrame(scipy.spatial.distance.cdist(arri, arri), index=idx, columns=idx) - np.fill_diagonal(dists.values, np.nan) - - attrs = dists[dists <= max_dist].reset_index().melt(id_vars="index").dropna() - attrs.index = attrs.iloc[:, :2].apply(tuple, axis=1).tolist() - value = attrs["value"] - g = nx.from_edgelist(attrs.index) - nx.set_edge_attributes(g, value.to_dict(), "distance") - nx.set_edge_attributes(g, (1 / value).to_dict(), "connectivity") - - if node_attrs is not None: - nx.set_node_attributes(g, node_attrs) - - return g - -def compute_and_draw_network( - adata, - slide_key: str = 'roi', - node_key: str = 'UTAG Label', - figsize: tuple = (11,11), - dpi: int = 100, - font_size: int = 12, - node_size_min: int = 1000, - node_size_max: int = 3000, - edge_weight: float = 10, - log_transform: bool = True, - ax = None -) -> nx.Graph: - from utag.utils import domain_connectivity - import networkx as nx - import matplotlib.pyplot as plt - - adjacency_matrix = domain_connectivity(adata = adata, slide_key = slide_key, domain_key = node_key) - s1 = adata.obs.groupby(node_key).count() - s1 = s1[s1.columns[0]] - node_size = s1.values - node_size = (node_size - node_size.min()) / (node_size.max() - node_size.min()) * (node_size_max - node_size_min) + node_size_min - - if ax == None: - fig = plt.figure(figsize = figsize, dpi = dpi) - G = nx.from_numpy_matrix(np.matrix(adjacency_matrix), create_using=nx.Graph) - G = nx.relabel.relabel_nodes(G, {i: label for i, label in enumerate(adjacency_matrix.index)}) - pos = nx.circular_layout(G) - - edges, weights = zip(*nx.get_edge_attributes(G,'weight').items()) - if log_transform: - weights = np.log(np.array(list(weights))+1) - else: - weights = np.array(list(weights)) - weights = (weights - weights.min()) / (weights.max() - weights.min()) * edge_weight + 0.2 - weights = tuple(weights.tolist()) - - if ax: - nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size, ax = ax) - else: - nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size) - #nx.draw(G, pos, cmap = plt.cm.tab10, node_color = range(8), edgelist=edges, edge_color=weights, width=3, edge_cmap=plt.cm.coolwarm, with_labels=True, font_size = 14, node_size = 1000) - - if ax == None: - ax = plt.gca() - - color_key = node_key + '_colors' - if color_key in adata.uns: - ax.collections[0].set_edgecolor(adata.uns[color_key]) - else: - ax.collections[0].set_edgecolor('lightgray') - ax.collections[0].set_linewidth(3) - ax.set_xlim([1.1*x for x in ax.get_xlim()]) - ax.set_ylim([1.1*y for y in ax.get_ylim()]) - - return G - -def get_adjacency_matrix(g: Graph) -> Array: - return nx.adjacency_matrix(g, weight="connectivity").todense() - - -def get_feature_matrix(g: Graph) -> DataFrame: - return pd.DataFrame({n: g.nodes[n] for n in g.nodes}).T - - -def message_pass_graph(adj: Array, feat: DataFrame) -> DataFrame: - return (adj @ feat).set_index(feat.index) - - -def pad_feature_matrix(df: DataFrame, size: int) -> DataFrame: - index = df.index.tolist() + (df.index.max() + np.arange(size - df.shape[0])).tolist() - return pd.DataFrame( - np.pad( - df.values, - [(0, size - df.shape[0]), (0, 0)], - ), - index=index, - columns=df.columns, - ) - - -def pad_adjacency_matrix(mat: DataFrame, size: int) -> DataFrame: - return np.pad( - mat, - [ - (0, size - mat.shape[0]), - (0, size - mat.shape[0]), - ], - ) - - -def message_pass_graphs(gs: tp.Sequence[Graph]) -> Array: - n = max([len(g) for g in gs]) - _adjs = list() - _feats = list() - for g in gs: - adj = get_adjacency_matrix(g) - adj = pad_adjacency_matrix(adj, n) - _adjs.append(adj) - feat = get_feature_matrix(g) - feat = pad_feature_matrix(feat, n) - _feats.append(feat) - adjs = np.stack(_adjs) - feats = np.stack(_feats).astype(float) - - return adjs @ feats - - -def mask_to_labelme( - labeled_image: Array, - filename: Path, - overwrite: bool = False, - simplify: bool = True, - simplification_threshold: float = 5.0, -) -> None: - import io - import base64 - import json - - import imageio - import tifffile - from imantics import Mask - from shapely.geometry import Polygon - - output_file = filename.replace_(".tif", ".json") - if overwrite or output_file.exists(): - return - polygons = Mask(labeled_image).polygons() - shapes = list() - for point in polygons.points: - - if not simplify: - poly = np.asarray(point).tolist() - else: - poly = np.asarray( - Polygon(point).simplify(simplification_threshold).exterior.coords.xy - ).T.tolist() - shape = { - "label": "A", - "points": poly, - "group_id": None, - "shape_type": "polygon", - "flags": {}, - } - shapes.append(shape) - - f = io.BytesIO() - imageio.imwrite(f, tifffile.imread(filename), format="PNG") - f.seek(0) - encoded = base64.encodebytes(f.read()) - - payload = { - "version": "4.5.6", - "flags": {}, - "shapes": shapes, - "imagePath": filename.name, - "imageData": encoded.decode("ascii"), - "imageHeight": labeled_image.shape[0], - "imageWidth": labeled_image.shape[1], - } - with open(output_file.as_posix(), "w") as fp: - json.dump(payload, fp, indent=2) - - -def z_score(x: Array) -> Array: - """ - Scale (divide by standard deviation) and center (subtract mean) array-like objects. - """ - return (x - x.min()) / (x.max() - x.min()) - - -def sparse_matrix_dstack( - matrices: tp.Sequence[scipy.sparse.csr_matrix], -) -> scipy.sparse.csr_matrix: - """ - Diagonally stack sparse matrices. - """ - import scipy - from tqdm import tqdm - - n = sum([x.shape[0] for x in matrices]) - _res = list() - i = 0 - for x in tqdm(matrices): - v = scipy.sparse.csr_matrix((x.shape[0], n)) - v[:, i : i + x.shape[0]] = x - _res.append(v) - i += x.shape[0] - return scipy.sparse.vstack(_res) diff --git a/utag/vizualize.py b/utag/vizualize.py deleted file mode 100644 index 4466018..0000000 --- a/utag/vizualize.py +++ /dev/null @@ -1,200 +0,0 @@ -import os - -import scanpy as sc -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import anndata -import holoviews as hv -from holoviews import opts, dim - -from utag.types import Path, Array, AnnData, DataFrame - - -def add_spatial_image( - adata: AnnData, - image_path: Path, - rgb_channels = [19, 9, 14], - log_transform: bool = False, - median_filter: bool = False, - scale_method: str = 'adjust_gamma', - contrast_percentile = (0, 90), - gamma: float = 0.2, - gain: float = 0.5 -): - - adata.obsm['spatial'] = adata.obs[['Y_centroid', 'X_centroid']].to_numpy() - adata.uns["spatial"] = {'image': {}} - adata.uns["spatial"]['image']["images"] = {} - - img = rgbfy_multiplexed_image( - image_path = image_path, - rgb_channels = rgb_channels, - contrast_percentile = contrast_percentile, - log_transform = log_transform, - median_filter = median_filter, - scale_method = scale_method, - gamma = gamma, - gain = gain - ) - - - adata.uns["spatial"]['image']["images"] = {"hires": img} - adata.uns["spatial"]['image']["scalefactors"] = {"tissue_hires_scalef": 1, "spot_diameter_fullres": 1} - return adata - -def add_scale_box_to_fig( - img: Array, - ax, - box_width: int = 100, - box_height: float = 3, - color: str = 'white' -) -> Array: - import matplotlib.patches as patches - x = img.shape[1] - y = img.shape[0] - - # Create a Rectangle patch - rect = patches.Rectangle((x - box_width, y * (1-box_height/100)), box_width, y * (box_height/100), linewidth=0.1, edgecolor='black', facecolor=color) - - # Add the patch to the Axes - ax.add_patch(rect) - return ax - -def rgbfy_multiplexed_image( - image_path: Path, - rgb_channels = [19, 9, 14], - log_transform: bool = True, - median_filter: bool = True, - scale_method: str = 'adjust_gamma', - contrast_percentile = (10, 90), - gamma: float = 0.4, - gain: float = 1 -) -> Array: - from skimage.exposure import rescale_intensity, adjust_gamma, equalize_hist - from scipy.ndimage import median_filter as mf - import tifffile - - def rescale(img, contrast_percentile): - r1, r2 = np.percentile(img, contrast_percentile) - img = rescale_intensity(img, in_range = (r1, r2), out_range = (0,1)) - return img - #assert(len(rgb_channels) == 3 or len(rgb_channels) == 1) - - img = tifffile.imread(image_path) - img = img.astype(np.float32) - if median_filter == True: - img = mf(img, size = 3) - - image_to_save = np.stack([img[x] for x in rgb_channels], axis = 2) - - for i in range(len(rgb_channels)): - if log_transform == True: - image_to_save[:,:,i] = np.log(image_to_save[:,:,i] + 1) - else: - image_to_save[:,:,i] = image_to_save[:,:,i] - - output_img = image_to_save - - for i in range(3): - if scale_method == 'contrast_stretch': - output_img[:,:,i] = rescale(output_img[:,:,i], contrast_percentile) - elif scale_method == 'adjust_gamma': - output_img[:,:,i] = adjust_gamma(output_img[:,:,i], gamma=gamma, gain=gain) - #output_img[:,:,i] = rescale(output_img[:,:,i], contrast_percentile) - elif scale_method == 'equalize_hist': - output_img[:,:,i] = equalize_hist(output_img[:,:,i]) - - output_img[:,:,i] = np.clip(output_img[:,:,i], 0, 1) - return output_img - - -def draw_network( - adata: AnnData, - node_key: str = 'UTAG Label', - adjacency_matrix_key: str = 'UTAG Label_domain_adjacency_matrix', - figsize: tuple = (11,11), - dpi: int = 200, - font_size: int = 12, - node_size_min: int = 1000, - node_size_max: int = 3000, - edge_weight: float = 5, - edge_weight_baseline: float = 1, - log_transform: bool = True, - ax = None -): - import networkx as nx - s1 = adata.obs.groupby(node_key).count() - s1 = s1[s1.columns[0]] - node_size = s1.values - node_size = (node_size - node_size.min()) / (node_size.max() - node_size.min()) * (node_size_max - node_size_min) + node_size_min - - if ax == None: - fig = plt.figure(figsize = figsize, dpi = dpi) - G = nx.from_numpy_matrix(np.matrix(adata.uns[adjacency_matrix_key]), create_using=nx.Graph) - G = nx.relabel.relabel_nodes(G, {i: label for i, label in enumerate(adata.uns[adjacency_matrix_key].index)}) - - edges, weights = zip(*nx.get_edge_attributes(G,'weight').items()) - - if log_transform: - weights = np.log(np.array(list(weights))+1) - else: - weights = np.array(list(weights)) - - weights = (weights - weights.min()) / (weights.max() - weights.min()) * edge_weight + edge_weight_baseline - weights = tuple(weights.tolist()) - - #pos = nx.spectral_layout(G, weight = 'weight') - pos = nx.spring_layout(G, weight = 'weight', seed = 42, k = 1) - - if ax: - nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size, ax = ax) - else: - nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size) - - if ax == None: - ax = plt.gca() - - color_key = node_key + '_colors' - if color_key in adata.uns: - ax.collections[0].set_edgecolor(adata.uns[color_key]) - ax.collections[0].set_facecolor(adata.uns[color_key]) - else: - ax.collections[0].set_edgecolor('lightgray') - ax.collections[0].set_linewidth(3) - ax.set_xlim([1.3*x for x in ax.get_xlim()]) - ax.set_ylim([1*y for y in ax.get_ylim()]) - - if ax == None: - return fig - -def adj2chord( - adjacency_matrix: Array, - size:int = 300 -): - hv.output(fig='svg', size=size) - - links = adjacency_matrix.stack().reset_index().rename(columns = {'level_0': 'source', 'level_1': 'target', 0: 'value'}).dropna() - order2ind = {k:i for i, k in enumerate(adjacency_matrix.index.tolist())} - - links['source'] = links['source'].replace(order2ind) - links['target'] = links['target'].replace(order2ind) - links['value'] = links['value'].astype(int) - - nodes = pd.DataFrame(order2ind.keys(), index = order2ind.values(), columns = ['name']).reset_index() - nodes['group'] = nodes['index'] - del nodes['index'] - nodes = hv.Dataset(nodes, 'index') - - chord = hv.Chord((links, nodes)).select(value=(5, None)) - chord.opts( - opts.Chord( - cmap='tab10', - edge_cmap='tab10', - edge_color=dim('source').str(), - labels='name', - node_color=dim('index').str() - ) - ) - - return chord \ No newline at end of file diff --git a/utils.py b/utils.py index 2a8fd95..3e3c939 100644 --- a/utils.py +++ b/utils.py @@ -12,6 +12,7 @@ from datetime import datetime import anndata import time +import pickle def set_filename_corresp_to_roi(df_paths, roi_name, curr_colname, curr_dir, curr_dir_listing): """Update the path in a main paths-holding dataframe corresponding to a particular ROI in a particular directory. @@ -63,6 +64,8 @@ def get_paths_for_rois(): # Obtain the directory holding the subdirectories containing various types of plots (in this case, three types) # plots_dir = os.path.join(os.getcwd(), '..', 'results', 'webpage', 'slices_1x{}'.format(radius_in_microns), 'real') plots_dir = os.path.join('.', 'output', 'images') + pickle_dir = os.path.join('.', 'output', 'checkpoints') + pickle_file = 'initial_data.pkl' # Obtain the paths to the subdirectories outlines_dir = os.path.join(plots_dir, 'single_roi_outlines_on_whole_slides') @@ -86,11 +89,17 @@ def get_paths_for_rois(): df_paths = set_filename_corresp_to_roi(df_paths=df_paths, roi_name=roi_name, curr_colname='heatmap', curr_dir=heatmaps_dir, curr_dir_listing=heatmaps_dir_listing) df_paths = set_filename_corresp_to_roi(df_paths=df_paths, roi_name=roi_name, curr_colname='outline', curr_dir=outlines_dir, curr_dir_listing=outlines_dir_listing) + with open(os.path.join(pickle_dir, pickle_file), 'rb') as f: + initial_data = pickle.load(f) + df_data_by_roi = initial_data['df_data_by_roi'] + df_data_by_roi['unique_roi'] = df_data_by_roi['unique_roi'].replace(' ', '_', regex=True) + ser_slide_per_roi = df_data_by_roi.set_index('unique_roi')['unique_slide'] + # Add columns containing the patient "case" ID and the slide "condition", in order to aid in sorting the data cases = [] conditions = [] for roi_name in df_paths.index: - slide_id = roi_name.split('-')[0] + slide_id = ser_slide_per_roi[roi_name].split('-')[0] # this is a more reliable way to get the slide ID cases.append(int(slide_id[:-1])) conditions.append(slide_id[-1]) df_paths['case'] = cases @@ -135,21 +144,16 @@ def get_paths_for_slides(): df_paths = pd.DataFrame([os.path.splitext(x)[0] for x in slides_listing], columns=['slide_name']).set_index('slide_name') # Determine the filenames of each of the image types corresponding to each slide name - corresp_slide_filename = [] - corresp_slide_filename_patched = [] - corresp_heatmap_filename = [] for slide_name in df_paths.index: - for slide_filename, heatmap_filename in zip(slides_listing, heatmaps_listing): + for slide_filename in slides_listing: if slide_name in slide_filename: - corresp_slide_filename.append(os.path.join(plots_dir, 'whole_slide_patches', slide_filename)) - corresp_slide_filename_patched.append(os.path.join(plots_dir, 'whole_slide_patches', '{}-patched{}'.format(os.path.splitext(slide_filename)[0], file_extension))) + df_paths.loc[slide_name, 'slide'] = os.path.join(plots_dir, 'whole_slide_patches', slide_filename) + df_paths.loc[slide_name, 'slide_patched'] = os.path.join(plots_dir, 'whole_slide_patches', '{}-patched{}'.format(os.path.splitext(slide_filename)[0], file_extension)) + break + for heatmap_filename in heatmaps_listing: if slide_name in heatmap_filename: - corresp_heatmap_filename.append(os.path.join(plots_dir, 'dens_pvals_per_slide', heatmap_filename)) - - # Add these paths to the main paths dataframe - df_paths['slide'] = corresp_slide_filename - df_paths['slide_patched'] = corresp_slide_filename_patched - df_paths['heatmap'] = corresp_heatmap_filename + df_paths.loc[slide_name, 'heatmap'] = os.path.join(plots_dir, 'dens_pvals_per_slide', heatmap_filename) + break # Add columns containing the patient "case" ID and the slide "condition", in order to aid in sorting the data cases = [] @@ -164,6 +168,14 @@ def get_paths_for_slides(): # Sort the data by case, then by condition, then by the slide string df_paths = df_paths.sort_values(by=['case', 'condition', 'slide_name']) + # # Delete rows in df_paths where 'slide', 'slide_patched', or 'heatmap' is None + # print(df_paths) + # num_rows_before = len(df_paths) + # df_paths = df_paths.dropna(subset=['slide', 'slide_patched', 'heatmap']) + # print(df_paths) + # num_rows_after = len(df_paths) + # print(f'Deleted {num_rows_before - num_rows_after} rows from df_paths where "slide", "slide_patched", or "heatmap" was None.') + # Return the paths dataframe return df_paths @@ -760,6 +772,23 @@ def calculate_neighbor_counts_with_possible_chunking(center_coords=None, neighbo # Return the neighbor counts return neighbor_counts + +def calculate_neighbor_counts_with_kdtree(center_coords, neighbor_coords, radius, tol=1e-9): + # NOTE FOR FUTURE: Probably reconsider using scipy.spatial.KDTree.count_neighbors() since that may align with the statistic of interest in both Poisson and permutation methods, i.e., the sum of the neighbor counts over all centers. Perhaps force that to work because that may really perfectly match the statistic of interest. E.g., note in calculate_density_metrics() that the full output of this function (num_centers,) is not used but is rather summed, which I believe would be the output of count_neighbors(). I.e., query_ball_tree() returns extra, unused information. I would need to make sure there would never be memory issues though, e.g., if an entire large slide were to run count_neighbors() at once. If there are P phenotypes, we'd still have to build 2P trees and call count_neighbors P^2 times per ROI, so both timing and memory usage should be tested thoroughly. + radius = radius - tol # to essentially make the check [0, radius) instead of [0, radius] + center_tree = scipy.spatial.KDTree(center_coords) + neighbor_tree = scipy.spatial.KDTree(neighbor_coords) + indexes = center_tree.query_ball_tree(neighbor_tree, r=radius) + return np.array([len(neighbors_list) for neighbors_list in indexes]) # (num_centers,) + + # Using this matches the cdist method but is not elegant, but using "tol" above gets the same result more efficiently + # neighbor_counts = [] + # for i, neighbors_list in enumerate(indexes): + # count = sum(np.linalg.norm(neighbor_coords[idx] - center_coords[i]) < radius for idx in neighbors_list) + # neighbor_counts.append(count) + # return np.array(neighbor_counts) # (num_centers,) + + def dataframe_insert_possibly_existing_column(df, column_position, column_name, srs_column_values): """ Alternative to df.insert() that replaces the column values if the column already exists, but otherwise uses df.insert() to add the column to the dataframe. @@ -837,21 +866,24 @@ def downcast_series_dtype(ser, frac_cutoff=0.05, number_cutoff=10): # Get the initial dtype initial_dtype = ser.dtype - # Check if the series dtype is 'object' - if ser.dtype == 'object': - cutoff = frac_cutoff * len(ser) # Calculate the cutoff based on the fraction of unique values - else: - cutoff = number_cutoff # Use the number cutoff for non-object dtypes + # Don't do anything if the series is boolean + if initial_dtype != 'bool': - # If the number of unique values is less than or equal to the cutoff, convert the series to the category data type - if ser.nunique() <= cutoff: - ser = ser.astype('category') + # Check if the series dtype is 'object' + if ser.dtype == 'object': + cutoff = frac_cutoff * len(ser) # Calculate the cutoff based on the fraction of unique values + else: + cutoff = number_cutoff # Use the number cutoff for non-object dtypes + + # If the number of unique values is less than or equal to the cutoff, convert the series to the category data type + if ser.nunique() <= cutoff: + ser = ser.astype('category') - # Halve the precision of integers and floats - if ser.dtype == 'int64': - ser = ser.astype('int32') - elif ser.dtype == 'float64': - ser = ser.astype('float32') + # Halve the precision of integers and floats + if ser.dtype == 'int64': + ser = ser.astype('int32') + elif ser.dtype == 'float64': + ser = ser.astype('float32') # Get the final dtype final_dtype = ser.dtype @@ -885,13 +917,12 @@ def downcast_dataframe_dtypes(df, also_return_final_size=False, frac_cutoff=0.05 print('----') print('Memory usage before conversion: {:.2f} MB'.format(original_memory / 1024 ** 2)) - # Potentially convert the columns to more efficient formats, ignoring boolean columns + # Potentially convert the columns to more efficient formats for col in df.columns: - if df[col].dtype != 'bool': - if no_categorical: - df[col] = downcast_series_dtype_no_categorical(df[col]) - else: - df[col] = downcast_series_dtype(df[col], frac_cutoff=frac_cutoff, number_cutoff=number_cutoff) + if no_categorical: + df[col] = downcast_series_dtype_no_categorical(df[col]) + else: + df[col] = downcast_series_dtype(df[col], frac_cutoff=frac_cutoff, number_cutoff=number_cutoff) # Print memory usage after conversion new_memory = df.memory_usage(deep=True).sum() @@ -930,12 +961,15 @@ def downcast_series_dtype_no_categorical(ser): # Get the initial dtype initial_dtype = ser.dtype - # Halve the precision of integers and floats - if initial_dtype == 'int64': - # ser = ser.astype('int32') - ser = downcast_int_series(ser) - elif initial_dtype == 'float64': - ser = ser.astype('float32') + # Don't do anything if the series is boolean + if initial_dtype != 'bool': + + # Halve the precision of integers and floats + if initial_dtype == 'int64': + # ser = ser.astype('int32') + ser = downcast_int_series(ser) + elif initial_dtype == 'float64': + ser = ser.astype('float32') # Get the final dtype final_dtype = ser.dtype @@ -1233,3 +1267,16 @@ def fast_neighbors_counts_for_block2(df_image, image_name, coord_column_names, p # Return the final dataframe of neighbor counts for the current image return df_curr_counts + + +def get_categorical_columns_including_numeric(df, max_num_unique_values=1000): + categorical_columns = [] + for col in df.columns: + if df[col].nunique() <= max_num_unique_values: + categorical_columns.append(col) + return categorical_columns + + +def sample_df_without_replacement_by_number(df, n, seed=None): + n = min(n, len(df)) + return df.sample(n=n, replace=False, random_state=seed)