From 26c45be83dacb48c0665c9bbeac5ff72fc4552f0 Mon Sep 17 00:00:00 2001 From: Lemon Rose <78662983+japandotorg@users.noreply.github.com> Date: Mon, 23 Sep 2024 09:18:57 +0530 Subject: [PATCH] [Screenshot] 0.1.1 released with better support. --- screenshot/__init__.py | 8 + screenshot/common/__init__.py | 36 ++-- screenshot/common/downloader.py | 331 ++++++++++++++++++++++++-------- screenshot/common/exceptions.py | 2 +- screenshot/common/filter.py | 63 ++++-- screenshot/common/firefox.py | 6 +- screenshot/common/utils.py | 4 +- screenshot/core.py | 102 ++++------ screenshot/info.json | 8 +- 9 files changed, 379 insertions(+), 181 deletions(-) diff --git a/screenshot/__init__.py b/screenshot/__init__.py index d9cb03d..789688d 100644 --- a/screenshot/__init__.py +++ b/screenshot/__init__.py @@ -22,15 +22,23 @@ SOFTWARE. """ +# isort: off +import logging import platform +from urllib3.connectionpool import log as urllib_logger from redbot.core.bot import Red from redbot.core.errors import CogLoadError +from selenium.webdriver.remote.remote_connection import LOGGER as selenium_logger + from .core import Screenshot +# isort: on async def setup(bot: Red) -> None: + urllib_logger.setLevel(logging.DEBUG) + selenium_logger.setLevel(logging.DEBUG) if platform.system().lower() not in ["windows", "linux"]: raise CogLoadError("This cog is only available for linux and windows devices right now.") cog: Screenshot = Screenshot(bot) diff --git a/screenshot/common/__init__.py b/screenshot/common/__init__.py index c93602f..3b07a94 100644 --- a/screenshot/common/__init__.py +++ b/screenshot/common/__init__.py @@ -22,22 +22,24 @@ SOFTWARE. """ +# isort: off +import logging import asyncio import contextlib -import logging from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, AsyncGenerator, Dict, Literal, TypeVar from redbot.core import commands -from selenium.common.exceptions import ScreenshotException, TimeoutException, WebDriverException -from selenium.webdriver.common.proxy import Proxy, ProxyType -from selenium.webdriver.firefox.firefox_profile import FirefoxProfile + from selenium.webdriver.firefox.options import Options from selenium.webdriver.firefox.service import Service +from selenium.webdriver.common.proxy import Proxy, ProxyType +from selenium.webdriver.firefox.firefox_profile import FirefoxProfile +from selenium.common.exceptions import TimeoutException, WebDriverException, ScreenshotException -from .exceptions import ProxyConnectFailedError from .firefox import Firefox +from .exceptions import ProxyConnectFailedError if TYPE_CHECKING: from ..core import Screenshot @@ -46,6 +48,7 @@ import regex as re except (ImportError, ModuleNotFoundError): import re as re +# isort: on T = TypeVar("T") @@ -62,9 +65,9 @@ def __init__(self, cog: "Screenshot") -> None: def get_service(self) -> Service: return Service( - executable_path=str(self.cog.manager.location), + executable_path=str(self.cog.manager.driver_location), service_args=["--log", "debug"], - log_output=str(self.cog.manager.data_directory / "gecko.log"), + log_output=str(self.cog.manager.logs_directory / "gecko.log"), ) def get_options(self) -> Options: @@ -72,15 +75,14 @@ def get_options(self) -> Options: options: Options = Options() options.add_argument("--headless") options.page_load_strategy = "normal" - options.binary_location = str(cog.manager.firefox) - if cog.CACHE["toggle"]: - options.proxy = Proxy( - { - "proxyType": ProxyType.MANUAL, - "socksProxy": "127.0.0.1:{}".format(cog.CACHE["port"]), - "socksVersion": 5, - } - ) + options.binary_location = str(cog.manager.firefox_location) + options.proxy = Proxy( + { + "proxyType": ProxyType.MANUAL, + "socksProxy": "127.0.0.1:21666", + "socksVersion": 5, + } + ) options.profile = FirefoxProfile() options.profile.set_preference("browser.cache.disk.enable", False) options.profile.set_preference("browser.cache.memory.enable", False) @@ -119,7 +121,7 @@ def take_screenshot_with_url( except TimeoutException: raise commands.UserFeedbackCheckFailure("Timed out opening the website.") except WebDriverException as error: - if re.search(pattern="neterror", string=str(error.msg), flags=re.IGNORECASE): + if re.search(pattern="about:neterror", string=str(error.msg), flags=re.IGNORECASE): log.exception("Something went wrong connecting to the internet.", exc_info=error) raise ProxyConnectFailedError() raise commands.UserFeedbackCheckFailure( diff --git a/screenshot/common/downloader.py b/screenshot/common/downloader.py index 6f24bf7..db4302d 100644 --- a/screenshot/common/downloader.py +++ b/screenshot/common/downloader.py @@ -22,56 +22,71 @@ SOFTWARE. """ -import asyncio -import concurrent.futures +# isort: off import io -import logging import os -import pathlib +import sys +import asyncio +import logging import platform -import tarfile -import zipfile -from typing import ClassVar, Dict, Final, Optional +from typing import ClassVar, Dict, Final, Optional, Tuple +import discord import aiohttp -from mozdownload.factory import FactoryScraper + +import pathlib +import zipfile +import tarfile +import rich.progress + from redbot.core import data_manager -from .exceptions import DriverDownloadFailed +from .exceptions import DownloadFailed + +# isort: on + log: logging.Logger = logging.getLogger("red.seina.screenshot.downloader") class DriverManager: - SYSTEM: Final[Dict[str, str]] = {"linux": "linux", "windows": "win", "arm": "linux-aarch"} - DOWNLOAD_URL: Final[str] = "https://github.com/mozilla/geckodriver/releases/download" - LATEST_RELEASE_URL: Final[str] = ( + DRIVER_LATEST_RELEASE_URL: Final[str] = ( "https://api.github.com/repos/mozilla/geckodriver/releases/latest" ) - RELEASE_TAG_URL: ClassVar[str] = ( + DRIVER_RELEASE_TAG_URL: ClassVar[str] = ( "https://api.github.com/repos/mozilla/geckodriver/releases/tags/{version}" ) - EXTRAS: Final[str] = ( - "/DesktopShortcut=false /StartMenuShortcut=false /PrivateBrowsingShortcut=false" + FIREFOX_DOWNLOAD_URL: ClassVar[str] = ( + "https://archive.mozilla.org/pub/firefox/nightly/latest-mozilla-central/firefox-{version}.en-US.{system}.{ext}" + ) + TOR_EXPERT_BUNDLE_URL: ClassVar[str] = ( + "https://archive.org/download/seina-tor-ext/tor-expert-bundle-{system}.{ext}" ) def __init__(self, session: Optional[aiohttp.ClientSession] = None) -> None: + self._tor_process: asyncio.subprocess.Process = discord.utils.MISSING self.__session: aiohttp.ClientSession = session or aiohttp.ClientSession() self.__event: asyncio.Event = asyncio.Event() - @staticmethod - def get_os_architecture() -> int: - if platform.machine().endswith("64"): - return 64 - else: - return 32 + @property + def __environ(self) -> Dict[str, str]: + environ: Dict[str, str] = os.environ.copy() + if self.tor_location and self.get_os() != "linux-aarch64": + environ["LD_LIBRARY_PATH"] = ( + os.getenv("LD_LIBRARY_PATH", "") + ":" + str(self.tor_location / "tor") + ) + return environ @property def data_directory(self) -> pathlib.Path: - return data_manager.cog_data_path(raw_name="Screenshot") + return data_manager.cog_data_path(raw_name="Screenshot") / "data" @property - def location(self) -> Optional[pathlib.Path]: + def logs_directory(self) -> pathlib.Path: + return self.data_directory / "logs" + + @property + def driver_location(self) -> Optional[pathlib.Path]: return ( loc[0] if ( @@ -82,104 +97,266 @@ def location(self) -> Optional[pathlib.Path]: ) @property - def firefox(self) -> Optional[pathlib.Path]: - return loc[0] if (loc := list(self.data_directory.glob("firefox/firefox*"))) else None + def firefox_location(self) -> Optional[pathlib.Path]: + return ( + loc[0] + if (loc := list(self.data_directory.glob("firefox-{}/firefox*".format(self.get_os())))) + else None + ) - def get_os_name(self) -> str: - if platform.machine().lower() == "aarch64": - return self.SYSTEM["arm"] + @property + def tor_location(self) -> Optional[pathlib.Path]: + return ( + loc[0] + if (loc := list(self.data_directory.glob("tor-{}*".format(self.get_os())))) + else None + ) + + @staticmethod + def get_os_name() -> str: + if platform.machine().lower().startswith("aarch"): + return "linux-aarch" if platform.system().lower() == "linux": - return self.SYSTEM["linux"] + return "linux" elif platform.system().lower() == "windows": - return self.SYSTEM["windows"] - else: - raise RuntimeError() + return "win" + raise RuntimeError() + + @staticmethod + def get_firefox_system() -> str: + if platform.machine().lower() == "aarch64": + return "linux-aarch64" + elif platform.system().lower() == "linux": + if platform.machine().endswith("64"): + return "linux-x86_64" + else: + return "linux-i686" + elif platform.system().lower() == "machine": + if platform.machine().endswith("64"): + return "win64" + else: + return "win32" + raise RuntimeError("Not a supported device.") def get_os(self) -> str: - return "{}{}".format(self.get_os_name(), self.get_os_architecture()) + return "{}{}".format(self.get_os_name(), 64 if platform.machine().endswith("64") else 32) def set_driver_downloaded(self) -> None: self.__event.set() + def get_firefox_download_url(self, version: str) -> str: + return self.FIREFOX_DOWNLOAD_URL.format( + version=version, + system=self.get_firefox_system(), + ext="zip" if self.get_os().startswith("win") else "tar.bz2", + ) + + def get_tor_download_url(self) -> str: + return self.TOR_EXPERT_BUNDLE_URL.format( + system=self.get_os(), + ext="tar.bz2" if self.get_os().startswith("linux-aarch64") else "tar.gz", + ) + async def wait_until_driver_downloaded(self) -> None: await self.__event.wait() + async def execute_tor_binary(self) -> Optional[asyncio.subprocess.Process]: + if self.tor_location is not None: + process: asyncio.subprocess.Process = await asyncio.subprocess.create_subprocess_shell( + ( + "{0}/tor/tor -f {0}/torrc" + if not self.get_os().startswith("win") + else "{0}/tor/tor.exe -f {0}/torrc" + ).format(self.tor_location), + env=self.__environ, + stdin=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + ) + self._tor_process: asyncio.subprocess.Process = process + log.info("Connected to tor successfully with ip:port 127.0.0.1:21666") + async def get_latest_release_version(self) -> str: - async with self.__session.get(url=self.LATEST_RELEASE_URL) as response: + async with self.__session.get(url=self.DRIVER_LATEST_RELEASE_URL) as response: json: Dict[str, str] = await response.json() version: str = json.get("tag_name", "v0.35.0") return version async def get_driver_download_url(self) -> str: version: str = await self.get_latest_release_version() - async with self.__session.get(self.RELEASE_TAG_URL.format(version=version)) as response: + async with self.__session.get( + self.DRIVER_RELEASE_TAG_URL.format(version=version) + ) as response: json = await response.json() assets = json["assets"] name: str = "{}-{}-{}.".format("geckodriver", version, self.get_os()) output_dict = [asset for asset in assets if asset["name"].startswith(name)] url: str = output_dict[0]["browser_download_url"] - log.debug("Downloading driver (%s) for [%s]" % (url, self.get_os())) + log.debug("Downloading driver - %s" % url) return url - async def download_and_extract_driver(self) -> None: + async def get_firefox_archive(self) -> Tuple[str, bytes]: + url: str = self.get_firefox_download_url("132.0a1") + log.info("Downloading firefox - %s" % url) + async with self.__session.get(url=url) as response: + if response.status == 404: + raise DownloadFailed("Could not find firefox with url: '%s'" % url, retry=False) + elif 400 <= response.status < 600: + raise DownloadFailed(retry=True) + response.raise_for_status() + byte: bytearray = bytearray() + byte_num: int = 0 + with rich.progress.Progress( + rich.progress.SpinnerColumn(), + rich.progress.TextColumn("[progress.description]{task.description}"), + rich.progress.BarColumn(), + rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + rich.progress.TimeRemainingColumn(), + rich.progress.TimeElapsedColumn(), + ) as progress: + task = progress.add_task( + "[red.seina.screenshot.downloader] Downloading Firefox", + total=response.content_length, + ) + chunk: bytes = await response.content.read(1024) + while chunk: + byte.extend(chunk) + size: int = sys.getsizeof(chunk) + byte_num += size + progress.update(task, advance=size) + chunk: bytes = await response.content.read(1024) + return url, bytes(byte) + + async def get_driver_archive(self) -> Tuple[str, bytes]: url: str = await self.get_driver_download_url() - async with self.__session.get(url=url, timeout=aiohttp.ClientTimeout(120)) as response: + log.info("Downloading driver - %s" % url) + async with self.__session.get(url=url) as response: if response.status == 404: - raise DriverDownloadFailed( - "Could not find a driver with url: '%s" % url, retry=False + raise DownloadFailed("Could not find a driver with url: '%s" % url, retry=False) + elif 400 <= response.status < 600: + raise DownloadFailed(retry=True) + response.raise_for_status() + byte: bytearray = bytearray() + byte_num: int = 0 + with rich.progress.Progress( + rich.progress.SpinnerColumn(), + rich.progress.TextColumn("[progress.description]{task.description}"), + rich.progress.BarColumn(), + rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + rich.progress.TimeRemainingColumn(), + rich.progress.TimeElapsedColumn(), + ) as progress: + task: rich.progress.TaskID = progress.add_task( + "[red.seina.screenshot.downloader] Downloading driver", + total=response.content_length, + ) + chunk: bytes = await response.content.read(1024) + while chunk: + byte.extend(chunk) + size: int = sys.getsizeof(chunk) + byte_num += size + progress.update(task, advance=size) + chunk: bytes = await response.content.read(1024) + return url, bytes(byte) + + async def get_tor_archive(self) -> Tuple[str, bytes]: + url: str = self.get_tor_download_url() + log.info("Downloading tor - %s" % url) + async with self.__session.get(url=url) as response: + if response.status == 404: + raise DownloadFailed( + "Could not find tor expert bundle with url: '%s'" % url, retry=False ) elif 400 <= response.status < 600: - raise DriverDownloadFailed(retry=True) + raise DownloadFailed(retry=True) response.raise_for_status() - byte: bytes = await response.read() + byte: bytearray = bytearray() + byte_num: int = 0 + with rich.progress.Progress( + rich.progress.SpinnerColumn(), + rich.progress.TextColumn("[progress.description]{task.description}"), + rich.progress.BarColumn(), + rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + rich.progress.TimeRemainingColumn(), + rich.progress.TimeElapsedColumn(), + ) as progress: + task = progress.add_task( + "[red.seina.screenshot.downloader] Downloading Tor", + total=response.content_length, + ) + chunk: bytes = await response.content.read(1024) + while chunk: + byte.extend(chunk) + size: int = sys.getsizeof(chunk) + byte_num += size + progress.update(task, advance=size) + chunk: bytes = await response.content.read(1024) + return url, bytes(byte) + + async def download_and_extract_firefox(self) -> None: + url, byte = await self.get_firefox_archive() + if url.endswith(".zip"): + with zipfile.ZipFile(io.BytesIO(byte), mode="r") as zip: + await asyncio.to_thread(lambda: zip.extractall(path=self.data_directory)) + elif url.endswith("tar.bz2"): + tar: tarfile.TarFile = tarfile.TarFile.open(fileobj=io.BytesIO(byte), mode="r:bz2") + await asyncio.to_thread(lambda: tar.extractall(path=self.data_directory)) + else: + raise DownloadFailed("Failed to download firefox.") + path: pathlib.Path = list(self.data_directory.glob("firefox*"))[0] + name: str = path.name + "-{}".format(self.get_os()) + os.rename(path, self.data_directory / name) + log.info("Downloaded firefox successfully with location: {}".format(self.firefox_location)) + + async def download_and_extract_driver(self) -> None: + url, byte = await self.get_driver_archive() if url.endswith(".zip"): with zipfile.ZipFile(file=io.BytesIO(byte), mode="r") as zip: - zip.extractall(path=self.data_directory) + await asyncio.to_thread(lambda: zip.extractall(path=self.data_directory)) elif url.endswith(".tar.gz"): tar: tarfile.TarFile = tarfile.TarFile.open(fileobj=io.BytesIO(byte), mode="r:gz") - tar.extractall(path=self.data_directory) + await asyncio.to_thread(lambda: tar.extractall(path=self.data_directory)) else: - raise DriverDownloadFailed("Failed to download the driver.") + raise DownloadFailed("Failed to download the driver.") path: pathlib.Path = list(self.data_directory.glob("geckodriver*"))[0] idx: int = path.name.rfind(".") - name: str = path.name[:idx] + "-{}".format( - "linux-aarch64" if platform.machine() == "aarch64" else self.get_os() - ) + name: str = path.name[:idx] + "-{}".format(self.get_os()) os.rename( path, ( self.data_directory / "{}.exe".format(name) - if self.get_os().startswith(self.SYSTEM["windows"]) + if self.get_os().startswith("win") else self.data_directory / name ), ) - log.info("Downloaded driver successfully with location: {}".format(self.location)) - - async def download_firefox(self) -> str: - log.info("Downloading firefox in {}".format(self.data_directory)) - with concurrent.futures.ThreadPoolExecutor() as executor: - scraper: FactoryScraper = FactoryScraper( - scraper_type="release", - version="130.0", - platform="linux-arm64" if platform.machine() == "aarch64" else None, - destination=str(self.data_directory), - ) - res: str = await asyncio.get_running_loop().run_in_executor( - executor=executor, func=scraper.download + log.info("Downloaded driver successfully with location: {}".format(self.driver_location)) + + async def download_and_extract_tor(self) -> None: + url, byte = await self.get_tor_archive() + if url.endswith("bz2"): + tar: tarfile.TarFile = tarfile.TarFile.open(fileobj=io.BytesIO(byte), mode="r:bz2") + await asyncio.to_thread(lambda: tar.extractall(path=self.data_directory)) + elif url.endswith("gz"): + tar: tarfile.TarFile = tarfile.TarFile.open(fileobj=io.BytesIO(byte), mode="r:gz") + await asyncio.to_thread( + lambda: tar.extractall(path=self.data_directory / "tor-{}".format(self.get_os())) ) - if self.get_os() == self.SYSTEM["windows"]: - process: asyncio.subprocess.Process = await asyncio.create_subprocess_shell( - "{} /InstallDirectoryPath={} {}".format( - res, str(self.data_directory / "firefox"), self.EXTRAS - ), - shell=True, + else: + raise DownloadFailed("Failed to download tor.") + if self.tor_location is not None: + file: pathlib.Path = self.tor_location / "torrc" + with file.open("w", encoding="utf-8") as t: + t.write( + """ + # Ports + SOCKSPort 21666 + ControlPort 27666 + + # Logs + Log debug file {} + DataDirectory {} + """.format( + self.logs_directory, self.tor_location / "teb-data" + ) ) - await process.wait() - elif res.endswith("tar.bz2"): - tar: tarfile.TarFile = tarfile.TarFile.open(name=res, mode="r:bz2") - tar.extractall(path=self.data_directory) - else: - raise RuntimeError() - os.remove(res) - log.info("Successfully downloaded firefox.") - return res + log.info("Downloaded tor successfully with location: {}".format(self.tor_location)) diff --git a/screenshot/common/exceptions.py b/screenshot/common/exceptions.py index 6471bee..a522bcd 100644 --- a/screenshot/common/exceptions.py +++ b/screenshot/common/exceptions.py @@ -29,7 +29,7 @@ class DownloaderError(Exception): """Base exception for web-driver downloader.""" -class DriverDownloadFailed(DownloaderError): +class DownloadFailed(DownloaderError): """Downloading the web driver failed.""" def __init__( diff --git a/screenshot/common/filter.py b/screenshot/common/filter.py index 606a54a..22903f3 100644 --- a/screenshot/common/filter.py +++ b/screenshot/common/filter.py @@ -22,28 +22,54 @@ SOFTWARE. """ -import asyncio +# isort: off +import os import io +import asyncio import logging -from typing import Any, Dict, Optional, cast +from PIL import Image +from typing import TYPE_CHECKING, Dict, Literal, Optional + +import discord +import torch import transformers -from PIL import Image + +if TYPE_CHECKING: + from ..core import Screenshot try: import regex as re except ModuleNotFoundError: import re as re +# isort: on log: logging.Logger = logging.getLogger("red.seina.screenshot.filter") class Filter: - def __init__(self) -> None: - self.model: transformers.Pipeline = transformers.pipeline( - "image-classification", model="Falconsai/nsfw_image_detection" - ) + def __init__(self, cog: "Screenshot") -> None: + self.cog: "Screenshot" = cog + self.models: Dict[Literal["small", "large"], transformers.Pipeline] = discord.utils.MISSING + self.__task: asyncio.Task[None] = asyncio.create_task(self.__models()) + + async def __models(self) -> None: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + await self.cog.manager.wait_until_driver_downloaded() + if self.models is discord.utils.MISSING: + self.models: Dict[Literal["small", "large"], transformers.Pipeline] = { + "small": transformers.pipeline( + "image-classification", + model="Falconsai/nsfw_image_detection", + device=torch.device("cpu"), + ), + "large": transformers.pipeline( + "image-classification", + model="MichalMlodawski/nsfw-image-detection-large", + device=torch.device("cpu"), + ), + } @staticmethod def is_valid_url(url: str) -> bool: @@ -59,10 +85,25 @@ def is_valid_url(url: str) -> bool: match: Optional[re.Match[str]] = regex.search(url) return match is not None - async def read(self, image: bytes) -> bool: + def close(self) -> None: + self.__task.cancel() + + async def read(self, image: bytes, model: Literal["small", "large"]) -> bool: + if model.lower() == "small": + return await asyncio.to_thread(lambda: self.small(image)) + elif model.lower() == "large": + return await asyncio.to_thread(lambda: self.large(image)) + else: + raise RuntimeError("No model named '{}' found.".format(model.lower())) + + def small(self, image: bytes) -> bool: img: Image.ImageFile.ImageFile = Image.open(io.BytesIO(image)) - response: Dict[str, Any] = cast( - Dict[str, Any], await asyncio.to_thread(lambda: self.model(img)) - ) + response: Dict[str, float] = self.models["small"](img) # type: ignore pred: str = max(response, key=lambda x: x["score"]) return pred["label"] == "nsfw" + + def large(self, image: bytes) -> bool: + img: Image.ImageFile.ImageFile = Image.open(io.BytesIO(image)) + response: Dict[str, float] = self.models["large"](img) # type: ignore + pred: str = max(response, key=lambda x: x["score"]) + return pred["label"].lower() != "safe" diff --git a/screenshot/common/firefox.py b/screenshot/common/firefox.py index 4a73d5d..afaeb8a 100644 --- a/screenshot/common/firefox.py +++ b/screenshot/common/firefox.py @@ -22,14 +22,16 @@ SOFTWARE. """ +# isort: off import atexit import weakref from typing import Any -from selenium.webdriver.common.options import BaseOptions from selenium.webdriver.common.service import Service -from selenium.webdriver.firefox.webdriver import WebDriver as _Firefox +from selenium.webdriver.common.options import BaseOptions from selenium.webdriver.remote.webdriver import WebDriver as _Driver +from selenium.webdriver.firefox.webdriver import WebDriver as _Firefox +# isort: on class Driver(_Driver): diff --git a/screenshot/common/utils.py b/screenshot/common/utils.py index a6b36d7..e0fed06 100644 --- a/screenshot/common/utils.py +++ b/screenshot/common/utils.py @@ -22,9 +22,10 @@ SOFTWARE. """ -import functools +# isort: off import hashlib import logging +import functools from typing import TYPE_CHECKING, Callable import discord @@ -38,6 +39,7 @@ import regex as re except ModuleNotFoundError: import re as re +# isort: on log: logging.Logger = logging.getLogger("red.seina.screenshot.core") diff --git a/screenshot/core.py b/screenshot/core.py index f6a37bd..7d3cde3 100644 --- a/screenshot/core.py +++ b/screenshot/core.py @@ -22,26 +22,30 @@ SOFTWARE. """ -import asyncio.sslproto -import contextlib +# isort: off import io +import aiohttp import logging -from typing import Dict, Final, List, Literal, Optional, Union +import contextlib +import asyncio.sslproto +from typing import Final, List, Literal -import aiohttp import discord from discord.ext import tasks -from redbot.core import commands from redbot.core.bot import Red +from redbot.core import commands from redbot.core.config import Config from redbot.core.utils.chat_formatting import humanize_list + from selenium.common.exceptions import NoSuchDriverException +from .common.filter import Filter from .common import FirefoxManager from .common.downloader import DriverManager from .common.exceptions import ProxyConnectFailedError -from .common.filter import Filter -from .common.utils import URLConverter, send_notification, counter as counter_api +from .common.utils import counter as counter_api, URLConverter +# isort: on + log: logging.Logger = logging.getLogger("red.seina.screenshot.core") @@ -53,15 +57,9 @@ class Screenshot(commands.Cog): setattr(asyncio.sslproto._SSLProtocolTransport, "_start_tls_compatible", True) - __version__: Final[str] = "0.1.0" + __version__: Final[str] = "0.1.1" __author__: Final[List[str]] = ["inthedark.org"] - CACHE: Dict[str, Union[bool, int, Optional[str]]] = { - "toggle": False, - "port": 9050, - "updated": False, - } - def format_help_for_context(self, ctx: commands.Context) -> str: pre_processed = super().format_help_for_context(ctx) or "" n = "\n" if "\n\n" not in pre_processed else "" @@ -75,30 +73,34 @@ def format_help_for_context(self, ctx: commands.Context) -> str: def __init__(self, bot: Red) -> None: self.bot: Red = bot self.config: Config = Config.get_conf(self, identifier=69_420_666, force_registration=True) - self.config.register_global(**self.CACHE) + self.config.register_global(**{"updated": False, "nsfw": "normal"}) self.session: aiohttp.ClientSession = aiohttp.ClientSession(trust_env=True) self.manager: DriverManager = DriverManager(session=self.session) self.driver: FirefoxManager = FirefoxManager(self) - self.filter: Filter = Filter() + self.filter: Filter = Filter(self) self.__task: asyncio.Task[None] = asyncio.create_task(self.update_counter_api()) async def cog_load(self) -> None: - if self.manager.firefox is None: - await self.manager.download_firefox() - if self.manager.location is None: + if self.manager.tor_location is None: + await self.manager.download_and_extract_tor() + await self.manager.execute_tor_binary() + if self.manager.firefox_location is None: + await self.manager.download_and_extract_firefox() + if self.manager.driver_location is None: await self.manager.download_and_extract_driver() - self.CACHE: Dict[str, Union[bool, int, Optional[str]]] = await self.config.all() - if not self.CACHE["toggle"]: - await send_notification(self) self.manager.set_driver_downloaded() - self.bg_task.start() # type: ignore + self.bg_task.start() async def cog_unload(self) -> None: + self.filter.close() self.__task.cancel() - self.bg_task.cancel() # type: ignore + self.bg_task.cancel() await self.session.close() + if self.manager._tor_process is not discord.utils.MISSING: + self.manager._tor_process.terminate() + self.manager._tor_process.kill() with contextlib.suppress(BaseException): self.driver.clear_all_drivers() @@ -116,9 +118,10 @@ async def update_counter_api(self) -> None: @tasks.loop(minutes=5.0, reconnect=True, name="red:seina:screenshot") async def bg_task(self) -> None: - self.driver.remove_drivers_if_time_has_passed(minutes=10.0) + with contextlib.suppress(RuntimeError): + self.driver.remove_drivers_if_time_has_passed(minutes=10.0) - @bg_task.before_loop # type: ignore + @bg_task.before_loop async def bg_task_before_loop(self) -> None: await self.manager.wait_until_driver_downloaded() @@ -127,34 +130,6 @@ async def bg_task_before_loop(self) -> None: async def screenshot_set(self, _: commands.Context): """Configuration commands for screenshot.""" - @screenshot_set.group(name="tor", invoke_without_command=True) # type: ignore - async def screenshot_set_tor(self, ctx: commands.Context, toggle: bool): - """ - Enable or disable tor proxy when taking screenshots. - """ - if not ctx.invoked_subcommand: - await self.config.toggle.set(toggle) - self.CACHE["toggle"] = toggle - await ctx.tick() - - @screenshot_set_tor.command(name="port") # type: ignore - async def screenshot_set_tor_port( - self, ctx: commands.Context, port: commands.Range[int, 1, 5] - ): - """ - Change the default port of the tor protocol. - """ - if port > 65535: - await ctx.send( - "The maximum supported port is '65535' got '{}' instead.".format(port), - reference=ctx.message.to_reference(fail_if_not_exists=False), - allowed_mentions=discord.AllowedMentions(replied_user=False), - ) - raise commands.CheckFailure() - await self.config.port(port) - self.CACHE["port"] = port - await ctx.tick() - @commands.command() @commands.cooldown(1, 60, commands.BucketType.user) @commands.has_permissions(attach_files=True, embed_links=True) @@ -195,21 +170,14 @@ async def screenshot( else "try again later." ) raise commands.CheckFailure() - except ProxyConnectFailedError: - if self.CACHE["toggle"]: - log.info( - "Failed connecting to the proxy, disabling proxy config...", - ) - await self.config.toggle.set(False) - self.CACHE["toggle"] = False + except ProxyConnectFailedError as error: + log.exception( + "Failed connecting to the proxy.", + exc_info=error, + ) await self.bot.send_to_owners( "Something went wrong with the screenshot cog, check logs for more details." ) - await ctx.send( - "Something went wrong with the screenshot cog, try again later.", - reference=ctx.message.to_reference(fail_if_not_exists=False), - allowed_mentions=discord.AllowedMentions(replied_user=False), - ) raise commands.CheckFailure() except commands.UserFeedbackCheckFailure as error: if message := error.message: @@ -230,7 +198,7 @@ async def screenshot( ), ) and not ctx.channel.is_nsfw() - and await self.filter.read(image) + and await self.filter.read(image=image, model="small") ): await ctx.send( "This image contains nsfw content, and cannot be sent on this channel.", diff --git a/screenshot/info.json b/screenshot/info.json index 245d637..85c33c0 100644 --- a/screenshot/info.json +++ b/screenshot/info.json @@ -2,11 +2,11 @@ "author": [ "inthedark.org" ], - "install_msg": "Thanks for installing the screenshot cog, check your dms for more instructions.", + "install_msg": "Thanks for installing the screenshot cog, check your dms for more instructions. This cog may take some time to load for the first time depending on your internet.", "name": "Screenshot", "disabled": false, "short": "Take web page screenshots with your bot without compromising privacy of your machine.", - "description": "Take web page screenshots with your bot without compromising privacy of your machine.", + "description": "Take web page screenshots with your bot without compromising privacy of your machine. Atleast 2GB RAM and 10GB storage disk is recommended to use this cog.", "tags": [ "screenshot", "scraping", @@ -25,14 +25,12 @@ "required_cogs": {}, "requirements": [ "selenium", - "git+https://github.com/mozilla/mozdownload.git@master", "Pillow", "transformers", "torch", "tensorflow", "tf-keras", - "google", - "protobuf" + "cloud-tpu-client" ], "type": "COG", "end_user_data_statement": "This cog does not store End User Data."