Skip to content

Commit

Permalink
Removed quadratic bottlenecks from the definition of Superloci and of…
Browse files Browse the repository at this point in the history
… AS graphs. This should speed up Mikado considerably for transcript-dense inputs.
  • Loading branch information
lucventurini committed Aug 18, 2019
1 parent cab1c5f commit f99d003
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 35 deletions.
58 changes: 36 additions & 22 deletions Mikado/loci/locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,28 +823,6 @@ def pad_transcripts(self) -> set:
self.exons.update(self[tid].exons)
return templates

def define_graph(self, objects: dict, inters=None, three_prime=False):

graph = nx.DiGraph()
graph.add_nodes_from(objects.keys())

if inters is None:
inters = self._share_extreme

for obj, other_obj in combinations(sorted(objects.keys()), 2):
self.logger.debug("Comparing %s to %s (%s')", obj, other_obj, "5" if not three_prime else "3")
if obj == other_obj:
continue
else:
edge = inters(objects[obj], objects[other_obj], three_prime=three_prime)
if edge:
assert edge[0].id in self
assert edge[1].id in self
# assert edge[1].id in self.scores
graph.add_edge(edge[0].id, edge[1].id)

return graph

def _find_communities_boundaries(self, five_graph, three_graph):

five_found = set()
Expand Down Expand Up @@ -891,6 +869,42 @@ def _find_communities_boundaries(self, five_graph, three_graph):

return __to_modify

def define_graph(self, objects: dict, inters=None, three_prime=False):

graph = nx.DiGraph()
graph.add_nodes_from(objects.keys())
if inters is None:
inters = self._share_extreme

if len(objects) >= 2:
if (three_prime is True and self.strand != "-") or (three_prime is False and self.strand == "-"):
reverse = True
else:
reverse = False
order = sorted([(objects[tid].start, objects[tid].end, tid) for tid in objects], reverse=reverse)

for pos in range(len(order) - 1):
obj = order[pos]
self.logger.warning("Checking %s", obj[2])
for other_obj in order[pos + 1:]:
if obj == other_obj:
continue
elif self.overlap(obj[:2], obj[:2], positive=False, flank=0) == 0:
break
else:
self.logger.warning("Comparing %s to %s (%s')", obj[2], other_obj[2],
"5" if not three_prime else "3")
edge = inters(objects[obj[2]], objects[other_obj[2]], three_prime=three_prime)
if edge:
assert edge[0].id in self
assert edge[1].id in self
# assert edge[1].id in self.scores
graph.add_edge(edge[0].id, edge[1].id)
else:
self.logger.warning("No comparison to be made (objects: %s)", objects)

return graph

def _share_extreme(self, first: Transcript, second: Transcript, three_prime=False):

"""
Expand Down
65 changes: 52 additions & 13 deletions Mikado/loci/superlocus.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from ..utilities import dbutils, grouper
from ..scales.assigner import Assigner
import bisect
import operator
import functools
import numpy as np
if version_info.minor < 5:
from sortedcontainers import SortedDict
Expand Down Expand Up @@ -1270,13 +1272,12 @@ def define_alternative_splicing(self):
cdna_overlap = self.json_conf["pick"]["alternative_splicing"]["min_cdna_overlap"]

self.logger.debug("Defining the transcript graph")
t_graph = super().define_graph(self.transcripts,
inters=MonosublocusHolder.is_intersecting,
cds_only=cds_only,
logger=self.logger,
min_cdna_overlap=cdna_overlap,
min_cds_overlap=cds_overlap,
simple_overlap_for_monoexonic=False)
t_graph = self.define_as_graph(inters=MonosublocusHolder.is_intersecting,
cds_only=cds_only,
min_cdna_overlap=cdna_overlap,
min_cds_overlap=cds_overlap,
simple_overlap_for_monoexonic=False)

self.logger.debug("Defined the transcript graph")

loci_cliques = dict()
Expand All @@ -1285,8 +1286,8 @@ def define_alternative_splicing(self):
neighbors = set(t_graph.neighbors(locus_instance.primary_transcript_id))
except networkx.exception.NetworkXError:
raise networkx.exception.NetworkXError(
"{} {} {}".format(
locus_instance.primary_transcript.attributes["Alias"],
"{} {}".format(
# locus_instance.primary_transcript.attributes["Alias"],
locus_instance.primary_transcript_id,
list(t_graph.nodes)
))
Expand Down Expand Up @@ -1423,10 +1424,48 @@ def define_graph(self, cds_only=False) -> networkx.Graph:
for key in intronic:
graph.add_edges_from(itertools.combinations(intronic[key], 2))

# This will be quadratic. Hopefully it will not break the program.
for one, other in itertools.combinations(monoexonic, 2):
if self.overlap((one[0], one[1]), (other[0], other[1]), positive=False) > 0:
graph.add_edge(one[2], other[2])
if len(monoexonic) > 1:
for pos in range(len(monoexonic) - 1):
one = monoexonic[pos]
for other in monoexonic[pos + 1:]:
if self.overlap((one[0], one[1]), (other[0], other[1]), positive=False) > 0:
graph.add_edge(one[2], other[2])
else:
break

return graph

def define_as_graph(self,
inters=MonosublocusHolder.is_intersecting,
cds_only=False,
min_cdna_overlap=0.2,
min_cds_overlap=0.2,
simple_overlap_for_monoexonic=True):

"""This method will try to build the AS graph using a O(nlogn) rather than O(n^2) algorithm."""

method = functools.partial(inters, cds_only=cds_only, min_cdna_overlap=min_cdna_overlap,
min_cds_overlap=min_cds_overlap,
simple_overlap_for_monoexonic=simple_overlap_for_monoexonic,
logger=self.logger)

graph = networkx.Graph()
graph.add_nodes_from(self.transcripts.keys())
if len(self.transcripts) >= 2:
if cds_only:
order = sorted([(transcript.selected_cds_start, transcript.selected_cds_end, transcript.id)
for transcript in self.transcripts.values()], key=operator.itemgetter(0, 1))
else:
order = sorted([(transcript.start, transcript.end, transcript.id)
for transcript in self.transcripts.values()], key=operator.itemgetter(0, 1))

for pos in range(len(order) -1 ):
one = order[pos]
for other in order[pos + 1:]:
if self.overlap((one[0], one[1]), (other[0], other[1]), positive=True) <= 0:
break
elif method(self[one[2]], self[other[2]]) is True:
graph.add_edge(one[2], other[2])

return graph

Expand Down

0 comments on commit f99d003

Please sign in to comment.