Skip to content

Commit

Permalink
Bulk index fix (#35)
Browse files Browse the repository at this point in the history
* added: pl.load_bulk_data: extend_bulk_index arg

* improved: pl.load_bulk_data: missing levels join on collection_date

* fixed: call to get with not_bulk_field where no bulk dictionaries exist

* fixed: load_bulk_data: regression in showing parent_bulk warning

* improved: comments
  • Loading branch information
alondmnt authored Dec 9, 2024
1 parent 5ce0ada commit bba8acb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 22 deletions.
58 changes: 47 additions & 11 deletions nbs/05_pheno_loader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
" pivot=None, \n",
" keep_undefined_research_stage: Union[None, str] = None,\n",
" join_non_overlapping: Union[None, bool] = None,\n",
" extend_bulk_index: bool = True,\n",
" **kwargs\n",
" ) -> Union[pd.DataFrame, None]:\n",
" \"\"\"\n",
Expand All @@ -198,6 +199,7 @@
" pivot=pivot,\n",
" keep_undefined_research_stage=keep_undefined_research_stage,\n",
" join_non_overlapping=join_non_overlapping,\n",
" extend_bulk_index=extend_bulk_index\n",
" **kwargs)\n",
"\n",
" def load_bulk_data(\n",
Expand All @@ -212,6 +214,7 @@
" pivot=None,\n",
" keep_undefined_research_stage: Union[None, str] = None,\n",
" join_non_overlapping: Union[None, bool] = None,\n",
" extend_bulk_index: bool = True,\n",
" **kwargs\n",
" ) -> Union[pd.DataFrame, None]:\n",
" \"\"\"\n",
Expand All @@ -228,6 +231,7 @@
" pivot (str, optional): The name of the field to pivot the data on (if DataFrame). Defaults to None.\n",
" keep_undefined_research_stage (bool, optional): Whether to keep samples with undefined research stage. Defaults to None.\n",
" join_non_overlapping (bool, optional): Whether to join tables with non-overlapping indices. Defaults to None.\n",
" extend_bulk_index (bool, optional): Whether to extend the bulk index to match the main table index. Defaults to False.\n",
" \"\"\"\n",
" if keep_undefined_research_stage is None:\n",
" keep_undefined_research_stage = self.keep_undefined_research_stage\n",
Expand All @@ -236,20 +240,42 @@
" # get path to bulk file\n",
" if type(field_name) is str:\n",
" field_name = [field_name]\n",
" sample, fields = self.get(field_name + ['participant_id'], return_fields=True, keep_undefined_research_stage=keep_undefined_research_stage, join_non_overlapping=join_non_overlapping)\n",
" sample, fields = self.get(\n",
" field_name,\n",
" squeeze=False,\n",
" return_fields=True,\n",
" keep_undefined_research_stage=keep_undefined_research_stage,\n",
" join_non_overlapping=join_non_overlapping\n",
" )\n",
" # TODO: slice bulk data based on field_type\n",
" if sample.shape[1] > 2:\n",
" if sample.shape[1] > 1:\n",
" # requested fields appear in more than one bulk file\n",
" if parent_bulk is not None:\n",
" # get the field_name associated with parent_bulk\n",
" sample, fields = self.get(field_name + ['participant_id'], return_fields=True, keep_undefined_research_stage=keep_undefined_research_stage, join_non_overlapping=join_non_overlapping, parent_dataframe=parent_bulk)\n",
" sample, fields = self.get(\n",
" field_name,\n",
" squeeze=False,\n",
" return_fields=True,\n",
" keep_undefined_research_stage=keep_undefined_research_stage,\n",
" join_non_overlapping=join_non_overlapping,\n",
" parent_dataframe=parent_bulk\n",
" )\n",
" else:\n",
" if self.errors == 'raise':\n",
" raise ValueError(f'More than one field found for {field_name}. Specify parent_bulk')\n",
" elif self.errors == 'warn':\n",
" warnings.warn(f'More than one field found for {field_name}. Specify parent_bulk')\n",
" fields = [f for f in fields if f != 'participant_id'] # these are fields, as opposed to parent_bulk\n",
" col = sample.columns.drop('participant_id')[0] # can be different from field_name if parent_dataframe is implied\n",
" sample = sample.astype({col: str})\n",
" col = sample.columns[0] # can be different from field_name if parent_dataframe is implied\n",
" # add participant_id and collection_date to the sample, ensure it's from the main tables\n",
" sample = sample \\\n",
" .join(\n",
" self.get(\n",
" ['participant_id', 'collection_date'],\n",
" keep_undefined_research_stage=keep_undefined_research_stage,\n",
" join_non_overlapping=join_non_overlapping,\n",
" not_bulk_field=True\n",
" )\n",
" ).astype({col: str})\n",
"\n",
" # filter by participant_id, research_stage and array_index\n",
" query_str = []\n",
Expand Down Expand Up @@ -290,15 +316,22 @@
" else:\n",
" field_type = self.dict.loc[fields, 'field_type'].values[0]\n",
" load_func = get_function_for_field_type(field_type)\n",
" sample = sample.loc[:, col]\n",
" sample = self.__slice_bulk_partition__(fields, sample)\n",
" sample_path = sample.loc[:, col]\n",
" sample_path = self.__slice_bulk_partition__(fields, sample_path)\n",
" kwargs.update(self.__slice_bulk_data__(fields))\n",
" data = []\n",
" for p in sample.unique():\n",
" for p in sample_path.unique():\n",
" try:\n",
" data.append(load_func(p, **kwargs))\n",
" if isinstance(data[-1], pd.DataFrame):\n",
" data[-1] = self.__add_missing_levels__(data[-1], sample.loc[sample == p].to_frame())\n",
" if extend_bulk_index:\n",
" data[-1] = self.__add_missing_levels__(\n",
" data[-1],\n",
" sample.loc[sample[col] == p, :].drop(\n",
" columns=['participant_id'],\n",
" errors='ignore'\n",
" )\n",
" )\n",
" if query_str:\n",
" data[-1] = data[-1].query(query_str)\n",
" data[-1].sort_index(inplace=True)\n",
Expand Down Expand Up @@ -450,7 +483,7 @@
" fields = [fields]\n",
"\n",
" search_dict = self.dict.copy()\n",
" if not_bulk_field:\n",
" if not_bulk_field and 'parent_dataframe' in search_dict.columns:\n",
" search_dict = search_dict.loc[search_dict['parent_dataframe'].isnull()]\n",
" for k, v in kwargs.items():\n",
" if k in search_dict.columns:\n",
Expand Down Expand Up @@ -895,6 +928,9 @@
" \"\"\"\n",
" # Identify common index levels\n",
" common_index_levels = list(set(data.index.names).intersection(set(more_levels.index.names)))\n",
" if 'collection_date' in data.columns.union(data.index.names) and \\\n",
" 'collection_date' in more_levels.columns.union(more_levels.index.names):\n",
" common_index_levels.append('collection_date')\n",
" if len(common_index_levels) == 0:\n",
" return data\n",
" \n",
Expand Down
58 changes: 47 additions & 11 deletions pheno_utils/pheno_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def load_sample_data(
pivot=None,
keep_undefined_research_stage: Union[None, str] = None,
join_non_overlapping: Union[None, bool] = None,
extend_bulk_index: bool = True,
**kwargs
) -> Union[pd.DataFrame, None]:
"""
Expand All @@ -147,6 +148,7 @@ def load_sample_data(
pivot=pivot,
keep_undefined_research_stage=keep_undefined_research_stage,
join_non_overlapping=join_non_overlapping,
extend_bulk_index=extend_bulk_index
**kwargs)

def load_bulk_data(
Expand All @@ -161,6 +163,7 @@ def load_bulk_data(
pivot=None,
keep_undefined_research_stage: Union[None, str] = None,
join_non_overlapping: Union[None, bool] = None,
extend_bulk_index: bool = True,
**kwargs
) -> Union[pd.DataFrame, None]:
"""
Expand All @@ -177,6 +180,7 @@ def load_bulk_data(
pivot (str, optional): The name of the field to pivot the data on (if DataFrame). Defaults to None.
keep_undefined_research_stage (bool, optional): Whether to keep samples with undefined research stage. Defaults to None.
join_non_overlapping (bool, optional): Whether to join tables with non-overlapping indices. Defaults to None.
extend_bulk_index (bool, optional): Whether to extend the bulk index to match the main table index. Defaults to False.
"""
if keep_undefined_research_stage is None:
keep_undefined_research_stage = self.keep_undefined_research_stage
Expand All @@ -185,20 +189,42 @@ def load_bulk_data(
# get path to bulk file
if type(field_name) is str:
field_name = [field_name]
sample, fields = self.get(field_name + ['participant_id'], return_fields=True, keep_undefined_research_stage=keep_undefined_research_stage, join_non_overlapping=join_non_overlapping)
sample, fields = self.get(
field_name,
squeeze=False,
return_fields=True,
keep_undefined_research_stage=keep_undefined_research_stage,
join_non_overlapping=join_non_overlapping
)
# TODO: slice bulk data based on field_type
if sample.shape[1] > 2:
if sample.shape[1] > 1:
# requested fields appear in more than one bulk file
if parent_bulk is not None:
# get the field_name associated with parent_bulk
sample, fields = self.get(field_name + ['participant_id'], return_fields=True, keep_undefined_research_stage=keep_undefined_research_stage, join_non_overlapping=join_non_overlapping, parent_dataframe=parent_bulk)
sample, fields = self.get(
field_name,
squeeze=False,
return_fields=True,
keep_undefined_research_stage=keep_undefined_research_stage,
join_non_overlapping=join_non_overlapping,
parent_dataframe=parent_bulk
)
else:
if self.errors == 'raise':
raise ValueError(f'More than one field found for {field_name}. Specify parent_bulk')
elif self.errors == 'warn':
warnings.warn(f'More than one field found for {field_name}. Specify parent_bulk')
fields = [f for f in fields if f != 'participant_id'] # these are fields, as opposed to parent_bulk
col = sample.columns.drop('participant_id')[0] # can be different from field_name if parent_dataframe is implied
sample = sample.astype({col: str})
col = sample.columns[0] # can be different from field_name if parent_dataframe is implied
# add participant_id and collection_date to the sample, ensure it's from the main tables
sample = sample \
.join(
self.get(
['participant_id', 'collection_date'],
keep_undefined_research_stage=keep_undefined_research_stage,
join_non_overlapping=join_non_overlapping,
not_bulk_field=True
)
).astype({col: str})

# filter by participant_id, research_stage and array_index
query_str = []
Expand Down Expand Up @@ -239,15 +265,22 @@ def load_bulk_data(
else:
field_type = self.dict.loc[fields, 'field_type'].values[0]
load_func = get_function_for_field_type(field_type)
sample = sample.loc[:, col]
sample = self.__slice_bulk_partition__(fields, sample)
sample_path = sample.loc[:, col]
sample_path = self.__slice_bulk_partition__(fields, sample_path)
kwargs.update(self.__slice_bulk_data__(fields))
data = []
for p in sample.unique():
for p in sample_path.unique():
try:
data.append(load_func(p, **kwargs))
if isinstance(data[-1], pd.DataFrame):
data[-1] = self.__add_missing_levels__(data[-1], sample.loc[sample == p].to_frame())
if extend_bulk_index:
data[-1] = self.__add_missing_levels__(
data[-1],
sample.loc[sample[col] == p, :].drop(
columns=['participant_id'],
errors='ignore'
)
)
if query_str:
data[-1] = data[-1].query(query_str)
data[-1].sort_index(inplace=True)
Expand Down Expand Up @@ -399,7 +432,7 @@ def get(self, fields: Union[str,List[str]], flexible: bool=None, not_bulk_field=
fields = [fields]

search_dict = self.dict.copy()
if not_bulk_field:
if not_bulk_field and 'parent_dataframe' in search_dict.columns:
search_dict = search_dict.loc[search_dict['parent_dataframe'].isnull()]
for k, v in kwargs.items():
if k in search_dict.columns:
Expand Down Expand Up @@ -844,6 +877,9 @@ def __add_missing_levels__(self, data: pd.DataFrame, more_levels: pd.DataFrame)
"""
# Identify common index levels
common_index_levels = list(set(data.index.names).intersection(set(more_levels.index.names)))
if 'collection_date' in data.columns.union(data.index.names) and \
'collection_date' in more_levels.columns.union(more_levels.index.names):
common_index_levels.append('collection_date')
if len(common_index_levels) == 0:
return data

Expand Down

0 comments on commit bba8acb

Please sign in to comment.