diff --git a/screenshot/common/__init__.py b/screenshot/common/__init__.py index 0f8225c..57dc2aa 100644 --- a/screenshot/common/__init__.py +++ b/screenshot/common/__init__.py @@ -139,17 +139,21 @@ def take_screenshot_with_url( @contextlib.asynccontextmanager async def driver(self) -> AsyncGenerator[Firefox, None]: await self.cog.manager.wait_until_driver_downloaded() - driver: Firefox = await self.launcher() - driver.set_page_load_timeout(time_to_wait=230.0) - driver.fullscreen_window() + await self.lock.acquire() try: - yield driver - except BaseException as error: - with contextlib.suppress(BaseException): - if isinstance(driver, Firefox): - driver.delete_all_cookies() - driver.quit() - raise error + driver: Firefox = await self.launcher() + driver.set_page_load_timeout(time_to_wait=230.0) + driver.fullscreen_window() + try: + yield driver + except BaseException as error: + with contextlib.suppress(BaseException): + if isinstance(driver, Firefox): + driver.delete_all_cookies() + driver.quit() + raise error + finally: + self.lock.release() async def launcher(self) -> Firefox: return await asyncio.to_thread( @@ -164,17 +168,13 @@ async def get_screenshot_bytes_from_url( mode: Literal["light", "dark"], wait: int = 10, ) -> bytes: - await self.lock.acquire() - try: - async with self.driver() as driver: - return await asyncio.to_thread( - lambda: self.take_screenshot_with_url( - driver, - url=url, - size=size, - mode=mode, - wait=wait, - ) + async with self.driver() as driver: + return await asyncio.to_thread( + lambda: self.take_screenshot_with_url( + driver, + url=url, + size=size, + mode=mode, + wait=wait, ) - finally: - self.lock.release() + ) diff --git a/screenshot/common/downloader.py b/screenshot/common/downloader.py index 6f7d25e..ee18238 100644 --- a/screenshot/common/downloader.py +++ b/screenshot/common/downloader.py @@ -150,7 +150,7 @@ def get_extension_location(self, name: str) -> Optional[pathlib.Path]: else None ) - def get_os(self) -> str: + def get_os(self) -> str: return "{}{}".format(self.get_os_name(), 64 if platform.machine().endswith("64") else 32) def set_driver_downloaded(self) -> None: @@ -169,7 +169,7 @@ def get_tor_download_url(self) -> str: ext="tar.bz2" if self.get_os().startswith("linux-aarch64") else "tar.gz", ) - async def initialize(self) -> None: + async def __initialize(self) -> None: if not self.tor_location: await self.download_and_extract_tor() if not self.driver_location: @@ -183,6 +183,14 @@ async def initialize(self) -> None: await self.execute_tor_binary() self.set_driver_downloaded() + async def initialize(self) -> None: + try: + await self.__initialize() + except DownloadFailed as error: + if error.retry: + return await self.initialize() + log.exception(error.message, exc_info=error) + async def close(self) -> None: await self.wait_until_driver_downloaded() with contextlib.suppress(ProcessLookupError): diff --git a/screenshot/core.py b/screenshot/core.py index 6359d2a..7deebdf 100644 --- a/screenshot/core.py +++ b/screenshot/core.py @@ -193,7 +193,7 @@ async def screenshot(self, ctx: commands.Context, url: URLConverter, *, flags: s allowed_mentions=discord.AllowedMentions(replied_user=False), ) raise commands.CheckFailure() - self.filter.maybe_setup_models() + await asyncio.to_thread(lambda: self.filter.maybe_setup_models()) if ( isinstance( ctx.channel,