From fd3396367ea1547e87262b74f5465a499337f2bf Mon Sep 17 00:00:00 2001 From: Jules Kreuer Date: Sat, 13 Apr 2024 14:59:49 +0200 Subject: [PATCH 1/7] add tree fixation using newick string --- algorithms.py | 259 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 251 insertions(+), 8 deletions(-) diff --git a/algorithms.py b/algorithms.py index 484ff7488..18b992cb5 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1,6 +1,7 @@ """ Python version of the simulation algorithm. """ + import argparse import heapq import itertools @@ -127,9 +128,10 @@ def __init__(self, index): self.label = 0 self.index = index self.hull = None + self.origin = set() def __repr__(self): - return repr((self.left, self.right, self.node)) + return repr((self.left, self.right, self.node, self.origin)) @staticmethod def show_chain(seg): @@ -435,6 +437,14 @@ def find_indv(self, indv): """ return self._ancestors[indv.label].index(indv) + def get_emerged_from_lineage(self, index, label): + """ + Returns the indices of lineages that emerged from a given one. + """ + return [ + i for i, s in enumerate(self._ancestors[label]) if index.issubset(s.origin) + ] + class Pedigree: """ @@ -840,6 +850,7 @@ def __init__( gene_conversion_rate=0.0, gene_conversion_length=1, discrete_genome=True, + coalescent_events=[], ): # Must be a square matrix. N = len(migration_matrix) @@ -898,6 +909,15 @@ def __init__( pop.set_growth_rate(population_growth_rates[pop.id], 0) self.edge_buffer = [] + self.fixed_coalescent_events = False + self.coalescent_events = [] + + if coalescent_events: + self.fixed_coalescent_events = True + self.parse_nwk(coalescent_events) + self.coalescent_events.sort() + logger.debug(self.coalescent_events) + # set hull_offset for smc_k, deviates from actual pattern # implemented using `ParametricAncestryModel()` self.hull_offset = None @@ -981,6 +1001,7 @@ def initialise(self, ts): # Insert the segment chains into the algorithm state. for node in range(ts.num_nodes): seg = root_segments_head[node] + seg.origin = {seg.node} if seg is not None: left_end = seg.left pop = seg.population @@ -1072,6 +1093,7 @@ def alloc_segment( next=None, # noqa: A002 label=0, hull=None, + origin={}, ): """ Pops a new segment off the stack and sets its properties. @@ -1085,6 +1107,7 @@ def alloc_segment( s.prev = prev s.label = label s.hull = hull + s.origin = origin return s def copy_segment(self, segment): @@ -1096,6 +1119,7 @@ def copy_segment(self, segment): next=segment.next, prev=segment.prev, label=segment.label, + origin=segment.origin, ) def free_segment(self, u): @@ -1337,7 +1361,39 @@ def hudson_simulate(self, end_time): event = "MOD" else: self.t += min_time - if min_time == t_re: + + if ( + self.fixed_coalescent_events + and self.coalescent_events + and self.coalescent_events[0][0] < self.t + ): + # Fixed Coalescent event. + logger.debug("Fixed CA @", self.t, self.coalescent_events[0]) + event = "CA" + ce = self.coalescent_events.pop(0) + + # Reset to fixed time + prev_time = self.t - min_time + # Check if the current event should have happen earlier + # or at the same time as the last event. + if ce[0] <= prev_time: + # Add epsilon as two events can't happen simultaneously. + self.t = prev_time + 0.00000001 + else: + # Reset to time of ce event and add epsilon to avoid colission with leaf nodes. + self.t = ce[0] + 0.00000001 + + self.common_ancestor_event( + ca_population, + 0, + lineage_a=ce[1], + lineage_b=ce[2], + ) + + if self.P[ca_population].get_num_ancestors() == 0: + non_empty_pops.remove(ca_population) + + elif min_time == t_re: event = "RE" self.hudson_recombination_event(0) elif min_time == t_gcin: @@ -2354,9 +2410,29 @@ def get_random_pair(self, pop, label): return (hull1_index, hull2.index) - def common_ancestor_event(self, population_index, label): + def is_blocked_ancestor(self, i, pop, label): + """ + Checks if the ancestor at index i is required for a future common ancestor event. + """ + # If no fixed events happen in the future (the list is empty): + if not self.coalescent_events: + return False + + # Check if ancestor i is required as ancestor in a future event. + _, ceA, ceB = zip(*self.coalescent_events) + return any(pop._ancestors[label][i].origin.issubset(j) for j in ceA + ceB) + + def common_ancestor_event( + self, + population_index, + label, + lineage_a=None, + lineage_b=None, + blocked=False, + ): """ Implements a coancestry event. + If lineage_a and lineage_b are set, only lines emerged from those will be selected. """ pop = self.P[population_index] @@ -2376,11 +2452,36 @@ def common_ancestor_event(self, population_index, label): self.free_hull(hull_j) else: - # Choose two ancestors uniformly. - j = random.randint(0, pop.get_num_ancestors(label) - 1) - x = pop.remove(j, label) - j = random.randint(0, pop.get_num_ancestors(label) - 1) - y = pop.remove(j, label) + if (lineage_a is None) ^ (lineage_b is None): # a xor b + raise RuntimeError( + "For a fixed Common Ancestor event, both lineages must be named." + ) + + if lineage_a is None and lineage_b is None: + # Choose uniformly from all lineages + all_lineage_ids = range(0, pop.get_num_ancestors(label) - 1) + i, j = random.choices(all_lineage_ids, k=2) + # random sampling without replacement + + if ( + blocked + and self.is_blocked_ancestor(i, pop, label) + or self.is_blocked_ancestor(j, pop, label) + ): + return + + x = pop.remove(i, label) + y = pop.remove(j, label) + + else: + # Get all lineages that emerged from lineage_a and lineage_b + lineage_ids = pop.get_emerged_from_lineage(lineage_a, label) + i = random.choice(lineage_ids) + x = pop.remove(i, label) + + lineage_ids = pop.get_emerged_from_lineage(lineage_b, label) + j = random.choice(lineage_ids) + y = pop.remove(j, label) self.merge_two_ancestors(population_index, label, x, y) @@ -2407,6 +2508,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): y = beta if x.right <= y.left: alpha = x + alpha.origin = x.origin | y.origin x = x.next alpha.next = None elif x.left != y.left: @@ -2414,6 +2516,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): alpha.prev = None alpha.next = None alpha.right = y.left + alpha.origin = x.origin | y.origin x.left = y.left else: if not coalescence: @@ -2446,6 +2549,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): node=u, population=population_index, label=label, + origin=x.origin | y.origin, ) if x.node != u: # required for dtwf and fixed_pedigree self.store_edge(left, right, u, x.node) @@ -2503,6 +2607,132 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): hull.right = min(right + self.hull_offset, self.L) pop.add_hull(label, hull) + def parse_nwk(self, nwk: str) -> None: + """ + Parses newick string. + """ + # Remove Tail and whitespace + nwk = nwk.strip().strip(";") + nwk = nwk.replace(" ", "") + + # Remove Root Time + head, tail = nwk.rsplit(":", 1) + if ")" not in tail: + nwk = head + + # Remove Root Label + if not nwk.endswith(")"): + nwk = nwk.rsplit(")", 1)[0] + ")" + + # Remove redundant brackets + if nwk.startswith("(") and nwk.endswith(")"): + nwk = nwk[1:-1] + + # Labels of leaf nodes + leaf_nwk = [ + label[0] + for n in nwk.replace("(", ",").split(",") + if (label := n.split(":"))[0] + ] + + # Create initial mapping for labels (str) to id (int). + self.leaf_mapping = {label: i for i, label in enumerate(leaf_nwk)} + self.label_mapping = {i: label for i, label in enumerate(leaf_nwk)} + + if len(self.tables.nodes) < len(leaf_nwk): + raise ValueError( + "Population in newick string is larger than the provided sample_size. " + f"Please increase to {len(leaf_nwk)} or greater." + ) + + if len(leaf_nwk) < len(self.tables.nodes): + # TODO check if warning library can be imported + print( + "Population in newick string is smaller than the provided sample_size." + ) + + self.parse_nwk_str(nwk, 0) + + def bracket_split(self, nkw: str): + """ + Splits a string at "," if and only if it is not enclosed by brackets. + Input: + nwk: str, string to split. + Returns: + splits: List[str], list of substrings. + """ + splits = [] + level = 0 + next_split = [] + for c in nkw: + if c == "," and level == 0: + splits.append("".join(next_split)) + next_split = [] + else: + if c == "(": + level += 1 + elif c == ")": + level -= 1 + next_split.append(c) + splits.append("".join(next_split)) + return splits + + def parse_nwk_str(self, newick_str: str, time: float): + """ + Parses a newick string recursivly. + """ + # Split nwk intro left and right part. + left_node, right_node = self.bracket_split(newick_str) + + left_time = float(left_node.split(":")[-1].split(")", 1)[0]) + right_time = float(right_node.split(":")[-1].split(")", 1)[0]) + + # Clean left node and remove label + left_node = left_node.rsplit(":", 1)[0] + left_node = left_node[1:] if left_node.startswith("(") else left_node + left_node = left_node[:-1] if left_node.endswith(")") else left_node + + # Clean right node and remove label + right_node = right_node.rsplit(":", 1)[0] + right_node = right_node[1:] if right_node.startswith("(") else right_node + right_node = right_node[:-1] if right_node.endswith(")") else right_node + + # Check if left node is a leaf + if not "," in left_node: + if left_node in self.leaf_mapping: + # Replace leaf name with mapped id + node_id = self.leaf_mapping[left_node] + else: + raise RuntimeError(f"Invalid newick structure: {left_node}") + left_node = [node_id] + else: + left_node, sub_left_time = self.parse_nwk_str(left_node, time + left_time) + left_time = left_time + sub_left_time + + # Check if left node is a leaf + if not "," in right_node: + if right_node in self.leaf_mapping: + # Replace leaf name with mapped id + node_id = self.leaf_mapping[right_node] + else: + raise RuntimeError(f"Invalid newick structure {right_node}") + + right_node = [node_id] + + else: + right_node, sub_right_time = self.parse_nwk_str( + right_node, time + right_time + ) + right_time = right_time + sub_right_time + + # Left and right time should be euqal. + # To handle potential inconsistentcies the maximum time is used. + ce_time = max(left_time, right_time) + self.coalescent_events.append((ce_time, set(left_node), set(right_node))) + node_name = left_node + right_node + + return node_name, ce_time + def print_state(self, verify=False): print("State @ time ", self.t) for label in range(self.num_labels): @@ -2771,6 +3001,9 @@ def run_simulate(args): rates = args.recomb_rates recombination_map = RateMap(positions, rates) num_labels = 1 + + coalescent_events = args.ce_events + sweep_trajectory = None if args.model == "single_sweep": if num_populations > 1: @@ -2818,9 +3051,11 @@ def run_simulate(args): gene_conversion_rate=gc_rate, gene_conversion_length=mean_tract_length, discrete_genome=args.discrete, + coalescent_events=coalescent_events, ) ts = s.simulate(args.end_time) ts.dump(args.output_file) + if args.verbose: s.print_state() @@ -2878,6 +3113,14 @@ def add_simulator_arguments(parser): parser.add_argument( "--census-time", type=float, nargs=1, action="append", default=[] ) + parser.add_argument( + "--ce-events", + default="", + help="""Specify the coalescent events as newick string. + Example: Merge A and B at time 0.5. Merge (A,B) with C at time 0.8. + --> "((A:0.5, B:0.5):0.3, C:0.8);" + """, + ) parser.add_argument( "--trajectory", type=float, From 384b23c5abdb2b4411316b77ea4a37b6c0196246 Mon Sep 17 00:00:00 2001 From: Jules Kreuer Date: Sat, 13 Apr 2024 17:42:48 +0200 Subject: [PATCH 2/7] fix version specific errors --- algorithms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithms.py b/algorithms.py index 18b992cb5..7e2170385 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1001,8 +1001,8 @@ def initialise(self, ts): # Insert the segment chains into the algorithm state. for node in range(ts.num_nodes): seg = root_segments_head[node] - seg.origin = {seg.node} if seg is not None: + seg.origin = {seg.node} left_end = seg.left pop = seg.population label = seg.label @@ -1093,7 +1093,7 @@ def alloc_segment( next=None, # noqa: A002 label=0, hull=None, - origin={}, + origin=set(), ): """ Pops a new segment off the stack and sets its properties. From 81a416989ddaf834acb736970599b42ce3fa9875 Mon Sep 17 00:00:00 2001 From: Jules Kreuer Date: Mon, 15 Apr 2024 11:53:52 +0200 Subject: [PATCH 3/7] improve lineage blocking --- algorithms.py | 55 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/algorithms.py b/algorithms.py index 7e2170385..c935142fa 100644 --- a/algorithms.py +++ b/algorithms.py @@ -2410,17 +2410,46 @@ def get_random_pair(self, pop, label): return (hull1_index, hull2.index) - def is_blocked_ancestor(self, i, pop, label): + def is_blocked_ancestor(self, i: int, pop, label) -> int: """ Checks if the ancestor at index i is required for a future common ancestor event. + + Returns + ------- + 0: If the ancestor is not required or multiple similar ancestor exists. + 1: If blocked. + 2: If exactly two ancestor with the same origin exists. """ # If no fixed events happen in the future (the list is empty): if not self.coalescent_events: - return False + return 0 # Check if ancestor i is required as ancestor in a future event. _, ceA, ceB = zip(*self.coalescent_events) - return any(pop._ancestors[label][i].origin.issubset(j) for j in ceA + ceB) + + ancestor = pop._ancestors[label][i] + # Stricter version + # return any(ancestor.origin.issubset(j) for j in ceA + ceB) + + # Relevant Common Ancestor Events for that the selected ancestor could be required + ca_events = [j for j in ceA + ceB if ancestor.origin.issubset(j)] + + if not ca_events: + return 0 + + # Check if multiple ancestors with this origin exists + matching_origins = 0 + for j in pop._ancestors[label]: + # Has j same origin as selected ancestor. + if ancestor.origin.issubset(j.origin) and any( + j.origin.issubset(k) for k in ca_events + ): + matching_origins += 1 + # If more than two is there -> ancestor can be chosen. + if matching_origins == 3: + return 0 + + return matching_origins # 1 or 2 def common_ancestor_event( self, @@ -2428,7 +2457,6 @@ def common_ancestor_event( label, lineage_a=None, lineage_b=None, - blocked=False, ): """ Implements a coancestry event. @@ -2463,12 +2491,19 @@ def common_ancestor_event( i, j = random.choices(all_lineage_ids, k=2) # random sampling without replacement - if ( - blocked - and self.is_blocked_ancestor(i, pop, label) - or self.is_blocked_ancestor(j, pop, label) - ): - return + blocked_i = self.is_blocked_ancestor(i, pop, label) + blocked_j = self.is_blocked_ancestor(j, pop, label) + if blocked_i or blocked_j: + # Check edge case where two identical lineages were chosen: + if blocked_i == blocked_j == 2: + if ( + pop._ancestors[label][i].origin + == pop._ancestors[label][j].origin + ): + # Skip as they are required. + return + else: + return x = pop.remove(i, label) y = pop.remove(j, label) From 6a10c32e58042882933a4dab70db7607ddaa8e2c Mon Sep 17 00:00:00 2001 From: Jules Kreuer Date: Mon, 15 Apr 2024 13:20:57 +0200 Subject: [PATCH 4/7] linting --- algorithms.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/algorithms.py b/algorithms.py index c935142fa..ec4555ae1 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1,7 +1,6 @@ """ Python version of the simulation algorithm. """ - import argparse import heapq import itertools @@ -850,7 +849,7 @@ def __init__( gene_conversion_rate=0.0, gene_conversion_length=1, discrete_genome=True, - coalescent_events=[], + coalescent_events=None, ): # Must be a square matrix. N = len(migration_matrix) @@ -912,7 +911,7 @@ def __init__( self.fixed_coalescent_events = False self.coalescent_events = [] - if coalescent_events: + if coalescent_events is not None: self.fixed_coalescent_events = True self.parse_nwk(coalescent_events) self.coalescent_events.sort() @@ -1093,7 +1092,7 @@ def alloc_segment( next=None, # noqa: A002 label=0, hull=None, - origin=set(), + origin=None, ): """ Pops a new segment off the stack and sets its properties. @@ -1107,7 +1106,7 @@ def alloc_segment( s.prev = prev s.label = label s.hull = hull - s.origin = origin + s.origin = origin if origin is not None else set() return s def copy_segment(self, segment): @@ -1380,7 +1379,8 @@ def hudson_simulate(self, end_time): # Add epsilon as two events can't happen simultaneously. self.t = prev_time + 0.00000001 else: - # Reset to time of ce event and add epsilon to avoid colission with leaf nodes. + # Reset to time of ce event and add epsilon to avoid + # collision with leaf nodes. self.t = ce[0] + 0.00000001 self.common_ancestor_event( @@ -2431,7 +2431,7 @@ def is_blocked_ancestor(self, i: int, pop, label) -> int: # Stricter version # return any(ancestor.origin.issubset(j) for j in ceA + ceB) - # Relevant Common Ancestor Events for that the selected ancestor could be required + # Relevant Common Ancestor events ca_events = [j for j in ceA + ceB if ancestor.origin.issubset(j)] if not ca_events: @@ -2460,7 +2460,8 @@ def common_ancestor_event( ): """ Implements a coancestry event. - If lineage_a and lineage_b are set, only lines emerged from those will be selected. + If lineage_a and lineage_b are set, only lineages emerged from those + will be selected. Raises an error if only one lineage is provided. """ pop = self.P[population_index] @@ -2733,7 +2734,7 @@ def parse_nwk_str(self, newick_str: str, time: float): right_node = right_node[:-1] if right_node.endswith(")") else right_node # Check if left node is a leaf - if not "," in left_node: + if "," not in left_node: if left_node in self.leaf_mapping: # Replace leaf name with mapped id node_id = self.leaf_mapping[left_node] @@ -2745,7 +2746,7 @@ def parse_nwk_str(self, newick_str: str, time: float): left_time = left_time + sub_left_time # Check if left node is a leaf - if not "," in right_node: + if "," not in right_node: if right_node in self.leaf_mapping: # Replace leaf name with mapped id node_id = self.leaf_mapping[right_node] From c963da129bfcd40a28b3227c97fa3381a40f5a4e Mon Sep 17 00:00:00 2001 From: Jules Kreuer Date: Mon, 15 Apr 2024 13:30:51 +0200 Subject: [PATCH 5/7] fix flake8 compliance for missing argument --- algorithms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithms.py b/algorithms.py index ec4555ae1..345c5fd84 100644 --- a/algorithms.py +++ b/algorithms.py @@ -911,7 +911,7 @@ def __init__( self.fixed_coalescent_events = False self.coalescent_events = [] - if coalescent_events is not None: + if coalescent_events: self.fixed_coalescent_events = True self.parse_nwk(coalescent_events) self.coalescent_events.sort() From 10870754fbcdd0a5dd631d49d11398d3c4c26abd Mon Sep 17 00:00:00 2001 From: Jules Kreuer Date: Wed, 17 Apr 2024 20:57:33 +0200 Subject: [PATCH 6/7] add loading from tree sequence --- algorithms.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/algorithms.py b/algorithms.py index 345c5fd84..9325401c5 100644 --- a/algorithms.py +++ b/algorithms.py @@ -849,7 +849,8 @@ def __init__( gene_conversion_rate=0.0, gene_conversion_length=1, discrete_genome=True, - coalescent_events=None, + coalescent_events_nwk=None, + coalescent_events_ts=None, ): # Must be a square matrix. N = len(migration_matrix) @@ -911,9 +912,14 @@ def __init__( self.fixed_coalescent_events = False self.coalescent_events = [] - if coalescent_events: + if coalescent_events_nwk: self.fixed_coalescent_events = True - self.parse_nwk(coalescent_events) + self.parse_nwk(coalescent_events_nwk) + self.coalescent_events.sort() + logger.debug(self.coalescent_events) + elif coalescent_events_ts: + self.fixed_coalescent_events = True + self.parse_ts(coalescent_events_ts) self.coalescent_events.sort() logger.debug(self.coalescent_events) @@ -2769,6 +2775,61 @@ def parse_nwk_str(self, newick_str: str, time: float): return node_name, ce_time + def parse_ts(self, ts: tskit.TreeSequence) -> None: + """ + Parses coalescent events from tree sequence file. + Does only work for one welldefined tree without gene conversion or recombination. + """ + tables = ts.dump_tables() + + edges = list(tables.edges) + events = [] + while edges: + e1 = edges.pop(0) + # Sweep until sibling edge is found. Usually the next in list. + # Can be simplified if that assumption is made. + for i, e2 in enumerate(edges): + if e1.parent == e2.parent: + e2 = edges.pop(i) + break + # Get time of parent node (coalescent time) + t = [n.time for i, n in enumerate(tables.nodes) if i == e1.parent][0] + events.append((t, e1.child, e2.child, e1.parent)) + events.sort() + + # Add all leafs + node_mapping = {i: {i} for i, n in enumerate(tables.nodes) if n.time == 0} + if len(self.tables.nodes) < len(node_mapping): + raise ValueError( + "Population in tree sequence is larger than the provided sample_size. " + f"Please increase to {len(node_mapping)} or greater." + ) + if len(node_mapping) < len(self.tables.nodes): + # TODO check if warning library can be imported + print( + "Population in tree sequence is smaller than the provided sample_size." + ) + + # While loop is only required if events happen at + # the same time and are not ordered correctly. + # Can be simplified if that assumption is made. + while events: + t, e1, e2, p = events.pop(0) + if e1 not in node_mapping or e2 not in node_mapping: + # Add event to end of the queue. + events.append(t, e1, e2, p) + continue + + e1 = node_mapping[e1] + e2 = node_mapping[e2] + self.coalescent_events.append((t, e1, e2)) + + if p in node_mapping: + p = node_mapping[p] + else: + # Add parent node to mapping + node_mapping[p] = e1 | e2 + def print_state(self, verify=False): print("State @ time ", self.t) for label in range(self.num_labels): @@ -3038,7 +3099,16 @@ def run_simulate(args): recombination_map = RateMap(positions, rates) num_labels = 1 - coalescent_events = args.ce_events + if args.ce_from_nwk and args.ce_from_ts: + raise RuntimeError( + "Can't load coalescent events from newick and tree sequence simultaneously." + ) + + coalescent_events_nwk = args.ce_from_nwk + coalescent_events_ts = args.ce_from_ts + + if coalescent_events_ts: + coalescent_events_ts = tskit.load(coalescent_events_ts) sweep_trajectory = None if args.model == "single_sweep": @@ -3087,7 +3157,8 @@ def run_simulate(args): gene_conversion_rate=gc_rate, gene_conversion_length=mean_tract_length, discrete_genome=args.discrete, - coalescent_events=coalescent_events, + coalescent_events_nwk=coalescent_events_nwk, + coalescent_events_ts=coalescent_events_ts, ) ts = s.simulate(args.end_time) ts.dump(args.output_file) @@ -3095,6 +3166,8 @@ def run_simulate(args): if args.verbose: s.print_state() + return ts + def add_simulator_arguments(parser): parser.add_argument("sample_size", type=int) @@ -3150,13 +3223,18 @@ def add_simulator_arguments(parser): "--census-time", type=float, nargs=1, action="append", default=[] ) parser.add_argument( - "--ce-events", + "--ce-from-nwk", default="", help="""Specify the coalescent events as newick string. Example: Merge A and B at time 0.5. Merge (A,B) with C at time 0.8. --> "((A:0.5, B:0.5):0.3, C:0.8);" """, ) + parser.add_argument( + "--ce-from-ts", + default="", + help="""Load coalescent events from a tree sequence file.""", + ) parser.add_argument( "--trajectory", type=float, From 9fd4538c48f7f161d8761c1df9f7c586f72b77ee Mon Sep 17 00:00:00 2001 From: Jules Kreuer <25013642+not-a-feature@users.noreply.github.com> Date: Sun, 28 Apr 2024 17:54:34 +0200 Subject: [PATCH 7/7] fix rare edge case of identical selected lineages --- algorithms.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/algorithms.py b/algorithms.py index 9325401c5..ee3e734e9 100644 --- a/algorithms.py +++ b/algorithms.py @@ -2517,12 +2517,21 @@ def common_ancestor_event( else: # Get all lineages that emerged from lineage_a and lineage_b - lineage_ids = pop.get_emerged_from_lineage(lineage_a, label) - i = random.choice(lineage_ids) - x = pop.remove(i, label) + lineage_ids_a = pop.get_emerged_from_lineage(lineage_a, label) + lineage_ids_b = pop.get_emerged_from_lineage(lineage_b, label) + + if len(lineage_ids_a) == 1 and lineage_ids_a == lineage_ids_b: + return - lineage_ids = pop.get_emerged_from_lineage(lineage_b, label) - j = random.choice(lineage_ids) + i = None + j = None + while i == j: + i = random.choice(lineage_ids_a) + j = random.choice(lineage_ids_b) + + x = pop.remove(i, label) + if i < j: + j -= 1 y = pop.remove(j, label) self.merge_two_ancestors(population_index, label, x, y)