Skip to content

Commit

Permalink
add test for AnnGenoDataset GIS computation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Feb 1, 2025
1 parent d3f454e commit 6c15e19
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 18 deletions.
11 changes: 7 additions & 4 deletions deeprvat/data/anngeno_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,18 +235,19 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
region_idx = idx % self.n_regions
regions = [self.regions[region_idx]]
sample_idx = idx // self.n_regions
result["region_idx"] = region_idx
result["region"] = self.regions[region_idx]

sample_slice = slice(
sample_idx * self.sample_batch_size,
min((sample_idx + 1) * self.sample_batch_size, self.n_samples),
)

if self.cache_genotypes:
# BUG: This doesn't work. Should modify get_region to use cached genotypes/annotations
slice_cache = self.anngeno.get_cached_regions(sample_slice=sample_slice)
genotypes = torch.tensor(slice_cache["genotypes"], dtype=self.dtype)
genotypes = torch.tensor(slice_cache["genotypes"][:], dtype=self.dtype)
annotations = torch.tensor(
slice_cache["annotations"], dtype=self.dtype
slice_cache["annotations"][:], dtype=self.dtype
) # TODO: these actually only need to be fetched once
else:
by_gene = [
Expand Down Expand Up @@ -293,7 +294,9 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
return result

def cache_regions(self, compress: bool = False):
self.anngeno.cache_regions(self.regions, compress=compress)
raise NotImplementedError # TODO: A correct implementation of this

self.anngeno.cache_regions(self.regions, compress=compress, dtype=self.dtype)
self.cache_genotypes = True


Expand Down
5 changes: 4 additions & 1 deletion deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,10 @@ def compute_burdens_(
) # TODO: Use AnnGenoDataModule, stage="associate"

logger.info("Caching genotypes to memory")
ds.anngeno.cache_genotypes() # TODO: Parametrize whether to do this

# TODO: Decide whether to do this.
# Current implementation has a bug; also, does it help with overall execution time?
# ds.anngeno.cache_genotypes()

logger.info("Loading models")
# agg_models = load_models(model_config, checkpoint_files, device=device)
Expand Down
87 changes: 74 additions & 13 deletions tests/deeprvat/dataloaders/test_anngeno_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
# pass


# TODO: Implement
# TODO: Check that regions are correct
# Check that all samples are iterated over
# Check output of __getitem__
# Sometimes use sample_set
# Sometimes use cache_regions
Expand Down Expand Up @@ -139,17 +140,77 @@ def test_getitem_training(
)


# TODO: Implement
# TODO: check that all samples and regions are iterated over
# Check output of __getitem__
# Sometimes use sample_set
# Sometimes use cache_regions
# @given(
# anngeno_args_and_genotypes=anngeno_args_and_genotypes(),
# batch_proportion=st.floats(min_value=0, max_value=1, exclude_min=True),
# cache_genotypes=st.booleans(),
# )
# @settings(phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target])
# def test_getitem_testing():
# # use __getitem__
# # compare to results from using AnnGeno.get_region(), AnnGeno.phenotypes, AnnGeno.annotations
# pass
# Sometimes use cache_regions - but not yet, this has a BUG
@given(
anngeno_args_and_genotypes=anngeno_args_and_genotypes(min_annotations=1),
batch_proportion=st.floats(min_value=0, max_value=1, exclude_min=True),
# cache_genotypes=st.booleans(),
)
@settings(phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target])
def test_getitem_gis_computation(anngeno_args_and_genotypes, batch_proportion):
# use __getitem__
anngeno_args = anngeno_args_and_genotypes["anngeno_args"]
genotypes = anngeno_args_and_genotypes["genotypes"]

variant_ids = anngeno_args["variant_metadata"]["id"]
with tempfile.TemporaryDirectory() as tmpdirname:
filename = Path(tmpdirname) / anngeno_args["filename"]
anngeno_args["filename"] = filename
ag = AnnGeno(**anngeno_args)

ag.set_samples(
slice(None),
genotypes,
variant_ids=variant_ids,
)

# Can only use ag.subset_samples in read-only mode
del ag
ag = AnnGeno(filename=filename)

if sample_subset := anngeno_args_and_genotypes.get("sample_subset", None):
ag.subset_samples(sample_subset)

ag.subset_annotations(
annotation_columns=anngeno_args_and_genotypes.get(
"annotation_columns", None
),
variant_set=anngeno_args_and_genotypes.get("variant_set", None),
)

# TODO: construct dataaset and iterate through it
batch_size = math.ceil(batch_proportion * ag.sample_count)
agd = AnnGenoDataset(
filename=filename,
sample_batch_size=batch_size,
mask_type="sum", # TODO: Test max
quantile_transform_phenotypes=False, # TODO: test this function
annotation_columns=anngeno_args_and_genotypes.get(
"annotation_columns", None
),
variant_set=anngeno_args_and_genotypes.get("variant_set", None),
sample_set=anngeno_args_and_genotypes.get("sample_subset", None),
)

# if cache_genotypes:
# agd.cache_regions(compress=True)

dl = DataLoader(
agd,
batch_size=None, # No automatic batching
batch_sampler=None, # No automatic batching
)

for batch in dl:
# compare to results from using AnnGeno.get_region()
reference = ag.get_region(batch["region"], batch["sample_slice"])

assert np.array_equal(
batch["genotypes"], reference["genotypes"].astype(np.float32)
)
assert np.allclose(
batch["annotations"], reference["annotations"].astype(np.float32)
)

0 comments on commit 6c15e19

Please sign in to comment.