diff --git a/pheno_utils/age_reference_plots.py b/pheno_utils/age_reference_plots.py index fba13d5..d7a77c3 100644 --- a/pheno_utils/age_reference_plots.py +++ b/pheno_utils/age_reference_plots.py @@ -17,6 +17,8 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +from matplotlib.colors import to_rgba +import matplotlib.patches as mpatches import importlib from scipy.stats import linregress from sklearn.linear_model import HuberRegressor @@ -81,6 +83,7 @@ def __init__( val_bins: Optional[np.ndarray] = None, linear_fit: bool = True, expectiles: Optional[List] = [0.03, 0.1, 0.5, 0.9, 0.97], + thresholds: Optional[List] = None, top_disp_perc: float = 99, bottom_disp_perc: float = 1, robust: bool = True, @@ -290,6 +293,61 @@ def plot_ornaments(self): # font="Roboto Condensed", ) + def plot_thresholds(self, thresholds: List[float], labels: Optional[List[str]] = None, + cmap: str = 'viridis', legend: bool = True) -> None: + """ + Add threshold-based fill_between patches to the main axis of the plot. + + Parameters: + ----------- + thresholds : List[float] + List of N threshold values for plotting N+1 patches. + + labels : Optional[List[str]], default=None + List of N+1 labels for the patches. If None, automatic labels will be generated. + + cmap : str, default='viridis' + Colormap to use for the patches. + + legend : bool, default=True + Whether to show the legend. + + Returns: + -------- + None + """ + cm = plt.get_cmap(cmap) + n_patches = max(2, len(thresholds) - 1) + colors = [to_rgba(cm(i / (n_patches - 1))) for i in range(n_patches)] + ylim = self.ax_main.get_ylim() + legend_patches = [] + + for i, (lo, hi) in enumerate(zip(thresholds[:-1], thresholds[1:])): + lo = max(lo, ylim[0]) + hi = min(hi, ylim[1]) + if labels is None: + if lo == ylim[0]: + label = f'<{hi:.4g}' + elif hi == ylim[1]: + label = f'>{lo:.4g}' + else: + label = f'[{lo:.4g}, {hi:.4g})' + else: + label = labels[i] + + self.ax_main.fill_between( + self.ax_main.get_xlim(), lo, hi, color=colors[i], + zorder=-1, alpha=0.2) + self.ax_valhist.fill_between( + self.ax_valhist.get_xlim(), lo, hi, color=colors[i], + zorder=-1, alpha=0.2) + legend_patches.append(mpatches.Patch(color=colors[i], alpha=0.2, label=label)) + + # Add legend + if legend: + plt.legend(handles=legend_patches, loc='upper left', bbox_to_anchor=(1, 1)) + + def plot(self): ax = self.ax_main ax.spines["right"].set_visible(False) @@ -380,6 +438,33 @@ def __init__( make_fig=False ) + def plot_thresholds(self, thresholds: List[float], labels: Optional[List[str]] = None, + cmap: str = 'viridis', legend: bool = True) -> None: + """ + Add threshold-based fill_between patches to the main axis of the plot. + + Parameters: + ----------- + thresholds : List[float] + List of N threshold values for plotting N+1 patches. + + labels : Optional[List[str]], default=None + List of N+1 labels for the patches. If None, automatic labels will be generated. + + cmap : str, default='viridis' + Colormap to use for the patches. + + legend : bool, default=True + Whether to show the legend. + + Returns: + -------- + None + """ + self.female_refplot.plot_thresholds(thresholds, labels, cmap, legend) + self.male_refplot.plot_thresholds(thresholds, labels, cmap, legend) + + def plot(self) -> None: """Plots the data for both genders in separate panels.""" ax_dict = plt.figure(constrained_layout=True, figsize=(16, 6)).subplot_mosaic( diff --git a/pheno_utils/cohort_selector.py b/pheno_utils/cohort_selector.py index ecd1572..aa84365 100644 --- a/pheno_utils/cohort_selector.py +++ b/pheno_utils/cohort_selector.py @@ -20,6 +20,9 @@ from .meta_loader import MetaLoader # %% ../nbs/12_cohort_selector.ipynb 5 +from typing import Optional + + class CohortSelector: """ Class for selecting a subset of a cohort's data based on a query. @@ -69,13 +72,14 @@ def __init__( flexible_field_search=False, errors=self.errors, **self.kwargs) - def select(self, query: str) -> pd.DataFrame: + def select(self, query: str, add_fields: Optional[List] = []) -> pd.DataFrame: """ Select a subset of the cohort's data based on the given query. Args: query (str): Query string to filter the data. + add_fields (List, optional): Additional fields to load. Defaults to []. Returns: @@ -87,17 +91,24 @@ def select(self, query: str) -> pd.DataFrame: ValueError: If column names in the query do not match the column names in the metadata. """ - column_names = re.findall(r'([a-zA-Z][a-zA-Z0-9_]*)\b', query) + column_names = re.findall(r'([a-zA-Z][a-zA-Z0-9_/]*)\b', query) if not column_names: raise ValueError('No column names found in query') - test_cols = self.ml.get(column_names) + test_cols = self.ml.get(column_names) missing_cols = [col for col in column_names - if col not in test_cols.columns.str.split('/').str[1]] + if col not in test_cols.columns and + col not in test_cols.columns.str.split('/').str[1]] if len(missing_cols): raise ValueError(f'Column names {missing_cols} in query do not match column names in metadata') - df = self.ml.load(column_names) + if type(add_fields) == str: + add_fields = [add_fields] + df = self.ml.load(column_names + add_fields) - return df.query(query) + try: + return df.query(re.sub(r'\w+/', '', query)) + except: + print(re.sub(r'\w+(/)', '__', query)) + return df.query(re.sub(r'(\w+)/', r'\1__', query)) diff --git a/pheno_utils/meta_loader.py b/pheno_utils/meta_loader.py index f234b27..e5acd50 100644 --- a/pheno_utils/meta_loader.py +++ b/pheno_utils/meta_loader.py @@ -77,17 +77,31 @@ def load(self, fields: Union[str,List[str]], flexible: bool=None, prop: str='tab pd.DataFrame: Dataframe containing the fields from the respective datasets. """ found_fields = self.get(fields, flexible, prop) + if found_fields.empty: + return pd.DataFrame() + found_fields.columns = found_fields.columns.str.split('/').str[1] dup_fields = found_fields.columns.value_counts()\ .to_frame('count').query('count > 1').index + n_datasets = found_fields.loc['dataset'].nunique() loaded_fields = [] for ds, f in found_fields.T.groupby('dataset'): df = PhenoLoader(ds, base_path=self.base_path, cohort=self.cohort, age_sex_dataset=None, **self.kwargs)\ [f.index.tolist()] + if df.empty: + continue + + if 'array_index' in df.index.names and n_datasets > 1: + if df.index.get_level_values('array_index').nunique() > 1: + df = df.reset_index('array_index', drop=False)\ + .rename(columns={'array_index': f'{ds}__array_index'}) + else: + df = df.reset_index('array_index', drop=True) + # rename duplicate fields - df = df.rename(columns=pd.Series(dup_fields + f'_{ds}', index=dup_fields)) + df = df.rename(columns=pd.Series(f'{ds}__' + dup_fields, index=dup_fields)) if not len(loaded_fields): loaded_fields = df @@ -113,20 +127,29 @@ def get(self, fields: Union[str,List[str]], flexible: bool=None, prop='tabular_f flexible = self.flexible_field_search if isinstance(fields, str): fields = [fields] - fields = [f.lower() for f in fields] + fields = pd.DataFrame({'field': [f.lower() for f in fields]}).assign(dataset=None) + + if prop == 'tabular_field_name': + ind = fields['field'].str.contains('/') + fields.loc[ind, 'dataset'] = fields.loc[ind, 'field'].str.split('/').str[0] + fields.loc[ind, 'field'] = fields.loc[ind, 'field'].str.split('/').str[1] data = pd.DataFrame() for dataset, df in self.dicts.items(): + keep = (fields['dataset'] == dataset) | fields['dataset'].isnull() + fields_in_dataset = fields.loc[keep, 'field'] + if prop == 'tabular_field_name': search_in = pd.Series(df.columns, index=df.columns).str.lower() else: search_in = df.loc[prop].dropna().str.lower() if flexible: # use fuzzy matching including regex to find fields - fields_in_col = np.unique([col for f in fields for col, text in search_in.items() + fields_in_col = np.unique([col for f in fields_in_dataset for col, text in search_in.items() if type(text) is str and re.search(f, text)]) else: - fields_in_col = search_in[search_in.isin(fields)].index + fields_in_col = search_in[search_in.isin(fields_in_dataset)].index + if len(fields_in_col): this_data = df[fields_in_col] this_data.columns = dataset + '/' + this_data.columns diff --git a/pheno_utils/pheno_loader.py b/pheno_utils/pheno_loader.py index 207c148..bedcf87 100644 --- a/pheno_utils/pheno_loader.py +++ b/pheno_utils/pheno_loader.py @@ -76,6 +76,7 @@ def __init__( valid_dates: bool = False, valid_stage: bool = False, flexible_field_search: bool = False, + squeeze: bool = False, errors: str = ERROR_ACTION, read_parquet_kwargs: Dict[str, Any] = {} ) -> None: @@ -92,6 +93,7 @@ def __init__( self.valid_dates = valid_dates self.valid_stage = valid_stage self.flexible_field_search = flexible_field_search + self.squeeze = squeeze self.errors = errors self.read_parquet_kwargs = read_parquet_kwargs @@ -103,7 +105,7 @@ def __init__( def load_sample_data( self, field_name: str, - participant_id: Union[int, List[int]], + participant_id: Union[None, int, List[int]] = None, research_stage: Union[None, str, List[str]] = None, array_index: Union[None, int, List[int]] = None, load_func: callable = pd.read_parquet, @@ -115,39 +117,46 @@ def load_sample_data( Args: field_name (str): The name of the field to load. - participant_id (str or list): The participant ID or IDs to load data for. + participant_id (str or list, optional): The participant ID or IDs to load data for. research_stage (str or list, optional): The research stage or stages to load data for. array_index (int or list, optional): The array index or indices to load data for. load_func (callable, optional): The function to use to load the data. Defaults to pd.read_parquet concat (bool, optional): Whether to concatenate the data into a single DataFrame. Automatically ignored if data is not a DataFrame. Defaults to True. pivot (str, optional): The name of the field to pivot the data on (if DataFrame). Defaults to None. """ - query_str = 'participant_id in @participant_id' - if not isinstance(participant_id, list): - participant_id = [participant_id] + query_str = [] + if participant_id is not None: + query_str.append('participant_id in @participant_id') + if not isinstance(participant_id, list): + participant_id = [participant_id] if research_stage is not None: if not isinstance(research_stage, list): research_stage = [research_stage] - query_str += ' and research_stage in @research_stage' + query_str.append('research_stage in @research_stage') if array_index is not None: if not isinstance(array_index, list): array_index = [array_index] - query_str += ' and array_index in @array_index' + query_str.append('array_index in @array_index') + query_str = ' and '.join(query_str) - sample = self[[field_name] + ['participant_id']].query(query_str) + sample = self[[field_name] + ['participant_id']] + if query_str: + sample = sample.query(query_str) col = sample.columns[0] # can be different from field_name is a parent_dataframe is implied - sample = sample.astype({col: str}) - missing_participants = np.setdiff1d(participant_id, sample['participant_id'].unique()) - sample = sample.loc[:, col] - + sample = sample.astype({col: str}) + if participant_id is not None: + missing_participants = np.setdiff1d(participant_id, sample['participant_id'].unique()) + else: + missing_participants = [] if len(missing_participants): if self.errors == 'raise': raise ValueError(f'Missing samples: {missing_participants}') elif self.errors == 'warn': warnings.warn(f'Missing samples: {missing_participants}') - if len(sample) == 0: - return None + if len(sample) == 0: + return None + sample = sample.loc[:, col] # Load data data = [] @@ -155,6 +164,9 @@ def load_sample_data( try: data.append(load_func(p, **kwargs)) if isinstance(data[-1], pd.DataFrame): + data[-1] = self.__add_missing_levels__(data[-1], sample.loc[sample == p].to_frame()) + if query_str: + data[-1] = data[-1].query(query_str) data[-1].sort_index(inplace=True) except Exception as e: if self.errors == 'raise': @@ -188,7 +200,7 @@ def __str__(self): Returns: str: String representation of object """ - return f'DataLoader for {self.dataset} with' +\ + return f'PhenoLoader for {self.dataset} with' +\ f'\n{len(self.fields)} fields\n{len(self.dfs)} tables: {list(self.dfs.keys())}' def __getitem__(self, fields: Union[str,List[str]]): @@ -203,19 +215,21 @@ def __getitem__(self, fields: Union[str,List[str]]): """ return self.get(fields) - def get(self, fields: Union[str,List[str]], flexible: bool=None): + def get(self, fields: Union[str,List[str]], flexible: bool=None, squeeze: bool=None): """ Return data for the specified fields from all tables Args: fields (List[str]): Fields to return - flexible (bool, optional): Whether to use fuzzy matching to find fields. Defaults to None, which uses the DataLoader's flexible_field_search attribute. + flexible (bool, optional): Whether to use fuzzy matching to find fields. Defaults to None, which uses the PhenoLoader's flexible_field_search attribute. Returns: pd.DataFrame: Data for the specified fields from all tables """ if flexible is None: flexible = self.flexible_field_search + if squeeze is None: + squeeze = self.squeeze if isinstance(fields, str): fields = [fields] @@ -262,9 +276,11 @@ def get(self, fields: Union[str,List[str]], flexible: bool=None): cols_order = [field for field in fields if field in data.columns] cols_order += [field for field in flexi_fields if (field in data.columns and field not in fields)] + if squeeze and len(cols_order) == 1: + return data[cols_order[0]] + return data[cols_order] - def replace_bulk_data_path(self, data, fields): bulk_fields = self.dict.loc[self.dict.index.isin(fields)].query('item_type == "Bulk"') cols = [col for col in bulk_fields.index.to_list() if col in data.columns] @@ -280,8 +296,6 @@ def replace_bulk_data_path(self, data, fields): return data - - def __concat__(self, df1, df2): if df1.empty: return df2 @@ -486,6 +500,67 @@ def __get_dataset_path__(self, dataset): return os.path.join(self.base_path, dataset, self.cohort) return os.path.join(self.base_path, dataset) + def __add_missing_levels__(self, data: pd.DataFrame, more_levels: pd.DataFrame) -> pd.DataFrame: + """ + Extends the index levels of the given DataFrame ('data') by appending missing levels + found in another DataFrame ('more_levels'). + + This method performs a left join on the common index levels between 'data' and 'more_levels'. + The extra index level from 'more_levels' is appended at the end of the index levels in 'data'. + Rows present in 'data' are retained, and their indices are potentially extended. + Rows from 'more_levels' that do not exist in 'data' are not included in the output. + + Parameters: + ----------- + data : pd.DataFrame + The DataFrame whose index levels you want to extend. + This DataFrame's rows will all be present in the output DataFrame. + + more_levels : pd.DataFrame + The DataFrame used as a reference for adding extra index levels to 'data'. + + Returns: + -------- + pd.DataFrame + A new DataFrame with the same rows as 'data' but potentially extended index levels. + + Example: + -------- + # Given the following DataFrames: + data: index(levels: A, B), columns: [value] + more_levels: index(levels: A, B, C), columns: [value] + + # The output will have: + index(levels: A, B, C), columns: [value] + """ + # Identify common index levels + common_index_levels = set(data.index.names).intersection(set(more_levels.index.names)) + + # Identify the extra index level in more_levels + extra_index_levels = list(set(more_levels.index.names) - common_index_levels) + + if not extra_index_levels: + # If there are no additional index levels in more_levels, return data as is + return data + + extra_index_level = extra_index_levels[0] + + # Reset the index of more_levels to convert all index levels to columns + more_levels_reset = more_levels.reset_index() + + # Select only the common and extra index levels + more_levels_subset = more_levels_reset[list(common_index_levels) + [extra_index_level]].drop_duplicates() + + # Merge the dataframes on the common index levels using a left join + new_data = pd.merge(data.reset_index(), more_levels_subset, how='left', on=list(common_index_levels)) + + # Explicitly set the order of the new index levels to maintain the original order in 'data' and append the new level + new_index_order = data.index.names + [extra_index_level] + + new_data.set_index(new_index_order, inplace=True) + + return new_data + def describe_field(self, fields: Union[str,List[str]], return_summary: bool=False): """ Display a summary dataframe for the specified fields from all tables