diff --git a/src/ruptures/detection/bottomup.py b/src/ruptures/detection/bottomup.py index d00cfb91..ec4656ec 100644 --- a/src/ruptures/detection/bottomup.py +++ b/src/ruptures/detection/bottomup.py @@ -37,11 +37,11 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None def _grow_tree(self): """Grow the entire binary tree.""" - partition = [(0, self.n_samples)] + partition = [(-self.n_samples, (0, self.n_samples))] stop = False while not stop: # recursively divide the signal stop = True - start, end = max(partition, key=lambda t: t[1] - t[0]) + _, (start, end) = partition[0] mid = (start + end) * 0.5 bkps = list() for bkp in range(start, end): @@ -50,15 +50,15 @@ def _grow_tree(self): bkps.append(bkp) if len(bkps) > 0: # at least one admissible breakpoint was found bkp = min(bkps, key=lambda x: abs(x - mid)) - partition.remove((start, end)) - partition.append((start, bkp)) - partition.append((bkp, end)) + heapq.heappop(partition) + heapq.heappush(partition, (-bkp + start, (start, bkp))) + heapq.heappush(partition, (-end + bkp, (bkp, end))) stop = False - partition.sort() + partition.sort(key=lambda x: x[1]) # compute segment costs leaves = list() - for start, end in partition: + for _, (start, end) in partition: val = self.cost.error(start, end) leaf = Bnode(start, end, val) leaves.append(leaf) @@ -87,6 +87,7 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None): dict: partition dict {(start, end): cost value,...} """ leaves = sorted(self.leaves) + keys = [leaf.start for leaf in leaves] removed = set() merged = [] for left, right in pairwise(leaves): @@ -121,10 +122,13 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None): if not stop: # updates the list of leaves (i.e. segments of the partitions) # find the merged segments indexes - keys = [leaf.start for leaf in leaves] left_idx = bisect_left(keys, leaf.left.start) - leaves[left_idx] = leaf # replace leaf.left - del leaves[left_idx + 1] # remove leaf.right + # replace leaf.left + leaves[left_idx] = leaf + keys[left_idx] = leaf.start + # remove leaf.right + del leaves[left_idx + 1] + del keys[left_idx + 1] # add to the set of removed segments. removed.add(leaf.left) removed.add(leaf.right)