diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 853a7ce8..1339c70c 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -580,7 +580,7 @@ def get_embeddings( """ Extract embeddings from a specified layer in the model for all regions in the dataset. - If anndata is provided, it will add the embeddings to anndata.obsm[layer_name]. + If anndata is provided, it will add the embeddings to anndata.varm[layer_name]. Parameters ---------- @@ -591,7 +591,7 @@ def get_embeddings( Returns ------- - Embeddings of shape (N, D), where D is the size of the embedding layer. + Embeddings of shape (N, D), where N is the number of regions in the dataset and D is the size of the embedding layer. """ if layer_name not in [layer.name for layer in self.model.layers]: raise ValueError(f"Layer '{layer_name}' not found in model.") @@ -607,7 +607,7 @@ def get_embeddings( embeddings = embedding_model.predict(predict_loader.data, steps=n_predict_steps) if anndata is not None: - anndata.obsm[layer_name] = embeddings + anndata.varm[layer_name] = embeddings return embeddings def predict( @@ -662,7 +662,7 @@ def predict_regions( Parameters ---------- region_idx - List of regions for which to make predictions in the format "chr:start-end". + List of regions for which to make predictions in the format of your original data, either "chr:start-end" or "chr:start-end:strand". Returns ------- @@ -973,7 +973,7 @@ def calculate_contribution_scores_regions( Parameters ---------- region_idx - Region(s) for which to calculate the contribution scores in the format "chr:start-end". + Region(s) for which to calculate the contribution scores in the format "chr:start-end" or "chr:start-end:strand". class_names List of class names to calculate the contribution scores for (should match anndata.obs_names) If the list is empty, the contribution scores for the 'combined' class will be calculated. @@ -1410,10 +1410,7 @@ def enhancer_design_motif_implementation( enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference) # get input sequence length of the model - seq_len = ( - self.anndatamodule.adata.var.iloc[0]["end"] - - self.anndatamodule.adata.var.iloc[0]["start"] - ) + seq_len = self.model.input_shape[1] # determine the flanks without changes if no_mutation_flanks is not None and target_len is not None: @@ -1616,10 +1613,7 @@ def enhancer_design_in_silico_evolution( enhancer_optimizer = EnhancerOptimizer(optimize_func=_weighted_difference) # get input sequence length of the model - seq_len = ( - self.anndatamodule.adata.var.iloc[0]["end"] - - self.anndatamodule.adata.var.iloc[0]["start"] - ) + seq_len = self.model.input_shape[1] # determine the flanks without changes if no_mutation_flanks is not None and target_len is not None: diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 941233d3..4de4f69a 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import re from os import PathLike import numpy as np @@ -24,6 +25,39 @@ def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]: chromsizes_dict = chromsizes.set_index("chr")["size"].to_dict() return chromsizes_dict +def _flip_region_strand(region: str) -> str: + """Reverse the strand of a region.""" + strand_reverser = {'+': '-', '-': '+'} + return region[:-1]+strand_reverser[region[-1]] + +def _check_strandedness(region: str) -> bool: + """Check the strandedness of a region, raising an error if the formatting isn't recognised.""" + if re.fullmatch(r".+:\d+-\d+:[-+]", region): + return True + elif re.fullmatch(r".+:\d+-\d+", region): + return False + else: + raise ValueError( + f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)." + "If provided, strand must be + or -.") + +def _deterministic_shift_region( + region: str, stride: int = 50, n_shifts: int = 2 +) -> list[str]: + """ + Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. + + This is a legacy function, it's recommended to use stochastic shifting instead. + """ + new_regions = [] + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + for i in range(-n_shifts, n_shifts + 1): + new_start = start + i * stride + new_end = end + i * stride + new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") + return new_regions + class SequenceLoader: """ @@ -54,6 +88,7 @@ def __init__( chromsizes: dict | None, in_memory: bool = False, always_reverse_complement: bool = False, + deterministic_shift: bool = False, max_stochastic_shift: int = 0, regions: list[str] | None = None, ): @@ -62,6 +97,7 @@ def __init__( self.chromsizes = chromsizes self.in_memory = in_memory self.always_reverse_complement = always_reverse_complement + self.deterministic_shift = deterministic_shift self.max_stochastic_shift = max_stochastic_shift self.sequences = {} self.complement = str.maketrans("ACGT", "TGCA") @@ -72,19 +108,40 @@ def __init__( def _load_sequences_into_memory(self, regions: list[str]): """Load all sequences into memory (dict).""" logger.info("Loading sequences into memory...") + # Check region formatting + stranded = _check_strandedness(regions[0]) + for region in tqdm(regions): - extended_sequence = self._get_extended_sequence(region) - self.sequences[f"{region}:+"] = extended_sequence - if self.always_reverse_complement: - self.sequences[f"{region}:-"] = self._reverse_complement( - extended_sequence - ) + # Make region stranded if not + if not stranded: + strand = "+" + region = f"{region}:{strand}" + if region[-4] == ":": + raise ValueError(f"You are double-adding strand ids to your region {region}. Check if all regions are stranded or unstranded.") + + # Add deterministic shift regions + if self.deterministic_shift: + regions = _deterministic_shift_region(region) + else: + regions = [region] - def _get_extended_sequence(self, region: str) -> str: - """Get sequence from genome file, extended for stochastic shifting.""" - chrom, start_end = region.split(":") - start, end = map(int, start_end.split("-")) + for region in regions: + # Parse region + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + + # Add region to self.sequences + extended_sequence = self._get_extended_sequence(chrom, start, end, strand) + self.sequences[region] = extended_sequence + + # Add reverse-complemented region to self.sequences if always_reverse_complement + if self.always_reverse_complement: + self.sequences[_flip_region_strand(region)] = self._reverse_complement( + extended_sequence + ) + def _get_extended_sequence(self, chrom: str, start: int, end: int, strand: str) -> str: + """Get sequence from genome file, extended for stochastic shifting.""" extended_start = max(0, start - self.max_stochastic_shift) extended_end = extended_start + (end - start) + (self.max_stochastic_shift * 2) @@ -96,32 +153,60 @@ def _get_extended_sequence(self, region: str) -> str: ) extended_end = chrom_size - return self.genome.fetch(chrom, extended_start, extended_end).upper() + seq = self.genome.fetch(chrom, extended_start, extended_end).upper() + if strand == "-": + seq = self._reverse_complement(seq) + return seq def _reverse_complement(self, sequence: str) -> str: """Reverse complement a sequence.""" return sequence.translate(self.complement)[::-1] - def get_sequence(self, region: str, strand: str = "+", shift: int = 0) -> str: - """Get sequence for a region, strand, and shift from memory or fasta.""" - key = f"{region}:{strand}" + def get_sequence(self, region: str, stranded: bool | None = None, shift: int = 0) -> str: + """ + Get sequence for a region, strand, and shift from memory or fasta. + + If no strand is given in region or strand, assumes positive strand. + + Parameters + ---------- + region + Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand). + stranded + Whether the input data is stranded. Default (None) infers from sequence (at a computational cost). + If not stranded, positive strand is assumed. + shift: + Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism. + + Returns + ------- + The DNA sequence, as a string. + """ + if stranded is None: + stranded = _check_strandedness(region) + if not stranded: + region = f"{region}:+" + # Parse region + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + + # Get extended sequence if self.in_memory: - sequence = self.sequences[key] + sequence = self.sequences[region] else: - sequence = self._get_extended_sequence(region) - chrom, start_end = region.split(":") - start, end = map(int, start_end.split("-")) + sequence = self._get_extended_sequence(chrom, start, end, strand) + + # Extract from extended sequence start_idx = self.max_stochastic_shift + shift end_idx = start_idx + (end - start) sub_sequence = sequence[start_idx:end_idx] - # handle reverse complement on the go if not loaded into memory - if (strand == "-") and (not self.in_memory): - sub_sequence = self._reverse_complement(sub_sequence) - - # pad with Ns if sequence is shorter than expected + # Pad with Ns if sequence is shorter than expected if len(sub_sequence) < (end - start): - sub_sequence = sub_sequence.ljust(end - start, "N") + if strand == "+": + sub_sequence = sub_sequence.ljust(end - start, "N") + else: + sub_sequence = sub_sequence.rjust(end - start, "N") return sub_sequence @@ -135,7 +220,7 @@ class IndexManager: Parameters ---------- indices - List of indices in format "chrom:start-end". + List of indices in format "chr:start-end" or "chr:start-end:strand". always_reverse_complement If True, all sequences will be augmented with their reverse complement. deterministic_shift @@ -157,7 +242,7 @@ def __init__( ) def shuffle_indices(self): - """Shuffle indices. Managed by subclass AnnDataLoader.""" + """Shuffle indices. Managed by wrapping class AnnDataLoader.""" np.random.shuffle(self.augmented_indices) def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str]]: @@ -165,39 +250,27 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str augmented_indices = [] augmented_indices_map = {} for region in indices: + if not _check_strandedness(region): # If slow, can use AnnDataset stranded argument - but this validates every region's formatting as well + stranded_region = f"{region}:+" + else: + stranded_region = region + if self.deterministic_shift: - shifted_regions = self._deterministic_shift_region(region) + shifted_regions = _deterministic_shift_region(stranded_region) for shifted_region in shifted_regions: - augmented_indices.append(f"{shifted_region}:+") - augmented_indices_map[f"{shifted_region}:+"] = region + augmented_indices.append(shifted_region) + augmented_indices_map[shifted_region] = region if self.always_reverse_complement: - augmented_indices.append(f"{shifted_region}:-") - augmented_indices_map[f"{shifted_region}:-"] = region + augmented_indices.append(_flip_region_strand(shifted_region)) + augmented_indices_map[_flip_region_strand(shifted_region)] = region else: - augmented_indices.append(f"{region}:+") - augmented_indices_map[f"{region}:+"] = region + augmented_indices.append(stranded_region) + augmented_indices_map[stranded_region] = region if self.always_reverse_complement: - augmented_indices.append(f"{region}:-") - augmented_indices_map[f"{region}:-"] = region + augmented_indices.append(_flip_region_strand(stranded_region)) + augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map - def _deterministic_shift_region( - self, region: str, stride: int = 50, n_shifts: int = 2 - ) -> list[str]: - """ - Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. - - This is a legacy function, it's recommended to use stochastic shifting instead. - """ - new_regions = [] - chrom, start_end = region.split(":") - start, end = map(int, start_end.split("-")) - for i in range(-n_shifts, n_shifts + 1): - new_start = start + i * stride - new_end = end + i * stride - new_regions.append(f"{chrom}:{new_start}-{new_end}") - return new_regions - if os.environ["KERAS_BACKEND"] == "pytorch": import torch @@ -259,22 +332,31 @@ def __init__( self.num_outputs = self.anndata.X.shape[0] self.random_reverse_complement = random_reverse_complement self.max_stochastic_shift = max_stochastic_shift - self.shuffle = False # managed by subclass AnnDataLoader + self.shuffle = False # managed by wrapping class AnnDataLoader + + # Check region formatting + stranded = _check_strandedness(self.indices[0]) + if stranded and (always_reverse_complement or random_reverse_complement): + logger.info( + "Setting always_reverse_complement=True or random_reverse_complement=True with stranded data.", + "This means both strands are used when training and the strand information is effectively disregarded." + ) self.sequence_loader = SequenceLoader( genome_file, - self.chromsizes, - in_memory, - always_reverse_complement, - max_stochastic_shift, - self.indices, + chromsizes=self.chromsizes, + in_memory=in_memory, + always_reverse_complement=always_reverse_complement, + deterministic_shift=deterministic_shift, + max_stochastic_shift=max_stochastic_shift, + regions=self.indices, ) self.index_manager = IndexManager( self.indices, always_reverse_complement=always_reverse_complement, deterministic_shift=deterministic_shift, ) - self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0])) + self.seq_len = len(self.sequence_loader.get_sequence(self.indices[0], stranded = stranded)) @staticmethod def _split_anndata(anndata: AnnData, split: str) -> AnnData: @@ -308,19 +390,18 @@ def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: """Return sequence and target for a given index.""" augmented_index = self.index_manager.augmented_indices[idx] original_index = self.index_manager.augmented_indices_map[augmented_index] - - strand = "-" if augmented_index.endswith(":-") else "+" - # stochastic shift if self.max_stochastic_shift > 0: shift = np.random.randint( -self.max_stochastic_shift, self.max_stochastic_shift + 1 ) - x = self.sequence_loader.get_sequence(original_index, strand, shift) else: - x = self.sequence_loader.get_sequence(original_index, strand) + shift = 0 + + # Get sequence + x = self.sequence_loader.get_sequence(augmented_index, stranded = True, shift = shift) - # random reverse complement (always is done in the sequence loader) + # random reverse complement (always_reverse_complement is done in the sequence loader) if self.random_reverse_complement and np.random.rand() < 0.5: x = self.sequence_loader._reverse_complement(x)