Skip to content

Commit

Permalink
Merge pull request #48 from aertslab/stranded_dataloader
Browse files Browse the repository at this point in the history
Add strand support to AnnDataset and fix deterministic_shift
  • Loading branch information
casblaauw authored Nov 13, 2024
2 parents 24c0a3c + 60c7169 commit 904fa13
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 78 deletions.
20 changes: 7 additions & 13 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def get_embeddings(
"""
Extract embeddings from a specified layer in the model for all regions in the dataset.
If anndata is provided, it will add the embeddings to anndata.obsm[layer_name].
If anndata is provided, it will add the embeddings to anndata.varm[layer_name].
Parameters
----------
Expand All @@ -591,7 +591,7 @@ def get_embeddings(
Returns
-------
Embeddings of shape (N, D), where D is the size of the embedding layer.
Embeddings of shape (N, D), where N is the number of regions in the dataset and D is the size of the embedding layer.
"""
if layer_name not in [layer.name for layer in self.model.layers]:
raise ValueError(f"Layer '{layer_name}' not found in model.")
Expand All @@ -607,7 +607,7 @@ def get_embeddings(
embeddings = embedding_model.predict(predict_loader.data, steps=n_predict_steps)

if anndata is not None:
anndata.obsm[layer_name] = embeddings
anndata.varm[layer_name] = embeddings
return embeddings

def predict(
Expand Down Expand Up @@ -662,7 +662,7 @@ def predict_regions(
Parameters
----------
region_idx
List of regions for which to make predictions in the format "chr:start-end".
List of regions for which to make predictions in the format of your original data, either "chr:start-end" or "chr:start-end:strand".
Returns
-------
Expand Down Expand Up @@ -973,7 +973,7 @@ def calculate_contribution_scores_regions(
Parameters
----------
region_idx
Region(s) for which to calculate the contribution scores in the format "chr:start-end".
Region(s) for which to calculate the contribution scores in the format "chr:start-end" or "chr:start-end:strand".
class_names
List of class names to calculate the contribution scores for (should match anndata.obs_names)
If the list is empty, the contribution scores for the 'combined' class will be calculated.
Expand Down Expand Up @@ -1410,10 +1410,7 @@ def enhancer_design_motif_implementation(
enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference)

# get input sequence length of the model
seq_len = (
self.anndatamodule.adata.var.iloc[0]["end"]
- self.anndatamodule.adata.var.iloc[0]["start"]
)
seq_len = self.model.input_shape[1]

# determine the flanks without changes
if no_mutation_flanks is not None and target_len is not None:
Expand Down Expand Up @@ -1616,10 +1613,7 @@ def enhancer_design_in_silico_evolution(
enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference)

# get input sequence length of the model
seq_len = (
self.anndatamodule.adata.var.iloc[0]["end"]
- self.anndatamodule.adata.var.iloc[0]["start"]
)
seq_len = self.model.input_shape[1]

# determine the flanks without changes
if no_mutation_flanks is not None and target_len is not None:
Expand Down
211 changes: 146 additions & 65 deletions src/crested/tl/data/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
import re
from os import PathLike

import numpy as np
Expand All @@ -24,6 +25,39 @@ def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]:
chromsizes_dict = chromsizes.set_index("chr")["size"].to_dict()
return chromsizes_dict

def _flip_region_strand(region: str) -> str:
"""Reverse the strand of a region."""
strand_reverser = {'+': '-', '-': '+'}
return region[:-1]+strand_reverser[region[-1]]

def _check_strandedness(region: str) -> bool:
"""Check the strandedness of a region, raising an error if the formatting isn't recognised."""
if re.fullmatch(r".+:\d+-\d+:[-+]", region):
return True
elif re.fullmatch(r".+:\d+-\d+", region):
return False
else:
raise ValueError(
f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)."
"If provided, strand must be + or -.")

def _deterministic_shift_region(
region: str, stride: int = 50, n_shifts: int = 2
) -> list[str]:
"""
Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two.
This is a legacy function, it's recommended to use stochastic shifting instead.
"""
new_regions = []
chrom, start_end, strand = region.split(":")
start, end = map(int, start_end.split("-"))
for i in range(-n_shifts, n_shifts + 1):
new_start = start + i * stride
new_end = end + i * stride
new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}")
return new_regions


class SequenceLoader:
"""
Expand Down Expand Up @@ -54,6 +88,7 @@ def __init__(
chromsizes: dict | None,
in_memory: bool = False,
always_reverse_complement: bool = False,
deterministic_shift: bool = False,
max_stochastic_shift: int = 0,
regions: list[str] | None = None,
):
Expand All @@ -62,6 +97,7 @@ def __init__(
self.chromsizes = chromsizes
self.in_memory = in_memory
self.always_reverse_complement = always_reverse_complement
self.deterministic_shift = deterministic_shift
self.max_stochastic_shift = max_stochastic_shift
self.sequences = {}
self.complement = str.maketrans("ACGT", "TGCA")
Expand All @@ -72,19 +108,40 @@ def __init__(
def _load_sequences_into_memory(self, regions: list[str]):
"""Load all sequences into memory (dict)."""
logger.info("Loading sequences into memory...")
# Check region formatting
stranded = _check_strandedness(regions[0])

for region in tqdm(regions):
extended_sequence = self._get_extended_sequence(region)
self.sequences[f"{region}:+"] = extended_sequence
if self.always_reverse_complement:
self.sequences[f"{region}:-"] = self._reverse_complement(
extended_sequence
)
# Make region stranded if not
if not stranded:
strand = "+"
region = f"{region}:{strand}"
if region[-4] == ":":
raise ValueError(f"You are double-adding strand ids to your region {region}. Check if all regions are stranded or unstranded.")

# Add deterministic shift regions
if self.deterministic_shift:
regions = _deterministic_shift_region(region)
else:
regions = [region]

def _get_extended_sequence(self, region: str) -> str:
"""Get sequence from genome file, extended for stochastic shifting."""
chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
for region in regions:
# Parse region
chrom, start_end, strand = region.split(":")
start, end = map(int, start_end.split("-"))

# Add region to self.sequences
extended_sequence = self._get_extended_sequence(chrom, start, end, strand)
self.sequences[region] = extended_sequence

# Add reverse-complemented region to self.sequences if always_reverse_complement
if self.always_reverse_complement:
self.sequences[_flip_region_strand(region)] = self._reverse_complement(
extended_sequence
)

def _get_extended_sequence(self, chrom: str, start: int, end: int, strand: str) -> str:
"""Get sequence from genome file, extended for stochastic shifting."""
extended_start = max(0, start - self.max_stochastic_shift)
extended_end = extended_start + (end - start) + (self.max_stochastic_shift * 2)

Expand All @@ -96,32 +153,60 @@ def _get_extended_sequence(self, region: str) -> str:
)
extended_end = chrom_size

return self.genome.fetch(chrom, extended_start, extended_end).upper()
seq = self.genome.fetch(chrom, extended_start, extended_end).upper()
if strand == "-":
seq = self._reverse_complement(seq)
return seq

def _reverse_complement(self, sequence: str) -> str:
"""Reverse complement a sequence."""
return sequence.translate(self.complement)[::-1]

def get_sequence(self, region: str, strand: str = "+", shift: int = 0) -> str:
"""Get sequence for a region, strand, and shift from memory or fasta."""
key = f"{region}:{strand}"
def get_sequence(self, region: str, stranded: bool | None = None, shift: int = 0) -> str:
"""
Get sequence for a region, strand, and shift from memory or fasta.
If no strand is given in region or strand, assumes positive strand.
Parameters
----------
region
Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand).
stranded
Whether the input data is stranded. Default (None) infers from sequence (at a computational cost).
If not stranded, positive strand is assumed.
shift:
Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism.
Returns
-------
The DNA sequence, as a string.
"""
if stranded is None:
stranded = _check_strandedness(region)
if not stranded:
region = f"{region}:+"
# Parse region
chrom, start_end, strand = region.split(":")
start, end = map(int, start_end.split("-"))

# Get extended sequence
if self.in_memory:
sequence = self.sequences[key]
sequence = self.sequences[region]
else:
sequence = self._get_extended_sequence(region)
chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
sequence = self._get_extended_sequence(chrom, start, end, strand)

# Extract from extended sequence
start_idx = self.max_stochastic_shift + shift
end_idx = start_idx + (end - start)
sub_sequence = sequence[start_idx:end_idx]

# handle reverse complement on the go if not loaded into memory
if (strand == "-") and (not self.in_memory):
sub_sequence = self._reverse_complement(sub_sequence)

# pad with Ns if sequence is shorter than expected
# Pad with Ns if sequence is shorter than expected
if len(sub_sequence) < (end - start):
sub_sequence = sub_sequence.ljust(end - start, "N")
if strand == "+":
sub_sequence = sub_sequence.ljust(end - start, "N")
else:
sub_sequence = sub_sequence.rjust(end - start, "N")

return sub_sequence

Expand All @@ -135,7 +220,7 @@ class IndexManager:
Parameters
----------
indices
List of indices in format "chrom:start-end".
List of indices in format "chr:start-end" or "chr:start-end:strand".
always_reverse_complement
If True, all sequences will be augmented with their reverse complement.
deterministic_shift
Expand All @@ -157,47 +242,35 @@ def __init__(
)

def shuffle_indices(self):
"""Shuffle indices. Managed by subclass AnnDataLoader."""
"""Shuffle indices. Managed by wrapping class AnnDataLoader."""
np.random.shuffle(self.augmented_indices)

def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str]]:
"""Augment indices with strand information. Necessary if always reverse complement to map sequences back to targets."""
augmented_indices = []
augmented_indices_map = {}
for region in indices:
if not _check_strandedness(region): # If slow, can use AnnDataset stranded argument - but this validates every region's formatting as well
stranded_region = f"{region}:+"
else:
stranded_region = region

if self.deterministic_shift:
shifted_regions = self._deterministic_shift_region(region)
shifted_regions = _deterministic_shift_region(stranded_region)
for shifted_region in shifted_regions:
augmented_indices.append(f"{shifted_region}:+")
augmented_indices_map[f"{shifted_region}:+"] = region
augmented_indices.append(shifted_region)
augmented_indices_map[shifted_region] = region
if self.always_reverse_complement:
augmented_indices.append(f"{shifted_region}:-")
augmented_indices_map[f"{shifted_region}:-"] = region
augmented_indices.append(_flip_region_strand(shifted_region))
augmented_indices_map[_flip_region_strand(shifted_region)] = region
else:
augmented_indices.append(f"{region}:+")
augmented_indices_map[f"{region}:+"] = region
augmented_indices.append(stranded_region)
augmented_indices_map[stranded_region] = region
if self.always_reverse_complement:
augmented_indices.append(f"{region}:-")
augmented_indices_map[f"{region}:-"] = region
augmented_indices.append(_flip_region_strand(stranded_region))
augmented_indices_map[_flip_region_strand(stranded_region)] = region
return augmented_indices, augmented_indices_map

def _deterministic_shift_region(
self, region: str, stride: int = 50, n_shifts: int = 2
) -> list[str]:
"""
Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two.
This is a legacy function, it's recommended to use stochastic shifting instead.
"""
new_regions = []
chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
for i in range(-n_shifts, n_shifts + 1):
new_start = start + i * stride
new_end = end + i * stride
new_regions.append(f"{chrom}:{new_start}-{new_end}")
return new_regions


if os.environ["KERAS_BACKEND"] == "pytorch":
import torch
Expand Down Expand Up @@ -259,22 +332,31 @@ def __init__(
self.num_outputs = self.anndata.X.shape[0]
self.random_reverse_complement = random_reverse_complement
self.max_stochastic_shift = max_stochastic_shift
self.shuffle = False # managed by subclass AnnDataLoader
self.shuffle = False # managed by wrapping class AnnDataLoader

# Check region formatting
stranded = _check_strandedness(self.indices[0])
if stranded and (always_reverse_complement or random_reverse_complement):
logger.info(
"Setting always_reverse_complement=True or random_reverse_complement=True with stranded data.",
"This means both strands are used when training and the strand information is effectively disregarded."
)

self.sequence_loader = SequenceLoader(
genome_file,
self.chromsizes,
in_memory,
always_reverse_complement,
max_stochastic_shift,
self.indices,
chromsizes=self.chromsizes,
in_memory=in_memory,
always_reverse_complement=always_reverse_complement,
deterministic_shift=deterministic_shift,
max_stochastic_shift=max_stochastic_shift,
regions=self.indices,
)
self.index_manager = IndexManager(
self.indices,
always_reverse_complement=always_reverse_complement,
deterministic_shift=deterministic_shift,
)
self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0]))
self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0], stranded = stranded))

@staticmethod
def _split_anndata(anndata: AnnData, split: str) -> AnnData:
Expand Down Expand Up @@ -308,19 +390,18 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]:
"""Return sequence and target for a given index."""
augmented_index = self.index_manager.augmented_indices[idx]
original_index = self.index_manager.augmented_indices_map[augmented_index]

strand = "-" if augmented_index.endswith(":-") else "+"

# stochastic shift
if self.max_stochastic_shift > 0:
shift = np.random.randint(
-self.max_stochastic_shift, self.max_stochastic_shift + 1
)
x = self.sequence_loader.get_sequence(original_index, strand, shift)
else:
x = self.sequence_loader.get_sequence(original_index, strand)
shift = 0

# Get sequence
x = self.sequence_loader.get_sequence(augmented_index, stranded = True, shift = shift)

# random reverse complement (always is done in the sequence loader)
# random reverse complement (always_reverse_complement is done in the sequence loader)
if self.random_reverse_complement and np.random.rand() < 0.5:
x = self.sequence_loader._reverse_complement(x)

Expand Down

0 comments on commit 904fa13

Please sign in to comment.