Skip to content

Commit

Permalink
Factor _cached_snp_calls(). Add sample_query_options arg to snp_calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
leehart committed Sep 26, 2024
1 parent e483138 commit ee67687
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions malariagen_data/anoph/snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ def snp_calls(
region: base_params.regions,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
sample_query_options: Optional[base_params.sample_query_options] = None,
sample_indices: Optional[base_params.sample_indices] = None,
site_mask: Optional[base_params.site_mask] = None,
site_class: Optional[base_params.site_class] = None,
Expand Down Expand Up @@ -974,6 +975,7 @@ def snp_calls(
regions=regions,
sample_sets=sample_sets_prepped,
sample_query=sample_query,
sample_query_options=sample_query_options,
sample_indices=sample_indices_prepped,
site_mask=site_mask_prepped,
site_class=site_class,
Expand All @@ -994,19 +996,13 @@ def snp_calls(
# We only cache up to 2 items because otherwise we can see
# high memory usage.
@lru_cache(maxsize=2)
def _snp_calls(
def _cached_snp_calls(
self,
*,
regions: Tuple[Region, ...],
sample_sets,
sample_query,
sample_indices,
site_mask,
site_class,
cohort_size,
min_cohort_size,
max_cohort_size,
random_seed,
inline_array,
chunks,
):
Expand Down Expand Up @@ -1067,10 +1063,39 @@ def _snp_calls(
# Add call_genotype_mask.
ds["call_genotype_mask"] = ds["call_genotype"] < 0

return ds

def _snp_calls(
self,
*,
regions: Tuple[Region, ...],
sample_sets,
sample_query,
sample_query_options,
sample_indices,
site_mask,
site_class,
cohort_size,
min_cohort_size,
max_cohort_size,
random_seed,
inline_array,
chunks,
):
# Get SNP calls and concatenate multiple sample sets and/or regions.
ds = self._cached_snp_calls(
regions=regions,
sample_sets=sample_sets,
site_mask=site_mask,
site_class=site_class,
inline_array=inline_array,
chunks=chunks,
)

# Handle sample selection.
if sample_query is not None:
df_samples = self.sample_metadata(sample_sets=sample_sets)
loc_samples = df_samples.eval(sample_query).values
loc_samples = df_samples.eval(sample_query, **sample_query_options).values
if np.count_nonzero(loc_samples) == 0:
raise ValueError(f"No samples found for query {sample_query!r}")
ds = ds.isel(samples=loc_samples)
Expand Down

0 comments on commit ee67687

Please sign in to comment.