diff --git a/src/atmst/mst/diff.py b/src/atmst/mst/diff.py index 9b841b0..2602e86 100644 --- a/src/atmst/mst/diff.py +++ b/src/atmst/mst/diff.py @@ -127,9 +127,9 @@ def _mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: deleted.add(a.frame.node.cid) while True: - while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other + while a.rpath != b.rpath: # we need a loop because they might "leapfrog" each other # "catch up" cursor a, if it's behind - while a.rkey < b.rkey and not a.is_final: + while a.rpath < b.rpath and not a.is_final: if a.subtree: # recurse down every subtree a.down() deleted.add(a.frame.node.cid) @@ -137,20 +137,20 @@ def _mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: a.right() # catch up cursor b, likewise - while b.rkey < a.rkey and not b.is_final: + while b.rpath < a.rpath and not b.is_final: if b.subtree: # recurse down every subtree b.down() created.add(b.frame.node.cid) else: b.right() - # the rkeys now match, but the subrees below us might not + # the rpaths now match, but the subrees below us might not _mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker()) # check if we can still go right XXX: do we need to care about the case where one can, but the other can't? # To consider: maybe if I just step a, b will catch up automagically - if a.rkey == a.stack[0].rkey and b.rkey == b.stack[0].rkey: + if a.rpath == a.stack[0].rpath and b.rpath == b.stack[0].rpath: break a.right() diff --git a/src/atmst/mst/node_walker.py b/src/atmst/mst/node_walker.py index 753e19b..dfd28c4 100644 --- a/src/atmst/mst/node_walker.py +++ b/src/atmst/mst/node_walker.py @@ -18,43 +18,43 @@ class NodeWalker: Recall MSTNode layout: :: - keys: (lkey) (0, 1, 2, 3) (rkey) - vals: (0, 1, 2, 3) - subtrees: (0, 1, 2, 3, 4) + keys: (lpath) (0, 1, 2, 3) (rpath) + vals: (0, 1, 2, 3) + subtrees: (0, 1, 2, 3, 4) """ - KEY_MIN = "" # string that compares less than all legal key strings - KEY_MAX = "\xff" # string that compares greater than all legal key strings + PATH_MIN = "" # string that compares less than all legal path strings + PATH_MAX = "\xff" # string that compares greater than all legal path strings @dataclass class StackFrame: node: MSTNode # could store CIDs only to save memory, in theory, but not much point - lkey: str - rkey: str + lpath: str + rpath: str idx: int ns: NodeStore stack: List[StackFrame] - def __init__(self, ns: NodeStore, root_cid: Optional[CID], lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None: + def __init__(self, ns: NodeStore, root_cid: Optional[CID], lpath: Optional[str]=PATH_MIN, rpath: Optional[str]=PATH_MAX) -> None: self.ns = ns self.stack = [self.StackFrame( node=MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid), - lkey=lkey, - rkey=rkey, + lpath=lpath, + rpath=rpath, idx=0 )] def subtree_walker(self) -> Self: - return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey) + return NodeWalker(self.ns, self.subtree, self.lpath, self.rpath) @property def frame(self) -> StackFrame: return self.stack[-1] @property - def lkey(self) -> str: - return self.frame.lkey if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1] + def lpath(self) -> str: + return self.frame.lpath if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1] @property def lval(self) -> Optional[CID]: @@ -64,10 +64,9 @@ def lval(self) -> Optional[CID]: def subtree(self) -> Optional[CID]: return self.frame.node.subtrees[self.frame.idx] - # hmmmm rkey is overloaded here... "right key" not "record key"... @property - def rkey(self) -> str: - return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx] + def rpath(self) -> str: + return self.frame.rpath if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx] @property def rval(self) -> Optional[CID]: @@ -75,7 +74,7 @@ def rval(self) -> Optional[CID]: @property def is_final(self) -> bool: - return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey) + return (not self.stack) or (self.subtree is None and self.rpath == self.stack[0].rpath) def right(self) -> None: if (self.frame.idx + 1) >= len(self.frame.node.subtrees): @@ -93,8 +92,8 @@ def down(self) -> None: self.stack.append(self.StackFrame( node=self.ns.get_node(subtree), - lkey=self.lkey, - rkey=self.rkey, + lpath=self.lpath, + rpath=self.rpath, idx=0 )) @@ -105,9 +104,10 @@ def next_kv(self) -> Tuple[str, CID]: while self.subtree: # recurse down every subtree self.down() self.right() - return self.lkey, self.lval # the kv pair we just jumped over + return self.lpath, self.lval # the kv pair we just jumped over # iterate over every k/v pair in key-sorted order + # NB: should really be p/v standing for path/value def iter_kv(self) -> Iterable[Tuple[str, CID]]: while not self.is_final: yield self.next_kv() @@ -128,7 +128,7 @@ def iter_node_cids(self) -> Iterable[CID]: # start inclusive def iter_kv_range(self, start: str, end: str, end_inclusive: bool=False) -> Iterable[Tuple[str, CID]]: while True: - while self.rkey < start: + while self.rpath < start: self.right() if not self.subtree: break @@ -141,11 +141,11 @@ def iter_kv_range(self, start: str, end: str, end_inclusive: bool=False) -> Iter def find_value(self, key: str) -> Optional[CID]: while True: - while self.rkey < key: + while self.rpath < key: self.right() if not self.subtree: break self.down() - if self.rkey != key: + if self.rpath != key: return None return self.rval diff --git a/tests/test_varint.py b/tests/test_varint.py index effe9d2..0c4c6e7 100644 --- a/tests/test_varint.py +++ b/tests/test_varint.py @@ -22,7 +22,7 @@ def test_varint_decode(self): self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7f')) # too big self.assertRaises(ValueError, decode_varint, io.BytesIO(b"")) # too short self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff')) # truncated - self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # too minimally encoded + self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # not minimally encoded if __name__ == '__main__': unittest.main(module="tests.test_varint")