Skip to content

Commit

Permalink
fix: Fix init of context managers and context handling in `BasicCrawl…
Browse files Browse the repository at this point in the history
…er` (apify#714)

### Problems

- The `EventManager` could be initialized multiple times, such as when
running a `Crawler` wrapped inside an `Actor`.
- In `crawler.run`, the async context was entered and exited directly,
which could lead to issues if the caller had already entered it. This
scenario might occur when users provide their own instances of
`BrowserPool`, `SessionPool`, `EventManager`, or others.

### Solution

- Address these issues by introducing an `active` flag to the existing
context managers to prevent multiple initializations.
- Implement an `ensure_context` helper and apply it to methods where
context management is required.
- Fix & improve tests to ensure these cases are covered.

### Others

- I added missing names to asyncio tasks.
  • Loading branch information
vdusek authored Nov 25, 2024
1 parent dcf2485 commit 486fe6d
Show file tree
Hide file tree
Showing 18 changed files with 392 additions and 87 deletions.
30 changes: 29 additions & 1 deletion src/crawlee/_autoscaling/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Snapshot,
)
from crawlee._utils.byte_size import ByteSize
from crawlee._utils.context import ensure_context
from crawlee._utils.docs import docs_group
from crawlee._utils.recurring_task import RecurringTask
from crawlee.events._types import Event, EventSystemInfoData
Expand Down Expand Up @@ -114,6 +115,9 @@ def __init__(

self._timestamp_of_last_memory_warning: datetime = datetime.now(timezone.utc) - timedelta(hours=1)

# Flag to indicate the context state.
self._active = False

@staticmethod
def _get_sorted_list_by_created_at(input_list: list[T]) -> SortedList[T]:
return SortedList(input_list, key=attrgetter('created_at'))
Expand All @@ -126,8 +130,21 @@ def _get_default_max_memory_size(available_memory_ratio: float) -> ByteSize:
logger.info(f'Setting max_memory_size of this run to {max_memory_size}.')
return max_memory_size

@property
def active(self) -> bool:
"""Indicates whether the context is active."""
return self._active

async def __aenter__(self) -> Snapshotter:
"""Starts capturing snapshots at configured intervals."""
"""Starts capturing snapshots at configured intervals.
Raises:
RuntimeError: If the context manager is already active.
"""
if self._active:
raise RuntimeError(f'The {self.__class__.__name__} is already active.')

self._active = True
self._event_manager.on(event=Event.SYSTEM_INFO, listener=self._snapshot_cpu)
self._event_manager.on(event=Event.SYSTEM_INFO, listener=self._snapshot_memory)
self._snapshot_event_loop_task.start()
Expand All @@ -144,12 +161,20 @@ async def __aexit__(
This method stops capturing snapshots of system resources (CPU, memory, event loop, and client information).
It should be called to terminate resource capturing when it is no longer needed.
Raises:
RuntimeError: If the context manager is not active.
"""
if not self._active:
raise RuntimeError(f'The {self.__class__.__name__} is not active.')

self._event_manager.off(event=Event.SYSTEM_INFO, listener=self._snapshot_cpu)
self._event_manager.off(event=Event.SYSTEM_INFO, listener=self._snapshot_memory)
await self._snapshot_event_loop_task.stop()
await self._snapshot_client_task.stop()
self._active = False

@ensure_context
def get_memory_sample(self, duration: timedelta | None = None) -> list[Snapshot]:
"""Returns a sample of the latest memory snapshots.
Expand All @@ -162,6 +187,7 @@ def get_memory_sample(self, duration: timedelta | None = None) -> list[Snapshot]
snapshots = cast(list[Snapshot], self._memory_snapshots)
return self._get_sample(snapshots, duration)

@ensure_context
def get_event_loop_sample(self, duration: timedelta | None = None) -> list[Snapshot]:
"""Returns a sample of the latest event loop snapshots.
Expand All @@ -174,6 +200,7 @@ def get_event_loop_sample(self, duration: timedelta | None = None) -> list[Snaps
snapshots = cast(list[Snapshot], self._event_loop_snapshots)
return self._get_sample(snapshots, duration)

@ensure_context
def get_cpu_sample(self, duration: timedelta | None = None) -> list[Snapshot]:
"""Returns a sample of the latest CPU snapshots.
Expand All @@ -186,6 +213,7 @@ def get_cpu_sample(self, duration: timedelta | None = None) -> list[Snapshot]:
snapshots = cast(list[Snapshot], self._cpu_snapshots)
return self._get_sample(snapshots, duration)

@ensure_context
def get_client_sample(self, duration: timedelta | None = None) -> list[Snapshot]:
"""Returns a sample of the latest client snapshots.
Expand Down
40 changes: 40 additions & 0 deletions src/crawlee/_utils/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import asyncio
from functools import wraps
from typing import Any, Callable, TypeVar

T = TypeVar('T', bound=Callable[..., Any])


def ensure_context(method: T) -> T:
"""Decorator to ensure the (async) context manager is initialized before calling the method.
Args:
method: The method to wrap.
Returns:
The wrapped method with context checking applied.
"""

@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self, 'active'):
raise RuntimeError(f'The {self.__class__.__name__} does not have the "active" attribute.')

if not self.active:
raise RuntimeError(f'The {self.__class__.__name__} is not active. Use it within the context.')

return method(self, *args, **kwargs)

@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self, 'active'):
raise RuntimeError(f'The {self.__class__.__name__} does not have the "active" attribute.')

if not self.active:
raise RuntimeError(f'The {self.__class__.__name__} is not active. Use it within the async context.')

return await method(self, *args, **kwargs)

return async_wrapper if asyncio.iscoroutinefunction(method) else sync_wrapper # type: ignore[return-value]
31 changes: 19 additions & 12 deletions src/crawlee/basic_crawler/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(
else None,
available_memory_ratio=self._configuration.available_memory_ratio,
)
self._pool = AutoscaledPool(
self._autoscaled_pool = AutoscaledPool(
system_status=SystemStatus(self._snapshotter),
is_finished_function=self.__is_finished_function,
is_task_ready_function=self.__is_task_ready_function,
Expand Down Expand Up @@ -442,7 +442,7 @@ def sigint_handler() -> None:

run_task.cancel()

run_task = asyncio.create_task(self._run_crawler())
run_task = asyncio.create_task(self._run_crawler(), name='run_crawler_task')

with suppress(NotImplementedError): # event loop signal handlers are not supported on Windows
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, sigint_handler)
Expand Down Expand Up @@ -476,18 +476,25 @@ def sigint_handler() -> None:
return final_statistics

async def _run_crawler(self) -> None:
async with AsyncExitStack() as exit_stack:
await exit_stack.enter_async_context(self._event_manager)
await exit_stack.enter_async_context(self._snapshotter)
await exit_stack.enter_async_context(self._statistics)

if self._use_session_pool:
await exit_stack.enter_async_context(self._session_pool)
# Collect the context managers to be entered. Context managers that are already active are excluded,
# as they were likely entered by the caller, who will also be responsible for exiting them.
contexts_to_enter = [
cm
for cm in (
self._event_manager,
self._snapshotter,
self._statistics,
self._session_pool if self._use_session_pool else None,
*self._additional_context_managers,
)
if cm and getattr(cm, 'active', False) is False
]

for context_manager in self._additional_context_managers:
await exit_stack.enter_async_context(context_manager)
async with AsyncExitStack() as exit_stack:
for context in contexts_to_enter:
await exit_stack.enter_async_context(context)

await self._pool.run()
await self._autoscaled_pool.run()

async def add_requests(
self,
Expand Down
17 changes: 15 additions & 2 deletions src/crawlee/browsers/_base_browser_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ class BaseBrowserPlugin(ABC):
AUTOMATION_LIBRARY: str | None = None
"""The name of the automation library that the plugin is managing."""

@property
@abstractmethod
def active(self) -> bool:
"""Indicates whether the context is active."""

@property
@abstractmethod
def browser_type(self) -> BrowserType:
Expand All @@ -45,7 +50,11 @@ def max_open_pages_per_browser(self) -> int:

@abstractmethod
async def __aenter__(self) -> BaseBrowserPlugin:
"""Enter the context manager and initialize the browser plugin."""
"""Enter the context manager and initialize the browser plugin.
Raises:
RuntimeError: If the context manager is already active.
"""

@abstractmethod
async def __aexit__(
Expand All @@ -54,7 +63,11 @@ async def __aexit__(
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
"""Exit the context manager and close the browser plugin."""
"""Exit the context manager and close the browser plugin.
Raises:
RuntimeError: If the context manager is not active.
"""

@abstractmethod
async def new_browser(self) -> BaseBrowserController:
Expand Down
32 changes: 28 additions & 4 deletions src/crawlee/browsers/_browser_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary

from crawlee._utils.context import ensure_context
from crawlee._utils.crypto import crypto_random_object_id
from crawlee._utils.docs import docs_group
from crawlee._utils.recurring_task import RecurringTask
Expand Down Expand Up @@ -91,6 +92,9 @@ def __init__(
self._pages = WeakValueDictionary[str, CrawleePage]() # Track the pages in the pool
self._plugins_cycle = itertools.cycle(self._plugins) # Cycle through the plugins

# Flag to indicate the context state.
self._active = False

@classmethod
def with_default_plugin(
cls,
Expand Down Expand Up @@ -148,10 +152,21 @@ def total_pages_count(self) -> int:
"""Returns the total number of pages opened since the browser pool was launched."""
return self._total_pages_count

@property
def active(self) -> bool:
"""Indicates whether the context is active."""
return self._active

async def __aenter__(self) -> BrowserPool:
"""Enter the context manager and initialize all browser plugins."""
logger.debug('Initializing browser pool.')
"""Enter the context manager and initialize all browser plugins.
Raises:
RuntimeError: If the context manager is already active.
"""
if self._active:
raise RuntimeError(f'The {self.__class__.__name__} is already active.')

self._active = True
# Start the recurring tasks for identifying and closing inactive browsers
self._identify_inactive_browsers_task.start()
self._close_inactive_browsers_task.start()
Expand All @@ -172,8 +187,13 @@ async def __aexit__(
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
"""Exit the context manager and close all browser plugins."""
logger.debug('Closing browser pool.')
"""Exit the context manager and close all browser plugins.
Raises:
RuntimeError: If the context manager is not active.
"""
if not self._active:
raise RuntimeError(f'The {self.__class__.__name__} is not active.')

await self._identify_inactive_browsers_task.stop()
await self._close_inactive_browsers_task.stop()
Expand All @@ -184,6 +204,9 @@ async def __aexit__(
for plugin in self._plugins:
await plugin.__aexit__(exc_type, exc_value, exc_traceback)

self._active = False

@ensure_context
async def new_page(
self,
*,
Expand Down Expand Up @@ -213,6 +236,7 @@ async def new_page(

return await self._get_new_page(page_id, plugin, proxy_info)

@ensure_context
async def new_page_with_each_plugin(self) -> Sequence[CrawleePage]:
"""Create a new page with each browser plugin in the pool.
Expand Down
20 changes: 18 additions & 2 deletions src/crawlee/browsers/_playwright_browser_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from playwright.async_api import Playwright, async_playwright
from typing_extensions import override

from crawlee._utils.context import ensure_context
from crawlee._utils.docs import docs_group
from crawlee.browsers._base_browser_plugin import BaseBrowserPlugin
from crawlee.browsers._playwright_browser_controller import PlaywrightBrowserController
Expand Down Expand Up @@ -55,6 +56,14 @@ def __init__(
self._playwright_context_manager = async_playwright()
self._playwright: Playwright | None = None

# Flag to indicate the context state.
self._active = False

@property
@override
def active(self) -> bool:
return self._active

@property
@override
def browser_type(self) -> BrowserType:
Expand All @@ -77,7 +86,10 @@ def max_open_pages_per_browser(self) -> int:

@override
async def __aenter__(self) -> PlaywrightBrowserPlugin:
logger.debug('Initializing Playwright browser plugin.')
if self._active:
raise RuntimeError(f'The {self.__class__.__name__} is already active.')

self._active = True
self._playwright = await self._playwright_context_manager.__aenter__()
return self

Expand All @@ -88,10 +100,14 @@ async def __aexit__(
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
logger.debug('Closing Playwright browser plugin.')
if not self._active:
raise RuntimeError(f'The {self.__class__.__name__} is not active.')

await self._playwright_context_manager.__aexit__(exc_type, exc_value, exc_traceback)
self._active = False

@override
@ensure_context
async def new_browser(self) -> PlaywrightBrowserController:
if not self._playwright:
raise RuntimeError('Playwright browser plugin is not initialized.')
Expand Down
Loading

0 comments on commit 486fe6d

Please sign in to comment.