Skip to content

Commit

Permalink
introduce .within filter
Browse files Browse the repository at this point in the history
  • Loading branch information
rodja committed Oct 31, 2023
1 parent 7eafdc4 commit 86af04e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
14 changes: 10 additions & 4 deletions nicegui/get.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Iterator, Optional, Type, TypeVar
from typing import Generic, Iterator, List, Optional, Type, TypeVar

from typing_extensions import Self

Expand All @@ -12,6 +12,7 @@ class elements(Generic[T], Iterator[T]):

def __init__(self, *, type: Optional[Type[T]] = Element):
self.type = type
self._within_types = []

def __iter__(self) -> Iterator[T]:
client = context.get_client()
Expand All @@ -22,18 +23,23 @@ def __next__(self) -> T: # Define __next__ to return the next item from _iterat
raise StopIteration
return next(self._iterator)

def iterate(self, parent: Element) -> Iterator[T]:
def iterate(self, parent: Element, *, visited: List[Element] = []) -> Iterator[T]:
for element in parent:
if self.type is None or isinstance(element, self.type):
yield element
yield from self.iterate(element)
if not self._within_types or any(isinstance(element, type) for type in self._within_types for element in visited):
yield element
yield from self.iterate(element, visited=visited + [element])

def __len__(self) -> int:
return len(list(iter(self)))

def __getitem__(self, index) -> T:
return list(iter(self))[index]

def within(self, *, type: Optional[Type[T]] = Element) -> Self:
self._within_types.append(type)
return self

def classes(self, add: Optional[str] = None, *, remove: Optional[str] = None, replace: Optional[str] = None) -> Self:
for element in self:
element.classes(add, remove=remove, replace=replace)
Expand Down
16 changes: 14 additions & 2 deletions tests/test_get_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,28 @@ def test_get_all(screen: Screen):


def test_get_by_type(screen: Screen):
ui.button('button A')
ui.label('label A')
ui.button('button B')
ui.label('label B')

result = ', '.join(b.text for b in ui.get(type=ui.button))

screen.open('/')
assert result == 'button A, button B'


def test_get_within(screen: Screen):
ui.button('button A')
ui.label('label A')
with ui.row():
ui.button('button B')
ui.label('label B')

ui.label(', '.join(b.text for b in ui.get(type=ui.button)))
result = [b.text for b in ui.get(type=ui.button).within(type=ui.row)]

screen.open('/')
screen.should_contain('button A, button B')
assert result == ['button B']


def test_setting_classes(screen: Screen):
Expand Down

0 comments on commit 86af04e

Please sign in to comment.