Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strand support to AnnDataset and fix deterministic_shift #48

Merged
merged 14 commits into from
Nov 13, 2024
Merged
20 changes: 7 additions & 13 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def get_embeddings(
"""
Extract embeddings from a specified layer in the model for all regions in the dataset.

If anndata is provided, it will add the embeddings to anndata.obsm[layer_name].
If anndata is provided, it will add the embeddings to anndata.varm[layer_name].

Parameters
----------
Expand All @@ -581,7 +581,7 @@ def get_embeddings(

Returns
-------
Embeddings of shape (N, D), where D is the size of the embedding layer.
Embeddings of shape (N, D), where N is the number of regions in the dataset and D is the size of the embedding layer.
"""
if layer_name not in [layer.name for layer in self.model.layers]:
raise ValueError(f"Layer '{layer_name}' not found in model.")
Expand All @@ -597,7 +597,7 @@ def get_embeddings(
embeddings = embedding_model.predict(predict_loader.data, steps=n_predict_steps)

if anndata is not None:
anndata.obsm[layer_name] = embeddings
anndata.varm[layer_name] = embeddings
return embeddings

def predict(
Expand Down Expand Up @@ -652,7 +652,7 @@ def predict_regions(
Parameters
----------
region_idx
List of regions for which to make predictions in the format "chr:start-end".
List of regions for which to make predictions in the format of your original data, either "chr:start-end" or "chr:start-end:strand".

Returns
-------
Expand Down Expand Up @@ -959,7 +959,7 @@ def calculate_contribution_scores_regions(
Parameters
----------
region_idx
Region(s) for which to calculate the contribution scores in the format "chr:start-end".
Region(s) for which to calculate the contribution scores in the format "chr:start-end" or "chr:start-end:strand".
class_names
List of class names to calculate the contribution scores for (should match anndata.obs_names)
If the list is empty, the contribution scores for the 'combined' class will be calculated.
Expand Down Expand Up @@ -1392,10 +1392,7 @@ def enhancer_design_motif_implementation(
enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference)

# get input sequence length of the model
seq_len = (
self.anndatamodule.adata.var.iloc[0]["end"]
- self.anndatamodule.adata.var.iloc[0]["start"]
)
seq_len = self.model.input_shape[1]

# determine the flanks without changes
if no_mutation_flanks is not None and target_len is not None:
Expand Down Expand Up @@ -1590,10 +1587,7 @@ def enhancer_design_in_silico_evolution(
enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference)

# get input sequence length of the model
seq_len = (
self.anndatamodule.adata.var.iloc[0]["end"]
- self.anndatamodule.adata.var.iloc[0]["start"]
)
seq_len = self.model.input_shape[1]

# determine the flanks without changes
if no_mutation_flanks is not None and target_len is not None:
Expand Down
211 changes: 146 additions & 65 deletions src/crested/tl/data/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
import re
from os import PathLike

import numpy as np
Expand All @@ -24,6 +25,39 @@ def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]:
chromsizes_dict = chromsizes.set_index("chr")["size"].to_dict()
return chromsizes_dict

def _flip_region_strand(region: str) -> str:
"""Reverse the strand of a region."""
strand_reverser = {'+': '-', '-': '+'}
return region[:-1]+strand_reverser[region[-1]]

def _check_strandedness(region: str) -> bool:
"""Check the strandedness of a region, raising an error if the formatting isn't recognised."""
if re.fullmatch(r".+:\d+-\d+:[-+]", region):
return True
elif re.fullmatch(r".+:\d+-\d+", region):
return False
else:
raise ValueError(
f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)."
"If provided, strand must be + or -.")

def _deterministic_shift_region(
region: str, stride: int = 50, n_shifts: int = 2
) -> list[str]:
"""
Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two.

This is a legacy function, it's recommended to use stochastic shifting instead.
"""
new_regions = []
chrom, start_end, strand = region.split(":")
start, end = map(int, start_end.split("-"))
for i in range(-n_shifts, n_shifts + 1):
new_start = start + i * stride
new_end = end + i * stride
new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}")
return new_regions


class SequenceLoader:
"""
Expand Down Expand Up @@ -54,6 +88,7 @@ def __init__(
chromsizes: dict | None,
in_memory: bool = False,
always_reverse_complement: bool = False,
deterministic_shift: bool = False,
max_stochastic_shift: int = 0,
regions: list[str] | None = None,
):
Expand All @@ -62,6 +97,7 @@ def __init__(
self.chromsizes = chromsizes
self.in_memory = in_memory
self.always_reverse_complement = always_reverse_complement
self.deterministic_shift = deterministic_shift
self.max_stochastic_shift = max_stochastic_shift
self.sequences = {}
self.complement = str.maketrans("ACGT", "TGCA")
Expand All @@ -72,19 +108,40 @@ def __init__(
def _load_sequences_into_memory(self, regions: list[str]):
"""Load all sequences into memory (dict)."""
logger.info("Loading sequences into memory...")
# Check region formatting
stranded = _check_strandedness(regions[0])

for region in tqdm(regions):
extended_sequence = self._get_extended_sequence(region)
self.sequences[f"{region}:+"] = extended_sequence
if self.always_reverse_complement:
self.sequences[f"{region}:-"] = self._reverse_complement(
extended_sequence
)
# Make region stranded if not
if not stranded:
strand = "+"
region = f"{region}:{strand}"
if region[-4] == ":":
raise ValueError(f"You are double-adding strand ids to your region {region}. Check if all regions are stranded or unstranded.")

# Add deterministic shift regions
if self.deterministic_shift:
regions = _deterministic_shift_region(region)
else:
regions = [region]

def _get_extended_sequence(self, region: str) -> str:
"""Get sequence from genome file, extended for stochastic shifting."""
chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
for region in regions:
# Parse region
chrom, start_end, strand = region.split(":")
start, end = map(int, start_end.split("-"))

# Add region to self.sequences
extended_sequence = self._get_extended_sequence(chrom, start, end, strand)
self.sequences[region] = extended_sequence

# Add reverse-complemented region to self.sequences if always_reverse_complement
if self.always_reverse_complement:
self.sequences[_flip_region_strand(region)] = self._reverse_complement(
extended_sequence
)

def _get_extended_sequence(self, chrom: str, start: int, end: int, strand: str) -> str:
"""Get sequence from genome file, extended for stochastic shifting."""
extended_start = max(0, start - self.max_stochastic_shift)
extended_end = extended_start + (end - start) + (self.max_stochastic_shift * 2)

Expand All @@ -96,32 +153,60 @@ def _get_extended_sequence(self, region: str) -> str:
)
extended_end = chrom_size

return self.genome.fetch(chrom, extended_start, extended_end).upper()
seq = self.genome.fetch(chrom, extended_start, extended_end).upper()
if strand == "-":
seq = self._reverse_complement(seq)
return seq

def _reverse_complement(self, sequence: str) -> str:
"""Reverse complement a sequence."""
return sequence.translate(self.complement)[::-1]

def get_sequence(self, region: str, strand: str = "+", shift: int = 0) -> str:
"""Get sequence for a region, strand, and shift from memory or fasta."""
key = f"{region}:{strand}"
def get_sequence(self, region: str, stranded: bool | None = None, shift: int = 0) -> str:
"""
Get sequence for a region, strand, and shift from memory or fasta.

If no strand is given in region or strand, assumes positive strand.

Parameters
----------
region
Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand).
stranded
Whether the input data is stranded. Default (None) infers from sequence (at a computational cost).
If not stranded, positive strand is assumed.
shift:
Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism.

Returns
-------
The DNA sequence, as a string.
"""
if stranded is None:
stranded = _check_strandedness(region)
if not stranded:
region = f"{region}:+"
# Parse region
chrom, start_end, strand = region.split(":")
start, end = map(int, start_end.split("-"))

# Get extended sequence
if self.in_memory:
sequence = self.sequences[key]
sequence = self.sequences[region]
else:
sequence = self._get_extended_sequence(region)
chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
sequence = self._get_extended_sequence(chrom, start, end, strand)

# Extract from extended sequence
start_idx = self.max_stochastic_shift + shift
end_idx = start_idx + (end - start)
sub_sequence = sequence[start_idx:end_idx]

# handle reverse complement on the go if not loaded into memory
if (strand == "-") and (not self.in_memory):
sub_sequence = self._reverse_complement(sub_sequence)

# pad with Ns if sequence is shorter than expected
# Pad with Ns if sequence is shorter than expected
if len(sub_sequence) < (end - start):
sub_sequence = sub_sequence.ljust(end - start, "N")
if strand == "+":
sub_sequence = sub_sequence.ljust(end - start, "N")
else:
sub_sequence = sub_sequence.rjust(end - start, "N")

return sub_sequence

Expand All @@ -135,7 +220,7 @@ class IndexManager:
Parameters
----------
indices
List of indices in format "chrom:start-end".
List of indices in format "chr:start-end" or "chr:start-end:strand".
always_reverse_complement
If True, all sequences will be augmented with their reverse complement.
deterministic_shift
Expand All @@ -157,47 +242,35 @@ def __init__(
)

def shuffle_indices(self):
"""Shuffle indices. Managed by subclass AnnDataLoader."""
"""Shuffle indices. Managed by wrapping class AnnDataLoader."""
np.random.shuffle(self.augmented_indices)

def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str]]:
"""Augment indices with strand information. Necessary if always reverse complement to map sequences back to targets."""
augmented_indices = []
augmented_indices_map = {}
for region in indices:
if not _check_strandedness(region): # If slow, can use AnnDataset stranded argument - but this validates every region's formatting as well
stranded_region = f"{region}:+"
else:
stranded_region = region

if self.deterministic_shift:
shifted_regions = self._deterministic_shift_region(region)
shifted_regions = _deterministic_shift_region(stranded_region)
for shifted_region in shifted_regions:
augmented_indices.append(f"{shifted_region}:+")
augmented_indices_map[f"{shifted_region}:+"] = region
augmented_indices.append(shifted_region)
augmented_indices_map[shifted_region] = region
if self.always_reverse_complement:
augmented_indices.append(f"{shifted_region}:-")
augmented_indices_map[f"{shifted_region}:-"] = region
augmented_indices.append(_flip_region_strand(shifted_region))
augmented_indices_map[_flip_region_strand(shifted_region)] = region
else:
augmented_indices.append(f"{region}:+")
augmented_indices_map[f"{region}:+"] = region
augmented_indices.append(stranded_region)
augmented_indices_map[stranded_region] = region
if self.always_reverse_complement:
augmented_indices.append(f"{region}:-")
augmented_indices_map[f"{region}:-"] = region
augmented_indices.append(_flip_region_strand(stranded_region))
augmented_indices_map[_flip_region_strand(stranded_region)] = region
return augmented_indices, augmented_indices_map

def _deterministic_shift_region(
self, region: str, stride: int = 50, n_shifts: int = 2
) -> list[str]:
"""
Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two.

This is a legacy function, it's recommended to use stochastic shifting instead.
"""
new_regions = []
chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
for i in range(-n_shifts, n_shifts + 1):
new_start = start + i * stride
new_end = end + i * stride
new_regions.append(f"{chrom}:{new_start}-{new_end}")
return new_regions


if os.environ["KERAS_BACKEND"] == "pytorch":
import torch
Expand Down Expand Up @@ -259,22 +332,31 @@ def __init__(
self.num_outputs = self.anndata.X.shape[0]
self.random_reverse_complement = random_reverse_complement
self.max_stochastic_shift = max_stochastic_shift
self.shuffle = False # managed by subclass AnnDataLoader
self.shuffle = False # managed by wrapping class AnnDataLoader

# Check region formatting
stranded = _check_strandedness(self.indices[0])
if stranded and (always_reverse_complement or random_reverse_complement):
logger.info(
"Setting always_reverse_complement=True or random_reverse_complement=True with stranded data.",
"This means both strands are used when training and the strand information is effectively disregarded."
)

self.sequence_loader = SequenceLoader(
genome_file,
self.chromsizes,
in_memory,
always_reverse_complement,
max_stochastic_shift,
self.indices,
chromsizes=self.chromsizes,
in_memory=in_memory,
always_reverse_complement=always_reverse_complement,
deterministic_shift=deterministic_shift,
max_stochastic_shift=max_stochastic_shift,
regions=self.indices,
)
self.index_manager = IndexManager(
self.indices,
always_reverse_complement=always_reverse_complement,
deterministic_shift=deterministic_shift,
)
self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0]))
self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0], stranded = stranded))

@staticmethod
def _split_anndata(anndata: AnnData, split: str) -> AnnData:
Expand Down Expand Up @@ -308,19 +390,18 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]:
"""Return sequence and target for a given index."""
augmented_index = self.index_manager.augmented_indices[idx]
original_index = self.index_manager.augmented_indices_map[augmented_index]

strand = "-" if augmented_index.endswith(":-") else "+"

# stochastic shift
if self.max_stochastic_shift > 0:
shift = np.random.randint(
-self.max_stochastic_shift, self.max_stochastic_shift + 1
)
x = self.sequence_loader.get_sequence(original_index, strand, shift)
else:
x = self.sequence_loader.get_sequence(original_index, strand)
shift = 0

# Get sequence
x = self.sequence_loader.get_sequence(augmented_index, stranded = True, shift = shift)

# random reverse complement (always is done in the sequence loader)
# random reverse complement (always_reverse_complement is done in the sequence loader)
if self.random_reverse_complement and np.random.rand() < 0.5:
x = self.sequence_loader._reverse_complement(x)

Expand Down
Loading