diff --git a/README.md b/README.md index aed1b76..1fab93d 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,33 @@ -pita improves transcript annotation -=================================== - -Pipeline to improve transcript annotation based on RNA-seq and ChIP-seq data. - -The current version has been used to annotate the Xenopus laevis genome based on experimental data. - -However, it is not yet easy to install and use as the documentation is incomplete. -In addition the tools have not been thoroughly tested on a clean installation, -which means I'm not sure all dependencies have been correctly specified. - -Prerequisites ------------- -The following Python modules are required: - -* GFF parser - http://github.com/chapmanb/bcbb/tree/master/gff -* Biopython - http://biopython.org/ -* pysam (>= 0.7.4) -* pyyaml -* networkx (>= 1.9) -* GimmeMotifs - http://github.com/simonvh/gimmemotifs -* HTSeq - http://www-huber.embl.de/users/anders/HTSeq/doc/overview.html -* numpy - -Installation ------------- - - # install prerequisites - git clone git@bitbucket.org:simonvh/pita.git - cd pita - python setup.py test - sudo python setup.py install \ No newline at end of file +pita improves transcript annotation +=================================== + +Pipeline to improve transcript annotation based on RNA-seq and ChIP-seq data. + +The current version has been used to annotate the Xenopus laevis genome based on experimental data. + +However, it is not yet easy to install and use as the documentation is incomplete. +In addition the tools have not been thoroughly tested on a clean installation, +which means I'm not sure all dependencies have been correctly specified. + +Prerequisites +------------ +The following Python modules are required: + +* GFF parser - http://github.com/chapmanb/bcbb/tree/master/gff +* Biopython - http://biopython.org/ +* pysam (>= 0.7.4) +* pyyaml +* networkx (>= 1.9) +* GimmeMotifs - http://github.com/simonvh/gimmemotifs +* HTSeq - http://www-huber.embl.de/users/anders/HTSeq/doc/overview.html +* numpy + +Installation +------------ + + # install prerequisites + git clone git@bitbucket.org:simonvh/pita.git + cd pita + python setup.py test + sudo python setup.py install + diff --git a/config/example.yaml b/config/example.yaml index a833b26..0be4714 100644 --- a/config/example.yaml +++ b/config/example.yaml @@ -1,5 +1,17 @@ # Path to data files. Data files can be specified relative to this path data_path: /home/simon/prj/laevis/annotation/XENLA_JGI7b/pita/data +genome: /usr/share/genome/XENLA_JGIv7b + +# Database connection details +# +# For file-based sqlite connection: +# database: sqlite:///database.db +# +# For mysql: mysql://user:password@server/database +# For example: +# database: mysql://pita:@localhost/pita + +database: sqlite:///pita_database.db # Comment to process all chromosomes chromosomes: diff --git a/pita/__init__.py b/pita/__init__.py index e69de29..bef6b24 100644 --- a/pita/__init__.py +++ b/pita/__init__.py @@ -0,0 +1,20 @@ +import sys +import atexit +from pita.db_backend import * + +def db_session(conn, new=False): + if not hasattr(db_session, 'session') or not db_session.session: + engine = create_engine(conn) + engine.raw_connection().connection.text_factory = str + db_session.engine = engine + if new: + Base.metadata.drop_all(db_session.engine) + Base.metadata.create_all(engine) + Base.metadata.bind = engine + db_session.session = scoped_session(sessionmaker(bind=engine)) + elif new: + db_session.session.commit() + Base.metadata.drop_all(db_session.engine) + Base.metadata.create_all(db_session.engine) + + return db_session.session diff --git a/pita/annotationdb.py b/pita/annotationdb.py new file mode 100644 index 0000000..0178e61 --- /dev/null +++ b/pita/annotationdb.py @@ -0,0 +1,623 @@ +import os +import sys +import logging +from gimmemotifs.genome_index import GenomeIndex +from sqlalchemy import and_,or_,func +from sqlalchemy.orm import joinedload,aliased +from pita import db_session +from pita.db_backend import * +from pita.util import read_statistics, get_splice_score +import yaml +from pita.io import exons_to_tabix_bed, tabix_overlap + +class AnnotationDb(): + def __init__(self, session=None, conn='mysql://pita:@localhost/pita', new=False, index=None): + self.logger = logging.getLogger("pita") + if session: + self.session = session + else: + #if conn.startswith("sqlite"): + # self.Session = db_session(conn, new) + # self.session = self.Session() + # self.engine = db_session.engine + #else: + self._init_session(conn, new) + + if index: + self.index = GenomeIndex(index) + else: + self.index = None + + self.cache_splice_stats = {} + self.cache_feature_stats = {} + #def __destroy__(self): + # self.session.close() + + def _init_session(self, conn, new=False): + self.engine = create_engine(conn) + self.engine.raw_connection().connection.text_factory = str + if new: + Base.metadata.drop_all(self.engine) + Base.metadata.create_all(self.engine) + Base.metadata.bind =self.engine + Session = scoped_session(sessionmaker(bind=self.engine)) + self.session = Session() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.session.close() + + def dump_yaml(self): + dump_dict = {} + dump_dict['feature'] = [[f.id, f.chrom.encode('ascii','ignore'), f.start, f.end, f.strand.encode('ascii','ignore'), f.ftype.encode('ascii','ignore'), f.seq.encode('ascii','ignore')] for f in self.session.query(Feature)] + + dump_dict['read_source'] = [[r.id, r.name.encode('ascii','ignore'), r.source.encode('ascii','ignore'), r.nreads] for r in self.session.query(ReadSource)] + + dump_dict['read_count'] = [[r.read_source_id, r.feature_id, r.count,r.span.encode('ascii','ignore'),r.extend_up, r.extend_down] for r in self.session.query(FeatureReadCount)] + + dump_dict['evidence'] = [[r.id, r.name.encode('ascii','ignore'), r.source.encode('ascii','ignore')] for r in self.session.query(Evidence)] + + dump_dict['feature_evidence'] = [[r.feature_id, r.evidence_id] for r in self.session.query(FeatureEvidence)] + + return yaml.dump(dump_dict) + + def load_yaml(self, fname): + data = yaml.load(open(fname)) + source_map = {} + + if not data['feature']: + return + + for old_id,name,fname,nreads in data['read_source']: + r = get_or_create(self.session, ReadSource, + name=name, source=fname, nreads=nreads) + self.session.commit() + source_map[old_id] = r.id + + + t = ["chrom","start","end","strand","ftype","seq"] + result = self.engine.execute( + Feature.__table__.insert(), + [dict(zip(t, row[1:])) for row in data['feature']] + ) + + self.session.commit() + + # print data['feature'][0][1:] + first = self.fetch_feature(data['feature'][0][1:]) + last = self.fetch_feature(data['feature'][-1][1:]) + + f_map = dict(zip([x[0] for x in data['feature']], range(first.id, last.id + 1))) + data['read_count'] = [ + [source_map[row[0]]] + [f_map[row[1]]] + row[2:] for row in data['read_count'] + ] + t = ["read_source_id", "feature_id", "count", "span", "extend_up", "extend_down"] + + result = self.engine.execute( + FeatureReadCount.__table__.insert(), + [dict(zip(t, row)) for row in data['read_count']] + ) + + if data['evidence']: + t = ["name","source"] + result = self.engine.execute( + Evidence.__table__.insert(), + [dict(zip(t, row[1:])) for row in data['evidence']] + ) + + self.session.commit() + first = self.fetch_evidence(data['evidence'][0][1:]) + last = self.fetch_evidence(data['evidence'][-1][1:]) + + ev_map = dict(zip([x[0] for x in data['evidence']], range(first.id, last.id + 1))) + + data['feature_evidence'] = [ + [f_map[row[0]], ev_map[row[1]]] for row in data['feature_evidence'] + ] + + t = ["feature_id", "evidence_id"] + result = self.engine.execute( + FeatureEvidence.__table__.insert(), + [dict(zip(t, row)) for row in data['feature_evidence']] + ) + + def add_transcript(self, name, source, exons): + """ + Add a transcript to the database + """ + + # Sanity checks + for e1, e2 in zip(exons[:-1], exons[1:]): + if e1[0] != e2[0]: + sys.stderr.write("{0} - {1}\n".format(e1, e2)) + raise ValueError, "Different chromosomes!" + if e2[1] <= e1[2]: + sys.stderr.write("{0} - {1}\n".format(e1, e2)) + raise ValueError, "exons overlap, or in wrong order" + if e1[3] != e2[3]: + sys.stderr.write("{0} - {1}\n".format(e1, e2)) + raise ValueError, "strands don't match" + + chrom = exons[0][0] + strand = exons[0][-1] + + evidence = get_or_create(self.session, Evidence, + name = name, + source=source) + + seqs = [] + for exon in exons: + seq = "" + real_seq = "" + if self.index: + seq = "" + try: + seq = self.index.get_sequence(chrom, exon[1] - 20, exon[2] + 20, strand) + real_seq = seq[20:-20] + except: + real_seq = self.index.get_sequence(chrom, exon[1], exon[2], strand) + seqs.append(seq) + + exon = get_or_create(self.session, Feature, + chrom = chrom, + start = exon[1], + end = exon[2], + strand = strand, + ftype = "exon", + seq = real_seq + ) + exon.evidences.append(evidence) + + splice_donors = [] + splice_acceptors = [] + bla = [] + for i,(start,end) in enumerate([(e1[2], e2[1]) for e1, e2 in zip(exons[0:-1], exons[1:])]): + self.logger.debug("{} {} {} {}".format(chrom, start, end, strand)) + sj = get_or_create(self.session, Feature, + chrom = chrom, + start = start, + end = end, + strand = strand, + ftype = "splice_junction" + ) + sj.evidences.append(evidence) + + if strand == "+": + if len(seqs) > (i + 1) and len(seqs[i]) > 46: + splice_donors.append(["{}_{}".format(name, i + 1), seqs[i][-23:-14]]) + if len(seqs) > (i + 2) and len(seqs[i + 1]) > 46: + f = ["{}_{}".format(name, i + 1), seqs[i + 1][:23]] + splice_acceptors.append(f) + else: + if len(seqs) > (i + 2) and len(seqs[i + 1]) > 46: + f = ["{}_{}".format(name, i + 1), seqs[i + 1][-23:-14]] + splice_donors.append(f) + + if len(seqs) > (i + 1) and len(seqs[i]) > 46: + f = ["{}_{}".format(name, i + 1), seqs[i][:23]] + splice_acceptors.append(f) + + donor_score = get_splice_score(splice_donors, 5) + acceptor_score = get_splice_score(splice_acceptors, 3) + if donor_score + acceptor_score < 0: + self.logger.warning("Skipping {}, splicing not OK!".format(name)) + self.session.rollback() + else: + self.session.commit() + #for sj in bla: + # self.logger.debug("{} {} {} {}".format(sj.id, sj.chrom, sj.start, sj.end)) + def get_features(self, ftype=None, chrom=None): + #self.session.query(Feature)#.options( + # joinedload('read_counts')).all() + + query = self.session.query(Feature) + query = query.filter(Feature.flag.op("IS NOT")(True)) + #query = query.filter(Feature.flag == Fal) + + if chrom: + query = query.filter(Feature.chrom == chrom) + if ftype: + query = query.filter(Feature.ftype == ftype) + features = [f for f in query] + return features + + def get_exons(self, chrom=None): + return self.get_features(ftype="exon", chrom=chrom) + + def get_splice_junctions(self, chrom=None, ev_count=None, read_count=None, max_reads=None): + + features = [] + if ev_count and read_count: + # All splices with no read, but more than one evidence source + fs = self.session.query(Feature).\ + filter(Feature.flag.op("IS NOT")(True)).\ + filter(Feature.ftype == "splice_junction").\ + filter(Feature.chrom == chrom).\ + outerjoin(FeatureReadCount).\ + group_by(Feature).\ + having(func.sum(FeatureReadCount.count) < read_count) + + for splice in fs: + self.logger.debug("Considering {}".format(splice)) + if len(splice.evidences) >= ev_count: + features.append(splice) + + fs = self.session.query(Feature).\ + filter(Feature.flag.op("IS NOT")(True)).\ + filter(Feature.ftype == "splice_junction").\ + filter(Feature.chrom == chrom).\ + outerjoin(FeatureReadCount).\ + group_by(Feature).\ + having(func.sum(FeatureReadCount.count) == None) + + for splice in fs: + self.logger.debug("Considering {} (no reads)".format(splice)) + if len(splice.evidences) >= ev_count: + features.append(splice) + + # All splcies with more than x reads + fs = self.session.query(Feature).\ + filter(Feature.flag.op("IS NOT")(True)).\ + filter(Feature.ftype == "splice_junction").\ + filter(Feature.chrom == chrom).\ + outerjoin(FeatureReadCount).\ + group_by(Feature).\ + having(func.sum(FeatureReadCount.count) >= read_count) + for f in fs: + self.logger.debug("Considering {} (reads)".format(f)) + features.append(f) + #features += [f for f in fs] + + elif max_reads: + fs = self.session.query(Feature).\ + filter(Feature.flag.op("IS NOT")(True)).\ + filter(Feature.ftype == "splice_junction").\ + filter(Feature.chrom == chrom).\ + outerjoin(FeatureReadCount).\ + group_by(Feature).\ + having(func.sum(FeatureReadCount.count) == None) + + features += [f for f in fs if len(f.evidences) > 0] + fs = self.session.query(Feature).\ + filter(Feature.flag.op("IS NOT")(True)).\ + filter(Feature.ftype == "splice_junction").\ + filter(Feature.chrom == chrom).\ + outerjoin(FeatureReadCount).\ + group_by(Feature).\ + having(func.sum(FeatureReadCount.count) < max_reads) + features += [f for f in fs if len(f.evidences) > 0] + else: + features = self.get_features(ftype="splice_junction", chrom=chrom) + return features + + def get_longest_3prime_exon(self, chrom, start5, strand): + if strand == "+": + + q = self.session.query(Feature).\ + filter(Feature.flag.op("IS NOT")(True)).\ + filter(Feature.ftype == "exon").\ + filter(Feature.chrom == chrom).\ + filter(Feature.strand == strand).\ + filter(Feature.start == start5).\ + order_by(Feature.end) + return q.all()[-1] + else: + q = self.session.query(Feature).\ + filter(Feature.ftype == "exon").\ + filter(Feature.chrom == chrom).\ + filter(Feature.strand == strand).\ + filter(Feature.end == start5).\ + order_by(Feature.end) + return q.all()[0] + + + def get_long_exons(self, chrom, l, evidence): + query = self.session.query(Feature) + query = query.filter(Feature.flag.op("IS NOT")(True)) + query = query.filter(Feature.ftype == 'exon') + query = query.filter(Feature.chrom == chrom) + query = query.filter(Feature.end - Feature.start >= l) + return [e for e in query if len(e.evidences) <= evidence] + + def filter_repeats(self, chrom, rep): + """ Flag all exons that overlap with a specified fraction + with a repeat track + """ + + self.logger.warn("Filtering repeats: {} with fraction {}".format(os.path.basename(rep["path"]), rep["fraction"])) + + exons = self.get_features("exon", chrom) + exon_tabix = exons_to_tabix_bed(exons) + + overlap_it = tabix_overlap(exon_tabix, rep["tabix"], chrom, rep["fraction"]) + exon_ids = [int(iv[3]) for iv in overlap_it] + + chunk = 20 + for i in range(0, len(exon_ids), chunk): + self.logger.warn("Filtering {}".format(exon_ids[i:i + chunk])) + query = self.session.query(Feature).\ + filter(Feature.id.in_(exon_ids[i:i + chunk])).\ + update({Feature.flag:True}, synchronize_session=False) + self.session.commit() + + + #fobj = TabixIteratorAsFile(tabixfile.fetch(chrom)) + #for line in fobj: + # print line + + def filter_evidence(self, chrom, source, experimental): + self.logger.debug("Filtering {}".format(source)) + #query = self.session.query(Feature).\ + # update({Feature.flag:False}, synchronize_session=False) + #self.session.commit() + + # Select all features that are supported by other evidence + n = self.session.query(Feature.id).\ + join(FeatureEvidence).\ + join(Evidence).\ + filter(Evidence.source != source).\ + filter(Evidence.source not in experimental).\ + filter(Feature.chrom == chrom).\ + subquery("n") + + # Select the total number of transcript from this source + # per feature + s = self.session.query(Feature.id, func.count('*').label('total')).\ + join(FeatureEvidence).\ + join(Evidence).\ + filter(Evidence.source == source).\ + filter(Feature.chrom == chrom).\ + group_by(Feature.id).\ + subquery("s") + + # Select all features where support from this source + # is only 1 transcript and which is not supported by any + # other sources + a = self.session.query(Feature.id).filter(and_( + Feature.id == s.c.id, + s.c.total == 1)).\ + filter(Feature.id.notin_(n)).\ + subquery("a") + + #ids = [i[0] for i in query] + + # Flag features + query = self.session.query(Feature).\ + filter(Feature.id.in_(a)).\ + update({Feature.flag:True}, synchronize_session=False) + self.session.commit() + + def get_read_statistics(self, chrom, fnames, name, span="all", extend=(0,0), nreads=None): + from fluff.fluffio import get_binned_stats + from tempfile import NamedTemporaryFile + + if span not in ["all", "start", "end"]: + raise Exception("Incorrect span: {}".format(span)) + + tmp = NamedTemporaryFile(delete=False) + estore = {} + self.logger.debug("Writing exons to file {}".format(tmp.name)) + exons = self.get_exons(chrom) + if len(exons) == 0: + return + + for exon in exons: + start = exon.start + end = exon.end + if span == "start": + if exon.strand == "+": + end = start + elif exon.strand == "-": + start = end + if span == "end": + if exon.strand == "+": + start = end + elif exon.strand == "-": + end = start + + if exon.strand == "-": + start -= extend[1] + end += extend[0] + else: + start -= extend[0] + end += extend[1] + if start < 0: + start = 0 + + estr = "{}:{}-{}".format(exon.chrom, start, end) + + if estore.has_key(estr): + estore[estr].append(exon) + else: + estore[estr] = [exon] + tmp.write("{}\t{}\t{}\t{}\t{}\t{}\n".format( + exon.chrom, + start, + end, + str(exon), + 0, + exon.strand + )) + tmp.flush() + + if type("") == type(fnames): + fnames = [fnames] + + for i, fname in enumerate(fnames): + self.logger.debug("Creating read_source for{} {}".format(name, fname)) + read_source = get_or_create(self.session, ReadSource, name=name, source=fname) + self.session.commit() + if fname.endswith("bam") and (not nreads or not nreads[i]): + rmrepeats = True + self.logger.debug("Counting reads in {0}".format(fname)) + read_source.nreads = read_statistics(fname) + else: + rmrepeats = False + + self.logger.debug("Getting overlap from {0}".format(fname)) + result = get_binned_stats(tmp.name, fname, 1, rpkm=False, rmdup=False, rmrepeats=False) + + self.logger.debug("Reading results, save to exon stats") + + insert_vals = [] + for row in result: + vals = row.strip().split("\t") + e = "%s:%s-%s" % (vals[0], vals[1], vals[2]) + c = float(vals[3]) + for exon in estore[e]: + insert_vals.append([read_source.id, exon.id, c, span, extend[0], extend[1]]) + + t = ["read_source_id", "feature_id", "count", "span", "extend_up", "extend_down"] + result = self.engine.execute( + FeatureReadCount.__table__.insert(), + [dict(zip(t,row)) for row in insert_vals] + ) + + tmp.close() + + def get_splice_statistics(self, chrom, fnames, name): + if type("") == type(fnames): + fnames = [fnames] + + nrsplice = {} + for fname in fnames: + self.logger.debug("Getting splicing data from {0}".format(fname)) + read_source = get_or_create(self.session, ReadSource, name=name, source=fname) + self.session.commit() + for line in open(fname): + vals = line.strip().split("\t") + if vals[0] == chrom: + start, end, c = [int(x) for x in vals[1:4]] + strand = vals[5] + + splice = get_or_create(self.session, Feature, + chrom = chrom, + start = start, + end = end, + strand = strand, + ftype = "splice_junction" + ) + self.session.commit() + + count = get_or_create(self.session, FeatureReadCount, + feature_id = splice.id, + read_source_id = read_source.id) + + if not count.count: + count.count = c + else: + count.count += c + + self.session.commit() + + def get_junction_exons(self, junction): + + left = self.session.query(Feature).filter(and_( + Feature.chrom == junction.chrom, + Feature.strand == junction.strand, + Feature.end == junction.start, + Feature.ftype == "exon" + )).\ + filter(Feature.flag.op("IS NOT")(True)) + + + right = self.session.query(Feature).filter(and_( + Feature.chrom == junction.chrom, + Feature.strand == junction.strand, + Feature.start == junction.end, + Feature.ftype == "exon" + )).\ + filter(Feature.flag.op("IS NOT")(True)) + + exon_pairs = [] + for e1 in left: + for e2 in right: + exon_pairs.append((e1, e2)) + return exon_pairs + + def clear_stats_cache(self): + self.cache_feature_stats = {} + self.cache_splice_stats = {} + + def feature_stats(self, feature, identifier): + if not self.cache_feature_stats.has_key("{}{}".format(feature, identifier)): + + q = self.session.query(FeatureReadCount, ReadSource).join(ReadSource) + q = q.filter(FeatureReadCount.feature_id == feature.id) + q = q.filter(ReadSource.name == identifier) + self.cache_feature_stats["{}{}".format(feature, identifier)] = sum([x[0].count for x in q.all()]) + + return self.cache_feature_stats["{}{}".format(feature, identifier)] + + def splice_stats(self, exon1, exon2, identifier): + if not self.cache_splice_stats.has_key("{}{}{}".format(self, exon1, exon2)): + + q = self.session.query(Feature) + q = q.filter(Feature.ftype == "splice_junction") + q = q.filter(Feature.chrom == exon1.chrom) + q = q.filter(Feature.strand == exon1.strand) + q = q.filter(Feature.start == exon1.end) + q = q.filter(Feature.end == exon2.start) + + splice = q.first() + + self.cache_splice_stats["{}{}{}".format(self, exon1, exon2)] = self.feature_stats(splice, identifier) + + return self.cache_splice_stats["{}{}{}".format(self, exon1, exon2)] + + def nreads(self, identifier): + q = self.session.query(ReadSource) + q = q.filter(ReadSource.name == identifier) + return sum([s.nreads for s in q.all()]) + + def get_splice_count(self, e1, e2): + counts = self.session.query(func.sum(FeatureReadCount.count)).\ + join(Feature).\ + filter(Feature.chrom == e1.chrom).\ + filter(Feature.start == e1.end).\ + filter(Feature.end == e2.start).\ + filter(Feature.ftype == "splice_junction").\ + group_by(Feature.id).all() + return sum([int(x[0]) for x in counts]) + + def fetch_feature(self, f): + """ Feature as list """ + chrom, start, end, strand, ftype, seq = f + + feature = self.session.query(Feature).\ + filter(Feature.chrom == chrom).\ + filter(Feature.start == start).\ + filter(Feature.end == end).\ + filter(Feature.strand == strand).\ + filter(Feature.ftype == ftype).\ + filter(Feature.seq == seq) + result = feature.first() + return result + + def fetch_evidence(self, f): + """ Feature as list """ + name, source = f + evidence = self.session.query(Evidence).\ + filter(Evidence.name == name).\ + filter(Evidence.source == source) + + result = evidence.first() + return result + + def get_transcript_statistics(self, exons): + stats = [] + for exon in exons: + q = self.session.query(ReadSource, + func.sum(FeatureReadCount.count)).\ + join(FeatureReadCount).\ + join(Feature).\ + filter(Feature.chrom == exon[0]).\ + filter(Feature.start == exon[1]).\ + filter(Feature.end == exon[2]).\ + filter(Feature.strand == exon[3]).\ + group_by(ReadSource.name) + stats.append(dict([(row[0].name, int(row[1])) for row in q.all()])) + return stats + diff --git a/pita/config.py b/pita/config.py index 63bf544..15cbab4 100644 --- a/pita/config.py +++ b/pita/config.py @@ -1,6 +1,160 @@ +import logging +import yaml +import os +import sys +import pysam +import subprocess +from tempfile import NamedTemporaryFile +from pita.io import _create_tabix + SAMTOOLS = "samtools" TSS_FOUND = "v" TSS_UPSTREAM = "u" TSS_DOWNSTREAM = "a" TSS_NOTFOUND = "x" SEP = ":::" +VALID_TYPES = ["bed", "gff", "gff3", "gtf"] +DEBUG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] + +class PitaConfig: + def __init__(self, fname, reannotate=False): + """ fname: name of yaml configuration file + """ + + self.logger = logging.getLogger("pita") + + # Parse YAML config file + f = open(fname, "r") + self.config = yaml.load(f) + f.close() + + self.db_conn = "sqlite:///pita_database.db" + if self.config.has_key("database"): + self.db_conn = self.config["database"] + + # Data directory + self.base = "." + if self.config.has_key("data_path"): + self.base = self.config["data_path"] + + # Prune overlaps + self.prune = None + if self.config.has_key("prune"): + self.prune = self.config["prune"] + + self.keep = [] + self.filter = [] + self.experimental = [] + + # Scoring weight + self.weight = {} + if self.config.has_key("scoring"): + self.weight = self.config["scoring"] + + self._parse_repeats() + + # load annotation files + self._parse_annotation(reannotate) + + # only use chromosome specified in config file + self.chroms = self.chroms.keys() + if self.config.has_key("chromosomes") and self.config["chromosomes"]: + if type(self.config["chromosomes"]) == type([]): + self.chroms = self.config["chromosomes"] + else: + self.chroms = [self.config["chromosomes"]] + + # check the data files + self._check_data_files() + + # output option + self.min_protein_size = 20 + + def _parse_repeats(self): + self.repeats = [] + if self.config.has_key("repeats"): + for d in self.config["repeats"]: + fname = os.path.join(self.base, d["path"]) + tabix_fname = _create_tabix(fname, "bed") + d["path"] = fname + d["tabix"] = tabix_fname + + self.repeats.append(d) + + def _parse_annotation(self, reannotate=False): + + if not self.config.has_key("annotation") or len(self.config["annotation"]) == 0: + self.logger.error("No annotation files specified.") + sys.exit(1) + + self.anno_files = [] + self.chroms = {} + for d in self.config["annotation"]: + self.logger.debug("annotation: {0}".format(d)) + fname = os.path.join(self.base, d["path"]) + t = d["type"].lower() + min_exons = 2 + if d.has_key("min_exons"): + min_exons = d["min_exons"] + + if d.has_key("keep") and d["keep"]: + self.keep.append(fname) + + if d.has_key("filter") and d["filter"]: + self.filter.append(fname) + + if d.has_key("experimental") and d["experimental"]: + self.experimental.append(fname) + + if not t in VALID_TYPES: + self.logger.error("Invalid type: {0}".format(t)) + sys.exit(1) + if not os.path.exists(fname): + self.logger.error("File does not exist: {0}".format(fname)) + sys.exit(1) + else: + tabix_file = "" + if not reannotate: + tabix_file = _create_tabix(fname, t) + + # Save chromosome names + for chrom in pysam.Tabixfile(tabix_file).contigs: + self.chroms[chrom] = 1 + + # Add file info + self.anno_files.append([d["name"], fname, tabix_file, t, min_exons]) + + def _check_data_files(self): + # data config + self.logger.info("Checking data files") + self.data = [] + if self.config.has_key("data") and self.config["data"]: + for d in self.config["data"]: + self.logger.debug("data: {0}".format(d)) + d.setdefault("up", 0) + d.setdefault("down", 0) + if type("") == type(d["path"]): + d["path"] = [d["path"]] + + d.setdefault("feature", "all") + if d["feature"] not in ["all", "start", "end", "splice"]: + self.logger.error("Incorrect span: {}".format(d["feature"])) + sys.exit(1) + + names_and_stats = [] + fnames = [os.path.join(self.base, x) for x in d["path"]] + for fname in fnames: + if not os.path.exists(fname): + self.logger.error("File does not exist: {0}".format(fname)) + sys.exit(1) + + if fname.endswith("bam") and not os.path.exists(fname + ".bai"): + self.logger.error("BAM file {0} needs to be indexed!".format(fname)) + sys.exit(1) + + #if fname.endswith("bam"): + # names_and_stats.append((fname, read_statistics(fname))) + #else: + # names_and_stats.append((fname, None)) + row = [d["name"], fnames, d["feature"], (int(d["up"]), int(d["down"]))] + self.data.append(row) diff --git a/pita/db_backend.py b/pita/db_backend.py new file mode 100644 index 0000000..b77e1f3 --- /dev/null +++ b/pita/db_backend.py @@ -0,0 +1,101 @@ +from sqlalchemy import Column, ForeignKey, Integer, String, Text, UniqueConstraint, Boolean +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship,sessionmaker,mapper,scoped_session +from sqlalchemy import create_engine, and_, event +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.inspection import inspect +from pita.exon import Exon + +Base = declarative_base() + +class FeatureEvidence(Base): + __tablename__ = 'feature_evidence' + feature_id = Column(Integer, ForeignKey('feature.id'), primary_key=True) + evidence_id = Column(Integer, ForeignKey('evidence.id'), primary_key=True) + evidence = relationship("Evidence") + feature = relationship("Feature") + +class Feature(Base): + __tablename__ = 'feature' + __table_args__ = ( + UniqueConstraint( + 'chrom', + 'start', + 'end', + 'strand', + 'ftype', + name='uix_1'), + ) + id = Column(Integer, primary_key=True) + chrom = Column(String(250), nullable=False) + start = Column(Integer, nullable=False) + end = Column(Integer, nullable=False) + strand = Column(String(1), nullable=False) + ftype = Column(String(250), nullable=False) + seq = Column(Text(), default="") + flag = Column(Boolean(), default=False) + _evidences = relationship('FeatureEvidence') + evidences = association_proxy('_evidences', 'evidence', + creator=lambda _i: FeatureEvidence(evidence=_i), + ) + read_counts = relationship('FeatureReadCount') + + def __str__(self): + return self.to_loc() + + def to_loc(self): + return "{}:{}{}{}".format( + self.chrom, self.start, self.strand, self.end + ) + + def to_sloc(self): + return "{}{}{}".format( + self.start, self.strand, self.end + ) + + def to_flat_exon(self): + e = Exon(self.chrom, self.start, self.end, self.strand) + e.seq = self.seq + return e + + def overlap(self, exon, strand=True, fraction=False): + if strand and self.strand != exon.strand: + return 0 + if exon.start >= self.start and exon.start <= self.end: + return self.end - exon.start + if exon.end >= self.start and exon.end <= self.end: + return exon.end - self.start + return 0 + +class Evidence(Base): + __tablename__ = "evidence" + id = Column(Integer, primary_key=True) + name = Column(String(250)) + source = Column(String(250)) + +class ReadSource(Base): + __tablename__ = "read_source" + id = Column(Integer, primary_key=True) + name = Column(String(250)) + source = Column(String(250)) + nreads = Column(Integer) + +class FeatureReadCount(Base): + __tablename__ = "read_count" + read_source_id = Column(Integer, ForeignKey('read_source.id'), primary_key=True) + feature_id = Column(Integer, ForeignKey('feature.id'), primary_key=True) + read_source = relationship("ReadSource") + feature = relationship("Feature") + count = Column(Integer, default=0) + span = Column(String(50), default="all", primary_key=True) + extend_up = Column(Integer, default=0, primary_key=True) + extend_down = Column(Integer, default=0, primary_key=True) + +def get_or_create(session, model, **kwargs): + instance = session.query(model).filter_by(**kwargs).first() + if instance: + return instance + else: + instance = model(**kwargs) + session.add(instance) + return instance diff --git a/pita/dbcollection.py b/pita/dbcollection.py new file mode 100644 index 0000000..630269f --- /dev/null +++ b/pita/dbcollection.py @@ -0,0 +1,434 @@ +from pita.exon import * +from pita.util import read_statistics +from pita.util import longest_orf,exons_to_seq,model_to_bed +from pita.config import SEP +import numpy as np +import sys +import logging +import pickle +from networkx.algorithms.components.connected import connected_components +import networkx as nx +from networkx.algorithms.connectivity import minimum_st_node_cut +from networkx.algorithms.flow import ford_fulkerson +from itertools import izip, count +from gimmemotifs.genome_index import GenomeIndex +import random + +def connected_models(graph): + for u, v in graph.edges(): + graph[u][v]['weight'] = -1 + + for c in nx.weakly_connected_components(graph): + starts = [k for k,v in graph.in_degree(c).items() if v == 0] + ends = [k for k,v in graph.out_degree(c).items() if v == 0] + paths = [] + + for i,s in enumerate(starts): + order,d = nx.bellman_ford(graph,s, weight='weight') + + for e in ends: + if d.has_key(e): + path = [e] + x = e + while order[x]: + path.append(order[x]) + x = order[x] + + paths.append(path[::-1]) + yield paths + +def recursive_neighbors(graph, node_list): + result = [] + for node in node_list: + result += graph.neighbors(node) + new_list = list(set(node_list + result)) + for node in result: + if not node in node_list: + return recursive_neighbors(graph, new_list) + return new_list + + +class DbCollection: + def __init__(self, db, chrom=None): + # dict with chrom as key + self.logger = logging.getLogger("pita") + + self.db = db + self.chrom = chrom + + # All transcript models will be stored as a directed (acyclic) graph + self.graph = nx.DiGraph() + + # Store read counts of BAM files + self.nreads = {} + + # Store extension used in BAM statistics + self.extend = {} + + self.logger.debug("Loading exons in graph") + for exon in self.db.get_exons(chrom): + self.add_feature(exon) + + self.logger.debug("Loading introns in graph") + n = 0 + for junction in self.db.get_splice_junctions(chrom, ev_count=1, read_count=20): + #for junction in self.db.get_splice_junctions(chrom): + #print junction + n += 1 + self.add_feature(junction) + self.logger.debug("{} introns were loaded".format(n)) + + + def add_feature(self, feature): + """ + """ + + if feature.ftype == "exon": + exon = feature + + # Add chromosome to keys + self.graph.add_node(feature) + self.graph.node[feature]['weight'] = 1 + #if self.index: + # e.seq = self.index.get_sequence(chrom, start, end, strand) + + elif feature.ftype == "splice_junction": + + # Add transcript model to the graph + for exons in self.db.get_junction_exons(feature): + self.graph.add_path(exons) + +# def remove_exon(self, e): +# if e in self.graph: +# self.logger.info("Removing exon {0}".format(e)) +# self.graph.remove_node(e) +# del self.exons[e.chrom][to_sloc(e.start, e.end, e.strand)] + + def get_initial_exons(self, chrom=None): + """ Return all leftmost exons + """ + in_degree = self.graph.in_degree(self.get_exons(chrom)).items() + return [k for k,v in in_degree if v == 0] + + def get_connected_models(self): + for paths in connected_models(self.graph): + if len(paths) > 0: + self.logger.debug("yielding {0} paths".format(len(paths))) + yield paths + + def get_node_cuts(self, model): + node_cuts = [] + cuts = list(minimum_st_node_cut(self.graph, model[0], model[-1], flow_func=ford_fulkerson)) + while len(cuts) == 1: + node_cuts = cuts + node_cuts + cuts = list(minimum_st_node_cut(self.graph, model[0], cuts[0], flow_func=ford_fulkerson)) + return node_cuts + + def get_best_variant(self, model, weight): + + if len(model) == 1: + return model + + nodeset = self.get_node_cuts(model) + if len(list(nodeset)) > 0: + self.logger.debug("option 1") + self.logger.debug("{}".format(str(nodeset))) + nodeset = [model[0]] + list(nodeset) + [model[-1]] + self.logger.debug("got nodeset") + best_variant = [model[0]] + for n1,n2 in zip(nodeset[:-1], nodeset[1:]): + self.logger.debug("{} {}".format(str(n1), str(n2))) + variants = [m for m in self.all_simple_paths(n1, n2)] + self.logger.debug("Got {} variants".format(len(variants))) + best_variant += self.max_weight(variants, weight)[1:] + self.logger.debug("Best variant".format(best_variant)) + else: + variants = [m for m in self.all_simple_paths(model[0], model[-1])] + best_variant = self.max_weight(variants, weight) + + e = best_variant[-1] + if e.strand == "+": + best_variant[-1] = self.db.get_longest_3prime_exon(e.chrom, e.start, e.strand) + else: + e = best_variant[0] + best_variant[0] = self.db.get_longest_3prime_exon(e.chrom, e.end, e.strand) + return best_variant + + def prune(self): + pruned = [] + + for i,cluster in enumerate(self.get_connected_models()): + self.logger.debug("Pruning {0} models".format(len(cluster))) + #print i + 1 + + discard = [] + new_cluster = [m for m in cluster] + + while len(new_cluster) > 0: + #print len(new_cluster) + #c_min = min([m[0].start for m in new_cluster]) + #c_max = max([m[-1].end for m in new_cluster]) + #print c_min, c_max + #selection = [m for m in new_cluster if m[0].start == c_min or m[-1].end == c_max] + + longest = sorted(new_cluster, cmp=lambda x,y: cmp(x[-1].end - x[0].start, y[-1].end - y[0].start))[-1] + discard.append(longest) + new_cluster = [m for m in new_cluster if m != longest] + if len(new_cluster) != 0: + + graph = nx.DiGraph() + for m in new_cluster: + graph.add_path(m) + + result = [x for x in connected_models(graph) if len([y for y in x if len(y) > 2]) > 1] + if len(result) > 1: + break + + if len(new_cluster) != 0: + #print len(new_cluster) + discard_edges = [] + for m in discard: + for e1, e2 in zip(m[:-1], m[1:]): + discard_edges.append((e1, e2)) + + keep_edges = [] + for m in new_cluster: + for e1, e2 in zip(m[:-1], m[1:]): + keep_edges.append((e1, e2)) + + for x in set(discard_edges) - set(keep_edges): + self.graph.remove_edge(x[0], x[1]) + + pruned.append([x[0].chrom, x[0].end, x[1].start]) + + return pruned + + def is_weak_splice(self, splice, evidence=1): + exons = [] + my = [] + for e1,e2 in self.db.get_junction_exons(splice): + if e1 in self.graph and e2 in self.graph and (e1, e2) in self.graph.edges(): + my.append((e1.end, e2.start)) + for e in [e1, e2]: + if not e in exons: + exons.append(e) + + if len(exons) == 0: + return False + + splices = self.graph.edges(recursive_neighbors(self.graph, exons)) + if len(splices) == 1: + return False + + counts = {} + for s in splices: + if not counts.has_key((s[0].end, s[1].start)): + counts[(s[0].end, s[1].start)] = self.db.get_splice_count(s[0], s[1]) + + bla = [v for k,v in counts.items() if k not in my] + if len(bla) == 0: + return False + self.logger.debug("{} {} {}".format(counts[my[0]], np.mean(bla), np.std(bla))) + return counts[my[0]] < 0.1 * np.mean(bla)# - np.std(bla)) + + def prune_splice_junctions(self, max_reads=10, evidence=2, keep=[]): + keep = set(keep) + for splice in self.db.get_splice_junctions(self.chrom, max_reads=max_reads): + self.logger.debug("Splice {}, evidence {}".format(splice, len(splice.evidences))) + ev_sources = [e.source for e in splice.evidences] + if keep.intersection(ev_sources): + self.logger.debug("Keeping this splice {}".format(splice)) + continue + if len(splice.evidences) <= evidence: + self.logger.debug("Checking splice {}".format(splice)) + if self.is_weak_splice(splice, evidence): + self.logger.debug("Removing splice {}".format(splice)) + for e1,e2 in self.db.get_junction_exons(splice): + if (e1,e2) in self.graph.edges(): + self.graph.remove_edge(e1, e2) + for node in e1, e2: + if len(self.graph.edges(node)) == 0: + self.logger.debug("Removing lonely exon {}".format(node)) + self.graph.remove_node(node) + + def filter_long(self, l=1000, evidence=2): + #print "HOIE" + for exon in self.db.get_long_exons(self.chrom, l, evidence): + out_edges = len(self.graph.out_edges([exon])) + in_edges = len(self.graph.in_edges([exon])) + self.logger.debug("Filter long: {}, in {} out {}".format(exon, in_edges,out_edges)) + + #print exon, exon.strand, in_edges, out_edges + if in_edges >= 0 and out_edges >= 1 and exon.strand == "+" or in_edges >= 1 and out_edges >= 0 and exon.strand == "-": + #print "Removing", exon + self.logger.info("Removing long exon {0}".format(exon)) + self.graph.remove_node(exon) + + def filter_and_merge(self, nodes, l): + for e1, e2 in self.graph.edges_iter(nodes): + if e2.start - e1.end <= l: + new_exon = self.add_exon(e1.chrom, e1.start, e2.end, e1.strand) + self.logger.info("Adding {0}".format(new_exon)) + for e_in in [e[0] for e in self.graph.in_edges([e1])]: + self.graph.add_edge(e_in, new_exon) + for e_out in [e[1] for e in self.graph.out_edges([e2])]: + self.graph.add_edge(new_exon, e2) + + for e in (e1, e2): + self.remove_exon(e) + + return new_exon + return None + + def filter_short_introns(self, l=10, mode='merge'): + filter_nodes = [] + for intron in self.graph.edges_iter(): + e1,e2 = intron + if e2.start - e1.end <= l: + filter_nodes += [e1, e2] + + if mode == "merge": + exon = self.filter_and_merge(filter_nodes, l) + while exon: + exon = self.filter_and_merge(filter_nodes + [exon], l) + else: + for e in filter_nodes: + self.remove_exon(e) + + def all_simple_paths(self, exon1, exon2): + return nx.all_simple_paths(self.graph, exon1, exon2) + + def get_alt_splicing_exons(self): + for exon in self.get_exons(): + out_exons = [e[1] for e in self.graph.out_edges([exon]) if len(self.graph.out_edges([e[1]])) > 0] + if len(out_exons) > 1: + + out_exon = out_exons[0] + + self.logger.info("ALT SPLICING {0} {1}".format(exon, out_exon)) + + #in_exons = [e[0] for e in self.graph.in_edges([exon])] + #for in_exon in in_exons: + # my_in_exons = [e[0] for e in self.graph.in_edges([in_exon])] + # for my_in_exon in my_in_exons: + # if my_in_exon in in_exons: + # self.logger.info("{0} is alternative exon".format(in_exon)) + + def get_weight(self, transcript, identifier, idtype): + signal = [self.db.feature_stats(e, identifier) for e in transcript] + exon_lengths = [e.end - e.start for e in transcript] + + if idtype in ["all", "rpkm", "weighted"]: + total_signal = float(sum(signal)) + total_exon_length = sum(exon_lengths) + + if idtype == "all": + return total_signal + elif idtype == "rpkm": + if total_signal == 0: + return 0 + return float(total_signal) / (self.db.nreads(identifier) / 1e6) / total_exon_length * 1000.0 + + elif idtype == "weighted": + return float(total_signal) / total_exon_length * len(transcript) + + elif idtype in ["mean_exon", "total_rpkm"]: + all_exons = [s/float(l) for s,l in zip(signal, exon_lengths)] + nreads = self.db.nreads(identifier) + if not nreads: + nreads = 1000000 + self.logger.warn("Number of reads in db is 0 for {}".format(identifier)) + rpkms = [s * 1000.0 / nreads * 1e6 for s in all_exons] + if idtype == "mean_exon": + if len(rpkms) == 0: + self.logger.warning("Empty score array for mean_exon") + return 0 + else: + return np.mean(rpkms) + + if idtype == "total_rpkm": + return sum(rpkms) + + elif idtype == "first": + if transcript[0].strand == "+": + return signal[0] + else: + return signal[-1] + + elif idtype == "first_rpkm": + if transcript[0].strand == "+": + exon = transcript[0] + count = signal[0] + else: + exon = transcript[-1] + count = signal[-1] + + size = exon.end - exon.start + extend = self.extend.setdefault(identifier, (0,0)) + size += extend[0] + extend[1] + if count == 0: + return 0 + + rpkm = count / (self.db.nreads(identifier)/ 1e6) / size * 1000.0 + + return rpkm + + elif idtype == "splice": + if len(transcript) == 1: + return 0 + w = [] + for e1, e2 in zip(transcript[:-1], transcript[1:]): + w.append(self.db.splice_stats(e1, e2, identifier)) + if len(w) == 0: + self.logger.warning("Empty score array for splice") + return 0 + else: + return np.sum(w) + + elif idtype == "orf": + start, end = longest_orf(exons_to_seq(transcript)) + return end - start + + elif idtype == "evidence": + #return 1 + evidences = [len(exon.evidences) for exon in transcript] + if len(evidences) == 0: + self.logger.warning("Empty score array for evidence") + return 0 + else: + return np.mean(evidences) + + elif idtype == "length": + return np.sum([e.end - e.start for e in transcript]) + + else: + raise Exception, "Unknown idtype" + + def max_weight(self, transcripts, identifier_weight): + max_transcripts = 10000 + if len(transcripts) > max_transcripts: + self.logger.warn("More than {} transcripts, random sampling to a managable number".format(max_transcripts)) + transcripts = random.sample(transcripts, max_transcripts) + + if not identifier_weight or len(identifier_weight) == 0: + w = [len(t) for t in transcripts] + else: + w = np.array([0] * len(transcripts)) + pseudo = 1e-10 + for iw in identifier_weight: + weight = iw["weight"] + idtype = iw["type"] + identifier = iw["name"] + + idw = [] + for i, transcript in enumerate(transcripts): + #if i % 10000 == 0: + # self.logger.debug("{} transcripts processed, weight {}".format(i, identifier)) + tw = self.get_weight(transcript, identifier, idtype) + idw.append(pseudo + tw) + + idw = np.array(idw) + idw = idw / max(idw) * weight + w = w + idw + self.db.clear_stats_cache() + return transcripts[np.argmax(w)] diff --git a/pita/io.py b/pita/io.py index 81fe015..6aab8bc 100644 --- a/pita/io.py +++ b/pita/io.py @@ -1,8 +1,86 @@ from BCBio import GFF -from pita.collection import * +#from pita.collection import * import pprint import sys +import os import logging +from tempfile import NamedTemporaryFile +import subprocess as sp +import pysam +import pybedtools + +def _create_tabix(fname, ftype): + logger = logging.getLogger("pita") + tabix_file = "" + logger.info("Creating tabix index for {0}".format(os.path.basename(fname))) + logger.debug("Preparing {0} for tabix".format(fname)) + tmp = NamedTemporaryFile(prefix="pita", delete=False) + preset = "gff" + if ftype == "bed": + cmd = "sort -k1,1 -k2g,2 {0} | grep -v track | grep -v \"^#\" > {1}" + preset = "bed" + elif ftype in ["gff", "gff3", "gtf3"]: + cmd = "sort -k1,1 -k4g,4 {0} | grep -v \"^#\" > {1}" + + # Sort the input file + logger.debug(cmd.format(fname, tmp.name)) + sp.call(cmd.format(fname, tmp.name), shell=True) + # Compress using bgzip + logger.debug("compressing {0}".format(tmp.name)) + tabix_file = tmp.name + ".gz" + pysam.tabix_compress(tmp.name, tabix_file) + tmp.close() + # Index (using tabix command line, as pysam.index results in a Segmentation fault + logger.debug("indexing {0}".format(tabix_file)) + sp.call("tabix {0} -p {1}".format(tabix_file, preset), shell=True) + return tabix_file + +def exons_to_tabix_bed(exons): + logger = logging.getLogger("pita") + logger.debug("Converting {} exons to tabix bed".format(len(exons))) + tmp = NamedTemporaryFile(prefix="pita", delete=False) + logger.debug("Temp name {}".format(tmp.name)) + for exon in sorted(exons, cmp=lambda x,y: cmp([x.chrom, x.start], [y.chrom, y.start])): + tmp.write("{}\t{}\t{}\t{}\t{}\t{}\n".format( + exon.chrom, exon.start, exon.end, exon.id, "0", exon.strand)) + + tmp.close() + tabix_fname = _create_tabix(tmp.name, "bed") + return tabix_fname + +def tabix_overlap(fname1, fname2, chrom, fraction): + logger = logging.getLogger("pita") + logger.debug("TABIX overlap between {} and {}, {}".format(fname1, fname2, fraction)) + + tab1 = pysam.Tabixfile(fname1) + tab2 = pysam.Tabixfile(fname2) + + if not ((chrom in tab1.contigs) and (chrom in tab2.contigs)): + return + + fobj1 = TabixIteratorAsFile(tab1.fetch(chrom)) + tmp1 = NamedTemporaryFile(prefix="pita.", delete=False) + for line in fobj1.readlines(): + tmp1.write("{}\n".format(line.strip())) + + fobj2 = TabixIteratorAsFile(tab2.fetch(chrom)) + tmp2 = NamedTemporaryFile(prefix="pita.", delete=False) + for line in fobj2.readlines(): + tmp2.write("{}\n".format(line.strip())) + + tmp1.flush() + tmp2.flush() + + b1 = pybedtools.BedTool(tmp1.name) + b2 = pybedtools.BedTool(tmp2.name) + + intersect = b1.intersect(b2, f=fraction) + + tmp1.close() + tmp2.close() + + for f in intersect: + yield f def merge_exons(starts, sizes, l=0): merge = [] @@ -73,9 +151,6 @@ def read_gff_transcripts(fobj, fname="", min_exons=1, merge=0): return transcripts - - - def read_bed_transcripts(fobj, fname="", min_exons=1, merge=0): # Setup logging @@ -139,3 +214,12 @@ def readline(self): except StopIteration: return None + def readlines(self): + line = self.readline() + while line: + yield line + line = self.readline() + + + + diff --git a/pita/log.py b/pita/log.py index d3cb0a4..79fa3c3 100644 --- a/pita/log.py +++ b/pita/log.py @@ -1,65 +1,26 @@ -from pita.collection import get_updated_exons -#import itertools import logging +import sys +from pita.config import DEBUG_LEVELS -class AnnotationLog: - header = "Model\tNr. Exons\tExons in best model\tExons in other models\tUpdated 5'\tUpdated 3'\tOriginal models\n" - log_str = "{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n" - def __init__(self, append): - self.files = {} - self.logger = logging.getLogger('pita') - self.append = append - - def add(self, name): - mode = "w" - if self.append: - mode = "a" - self.files[name] = open("pita.{0}.log".format(name), mode) - self.files[name].write(self.header) - - def log_to_file(self, genename, model, ev, best_ev, other_ev): -# best_exons = [e for e in model] -# best_ev = {} -# for e in best_exons: -# for ev in set([x.split(":")[0] for x in e.evidence]): -# best_ev[ev] = best_ev.setdefault(ev, 0) + 1 -# other_exons = [] -# other_ev = {} -# #self.logger.debug("cluster {0}".format(len(cluster))) -# -# # Fast way to collapse -# other_exons = [e for e in set(itertools.chain.from_iterable(cluster)) if not e in best_exons] -# for e in other_exons: -# for ev in set([x.split(":")[0] for x in e.evidence]): -# other_ev[ev] = other_ev.setdefault(ev, 0) + 1 -# -# ev = [] -# for e in best_exons + other_exons: -# for evidence in e.evidence: -# ev.append(evidence.split(":")) - - # ev, model, best_exons - for name, f in self.files.items(): - orig_models = {} - for (origin,orig_name) in ev: - if origin == name: - orig_models[orig_name] = orig_models.setdefault(orig_name, 0) + 1 +def setup_logging(basename, debug_level): + debug_level = debug_level.upper() - u5, u3 = get_updated_exons(model, name) - f.write(self.log_str.format( - genename, - len(model), - best_ev.setdefault(name, 0), - other_ev.setdefault(name, 0), - u5, - u3, - ",".join(orig_models.keys()) - ) - ) - f.flush() + if not debug_level in DEBUG_LEVELS: + sys.stderr.write("Invalid debug level {0}\n".format(debug_level)) + sys.stderr.write("Valid values are {0}\n".format(",".join(DEBUG_LEVELS))) + sys.exit(1) - def __del__(self): - for f in self.files.values(): - f.close() + logger = logging.getLogger("pita") + logger.setLevel(getattr(logging, debug_level)) + formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + handler.setLevel(getattr(logging, debug_level)) + fh = logging.FileHandler("{0}.log".format(basename)) + fh.setFormatter(formatter) + fh.setLevel(getattr(logging, debug_level)) + logger.addHandler(handler) + logger.addHandler(fh) + return logger diff --git a/pita/model.py b/pita/model.py index 03ee583..4a90e60 100644 --- a/pita/model.py +++ b/pita/model.py @@ -1,4 +1,5 @@ -from pita.collection import Collection +from pita.dbcollection import DbCollection +from pita.annotationdb import AnnotationDb from pita.io import TabixIteratorAsFile, read_gff_transcripts, read_bed_transcripts from pita.util import get_overlapping_models,to_genomic_orf,longest_orf import logging @@ -8,17 +9,17 @@ from tempfile import NamedTemporaryFile from pita.config import SEP -def get_chrom_models(chrom, anno_files, data, weight, prune=None, index=None): - +def load_chrom_data(conn, new, chrom, anno_files, data, index=None): logger = logging.getLogger("pita") try: # Read annotation files - mc = Collection(index) + db = AnnotationDb(index=index, conn=conn, new=new) + logger.debug("{} {}".format(chrom, id(db.session))) logger.info("Reading annotation for {0}".format(chrom)) - for name, fname, ftype, min_exons in anno_files: - logger.debug("Reading annotation from {0}".format(fname)) - tabixfile = pysam.Tabixfile(fname) + for name, fname, tabix_file, ftype, min_exons in anno_files: + logger.info("Reading annotation from {0}".format(fname)) + tabixfile = pysam.Tabixfile(tabix_file) #tabixfile = fname if chrom in tabixfile.contigs: fobj = TabixIteratorAsFile(tabixfile.fetch(chrom)) @@ -27,94 +28,93 @@ def get_chrom_models(chrom, anno_files, data, weight, prune=None, index=None): elif ftype in ["gff", "gtf", "gff3"]: it = read_gff_transcripts(fobj, fname, min_exons=min_exons, merge=10) for tname, source, exons in it: - mc.add_transcript("{0}{1}{2}".format(name, SEP, tname), source, exons) + db.add_transcript("{0}{1}{2}".format(name, SEP, tname), source, exons) del fobj tabixfile.close() del tabixfile - # Prune spurious exon linkages - #for p in mc.prune(): - # logger.debug("Pruning {0}:{1}-{2}".format(*p)) - - # Remove long exons with only one evidence source - mc.filter_long(l=2000) - # Remove short introns - #mc.filter_short_introns() logger.info("Loading data for {0}".format(chrom)) for name, fname, span, extend in data: if span == "splice": - logger.debug("Reading splice data {0} from {1}".format(name, fname)) - mc.get_splice_statistics(fname, name=name) + logger.info("Reading splice data {0} from {1}".format(name, fname)) + db.get_splice_statistics(chrom, fname, name) else: - logger.debug("Reading BAM data {0} from {1}".format(name, fname)) - mc.get_read_statistics(fname, name=name, span=span, extend=extend, nreads=None) + logger.info("Reading BAM data {0} from {1}".format(name, fname)) + db.get_read_statistics(chrom, fname, name=name, span=span, extend=extend, nreads=None) + + except: + logger.exception("Error on {0}".format(chrom)) + raise + +def get_chrom_models(conn, chrom, weight, repeats=None, prune=None, keep=[], filter=[], experimental=[]): + + logger = logging.getLogger("pita") + logger.debug(str(weight)) + try: + db = AnnotationDb(conn=conn) + + # Filter repeats + if repeats: + for x in repeats: + db.filter_repeats(chrom, x) + + for ev in filter: + db.filter_evidence(chrom, ev, experimental) + + mc = DbCollection(db, chrom) + # Remove long exons with 2 or less evidence sources + if prune and prune.has_key("exons"): + l = prune["exons"]["length"] + ev = prune["exons"]["evidence"] + logger.debug("EXON PRUNE {} {}".format(l, ev)) + mc.filter_long(l=l, evidence=ev) + + if prune and prune.has_key("introns"): + max_reads = prune["introns"]["max_reads"] + ev = prune["introns"]["evidence"] + logger.debug("EXON PRUNE {} {}".format(max_reads, ev)) + mc.prune_splice_junctions(evidence=3, max_reads=10, keep=keep) + + # Remove short introns + #mc.filter_short_introns() + models = {} exons = {} logger.info("Calling transcripts for {0}".format(chrom)) for cluster in mc.get_connected_models(): + logger.debug("{}: got cluster".format(chrom)) while len(cluster) > 0: - logger.debug("best model") + #logger.debug("best model") best_model = mc.max_weight(cluster, weight) - logger.debug("best variant") - best_model = mc.get_best_variant(best_model, weight) - + logger.debug("{}: got best model".format(chrom)) + best_model = mc.get_best_variant(best_model, weight) + logger.debug("{}: got best variant".format(chrom)) genename = "{0}:{1}-{2}_".format( best_model[0].chrom, best_model[0].start, best_model[-1].end, ) - - logger.debug("get weight RNAseq") - rpkm = mc.get_weight(best_model, "RNAseq", "rpkm") - if rpkm >= 0.2: - genename += "V" - else: - genename += "X" - - rpkm = mc.get_weight(best_model, "H3K4me3", "first_rpkm") - logger.debug("{0}: H3K4me3: {1}".format(genename, rpkm)) - if rpkm >= 1: - genename += "V" - else: - genename += "X" - - other_exons = [e for e in set(itertools.chain.from_iterable(cluster)) if not e in best_model] - for i in range(len(cluster) - 1, -1, -1): - if cluster[i][0].start <= best_model[-1].end and cluster[i][-1].end >= best_model[0].start: - del cluster[i] - - ### Ugly logging stuff - best_ev = {} - other_ev = {} - for e in best_model: - for ev in set([x.split(SEP)[0] for x in e.evidence]): - best_ev[ev] = best_ev.setdefault(ev, 0) + 1 - - # Fast way to collapse - for e in other_exons: - for ev in set([x.split(SEP)[0] for x in e.evidence]): - other_ev[ev] = other_ev.setdefault(ev, 0) + 1 - ev = [] - for e in best_model + other_exons: - for evidence in e.evidence: - ev.append(evidence.split(SEP)) - - ### End ugly logging stuff - logger.debug("Best model: {0} with {1} exons".format(genename, len(best_model))) - models[genename] = [genename, best_model, ev, best_ev, other_ev] + + + logger.info("Best model: {0} with {1} exons".format(genename, len(best_model))) + models[genename] = [genename, best_model] for exon in best_model: exons[str(exon)] = [exon, genename] - + for i in range(len(cluster) - 1, -1, -1): + if cluster[i][0].start <= best_model[-1].end and cluster[i][-1].end >= best_model[0].start: + del cluster[i] discard = {} if prune: #logger.debug("Prune: {0}".format(prune)) overlap = get_overlapping_models([x[0] for x in exons.values()]) - #logger.debug("{0} overlapping exons".format(len(overlap))) - + if len(overlap) > 1: + logger.info("{0} overlapping exons".format(len(overlap))) +# logger.warn("Overlap: {0}".format(overlap)) + gene_count = {} for e1, e2 in overlap: gene1 = exons[str(e1)][1] @@ -138,20 +138,20 @@ def get_chrom_models(chrom, anno_files, data, weight, prune=None, index=None): overlap = l2 #logger.info("Pruning {} vs. {}".format(str(m1),str(m2))) - logger.info("1: {}, 2: {}, overlap: {}".format( - l1, l2, overlap)) - logger.info("Gene {} count {}, gene {} count {}".format( - str(gene1), gene_count[gene1], str(gene2), gene_count[gene2] - )) - - prune_overlap = 0.1 + #logger.info("1: {}, 2: {}, overlap: {}".format( + # l1, l2, overlap)) + #logger.info("Gene {} count {}, gene {} count {}".format( + # str(gene1), gene_count[gene1], str(gene2), gene_count[gene2] + # )) +# + prune_overlap = prune["overlap"]["fraction"] if overlap / l1 < prune_overlap and overlap / l2 < prune_overlap: - logger.info("Not pruning!") + logger.debug("Not pruning because fraction of overlap is too small!") continue w1 = 0.0 w2 = 0.0 - for d in prune: + for d in prune["overlap"]["weights"]: logger.debug("Pruning overlap: {0}".format(d)) tmp_w1 = mc.get_weight(m1, d["name"], d["type"]) tmp_w2 = mc.get_weight(m2, d["name"], d["type"]) @@ -161,13 +161,15 @@ def get_chrom_models(chrom, anno_files, data, weight, prune=None, index=None): w2 += tmp_w2 / max((tmp_w1, tmp_w2)) if w1 >= w2: + logger.info("Discarding {}".format(gene2)) discard[gene2] = 1 else: + logger.info("Discarding {}".format(gene1)) discard[gene1] = 1 - del mc - - return [v for m,v in models.items() if not m in discard] + logger.info("Done calling transcripts for {0}".format(chrom)) + result = [v for m,v in models.items() if not m in discard] + return [[name, [e.to_flat_exon() for e in exons]] for name, exons in result] except: logger.exception("Error on {0}".format(chrom)) diff --git a/pita/r_cpt.py b/pita/r_cpt.py new file mode 100755 index 0000000..3eb9e1d --- /dev/null +++ b/pita/r_cpt.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +from rpy2.robjects.packages import importr +from rpy2.rinterface import RRuntimeError +import rpy2.robjects as robjects + +# Try to load changepoint library, install it if it's not installed +try: + cp = importr('changepoint') +except RRuntimeError: + utils = importr('utils') + # Choose first available mirror from the list + utils.chooseCRANmirror(ind=1) + utils.install_packages('changepoint') + cp = importr('changepoint') + +def cpt(l): + """ + Call the mean.cpt function from the R changepoint package and + return the changepoint as a float. + """ + + vector = robjects.FloatVector([x for x in l]) + result = cp.cpt_mean(vector) + + try: + return float(cp.cpts(result)[0]) + except: + return None diff --git a/pita/util.py b/pita/util.py index 8c9cdbb..767198e 100644 --- a/pita/util.py +++ b/pita/util.py @@ -4,6 +4,11 @@ from Bio.Alphabet import IUPAC import re import sys +import subprocess as sp +from tempfile import NamedTemporaryFile +import logging + +logger = logging.getLogger('pita') def read_statistics(fname, rmrepeat=False, rmdup=False, mapped=False): """ Count number of reads in BAM file. @@ -44,9 +49,11 @@ def exons_to_seq(exons): exons = exons[::-1] for e in exons: - if not e.seq: - raise Exception, "exon has no sequence" - seq = "".join((seq, e.seq)) + if e.seq: + seq = "".join((seq, e.seq)) + else: + logger.error("exon {} has no sequence".format(e)) + return seq def longest_orf(seq, do_prot=False): @@ -143,6 +150,36 @@ def model_to_bed(exons, genename=None): +def get_splice_score(a, s_type=5): + if not s_type in [3,5]: + raise Exception("Invalid splice type {}, should be 3 or 5".format(s_type)) + maxent = "/home/simon/dwn/fordownload" + tmp = NamedTemporaryFile() + for name,seq in a: + tmp.write(">{}\n{}\n".format(name,seq)) + tmp.flush() + cmd = "perl score{}.pl {}".format(s_type, tmp.name) + p = sp.Popen(cmd, shell=True, cwd=maxent, stdout=sp.PIPE) + score = 0 + for line in p.stdout.readlines(): + vals = line.strip().split("\t") + if len(vals) > 1: + score += float(vals[-1]) + return score + +def bed2exonbed(inbed, outbed): + with open(outbed, "w") as out: + for line in open(inbed): + if line.startswith("#") or line.startswith("track"): + out.write(line) + else: + vals = line.strip().split("\t") + + exonsizes = [int(x) for x in vals[10].split(",") if x] + exonstarts = [int(x) for x in vals[11].split(",") if x] + for exon_start, exon_size in zip(exonstarts, exonsizes): + out.write("%s\t%s\t%s\t%s\t%s\t%s\n" % (vals[0], int(vals[1]) + exon_start, int(vals[1]) + exon_start + exon_size, vals[3], vals[4], vals[5])) + if __name__ == "__main__": print longest_orf("CAGGAAGTCACGGAGCGCGGGATTTTTCAATCAGACTGATGAACAGATGAATACGACGAAGAGCATGGAGGCAATTCTGGAATTTTTTGTGCTGTGTGATCCAAAGAAGCGGCCAGTCAGACTGAACCGGTTGCCTTCTGTACCAAAGGATGCACTGTGTTATTCTGCCCTGCTGCCATCTCCTCTACCATCCCAGCTGTTGATCTTTGGCTTAGGTGACTGGTCAGGGTTATCTGGAGGAAGCACAGTAGAAGTGAAATTGGAAGGAAGTGGAACCAAAGAGCACAGACTGGGAACGCTGACTCCTGAGTCAAGATGCTTCCTGTGGGAATCTGACCAAAACCCCGACACCAGCATAATGTTACAAGAGGGAAAGCTGCATATCTGCATGTCGGTTAAAGGGCAGGTCAATATTAATTCTACTAACAGGAAAAAAGAGCATGGAAAGCGCAAGAGAATTAAAGAGGAAGAGGAAAATGTTTGTCCAAATAGTGGACATGTAAAAGTGCCTGCTCAAAAACAGAAGAACAGTAGTCCTAAGAGTCCAGCACCAGCAAAGCAACTTGCTCATTCTAAGGCCTTTTTAGCAGCACCAGCTGTGCCAACTGCACGCTGGGGTCAAGCGCTCTGTCCTGTCAACTCTGAGACAGTAATCTTGATTGGTGGACAGGGAACACGTATGCAGTTCTGTAAGGATTCCATGTGGAAACTGAATACAGATAGGAGCACATGGACTCCAGCTGAGGCATTGGCAGATGGCCTTTCACCAGAAGCTCGTACTGGGCACACAGCAACCTTCGATCCTGAGAACAACCGTATTTATGTGTTTGGAGGTTCTAAGAACAGAAAATGGTTCAATGATGTACATATTTTGGACATTGAGGCCTGGCGATGGAGGAGCGTGGAAGTAAGTAAACTAAGTAGTTGA") diff --git a/pita/utr.py b/pita/utr.py new file mode 100755 index 0000000..1dfb9f2 --- /dev/null +++ b/pita/utr.py @@ -0,0 +1,220 @@ +from pita.r_cpt import cpt +from pita.util import bed2exonbed +from pita.io import read_bed_transcripts +from tempfile import NamedTemporaryFile +import pybedtools +import subprocess as sp +import sys +import os +import numpy as np + +def call_cpt(start, end, strand, data, min_reads=5, min_log2_ratio=1.5, upstream=False): + """ + Determine UTR location from basepair-resolution coverage vector of reads. + Return tuple of utr_start and utr_end if changepoint is found. + """ + + sys.stderr.write("{} {} {}\n".format(start, end, strand)) + pt_cutoff = 5 + counts = np.array(data) + + # Do calculations on reverse array if gene is on the - strand or + # when predicting 5' UTR + if (upstream an strand == "+") or strand == "-": + counts = counts[::-1] + + pt = len(counts) + ratio = 0 + while pt > pt_cutoff and ratio < 1: + sys.stderr.write("cpt {}\n".format(pt)) + pt = cpt(counts[:pt]) + if pt > pt_cutoff: + ratio = np.log2(counts[:pt].mean() / (counts[pt:].mean()) + 0.1) + + if pt > pt_cutoff: + # Add to the changepoint while the number of reads is above min_reads + while pt < len(counts) and counts[pt] >= min_reads: + pt += 1 + + m = counts[:pt].mean() + if m >= min_reads and ratio >= min_log2_ratio: + if (upstream and strand == "+") or strand == "-": + utr_start = int(end) - pt + utr_end = int(end) + else: + utr_start = int(start) + utr_end = int(start) + pt + + return int(utr_start), int(utr_end) + +def call_utr(inbed, bamfiles, utr5=False, utr3=True): + """ + Call 3' UTR for all genes in a BED12 file based on RNA-seq reads + in BAM files. + """ + + # Load genes in BED file + transcripts = read_bed_transcripts(open(inbed)) + + # No genes + if len(transcripts) == 0: + return + + td = dict([(t[0].split("_")[1] + "_", t[2]) for t in transcripts]) + + # Create a BED6 file with exons, used to determine UTR boundaries + sys.stderr.write("Preparing temporary BED files\n") + exonbed = NamedTemporaryFile(prefix="pita.", suffix=".bed") + bed2exonbed(inbed, exonbed.name) + + # Determine boundaries using bedtools + genes = pybedtools.BedTool(inbed) + exons = pybedtools.BedTool(exonbed.name) + + tmp = NamedTemporaryFile(prefix="pita.", suffix=".bed") + + EXTEND = 10000 + sys.stderr.write("Determining gene boundaries determined by closest gene\n") + for x in genes.closest(exons, D="a", io=True, iu=True): + transcript = td[x[3]] + + # Extend to closest exon or EXTEND, whichever is closer + extend = EXTEND + if (int(x[-1]) >= 0) and (int(x[-1]) < extend): + extend = int(x[-1]) + + if transcript[0][-1] == "+": + first = transcript[-1] + first[2] += extend + else: + first = transcript[-0] + first[1] -= extend + + if first[1] < 0: + first[1] = 0 + + tmp.write("{}\t{}\t{}\t{}\t{}\t{}\n".format( + first[0], + first[1], + first[2], + x[3], + 0, + first[3] + )) + + tmp.flush() + + tmpsam = NamedTemporaryFile(prefix="pita.", suffix=".sam") + tmpbam = NamedTemporaryFile(prefix="pita.") + + # Retrieve header from first BAM file + sp.call("samtools view -H {} > {}".format(bamfiles[0], tmpsam.name), shell=True) + + # Filter all BAM files for the specific regions. This runs much faster + # then running bedtools coverage on all individual BAM files + tmp_check = NamedTemporaryFile(prefix="pita.", suffix=".bam") + cmd = "samtools view -L {} {} > {}" + sys.stderr.write("Merging bam files\n") + for bamfile in bamfiles: + try: + sp.check_call(cmd.format(tmp.name, bamfile, tmp_check.name), shell=True) + sp.call("cat {} >> {}".format(tmp_check.name, tmpsam.name), shell=True) + except sp.CalledProcessError as e: + sys.stderr.write("Error in file {}, skipping:\n".format(bamfile)) + sys.stderr.write("{}\n".format(e)) + + tmp_check.close() + + # Created sorted and index bam + cmd = "samtools view -Sb {} | samtools sort -m 6G - {}" + sp.call(cmd.format(tmpsam.name, tmpbam.name), shell=True) + sp.call("samtools index {}.bam".format(tmpbam.name), shell=True) + + # Close and remove temporary SAM file + tmpsam.close() + + sys.stderr.write("Calculating coverage\n") + cmd = "bedtools coverage -abam {} -b {} -d -split " + + p = sp.Popen(cmd.format(tmpbam.name + ".bam", tmp.name), shell=True, stdout=sp.PIPE, bufsize=1) + + sys.stderr.write("Calling UTRs\n") + + data = [] + current = [None] + utr = {} + for line in iter(p.stdout.readline, b''): + vals = line.strip().split("\t") + if vals[3] != current[0]: + if len(data) > 0: + result = call_cpt(current[1], current[2], current[3], data, len(bamfile)) + #print result + if result: + utr[current[0]] = result + data = [] + current = [vals[3], int(vals[1]), int(vals[2]), vals[5]] + data.append(int(vals[7])) + if current[0]: + result = call_cpt(current[1], current[2], current[3], data, len(bamfiles)) + if result: + utr[current[0]] = result + + for fname in [tmpbam.name + ".bam", tmpbam.name + ".bam.bai"]: + if os.path.exists(fname): + os.unlink(fname) + + tmpbam.close() + tmp.close() + + + return utr + +def print_updated_bed(bedfile, bamfiles): + utr = call_utr(bedfile, bamfiles) + for line in open(bedfile): + if line.startswith("track") or line[0] == "#": + print line.strip() + continue + + vals = line.strip().split("\t") + start,end = int(vals[1]), int(vals[2]) + strand = vals[5] + name = vals[3] + thickstart, thickend = int(vals[6]), int(vals[7]) + exonsizes = [int(x) for x in vals[10].split(",") if x] + exonstarts = [int(x) for x in vals[11].split(",") if x] + + if utr.has_key(name): + sys.stderr.write("Updating {}\n".format(name)) + utr_start, utr_end = utr[name] + + if strand == "+": + if utr_end < thickend: + sys.stderr.write("Setting end of {} to CDS end\n".format(name)) + utr_end = thickend + diff = exonsizes[-1] - (utr_end - utr_start) + end -= diff + + exonsizes[-1] -= diff + + vals[2] = end + vals[10] = ",".join([str(x) for x in exonsizes] + [""]) + else: + if utr_start > thickstart: + sys.stderr.write("Setting start of {} to CDS start\n".format(name)) + utr_start = thickstart + diff = exonsizes[0] - (utr_end - utr_start) + sys.stderr.write("{} {} {} diff: {}\n".format(utr_start, utr_end, exonsizes[0], diff)) + start += diff + + exonstarts = [0] + [x - diff for x in exonstarts[1:]] + exonsizes[0] -= diff + + vals[1] = start + vals[10] = ",".join([str(x) for x in exonsizes] + [""]) + vals[11] = ",".join([str(x) for x in exonstarts] + [""]) + + print "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format(*vals) + else: + print line.strip() + diff --git a/pita_test.db b/pita_test.db new file mode 100644 index 0000000..bb5fc8d Binary files /dev/null and b/pita_test.db differ diff --git a/scripts/bam2splicecount b/scripts/bam2splicecount new file mode 100755 index 0000000..767cd61 --- /dev/null +++ b/scripts/bam2splicecount @@ -0,0 +1,43 @@ +#!/usr/bin/env python +#JGIv7b.000209344 345095 349191 JGIv7b.000209344:345095-349191_VV 600 - 345095 349191 0,0,0 3 1061,183,186, 0,1544,3910, +import sys +import subprocess as sp + +if len(sys.argv) != 2: + sys.stderr.write("Usage:\nbam2splicecount \n\n") + sys.stderr.write("Extract number of splice junctions from a BAM file.\n") + sys.stderr.write("Can consume a lot of memory dependent on the size of the BAM file.\n") + sys.exit(1) + +cmd = "samtools view -h {} | awk '($6 ~ /N/ || $1 ~ /^@/)' | samtools view -Sbu - | bedtools bamtobed -bed12" + +p = sp.Popen(cmd.format(sys.argv[1]), shell=True, stdout=sp.PIPE) +splice_count = {} +for line in p.stdout.readlines(): + if line.startswith("track") or line[0] == "#": + continue + vals = line.strip().split("\t") + start,end = int(vals[1]), int(vals[2]) + strand = vals[5] + exon_sizes = [int(x) for x in vals[10].strip(",").split(",")[:]] + exon_starts = [int(x) for x in vals[11].strip(",").split(",")[:]] + + for estart1, estart2, esize in zip(exon_starts[:-1], exon_starts[1:], exon_sizes[:-1]): + sstart = start + estart1 + esize + send = start + estart2 + if sstart > send: + sys.stderr.write("Skipping {0}: {1} - {2}\n".format(vals[3], sstart, send)) + else: + splice_count.setdefault((vals[0],sstart, send,strand), 0) + splice_count[(vals[0],sstart, send,strand)] += 1 + +for (chrom, start, end, strand), count in splice_count.items(): + print "{0}\t{1}\t{2}\t{3}\t{4}\t{5}".format( + chrom, + start, + end, + count, + 0, + strand, + ) + diff --git a/scripts/flatbread b/scripts/flatbread index ea65d0a..8b796da 100755 --- a/scripts/flatbread +++ b/scripts/flatbread @@ -55,7 +55,7 @@ for f in a: config["chromosomes"] = chroms.keys() # Do all intersections with BED file -for f in config["annotation"]: +for f in config["annotation"] + config['repeats']: path = f['path'] sys.stderr.write("Filtering {0}\n".format(path)) if not os.path.exists(path): @@ -69,7 +69,9 @@ for f in config["annotation"]: f['path'] = os.path.basename(path) for f in config['data']: - for p in f['path']: + remove = [] + for rel_p in f['path']: + p = rel_p if not os.path.exists(p): p = os.path.join(base, p) out = os.path.join(outname, os.path.basename(p)) @@ -84,13 +86,27 @@ for f in config['data']: loc = "{0}:{1}-{2}".format(vals[0], vals[1], vals[2]) cmd = "samtools view {0} {1} >> {2}".format(p, loc, tmp.name) sp.call(cmd, shell=True) - cmd = "samtools view -bS {0} > {1} && samtools index {1}" - sp.call(cmd.format(tmp.name, out), shell=True) - tmp.close() + + cmd = "samtools view -S {0} | head -n 1".format(tmp.name) + pop = sp.Popen(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE) + stdout, stderr = pop.communicate() + if stdout.strip(): + + cmd = "samtools view -bS {0} > {1} && mv {1} {0}" + sp.call(cmd.format(tmp.name, out), shell=True) + cmd = "samtools sort {0} {1} && samtools index {2}" + sp.call(cmd.format(tmp.name, out.replace(".bam", ""), out), shell=True) + tmp.close() + else: + sys.stderr.write("{} does not contain reads in these regions, leaving it out.\n".format(rel_p)) + + remove.append(rel_p) else: a = pybedtools.BedTool(p) b = pybedtools.BedTool(bedfile) a.intersect(b, wa=True).saveas(out) + for p in remove: + f['path'].remove(p) config["data_path"] = os.path.abspath(outname) diff --git a/scripts/pita b/scripts/pita index b0e5301..03fe620 100755 --- a/scripts/pita +++ b/scripts/pita @@ -1,253 +1,173 @@ #!/usr/bin/env python -from pita.model import get_chrom_models +from pita.model import get_chrom_models, load_chrom_data from pita.util import model_to_bed, read_statistics, exons_to_seq, longest_orf -from pita.log import AnnotationLog +from pita.log import setup_logging +from pita.config import PitaConfig +from pita.annotationdb import AnnotationDb import os import sys -import logging -import yaml import argparse -import pysam -import subprocess -import pp -from tempfile import NamedTemporaryFile import multiprocessing as mp from functools import partial import signal DEFAULT_CONFIG = "pita.yaml" DEFAULT_THREADS = 4 -VALID_TYPES = ["bed", "gff", "gff3", "gtf"] -DEBUG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] p = argparse.ArgumentParser() p.add_argument("-c", dest= "configfile", default = DEFAULT_CONFIG, - help="Configuration file (default: {0})".format(DEFAULT_CONFIG) + help="configuration file (default: {0})".format(DEFAULT_CONFIG) ) p.add_argument("-t", dest= "threads", default = DEFAULT_THREADS, type = int, - help="Number of threads (default: {0})".format(DEFAULT_THREADS) + help="number of threads (default: {0})".format(DEFAULT_THREADS) ) p.add_argument("-i", dest= "index_dir", default = None, - help="Genome index dir" + help="genome index dir" + ) +p.add_argument("-y", + dest= "yaml_file", + default = None, + help="dump database to yaml file" ) p.add_argument("-d", dest= "debug_level", default = "INFO", help="Debug level" ) +p.add_argument("-r", + dest ="reannotate", + default = False, + action = "store_true", + help="reannotate using existing database" + ) + args = p.parse_args() configfile = args.configfile + +if not os.path.exists(configfile): + print "Missing config file {}".format(configfile) + print + p.print_help() + sys.exit() + threads = args.threads index = args.index_dir -debug_level = args.debug_level.upper() - -if not debug_level in DEBUG_LEVELS: - sys.stderr.write("Invalid debug level {0}\n".format(debug_level)) - sys.stderr.write("Valid values are {0}\n".format(",".join(DEBUG_LEVELS))) - sys.exit(1) +debug_level = args.debug_level basename = os.path.splitext(configfile)[0] # Setup logging -logger = logging.getLogger("pita") -logger.setLevel(getattr(logging, debug_level)) -formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s') -handler = logging.StreamHandler() -handler.setFormatter(formatter) -handler.setLevel(getattr(logging, debug_level)) -fh = logging.FileHandler("{0}.log".format(basename)) -fh.setFormatter(formatter) -fh.setLevel(getattr(logging, debug_level)) -logger.addHandler(handler) -logger.addHandler(fh) - -# Parse YAML config file -f = open(configfile, "r") -config = yaml.load(f) - -# Data directory -base = "." -if config.has_key("data_path"): - base = config["data_path"] +logger = setup_logging(basename, debug_level) + +# Load config file +config = PitaConfig(configfile, args.reannotate) # FASTA output protein_fh = open("{0}.protein.fa".format(basename), "w") cdna_fh = open("{0}.cdna.fa".format(basename), "w") -prune = None -if config.has_key("prune_overlap"): - prune = config["prune_overlap"] - -if not config.has_key("annotation") or len(config["annotation"]) == 0: - logger.error("No annotation files specified.") - sys.exit(1) - -anno_files = [] -chroms = {} -for d in config["annotation"]: - logger.debug("annotation: {0}".format(d)) - fname = os.path.join(base, d["path"]) - t = d["type"].lower() - min_exons = 2 - if d.has_key("min_exons"): - min_exons = d["min_exons"] - if not t in VALID_TYPES: - logger.error("Invalid type: {0}".format(t)) - sys.exit(1) - if not os.path.exists(fname): - logger.error("File does not exist: {0}".format(fname)) - sys.exit(1) - else: - logger.info("Creating tabix index for {0}".format(os.path.basename(fname))) - logger.debug("Preparing {0} for tabix".format(fname)) - tmp = NamedTemporaryFile(prefix="pita") - preset = "gff" - if t == "bed": - cmd = "sort -k1,1 -k2g,2 {0} | grep -v track | grep -v \"^#\" > {1}" - preset = "bed" - elif t in ["gff", "gff3", "gtf3"]: - cmd = "sort -k1,1 -k4g,4 {0} | grep -v \"^#\" > {1}" - - # Sort the input file - logger.debug(cmd.format(fname, tmp.name)) - subprocess.call(cmd.format(fname, tmp.name), shell=True) - # Compress using bgzip - logger.debug("compressing {0}".format(tmp.name)) - tabix_file = tmp.name + ".gz" - pysam.tabix_compress(tmp.name, tabix_file) - tmp.close() - # Index (using tabix command line, as pysam.index results in a Segmentation fault - logger.debug("indexing {0}".format(tabix_file)) - subprocess.call("tabix {0} -p {1}".format(tabix_file, preset), shell=True) - - #fobj = pysam.Tabixfile(tabix_file) - # Add file info - anno_files.append([d["name"], tabix_file, t, min_exons]) - # Save chromosome names - for chrom in pysam.Tabixfile(tabix_file).contigs: - chroms[chrom] = 1 - -# data config -logger.info("Checking data files") -data = [] -if config.has_key("data") and config["data"]: - for d in config["data"]: - logger.debug("data: {0}".format(d)) - d.setdefault("up", 0) - d.setdefault("down", 0) - if type("") == type(d["path"]): - d["path"] = [d["path"]] - - - names_and_stats = [] - fnames = [os.path.join(base, x) for x in d["path"]] - for fname in fnames: - if not os.path.exists(fname): - logger.error("File does not exist: {0}".format(fname)) - sys.exit(1) - - if fname.endswith("bam") and not os.path.exists(fname + ".bai"): - logger.error("BAM file {0} needs to be indexed!".format(fname)) - sys.exit(1) - - #if fname.endswith("bam"): - # names_and_stats.append((fname, read_statistics(fname))) - #else: - # names_and_stats.append((fname, None)) - row = [d["name"], fnames, d["feature"], (d["up"], d["down"])] - data.append(row) - -weight = {} -if config.has_key("scoring"): - weight = config["scoring"] - line_format = "{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\t{7}\t{8}\t{9}\t{10}\t{11}" print 'track name="{0}"'.format(basename) -chroms = chroms.keys() - -if config.has_key("chromosomes") and config["chromosomes"]: - if type(config["chromosomes"]) == type([]): - chroms = config["chromosomes"] - else: - chroms = [config["chromosomes"]] +chroms = config.chroms def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) -def print_output(alog, genename, exons, ev, best_ev, other_ev, lock=None): - print model_to_bed(exons, genename) - alog.log_to_file(genename, exons, ev, best_ev, other_ev ) +def print_output(genename, exons, lock=None): + #alog.log_to_file(genename, exons) if lock: lock.acquire() + sys.stdout.write("{}\n".format(model_to_bed(exons, genename))) + sys.stdout.flush() # Print sequences cdna_fh.write(">{0}\n{1}\n".format(genename, exons_to_seq(exons))) - protein_fh.write(">{0}\n{1}\n".format(genename, longest_orf(exons, do_prot=True))) + + pep = longest_orf(exons, do_prot=True) + if len(pep) >= config.min_protein_size: + protein_fh.write(">{0}\n{1}\n".format(genename, pep)) cdna_fh.flush() protein_fh.flush() if lock: lock.release() -def listener(q, names, append=False, lock=None): +def listener(q, lock=None): '''listens for messages on the q, writes to file. ''' - - alog = AnnotationLog(append) - for name in names: - alog.add(name) while 1: m = q.get() if m == 'kill': break - - genename, exons, ev, best_ev, other_ev = m + + genename, exons = m logger.debug("calling print_output for {0}".format(genename)) - print_output(alog, genename, exons, ev, best_ev, other_ev, lock) - -def annotate_chrom(chrom, q, anno_files, data, weight, prune, index): + print_output(genename, exons, lock) + +def annotate_chrom(chrom, conn, q, anno_files, data, repeats, weight, prune, keep, filter, experimental, index, reannotate): + new = False + if conn.startswith("sqlite"): + conn += ".{}".format(chrom) + if not reannotate: + new = True logger.info("Chromosome {0} started".format(chrom)) - for genename, best_exons, ev, best_ev, other_ev in get_chrom_models(chrom, anno_files, data, weight, prune, index): + if not reannotate: + load_chrom_data(conn, new, chrom, anno_files, data, index) + for genename, best_exons in get_chrom_models(conn, chrom, weight, repeats, prune, keep, filter, experimental): + #results.append([genename, best_exons]) logger.debug("Putting {0} in print queue".format(genename)) - q.put([genename, best_exons, ev, best_ev, other_ev]) + q.put([genename, best_exons]) logger.info("Chromosome {0} finished".format(chrom)) +# Initialize database +if not args.reannotate: + db = AnnotationDb(new=True, conn=config.db_conn) + if threads > 1: logger.info("Starting threaded work") manager = mp.Manager() lock = manager.Lock() - q = manager.Queue() - - pool = mp.Pool(threads, init_worker) + q = manager.Queue() + pool = mp.Pool(threads, init_worker, maxtasksperchild=1) - try: - partialAnnotate = partial(annotate_chrom, q=q, anno_files=anno_files, data=data, weight=weight, prune=prune, index=index) + try: + #put listener to work first - watcher = pool.apply_async(listener, (q,[x[0] for x in anno_files], False, lock)) + watcher = pool.apply_async(listener, args=(q, lock) ) + + # do the main work + partialAnnotate = partial(annotate_chrom, conn=config.db_conn, q=q, anno_files=config.anno_files, data=config.data, repeats=config.repeats, weight=config.weight, prune=config.prune, keep=config.keep, filter=config.filter, experimental=config.experimental, index=index, reannotate=args.reannotate) pool.map(partialAnnotate, chroms) + + # kill the queue! q.put('kill') pool.close() pool.join() + except KeyboardInterrupt: logger.exception("Caught KeyboardInterrupt, terminating workers") pool.terminate() pool.join() else: - alog = AnnotationLog(append=False) - for name in [x[0] for x in anno_files]: - alog.add(name) - for chrom in chroms: - for genename,best_exons, ev, best_ev, other_ev in get_chrom_models(chrom, anno_files, data, weight, prune, index): - print_output(alog, genename,best_exons, ev, best_ev, other_ev) + if not args.reannotate: + load_chrom_data(config.db_conn, True, chrom, config.anno_files, config.data, index) + for genename, best_exons in get_chrom_models(config.db_conn, chrom, config.weight, repeats=config.repeats, prune=config.prune, keep=config.keep, filter=config.filter, experimental=config.experimental): + print_output(genename, best_exons) cdna_fh.close() protein_fh.close() + +# dump database +if args.yaml_file: + with AnnotationDb(new=False) as db: + with open(args.yaml_file, "w") as f: + f.write(db.dump_yaml()) diff --git a/scripts/pita_utr b/scripts/pita_utr new file mode 100755 index 0000000..0d6bf82 --- /dev/null +++ b/scripts/pita_utr @@ -0,0 +1,37 @@ +#!/usr/bin/env python +import sys +import os +import argparse +from pita.utr import * + +p = argparse.ArgumentParser() +p.add_argument("-i", + dest= "bedfile", + help="genes in BED12 format", + ) +p.add_argument("-b", + dest= "bamfiles", + help="list of RNA-seq BAM files (seperated by comma)", + ) +args = p.parse_args() + +if not args.bedfile or not args.bamfiles: + p.print_help() + sys.exit() + +bedfile = args.bedfile +bamfiles = args.bamfiles.split(",") + +for bamfile in bamfiles: + if not os.path.exists(bamfile): + sys.stderr.write("BAM file {} does not exist.\n".format(bamfile)) + sys.exit(1) + if not os.path.exists(bamfile + ".bai"): + sys.stderr.write("index file {}.bai does not exist.\n".format(bamfile)) + sys.exit(1) + +if not os.path.exists(bedfile): + sys.stderr.write("BED file {} does not exist.\n".format(bedfile)) + sys.exit(1) + +print_updated_bed(bedfile, bamfiles) diff --git a/setup.py b/setup.py index a3c5846..c4d4dee 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools.command.test import test as TestCommand import sys -VERSION = "1.7" +VERSION = "1.72" DESCRIPTION = """ pita - pita improves transcript annotation """ @@ -29,6 +29,7 @@ def run_tests(self): ], scripts=[ "scripts/pita", + "scripts/pita_utr", "scripts/bed12togff3", "scripts/gff3tobed12", "scripts/flatbread", @@ -37,8 +38,9 @@ def run_tests(self): data_files=[], tests_require=['pytest'], install_requires=[ + "SQLAlchemy", "gimmemotifs", - "pysam >= 0.7.4", + "pysam < 0.8", "pyyaml", "HTSeq", "bcbio-gff", diff --git a/tests/data/merge1.yaml b/tests/data/merge1.yaml new file mode 100644 index 0000000..594cc0c --- /dev/null +++ b/tests/data/merge1.yaml @@ -0,0 +1,25 @@ +read_source: + - [1,H3K4me3,/home/simon/prj/laevis/pita/data/H3K4me3_stage14_300.PE.bam,244989632] + - [2,RNAseq,/home/simon/prj/laevis/pita/data/C022YABXX_8_2170.sorted.bam,139030053] + +feature: + - [1, chr1,181322,181786,-,exon,""] + - [2, chr1,187341,187455,-,exon,""] + - [3, chr1,188581,188699,-,exon,""] + +read_count: + - [1,1,1,all,0,0] + - [1,2,10,all,0,0] + - [1,3,100,all,0,0] + - [2,1,5,all,0,0] + - [2,2,25,all,0,0] + - [2,3,125,all,0,0] + +evidence: + - [1, "est1", "file1"] + +feature_evidence: + - [1,1] + - [2,1] + - [3,1] + diff --git a/tests/data/merge2.yaml b/tests/data/merge2.yaml new file mode 100644 index 0000000..f7d472f --- /dev/null +++ b/tests/data/merge2.yaml @@ -0,0 +1,27 @@ +read_source: + - [2,H3K4me3,/home/simon/prj/laevis/pita/data/H3K4me3_stage14_300.PE.bam,244989632] + - [1,RNAseq,/home/simon/prj/laevis/pita/data/C022YABXX_8_2170.sorted.bam,139030053] + +feature: + - [1, chr2,181322,181786,-,exon,""] + - [2, chr2,187341,187455,-,exon,""] + - [3, chr2,188581,188699,-,exon,""] + +read_count: + - [1,1,1,all,0,0] + - [1,2,10,all,0,0] + - [1,3,100,all,0,0] + - [2,1,5,all,0,0] + - [2,2,25,all,0,0] + - [2,3,125,all,0,0] + +evidence: + - [1, "est2.1", "file1"] + - [3, "est2.2", "file1"] + +feature_evidence: + - [1,1] + - [2,1] + - [3,1] + - [2,3] + - [3,3] diff --git a/tests/data/read_stats.bed b/tests/data/read_stats.bed new file mode 100644 index 0000000..b92defb --- /dev/null +++ b/tests/data/read_stats.bed @@ -0,0 +1,5 @@ +scaffold_1 18070000 18080000 +scaffold_1 18080000 18200000 +scaffold_1 18200000 18200100 +scaffold_1 18200100 18250000 +scaffold_1 18250000 18300000 diff --git a/tests/data/splice_data.bed b/tests/data/splice_data.bed new file mode 100644 index 0000000..4ff5870 --- /dev/null +++ b/tests/data/splice_data.bed @@ -0,0 +1,2 @@ +scaffold_1 18080000 18200000 4 0 + +scaffold_1 18200100 18250000 20 0 + diff --git a/tests/test_collection.py b/tests/test_collection.py index 7a677a7..2f1c111 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -143,7 +143,7 @@ def test_long_exon_filter(t1, t2): def short_intron_track(): return "tests/data/short_introns.bed" -def test_long_exon_filter(short_intron_track): +def test_short_intron_filter(short_intron_track): from pita.collection import Collection from pita.io import read_bed_transcripts from pita.util import model_to_bed diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..ac167f7 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,177 @@ +import pytest + +@pytest.fixture +def three_exons(): + exons = [ + ["chr1", 100, 200, "+"], + ["chr1", 300, 400, "+"], + ["chr2", 100, 200, "-"], + ] + return exons + +@pytest.fixture +def three_transcripts(): + transcripts = [ + ["t1", "annotation", + [ + ["chr1", 100, 200, "+"], + ["chr1", 300, 400, "+"], + ["chr1", 500, 600, "+"], + ] + ], + ["t2", "annotation", + [ + ["chr1", 50, 200, "+"], + ["chr1", 300, 400, "+"], + ] + ], + ["t3", "annotation", + [ + ["chr1", 300, 400, "+"], + ["chr1", 500, 800, "+"], + ] + ], + ] + return transcripts + +@pytest.fixture +def two_transcripts(): + transcripts = [ + ["t1", "annotation", + [ + ["chr1", 700, 900, "+"], + ["chr1", 1100, 1200, "+"], + ] + ], + ["t2", "annotation", + [ + ["chr1", 1100, 1200, "+"], + ["chr1", 1400, 1600, "+"], + ] + ], + ] + return transcripts + +@pytest.fixture +def bam_file(): + return "tests/data/H3K4me3.bam" + +@pytest.fixture +def transcripts(): + transcripts = [ + ["t1", "annotation", + [ + ["scaffold_1", 18070000, 18080000, "+"], + ["scaffold_1", 18200000, 18200100, "+"], + ["scaffold_1", 18250000, 18300000, "+"], + ] + ] + + ] + return transcripts + +@pytest.yield_fixture +def db(transcripts): + from pita.annotationdb import AnnotationDb + with AnnotationDb(conn="sqlite:////tmp/pita_test.db", new=True) as d: + for name, source, exons in transcripts: + d.add_transcript(name, source, exons) + yield d + +@pytest.yield_fixture +def empty_db(): + from pita.annotationdb import AnnotationDb + with AnnotationDb(conn="sqlite:////tmp/pita_test.db", new=True) as d: + yield d +#scaffold_1 18070000 18080000 64 +#scaffold_1 18080000 18200000 1092 +#scaffold_1 18200000 18200100 1 +#scaffold_1 18200100 18250000 318 +#scaffold_1 18250000 18300000 300 + +def test_read_statistics(bam_file, db): + db.get_read_statistics("scaffold_1", bam_file, "test") + exons = db.get_exons() + counts = [e.read_counts[0].count for e in exons] + assert [1,64,300] == sorted(counts) + assert 5218 == db.nreads("test") + +@pytest.fixture +def splice_file(): + return "tests/data/splice_data.bed" + +def test_splice_statistics(db, splice_file): + db.get_splice_statistics("scaffold_1", splice_file, "test") + splices = db.get_splice_junctions() + counts = [s.read_counts[0].count for s in splices] + assert 2 == len(splices) + assert [4,20] == counts + +def test_get_weight(db, bam_file, splice_file): + db.get_read_statistics("scaffold_1", bam_file, "H3K4me3") + db.get_splice_statistics("scaffold_1", splice_file, "RNAseq") + from pita.dbcollection import DbCollection + c = DbCollection(db) + + model = list(c.get_connected_models())[0][0] + w = c.get_weight(model, "H3K4me3", "all") + assert 365 == w + w = c.get_weight(model, None, "length") + assert 60100 == w + w = c.get_weight(model, "H3K4me3", "rpkm") + assert abs(1163.9 - w) < 0.1 + w = c.get_weight(model, "H3K4me3", "weighted") + assert abs(0.01821963394342762 - w) < 0.0001 + w = c.get_weight(model, "H3K4me3", "total_rpkm") + assert abs(4292.832 - w) < 0.1 + w = c.get_weight(model, "H3K4me3", "mean_exon") + assert abs(1430.944 - w) < 0.1 + w = c.get_weight(model, "RNAseq", "splice") + assert 24 == w + w = c.get_weight(model, "H3K4me3", "first") + assert 64 == w + w = c.get_weight(model, None, "evidence") + assert 1 == w + +def test_get_junction_exons(db): + splices = db.get_splice_junctions() + splice = [s for s in splices if s.start == 18080000][0] + + exon_pairs = db.get_junction_exons(splice) + assert 1 == len(exon_pairs) + e1,e2 = exon_pairs[0] + assert e1.start == 18070000 + assert e1.end == 18080000 + assert e2.start == 18200000 + assert e2.end == 18200100 + +def test_db_collection(db): + from pita.dbcollection import DbCollection + c = DbCollection(db) + + for model in c.get_connected_models(): + print model + +def test_get_long_exons(db): + assert 0 == len(db.get_long_exons("scaffold_1", 100000, 2)) + assert 1 == len(db.get_long_exons("scaffold_1", 50000, 2)) + assert 2 == len(db.get_long_exons("scaffold_1", 10000, 2)) + assert 3 == len(db.get_long_exons("scaffold_1", 50, 2)) + +def test_load_yaml(empty_db): + db = empty_db + db.load_yaml("tests/data/merge1.yaml") + db.load_yaml("tests/data/merge2.yaml") + + for e in db.get_exons(): + print str(e) + + assert 6 == len([e for e in db.get_exons()]) + + l = [len(e.evidences) for e in db.get_exons()] + print l + assert sorted(l) == [1,1,1,1,2,2] + + + +