diff --git a/Mikado/loci/locus.py b/Mikado/loci/locus.py index f25f355bd..1a823e2be 100644 --- a/Mikado/loci/locus.py +++ b/Mikado/loci/locus.py @@ -823,28 +823,6 @@ def pad_transcripts(self) -> set: self.exons.update(self[tid].exons) return templates - def define_graph(self, objects: dict, inters=None, three_prime=False): - - graph = nx.DiGraph() - graph.add_nodes_from(objects.keys()) - - if inters is None: - inters = self._share_extreme - - for obj, other_obj in combinations(sorted(objects.keys()), 2): - self.logger.debug("Comparing %s to %s (%s')", obj, other_obj, "5" if not three_prime else "3") - if obj == other_obj: - continue - else: - edge = inters(objects[obj], objects[other_obj], three_prime=three_prime) - if edge: - assert edge[0].id in self - assert edge[1].id in self - # assert edge[1].id in self.scores - graph.add_edge(edge[0].id, edge[1].id) - - return graph - def _find_communities_boundaries(self, five_graph, three_graph): five_found = set() @@ -891,6 +869,42 @@ def _find_communities_boundaries(self, five_graph, three_graph): return __to_modify + def define_graph(self, objects: dict, inters=None, three_prime=False): + + graph = nx.DiGraph() + graph.add_nodes_from(objects.keys()) + if inters is None: + inters = self._share_extreme + + if len(objects) >= 2: + if (three_prime is True and self.strand != "-") or (three_prime is False and self.strand == "-"): + reverse = True + else: + reverse = False + order = sorted([(objects[tid].start, objects[tid].end, tid) for tid in objects], reverse=reverse) + + for pos in range(len(order) - 1): + obj = order[pos] + self.logger.warning("Checking %s", obj[2]) + for other_obj in order[pos + 1:]: + if obj == other_obj: + continue + elif self.overlap(obj[:2], obj[:2], positive=False, flank=0) == 0: + break + else: + self.logger.warning("Comparing %s to %s (%s')", obj[2], other_obj[2], + "5" if not three_prime else "3") + edge = inters(objects[obj[2]], objects[other_obj[2]], three_prime=three_prime) + if edge: + assert edge[0].id in self + assert edge[1].id in self + # assert edge[1].id in self.scores + graph.add_edge(edge[0].id, edge[1].id) + else: + self.logger.warning("No comparison to be made (objects: %s)", objects) + + return graph + def _share_extreme(self, first: Transcript, second: Transcript, three_prime=False): """ diff --git a/Mikado/loci/superlocus.py b/Mikado/loci/superlocus.py index 26965d979..c6baa6820 100644 --- a/Mikado/loci/superlocus.py +++ b/Mikado/loci/superlocus.py @@ -29,6 +29,8 @@ from ..utilities import dbutils, grouper from ..scales.assigner import Assigner import bisect +import operator +import functools import numpy as np if version_info.minor < 5: from sortedcontainers import SortedDict @@ -1270,13 +1272,12 @@ def define_alternative_splicing(self): cdna_overlap = self.json_conf["pick"]["alternative_splicing"]["min_cdna_overlap"] self.logger.debug("Defining the transcript graph") - t_graph = super().define_graph(self.transcripts, - inters=MonosublocusHolder.is_intersecting, - cds_only=cds_only, - logger=self.logger, - min_cdna_overlap=cdna_overlap, - min_cds_overlap=cds_overlap, - simple_overlap_for_monoexonic=False) + t_graph = self.define_as_graph(inters=MonosublocusHolder.is_intersecting, + cds_only=cds_only, + min_cdna_overlap=cdna_overlap, + min_cds_overlap=cds_overlap, + simple_overlap_for_monoexonic=False) + self.logger.debug("Defined the transcript graph") loci_cliques = dict() @@ -1285,8 +1286,8 @@ def define_alternative_splicing(self): neighbors = set(t_graph.neighbors(locus_instance.primary_transcript_id)) except networkx.exception.NetworkXError: raise networkx.exception.NetworkXError( - "{} {} {}".format( - locus_instance.primary_transcript.attributes["Alias"], + "{} {}".format( + # locus_instance.primary_transcript.attributes["Alias"], locus_instance.primary_transcript_id, list(t_graph.nodes) )) @@ -1423,10 +1424,48 @@ def define_graph(self, cds_only=False) -> networkx.Graph: for key in intronic: graph.add_edges_from(itertools.combinations(intronic[key], 2)) - # This will be quadratic. Hopefully it will not break the program. - for one, other in itertools.combinations(monoexonic, 2): - if self.overlap((one[0], one[1]), (other[0], other[1]), positive=False) > 0: - graph.add_edge(one[2], other[2]) + if len(monoexonic) > 1: + for pos in range(len(monoexonic) - 1): + one = monoexonic[pos] + for other in monoexonic[pos + 1:]: + if self.overlap((one[0], one[1]), (other[0], other[1]), positive=False) > 0: + graph.add_edge(one[2], other[2]) + else: + break + + return graph + + def define_as_graph(self, + inters=MonosublocusHolder.is_intersecting, + cds_only=False, + min_cdna_overlap=0.2, + min_cds_overlap=0.2, + simple_overlap_for_monoexonic=True): + + """This method will try to build the AS graph using a O(nlogn) rather than O(n^2) algorithm.""" + + method = functools.partial(inters, cds_only=cds_only, min_cdna_overlap=min_cdna_overlap, + min_cds_overlap=min_cds_overlap, + simple_overlap_for_monoexonic=simple_overlap_for_monoexonic, + logger=self.logger) + + graph = networkx.Graph() + graph.add_nodes_from(self.transcripts.keys()) + if len(self.transcripts) >= 2: + if cds_only: + order = sorted([(transcript.selected_cds_start, transcript.selected_cds_end, transcript.id) + for transcript in self.transcripts.values()], key=operator.itemgetter(0, 1)) + else: + order = sorted([(transcript.start, transcript.end, transcript.id) + for transcript in self.transcripts.values()], key=operator.itemgetter(0, 1)) + + for pos in range(len(order) -1 ): + one = order[pos] + for other in order[pos + 1:]: + if self.overlap((one[0], one[1]), (other[0], other[1]), positive=True) <= 0: + break + elif method(self[one[2]], self[other[2]]) is True: + graph.add_edge(one[2], other[2]) return graph