Skip to content

Commit

Permalink
Tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 11, 2023
1 parent 4783ebd commit 80649a6
Showing 1 changed file with 39 additions and 27 deletions.
66 changes: 39 additions & 27 deletions python/tests/test_tree_positioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ class EdgeRange:
order: typing.List | None


# TODO deal with direction change and calling next()/prev on the null
# tree
class TreePosition:
def __init__(self, ts):
self.ts = ts
Expand All @@ -70,8 +68,6 @@ def __init__(self, ts):
self.interval = Interval(0, 0)
self.in_range = EdgeRange(0, 0, None)
self.out_range = EdgeRange(0, 0, None)
self.left_current_index = 0
self.right_current_index = 0

def __str__(self):
s = f"index: {self.index}\ninterval: {self.interval}\n"
Expand All @@ -81,11 +77,20 @@ def __str__(self):
return s

def seek_forward(self, index):
assert index >= self.index and index < self.ts.num_trees
if self.index == -1:
self.direction = FORWARD
self.interval.left = 0
left_current_index = 0
right_current_index = 0
else:
if self.direction == FORWARD:
right_current_index = self.out_range.stop
left_current_index = self.in_range.stop
else:
right_current_index = self.in_range.stop + 1
left_current_index = self.out_range.stop + 1

assert index >= self.index and index < self.ts.num_trees
direction_change = int(self.direction != FORWARD)
self.direction = FORWARD

M = self.ts.num_edges
Expand All @@ -101,25 +106,23 @@ def seek_forward(self, index):
# at the current right index and ends at the first edge
# where the right coordinate is equal to the new tree's
# left coordinate.
j = self.right_current_index + direction_change
j = right_current_index
self.out_range.start = j
# TODO This could be done with binary search
while j < M and right_coords[right_order[j]] <= left:
j += 1
self.out_range.stop = j
self.right_current_index = j

# The range of edges we need to consider for the new tree
# must have right coordinate > left
j = self.left_current_index + direction_change
j = left_current_index
while j < M and right_coords[left_order[j]] <= left:
j += 1
self.in_range.start = j
# TODO this could be done with a binary search
while j < M and left_coords[left_order[j]] <= left:
j += 1
self.in_range.stop = j
self.left_current_index = j

self.interval.left = left
self.interval.right = breakpoints[index + 1]
Expand All @@ -134,10 +137,16 @@ def set_null(self):

def next(self): # NOQA: A003
if self.index == -1:
self.direction = FORWARD

direction_change = int(self.direction != FORWARD)
self.direction = FORWARD
self.interval.left = 0
left_current_index = 0
right_current_index = 0
else:
if self.direction == FORWARD:
right_current_index = self.out_range.stop
left_current_index = self.in_range.stop
else:
right_current_index = self.in_range.stop + 1
left_current_index = self.out_range.stop + 1

M = self.ts.num_edges
breakpoints = self.ts.breakpoints(as_array=True)
Expand All @@ -147,22 +156,21 @@ def next(self): # NOQA: A003
right_order = self.ts.indexes_edge_removal_order
x = self.interval.right

j = self.right_current_index + direction_change
j = right_current_index
self.out_range.start = j
while j < M and right_coords[right_order[j]] == x:
j += 1
self.out_range.stop = j
self.right_current_index = j
self.out_range.order = right_order

j = self.left_current_index + direction_change
j = left_current_index
self.in_range.start = j
while j < M and left_coords[left_order[j]] == x:
j += 1
self.in_range.stop = j
self.left_current_index = j

self.out_range.order = right_order
self.in_range.order = left_order

self.direction = FORWARD
self.index += 1
if self.index == self.ts.num_trees:
self.set_null()
Expand All @@ -176,11 +184,17 @@ def prev(self):
if self.index == -1:
self.index = self.ts.num_trees
self.interval.left = self.ts.sequence_length
self.right_current_index = M - 1
self.left_current_index = M - 1
right_current_index = M - 1
left_current_index = M - 1
self.direction = REVERSE
else:
if self.direction == REVERSE:
right_current_index = self.in_range.stop
left_current_index = self.out_range.stop
else:
right_current_index = self.out_range.stop - 1
left_current_index = self.in_range.stop - 1

direction_change = int(self.direction != REVERSE)
self.direction = REVERSE

breakpoints = self.ts.breakpoints(as_array=True)
Expand All @@ -190,21 +204,19 @@ def prev(self):
left_order = self.ts.indexes_edge_insertion_order
x = self.interval.left

j = self.left_current_index - direction_change
j = left_current_index
self.out_range.start = j
while j >= 0 and left_coords[left_order[j]] == x:
j -= 1
self.out_range.stop = j
self.out_range.order = left_order
self.left_current_index = j

j = self.right_current_index - direction_change
j = right_current_index
self.in_range.start = j
while j >= 0 and right_coords[right_order[j]] == x:
j -= 1
self.in_range.stop = j
self.in_range.order = right_order
self.right_current_index = j

self.index -= 1
if self.index == -1:
Expand Down

0 comments on commit 80649a6

Please sign in to comment.