diff --git a/src/sctools/__init__.py b/src/sctools/__init__.py index cea1b9f..8e37f63 100644 --- a/src/sctools/__init__.py +++ b/src/sctools/__init__.py @@ -7,6 +7,7 @@ from . import reader from . import metrics from . import platform +from . import consts from pkg_resources import get_distribution, DistributionNotFound diff --git a/src/sctools/bam.py b/src/sctools/bam.py index f37492c..9c6194a 100644 --- a/src/sctools/bam.py +++ b/src/sctools/bam.py @@ -13,15 +13,18 @@ Methods ------- -iter_tag_groups function to iterate over reads by an arbitrary tag -iter_cell_barcodes wrapper for iter_tag_groups that iterates over cell barcode tags -iter_genes wrapper for iter_tag_groups that iterates over gene tags -iter_molecules wrapper for iter_tag_groups that iterates over molecule tags +iter_tag_groups function to iterate over reads by an arbitrary tag +iter_cell_barcodes wrapper for iter_tag_groups that iterates over cell barcode tags +iter_genes wrapper for iter_tag_groups that iterates over gene tags +iter_molecules wrapper for iter_tag_groups that iterates over molecule tags Classes ------- -SubsetAlignments class to extract reads specific to requested chromosome(s) -Tagger class to add tags to sam/bam records from paired fastq records +SubsetAlignments class to extract reads specific to requested chromosome(s) +Tagger class to add tags to sam/bam records from paired fastq records +AlignmentSortOrder abstract class to represent alignment sort orders +QueryNameSortOrder alignment sort order by query name +CellMoleculeGeneQueryNameSortOrder alignment sort order hierarchically cell > molecule > gene > query name References ---------- @@ -29,14 +32,17 @@ """ -import warnings -import os import math +import os +import warnings +from abc import abstractmethod from itertools import cycle -from typing import Iterator, Generator, List, Union, Tuple +from typing import Iterator, Generator, List, Dict, Union, Tuple, Callable, Any, Optional import pysam +from . import consts + class SubsetAlignments: """Wrapper for pysam/htslib that extracts reads corresponding to requested chromosome(s) @@ -74,9 +80,8 @@ def __init__(self, alignment_file: str, open_mode: str=None): open_mode = 'r' else: raise ValueError( - 'could not autodetect file type for alignment_file %s (detectible suffixes: ' - '.sam, .bam)' % alignment_file - ) + f'Could not autodetect file type for alignment_file {alignment_file} (detectable suffixes: ' + f'.sam, .bam)') self._file: str = alignment_file self._open_mode: str = open_mode @@ -164,7 +169,7 @@ class Tagger: def __init__(self, bam_file: str) -> None: if not isinstance(bam_file, str): - raise TypeError('bam_file must be type str, not %s' % type(bam_file)) + raise TypeError(f'The argument "bam_file" must be of type str, not {type(bam_file)}') self.bam_file = bam_file # todo add type to tag_generators (make sure it doesn't introduce import issues @@ -193,7 +198,7 @@ def tag(self, output_bam_name: str, tag_generators) -> None: outbam.write(sam_record) -def split(in_bam, out_prefix, *tags, approx_mb_per_split=1000, raise_missing=True) -> List[str]: +def split(in_bam: str, out_prefix: str, *tags, approx_mb_per_split=1000, raise_missing=True) -> List[str]: """split `in_bam` by tag into files of `approx_mb_per_split` Parameters @@ -203,7 +208,7 @@ def split(in_bam, out_prefix, *tags, approx_mb_per_split=1000, raise_missing=Tru out_prefix : str Prefix for all output files; output will be named as prefix_n where n is an integer equal to the chunk number. - tags : list + tags : tuple The bam tags to split on. The tags are checked in order, and sorting is done based on the first identified tag. Further tags are only checked if the first tag is missing. This is useful in cases where sorting is executed over a corrected barcode, but some records only @@ -231,44 +236,47 @@ def split(in_bam, out_prefix, *tags, approx_mb_per_split=1000, raise_missing=Tru if len(tags) == 0: raise ValueError('At least one tag must be passed') - def _cleanup(files_to_counts, files_to_names, rm_all=False) -> None: + def _cleanup( + _files_to_counts: Dict[pysam.AlignmentFile, int], _files_to_names: Dict[pysam.AlignmentFile, str], + rm_all: bool=False) -> None: """Closes file handles and remove any empty files. Parameters ---------- - files_to_counts : dict + _files_to_counts : dict Dictionary of file objects to the number of reads each contains. - files_to_names : dict + _files_to_names : dict Dictionary of file objects to file handles. rm_all : bool, optional If True, indicates all files should be removed, regardless of count number, else only empty files without counts are removed (default = False) """ - for bamfile, count in files_to_counts.items(): + for bamfile, count in _files_to_counts.items(): # corner case: clean up files that were created, but didn't get data because # n_cell < n_barcode bamfile.close() if count == 0 or rm_all: - os.remove(files_to_names[bamfile]) - del files_to_names[bamfile] + os.remove(_files_to_names[bamfile]) + del _files_to_names[bamfile] # find correct number of subfiles to spawn bam_mb = os.path.getsize(in_bam) * 1e-6 n_subfiles = int(math.ceil(bam_mb / approx_mb_per_split)) - if n_subfiles > 500: - warnings.warn('Number of requested subfiles (%d) exceeds 500; this may cause OS errors by ' - 'exceeding fid limits' % n_subfiles) - if n_subfiles > 1000: - raise ValueError('Number of requested subfiles (%d) exceeds 1000; this will usually cause ' - 'OS errors, think about increasing max_mb_per_split.' % n_subfiles) + if n_subfiles > consts.MAX_BAM_SPLIT_SUBFILES_TO_WARN: + warnings.warn(f'Number of requested subfiles ({n_subfiles}) exceeds ' + f'{consts.MAX_BAM_SPLIT_SUBFILES_TO_WARN}; this may cause OS errors by exceeding fid limits') + if n_subfiles > consts.MAX_BAM_SPLIT_SUBFILES_TO_RAISE: + raise ValueError(f'Number of requested subfiles ({n_subfiles}) exceeds ' + f'{consts.MAX_BAM_SPLIT_SUBFILES_TO_RAISE}; this will usually cause OS errors, ' + f'think about increasing max_mb_per_split.') # create all the output files with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments: # map files to counts - files_to_counts = {} - files_to_names = {} + files_to_counts: Dict[pysam.AlignmentFile, int] = {} + files_to_names: Dict[pysam.AlignmentFile, str] = {} for i in range(n_subfiles): out_bam_name = out_prefix + '_%d.bam' % i open_bam = pysam.AlignmentFile(out_bam_name, 'wb', template=input_alignments) @@ -297,8 +305,7 @@ def _cleanup(files_to_counts, files_to_names, rm_all=False) -> None: if tag_content is None: if raise_missing: _cleanup(files_to_counts, files_to_names, rm_all=True) - raise RuntimeError( - 'alignment encountered that is missing %s tag(s).' % repr(tags)) + raise RuntimeError('Alignment encountered that is missing {repr(tags)} tag(s).') else: continue # move on to next alignment @@ -379,12 +386,12 @@ def iter_molecule_barcodes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Gene Yields ------ grouped_by_tag : Iterator[pysam.AlignedSegment] - reads sharing a unique molecule barcode (``UB`` tag) + reads sharing a unique molecule barcode tag current_tag : str the molecule barcode that records in the group all share """ - return iter_tag_groups(tag='UB', bam_iterator=bam_iterator) + return iter_tag_groups(tag=consts.MOLECULE_BARCODE_TAG_KEY, bam_iterator=bam_iterator) def iter_cell_barcodes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generator: @@ -398,12 +405,12 @@ def iter_cell_barcodes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generato Yields ------ grouped_by_tag : Iterator[pysam.AlignedSegment] - reads sharing a unique cell barcode (``CB`` tag) + reads sharing a unique cell barcode tag current_tag : str the cell barcode that reads in the group all share """ - return iter_tag_groups(tag='CB', bam_iterator=bam_iterator) + return iter_tag_groups(tag=consts.CELL_BARCODE_TAG_KEY, bam_iterator=bam_iterator) def iter_genes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generator: @@ -417,9 +424,69 @@ def iter_genes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generator: Yields ------ grouped_by_tag : Iterator[pysam.AlignedSegment] - reads sharing a unique gene id (``GE`` tag) + reads sharing a unique gene name tag current_tag : str the gene id that reads in the group all share """ - return iter_tag_groups(tag='GE', bam_iterator=bam_iterator) + return iter_tag_groups(tag=consts.GENE_NAME_TAG_KEY, bam_iterator=bam_iterator) + + +def get_tag_or_default(alignment: pysam.AlignedSegment, tag_key: str, default: Optional[str] = None) -> Optional[str]: + """Extracts the value associated to `tag_key` from `alignment`, and returns a default value + if the tag is not present.""" + try: + return alignment.get_tag(tag_key) + except KeyError: + return default + + +class AlignmentSortOrder: + """The base class of alignment sort orders.""" + @property + @abstractmethod + def key_generator(self) -> Callable[[pysam.AlignedSegment], Any]: + """Returns a callable function that calculates a sort key from given pysam.AlignedSegment.""" + raise NotImplementedError + + +class QueryNameSortOrder(AlignmentSortOrder): + """Alignment record sort order by query name.""" + @staticmethod + def get_sort_key(alignment: pysam.AlignedSegment) -> str: + return alignment.query_name + + @property + def key_generator(self): + return QueryNameSortOrder.get_sort_key + + def __repr__(self) -> str: + return 'query_name' + + +class CellMoleculeGeneQueryNameSortOrder(AlignmentSortOrder): + """Hierarchical alignment record sort order (cell barcode >= molecule barcode >= gene name >= query name).""" + def __init__( + self, + cell_barcode_tag_key: str = consts.CELL_BARCODE_TAG_KEY, + molecule_barcode_tag_key: str = consts.MOLECULE_BARCODE_TAG_KEY, + gene_name_tag_key: str = consts.GENE_NAME_TAG_KEY) -> None: + assert cell_barcode_tag_key, "Cell barcode tag key can not be None" + assert molecule_barcode_tag_key, "Molecule barcode tag key can not be None" + assert gene_name_tag_key, "Gene name tag key can not be None" + self.cell_barcode_tag_key = cell_barcode_tag_key + self.molecule_barcode_tag_key = molecule_barcode_tag_key + self.gene_name_tag_key = gene_name_tag_key + + def _get_sort_key(self, alignment: pysam.AlignedSegment) -> Tuple[str, str, str, str]: + return (get_tag_or_default(alignment, self.cell_barcode_tag_key, default='N'), + get_tag_or_default(alignment, self.molecule_barcode_tag_key, default='N'), + get_tag_or_default(alignment, self.gene_name_tag_key, default='N'), + alignment.query_name) + + @property + def key_generator(self) -> Callable[[pysam.AlignedSegment], Tuple[str, str, str, str]]: + return self._get_sort_key + + def __repr__(self) -> str: + return 'hierarchical__cell_molecule_gene_query_name' diff --git a/src/sctools/barcode.py b/src/sctools/barcode.py index 383771b..35547d1 100644 --- a/src/sctools/barcode.py +++ b/src/sctools/barcode.py @@ -22,6 +22,7 @@ import numpy as np import pysam +from . import consts from .encodings import TwoBit from .stats import base4_entropy @@ -47,11 +48,11 @@ class can optionally be constructed from an iterable where barcodes can be prese """ def __init__(self, barcodes: Mapping[str, int], barcode_length: int): if not isinstance(barcodes, Mapping): - raise TypeError('barcode set must be a dict-like object mapping barcodes to counts') + raise TypeError('The argument "barcodes" must be a dict-like object mapping barcodes to counts') self._mapping: Mapping[str, int] = barcodes if not isinstance(barcode_length, int) and barcode_length > 0: - raise ValueError('barcode length must be a positive integer') + raise ValueError('The argument "barcode_length" must be a positive integer') self._barcode_length: int = barcode_length def __contains__(self, item) -> bool: @@ -257,10 +258,9 @@ class ErrorsToCorrectBarcodesMap: """ def __init__(self, errors_to_barcodes: Mapping[str, str]): - if not isinstance(errors_to_barcodes, Mapping): - raise TypeError('errors_to_barcodes must be a mapping of erroneous barcodes to correct ' - 'barcodes, not %a' % type(errors_to_barcodes)) + raise TypeError(f'The argument "errors_to_barcodes" must be a mapping of erroneous barcodes to correct ' + f'barcodes, not {type(errors_to_barcodes)}') self._map = errors_to_barcodes def get_corrected_barcode(self, barcode: str) -> str: @@ -349,6 +349,6 @@ def correct_bam(self, bam_file: str, output_bam_file: str) -> None: try: tag = self.get_corrected_barcode(alignment.get_tag('CR')) except KeyError: # pass through the uncorrected barcode. - tag = alignment.get_tag('CR') - alignment.set_tag(tag='CB', value=tag, value_type='Z') + tag = alignment.get_tag(consts.RAW_CELL_BARCODE_TAG_KEY) + alignment.set_tag(tag=consts.CELL_BARCODE_TAG_KEY, value=tag, value_type='Z') fout.write(alignment) diff --git a/src/sctools/consts.py b/src/sctools/consts.py new file mode 100644 index 0000000..8c2c27b --- /dev/null +++ b/src/sctools/consts.py @@ -0,0 +1,36 @@ +""" +Global constants +================ + +.. currentmodule:: sctools + +This module contains global constants, such as various barcoded BAM tags, and sctools-specific +constants. +""" + +# BAM tag constants + +RAW_SAMPLE_BARCODE_TAG_KEY = 'SR' +QUALITY_SAMPLE_BARCODE_TAG_KEY = 'SY' + +MOLECULE_BARCODE_TAG_KEY = 'UB' +RAW_MOLECULE_BARCODE_TAG_KEY = 'UR' +QUALITY_MOLECULE_BARCODE_TAG_KEY = 'UY' + +CELL_BARCODE_TAG_KEY = 'CB' +RAW_CELL_BARCODE_TAG_KEY = 'CR' +QUALITY_CELL_BARCODE_TAG_KEY = 'CY' + +GENE_NAME_TAG_KEY = 'GE' +NUMBER_OF_HITS_TAG_KEY = 'NH' + +ALIGNMENT_LOCATION_TAG_KEY = 'XF' +INTRONIC_ALIGNMENT_LOCATION_TAG_VALUE = 'INTRONIC' +CODING_ALIGNMENT_LOCATION_TAG_VALUE = 'CODING' +UTR_ALIGNMENT_LOCATION_TAG_VALUE = 'UTR' +INTERGENIC_ALIGNMENT_LOCATION_TAG_VALUE = 'INTERGENIC' + +# bam.py constants + +MAX_BAM_SPLIT_SUBFILES_TO_WARN = 500 +MAX_BAM_SPLIT_SUBFILES_TO_RAISE = 1000 diff --git a/src/sctools/count.py b/src/sctools/count.py index d30ffb8..0b5d080 100644 --- a/src/sctools/count.py +++ b/src/sctools/count.py @@ -8,7 +8,11 @@ Methods ------- -bam_to_count(bam_file, cell_barcode_tag: str='CB', molecule_barcode_tag='UB', gene_id_tag='GE') +from_sorted_tagged_bam( + bam_file: str, annotation_file: str, cell_barcode_tag: str = consts.CELL_BARCODE_TAG_KEY, + molecule_barcode_tag: str=consts.MOLECULE_BARCODE_TAG_KEY, + gene_name_tag: str=consts.GENE_NAME_TAG_KEY, open_mode: str='rb') +from_mtx(matrix_mtx: str, row_index_file: str, col_index_file: str) Notes ----- @@ -16,17 +20,16 @@ The memory usage is equal to approximately 6*8 bytes per molecules in the file. """ -from typing import List, Dict, Tuple -import tempfile +import itertools import operator +from typing import List, Dict, Tuple, Set, Optional, Generator import numpy as np +import pysam import scipy.sparse as sp from scipy.io import mmread -import pysam -import gffutils -from sctools import gtf +from sctools import consts, bam class CountMatrix: @@ -40,20 +43,80 @@ def __init__(self, matrix: sp.csr_matrix, row_index: np.ndarray, col_index: np.n def matrix(self): return self._matrix + @property + def row_index(self): + return self._row_index + + @property + def col_index(self): + return self._col_index + + @staticmethod + def _get_alignments_grouped_by_query_name_generator( + bam_file: str, cell_barcode_tag: str, molecule_barcode_tag: str, open_mode: str = 'rb') -> \ + Generator[Tuple[str, Optional[str], Optional[str], List[pysam.AlignedSegment]], None, None]: + """Iterates through a query_name-sorted BAM file, groups all alignments with the same query name + + Parameters + ---------- + bam_file : str + input bam file marked by cell barcode, molecule barcode, and gene ID tags sorted in that + order + cell_barcode_tag : str + Tag that specifies the cell barcode for each read. + molecule_barcode_tag : str + Tag that specifies the molecule barcode for each read. + + Returns + ------- + a generator for tuples (query_name, cell_barcode, molecule_barcode, alignments) + """ + with pysam.AlignmentFile(bam_file, mode=open_mode) as bam_records: + for (query_name, grouper) in itertools.groupby(bam_records, key=lambda record: record.query_name): + alignments: List[pysam.AlignedSegment] = list(grouper) + cell_barcode: Optional[str] = bam.get_tag_or_default(alignments[0], cell_barcode_tag) + molecule_barcode: Optional[str] = bam.get_tag_or_default(alignments[0], molecule_barcode_tag) + yield query_name, cell_barcode, molecule_barcode, alignments + + # todo add support for generating a matrix of invalid barcodes + # todo add support for splitting spliced and unspliced reads + # todo add support for generating a map of cell barcodes + # todo add the option for stringent checks on the input (e.g. BAM sort order) + # todo once the stringent checks are in place, safely move on to the hashset-free implementation @classmethod - def from_bam( + def from_sorted_tagged_bam( cls, bam_file: str, - annotation_file: str, - cell_barcode_tag: str='CB', - molecule_barcode_tag: str='UB', - gene_id_tag: str='GE', - open_mode: str='rb', - ): + gene_name_to_index: Dict[str, int], + cell_barcode_tag: str=consts.CELL_BARCODE_TAG_KEY, + molecule_barcode_tag: str=consts.MOLECULE_BARCODE_TAG_KEY, + gene_name_tag: str=consts.GENE_NAME_TAG_KEY, + open_mode: str='rb') -> 'CountMatrix': """Generate a count matrix from a sorted, tagged bam file - Input bam file must be sorted by cell, molecule, and gene (where the gene tag varies fastest). - This module returns reads that correspond to both spliced and unspliced reads. + Notes + ----- + - Input bam file must be sorted by query name. + + - The sort order of the input BAM file is not strictly checked. If the input BAM file not sorted + by query_name, the output counts will be wrong without any warnings being issued. + + This method returns counts that correspond to both spliced and unspliced reads. + + Description of the algorithm + ---------------------------- + The implemented counting strategy is intended to closely match that of CellRanger 2.1.1 + (see the references). The following pseudo-code describes the counting algorithm: + + for each query_name (i.e. unique sequenced read): + - if only a single alignment exists, _consider_ the read + - if multiple alignments exist, + - if a unique gene name is associated to all alignments that have a gene name tag, + _consider_ the read; otherwise, the read is useless and neglect it + - if the read is to be _considered_, + - if the triple (cell barcode, molecule barcode, gene name) is not encountered before, + count it as evidence for a unique transcript; otherwise, consider the read as duplicate + and neglect it Parameters ---------- @@ -62,15 +125,15 @@ def from_bam( order cell_barcode_tag : str, optional Tag that specifies the cell barcode for each read. Reads without this tag will be ignored - (default = 'CB') + (default = consts.CELL_BARCODE_TAG_KEY) molecule_barcode_tag : str, optional Tag that specifies the molecule barcode for each read. Reads without this tag will be - ignored (default = 'UB') - gene_id_tag - Tag that specifies the gene for each read. Reads without this tag will be ignored - (default = 'GE') - annotation_file : str - gtf annotation file that was used to create gene ID tags. Used to map genes to indices + ignored (default = consts.MOLECULE_BARCODE_TAG_KEY) + gene_name_tag + Tag that specifies the gene name for each read. Reads without this tag will be ignored + (default = consts.GENE_NAME_TAG_KEY) + gene_name_to_index : dict + A map from gene names to their counts matrix column index open_mode : {'r', 'rb'}, optional indicates that the passed file is a bam file ('rb') or sam file ('r') (default = 'rb'). @@ -81,8 +144,9 @@ def from_bam( Notes ----- - Any matrices produced by this function that share the same annotation file can be concatenated - using the scipy sparse vstack function, for example: + All matrices produced by this function called on different BAM chunks that share the same annotation + file can be concatenated using the scipy sparse vstack function, since by definition, the cell barcodes + contained in different BAM chunks are mutually exclusive. for example: >>> import scipy.sparse as sp >>> A = sp.coo_matrix([[1, 2], [3, 4]]).tocsr() @@ -97,105 +161,102 @@ def from_bam( samtools sort (-t parameter): C library that can sort files as required. http://www.htslib.org/doc/samtools.html#COMMANDS_AND_OPTIONS + TagSortBam.CellSortBam: WDL task that accomplishes the sorting necessary for this module. https://github.com/HumanCellAtlas/skylab/blob/master/library/tasks/TagSortBam.wdl + Relevant parmalinks to the counting algorithm in CellRanger: + [1] https://github.com/10XGenomics/cellranger/blob/aba5d379169ff0d4bee60e3d100df35752b90383/mro/stages/counter/ + attach_bcs_and_umis/__init__.py + [2] https://github.com/10XGenomics/cellranger/blob/aba5d379169ff0d4bee60e3d100df35752b90383/lib/rust/ + annotate_reads/src/main.rs """ - # create input arrays + # map the gene from reach record to an index in the sparse matrix + n_genes = len(gene_name_to_index) + + # track which tuples (cell_barcode, molecule_barcode, gene_name) we've encountered so far + observed_cell_molecule_gene_set: Set[Tuple[str, str, str]] = set() + + # COO sparse matrix entries data: List[int] = [] cell_indices: List[int] = [] gene_indices: List[int] = [] - gene_id_to_index: Dict[str, int] = {} - gtf_reader = gtf.Reader(annotation_file) - - # map the gene from reach record to an index in the sparse matrix - for gene_index, record in enumerate(gtf_reader.filter(retain_types=['gene'])): - gene_id = record.get_attribute('gene_name') - if gene_id is None: - raise ValueError( - 'malformed GTF file detected. Record is of type gene but does not have a ' - '"gene_name" field: %s' % repr(record)) - gene_id_to_index[gene_id] = gene_index - # track which cells we've seen, and what the current cell number is n_cells = 0 - cell_id_to_index: Dict[str, int] = {} - - # process the data - current_molecule: Tuple[str, str, str] = tuple() - - with pysam.AlignmentFile(bam_file, mode=open_mode) as f: - - for sam_record in f: + cell_barcode_to_index: Dict[str, int] = {} + + grouped_records_generator = cls._get_alignments_grouped_by_query_name_generator( + bam_file, cell_barcode_tag, molecule_barcode_tag, open_mode=open_mode) + + for query_name, cell_barcode, molecule_barcode, alignments in grouped_records_generator: + + if cell_barcode is None or molecule_barcode is None: # only keep queries w/ well-formed UMIs + continue + + if len(alignments) == 1: + primary_alignment = alignments[0] + if primary_alignment.has_tag(gene_name_tag): + gene_name = primary_alignment.get_tag(gene_name_tag) + else: + continue # drop query + else: # multi-map + implicated_gene_names: Set[str] = set() + for alignment in alignments: + if alignment.has_tag(gene_name_tag): + implicated_gene_names.add(alignment.get_tag(gene_name_tag)) + if len(implicated_gene_names) == 1: # only one gene + gene_name = implicated_gene_names.__iter__().__next__() + else: + continue # drop query + + if (cell_barcode, molecule_barcode, gene_name) in observed_cell_molecule_gene_set: + continue # optical/PCR duplicate -> drop query + else: + observed_cell_molecule_gene_set.add((cell_barcode, molecule_barcode, gene_name)) + + # find the indices that this molecule should correspond to + gene_index = gene_name_to_index[gene_name] + + # if we've seen this cell before, get its index, else set it + try: + cell_index = cell_barcode_to_index[cell_barcode] + except KeyError: + cell_index = n_cells + cell_barcode_to_index[cell_barcode] = n_cells + n_cells += 1 + + # record the molecule data + data.append(1) # one count of this molecule + cell_indices.append(cell_index) + gene_indices.append(gene_index) - # get the tags that define the record's molecular identity - try: - gene: str = sam_record.get_tag(gene_id_tag) - cell: str = sam_record.get_tag(cell_barcode_tag) - molecule: str = sam_record.get_tag(molecule_barcode_tag) - except KeyError: # if a record is missing any of these, just drop it. - continue - - # each molecule is counted only once - if current_molecule == (gene, cell, molecule): - continue - - # find the indices that this molecule should correspond to - gene_index = gene_id_to_index[gene] - - # if we've seen this cell before, get its index, else set it - try: - cell_index = cell_id_to_index[cell] - except KeyError: - cell_index = n_cells - cell_id_to_index[cell] = n_cells - n_cells += 1 - - # record the molecule data - data.append(1) # one count of this molecule - cell_indices.append(cell_index) - gene_indices.append(gene_index) - - # set the current molecule - current_molecule = (gene, cell, molecule) + # convert into coo_matrix + coordinate_matrix = sp.coo_matrix( + (data, (cell_indices, gene_indices)), shape=(n_cells, n_genes), dtype=np.uint32) - # get shape - gene_number = len(gene_id_to_index) - cell_number = len(cell_indices) - shape = (cell_number, gene_number) + # convert to a csr sparse matrix and return + col_index = np.asarray([k for k, v in sorted(gene_name_to_index.items(), key=operator.itemgetter(1))]) + row_index = np.asarray([k for k, v in sorted(cell_barcode_to_index.items(), key=operator.itemgetter(1))]) - # convert into coo_matrix - coordinate_matrix = sp.coo_matrix((data, (cell_indices, gene_indices)), - shape=shape, dtype=np.uint32) - - # convert into csr matrix and return - col_iterable = [k for k, v in sorted(gene_id_to_index.items(), key=operator.itemgetter(1))] - row_iterable = [k for k, v in sorted(cell_id_to_index.items(), key=operator.itemgetter(1))] - col_index = np.array(col_iterable) - row_index = np.array(row_iterable) return cls(coordinate_matrix.tocsr(), row_index, col_index) - # todo add support for generating a matrix of invalid barcodes - # todo add support for splitting spliced and unspliced reads - # todo add support for generating a map of cell barcodes - - def save(self, prefix: str): + def save(self, prefix: str) -> None: sp.save_npz(prefix + '.npz', self._matrix, compressed=True) np.save(prefix + '_row_index.npy', self._row_index) np.save(prefix + '_col_index.npy', self._col_index) @classmethod - def load(cls, prefix: str): + def load(cls, prefix: str) -> 'CountMatrix': matrix = sp.load_npz(prefix + '.npz') row_index = np.load(prefix + '_row_index.npy') col_index = np.load(prefix + '_col_index.npy') return cls(matrix, row_index, col_index) @classmethod - def merge_matrices(cls, input_prefixes: str): + def merge_matrices(cls, input_prefixes: str) -> 'CountMatrix': col_indices = [np.load(p + '_col_index.npy') for p in input_prefixes] row_indices = [np.load(p + '_row_index.npy') for p in input_prefixes] matrices = [sp.load_npz(p + '.npz') for p in input_prefixes] @@ -207,7 +268,7 @@ def merge_matrices(cls, input_prefixes: str): return cls(matrix, row_index, col_index) @classmethod - def from_mtx(cls, matrix_mtx: str, row_index_file: str, col_index_file: str): + def from_mtx(cls, matrix_mtx: str, row_index_file: str, col_index_file: str) -> 'CountMatrix': """ Parameters diff --git a/src/sctools/encodings.py b/src/sctools/encodings.py index 1491310..4ff29c9 100644 --- a/src/sctools/encodings.py +++ b/src/sctools/encodings.py @@ -160,7 +160,7 @@ def __getitem__(self, byte: int) -> int: return self.map_[byte] except KeyError: if byte not in self.iupac_ambiguous: - raise KeyError('%s is not a valid IUPAC nucleotide code' % chr(byte)) + raise KeyError(f'{chr(byte)} is not a valid IUPAC nucleotide code') return random.randint(0, 3) encoding_map: TwoBitEncodingMap = TwoBitEncodingMap() diff --git a/src/sctools/fastq.py b/src/sctools/fastq.py index 3bb470b..7c6172e 100644 --- a/src/sctools/fastq.py +++ b/src/sctools/fastq.py @@ -29,7 +29,7 @@ from collections import namedtuple from typing import Iterable, AnyStr, Iterator, Union, Tuple -from . import reader +from . import reader, consts from .barcode import ErrorsToCorrectBarcodesMap @@ -75,9 +75,9 @@ def name(self) -> AnyStr: def name(self, value): """fastq record name""" if not isinstance(value, (bytes, str)): - raise TypeError('fastq name must be bytes') + raise TypeError('FASTQ name must be bytes') elif not value.startswith(b'@'): - raise ValueError('fastq name must start with @') + raise ValueError('FASTQ name must start with @') else: self._name = value @@ -87,9 +87,9 @@ def sequence(self) -> AnyStr: @sequence.setter def sequence(self, value): - """fastq nucleotide sequence""" + """FASTQ nucleotide sequence""" if not isinstance(value, (bytes, str)): - raise TypeError('fastq sequence must be str or bytes') + raise TypeError('FASTQ sequence must be str or bytes') else: self._sequence = value @@ -99,9 +99,9 @@ def name2(self) -> AnyStr: @name2.setter def name2(self, value): - """second fastq record name field (rarely used)""" + """second FASTQ record name field (rarely used)""" if not isinstance(value, (bytes, str)): - raise TypeError('fastq name2 must be str or bytes') + raise TypeError('FASTQ name2 must be str or bytes') else: self._name2 = value @@ -111,9 +111,9 @@ def quality(self) -> AnyStr: @quality.setter def quality(self, value): - """fastq record base call quality scores""" + """FASTQ record base call quality scores""" if not isinstance(value, (bytes, str)): - raise TypeError('fastq quality must be str or bytes') + raise TypeError('FASTQ quality must be str or bytes') else: self._quality = value @@ -142,23 +142,23 @@ class StrRecord(Record): Parameters ---------- record : Iterable[str] - Iterable of 4 bytes strings that comprise a fastq record + Iterable of 4 bytes strings that comprise a FASTQ record Attributes ---------- name : str - fastq record name + FASTQ record name sequence : str - fastq nucleotide sequence + FASTQ nucleotide sequence name2 : str - second fastq record name field (rarely used) + second FASTQ record name field (rarely used) quality : str base call quality for each nucleotide in sequence Methods ------- average_quality() - The average quality of the fastq record + The average quality of the FASTQ record """ @@ -175,11 +175,11 @@ def name(self) -> str: @name.setter def name(self, value): - """fastq record name""" + """FASTQ record name""" if not isinstance(value, (bytes, str)): - raise TypeError('fastq name must be str or bytes') + raise TypeError('FASTQ name must be str or bytes') if not value.startswith('@'): - raise ValueError('fastq name must start with @') + raise ValueError('FASTQ name must start with @') else: self._name = value @@ -190,7 +190,7 @@ def average_quality(self) -> float: class Reader(reader.Reader): - """Fastq Reader that defines some special methods for reading and summarizing fastq data. + """Fastq Reader that defines some special methods for reading and summarizing FASTQ data. Simple reader class that exposes an __iter__ and __len__ method @@ -225,12 +225,12 @@ def _record_grouper(iterable): return zip(*args) def __iter__(self) -> Iterator[Tuple[str]]: - """Iterate over a fastq file, returning records + """Iterate over a FASTQ file, returning records Yields ------ fastq_record : Tuple[str] - tuple of length 4 containing the name, sequence, name2, and quality for a fastq record + tuple of length 4 containing the name, sequence, name2, and quality for a FASTQ record """ record_type = StrRecord if self._mode == 'r' else Record @@ -244,7 +244,7 @@ def __iter__(self) -> Iterator[Tuple[str]]: def extract_barcode(record, embedded_barcode) -> Tuple[Tuple[str, str, str], Tuple[str, str, str]]: - """Extracts barcodes from a fastq record at positions defined by an EmbeddedBarcode object. + """Extracts barcodes from a FASTQ record at positions defined by an EmbeddedBarcode object. Parameters ---------- @@ -269,7 +269,7 @@ def extract_barcode(record, embedded_barcode) -> Tuple[Tuple[str, str, str], Tup # todo the reader subclasses need better docs class EmbeddedBarcodeGenerator(Reader): - """Generate barcodes from a fastq file(s) from positions defined by EmbeddedBarcode(s) + """Generate barcodes from a FASTQ file(s) from positions defined by EmbeddedBarcode(s) Extracted barcode objects are produced in a form that is consumable by pysam's bam and sam set_tag methods. @@ -280,9 +280,9 @@ class EmbeddedBarcodeGenerator(Reader): tag objects defining start and end of the sequence containing the tag, and the tag identifiers for sequence and quality tags fastq_files : str | List, optional - fastq file or files to be read. (default = sys.stdin) + FASTQ file or files to be read. (default = sys.stdin) mode : {'r', 'rb'}, optional - open mode for fastq files. If 'r', return string. If 'rb', return bytes (default = 'r') + open mode for FASTQ files. If 'r', return string. If 'rb', return bytes (default = 'r') """ @@ -291,7 +291,7 @@ def __init__(self, fastq_files, embedded_barcodes, *args, **kwargs): self.embedded_barcodes = embedded_barcodes def __iter__(self): - """iterates over barcodes extracted from fastq""" + """iterates over barcodes extracted from FASTQ""" for record in super().__iter__(): # iterates records; we extract barcodes. barcodes = [] for barcode in self.embedded_barcodes: @@ -301,7 +301,7 @@ def __iter__(self): # todo the reader subclasses need better docs class BarcodeGeneratorWithCorrectedCellBarcodes(Reader): - """Generate barcodes from fastq file(s) from positions defined by EmbeddedBarcode(s) + """Generate barcodes from FASTQ file(s) from positions defined by EmbeddedBarcode(s) Extracted barcode objects are produced in a form that is consumable by pysam's bam and sam set_tag methods. In this class, one EmbeddedBarcode must be defined as an @@ -311,7 +311,7 @@ class BarcodeGeneratorWithCorrectedCellBarcodes(Reader): Parameters ---------- fastq_files : str | List, optional - fastq file or files to be read. (default = sys.stdin) + FASTQ file or files to be read. (default = sys.stdin) mode : {'r', 'rb'}, optional open mode for fastq files. If 'r', return string. If 'rb', return bytes (default = 'r') whitelist : str @@ -381,6 +381,6 @@ def extract_cell_barcode(self, record: Tuple[str], cb: EmbeddedBarcode): seq_tag, qual_tag = extract_barcode(record, cb) try: corrected_cb = self._error_mapping.get_corrected_barcode(seq_tag[1]) - return seq_tag, qual_tag, ('CB', corrected_cb, 'Z') + return seq_tag, qual_tag, (consts.CELL_BARCODE_TAG_KEY, corrected_cb, 'Z') except KeyError: return seq_tag, qual_tag diff --git a/src/sctools/gtf.py b/src/sctools/gtf.py index 7da3023..9c5d71e 100644 --- a/src/sctools/gtf.py +++ b/src/sctools/gtf.py @@ -16,17 +16,20 @@ https://useast.ensembl.org/info/website/upload/gff.html """ +import logging import string -from typing import List, Dict, Generator, Iterable +from typing import List, Dict, Generator, Iterable, Union from . import reader +_logger = logging.getLogger(__name__) -class Record: + +class GTFRecord: """Data class for storing and interacting with GTF records Subclassed to produce exon, transcript, and gene-specific record types. - A gtf record has 8 fixed fields which are followed by optional fields separated by ;\t, which + A GTF record has 8 fixed fields which are followed by optional fields separated by ;\t, which are stored by this class in the attributes field and accessible by get_attribute. Fixed fields are accessible by name. @@ -84,11 +87,8 @@ def __init__(self, record: str): try: key, _, value = field.strip().partition(' ') self._attributes[key] = value.strip('"') - except: - print(field) - print(field.strip().split()) - print(len(field.strip().split())) - raise + except Exception: + raise RuntimeError(f'Error parsing field "{field}" of GTF record "{record}"') def __repr__(self): return '' % self.__str__() @@ -145,7 +145,7 @@ def frame(self) -> str: def size(self) -> int: size = self.end - self.start if size < 0: - raise ValueError('invalid record: negative size %d (start > end)' % size) + raise ValueError(f'Invalid record: negative size {size} (start > end)') else: return size @@ -207,9 +207,9 @@ class Reader(reader.Reader): Methods ------- filter(retain_types: Iterable[str]) - Iterate over a gtf file, only yielding records in `retain_types`. + Iterate over a GTF file, only yielding records in `retain_types`. __iter__() - iterate over gtf records in file, yielding `Record` objects + iterate over GTF records in file, yielding `Record` objects See Also -------- @@ -222,10 +222,10 @@ def __init__(self, files='-', mode='r', header_comment_char='#'): def __iter__(self): for line in super().__iter__(): - yield Record(line) + yield GTFRecord(line) def filter(self, retain_types: Iterable[str]) -> Generator: - """Iterate over a gtf file, returning only record whose feature type is in retain_types. + """Iterate over a GTF file, returning only record whose feature type is in retain_types. Features are stored in GTF field 2. @@ -244,3 +244,45 @@ def filter(self, retain_types: Iterable[str]) -> Generator: for record in self: if record.feature in retain_types: yield record + + +# todo this lenient behavior is deemed to change in the future (warning -> exception) +def _resolve_multiple_gene_names(gene_name: str): + _logger.warning(f'Multiple entries encountered for "{gene_name}". Please validate the input GTF file(s). ' + f'Skipping the record for now; in the future, this will be considered as a ' + f'malformed GTF file.') + + +def extract_gene_names( + files: Union[str, List[str]]='-', mode: str='r', header_comment_char: str='#') -> Dict[str, int]: + """Extract gene names from GTF file(s) and returns a map from gene names to their corresponding + occurrence orders in the given file(s). + + Parameters + ---------- + files : Union[str, List], optional + File(s) to read. If '-', read sys.stdin (default = '-') + mode : {'r', 'rb'}, optional + Open mode. If 'r', read strings. If 'rb', read bytes (default = 'r'). + header_comment_char : str, optional + lines beginning with this character are skipped (default = '#') + + Returns + ------- + Dict[str, int] + A map from gene names to their linear index + """ + gene_name_to_index: Dict[str, int] = dict() + gene_index = 0 + for record in Reader(files, mode, header_comment_char).filter(retain_types=['gene']): + gene_name = record.get_attribute('gene_name') + if gene_name is None: + raise ValueError( + f'Malformed GTF file detected. Record is of type gene but does not have a ' + f'"gene_name" field: {record}') + if gene_name in gene_name_to_index: + _resolve_multiple_gene_names(gene_name) + continue + gene_name_to_index[gene_name] = gene_index + gene_index += 1 + return gene_name_to_index diff --git a/src/sctools/metrics/aggregator.py b/src/sctools/metrics/aggregator.py index ca8ec6c..1991261 100644 --- a/src/sctools/metrics/aggregator.py +++ b/src/sctools/metrics/aggregator.py @@ -33,12 +33,13 @@ """ -from typing import Iterable, Tuple, Counter, List, Sequence from collections import Counter +from typing import Iterable, Tuple, Counter, List, Sequence -import pysam import numpy as np +import pysam +from sctools import consts from sctools.stats import OnlineGaussianSufficientStatistic @@ -59,8 +60,7 @@ class is subclassed by ``GeneMetrics`` and ``CellMetrics``, which define data-sp Number of reads that are categorized by 10x genomics cellranger as "noise". Refers to long polymers, or reads with high numbers of N (ambiguous) nucleotides perfect_molecule_barcodes : int - The number of reads with molecule barcodes that have no errors (cell barcode tag ``CB`` == - raw barcode tag ``UB``) + The number of reads with molecule barcodes that have no errors (cell barcode tag == raw barcode tag) reads_mapped_exonic : int The number of reads for this entity that are mapped to exons reads_mapped_intronic : int @@ -254,9 +254,10 @@ def parse_molecule( self._molecule_barcode_fraction_bases_above_30.update( self._quality_above_threshold( - 30, self._quality_string_to_numeric(record.get_tag('UY')))) + 30, self._quality_string_to_numeric(record.get_tag(consts.QUALITY_MOLECULE_BARCODE_TAG_KEY)))) - self.perfect_molecule_barcodes += record.get_tag('UR') == record.get_tag('UB') + self.perfect_molecule_barcodes += ( + record.get_tag(consts.RAW_MOLECULE_BARCODE_TAG_KEY) == record.get_tag(consts.MOLECULE_BARCODE_TAG_KEY)) self._genomic_reads_fraction_bases_quality_above_30.update( self._quality_above_threshold(30, record.query_alignment_qualities)) @@ -275,19 +276,19 @@ def parse_molecule( reference: int = record.reference_id self._fragment_histogram[reference, position, strand, tags] += 1 - alignment_location = record.get_tag('XF') - if alignment_location == 'CODING': + alignment_location = record.get_tag(consts.ALIGNMENT_LOCATION_TAG_KEY) + if alignment_location == consts.CODING_ALIGNMENT_LOCATION_TAG_VALUE: self.reads_mapped_exonic += 1 - elif alignment_location == 'INTRONIC': + elif alignment_location == consts.INTRONIC_ALIGNMENT_LOCATION_TAG_VALUE: self.reads_mapped_intronic += 1 - elif alignment_location == 'UTR': + elif alignment_location == consts.UTR_ALIGNMENT_LOCATION_TAG_VALUE: self.reads_mapped_utr += 1 # todo check if read maps outside window (needs gene model) # todo create distances from terminate side (needs gene model) # uniqueness - number_mappings = record.get_tag('NH') + number_mappings = record.get_tag(consts.NUMBER_OF_HITS_TAG_KEY) if number_mappings == 1: self.reads_mapped_uniquely += 1 else: @@ -452,13 +453,14 @@ def parse_extra_fields(self, tags: Sequence[str], record: pysam.AlignedSegment) """ self._cell_barcode_fraction_bases_above_30.update( self._quality_above_threshold( - 30, self._quality_string_to_numeric(record.get_tag('CY')))) + 30, self._quality_string_to_numeric(record.get_tag(consts.QUALITY_CELL_BARCODE_TAG_KEY)))) - self.perfect_cell_barcodes += record.get_tag('UR') == record.get_tag('UB') + self.perfect_cell_barcodes += ( + record.get_tag(consts.RAW_CELL_BARCODE_TAG_KEY) == record.get_tag(consts.CELL_BARCODE_TAG_KEY)) try: - alignment_location = record.get_tag('XF') - if alignment_location == 'INTERGENIC': + alignment_location = record.get_tag(consts.ALIGNMENT_LOCATION_TAG_KEY) + if alignment_location == consts.INTERGENIC_ALIGNMENT_LOCATION_TAG_VALUE: self.reads_mapped_intergenic += 1 except KeyError: self.reads_unmapped += 1 diff --git a/src/sctools/platform.py b/src/sctools/platform.py index 87d1a34..6b01e81 100644 --- a/src/sctools/platform.py +++ b/src/sctools/platform.py @@ -18,11 +18,9 @@ """ import argparse -from typing import Iterable, List +from typing import Iterable, List, Dict -import scipy.sparse as sp - -from sctools import fastq, bam, metrics, count +from sctools import fastq, bam, metrics, count, consts, gtf class GenericPlatform: @@ -246,9 +244,9 @@ def bam_to_count_matrix(cls, args: Iterable[str]=None) -> int: """ parser = argparse.ArgumentParser() parser.set_defaults( - cell_barcode_tag='CB', - molecule_barcode_tag='UB', - gene_id_tag='GE' + cell_barcode_tag=consts.CELL_BARCODE_TAG_KEY, + molecule_barcode_tag=consts.MOLECULE_BARCODE_TAG_KEY, + gene_name_tag=consts.GENE_NAME_TAG_KEY ) parser.add_argument('-b', '--bam-file', help='input_bam_file', required=True) parser.add_argument( @@ -258,13 +256,13 @@ def bam_to_count_matrix(cls, args: Iterable[str]=None) -> int: help='gtf annotation file that bam_file was aligned against') parser.add_argument( '-c', '--cell-barcode-tag', - help='tag that identifies the cell barcode (default = "CB")') + help=f'tag that identifies the cell barcode (default = {consts.CELL_BARCODE_TAG_KEY})') parser.add_argument( '-m', '--molecule-barcode-tag', - help='tag that identifies the molecule barcode (default = "UB")') + help=f'tag that identifies the molecule barcode (default = {consts.MOLECULE_BARCODE_TAG_KEY})') parser.add_argument( '-g', '--gene-id-tag', - help='tag that identifies the gene id (default = "GE")') + help=f'tag that identifies the gene name (default = {consts.GENE_NAME_TAG_KEY})') if args is not None: args = parser.parse_args(args) @@ -274,12 +272,15 @@ def bam_to_count_matrix(cls, args: Iterable[str]=None) -> int: # assume bam file unless the file explicitly has a sam suffix open_mode = 'r' if args.bam_file.endswith('.sam') else 'rb' - matrix = count.CountMatrix.from_bam( + # load gene names from the annotation file + gene_name_to_index: Dict[str, int] = gtf.extract_gene_names(args.gtf_annotation_file) + + matrix = count.CountMatrix.from_sorted_tagged_bam( bam_file=args.bam_file, - annotation_file=args.gtf_annotation_file, + gene_name_to_index=gene_name_to_index, cell_barcode_tag=args.cell_barcode_tag, molecule_barcode_tag=args.molecule_barcode_tag, - gene_id_tag=args.gene_id_tag, + gene_name_tag=args.gene_name_tag, open_mode=open_mode ) matrix.save(args.output_prefix) @@ -353,9 +354,21 @@ class TenXV2(GenericPlatform): # 10x contains three barcodes embedded within sequencing reads. The below objects define the # start and end points of those barcodes relative to the start of the sequence, and the # GA4GH standard tags that the extracted barcodes should be labeled with in the BAM file. - cell_barcode = fastq.EmbeddedBarcode(start=0, end=16, quality_tag='CY', sequence_tag='CR') - molecule_barcode = fastq.EmbeddedBarcode(start=16, end=26, quality_tag='UY', sequence_tag='UR') - sample_barcode = fastq.EmbeddedBarcode(start=0, end=8, quality_tag='SY', sequence_tag='SR') + cell_barcode = fastq.EmbeddedBarcode( + start=0, + end=16, + quality_tag=consts.QUALITY_CELL_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_CELL_BARCODE_TAG_KEY) + molecule_barcode = fastq.EmbeddedBarcode( + start=16, + end=26, + quality_tag=consts.QUALITY_MOLECULE_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_MOLECULE_BARCODE_TAG_KEY) + sample_barcode = fastq.EmbeddedBarcode( + start=0, + end=8, + quality_tag=consts.QUALITY_SAMPLE_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_SAMPLE_BARCODE_TAG_KEY) @classmethod def _tag_bamfile( diff --git a/src/sctools/reader.py b/src/sctools/reader.py index 08998be..588119c 100644 --- a/src/sctools/reader.py +++ b/src/sctools/reader.py @@ -97,13 +97,13 @@ def __init__(self, files='-', mode='r', header_comment_char=None): if all(isinstance(f, str) for f in files): self._files = files else: - raise TypeError('all passed files must be type str') + raise TypeError('All passed files must be type str') else: - raise TypeError('files must be a string filename or a list of such names.') + raise TypeError('Files must be a string filename or a list of such names.') # set open mode: if mode not in {'r', 'rb'}: - raise ValueError('mode must be one of r, rb') + raise ValueError("Mode must be one of 'r', 'rb'") self._mode = mode if isinstance(header_comment_char, str) and mode == 'rb': diff --git a/src/sctools/test/data/cell-sorted2.bam b/src/sctools/test/data/cell-sorted2.bam deleted file mode 100644 index a0189fb..0000000 Binary files a/src/sctools/test/data/cell-sorted2.bam and /dev/null differ diff --git a/src/sctools/test/data/chr1.30k_records.gtf.gz b/src/sctools/test/data/chr1.30k_records.gtf.gz new file mode 100644 index 0000000..36e6f0f Binary files /dev/null and b/src/sctools/test/data/chr1.30k_records.gtf.gz differ diff --git a/src/sctools/test/data/chr1.gtf.gz b/src/sctools/test/data/chr1.gtf.gz deleted file mode 100644 index 0c7be5c..0000000 Binary files a/src/sctools/test/data/chr1.gtf.gz and /dev/null differ diff --git a/src/sctools/test/test_bam.py b/src/sctools/test/test_bam.py index 3fca08d..84b1502 100644 --- a/src/sctools/test/test_bam.py +++ b/src/sctools/test/test_bam.py @@ -1,11 +1,10 @@ -import os -import pytest import glob +import os import pysam +import pytest -from .. import bam, platform - +from .. import bam, platform, consts data_dir = os.path.split(__file__)[0] + '/data/' @@ -86,7 +85,7 @@ def bamfile(request): def test_split_bam_raises_value_error_when_passed_bam_without_barcodes(bamfile): split_size = 0.02 # our test data is very small, 0.01mb = ~10kb, which should yield 5 files. with pytest.raises(RuntimeError): - bam.split(bamfile, 'test_output', 'CB', approx_mb_per_split=split_size) + bam.split(bamfile, 'test_output', consts.CELL_BARCODE_TAG_KEY, approx_mb_per_split=split_size) @pytest.fixture @@ -103,7 +102,9 @@ def tagged_bam(): def test_split_on_tagged_bam(tagged_bam): split_size = 0.005 # our test data is very small, this value should yield 3 files - outputs = bam.split(tagged_bam, 'test_output', 'CB', 'CR', approx_mb_per_split=split_size) + outputs = bam.split( + tagged_bam, 'test_output', consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY, + approx_mb_per_split=split_size) assert len(outputs) == 3 # cleanup @@ -114,7 +115,9 @@ def test_split_on_tagged_bam(tagged_bam): def test_split_with_large_chunk_size_generates_one_file(tagged_bam): split_size = 1024 # our test data is very small, this value should yield 1 file - outputs = bam.split(tagged_bam, 'test_output', 'CB', 'CR', approx_mb_per_split=split_size) + outputs = bam.split( + tagged_bam, 'test_output', consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY, + approx_mb_per_split=split_size) assert len(outputs) == 1 # the file should be full size @@ -130,8 +133,8 @@ def test_split_with_large_chunk_size_generates_one_file(tagged_bam): def test_split_with_raise_missing_true_raises_warning_without_cr_barcode_passed(tagged_bam): split_size = 1024 # our test data is very small, this value should yield 1 file with pytest.raises(RuntimeError): - outputs = bam.split(tagged_bam, 'test_output', 'CB', approx_mb_per_split=split_size, - raise_missing=True) + outputs = bam.split( + tagged_bam, 'test_output', consts.CELL_BARCODE_TAG_KEY, approx_mb_per_split=split_size, raise_missing=True) # cleanup os.remove(tagged_bam) # clean up @@ -141,8 +144,8 @@ def test_split_with_raise_missing_true_raises_warning_without_cr_barcode_passed( def test_split_succeeds_with_raise_missing_false_and_no_cr_barcode_passed(tagged_bam): split_size = 1024 # our test data is very small, this value should yield 1 file - outputs = bam.split(tagged_bam, 'test_output', 'CB', approx_mb_per_split=split_size, - raise_missing=False) + outputs = bam.split( + tagged_bam, 'test_output', consts.CELL_BARCODE_TAG_KEY, approx_mb_per_split=split_size, raise_missing=False) assert len(outputs) == 1 diff --git a/src/sctools/test/test_barcode.py b/src/sctools/test/test_barcode.py index 246593b..6c04a87 100644 --- a/src/sctools/test/test_barcode.py +++ b/src/sctools/test/test_barcode.py @@ -1,9 +1,10 @@ import os + import numpy as np -import pytest import pysam -from .. import barcode, encodings, platform +import pytest +from .. import barcode, encodings, platform, consts data_dir = os.path.split(__file__)[0] + '/data/' @@ -136,7 +137,7 @@ def test_correct_bam_produces_cb_tags(tagged_bamfile, truncated_whitelist_from_1 with pysam.AlignmentFile(outbam, 'rb') as f: for record in f: try: - success = record.get_tag('CB') + success = record.get_tag(consts.CELL_BARCODE_TAG_KEY) except KeyError: continue assert success diff --git a/src/sctools/test/test_count.py b/src/sctools/test/test_count.py index bafa07f..05f5ccc 100644 --- a/src/sctools/test/test_count.py +++ b/src/sctools/test/test_count.py @@ -2,43 +2,629 @@ Testing for Count Matrix Construction ===================================== - bam_file: str, - annotation_file: str, - cell_barcode_tag: str='CB', - molecule_barcode_tag: str='UB', - gene_id_tag: str='GE', - open_mode: str='rb', +The test generates (1) a random count matrix, and (2) corresponding alignment records, and writes them to disk +(a BAM file, count matrix, row and column indices). The alignment records are expected to produce the same count +matrix according to the counting algorithm implemented in `sctools:bam.from_sorted_tagged_bam`. Gene names are +fetched from an annotations GTF file that is a subset of GENCODE annotations (see `_test_annotation_file` below). +Notes +----- + +- The agreement between the synthetic count matrix and the synthetic BAM file is contingent on the + agreement between the counting algorithm implemented in `sctools:bam.from_sorted_tagged_bam` and + the test data generator (see SyntheticTaggedBAMGenerator below). Therefore, future changes in the + counting algorithm must be accompanied by a corresponding change in the test data generation class. + Otherwise, the tests will fail. + +- We have adopt a minimal test suite design strategy, in the sense that the synthetic test data is only complete + to the degree that is required by `sctools:bam.from_sorted_tagged_bam`. As such, the synthetic BAM file lacks + the following features: + + * flag, + * query_sequence, + * query_quality, + * CIGAR string, + * cell barcode quality tag, + * molecule barcode quality tag, + * raw cell and molecule barcodes, + + At the time of writing, the counting algorithm **only** relies on the BAM tags. + +- SyntheticTaggedBAMGenerator generates four types of alignment records: + + * necessary alignments -- these records contain one unique cell/molecule/gene tag for each cell/gene count + unit, according to the randomly generated count matrix. Necessary alignments are also sufficient + in the sense that they are expected to reproduce the count matrix in the absence of any other alignment + record. + + * redundant alignments -- these records are expected to be ignored by the counting algorithm and have three + subtypes: + + - duplicate alignments -- these are randomly picked from necessary alignments, though, they are given a new + query name (to mimic PCR and optical duplicates). + + - incomplete alignments -- these records miss at least one necessary tag, e.g. cell barcode, molecule + barcode, or gene name. + + - multi-gene alignments -- these records have the same tags and query_name, though, at least two such + records per query_name exist that point to different genes. """ + +import operator import os import tempfile +from typing import Optional, List, Set, Tuple, Dict, Generator import numpy as np -import scipy.sparse as sp +import pysam +import pytest +from sctools import gtf, bam, consts from sctools.count import CountMatrix -# set the input and output directories, using a tempdir to automatically clean up generated files -_data_dir = os.path.split(__file__)[0] + '/data' -_test_dir = tempfile.mkdtemp() -os.makedirs(_test_dir, exist_ok=True) +# set the input and output directories +_test_data_dir = os.path.join(os.path.split(__file__)[0], 'data') +_test_annotation_file = os.path.join(_test_data_dir, 'chr1.30k_records.gtf.gz') + +# constants +_test_num_cells = 50 +_test_max_genes = 20 +_test_gene_expression_rate = 5.0 +_test_num_duplicates = 20 +_test_num_missing_some_tags = 20 +_test_num_multiple_gene_alignments = 20 +_test_max_gene_hits_per_multiple_gene_alignments = 5 + + +@pytest.fixture(scope='module') +def gene_name_to_index() -> Dict[str, int]: + return gtf.extract_gene_names(_test_annotation_file) + + +class AlignmentRecordTags: + """Represents the bundle of cell barcode, molecule barcode, and gene name.""" + def __init__(self, + cell_barcode: Optional[str], + molecule_barcode: Optional[str], + gene_name: Optional[str]) -> None: + self.cell_barcode = cell_barcode + self.molecule_barcode = molecule_barcode + self.gene_name = gene_name + + def __hash__(self): + return hash((self.cell_barcode, self.molecule_barcode, self.gene_name)) + + def __repr__(self): + return f'{consts.CELL_BARCODE_TAG_KEY}: {self.cell_barcode}, ' \ + f'{consts.MOLECULE_BARCODE_TAG_KEY}: {self.molecule_barcode}, ' \ + f'{consts.GENE_NAME_TAG_KEY}: {self.gene_name}' + + +class SyntheticTaggedBAMGenerator: + """This class generates a synthetic count matrix and an accompanying synthetic tagged BAM file as + described in the preamble documentation block. + + Parameters + ---------- + num_cells : int + number of real cells + max-genes : int + maximum number of genes to use to generate synthetic counts + gene_name_to_index : dict + a map from gene name to their count matrix index + gene_expression_rate : float + poisson rate at which each gene is expressed + rng_seed : int + random number generator seed + + Methods + ------- + generate_synthetic_bam_and_counts_matrix + generates synthetic test data and writes the output to disk + + See Also + -------- + count.from_sorted_tagged_bam + """ + + OUTPUT_PREFIX = 'synthetic_' + SYNTHETIC_SEQUENCE_NAME = "SYNTHETIC_SEQUENCE" + SYNTHETIC_SEQUENCE_LENGTH = 100 + NECESSARY_QUERY_NAME_PREFIX = "NECESSARY_QUERY_" + DUPLICATE_QUERY_NAME_PREFIX = "DUPLICATE_QUERY_" + INCOMPLETE_QUERY_NAME_PREFIX = "INCOMPLETE_QUERY_" + MULTI_GENE_QUERY_NAME_PREFIX = "MULTI_GENE_QUERY_" + + bam_output_filename = OUTPUT_PREFIX + 'records.bam' + count_matrix_output_filename = OUTPUT_PREFIX + 'count_matrix.npy' + row_index_output_filename = OUTPUT_PREFIX + '_row_index.npy' + col_index_output_filename = OUTPUT_PREFIX + '_col_index.npy' + + def __init__( + self, num_cells: int, max_genes: int, gene_name_to_index: Dict[str, int], + gene_expression_rate: float, rng_seed: int = 777) -> None: + self.num_cells = num_cells + self.gene_expression_rate = gene_expression_rate + + # initialize the random number generator + self.rng: np.random.RandomState = np.random.RandomState(seed=rng_seed) + + # generate gene names + self.all_gene_names = [k for k, v in sorted(gene_name_to_index.items(), key=operator.itemgetter(1))] + self.num_genes = len(self.all_gene_names) + self.max_genes = max_genes + assert max_genes <= self.num_genes, \ + f"Max genes ({self.max_genes}) must be <= to all annotated genes ({self.num_genes})" + self.to_be_used_gene_indices: List[int] = self.rng.choice( + np.arange(0, self.num_genes, dtype=np.int), size=self.max_genes, replace=False).tolist() + self.to_be_used_gene_names = [self.all_gene_names[j] for j in self.to_be_used_gene_indices] + + def generate_synthetic_bam_and_counts_matrix( + self, output_path: str, num_duplicates: int, num_missing_some_tags: int, + num_multiple_gene_alignments: int, max_gene_hits_per_multiple_gene_alignments: int, + alignment_sort_order: bam.AlignmentSortOrder = bam.CellMoleculeGeneQueryNameSortOrder()): + """Generates synthetic count matrix and BAM file and writes them to disk. + + Parameters + ---------- + output_path : str + output path + num_duplicates : int + number of duplicate records + num_missing_some_tags : int + number of records that miss at least one crucial tag + num_multiple_gene_alignments : int + number of records that have at least two different gene tags + max_gene_hits_per_multiple_gene_alignments : int + maximum number of unique gene names to use for multiple-gene records + alignment_sort_order : bam.AlignmentSortOrder + sort order of BAM alignment records; if 'None', random sort order is implied + + Returns + ------- + None + """ + assert 2 <= max_gene_hits_per_multiple_gene_alignments <= self.max_genes, \ + f"The parameter `max_gene_hits_per_multiple_gene_alignments` must >= 2 and < maximum annotated " \ + f"genes ({self.max_genes})" + assert num_duplicates >= 0, "Number of duplicate queries must be non-negative" + assert num_missing_some_tags >= 0, "Number of queries with missing tags must be non-negative" + assert num_multiple_gene_alignments >= 0, "Number of queries with multiple gene alignments must be non-negative" + + # generate synthetic count matrix and corresponding simulated records + synthetic_data_bundle = self._generate_synthetic_counts_and_alignment_tags( + num_duplicates, + num_missing_some_tags, + num_multiple_gene_alignments, + max_gene_hits_per_multiple_gene_alignments) + records = list(SyntheticTaggedBAMGenerator._get_bam_records_generator(synthetic_data_bundle)) + + if not alignment_sort_order: # random + # shuffle records + self.rng.shuffle(records) + + else: + records = sorted(records, key=alignment_sort_order.key_generator) + + # write BAM file + with pysam.AlignmentFile( + os.path.join(output_path, self.bam_output_filename), + mode='wb', + reference_names=[self.SYNTHETIC_SEQUENCE_NAME], + reference_lengths=[self.SYNTHETIC_SEQUENCE_LENGTH]) as bo: + for record in records: + bo.write(record) + + # write count matrix, row index, and col index + np.save(os.path.join(output_path, self.count_matrix_output_filename), synthetic_data_bundle.count_matrix) + np.save(os.path.join(output_path, self.row_index_output_filename), synthetic_data_bundle.row_index) + np.save(os.path.join(output_path, self.col_index_output_filename), synthetic_data_bundle.col_index) + + def _generate_synthetic_counts_and_alignment_tags( + self, num_duplicates: int, num_missing_some_tags: int, num_multiple_gene_alignments: int, + max_gene_hits_per_multiple_gene_alignments: int) -> 'SyntheticDataBundle': + + # generate count matrix + count_matrix: np.ndarray = self._generate_random_count_matrix() + + # generate necessary alignment tags that produce count_matrix + (necessary_alignment_record_tags_set, row_index, + col_index) = self._generate_necessary_alignment_record_bundle(count_matrix) + necessary_alignment_record_tags_list = list(necessary_alignment_record_tags_set) + + # sanity check -- we require as many necessary alignment records as the total counts + assert len(necessary_alignment_record_tags_set) == np.sum(count_matrix), \ + "There is an inconsistency between synthetic counts and necessary tags: we require as " \ + "many necessary alignment tags as the total counts" + + # add duplicate records + duplicate_alignment_tags_list = self._generate_duplicate_alignment_tags( + num_duplicates, necessary_alignment_record_tags_list) + + # add records with missing tags + incomplete_alignment_tags_list: List[AlignmentRecordTags] = self._generate_incomplete_alignment_tags( + num_missing_some_tags) + + # add records with multiple gene alignments + multiple_alignment_tags_list: List[List[AlignmentRecordTags]] = self._generate_multiple_gene_alignment_tags( + num_multiple_gene_alignments, + max_gene_hits_per_multiple_gene_alignments, + necessary_alignment_record_tags_set) + + return SyntheticDataBundle( + count_matrix, row_index, col_index, + necessary_alignment_record_tags_list, + duplicate_alignment_tags_list, + incomplete_alignment_tags_list, + multiple_alignment_tags_list) + + def _generate_random_count_matrix(self) -> np.ndarray: + """Generates a random count matrix. + + This method selects `self.max_genes` out of all all genes (`self.num_genes`) and populates the selected genes + with Poisson counts with rate `self.gene_expression_rate`. The count matrix entries corresponding to the + rest of the genes are set to zero. + + Returns + ------- + np.ndarray + an ndarray of shape (`self.num_cells`, `self.num_genes`) + """ + non_zero_count_matrix = self.rng.poisson( + lam=self.gene_expression_rate, size=(self.num_cells, self.max_genes)) + count_matrix = np.zeros((self.num_cells, self.num_genes), dtype=np.int) + for i, i_gene in enumerate(self.to_be_used_gene_indices): + count_matrix[:, i_gene] = non_zero_count_matrix[:, i] + return count_matrix + + @staticmethod + def _get_bam_records_generator( + synthetic_data_bundle: 'SyntheticDataBundle', + rng_seed: int = 777) -> Generator[pysam.AlignedSegment, None, None]: + """Returns a generator of pysam.AlignedSegment instances created from the alignment tags + provided to the initializer. + + Parameters + ---------- + synthetic_data_bundle : SyntheticDataBundle + a bundle of synthetic alignment tags + rng_seed : int + random number generator seed; it is used for generating random reference_start position. + + See Also + -------- + - The preamble documentation block for a description of the meaning of different alignment records + (necessary, duplicate, incomplete, etc.) + - SyntheticTaggedBAMGenerator._generate_aligned_segment_from_tags + """ + rng = np.random.RandomState(rng_seed) + + num_queries = synthetic_data_bundle.num_queries + i_query = 0 + + # necessary, duplicate, and incomplete alignments + for alignment_tags_list, query_name_prefix in zip( + [synthetic_data_bundle.necessary_alignment_record_tags_list, + synthetic_data_bundle.duplicate_alignment_tags_list, + synthetic_data_bundle.incomplete_alignment_tags_list], + [SyntheticTaggedBAMGenerator.NECESSARY_QUERY_NAME_PREFIX, + SyntheticTaggedBAMGenerator.DUPLICATE_QUERY_NAME_PREFIX, + SyntheticTaggedBAMGenerator.INCOMPLETE_QUERY_NAME_PREFIX]): + for alignment_tags in alignment_tags_list: + yield SyntheticTaggedBAMGenerator._generate_aligned_segment_from_tags( + alignment_tags, query_name_prefix, i_query, num_queries, rng) + i_query += 1 + + # multi-gene alignments + for alignment_tags_list in synthetic_data_bundle.multiple_alignment_tags_list: + # multiple alignments have the same query name (by definition) + for alignment_tags in alignment_tags_list: + yield SyntheticTaggedBAMGenerator._generate_aligned_segment_from_tags( + alignment_tags, SyntheticTaggedBAMGenerator.MULTI_GENE_QUERY_NAME_PREFIX, + i_query, num_queries, rng) + i_query += 1 + + @staticmethod + def _generate_aligned_segment_from_tags( + alignment_tags: AlignmentRecordTags, query_prefix: str, i_query: int, num_queries: int, + rng: np.random.RandomState) -> pysam.AlignedSegment: + """Generates pysam.AlignedSegment instances from alignment_tags. + + Parameters + ---------- + alignment_tags : AlignmentRecordTags + tags to attach to the instantiated pysam.AlignedSegment + query_prefix : str + prefix to use for query name + i_query : int + query index + num_queries: int + maximum number of queries (only used for pretty-printing the query index) + rng: np.random.RandomState + a random number generator + + Notes + ----- + The query_sequence and query_quality are both empty as these query features are not used for generating + the counts matrix. Likewise, the flag is currently unset. In the future, once we add a filtering + policy based on BAM record flags (such as duplicates), this method must be updated accordingly. + + Returns + ------- + pysam.AlignedSegment + an instance of pysam.AlignedSegment + + """ + tags = [] + if alignment_tags.cell_barcode: + tags.append((consts.CELL_BARCODE_TAG_KEY, + alignment_tags.cell_barcode, 'Z')) + if alignment_tags.molecule_barcode: + tags.append((consts.MOLECULE_BARCODE_TAG_KEY, + alignment_tags.molecule_barcode, 'Z')) + if alignment_tags.gene_name: + tags.append((consts.GENE_NAME_TAG_KEY, + alignment_tags.gene_name, 'Z')) + record = pysam.AlignedSegment() + record.query_name = SyntheticTaggedBAMGenerator._generate_query_name(query_prefix, i_query, num_queries) + record.reference_start = rng.randint(low=0, high=SyntheticTaggedBAMGenerator.SYNTHETIC_SEQUENCE_LENGTH) + record.reference_id = 0 # note: we only use one synthetic sequence + if len(tags) > 0: + record.set_tags(tags) + return record + + @staticmethod + def _generate_query_name(query_prefix: str, i_query: int, num_queries: int) -> str: + """Returns query name string from query index. We zero-pad the string representation of query + indices merely for pretty-printing, e.g. 0000, 0001, ..., 9999.""" + num_digits = len(str(num_queries - 1)) + return query_prefix + str(i_query).zfill(num_digits) + + def _generate_necessary_alignment_record_bundle( + self, count_matrix: np.ndarray) -> Tuple[Set[AlignmentRecordTags], List[str], List[str]]: + alignments: Set[AlignmentRecordTags] = set() + used_cell_barcodes: Set[str] = set() + + row_index: List[str] = [] + col_index = self.all_gene_names + + for i_cell in range(self.num_cells): + # generate a unique cell barcode + while True: + cell_barcode = self._generate_random_cell_barcode() + if cell_barcode not in used_cell_barcodes: + break + row_index.append(cell_barcode) + + for i_gene in self.to_be_used_gene_indices: + for i_molecule in range(count_matrix[i_cell, i_gene]): + # generate a unique alignment tag + unique_alignment_tag = self._generate_unique_random_alignment_tag( + alignments, + gene_name=self.all_gene_names[i_gene], + cell_barcode=cell_barcode) + alignments.add(unique_alignment_tag) + + return alignments, row_index, col_index + + def _generate_unique_random_alignment_tag( + self, + existing_alignment_tags: Set[AlignmentRecordTags], + gene_name: str, + cell_barcode: Optional[str] = None, + molecule_barcode: Optional[str] = None) -> AlignmentRecordTags: + assert gene_name in self.to_be_used_gene_names, \ + f"{gene_name} is not an allowed gene for generating synthetic data" + + while True: + alignment = AlignmentRecordTags( + cell_barcode=cell_barcode if cell_barcode else self._generate_random_cell_barcode(), + molecule_barcode=molecule_barcode if molecule_barcode else self._generate_random_molecule_barcode(), + gene_name=gene_name) + if alignment not in existing_alignment_tags: + return alignment + + def _generate_duplicate_alignment_tags( + self, num_duplicates: int, necessary_alignments_list: List[AlignmentRecordTags]) \ + -> List[AlignmentRecordTags]: + return self.rng.choice(necessary_alignments_list, size=num_duplicates).tolist() + + def _generate_incomplete_alignment_tags(self, num_missing_some_tags: int) -> List[AlignmentRecordTags]: + """Generates alignments with missing crucial tags. + + Notes + ----- + This method requires each combination of missing tags to occur at least once and may therefore return lists + that are longer than `num_missing_some_tags`. + """ + incomplete_alignment_tags_list: List[AlignmentRecordTags] = list() + tag_mask_occurrences: Set[int] = set() + i_entries = 0 + while i_entries < num_missing_some_tags or len(tag_mask_occurrences) < 7: + tag_mask = self.rng.randint(low=0, high=7) + tag_mask_occurrences.add(tag_mask) + gene_name = self.rng.choice(self.to_be_used_gene_names) + alignment = self._generate_unique_random_alignment_tag(set(), gene_name) + if not tag_mask & 1: + alignment.cell_barcode = None + if not tag_mask & 2: + alignment.molecule_barcode = None + if not tag_mask & 4: + alignment.gene_name = None + incomplete_alignment_tags_list.append(alignment) + i_entries += 1 + return incomplete_alignment_tags_list + + def _generate_multiple_gene_alignment_tags( + self, num_multiple_gene_alignments: int, max_gene_hits_per_multiple_gene_alignments: int, + necessary_alignment_record_tags_set: Set[AlignmentRecordTags]) -> List[List[AlignmentRecordTags]]: + + necessary_alignment_record_tags_list = list(necessary_alignment_record_tags_set) + + multiple_gene_alignment_tags_list: List[List[AlignmentRecordTags]] = list() + for _ in range(num_multiple_gene_alignments): + random_necessary_alignment = self.rng.choice(necessary_alignment_record_tags_list) + random_necessary_cell_barcode: str = random_necessary_alignment.cell_barcode + novel_molecule_barcode: str = self._generate_unique_random_alignment_tag( + necessary_alignment_record_tags_set, + gene_name=random_necessary_alignment.gene_name, + cell_barcode=random_necessary_cell_barcode).molecule_barcode + num_gene_hits = self.rng.randint(low=2, high=max_gene_hits_per_multiple_gene_alignments + 1) + gene_name_hits = self.rng.choice(self.to_be_used_gene_names, replace=False, size=num_gene_hits) + multiple_gene_alignment_tags_list.append( + [AlignmentRecordTags(random_necessary_cell_barcode, novel_molecule_barcode, gene_name) + for gene_name in gene_name_hits]) + return multiple_gene_alignment_tags_list + + def _generate_random_cell_barcode(self, length: int = 16): + return self._generate_random_genomic_sequences(length) + + def _generate_random_molecule_barcode(self, length: int = 10): + return self._generate_random_genomic_sequences(length) + + def _generate_random_genomic_sequences(self, length: int): + return ''.join(self.rng.choice(['A', 'C', 'T', 'G'], size=length)) + + +class SyntheticDataBundle: + """A container for synthetic count matrix, row and column indices, and alignment tags. + + Parameters + ---------- + count_matrix : np.ndarray + the cell x gene synthetic count matrix + row_index : List[str] + list of cell barcodes + col_index : List[str] + list of gene names + necessary_alignment_record_tags_list : List[AlignmentRecordTags] + list of necessary alignment tags; alignment records made using these tags are expected to produce + `count_matrix` once processed by the counting algorithm. + duplicate_alignment_tags_list : List[AlignmentRecordTags] + list of duplicate alignment tags (a subset of `necessary_alignment_record_tags_list`) + incomplete_alignment_tags_list : List[AlignmentRecordTags] + list of incomplete alignment tags (miss at least one of the required tags: cell, molecule, gene) + multiple_alignment_tags_list : List[List[AlignmentRecordTags]] + list of lists of multiple alignment tags; each list element is a list of alignment tags with the + same molecular barcodes, though, with multiple gene names. + + See Also + -------- + SyntheticBarcodedBAMGenerator + """ + + def __init__( + self, count_matrix: np.ndarray, row_index: List[str], col_index: List[str], + necessary_alignment_record_tags_list: List[AlignmentRecordTags], + duplicate_alignment_tags_list: List[AlignmentRecordTags], + incomplete_alignment_tags_list: List[AlignmentRecordTags], + multiple_alignment_tags_list: List[List[AlignmentRecordTags]]) -> None: + + assert count_matrix.shape == (len(row_index), len(col_index)), \ + "The shape of the count matrix is inconsistent with the provided row/column indices" + + self.count_matrix = count_matrix + self.row_index = row_index + self.col_index = col_index + + self.necessary_alignment_record_tags_list = necessary_alignment_record_tags_list + self.duplicate_alignment_tags_list = duplicate_alignment_tags_list + self.incomplete_alignment_tags_list = incomplete_alignment_tags_list + self.multiple_alignment_tags_list = multiple_alignment_tags_list + + self.num_queries = (len(necessary_alignment_record_tags_list) + + len(duplicate_alignment_tags_list) + + len(incomplete_alignment_tags_list) + + len(multiple_alignment_tags_list)) + + +def _get_sorted_count_matrix(count_matrix: np.ndarray, row_index: np.ndarray, col_index: np.ndarray)\ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Sorted the rows and columns of `count_matrix` and the associated row/column indices. + + Parameters + ---------- + count_matrix : np.ndarray + a cell x gene count matrix + row_index : np.ndarray + row index of the count matrix (i.e. cell barcodes) + col_index : np.ndarray + column index of the count matrix (i.e. gene names) + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + row/column sorted count matrix, sorted row index, sorted column index + """ + sorted_row_indices = [idx for idx, _ in sorted(enumerate(row_index), key=operator.itemgetter(1))] + sorted_col_indices = [idx for idx, _ in sorted(enumerate(col_index), key=operator.itemgetter(1))] + return (count_matrix[sorted_row_indices, :][:, sorted_col_indices], + row_index[sorted_row_indices], + col_index[sorted_col_indices]) + + +@pytest.mark.parametrize( + 'alignment_sort_order', + [bam.QueryNameSortOrder(), bam.CellMoleculeGeneQueryNameSortOrder()], + ids=['query_name_sort_order', 'cell_molecule_gene_query_name_sort_order']) +def test_count_matrix_from_bam(alignment_sort_order: bam.AlignmentSortOrder, gene_name_to_index): + # instantiate a test data generator + synthetic_data_generator = SyntheticTaggedBAMGenerator( + _test_num_cells, + _test_max_genes, + gene_name_to_index, + _test_gene_expression_rate) + + _test_temp_dir = tempfile.TemporaryDirectory() + try: + # generate test data + synthetic_data_generator.generate_synthetic_bam_and_counts_matrix( + _test_temp_dir.name, + _test_num_duplicates, + _test_num_missing_some_tags, + _test_num_multiple_gene_alignments, + _test_max_gene_hits_per_multiple_gene_alignments, + alignment_sort_order=alignment_sort_order) + + # test data paths + test_bam_path = os.path.join( + _test_temp_dir.name, SyntheticTaggedBAMGenerator.bam_output_filename) + test_count_matrix_path = os.path.join( + _test_temp_dir.name, SyntheticTaggedBAMGenerator.count_matrix_output_filename) + test_row_index_path = os.path.join( + _test_temp_dir.name, SyntheticTaggedBAMGenerator.row_index_output_filename) + test_col_index_path = os.path.join( + _test_temp_dir.name, SyntheticTaggedBAMGenerator.col_index_output_filename) + + # create CountMatrix from the synthetic bam + count_matrix_from_bam: CountMatrix = CountMatrix.from_sorted_tagged_bam(test_bam_path, gene_name_to_index) -_bam_file = _data_dir + '/cell-sorted2.bam' -_gtf_file = _data_dir + '/chr1.gtf.gz' + # load the test counts matrix + count_matrix_data_expected = np.load(test_count_matrix_path) + row_index_expected = np.load(test_row_index_path) + col_index_expected = np.load(test_col_index_path) + finally: + _test_temp_dir.cleanup() -# todo this test should be made faster -def test_count_matrix(): - count_matrix = CountMatrix.from_bam(_bam_file, _gtf_file) - assert count_matrix._matrix.shape == (734, 3132) + count_matrix_data_from_bam = count_matrix_from_bam.matrix.todense() + row_index_from_bam = count_matrix_from_bam.row_index + col_index_from_bam = count_matrix_from_bam.col_index - # check save + load - count_matrix.save(_test_dir + '/test_count_matrix') - count_matrix_2 = CountMatrix.load(_test_dir + '/test_count_matrix') + # sort expected and from_bam results by their respective row and column indices, since their sort order + # is not part of the design specs and is considered arbitrary + (sorted_count_matrix_data_from_bam, + sorted_row_index_from_bam, + sorted_col_index_from_bam) = _get_sorted_count_matrix( + count_matrix_data_from_bam, row_index_from_bam, col_index_from_bam) + (sorted_count_matrix_data_expected, + sorted_row_index_expected, + sorted_col_index_expected) = _get_sorted_count_matrix( + count_matrix_data_expected, row_index_expected, col_index_expected) - # check matrices are the same - csr1: sp.csr_matrix = count_matrix._matrix - csr2: sp.csr_matrix = count_matrix_2._matrix - assert np.allclose(csr1.indices, csr2.indices) - assert np.allclose(csr1.indptr, csr2.indptr) - assert np.allclose(csr1.data, csr2.data) + # assert equality of sorted count matrices and sorted row/col indices + assert np.allclose(sorted_count_matrix_data_from_bam, sorted_count_matrix_data_expected) + assert all([row_name_from_bam == row_name_expected + for row_name_from_bam, row_name_expected in zip(sorted_row_index_from_bam, sorted_row_index_expected)]) + assert all([col_name_from_bam == col_name_expected + for col_name_from_bam, col_name_expected in zip(sorted_col_index_from_bam, sorted_col_index_expected)]) diff --git a/src/sctools/test/test_entrypoints.py b/src/sctools/test/test_entrypoints.py index b9107e8..e4b82a4 100644 --- a/src/sctools/test/test_entrypoints.py +++ b/src/sctools/test/test_entrypoints.py @@ -1,11 +1,12 @@ -import os import glob +import os import tempfile -import scipy.sparse as sp -import numpy as np +import numpy as np import pysam -from sctools import platform, count +import scipy.sparse as sp + +from sctools import platform, count, consts data_dir = os.path.split(__file__)[0] + '/data/' @@ -22,12 +23,12 @@ def test_Attach10XBarcodes_entrypoint(): with pysam.AlignmentFile('test_tagged_bam.bam', 'rb', check_sq=False) as f: for alignment in f: # each alignment should now have a tag, and that tag should be a string - assert isinstance(alignment.get_tag('CY'), str) - assert isinstance(alignment.get_tag('CR'), str) - assert isinstance(alignment.get_tag('UY'), str) - assert isinstance(alignment.get_tag('UR'), str) - assert isinstance(alignment.get_tag('SY'), str) - assert isinstance(alignment.get_tag('SR'), str) + assert isinstance(alignment.get_tag(consts.QUALITY_CELL_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.RAW_CELL_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.QUALITY_MOLECULE_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.RAW_MOLECULE_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.RAW_SAMPLE_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.QUALITY_SAMPLE_BARCODE_TAG_KEY), str) os.remove('test_tagged_bam.bam') # clean up @@ -44,15 +45,15 @@ def test_Attach10XBarcodes_entrypoint_with_whitelist(): success = False with pysam.AlignmentFile('test_tagged_bam.bam', 'rb', check_sq=False) as f: for alignment in f: - if alignment.has_tag('CB'): + if alignment.has_tag(consts.CELL_BARCODE_TAG_KEY): success = True # each alignment should now have a tag, and that tag should be a string - assert isinstance(alignment.get_tag('CY'), str) - assert isinstance(alignment.get_tag('CR'), str) - assert isinstance(alignment.get_tag('UY'), str) - assert isinstance(alignment.get_tag('UR'), str) - assert isinstance(alignment.get_tag('SY'), str) - assert isinstance(alignment.get_tag('SR'), str) + assert isinstance(alignment.get_tag(consts.RAW_CELL_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.QUALITY_CELL_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.RAW_MOLECULE_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.QUALITY_MOLECULE_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.RAW_SAMPLE_BARCODE_TAG_KEY), str) + assert isinstance(alignment.get_tag(consts.QUALITY_SAMPLE_BARCODE_TAG_KEY), str) assert success os.remove('test_tagged_bam.bam') # clean up @@ -71,7 +72,7 @@ def test_split_bam(): '--bamfile', 'test_tagged_bam.bam', '--output-prefix', 'test_tagged', '--subfile-size', '0.005', - '--tags', 'CB', 'CR'] + '--tags', consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY] return_call = platform.GenericPlatform.split_bam(split_args) assert return_call == 0 diff --git a/src/sctools/test/test_fastq.py b/src/sctools/test/test_fastq.py index 73ae737..c1bb8ac 100644 --- a/src/sctools/test/test_fastq.py +++ b/src/sctools/test/test_fastq.py @@ -1,11 +1,12 @@ -from itertools import product -from functools import partial -import string import os +import string +from functools import partial +from itertools import product + import pytest -from .. import fastq -from ..reader import zip_readers +from .. import fastq, consts +from ..reader import zip_readers # set some useful globals for testing data_dir = os.path.split(__file__)[0] + '/data/' @@ -193,19 +194,34 @@ def test_fields_populate_properly(reader_all_compressions): @pytest.fixture(scope='function') def embedded_barcode_generator(): - cell_barcode = fastq.EmbeddedBarcode(start=0, end=16, quality_tag='CY', sequence_tag='CR') - molecule_barcode = fastq.EmbeddedBarcode(start=16, end=26, quality_tag='UY', sequence_tag='UR') - return fastq.EmbeddedBarcodeGenerator(data_dir + 'test_r1.fastq.gz', - [cell_barcode, molecule_barcode]) + cell_barcode = fastq.EmbeddedBarcode( + start=0, + end=16, + quality_tag=consts.QUALITY_CELL_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_CELL_BARCODE_TAG_KEY) + molecule_barcode = fastq.EmbeddedBarcode( + start=16, + end=26, + quality_tag=consts.QUALITY_MOLECULE_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_MOLECULE_BARCODE_TAG_KEY) + return fastq.EmbeddedBarcodeGenerator( + data_dir + 'test_r1.fastq.gz', [cell_barcode, molecule_barcode]) @pytest.fixture(scope='function') def barcode_generator_with_corrected_cell_barcodes(): - cell_barcode = fastq.EmbeddedBarcode(start=0, end=16, quality_tag='CY', sequence_tag='CR') - molecule_barcode = fastq.EmbeddedBarcode(start=16, end=26, quality_tag='UY', sequence_tag='UR') + cell_barcode = fastq.EmbeddedBarcode( + start=0, + end=16, + quality_tag=consts.QUALITY_CELL_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_CELL_BARCODE_TAG_KEY) + molecule_barcode = fastq.EmbeddedBarcode( + start=16, + end=26, + quality_tag=consts.QUALITY_MOLECULE_BARCODE_TAG_KEY, + sequence_tag=consts.RAW_MOLECULE_BARCODE_TAG_KEY) return fastq.BarcodeGeneratorWithCorrectedCellBarcodes( - data_dir + 'test_r1.fastq.gz', cell_barcode, data_dir + '1k-august-2016.txt', - [molecule_barcode]) + data_dir + 'test_r1.fastq.gz', cell_barcode, data_dir + '1k-august-2016.txt', [molecule_barcode]) def test_embedded_barcode_generator_produces_outputs_of_expected_size(embedded_barcode_generator): @@ -218,21 +234,21 @@ def test_embedded_barcode_generator_produces_outputs_of_expected_size(embedded_b # note that all barcodes are strings and therefore should get 'Z' values # test cell tags - assert cell_seq[0] == 'CR' + assert cell_seq[0] == consts.RAW_CELL_BARCODE_TAG_KEY assert len(cell_seq[1]) == correct_cell_barcode_length assert all(v in 'ACGTN' for v in cell_seq[1]) assert cell_seq[2] == 'Z' - assert cell_qual[0] == 'CY' + assert cell_qual[0] == consts.QUALITY_CELL_BARCODE_TAG_KEY assert len(cell_qual[1]) == correct_cell_barcode_length assert all(v in string.printable for v in cell_qual[1]) assert cell_seq[2] == 'Z' # test umi tags - assert umi_seq[0] == 'UR' + assert umi_seq[0] == consts.RAW_MOLECULE_BARCODE_TAG_KEY assert len(umi_seq[1]) == correct_umi_length assert all(v in 'ACGTN' for v in umi_seq[1]) assert umi_seq[2] == 'Z' - assert umi_qual[0] == 'UY' + assert umi_qual[0] == consts.QUALITY_MOLECULE_BARCODE_TAG_KEY assert len(umi_qual[1]) == correct_umi_length assert all(v in string.printable for v in umi_qual[1]) assert umi_seq[2] == 'Z' @@ -244,7 +260,7 @@ def test_corrects_barcodes(barcode_generator_with_corrected_cell_barcodes): success = False for barcode_sets in barcode_generator_with_corrected_cell_barcodes: for barcode_set in barcode_sets: - if barcode_set[0] == 'CB': + if barcode_set[0] == consts.CELL_BARCODE_TAG_KEY: success = True break assert success diff --git a/src/sctools/test/test_gtf.py b/src/sctools/test/test_gtf.py index 0dccac6..225bdd4 100644 --- a/src/sctools/test/test_gtf.py +++ b/src/sctools/test/test_gtf.py @@ -16,7 +16,7 @@ def files(request): def test_opens_file_reads_first_line(files): rd = gtf.Reader(files, 'r', header_comment_char='#') record = next(iter(rd)) - assert isinstance(record, gtf.Record) + assert isinstance(record, gtf.GTFRecord) def test_opens_file_populates_fields_properly(files): diff --git a/src/sctools/test/test_metrics.py b/src/sctools/test/test_metrics.py index 7c128c7..f9d9469 100644 --- a/src/sctools/test/test_metrics.py +++ b/src/sctools/test/test_metrics.py @@ -1,13 +1,12 @@ +import fileinput +import math import os import tempfile -import math -import fileinput from typing import Callable -import pytest -import pandas as pd import numpy as np - +import pandas as pd +import pytest from sctools.metrics.gatherer import GatherGeneMetrics, GatherCellMetrics, MetricGatherer from sctools.metrics.merge import MergeCellMetrics, MergeGeneMetrics from sctools.platform import TenXV2 @@ -378,8 +377,6 @@ def test_single_read_evidence(metrics, key, expected_value): assert observed == expected_value - - def split_metrics_file(metrics_file): """ produces two mergeable on-disk metric files from a single file that contain the first 3/4