Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start using a generic retry function #1251

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 11 additions & 32 deletions splinter/driver/webdriver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __repr__(self):
)


def _find(self, finder, finder_kwargs=None):
def _safe_find(finder, finder_kwargs=None):
"""Search for elements. Returns a list of results.

Arguments:
Expand All @@ -209,12 +209,10 @@ def _find(self, finder, finder_kwargs=None):

Returns:
list

"""
finder_kwargs = finder_kwargs or {}

elements = None
elem_list = []
elements = []

try:
elements = finder(**finder_kwargs)
Expand All @@ -225,14 +223,10 @@ def _find(self, finder, finder_kwargs=None):
NoSuchElementException,
StaleElementReferenceException,
):
# This exception is sometimes thrown if the page changes
# quickly
# Exception which can be thrown if the page isn't ready.
pass

if elements:
elem_list = [self.element_class(element, self, finder_kwargs) for element in elements]

return elem_list
return elements


def find_by(
Expand All @@ -249,26 +243,18 @@ def find_by(

Returns:
ElementList

"""
elem_list = []

find_by = original_find or finder_kwargs["by"]
query = original_query or finder_kwargs.get("value")

# Zero second wait time means only check once
if wait_time == 0:
elem_list = _find(self, finder, finder_kwargs)
else:
wait_time = wait_time or self.wait_time
end_time = time.time() + wait_time

while time.time() < end_time:
elem_list = _find(self, finder, finder_kwargs)

if elem_list:
break
elements = _retry(
_safe_find,
[finder],
{"finder_kwargs": finder_kwargs},
timeout=self.wait_time,
)

elem_list = [self.element_class(elem, self, finder_kwargs) for elem in elements]
return ElementList(elem_list, find_by=find_by, query=query)


Expand Down Expand Up @@ -466,7 +452,6 @@ def find_by_css(self, css_selector, wait_time=None):
self.driver.find_elements,
finder_kwargs={"by": By.CSS_SELECTOR, "value": css_selector},
original_find="css",
original_query=css_selector,
wait_time=wait_time,
)

Expand All @@ -490,15 +475,13 @@ def find_by_name(self, name, wait_time=None):
return self.find_by(
self.driver.find_elements,
finder_kwargs={"by": By.NAME, "value": name},
original_find="name",
wait_time=wait_time,
)

def find_by_tag(self, tag, wait_time=None):
return self.find_by(
self.driver.find_elements,
finder_kwargs={"by": By.TAG_NAME, "value": tag},
original_find="tag_name",
wait_time=wait_time,
)

Expand Down Expand Up @@ -526,7 +509,6 @@ def find_by_id(self, id, wait_time=None): # NOQA: A002
return self.find_by(
self.driver.find_element,
finder_kwargs={"by": By.ID, "value": id},
original_find="id",
wait_time=wait_time,
)

Expand Down Expand Up @@ -951,15 +933,13 @@ def find_by_name(self, selector, wait_time=None):
return self.find_by(
self._element.find_elements,
finder_kwargs={"by": By.NAME, "value": selector},
original_find="name",
wait_time=wait_time,
)

def find_by_tag(self, selector, wait_time=None):
return self.find_by(
self._element.find_elements,
finder_kwargs={"by": By.TAG_NAME, "value": selector},
original_find="tag",
wait_time=wait_time,
)

Expand Down Expand Up @@ -990,7 +970,6 @@ def find_by_id(self, selector, wait_time=None):
return self.find_by(
self._element.find_elements,
finder_kwargs={"by": By.ID, "value": selector},
original_find="id",
wait_time=wait_time,
)

Expand Down
33 changes: 22 additions & 11 deletions splinter/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,35 @@ def _retry(
fn_args: Optional[list] = None,
fn_kwargs: Optional[dict] = None,
timeout: int = 0,
) -> bool:
"""Retry a function that should return a truthy value until a timeout is hit.
) -> Any:
"""Retry a function until it returns a non-falsey result or timeout is hit.

If timeout is set to 0, the function will only be run once.

This will not wrap Exceptions, only falsey values.

Arguments:
fn: Function to retry
timeout: Number of seconds to retry.
fn: A function to retry.
timeout: How long, in seconds, to retry the function.

Returns:
bool - True if the function returns a truthy value before the timeout, else False.

The final return value of func.
"""
fn_args = fn_args or []
fn_kwargs = fn_kwargs or {}

end_time = time.time() + timeout
result = None

while time.time() < end_time:
# Zero second wait time means only check once
if timeout == 0:
result = fn(*fn_args, **fn_kwargs)
if result:
return True
return False
else:
end_time = time.time() + timeout

while time.time() < end_time:
result = fn(*fn_args, **fn_kwargs)

if result:
break

return result
Loading