Skip to content

Commit

Permalink
PR review requested changes
Browse files Browse the repository at this point in the history
- refactoring, rearrangement, and cleanup of test_count.py
- documentation in test_count.py
- AlignmentSortOrder class in bam.py
- refactoring further instances of constants to consts.py
  • Loading branch information
mbabadi committed Jul 9, 2018
1 parent a2117d5 commit 394fb21
Show file tree
Hide file tree
Showing 3 changed files with 488 additions and 384 deletions.
122 changes: 93 additions & 29 deletions src/sctools/bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,36 @@
Methods
-------
iter_tag_groups function to iterate over reads by an arbitrary tag
iter_cell_barcodes wrapper for iter_tag_groups that iterates over cell barcode tags
iter_genes wrapper for iter_tag_groups that iterates over gene tags
iter_molecules wrapper for iter_tag_groups that iterates over molecule tags
iter_tag_groups function to iterate over reads by an arbitrary tag
iter_cell_barcodes wrapper for iter_tag_groups that iterates over cell barcode tags
iter_genes wrapper for iter_tag_groups that iterates over gene tags
iter_molecules wrapper for iter_tag_groups that iterates over molecule tags
Classes
-------
SubsetAlignments class to extract reads specific to requested chromosome(s)
Tagger class to add tags to sam/bam records from paired fastq records
SubsetAlignments class to extract reads specific to requested chromosome(s)
Tagger class to add tags to sam/bam records from paired fastq records
AlignmentSortOrder abstract class to represent alignment sort orders
QueryNameSortOrder alignment sort order by query name
CellMoleculeGeneQueryNameSortOrder alignment sort order hierarchically cell > molecule > gene > query name
References
----------
htslib : https://github.com/samtools/htslib
"""

import warnings
import os
import math
import os
import warnings
from abc import abstractmethod
from itertools import cycle
from typing import Iterator, Generator, List, Union, Tuple
from typing import Iterator, Generator, List, Union, Tuple, Callable, Any

import pysam

from . import consts


class SubsetAlignments:
"""Wrapper for pysam/htslib that extracts reads corresponding to requested chromosome(s)
Expand Down Expand Up @@ -74,9 +80,8 @@ def __init__(self, alignment_file: str, open_mode: str=None):
open_mode = 'r'
else:
raise ValueError(
'could not autodetect file type for alignment_file %s (detectible suffixes: '
'.sam, .bam)' % alignment_file
)
f'Could not autodetect file type for alignment_file {alignment_file} (detectable suffixes: '
f'.sam, .bam)')
self._file: str = alignment_file
self._open_mode: str = open_mode

Expand Down Expand Up @@ -164,7 +169,7 @@ class Tagger:

def __init__(self, bam_file: str) -> None:
if not isinstance(bam_file, str):
raise TypeError('bam_file must be type str, not %s' % type(bam_file))
raise TypeError(f'The argument "bam_file" must be of type str, not {type(bam_file)}')
self.bam_file = bam_file

# todo add type to tag_generators (make sure it doesn't introduce import issues
Expand Down Expand Up @@ -231,27 +236,27 @@ def split(in_bam, out_prefix, *tags, approx_mb_per_split=1000, raise_missing=Tru
if len(tags) == 0:
raise ValueError('At least one tag must be passed')

def _cleanup(files_to_counts, files_to_names, rm_all=False) -> None:
def _cleanup(_files_to_counts, _files_to_names, rm_all=False) -> None:
"""Closes file handles and remove any empty files.
Parameters
----------
files_to_counts : dict
_files_to_counts : dict
Dictionary of file objects to the number of reads each contains.
files_to_names : dict
_files_to_names : dict
Dictionary of file objects to file handles.
rm_all : bool, optional
If True, indicates all files should be removed, regardless of count number, else only
empty files without counts are removed (default = False)
"""
for bamfile, count in files_to_counts.items():
for bamfile, count in _files_to_counts.items():
# corner case: clean up files that were created, but didn't get data because
# n_cell < n_barcode
bamfile.close()
if count == 0 or rm_all:
os.remove(files_to_names[bamfile])
del files_to_names[bamfile]
os.remove(_files_to_names[bamfile])
del _files_to_names[bamfile]

# find correct number of subfiles to spawn
bam_mb = os.path.getsize(in_bam) * 1e-6
Expand All @@ -260,8 +265,8 @@ def _cleanup(files_to_counts, files_to_names, rm_all=False) -> None:
warnings.warn('Number of requested subfiles (%d) exceeds 500; this may cause OS errors by '
'exceeding fid limits' % n_subfiles)
if n_subfiles > 1000:
raise ValueError('Number of requested subfiles (%d) exceeds 1000; this will usually cause '
'OS errors, think about increasing max_mb_per_split.' % n_subfiles)
raise ValueError(f'Number of requested subfiles ({ n_subfiles}) exceeds 1000; this will usually cause '
f'OS errors, think about increasing max_mb_per_split.')

# create all the output files
with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments:
Expand Down Expand Up @@ -297,8 +302,7 @@ def _cleanup(files_to_counts, files_to_names, rm_all=False) -> None:
if tag_content is None:
if raise_missing:
_cleanup(files_to_counts, files_to_names, rm_all=True)
raise RuntimeError(
'alignment encountered that is missing %s tag(s).' % repr(tags))
raise RuntimeError('Alignment encountered that is missing {repr(tags)} tag(s).')
else:
continue # move on to next alignment

Expand Down Expand Up @@ -379,12 +383,12 @@ def iter_molecule_barcodes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Gene
Yields
------
grouped_by_tag : Iterator[pysam.AlignedSegment]
reads sharing a unique molecule barcode (``UB`` tag)
reads sharing a unique molecule barcode tag
current_tag : str
the molecule barcode that records in the group all share
"""
return iter_tag_groups(tag='UB', bam_iterator=bam_iterator)
return iter_tag_groups(tag=consts.MOLECULE_BARCODE_TAG_KEY, bam_iterator=bam_iterator)


def iter_cell_barcodes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generator:
Expand All @@ -398,12 +402,12 @@ def iter_cell_barcodes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generato
Yields
------
grouped_by_tag : Iterator[pysam.AlignedSegment]
reads sharing a unique cell barcode (``CB`` tag)
reads sharing a unique cell barcode tag
current_tag : str
the cell barcode that reads in the group all share
"""
return iter_tag_groups(tag='CB', bam_iterator=bam_iterator)
return iter_tag_groups(tag=consts.CELL_BARCODE_TAG_KEY, bam_iterator=bam_iterator)


def iter_genes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generator:
Expand All @@ -417,9 +421,69 @@ def iter_genes(bam_iterator: Iterator[pysam.AlignedSegment]) -> Generator:
Yields
------
grouped_by_tag : Iterator[pysam.AlignedSegment]
reads sharing a unique gene id (``GE`` tag)
reads sharing a unique gene name tag
current_tag : str
the gene id that reads in the group all share
"""
return iter_tag_groups(tag='GE', bam_iterator=bam_iterator)
return iter_tag_groups(tag=consts.GENE_NAME_TAG_KEY, bam_iterator=bam_iterator)


def get_tag_or_default(alignment: pysam.AlignedSegment, tag_key: str, default: str = 'N') -> str:
"""Extracts the value associated to `tag_key` from `alignment`, and returns a default value
if the tag is not present."""
try:
return alignment.get_tag(tag_key)
except KeyError:
return default


class AlignmentSortOrder:
"""The base class of alignment sort orders."""
@property
@abstractmethod
def key_generator(self) -> Callable[[pysam.AlignedSegment], Any]:
"""Returns a callable function that calculates a sort key from given pysam.AlignedSegment."""
raise NotImplementedError


class QueryNameSortOrder(AlignmentSortOrder):
"""Alignment record sort order by query name."""
@staticmethod
def get_sort_key(alignment: pysam.AlignedSegment) -> str:
return alignment.query_name

@property
def key_generator(self):
return QueryNameSortOrder.get_sort_key

def __repr__(self) -> str:
return 'query_name'


class CellMoleculeGeneQueryNameSortOrder(AlignmentSortOrder):
"""Hierarchical alignment record sort order (cell barcode >= molecule barcode >= gene name >= query name)."""
def __init__(
self,
cell_barcode_tag_key: str = consts.CELL_BARCODE_TAG_KEY,
molecule_barcode_tag_key: str = consts.MOLECULE_BARCODE_TAG_KEY,
gene_name_tag_key: str = consts.GENE_NAME_TAG_KEY) -> None:
assert cell_barcode_tag_key, "Cell barcode tag key can not be None"
assert molecule_barcode_tag_key, "Molecule barcode tag key can not be None"
assert gene_name_tag_key, "Gene name tag key can not be None"
self.cell_barcode_tag_key = cell_barcode_tag_key
self.molecule_barcode_tag_key = molecule_barcode_tag_key
self.gene_name_tag_key = gene_name_tag_key

def _get_sort_key(self, alignment: pysam.AlignedSegment) -> Tuple[str, str, str, str]:
return (get_tag_or_default(alignment, self.cell_barcode_tag_key),
get_tag_or_default(alignment, self.molecule_barcode_tag_key),
get_tag_or_default(alignment, self.gene_name_tag_key),
alignment.query_name)

@property
def key_generator(self) -> Callable[[pysam.AlignedSegment], Tuple[str, str, str, str]]:
return self._get_sort_key

def __repr__(self) -> str:
return 'hierarchical__cell_molecule_gene_query_name'
40 changes: 16 additions & 24 deletions src/sctools/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
Methods
-------
from_sorted_tagged_bam(bam_file: str, annotation_file: str, cell_barcode_tag: str='CB',
molecule_barcode_tag: str='UB', gene_name_tag: str='GE', open_mode: str='rb')
from_sorted_tagged_bam(
bam_file: str, annotation_file: str, cell_barcode_tag: str = consts.CELL_BARCODE_TAG_KEY,
molecule_barcode_tag: str=consts.MOLECULE_BARCODE_TAG_KEY,
gene_name_tag: str=consts.GENE_NAME_TAG_KEY, open_mode: str='rb')
from_mtx(matrix_mtx: str, row_index_file: str, col_index_file: str)
Notes
Expand All @@ -26,7 +28,8 @@
import pysam
import scipy.sparse as sp
from scipy.io import mmread
from sctools import gtf

from sctools import gtf, consts


class CountMatrix:
Expand All @@ -49,10 +52,8 @@ def col_index(self):
return self._col_index

@staticmethod
def _get_alignments_grouped_by_query_name_generator(bam_file: str,
cell_barcode_tag: str,
molecule_barcode_tag: str,
open_mode: str = 'rb') -> \
def _get_alignments_grouped_by_query_name_generator(
bam_file: str, cell_barcode_tag: str, molecule_barcode_tag: str, open_mode: str = 'rb') -> \
Generator[Tuple[str, Optional[str], Optional[str], List[pysam.AlignedSegment]], None, None]:
"""Iterates through a query_name-sorted BAM file, groups all alignments with the same query name
Expand All @@ -71,17 +72,8 @@ def _get_alignments_grouped_by_query_name_generator(bam_file: str,
a generator for tuples (query_name, cell_barcode, molecule_barcode, alignments)
"""
with pysam.AlignmentFile(bam_file, mode=open_mode) as bam_records:
for alignment in itertools.groupby(bam_records, key=lambda record: record.query_name):
query_name: str = alignment[0]
grouper = alignment[1]
alignments: List[pysam.AlignedSegment] = []
try:
while True:
alignment = grouper.__next__()
alignments.append(alignment)
except StopIteration:
pass

for (query_name, grouper) in itertools.groupby(bam_records, key=lambda record: record.query_name):
alignments: List[pysam.AlignedSegment] = list(grouper)
cell_barcode: Optional[str] = None
try:
cell_barcode = alignments[0].get_tag(cell_barcode_tag)
Expand All @@ -106,9 +98,9 @@ def from_sorted_tagged_bam(
cls,
bam_file: str,
annotation_file: str,
cell_barcode_tag: str='CB',
molecule_barcode_tag: str='UB',
gene_name_tag: str='GE',
cell_barcode_tag: str=consts.CELL_BARCODE_TAG_KEY,
molecule_barcode_tag: str=consts.MOLECULE_BARCODE_TAG_KEY,
gene_name_tag: str=consts.GENE_NAME_TAG_KEY,
open_mode: str='rb') -> 'CountMatrix':
"""Generate a count matrix from a sorted, tagged bam file
Expand Down Expand Up @@ -143,13 +135,13 @@ def from_sorted_tagged_bam(
order
cell_barcode_tag : str, optional
Tag that specifies the cell barcode for each read. Reads without this tag will be ignored
(default = 'CB')
(default = consts.CELL_BARCODE_TAG_KEY)
molecule_barcode_tag : str, optional
Tag that specifies the molecule barcode for each read. Reads without this tag will be
ignored (default = 'UB')
ignored (default = consts.MOLECULE_BARCODE_TAG_KEY)
gene_name_tag
Tag that specifies the gene name for each read. Reads without this tag will be ignored
(default = 'GE')
(default = consts.GENE_NAME_TAG_KEY)
annotation_file : str
gtf annotation file that was used to create gene ID tags. Used to map genes to indices
open_mode : {'r', 'rb'}, optional
Expand Down
Loading

0 comments on commit 394fb21

Please sign in to comment.