Skip to content

Commit

Permalink
nbdev_prepare
Browse files Browse the repository at this point in the history
  • Loading branch information
alondmnt committed Sep 19, 2023
1 parent 334351f commit 3ee65ed
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 30 deletions.
85 changes: 85 additions & 0 deletions pheno_utils/age_reference_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import matplotlib.patches as mpatches
import importlib
from scipy.stats import linregress
from sklearn.linear_model import HuberRegressor
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
val_bins: Optional[np.ndarray] = None,
linear_fit: bool = True,
expectiles: Optional[List] = [0.03, 0.1, 0.5, 0.9, 0.97],
thresholds: Optional[List] = None,
top_disp_perc: float = 99,
bottom_disp_perc: float = 1,
robust: bool = True,
Expand Down Expand Up @@ -290,6 +293,61 @@ def plot_ornaments(self):
# font="Roboto Condensed",
)

def plot_thresholds(self, thresholds: List[float], labels: Optional[List[str]] = None,
cmap: str = 'viridis', legend: bool = True) -> None:
"""
Add threshold-based fill_between patches to the main axis of the plot.
Parameters:
-----------
thresholds : List[float]
List of N threshold values for plotting N+1 patches.
labels : Optional[List[str]], default=None
List of N+1 labels for the patches. If None, automatic labels will be generated.
cmap : str, default='viridis'
Colormap to use for the patches.
legend : bool, default=True
Whether to show the legend.
Returns:
--------
None
"""
cm = plt.get_cmap(cmap)
n_patches = max(2, len(thresholds) - 1)
colors = [to_rgba(cm(i / (n_patches - 1))) for i in range(n_patches)]
ylim = self.ax_main.get_ylim()
legend_patches = []

for i, (lo, hi) in enumerate(zip(thresholds[:-1], thresholds[1:])):
lo = max(lo, ylim[0])
hi = min(hi, ylim[1])
if labels is None:
if lo == ylim[0]:
label = f'<{hi:.4g}'
elif hi == ylim[1]:
label = f'>{lo:.4g}'
else:
label = f'[{lo:.4g}, {hi:.4g})'
else:
label = labels[i]

self.ax_main.fill_between(
self.ax_main.get_xlim(), lo, hi, color=colors[i],
zorder=-1, alpha=0.2)
self.ax_valhist.fill_between(
self.ax_valhist.get_xlim(), lo, hi, color=colors[i],
zorder=-1, alpha=0.2)
legend_patches.append(mpatches.Patch(color=colors[i], alpha=0.2, label=label))

# Add legend
if legend:
plt.legend(handles=legend_patches, loc='upper left', bbox_to_anchor=(1, 1))


def plot(self):
ax = self.ax_main
ax.spines["right"].set_visible(False)
Expand Down Expand Up @@ -380,6 +438,33 @@ def __init__(
make_fig=False
)

def plot_thresholds(self, thresholds: List[float], labels: Optional[List[str]] = None,
cmap: str = 'viridis', legend: bool = True) -> None:
"""
Add threshold-based fill_between patches to the main axis of the plot.
Parameters:
-----------
thresholds : List[float]
List of N threshold values for plotting N+1 patches.
labels : Optional[List[str]], default=None
List of N+1 labels for the patches. If None, automatic labels will be generated.
cmap : str, default='viridis'
Colormap to use for the patches.
legend : bool, default=True
Whether to show the legend.
Returns:
--------
None
"""
self.female_refplot.plot_thresholds(thresholds, labels, cmap, legend)
self.male_refplot.plot_thresholds(thresholds, labels, cmap, legend)


def plot(self) -> None:
"""Plots the data for both genders in separate panels."""
ax_dict = plt.figure(constrained_layout=True, figsize=(16, 6)).subplot_mosaic(
Expand Down
23 changes: 17 additions & 6 deletions pheno_utils/cohort_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from .meta_loader import MetaLoader

# %% ../nbs/12_cohort_selector.ipynb 5
from typing import Optional


class CohortSelector:
"""
Class for selecting a subset of a cohort's data based on a query.
Expand Down Expand Up @@ -69,13 +72,14 @@ def __init__(
flexible_field_search=False, errors=self.errors,
**self.kwargs)

def select(self, query: str) -> pd.DataFrame:
def select(self, query: str, add_fields: Optional[List] = []) -> pd.DataFrame:
"""
Select a subset of the cohort's data based on the given query.
Args:
query (str): Query string to filter the data.
add_fields (List, optional): Additional fields to load. Defaults to [].
Returns:
Expand All @@ -87,17 +91,24 @@ def select(self, query: str) -> pd.DataFrame:
ValueError: If column names in the query do not match the column names in the metadata.
"""
column_names = re.findall(r'([a-zA-Z][a-zA-Z0-9_]*)\b', query)
column_names = re.findall(r'([a-zA-Z][a-zA-Z0-9_/]*)\b', query)
if not column_names:
raise ValueError('No column names found in query')

test_cols = self.ml.get(column_names)
test_cols = self.ml.get(column_names)
missing_cols = [col for col in column_names
if col not in test_cols.columns.str.split('/').str[1]]
if col not in test_cols.columns and
col not in test_cols.columns.str.split('/').str[1]]
if len(missing_cols):
raise ValueError(f'Column names {missing_cols} in query do not match column names in metadata')

df = self.ml.load(column_names)
if type(add_fields) == str:
add_fields = [add_fields]
df = self.ml.load(column_names + add_fields)

return df.query(query)
try:
return df.query(re.sub(r'\w+/', '', query))
except:
print(re.sub(r'\w+(/)', '__', query))
return df.query(re.sub(r'(\w+)/', r'\1__', query))

31 changes: 27 additions & 4 deletions pheno_utils/meta_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,31 @@ def load(self, fields: Union[str,List[str]], flexible: bool=None, prop: str='tab
pd.DataFrame: Dataframe containing the fields from the respective datasets.
"""
found_fields = self.get(fields, flexible, prop)
if found_fields.empty:
return pd.DataFrame()

found_fields.columns = found_fields.columns.str.split('/').str[1]
dup_fields = found_fields.columns.value_counts()\
.to_frame('count').query('count > 1').index
n_datasets = found_fields.loc['dataset'].nunique()

loaded_fields = []
for ds, f in found_fields.T.groupby('dataset'):
df = PhenoLoader(ds, base_path=self.base_path, cohort=self.cohort,
age_sex_dataset=None, **self.kwargs)\
[f.index.tolist()]
if df.empty:
continue

if 'array_index' in df.index.names and n_datasets > 1:
if df.index.get_level_values('array_index').nunique() > 1:
df = df.reset_index('array_index', drop=False)\
.rename(columns={'array_index': f'{ds}__array_index'})
else:
df = df.reset_index('array_index', drop=True)

# rename duplicate fields
df = df.rename(columns=pd.Series(dup_fields + f'_{ds}', index=dup_fields))
df = df.rename(columns=pd.Series(f'{ds}__' + dup_fields, index=dup_fields))

if not len(loaded_fields):
loaded_fields = df
Expand All @@ -113,20 +127,29 @@ def get(self, fields: Union[str,List[str]], flexible: bool=None, prop='tabular_f
flexible = self.flexible_field_search
if isinstance(fields, str):
fields = [fields]
fields = [f.lower() for f in fields]
fields = pd.DataFrame({'field': [f.lower() for f in fields]}).assign(dataset=None)

if prop == 'tabular_field_name':
ind = fields['field'].str.contains('/')
fields.loc[ind, 'dataset'] = fields.loc[ind, 'field'].str.split('/').str[0]
fields.loc[ind, 'field'] = fields.loc[ind, 'field'].str.split('/').str[1]

data = pd.DataFrame()
for dataset, df in self.dicts.items():
keep = (fields['dataset'] == dataset) | fields['dataset'].isnull()
fields_in_dataset = fields.loc[keep, 'field']

if prop == 'tabular_field_name':
search_in = pd.Series(df.columns, index=df.columns).str.lower()
else:
search_in = df.loc[prop].dropna().str.lower()
if flexible:
# use fuzzy matching including regex to find fields
fields_in_col = np.unique([col for f in fields for col, text in search_in.items()
fields_in_col = np.unique([col for f in fields_in_dataset for col, text in search_in.items()
if type(text) is str and re.search(f, text)])
else:
fields_in_col = search_in[search_in.isin(fields)].index
fields_in_col = search_in[search_in.isin(fields_in_dataset)].index

if len(fields_in_col):
this_data = df[fields_in_col]
this_data.columns = dataset + '/' + this_data.columns
Expand Down
Loading

0 comments on commit 3ee65ed

Please sign in to comment.