Skip to content

Commit

Permalink
update for changes to AnnGeno
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Jan 27, 2025
1 parent 70c4ca6 commit cb74c3e
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions deeprvat/data/anngeno_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
training_regions: Optional[Dict[int, np.ndarray]] = None,
covariates: Optional[List[str]] = None,
standardize_covariates: bool = True,
phenotypes: Optional[List[str]] = None,
# phenotypes: Optional[List[str]] = None,
quantile_transform_phenotypes: bool = True, # TODO: This is different from current default
annotation_columns: Optional[List[str]] = None,
variant_set: Optional[Set[int]] = None,
Expand Down Expand Up @@ -119,12 +119,14 @@ def __init__(
self.dtype = dtype
self.mask_type = mask_type
self.standardize_covariates = standardize_covariates
# TODO: Implement this
self.quantile_transform_phenotypes = quantile_transform_phenotypes

if self.training_mode:
if training_regions is None or covariates is None or phenotypes is None:
if training_regions is None or covariates is None: # or phenotypes is None:
raise ValueError(
"training_regions, covariate and phenotypes must be provided if training_mode=True"
"training_regions and covariates "
"must be provided if training_mode=True"
)

# Store regions
Expand All @@ -136,12 +138,7 @@ def __init__(
region_sizes = [self.anngeno.masked_region_sizes[k] for k in self.regions]
region_boundaries = np.concatenate([[0], np.cumsum(region_sizes)])
region_indices = zip(region_boundaries[:-1], region_boundaries[1:])
self.gene_indices = dict(
zip(
self.training_regions.keys(),
region_indices,
)
)
self.gene_indices = dict(zip(self.training_regions.keys(), region_indices))

n_variants = region_boundaries[-1]
n_genes = self.regions.shape[0]
Expand All @@ -152,7 +149,7 @@ def __init__(
self.variant_gene_mask[start:stop, i] = 1

self.covariate_cols = covariates
self.phenotype_cols = phenotypes
self.phenotype_cols = list(self.training_regions.keys())

# Build gene-to-phenotype mask for MaskedLinear layer
n_phenos = len(self.training_regions)
Expand Down Expand Up @@ -185,13 +182,6 @@ def set_samples(self, sample_set: Optional[Set[str]]):
self.sample_batch_size = min(self.sample_batch_size, self.n_samples)

if self.training_mode:
# # TODO: Use anngeno.get_phenotypes()
# phenotype_df = pd.read_parquet(
# self.anngeno.attrs["phenotype_filename"],
# columns=covariates + phenotypes,
# )
# if sample_set is not None:
# phenotype_df = phenotype_df.query("sample in @sample_set")
self.phenotype_df = self.anngeno.get_phenotypes(
columns=self.covariate_cols + self.phenotype_cols
)
Expand Down

0 comments on commit cb74c3e

Please sign in to comment.