Skip to content

Commit

Permalink
rename MST 'keys' to 'paths', for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Feb 27, 2024
1 parent ee7a5e3 commit dbb8329
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
10 changes: 5 additions & 5 deletions src/atmst/mst/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,30 +127,30 @@ def _mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b:
deleted.add(a.frame.node.cid)

while True:
while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other
while a.rpath != b.rpath: # we need a loop because they might "leapfrog" each other
# "catch up" cursor a, if it's behind
while a.rkey < b.rkey and not a.is_final:
while a.rpath < b.rpath and not a.is_final:
if a.subtree: # recurse down every subtree
a.down()
deleted.add(a.frame.node.cid)
else:
a.right()

# catch up cursor b, likewise
while b.rkey < a.rkey and not b.is_final:
while b.rpath < a.rpath and not b.is_final:
if b.subtree: # recurse down every subtree
b.down()
created.add(b.frame.node.cid)
else:
b.right()

# the rkeys now match, but the subrees below us might not
# the rpaths now match, but the subrees below us might not

_mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker())

# check if we can still go right XXX: do we need to care about the case where one can, but the other can't?
# To consider: maybe if I just step a, b will catch up automagically
if a.rkey == a.stack[0].rkey and b.rkey == b.stack[0].rkey:
if a.rpath == a.stack[0].rpath and b.rpath == b.stack[0].rpath:
break

a.right()
Expand Down
46 changes: 23 additions & 23 deletions src/atmst/mst/node_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,43 @@ class NodeWalker:
Recall MSTNode layout: ::
keys: (lkey) (0, 1, 2, 3) (rkey)
vals: (0, 1, 2, 3)
subtrees: (0, 1, 2, 3, 4)
keys: (lpath) (0, 1, 2, 3) (rpath)
vals: (0, 1, 2, 3)
subtrees: (0, 1, 2, 3, 4)
"""
KEY_MIN = "" # string that compares less than all legal key strings
KEY_MAX = "\xff" # string that compares greater than all legal key strings
PATH_MIN = "" # string that compares less than all legal path strings
PATH_MAX = "\xff" # string that compares greater than all legal path strings

@dataclass
class StackFrame:
node: MSTNode # could store CIDs only to save memory, in theory, but not much point
lkey: str
rkey: str
lpath: str
rpath: str
idx: int

ns: NodeStore
stack: List[StackFrame]

def __init__(self, ns: NodeStore, root_cid: Optional[CID], lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None:
def __init__(self, ns: NodeStore, root_cid: Optional[CID], lpath: Optional[str]=PATH_MIN, rpath: Optional[str]=PATH_MAX) -> None:
self.ns = ns
self.stack = [self.StackFrame(
node=MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid),
lkey=lkey,
rkey=rkey,
lpath=lpath,
rpath=rpath,
idx=0
)]

def subtree_walker(self) -> Self:
return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey)
return NodeWalker(self.ns, self.subtree, self.lpath, self.rpath)

@property
def frame(self) -> StackFrame:
return self.stack[-1]

@property
def lkey(self) -> str:
return self.frame.lkey if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1]
def lpath(self) -> str:
return self.frame.lpath if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1]

@property
def lval(self) -> Optional[CID]:
Expand All @@ -64,18 +64,17 @@ def lval(self) -> Optional[CID]:
def subtree(self) -> Optional[CID]:
return self.frame.node.subtrees[self.frame.idx]

# hmmmm rkey is overloaded here... "right key" not "record key"...
@property
def rkey(self) -> str:
return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx]
def rpath(self) -> str:
return self.frame.rpath if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx]

@property
def rval(self) -> Optional[CID]:
return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx]

@property
def is_final(self) -> bool:
return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey)
return (not self.stack) or (self.subtree is None and self.rpath == self.stack[0].rpath)

def right(self) -> None:
if (self.frame.idx + 1) >= len(self.frame.node.subtrees):
Expand All @@ -93,8 +92,8 @@ def down(self) -> None:

self.stack.append(self.StackFrame(
node=self.ns.get_node(subtree),
lkey=self.lkey,
rkey=self.rkey,
lpath=self.lpath,
rpath=self.rpath,
idx=0
))

Expand All @@ -105,9 +104,10 @@ def next_kv(self) -> Tuple[str, CID]:
while self.subtree: # recurse down every subtree
self.down()
self.right()
return self.lkey, self.lval # the kv pair we just jumped over
return self.lpath, self.lval # the kv pair we just jumped over

# iterate over every k/v pair in key-sorted order
# NB: should really be p/v standing for path/value
def iter_kv(self) -> Iterable[Tuple[str, CID]]:
while not self.is_final:
yield self.next_kv()
Expand All @@ -128,7 +128,7 @@ def iter_node_cids(self) -> Iterable[CID]:
# start inclusive
def iter_kv_range(self, start: str, end: str, end_inclusive: bool=False) -> Iterable[Tuple[str, CID]]:
while True:
while self.rkey < start:
while self.rpath < start:
self.right()
if not self.subtree:
break
Expand All @@ -141,11 +141,11 @@ def iter_kv_range(self, start: str, end: str, end_inclusive: bool=False) -> Iter

def find_value(self, key: str) -> Optional[CID]:
while True:
while self.rkey < key:
while self.rpath < key:
self.right()
if not self.subtree:
break
self.down()
if self.rkey != key:
if self.rpath != key:
return None
return self.rval
2 changes: 1 addition & 1 deletion tests/test_varint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_varint_decode(self):
self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7f')) # too big
self.assertRaises(ValueError, decode_varint, io.BytesIO(b"")) # too short
self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff')) # truncated
self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # too minimally encoded
self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # not minimally encoded

if __name__ == '__main__':
unittest.main(module="tests.test_varint")

0 comments on commit dbb8329

Please sign in to comment.