From ddb42c9e9b4157e4b38d0fb020938de6c5c1a1dd Mon Sep 17 00:00:00 2001 From: Goldy <153996346+g0ldyy@users.noreply.github.com> Date: Thu, 4 Jul 2024 14:52:37 +0200 Subject: [PATCH] pydantic user config validation --- comet/api/stream.py | 4 ++-- comet/debrid/alldebrid.py | 3 +++ comet/debrid/manager.py | 9 +++++++-- comet/templates/index.html | 4 ++-- comet/utils/general.py | 40 ++++++++++++++------------------------ comet/utils/models.py | 35 ++++++++++++++++++++++++++++++++- 6 files changed, 63 insertions(+), 32 deletions(-) create mode 100644 comet/debrid/alldebrid.py diff --git a/comet/api/stream.py b/comet/api/stream.py index 5efa58b..8f44b6e 100644 --- a/comet/api/stream.py +++ b/comet/api/stream.py @@ -193,7 +193,7 @@ async def stream(request: Request, b64config: str, type: str, id: str): tasks = [] filtered = 0 - filter_title = config["filterTitles"] if "filterTitles" in config else True # not needed when pydantic config validation system implemented + filter_title = config["filterTitles"] for torrent in torrents: if filter_title: parsed_torrent = parse( @@ -207,7 +207,7 @@ async def stream(request: Request, b64config: str, type: str, id: str): tasks.append(get_torrent_hash(session, indexer_manager_type, torrent)) - logger.info(f"{filtered} filtered torrents from Zilean API for {logName}") + logger.info(f"{filtered} filtered torrents for {logName}") torrent_hashes = await asyncio.gather(*tasks) torrent_hashes = list(set([hash for hash in torrent_hashes if hash])) diff --git a/comet/debrid/alldebrid.py b/comet/debrid/alldebrid.py new file mode 100644 index 0000000..4b7206b --- /dev/null +++ b/comet/debrid/alldebrid.py @@ -0,0 +1,3 @@ +class AllDebrid: + def __init__(self): + pass \ No newline at end of file diff --git a/comet/debrid/manager.py b/comet/debrid/manager.py index 7b50905..4c8fbab 100644 --- a/comet/debrid/manager.py +++ b/comet/debrid/manager.py @@ -1,8 +1,13 @@ import aiohttp from .realdebrid import RealDebrid +from .alldebrid import AllDebrid def getDebrid(session: aiohttp.ClientSession, config: dict): - if config["debridService"] == "realdebrid": - return RealDebrid(session, config["debridApiKey"]) + debrid_service = config["debridService"] + debrid_api_key = config["debridApiKey"] + if debrid_service == "realdebrid": + return RealDebrid(session, debrid_api_key) + elif debrid_service == "alldebrid": + return AllDebrid(session, debrid_api_key) \ No newline at end of file diff --git a/comet/templates/index.html b/comet/templates/index.html index 34f744f..78c88a9 100644 --- a/comet/templates/index.html +++ b/comet/templates/index.html @@ -571,13 +571,13 @@ const button = document.querySelector("sl-button"); const alert = document.querySelector('sl-alert[variant="neutral"]'); button.addEventListener("click", () => { - const debridService = document.getElementById("debridService").value; - const debridApiKey = document.getElementById("debridApiKey").value; const indexers = Array.from(document.getElementById("indexers").selectedOptions).map(option => option.value); const languages = Array.from(document.getElementById("languages").selectedOptions).map(option => option.value); const resolutions = Array.from(document.getElementById("resolutions").selectedOptions).map(option => option.value); const maxResults = document.getElementById("maxResults").value; const filterTitles = document.getElementById("filterTitles").checked; + const debridService = document.getElementById("debridService").value; + const debridApiKey = document.getElementById("debridApiKey").value; const selectedLanguages = languages.length === defaultLanguages.length && languages.every((val, index) => val === defaultLanguages[index]) ? ["All"] : languages; const selectedResolutions = resolutions.length === defaultResolutions.length && resolutions.every((val, index) => val === defaultResolutions[index]) ? ["All"] : resolutions; diff --git a/comet/utils/general.py b/comet/utils/general.py index f0b6502..00de527 100644 --- a/comet/utils/general.py +++ b/comet/utils/general.py @@ -7,7 +7,7 @@ import bencodepy from comet.utils.logger import logger -from comet.utils.models import settings +from comet.utils.models import settings, ConfigModel translation_table = { "ā": "a", @@ -181,26 +181,8 @@ def bytes_to_size(bytes: int): def config_check(b64config: str): try: config = json.loads(base64.b64decode(b64config).decode()) - - if not isinstance(config["debridService"], str) or config[ - "debridService" - ] not in ["realdebrid"]: - return False - if not isinstance(config["debridApiKey"], str): - return False - if not isinstance(config["indexers"], list): - return False - if not isinstance(config["maxResults"], int) or config["maxResults"] < 0: - return False - if ( - not isinstance(config["resolutions"], list) - or len(config["resolutions"]) == 0 - ): - return False - if not isinstance(config["languages"], list) or len(config["languages"]) == 0: - return False - - return config + validated_config = ConfigModel(**config) + return validated_config.model_dump() except: return False @@ -302,15 +284,23 @@ async def get_torrent_hash( async def get_balanced_hashes(hashes: dict, config: dict): max_results = config["maxResults"] config_resolutions = config["resolutions"] - config_languages = {language.replace("_", " ").capitalize() for language in config["languages"]} + config_languages = { + language.replace("_", " ").capitalize() for language in config["languages"] + } include_all_languages = "All" in config_languages include_all_resolutions = "All" in config_resolutions - include_unknown_resolution = include_all_resolutions or "Unknown" in config_resolutions + include_unknown_resolution = ( + include_all_resolutions or "Unknown" in config_resolutions + ) hashes_by_resolution = {} for hash, hash_data in hashes.items(): hash_info = hash_data["data"] - if not include_all_languages and not hash_info["is_multi_audio"] and not any(lang in hash_info["language"] for lang in config_languages): + if ( + not include_all_languages + and not hash_info["is_multi_audio"] + and not any(lang in hash_info["language"] for lang in config_languages) + ): continue resolution = hash_info["resolution"] @@ -348,7 +338,7 @@ async def get_balanced_hashes(hashes: dict, config: dict): if missing_hashes <= 0: break current_count = len(balanced_hashes[resolution]) - available_hashes = hash_list[current_count:current_count + missing_hashes] + available_hashes = hash_list[current_count : current_count + missing_hashes] balanced_hashes[resolution].extend(available_hashes) missing_hashes -= len(available_hashes) diff --git a/comet/utils/models.py b/comet/utils/models.py index fefe62a..823bf54 100644 --- a/comet/utils/models.py +++ b/comet/utils/models.py @@ -2,6 +2,7 @@ from typing import List, Optional from databases import Database +from pydantic import BaseModel, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from RTN import RTN, BaseRankingModel, SettingsModel @@ -27,6 +28,39 @@ class AppSettings(BaseSettings): CUSTOM_HEADER_HTML: Optional[str] = None +settings = AppSettings() + + +class ConfigModel(BaseModel): + indexers: List[str] + languages: Optional[List[str]] = ["All"] + resolutions: Optional[List[str]] = ["All"] + maxResults: Optional[int] = 0 + filterTitles: Optional[bool] = True + debridService: str + debridApiKey: str + + @field_validator("indexers") + def check_indexers(cls, v, values): + if not any(indexer in settings.INDEXER_MANAGER_INDEXERS for indexer in v): + raise ValueError( + f"At least one indexer must be from {settings.INDEXER_MANAGER_INDEXERS}" + ) + return v + + @field_validator("maxResults") + def check_max_results(cls, v): + if v < 0: + raise ValueError("maxResults cannot be less than 0") + return v + + @field_validator("debridService") + def check_debrid_service(cls, v): + if v not in ["realdebrid", "realdebrid"]: + raise ValueError("Invalid debridService") + return v + + class BestOverallRanking(BaseRankingModel): uhd: int = 100 fhd: int = 90 @@ -53,5 +87,4 @@ class BestOverallRanking(BaseRankingModel): # For use anywhere rtn = RTN(settings=rtn_settings, ranking_model=rtn_ranking) -settings = AppSettings() database = Database(f"sqlite:///{settings.DATABASE_PATH}")