Skip to content

Commit

Permalink
Tree: various speedups
Browse files Browse the repository at this point in the history
- make dataclass non-frozen
- use mutate() for cases where a Map is modified multiple times
- remove asserts for cases that would fail immediately anyway
  • Loading branch information
matthiasdiener committed Nov 27, 2024
1 parent 1af4523 commit 1ff3d54
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
NodeT = TypeVar("NodeT", bound=Hashable)


@dataclass(frozen=True)
# Not frozen because it is slower. Tree objects are immutable, and offer no
# way to mutate the tree.
@dataclass(frozen=False)
class Tree(Generic[NodeT]):
"""
An immutable tree containing nodes of type :class:`NodeT`.
Expand Down Expand Up @@ -95,8 +97,6 @@ def ancestors(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns a :class:`tuple` of nodes that are ancestors of *node*.
"""
assert node in self

if self.is_root(node):
# => root
return ()
Expand All @@ -110,42 +110,33 @@ def parent(self, node: NodeT) -> NodeT | None:
"""
Returns the parent of *node*.
"""
assert node in self

return self._child_to_parent[node]

def children(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns the children of *node*.
"""
assert node in self

return self._parent_to_children[node]

@memoize_method
def depth(self, node: NodeT) -> int:
"""
Returns the depth of *node*, with the root having depth 0.
"""
assert node in self

if self.is_root(node):
# => None
return 0

parent_of_node = self.parent(node)
assert parent_of_node is not None

return 1 + self.depth(parent_of_node)

def is_root(self, node: NodeT) -> bool:
assert node in self

"""Return *True* if *node* is the root of the tree."""
return self.parent(node) is None

def is_leaf(self, node: NodeT) -> bool:
assert node in self

"""Return *True* if *node* has no children."""
return len(self.children(node)) == 0

def __contains__(self, node: NodeT) -> bool:
Expand All @@ -162,9 +153,11 @@ def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]:

siblings = self._parent_to_children[parent]

return Tree((self._parent_to_children
.set(parent, (*siblings, node))
.set(node, ())),
_parent_to_children_mut = self._parent_to_children.mutate()
_parent_to_children_mut[parent] = (*siblings, node)
_parent_to_children_mut[node] = ()

return Tree(_parent_to_children_mut.finish(),
self._child_to_parent.set(node, parent))

def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:
Expand Down Expand Up @@ -234,13 +227,12 @@ def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
parents_new_children = tuple(frozenset(siblings) - frozenset([node]))
new_parents_children = (*self.children(new_parent), node)

new_child_to_parent = self._child_to_parent.set(node, new_parent)
new_parent_to_children = (self._parent_to_children
.set(parent, parents_new_children)
.set(new_parent, new_parents_children))
_parent_to_children_mut = self._parent_to_children.mutate()
_parent_to_children_mut[parent] = parents_new_children
_parent_to_children_mut[new_parent] = new_parents_children

return Tree(new_parent_to_children,
new_child_to_parent)
return Tree(_parent_to_children_mut.finish(),
self._child_to_parent.set(node, new_parent))

def __str__(self) -> str:
"""
Expand Down

0 comments on commit 1ff3d54

Please sign in to comment.