Skip to content

Commit

Permalink
[Screenshot] fixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
japandotorg committed Sep 25, 2024
1 parent fba53af commit aeb6d99
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 166 deletions.
6 changes: 3 additions & 3 deletions screenshot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
from redbot.core.bot import Red
from redbot.core.errors import CogLoadError

from selenium.webdriver.remote.remote_connection import LOGGER as selenium_logger

from .core import Screenshot

# isort: on


async def setup(bot: Red) -> None:
urllib_logger.setLevel(logging.WARNING)
selenium_logger.setLevel(logging.WARNING)
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("h5py").setLevel(logging.WARNING)
logging.getLogger("selenium").setLevel(logging.WARNING)
if platform.system().lower() not in ["windows", "linux"]:
raise CogLoadError("This cog is only available for linux and windows devices right now.")
cog: Screenshot = Screenshot(bot)
Expand Down
33 changes: 27 additions & 6 deletions screenshot/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
"""

# isort: off
import os
import logging
import asyncio
import contextlib
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Literal, TypeVar

Expand Down Expand Up @@ -114,8 +114,18 @@ def has_time_passed(started: datetime, minutes: float) -> bool:
del self.drivers[date]

def take_screenshot_with_url(
self, driver: Firefox, *, url: str, mode: Literal["normal", "full"], wait: int = 10
self,
driver: Firefox,
*,
url: str,
size: Literal["normal", "full"],
mode: Literal["light", "dark"],
wait: int = 10,
) -> bytes:
if mode.lower() == "dark" and (
location := self.cog.manager.get_extension_location(mode.lower())
):
driver.install_addon(os.fspath(location), temporary=True)
try:
driver.get(url)
except TimeoutException:
Expand All @@ -133,9 +143,9 @@ def take_screenshot_with_url(
)
driver.implicitly_wait(wait)
try:
if mode.lower() == "normal":
if size.lower() == "normal":
byte: bytes = driver.get_screenshot_as_png()
elif mode.lower() == "full":
elif size.lower() == "full":
byte: bytes = driver.get_full_page_screenshot_as_png()
else:
raise commands.UserFeedbackCheckFailure(
Expand All @@ -149,13 +159,18 @@ def take_screenshot_with_url(
)
return byte

@asynccontextmanager
@contextlib.asynccontextmanager
async def driver(self) -> AsyncGenerator[Firefox, None]:
await self.cog.manager.wait_until_driver_downloaded()
await self.lock.acquire()
now: datetime = datetime.now(timezone.utc)
try:
driver: Firefox = await self.launcher()
(
driver.install_addon(os.fspath(location), temporary=True)
if (location := self.cog.manager.get_extension_location("cookies"))
else None
)
driver.set_page_load_timeout(time_to_wait=230.0)
driver.fullscreen_window()
try:
Expand All @@ -177,13 +192,19 @@ async def launcher(self) -> Firefox:
)

async def get_screenshot_bytes_from_url(
self, *, url: str, mode: Literal["normal", "full"], wait: int = 10
self,
*,
url: str,
size: Literal["normal", "full"],
mode: Literal["light", "dark"],
wait: int = 10,
) -> bytes:
async with self.driver() as driver:
return await asyncio.to_thread(
lambda: self.take_screenshot_with_url(
driver,
url=url,
size=size,
mode=mode,
wait=wait,
)
Expand Down
146 changes: 63 additions & 83 deletions screenshot/common/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import io
import os
import sys
import signal
import asyncio
import logging
import platform
from typing import ClassVar, Dict, Final, Optional, Tuple
import contextlib
from typing import ClassVar, Dict, Final, Optional

import discord
import aiohttp
Expand Down Expand Up @@ -62,9 +64,13 @@ class DriverManager:
TOR_EXPERT_BUNDLE_URL: ClassVar[str] = (
"https://archive.org/download/tor-expert-bundle/tor-expert-bundle-{system}.{ext}"
)
FIREFOX_ADDONS = {
"dark": "https://addons.mozilla.org/firefox/downloads/file/4351387/darkreader-4.9.92.xpi",
"cookies": "https://addons.mozilla.org/firefox/downloads/file/3625855/ninja_cookie-0.2.7.xpi",
}

def __init__(self, session: Optional[aiohttp.ClientSession] = None) -> None:
self._tor_process: asyncio.subprocess.Process = discord.utils.MISSING
self._tor_process: int = discord.utils.MISSING
self.__session: aiohttp.ClientSession = session or aiohttp.ClientSession()
self.__event: asyncio.Event = asyncio.Event()

Expand All @@ -79,12 +85,18 @@ def __environ(self) -> Dict[str, str]:
def data_directory(self) -> pathlib.Path:
return data_manager.cog_data_path(raw_name="Screenshot") / "data"

@property
def extensions_directory(self) -> pathlib.Path:
directory: pathlib.Path = self.data_directory / "extensions"
directory.mkdir(parents=True, exist_ok=True)
return directory

@property
def driver_location(self) -> Optional[pathlib.Path]:
return (
loc[0]
if (
loc := list(self.data_directory.glob("geckodriver-{}*".format(self.get_os())))
(loc := list(self.data_directory.glob("geckodriver-{}*".format(self.get_os()))))
or (loc := list(self.data_directory.glob("geckodrive-{}*".format(self.get_os()))))
)
else None
Expand Down Expand Up @@ -132,6 +144,13 @@ def get_firefox_system() -> str:
return "win32"
raise RuntimeError("Not a supported device.")

def get_extension_location(self, name: str) -> Optional[pathlib.Path]:
return (
loc[0]
if (loc := list(self.extensions_directory.glob("{}.xpi".format(name))))
else None
)

def get_os(self) -> str:
return "{}{}".format(self.get_os_name(), 64 if platform.machine().endswith("64") else 32)

Expand All @@ -151,23 +170,40 @@ def get_tor_download_url(self) -> str:
ext="tar.bz2" if self.get_os().startswith("linux-aarch64") else "tar.gz",
)

async def initialize(self) -> None:
if not self.tor_location:
await self.download_and_extract_tor()
if not self.driver_location:
await self.download_and_extract_driver()
if not self.firefox_location:
await self.download_and_extract_firefox()
for name, url in self.FIREFOX_ADDONS.items():
if not self.get_extension_location(name):
await self.download_aand_save_addon(name, url)
if not self._tor_process:
await self.execute_tor_binary()
self.set_driver_downloaded()

async def close(self) -> None:
await self.wait_until_driver_downloaded()
with contextlib.suppress(ProcessLookupError):
await asyncio.to_thread(lambda: os.kill(self._tor_process, signal.SIGTERM))

async def wait_until_driver_downloaded(self) -> None:
await self.__event.wait()

async def execute_tor_binary(self) -> Optional[asyncio.subprocess.Process]:
if self.tor_location is not None:
if self.tor_location and self._tor_process is discord.utils.MISSING:
process: asyncio.subprocess.Process = await asyncio.subprocess.create_subprocess_shell(
(
"{0}/tor/tor -f {0}/torrc"
if not self.get_os().startswith("win")
else "{0}/tor/tor.exe -f {0}/torrc"
).format(self.tor_location),
env=self.__environ,
stdin=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
)
self._tor_process: asyncio.subprocess.Process = process
self._tor_process: int = process.pid
log.info("Connected to tor successfully with ip:port 127.0.0.1:21666")

async def get_latest_release_version(self) -> str:
Expand All @@ -188,44 +224,11 @@ async def get_driver_download_url(self) -> str:
url: str = output_dict[0]["browser_download_url"]
return url

async def get_firefox_archive(self) -> Tuple[str, bytes]:
url: str = self.get_firefox_download_url("132.0a1")
log.info("Downloading firefox - %s" % url)
async with self.__session.get(url=url) as response:
if response.status == 404:
raise DownloadFailed("Could not find firefox with url: '%s'" % url, retry=False)
elif 400 <= response.status < 600:
raise DownloadFailed(retry=True)
response.raise_for_status()
byte: bytearray = bytearray()
byte_num: int = 0
with rich.progress.Progress(
rich.progress.SpinnerColumn(),
rich.progress.TextColumn("[progress.description]{task.description}"),
rich.progress.BarColumn(),
rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
rich.progress.TimeRemainingColumn(),
rich.progress.TimeElapsedColumn(),
) as progress:
task = progress.add_task(
"[red.seina.screenshot.downloader] Downloading Firefox",
total=response.content_length,
)
chunk: bytes = await response.content.read(1024)
while chunk:
byte.extend(chunk)
size: int = sys.getsizeof(chunk)
byte_num += size
progress.update(task, advance=size)
chunk: bytes = await response.content.read(1024)
return url, bytes(byte)

async def get_driver_archive(self) -> Tuple[str, bytes]:
url: str = await self.get_driver_download_url()
log.info("Downloading driver - %s" % url)
async def download_with_progress_bar(self, *, url: str, name: str) -> bytes:
log.info("Downloading %s - %s" % (name.lower(), url))
async with self.__session.get(url=url) as response:
if response.status == 404:
raise DownloadFailed("Could not find a driver with url: '%s" % url, retry=False)
raise DownloadFailed("Could not find %s with url '%s'" % (name.lower(), url))
elif 400 <= response.status < 600:
raise DownloadFailed(retry=True)
response.raise_for_status()
Expand All @@ -240,41 +243,7 @@ async def get_driver_archive(self) -> Tuple[str, bytes]:
rich.progress.TimeElapsedColumn(),
) as progress:
task: rich.progress.TaskID = progress.add_task(
"[red.seina.screenshot.downloader] Downloading driver",
total=response.content_length,
)
chunk: bytes = await response.content.read(1024)
while chunk:
byte.extend(chunk)
size: int = sys.getsizeof(chunk)
byte_num += size
progress.update(task, advance=size)
chunk: bytes = await response.content.read(1024)
return url, bytes(byte)

async def get_tor_archive(self) -> Tuple[str, bytes]:
url: str = self.get_tor_download_url()
log.info("Downloading tor - %s" % url)
async with self.__session.get(url=url) as response:
if response.status == 404:
raise DownloadFailed(
"Could not find tor expert bundle with url: '%s'" % url, retry=False
)
elif 400 <= response.status < 600:
raise DownloadFailed(retry=True)
response.raise_for_status()
byte: bytearray = bytearray()
byte_num: int = 0
with rich.progress.Progress(
rich.progress.SpinnerColumn(),
rich.progress.TextColumn("[progress.description]{task.description}"),
rich.progress.BarColumn(),
rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
rich.progress.TimeRemainingColumn(),
rich.progress.TimeElapsedColumn(),
) as progress:
task = progress.add_task(
"[red.seina.screenshot.downloader] Downloading Tor",
"[red.seina.screenshot.downloader] Downloading {}".format(name),
total=response.content_length,
)
chunk: bytes = await response.content.read(1024)
Expand All @@ -284,10 +253,11 @@ async def get_tor_archive(self) -> Tuple[str, bytes]:
byte_num += size
progress.update(task, advance=size)
chunk: bytes = await response.content.read(1024)
return url, bytes(byte)
return bytes(byte)

async def download_and_extract_firefox(self) -> None:
url, byte = await self.get_firefox_archive()
url: str = self.get_firefox_download_url("132.0a1")
byte: bytes = await self.download_with_progress_bar(url=url, name="firefox")
if url.endswith(".zip"):
with zipfile.ZipFile(io.BytesIO(byte), mode="r") as zip:
await asyncio.to_thread(lambda: zip.extractall(path=self.data_directory))
Expand All @@ -302,7 +272,8 @@ async def download_and_extract_firefox(self) -> None:
log.info("Downloaded firefox successfully with location: {}".format(self.firefox_location))

async def download_and_extract_driver(self) -> None:
url, byte = await self.get_driver_archive()
url: str = await self.get_driver_download_url()
byte: bytes = await self.download_with_progress_bar(url=url, name="driver")
if url.endswith(".zip"):
with zipfile.ZipFile(file=io.BytesIO(byte), mode="r") as zip:
await asyncio.to_thread(lambda: zip.extractall(path=self.data_directory))
Expand All @@ -325,7 +296,8 @@ async def download_and_extract_driver(self) -> None:
log.info("Downloaded driver successfully with location: {}".format(self.driver_location))

async def download_and_extract_tor(self) -> None:
url, byte = await self.get_tor_archive()
url: str = self.get_tor_download_url()
byte: bytes = await self.download_with_progress_bar(url=url, name="tor")
if url.endswith("bz2"):
tar: tarfile.TarFile = tarfile.TarFile.open(fileobj=io.BytesIO(byte), mode="r:bz2")
await asyncio.to_thread(lambda: tar.extractall(path=self.data_directory))
Expand All @@ -343,7 +315,6 @@ async def download_and_extract_tor(self) -> None:
"""
# Ports
SOCKSPort 21666
ControlPort 27666
# Logs
Log debug file {}
Expand All @@ -353,3 +324,12 @@ async def download_and_extract_tor(self) -> None:
)
)
log.info("Downloaded tor successfully with location: {}".format(self.tor_location))

async def download_aand_save_addon(self, name: str, url: str) -> None:
byte: bytes = await self.download_with_progress_bar(
url=url, name="{} extension".format(name)
)
file: pathlib.Path = self.extensions_directory / "{}.xpi".format(name)
with file.open("wb") as f:
f.write(byte)
log.info("Downloaded %s extension for firefox." % name)
Loading

0 comments on commit aeb6d99

Please sign in to comment.