From 09007e21df79e20c3fefb9c58396d62016d1af8a Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 15:49:30 +0100 Subject: [PATCH 01/14] Add support for stranded data in AnnDataset --- src/crested/tl/_crested.py | 4 +- src/crested/tl/data/_dataset.py | 178 +++++++++++++++++++++++--------- 2 files changed, 134 insertions(+), 48 deletions(-) diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 3015bec6..b6d686d1 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -652,7 +652,7 @@ def predict_regions( Parameters ---------- region_idx - List of regions for which to make predictions in the format "chr:start-end". + List of regions for which to make predictions in the format of your original data, either "chr:start-end" or "chr:start-end:strand". Returns ------- @@ -959,7 +959,7 @@ def calculate_contribution_scores_regions( Parameters ---------- region_idx - Region(s) for which to calculate the contribution scores in the format "chr:start-end". + Region(s) for which to calculate the contribution scores in the format "chr:start-end" or "chr:start-end:strand". class_names List of class names to calculate the contribution scores for (should match anndata.obs_names) If the list is empty, the contribution scores for the 'combined' class will be calculated. diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 941233d3..80b33cf9 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import re from os import PathLike import numpy as np @@ -24,6 +25,20 @@ def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]: chromsizes_dict = chromsizes.set_index("chr")["size"].to_dict() return chromsizes_dict +def _flip_region_strand(region: str) -> str: + """Reverse the strand of a region.""" + strand_reverser = {'+': '-', '-': '+'} + return region[:-1]+strand_reverser[region[-1]] + +def _check_strandedness(region: str) -> bool: + """Check the strandedness of a region, raising an error if the formatting isn't recognised.""" + if re.fullmatch(r".+:\d+-\d+:[-+]", region): + return True + elif re.fullmatch(r".+:\d+-\d+", region): + return False + else: + raise ValueError(f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand).") + class SequenceLoader: """ @@ -39,6 +54,8 @@ class SequenceLoader: Dictionary with chromosome sizes. Required if max_stochastic_shift > 0. in_memory If True, the sequences of supplied regions will be loaded into memory. + stranded + Whether the dataset regions have strand information. If None, inferred from regions. always_reverse_complement If True, all sequences will be augmented with their reverse complement. Doubles the dataset size. @@ -53,6 +70,7 @@ def __init__( genome_file: PathLike, chromsizes: dict | None, in_memory: bool = False, + stranded: bool | None = None, always_reverse_complement: bool = False, max_stochastic_shift: int = 0, regions: list[str] | None = None, @@ -61,6 +79,7 @@ def __init__( self.genome = FastaFile(genome_file) self.chromsizes = chromsizes self.in_memory = in_memory + self.stranded = stranded self.always_reverse_complement = always_reverse_complement self.max_stochastic_shift = max_stochastic_shift self.sequences = {} @@ -72,19 +91,32 @@ def __init__( def _load_sequences_into_memory(self, regions: list[str]): """Load all sequences into memory (dict).""" logger.info("Loading sequences into memory...") + strand_reverser = {'+': '-', '-': '+'} + # Check region formatting + if self.stranded is None: + self.stranded = _check_strandedness(regions[0]) + for region in tqdm(regions): - extended_sequence = self._get_extended_sequence(region) - self.sequences[f"{region}:+"] = extended_sequence + # Parse region + if self.stranded: + chrom, start_end, strand = region.split(":") + else: + chrom, start_end = region.split(":") + strand = "+" + start, end = map(int, start_end.split("-")) + + # Add region to self.sequences + extended_sequence = self._get_extended_sequence(chrom, start, end, strand) + self.sequences[f"{chrom}:{start}-{end}:{strand}"] = extended_sequence + + # Add reverse-complemented region to self.sequences if always_reverse_complement if self.always_reverse_complement: - self.sequences[f"{region}:-"] = self._reverse_complement( + self.sequences[f"{chrom}:{start}-{end}:{strand_reverser[strand]}"] = self._reverse_complement( extended_sequence ) - def _get_extended_sequence(self, region: str) -> str: + def _get_extended_sequence(self, chrom: str, start: int, end: int, strand: str) -> str: """Get sequence from genome file, extended for stochastic shifting.""" - chrom, start_end = region.split(":") - start, end = map(int, start_end.split("-")) - extended_start = max(0, start - self.max_stochastic_shift) extended_end = extended_start + (end - start) + (self.max_stochastic_shift * 2) @@ -96,32 +128,80 @@ def _get_extended_sequence(self, region: str) -> str: ) extended_end = chrom_size - return self.genome.fetch(chrom, extended_start, extended_end).upper() + seq = self.genome.fetch(chrom, extended_start, extended_end).upper() + if strand == "-": + seq = self._reverse_complement(seq) + return seq def _reverse_complement(self, sequence: str) -> str: """Reverse complement a sequence.""" return sequence.translate(self.complement)[::-1] - def get_sequence(self, region: str, strand: str = "+", shift: int = 0) -> str: - """Get sequence for a region, strand, and shift from memory or fasta.""" - key = f"{region}:{strand}" - if self.in_memory: - sequence = self.sequences[key] + def get_sequence(self, region: str, strand: str | None = None, shift: int = 0) -> str: + """ + Get sequence for a region, strand, and shift from memory or fasta. + + If no strand is given in region or strand, assumes positive strand. + + Parameters + ---------- + region + Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand). + strand + Strand to extract sequence for. Default uses region info if available and positive strand if not. + shift: + Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism. + + Returns + ------- + The DNA sequence, as a string. + """ + # If strand status is unknown (because not provided at init + # and not inferred by _load_sequences_into_memory), infer: + # Common case: with 'predict' sequenceloader + if self.stranded is None: + stranded_region = _check_strandedness(region) + else: + stranded_region = self.stranded + + # Add strand if not provided + if not stranded_region: + if strand is None: + region = f"{region}:+" + else: + region = f"{region}:{strand}" else: - sequence = self._get_extended_sequence(region) - chrom, start_end = region.split(":") + if strand is not None: + # Check whether actually stranded or just SequenceLoader setting + if _check_strandedness(region): + logger.warning( + f"Argument 'strand' provided while region {region} already had strand information. Using provided strand {strand}.", + ) + region = f"{region[:-2]}:{strand}" + else: + region = f"{region}:{strand}" + + # Parse region + chrom, start_end, strand = region.split(":") start, end = map(int, start_end.split("-")) + + # Get extended sequence + if self.in_memory: + sequence = self.sequences[region] + else: + sequence = self._get_extended_sequence(chrom, start, end, strand) + + # Extract from extended sequence start_idx = self.max_stochastic_shift + shift end_idx = start_idx + (end - start) sub_sequence = sequence[start_idx:end_idx] - # handle reverse complement on the go if not loaded into memory - if (strand == "-") and (not self.in_memory): - sub_sequence = self._reverse_complement(sub_sequence) - - # pad with Ns if sequence is shorter than expected + # Pad with Ns if sequence is shorter than expected if len(sub_sequence) < (end - start): - sub_sequence = sub_sequence.ljust(end - start, "N") + if strand == "+": + sub_sequence = sub_sequence.ljust(end - start, "N") + else: + sub_sequence = sub_sequence.rjust(end - start, "N") return sub_sequence @@ -135,7 +215,7 @@ class IndexManager: Parameters ---------- indices - List of indices in format "chrom:start-end". + List of indices in format "chr:start-end" or "chr:start-end:strand". always_reverse_complement If True, all sequences will be augmented with their reverse complement. deterministic_shift @@ -157,7 +237,7 @@ def __init__( ) def shuffle_indices(self): - """Shuffle indices. Managed by subclass AnnDataLoader.""" + """Shuffle indices. Managed by wrapping class AnnDataLoader.""" np.random.shuffle(self.augmented_indices) def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str]]: @@ -165,20 +245,25 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str augmented_indices = [] augmented_indices_map = {} for region in indices: + if not _check_strandedness(region): # If slow, can use AnnDataset stranded argument - but this validates every region's formatting as well + stranded_region = f"{region}:+" + else: + stranded_region = region + if self.deterministic_shift: - shifted_regions = self._deterministic_shift_region(region) + shifted_regions = self._deterministic_shift_region(stranded_region) for shifted_region in shifted_regions: - augmented_indices.append(f"{shifted_region}:+") - augmented_indices_map[f"{shifted_region}:+"] = region + augmented_indices.append(shifted_region) + augmented_indices_map[shifted_region] = region if self.always_reverse_complement: - augmented_indices.append(f"{shifted_region}:-") - augmented_indices_map[f"{shifted_region}:-"] = region + augmented_indices.append(_flip_region_strand(shifted_region)) + augmented_indices_map[_flip_region_strand(shifted_region)] = region else: - augmented_indices.append(f"{region}:+") - augmented_indices_map[f"{region}:+"] = region + augmented_indices.append(stranded_region) + augmented_indices_map[stranded_region] = region if self.always_reverse_complement: - augmented_indices.append(f"{region}:-") - augmented_indices_map[f"{region}:-"] = region + augmented_indices.append(_flip_region_strand(stranded_region)) + augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map def _deterministic_shift_region( @@ -190,12 +275,12 @@ def _deterministic_shift_region( This is a legacy function, it's recommended to use stochastic shifting instead. """ new_regions = [] - chrom, start_end = region.split(":") + chrom, start_end, strand = region.split(":") start, end = map(int, start_end.split("-")) for i in range(-n_shifts, n_shifts + 1): new_start = start + i * stride new_end = end + i * stride - new_regions.append(f"{chrom}:{new_start}-{new_end}") + new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") return new_regions @@ -259,15 +344,19 @@ def __init__( self.num_outputs = self.anndata.X.shape[0] self.random_reverse_complement = random_reverse_complement self.max_stochastic_shift = max_stochastic_shift - self.shuffle = False # managed by subclass AnnDataLoader + self.shuffle = False # managed by wrapping class AnnDataLoader + + # Check region formatting + stranded = _check_strandedness(self.indices[0]) self.sequence_loader = SequenceLoader( genome_file, - self.chromsizes, - in_memory, - always_reverse_complement, - max_stochastic_shift, - self.indices, + chromsizes=self.chromsizes, + in_memory=in_memory, + stranded=stranded, + always_reverse_complement=always_reverse_complement, + max_stochastic_shift=max_stochastic_shift, + regions=self.indices, ) self.index_manager = IndexManager( self.indices, @@ -308,19 +397,16 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: """Return sequence and target for a given index.""" augmented_index = self.index_manager.augmented_indices[idx] original_index = self.index_manager.augmented_indices_map[augmented_index] - - strand = "-" if augmented_index.endswith(":-") else "+" - # stochastic shift if self.max_stochastic_shift > 0: shift = np.random.randint( -self.max_stochastic_shift, self.max_stochastic_shift + 1 ) - x = self.sequence_loader.get_sequence(original_index, strand, shift) + x = self.sequence_loader.get_sequence(original_index, shift) else: - x = self.sequence_loader.get_sequence(original_index, strand) + x = self.sequence_loader.get_sequence(original_index) - # random reverse complement (always is done in the sequence loader) + # random reverse complement (always_reverse_complement is done in the sequence loader) if self.random_reverse_complement and np.random.rand() < 0.5: x = self.sequence_loader._reverse_complement(x) From 308b9de19284934f7d3b6ff0b690376a029e24cb Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 16:41:40 +0100 Subject: [PATCH 02/14] Save get_embeddings data in correct slot (.varm) --- src/crested/tl/_crested.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index b6d686d1..7e147cfb 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -570,7 +570,7 @@ def get_embeddings( """ Extract embeddings from a specified layer in the model for all regions in the dataset. - If anndata is provided, it will add the embeddings to anndata.obsm[layer_name]. + If anndata is provided, it will add the embeddings to anndata.varm[layer_name]. Parameters ---------- @@ -581,7 +581,7 @@ def get_embeddings( Returns ------- - Embeddings of shape (N, D), where D is the size of the embedding layer. + Embeddings of shape (N, D), where N is the number of regions in the dataset and D is the size of the embedding layer. """ if layer_name not in [layer.name for layer in self.model.layers]: raise ValueError(f"Layer '{layer_name}' not found in model.") @@ -597,7 +597,7 @@ def get_embeddings( embeddings = embedding_model.predict(predict_loader.data, steps=n_predict_steps) if anndata is not None: - anndata.obsm[layer_name] = embeddings + anndata.varm[layer_name] = embeddings return embeddings def predict( From 11876083b87671c1a2ece6113fb50cdf91510233 Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 16:48:15 +0100 Subject: [PATCH 03/14] Make enhancer design seq_len use model instead of var start/end columns --- src/crested/tl/_crested.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 7e147cfb..005a40e9 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -1392,10 +1392,7 @@ def enhancer_design_motif_implementation( enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference) # get input sequence length of the model - seq_len = ( - self.anndatamodule.adata.var.iloc[0]["end"] - - self.anndatamodule.adata.var.iloc[0]["start"] - ) + seq_len = self.model.input_shape[1] # determine the flanks without changes if no_mutation_flanks is not None and target_len is not None: @@ -1590,10 +1587,7 @@ def enhancer_design_in_silico_evolution( enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference) # get input sequence length of the model - seq_len = ( - self.anndatamodule.adata.var.iloc[0]["end"] - - self.anndatamodule.adata.var.iloc[0]["start"] - ) + seq_len = self.model.input_shape[1] # determine the flanks without changes if no_mutation_flanks is not None and target_len is not None: From 0bf9f1ee56c6ffb568d052edb010e0bd1fdc93bd Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 17:19:51 +0100 Subject: [PATCH 04/14] Tell user if reverse complementing stranded data --- src/crested/tl/data/_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 80b33cf9..405755c5 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -348,6 +348,11 @@ def __init__( # Check region formatting stranded = _check_strandedness(self.indices[0]) + if stranded and (always_reverse_complement or random_reverse_complement): + logger.info( + "Setting always_reverse_complement=True or random_reverse_complement=True with stranded data.", + "This means both strands are used when training and the strand information is effectively disregarded." + ) self.sequence_loader = SequenceLoader( genome_file, From 0d289da15a3c56977a95fa738da6a9521a3c79df Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 17:26:46 +0100 Subject: [PATCH 05/14] Clarify that strand should be - or + --- src/crested/tl/data/_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 405755c5..ec036b84 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -37,7 +37,9 @@ def _check_strandedness(region: str) -> bool: elif re.fullmatch(r".+:\d+-\d+", region): return False else: - raise ValueError(f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand).") + raise ValueError( + f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)." + "If provided, strand must be + or -.") class SequenceLoader: From c0b94f7fdf0b412dfee7e6bead637719a344366f Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 17:53:03 +0100 Subject: [PATCH 06/14] Fix test: missing argument name in get_sequence call --- src/crested/tl/data/_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index ec036b84..782bfe31 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -409,7 +409,7 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: shift = np.random.randint( -self.max_stochastic_shift, self.max_stochastic_shift + 1 ) - x = self.sequence_loader.get_sequence(original_index, shift) + x = self.sequence_loader.get_sequence(original_index, shift = shift) else: x = self.sequence_loader.get_sequence(original_index) From dba8e7d2289c1179af151e6351c2496c09543995 Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 19:26:18 +0100 Subject: [PATCH 07/14] Fix not actually using augmented index when getting sequence --- src/crested/tl/data/_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 782bfe31..b4538132 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -409,9 +409,9 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: shift = np.random.randint( -self.max_stochastic_shift, self.max_stochastic_shift + 1 ) - x = self.sequence_loader.get_sequence(original_index, shift = shift) + x = self.sequence_loader.get_sequence(augmented_index, shift = shift) else: - x = self.sequence_loader.get_sequence(original_index) + x = self.sequence_loader.get_sequence(augmented_index) # random reverse complement (always_reverse_complement is done in the sequence loader) if self.random_reverse_complement and np.random.rand() < 0.5: From a150bbc4291c61092aad62081a5138b8ae6bf990 Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 19:36:55 +0100 Subject: [PATCH 08/14] Remove deterministic shift --- src/crested/tl/data/_anndatamodule.py | 14 ++++++------ src/crested/tl/data/_dataset.py | 31 ++++++--------------------- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/src/crested/tl/data/_anndatamodule.py b/src/crested/tl/data/_anndatamodule.py index b59054ad..988f1bae 100644 --- a/src/crested/tl/data/_anndatamodule.py +++ b/src/crested/tl/data/_anndatamodule.py @@ -49,9 +49,6 @@ class AnnDataModule: If True, the sequences will be randomly reverse complemented during training. Default is False. max_stochastic_shift Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. Default is 0. - deterministic_shift - If true, each region will be shifted twice with stride 50bp to each side. Default is False. - This is our legacy shifting, we recommend using max_stochastic_shift instead. shuffle If True, the data will be shuffled at the end of each epoch during training. Default is True. batch_size @@ -67,11 +64,18 @@ def __init__( always_reverse_complement=True, random_reverse_complement: bool = False, max_stochastic_shift: int = 0, - deterministic_shift: bool = False, shuffle: bool = True, batch_size: int = 256, + deterministic_shift = None ): """Initialize the DataModule with the provided dataset and options.""" + if deterministic_shift is not None: + determ_shift_warning = "Argument `deterministic_shift` is deprecated and is no longer functional. Use max_stochastic_shift instead." + if max_stochastic_shift == 0: + determ_shift_warning += " Setting max_stochastic_shift to 3." + max_stochastic_shift = 3 + logger.warning(determ_shift_warning) + self.adata = adata self.genome_file = genome_file self.chromsizes_file = chromsizes_file @@ -79,7 +83,6 @@ def __init__( self.in_memory = in_memory self.random_reverse_complement = random_reverse_complement self.max_stochastic_shift = max_stochastic_shift - self.deterministic_shift = deterministic_shift self.shuffle = shuffle self.batch_size = batch_size @@ -130,7 +133,6 @@ def setup(self, stage: str) -> None: always_reverse_complement=self.always_reverse_complement, random_reverse_complement=self.random_reverse_complement, max_stochastic_shift=self.max_stochastic_shift, - deterministic_shift=self.deterministic_shift, ) self.val_dataset = AnnDataset( self.adata, diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index b4538132..501eaa2d 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -220,20 +220,16 @@ class IndexManager: List of indices in format "chr:start-end" or "chr:start-end:strand". always_reverse_complement If True, all sequences will be augmented with their reverse complement. - deterministic_shift - If True, each region will be shifted twice with stride 50bp to each side. """ def __init__( self, indices: list[str], always_reverse_complement: bool, - deterministic_shift: bool = False, ): """Initialize the IndexManager with the provided indices.""" self.indices = indices self.always_reverse_complement = always_reverse_complement - self.deterministic_shift = deterministic_shift self.augmented_indices, self.augmented_indices_map = self._augment_indices( indices ) @@ -251,21 +247,11 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str stranded_region = f"{region}:+" else: stranded_region = region - - if self.deterministic_shift: - shifted_regions = self._deterministic_shift_region(stranded_region) - for shifted_region in shifted_regions: - augmented_indices.append(shifted_region) - augmented_indices_map[shifted_region] = region - if self.always_reverse_complement: - augmented_indices.append(_flip_region_strand(shifted_region)) - augmented_indices_map[_flip_region_strand(shifted_region)] = region - else: - augmented_indices.append(stranded_region) - augmented_indices_map[stranded_region] = region - if self.always_reverse_complement: - augmented_indices.append(_flip_region_strand(stranded_region)) - augmented_indices_map[_flip_region_strand(stranded_region)] = region + augmented_indices.append(stranded_region) + augmented_indices_map[stranded_region] = region + if self.always_reverse_complement: + augmented_indices.append(_flip_region_strand(stranded_region)) + augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map def _deterministic_shift_region( @@ -318,9 +304,6 @@ class AnnDataset(BaseClass): If True, all sequences will be augmented with their reverse complement during training. max_stochastic_shift Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. - deterministic_shift - If true, each region will be shifted twice with stride 50bp to each side. - This is our legacy shifting, we recommend using max_stochastic_shift instead. """ def __init__( @@ -333,7 +316,6 @@ def __init__( random_reverse_complement: bool = False, always_reverse_complement: bool = False, max_stochastic_shift: int = 0, - deterministic_shift: bool = False, ): """Initialize the dataset with the provided AnnData object and options.""" self.anndata = self._split_anndata(anndata, split) @@ -367,8 +349,7 @@ def __init__( ) self.index_manager = IndexManager( self.indices, - always_reverse_complement=always_reverse_complement, - deterministic_shift=deterministic_shift, + always_reverse_complement=always_reverse_complement ) self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0])) From bd27269226a695d33b9299fce208aef6fd2d2c2d Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 19:48:19 +0100 Subject: [PATCH 09/14] Remove leftover deterministic shift function --- src/crested/tl/data/_dataset.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 501eaa2d..be5d63fe 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -254,23 +254,6 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map - def _deterministic_shift_region( - self, region: str, stride: int = 50, n_shifts: int = 2 - ) -> list[str]: - """ - Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. - - This is a legacy function, it's recommended to use stochastic shifting instead. - """ - new_regions = [] - chrom, start_end, strand = region.split(":") - start, end = map(int, start_end.split("-")) - for i in range(-n_shifts, n_shifts + 1): - new_start = start + i * stride - new_end = end + i * stride - new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") - return new_regions - if os.environ["KERAS_BACKEND"] == "pytorch": import torch From d9250f4830cc37fe509b30f167a35f171cfecbe1 Mon Sep 17 00:00:00 2001 From: cblaauw Date: Sun, 3 Nov 2024 23:56:27 +0100 Subject: [PATCH 10/14] Simplify stranded handling now __getitem__ always gives stranded --- src/crested/tl/data/_dataset.py | 54 +++++++++------------------------ 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index be5d63fe..6303230a 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -56,8 +56,6 @@ class SequenceLoader: Dictionary with chromosome sizes. Required if max_stochastic_shift > 0. in_memory If True, the sequences of supplied regions will be loaded into memory. - stranded - Whether the dataset regions have strand information. If None, inferred from regions. always_reverse_complement If True, all sequences will be augmented with their reverse complement. Doubles the dataset size. @@ -72,7 +70,6 @@ def __init__( genome_file: PathLike, chromsizes: dict | None, in_memory: bool = False, - stranded: bool | None = None, always_reverse_complement: bool = False, max_stochastic_shift: int = 0, regions: list[str] | None = None, @@ -81,7 +78,6 @@ def __init__( self.genome = FastaFile(genome_file) self.chromsizes = chromsizes self.in_memory = in_memory - self.stranded = stranded self.always_reverse_complement = always_reverse_complement self.max_stochastic_shift = max_stochastic_shift self.sequences = {} @@ -95,12 +91,11 @@ def _load_sequences_into_memory(self, regions: list[str]): logger.info("Loading sequences into memory...") strand_reverser = {'+': '-', '-': '+'} # Check region formatting - if self.stranded is None: - self.stranded = _check_strandedness(regions[0]) + stranded = _check_strandedness(regions[0]) for region in tqdm(regions): # Parse region - if self.stranded: + if stranded: chrom, start_end, strand = region.split(":") else: chrom, start_end = region.split(":") @@ -139,7 +134,7 @@ def _reverse_complement(self, sequence: str) -> str: """Reverse complement a sequence.""" return sequence.translate(self.complement)[::-1] - def get_sequence(self, region: str, strand: str | None = None, shift: int = 0) -> str: + def get_sequence(self, region: str, stranded: bool | None = None, shift: int = 0) -> str: """ Get sequence for a region, strand, and shift from memory or fasta. @@ -149,8 +144,9 @@ def get_sequence(self, region: str, strand: str | None = None, shift: int = 0) - ---------- region Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand). - strand - Strand to extract sequence for. Default uses region info if available and positive strand if not. + stranded + Whether the input data is stranded. Default (None) infers from sequence (at a computational cost). + If not stranded, positive strand is assumed. shift: Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism. @@ -158,31 +154,10 @@ def get_sequence(self, region: str, strand: str | None = None, shift: int = 0) - ------- The DNA sequence, as a string. """ - # If strand status is unknown (because not provided at init - # and not inferred by _load_sequences_into_memory), infer: - # Common case: with 'predict' sequenceloader - if self.stranded is None: - stranded_region = _check_strandedness(region) - else: - stranded_region = self.stranded - - # Add strand if not provided - if not stranded_region: - if strand is None: - region = f"{region}:+" - else: - region = f"{region}:{strand}" - else: - if strand is not None: - # Check whether actually stranded or just SequenceLoader setting - if _check_strandedness(region): - logger.warning( - f"Argument 'strand' provided while region {region} already had strand information. Using provided strand {strand}.", - ) - region = f"{region[:-2]}:{strand}" - else: - region = f"{region}:{strand}" - + if stranded is None: + stranded = _check_strandedness(region) + if not stranded: + region = f"{region}:+" # Parse region chrom, start_end, strand = region.split(":") start, end = map(int, start_end.split("-")) @@ -325,7 +300,6 @@ def __init__( genome_file, chromsizes=self.chromsizes, in_memory=in_memory, - stranded=stranded, always_reverse_complement=always_reverse_complement, max_stochastic_shift=max_stochastic_shift, regions=self.indices, @@ -334,7 +308,7 @@ def __init__( self.indices, always_reverse_complement=always_reverse_complement ) - self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0])) + self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0], stranded = stranded)) @staticmethod def _split_anndata(anndata: AnnData, split: str) -> AnnData: @@ -373,9 +347,11 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: shift = np.random.randint( -self.max_stochastic_shift, self.max_stochastic_shift + 1 ) - x = self.sequence_loader.get_sequence(augmented_index, shift = shift) else: - x = self.sequence_loader.get_sequence(augmented_index) + shift = 0 + + # Get sequence + x = self.sequence_loader.get_sequence(augmented_index, stranded = True, shift = shift) # random reverse complement (always_reverse_complement is done in the sequence loader) if self.random_reverse_complement and np.random.rand() < 0.5: From 2f8a4c68958fdada65b9b9c0140140b563c43155 Mon Sep 17 00:00:00 2001 From: Cas Blaauw <38132585+casblaauw@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:07:54 +0100 Subject: [PATCH 11/14] Revert "Remove deterministic shift" This reverts commit a150bbc4291c61092aad62081a5138b8ae6bf990. --- src/crested/tl/data/_anndatamodule.py | 14 ++++++------ src/crested/tl/data/_dataset.py | 31 +++++++++++++++++++++------ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/crested/tl/data/_anndatamodule.py b/src/crested/tl/data/_anndatamodule.py index 988f1bae..b59054ad 100644 --- a/src/crested/tl/data/_anndatamodule.py +++ b/src/crested/tl/data/_anndatamodule.py @@ -49,6 +49,9 @@ class AnnDataModule: If True, the sequences will be randomly reverse complemented during training. Default is False. max_stochastic_shift Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. Default is 0. + deterministic_shift + If true, each region will be shifted twice with stride 50bp to each side. Default is False. + This is our legacy shifting, we recommend using max_stochastic_shift instead. shuffle If True, the data will be shuffled at the end of each epoch during training. Default is True. batch_size @@ -64,18 +67,11 @@ def __init__( always_reverse_complement=True, random_reverse_complement: bool = False, max_stochastic_shift: int = 0, + deterministic_shift: bool = False, shuffle: bool = True, batch_size: int = 256, - deterministic_shift = None ): """Initialize the DataModule with the provided dataset and options.""" - if deterministic_shift is not None: - determ_shift_warning = "Argument `deterministic_shift` is deprecated and is no longer functional. Use max_stochastic_shift instead." - if max_stochastic_shift == 0: - determ_shift_warning += " Setting max_stochastic_shift to 3." - max_stochastic_shift = 3 - logger.warning(determ_shift_warning) - self.adata = adata self.genome_file = genome_file self.chromsizes_file = chromsizes_file @@ -83,6 +79,7 @@ def __init__( self.in_memory = in_memory self.random_reverse_complement = random_reverse_complement self.max_stochastic_shift = max_stochastic_shift + self.deterministic_shift = deterministic_shift self.shuffle = shuffle self.batch_size = batch_size @@ -133,6 +130,7 @@ def setup(self, stage: str) -> None: always_reverse_complement=self.always_reverse_complement, random_reverse_complement=self.random_reverse_complement, max_stochastic_shift=self.max_stochastic_shift, + deterministic_shift=self.deterministic_shift, ) self.val_dataset = AnnDataset( self.adata, diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 6303230a..5d798932 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -195,16 +195,20 @@ class IndexManager: List of indices in format "chr:start-end" or "chr:start-end:strand". always_reverse_complement If True, all sequences will be augmented with their reverse complement. + deterministic_shift + If True, each region will be shifted twice with stride 50bp to each side. """ def __init__( self, indices: list[str], always_reverse_complement: bool, + deterministic_shift: bool = False, ): """Initialize the IndexManager with the provided indices.""" self.indices = indices self.always_reverse_complement = always_reverse_complement + self.deterministic_shift = deterministic_shift self.augmented_indices, self.augmented_indices_map = self._augment_indices( indices ) @@ -222,11 +226,21 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str stranded_region = f"{region}:+" else: stranded_region = region - augmented_indices.append(stranded_region) - augmented_indices_map[stranded_region] = region - if self.always_reverse_complement: - augmented_indices.append(_flip_region_strand(stranded_region)) - augmented_indices_map[_flip_region_strand(stranded_region)] = region + + if self.deterministic_shift: + shifted_regions = self._deterministic_shift_region(stranded_region) + for shifted_region in shifted_regions: + augmented_indices.append(shifted_region) + augmented_indices_map[shifted_region] = region + if self.always_reverse_complement: + augmented_indices.append(_flip_region_strand(shifted_region)) + augmented_indices_map[_flip_region_strand(shifted_region)] = region + else: + augmented_indices.append(stranded_region) + augmented_indices_map[stranded_region] = region + if self.always_reverse_complement: + augmented_indices.append(_flip_region_strand(stranded_region)) + augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map @@ -262,6 +276,9 @@ class AnnDataset(BaseClass): If True, all sequences will be augmented with their reverse complement during training. max_stochastic_shift Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. + deterministic_shift + If true, each region will be shifted twice with stride 50bp to each side. + This is our legacy shifting, we recommend using max_stochastic_shift instead. """ def __init__( @@ -274,6 +291,7 @@ def __init__( random_reverse_complement: bool = False, always_reverse_complement: bool = False, max_stochastic_shift: int = 0, + deterministic_shift: bool = False, ): """Initialize the dataset with the provided AnnData object and options.""" self.anndata = self._split_anndata(anndata, split) @@ -306,7 +324,8 @@ def __init__( ) self.index_manager = IndexManager( self.indices, - always_reverse_complement=always_reverse_complement + always_reverse_complement=always_reverse_complement, + deterministic_shift=deterministic_shift, ) self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0], stranded = stranded)) From 57a910e4449f1b151ed29b4dd7ae040e07373b6c Mon Sep 17 00:00:00 2001 From: Cas Blaauw <38132585+casblaauw@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:08:00 +0100 Subject: [PATCH 12/14] Revert "Remove leftover deterministic shift function" This reverts commit bd27269226a695d33b9299fce208aef6fd2d2c2d. --- src/crested/tl/data/_dataset.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 5d798932..7c825026 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -243,6 +243,23 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map + def _deterministic_shift_region( + self, region: str, stride: int = 50, n_shifts: int = 2 + ) -> list[str]: + """ + Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. + + This is a legacy function, it's recommended to use stochastic shifting instead. + """ + new_regions = [] + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + for i in range(-n_shifts, n_shifts + 1): + new_start = start + i * stride + new_end = end + i * stride + new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") + return new_regions + if os.environ["KERAS_BACKEND"] == "pytorch": import torch From 87ae02717d9f5ea24ca8ba61f6fbd42f11b91c36 Mon Sep 17 00:00:00 2001 From: cblaauw Date: Mon, 4 Nov 2024 11:40:39 +0100 Subject: [PATCH 13/14] Fix deterministic shift --- src/crested/tl/data/_dataset.py | 77 +++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 7c825026..a815999a 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -41,6 +41,23 @@ def _check_strandedness(region: str) -> bool: f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)." "If provided, strand must be + or -.") +def _deterministic_shift_region( + region: str, stride: int = 50, n_shifts: int = 2 +) -> list[str]: + """ + Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. + + This is a legacy function, it's recommended to use stochastic shifting instead. + """ + new_regions = [] + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + for i in range(-n_shifts, n_shifts + 1): + new_start = start + i * stride + new_end = end + i * stride + new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") + return new_regions + class SequenceLoader: """ @@ -71,6 +88,7 @@ def __init__( chromsizes: dict | None, in_memory: bool = False, always_reverse_complement: bool = False, + deterministic_shift: bool = False, max_stochastic_shift: int = 0, regions: list[str] | None = None, ): @@ -79,6 +97,7 @@ def __init__( self.chromsizes = chromsizes self.in_memory = in_memory self.always_reverse_complement = always_reverse_complement + self.deterministic_shift = deterministic_shift self.max_stochastic_shift = max_stochastic_shift self.sequences = {} self.complement = str.maketrans("ACGT", "TGCA") @@ -94,23 +113,33 @@ def _load_sequences_into_memory(self, regions: list[str]): stranded = _check_strandedness(regions[0]) for region in tqdm(regions): - # Parse region - if stranded: - chrom, start_end, strand = region.split(":") - else: - chrom, start_end = region.split(":") + # Make region stranded if not + if not stranded: strand = "+" - start, end = map(int, start_end.split("-")) + region = f"{region}:{strand}" + if region[-4] == ":": + raise ValueError(f"You are double-adding strand ids to your region {region}. Check if all regions are stranded or unstranded.") - # Add region to self.sequences - extended_sequence = self._get_extended_sequence(chrom, start, end, strand) - self.sequences[f"{chrom}:{start}-{end}:{strand}"] = extended_sequence + # Add deterministic shift regions + if self.deterministic_shift: + regions = _deterministic_shift_region(region) + else: + regions = [region] - # Add reverse-complemented region to self.sequences if always_reverse_complement - if self.always_reverse_complement: - self.sequences[f"{chrom}:{start}-{end}:{strand_reverser[strand]}"] = self._reverse_complement( - extended_sequence - ) + for region in regions: + # Parse region + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + + # Add region to self.sequences + extended_sequence = self._get_extended_sequence(chrom, start, end, strand) + self.sequences[f"{chrom}:{start}-{end}:{strand}"] = extended_sequence + + # Add reverse-complemented region to self.sequences if always_reverse_complement + if self.always_reverse_complement: + self.sequences[f"{chrom}:{start}-{end}:{strand_reverser[strand]}"] = self._reverse_complement( + extended_sequence + ) def _get_extended_sequence(self, chrom: str, start: int, end: int, strand: str) -> str: """Get sequence from genome file, extended for stochastic shifting.""" @@ -228,7 +257,7 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str stranded_region = region if self.deterministic_shift: - shifted_regions = self._deterministic_shift_region(stranded_region) + shifted_regions = _deterministic_shift_region(stranded_region) for shifted_region in shifted_regions: augmented_indices.append(shifted_region) augmented_indices_map[shifted_region] = region @@ -243,23 +272,6 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map - def _deterministic_shift_region( - self, region: str, stride: int = 50, n_shifts: int = 2 - ) -> list[str]: - """ - Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. - - This is a legacy function, it's recommended to use stochastic shifting instead. - """ - new_regions = [] - chrom, start_end, strand = region.split(":") - start, end = map(int, start_end.split("-")) - for i in range(-n_shifts, n_shifts + 1): - new_start = start + i * stride - new_end = end + i * stride - new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") - return new_regions - if os.environ["KERAS_BACKEND"] == "pytorch": import torch @@ -336,6 +348,7 @@ def __init__( chromsizes=self.chromsizes, in_memory=in_memory, always_reverse_complement=always_reverse_complement, + deterministic_shift=deterministic_shift, max_stochastic_shift=max_stochastic_shift, regions=self.indices, ) From 60c71698e5ff3d6cd0d02c591c49d3e84cff834b Mon Sep 17 00:00:00 2001 From: cblaauw Date: Mon, 4 Nov 2024 13:26:55 +0100 Subject: [PATCH 14/14] Slight clean up _load_sequences_into_memory --- src/crested/tl/data/_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index a815999a..4de4f69a 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -108,7 +108,6 @@ def __init__( def _load_sequences_into_memory(self, regions: list[str]): """Load all sequences into memory (dict).""" logger.info("Loading sequences into memory...") - strand_reverser = {'+': '-', '-': '+'} # Check region formatting stranded = _check_strandedness(regions[0]) @@ -133,11 +132,11 @@ def _load_sequences_into_memory(self, regions: list[str]): # Add region to self.sequences extended_sequence = self._get_extended_sequence(chrom, start, end, strand) - self.sequences[f"{chrom}:{start}-{end}:{strand}"] = extended_sequence + self.sequences[region] = extended_sequence # Add reverse-complemented region to self.sequences if always_reverse_complement if self.always_reverse_complement: - self.sequences[f"{chrom}:{start}-{end}:{strand_reverser[strand]}"] = self._reverse_complement( + self.sequences[_flip_region_strand(region)] = self._reverse_complement( extended_sequence )