diff --git a/quetz/config.py b/quetz/config.py index da42cace..9d529a45 100644 --- a/quetz/config.py +++ b/quetz/config.py @@ -4,6 +4,7 @@ import logging import logging.config import os +from dataclasses import dataclass from distutils.util import strtobool from secrets import token_bytes from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Type, Union @@ -22,6 +23,24 @@ _user_dir = appdirs.user_config_dir("quetz") PAGINATION_LIMIT = 20 +COMPRESSION_EXTENSIONS = ["bz2", "gz", "zst"] + + +@dataclass +class CompressionConfig: + bz2_enabled: bool + gz_enabled: bool + zst_enabled: bool + + def enabled_extensions(self): + return [ + ext for ext in COMPRESSION_EXTENSIONS if getattr(self, f"{ext}_enabled") + ] + + def disabled_extensions(self): + return [ + ext for ext in COMPRESSION_EXTENSIONS if not getattr(self, f"{ext}_enabled") + ] class ConfigEntry(NamedTuple): @@ -444,6 +463,20 @@ def _get_environ_config(self) -> Dict[str, Any]: return config + def get_compression_config(self) -> CompressionConfig: + """Return the compression configuration. + + Returns + ------- + compression_config : CompressionConfig + The compression configuration + """ + return CompressionConfig( + self.compression_bz2_enabled, + self.compression_gz_enabled, + self.compression_zst_enabled, + ) + def get_package_store(self) -> pkgstores.PackageStore: """Return the appropriate package store as set in the config. @@ -495,20 +528,6 @@ def get_package_store(self) -> pkgstores.PackageStore: } ) - def get_enabled_compression_extensions(self): - return [ - ext - for ext in ("bz2", "gz", "zst") - if getattr(self, f"compression_{ext}_enabled") - ] - - def get_disabled_compression_extensions(self): - return [ - ext - for ext in ("bz2", "gz", "zst") - if not getattr(self, f"compression_{ext}_enabled") - ] - def configured_section(self, section: str) -> bool: """Return if a given section has been configured. diff --git a/quetz/main.py b/quetz/main.py index 9691dcd0..7a753f7f 100644 --- a/quetz/main.py +++ b/quetz/main.py @@ -750,9 +750,7 @@ def post_channel( dao, pkgstore, new_channel.name, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) # register mirror @@ -912,9 +910,7 @@ def delete_package( pkgstore, channel_name, platforms, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) @@ -1190,9 +1186,7 @@ def delete_package_version( pkgstore, channel_name, [platform], - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) @@ -1387,9 +1381,7 @@ def post_file_to_package( dao, pkgstore, package.channel_name, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) @@ -1481,9 +1473,7 @@ async def post_upload( dao, pkgstore, channel_name, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) @@ -1510,9 +1500,7 @@ def post_file_to_channel( dao, pkgstore, channel.name, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) @@ -1846,8 +1834,9 @@ def serve_path( chunk_size = 10_000 # Ensure we don't serve an old compressed file if this compression is now disabled + compression = config.get_compression_config() disabled_compressed_json_extensions = tuple( - f".json.{ext}" for ext in config.get_disabled_compression_extensions() + f".json.{ext}" for ext in compression.disabled_extensions() ) if path.endswith(disabled_compressed_json_extensions): suffix = PurePath(path).suffix @@ -1884,7 +1873,7 @@ def serve_path( if channel.mirror_channel_url and channel.mirror_mode == "proxy": repository = RemoteRepository(channel.mirror_channel_url, session) enabled_compressed_json_extensions = tuple( - f".json.{ext}" for ext in config.get_enabled_compression_extensions() + f".json.{ext}" for ext in compression.enabled_extensions() ) if path.endswith((".json",) + enabled_compressed_json_extensions): # repodata.json and current_repodata.json are cached locally diff --git a/quetz/tasks/common.py b/quetz/tasks/common.py index 563c536b..6f6358f0 100644 --- a/quetz/tasks/common.py +++ b/quetz/tasks/common.py @@ -54,9 +54,7 @@ def __init__( self.dao = dao.Dao(db) config = get_config() self.pkgstore = config.get_package_store() - self.bz2_enabled = config.compression_bz2_enabled - self.gz_enabled = config.compression_gz_enabled - self.zst_enabled = config.compression_zst_enabled + self.compression = config.get_compression_config() def execute_channel_action( self, @@ -104,12 +102,7 @@ def execute_channel_action( ) elif action == ChannelActionEnum.validate_packages: auth.assert_validate_package_cache(channel_name) - extra_args = dict( - channel_name=channel.name, - bz2_enabled=self.bz2_enabled, - gz_enabled=self.gz_enabled, - zst_enabled=self.zst_enabled, - ) + extra_args = dict(channel_name=channel.name, compression=self.compression) task = self.jobs_dao.create_job( action.encode('ascii'), user_id, diff --git a/quetz/tasks/indexing.py b/quetz/tasks/indexing.py index fca9f970..f2f4e130 100644 --- a/quetz/tasks/indexing.py +++ b/quetz/tasks/indexing.py @@ -9,6 +9,7 @@ import uuid from datetime import datetime, timezone from pathlib import Path +from typing import Any, Dict, Optional from jinja2 import Environment, PackageLoader, select_autoescape from jinja2.exceptions import UndefinedError @@ -16,6 +17,7 @@ import quetz.config from quetz import channel_data, repo_data from quetz.condainfo import MAX_CONDA_TIMESTAMP +from quetz.config import CompressionConfig from quetz.db_models import PackageVersion from quetz.utils import add_static_file, add_temp_static_file @@ -89,7 +91,7 @@ def _subdir_key(dir): def validate_packages( - dao, pkgstore, channel_name, bz2_enabled, gz_enabled, zst_enabled + dao, pkgstore, channel_name, compression: Optional[CompressionConfig] = None ): # for now we're just validating the size of the uploaded file logger.info("Starting package validation") @@ -181,9 +183,7 @@ def validate_packages( dao, pkgstore, channel_name, - bz2_enabled=bz2_enabled, - gz_enabled=gz_enabled, - zst_enabled=zst_enabled, + compression=compression, ) @@ -192,9 +192,7 @@ def update_indexes( pkgstore, channel_name, subdirs=None, - bz2_enabled=False, - gz_enabled=False, - zst_enabled=False, + compression: Optional[CompressionConfig] = None, ): jinjaenv = _jinjaenv() channeldata = channel_data.export(dao, channel_name) @@ -210,9 +208,7 @@ def update_indexes( None, "channeldata.json", pkgstore, - bz2_enabled=bz2_enabled, - gz_enabled=gz_enabled, - zst_enabled=zst_enabled, + compression=compression, ) # Generate index.html for the "root" directory @@ -226,7 +222,7 @@ def update_indexes( add_static_file(channel_index, channel_name, None, "index.html", pkgstore) # NB. No rss.xml is being generated here - files = {} + files: Dict[str, Any] = {} packages = {} subdir_template = jinjaenv.get_template("subdir-index.html.j2") @@ -259,9 +255,7 @@ def update_indexes( "repodata.json", tempdir_path, files, - bz2_enabled=bz2_enabled, - gz_enabled=gz_enabled, - zst_enabled=zst_enabled, + compression=compression, ) try: diff --git a/quetz/tasks/mirror.py b/quetz/tasks/mirror.py index e962b8eb..233353ab 100644 --- a/quetz/tasks/mirror.py +++ b/quetz/tasks/mirror.py @@ -175,9 +175,7 @@ def download_remote_file( None, path, pkgstore, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) else: remote_file = repository.open(path) @@ -190,9 +188,7 @@ def download_remote_file( None, path, pkgstore, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) else: pkgstore.add_package(data_stream, channel, path) @@ -508,9 +504,7 @@ def handle_batch(update_batch): pkgstore, channel_name, subdirs=[arch], - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + compression=config.get_compression_config(), ) diff --git a/quetz/tasks/reindexing.py b/quetz/tasks/reindexing.py index 9f9db4c1..70781e67 100644 --- a/quetz/tasks/reindexing.py +++ b/quetz/tasks/reindexing.py @@ -173,12 +173,7 @@ def reindex_packages_from_store( try: update_indexes( - dao, - pkgstore, - channel_name, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + dao, pkgstore, channel_name, compression=config.get_compression_config() ) dao.db.commit() except IntegrityError: diff --git a/quetz/tests/api/test_main_packages.py b/quetz/tests/api/test_main_packages.py index 3601da6b..7fc74d58 100644 --- a/quetz/tests/api/test_main_packages.py +++ b/quetz/tests/api/test_main_packages.py @@ -80,12 +80,7 @@ def test_delete_package_versions_with_package( assert package_version.package_name == public_package.name update_indexes( - dao, - pkgstore, - public_channel.name, - bz2_enabled=config.compression_bz2_enabled, - gz_enabled=config.compression_gz_enabled, - zst_enabled=config.compression_zst_enabled, + dao, pkgstore, public_channel.name, compression=config.get_compression_config() ) # Get package files diff --git a/quetz/tests/test_indexing.py b/quetz/tests/test_indexing.py index 67b041d3..e33d1db1 100644 --- a/quetz/tests/test_indexing.py +++ b/quetz/tests/test_indexing.py @@ -7,6 +7,7 @@ import zstandard from quetz import channel_data +from quetz.config import CompressionConfig from quetz.tasks.indexing import update_indexes @@ -37,9 +38,7 @@ def test_update_indexes_empty_channel( dao, pkgstore, public_channel.name, - bz2_enabled=bz2_enabled, - gz_enabled=gz_enabled, - zst_enabled=zst_enabled, + compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled), ) files = pkgstore.list_files(public_channel.name) @@ -81,9 +80,7 @@ def test_update_indexes_empty_package( dao, pkgstore, public_channel.name, - bz2_enabled=bz2_enabled, - gz_enabled=gz_enabled, - zst_enabled=zst_enabled, + compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled), ) files = pkgstore.list_files(public_channel.name) @@ -132,9 +129,7 @@ def test_update_indexes_with_package_version( dao, pkgstore, public_channel.name, - bz2_enabled=bz2_enabled, - gz_enabled=gz_enabled, - zst_enabled=zst_enabled, + compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled), ) files = pkgstore.list_files(public_channel.name) diff --git a/quetz/tests/test_mirror.py b/quetz/tests/test_mirror.py index 2827fb9b..915f1c74 100644 --- a/quetz/tests/test_mirror.py +++ b/quetz/tests/test_mirror.py @@ -1513,7 +1513,9 @@ def test_download_remote_file_current_repodata( @pytest.fixture(scope="function") -def config_extra(request) -> str: +def config_extra(request=None) -> str: + if request is None: + return "" return ( "[compression]\n" f"bz2_enabled = {str(request.param[0]).lower()}\n" diff --git a/quetz/utils.py b/quetz/utils.py index fa0538ea..64ee72a8 100644 --- a/quetz/utils.py +++ b/quetz/utils.py @@ -24,6 +24,7 @@ import zstandard from sqlalchemy import String, and_, cast, collate, not_, or_ +from .config import CompressionConfig from .db_models import Channel, Package, PackageVersion, User # Same values as conda-index @@ -61,17 +62,17 @@ def add_static_file( fname, pkgstore, file_index=None, - bz2_enabled=False, - gz_enabled=False, - zst_enabled=False, + compression: Optional[CompressionConfig] = None, ): - compressed = compress_file(contents, bz2_enabled, gz_enabled, zst_enabled) + if compression is None: + compression = CompressionConfig(False, False, False) + compressed = compress_file(contents, compression) path = f"{subdir}/{fname}" if subdir else fname - if bz2_enabled: + if compression.bz2_enabled: pkgstore.add_file(compressed.bz2_file, channel_name, f"{path}.bz2") - if gz_enabled: + if compression.gz_enabled: pkgstore.add_file(compressed.gz_file, channel_name, f"{path}.gz") - if zst_enabled: + if compression.zst_enabled: pkgstore.add_file(compressed.zst_file, channel_name, f"{path}.zst") pkgstore.add_file(compressed.raw_file, channel_name, f"{path}") @@ -86,11 +87,11 @@ def add_temp_static_file( fname, temp_dir, file_index=None, - bz2_enabled=False, - gz_enabled=False, - zst_enabled=False, + compression: Optional[CompressionConfig] = None, ): - compressed = compress_file(contents, bz2_enabled, gz_enabled, zst_enabled) + if compression is None: + compression = CompressionConfig(False, False, False) + compressed = compress_file(contents, compression) temp_dir = Path(temp_dir) @@ -106,13 +107,13 @@ def add_temp_static_file( with open(file_path, 'wb') as fo: fo.write(compressed.raw_file) - if bz2_enabled: + if compressed.bz2_file: with open(f"{file_path}.bz2", 'wb') as fo: fo.write(compressed.bz2_file) - if gz_enabled: + if compressed.gz_file: with open(f"{file_path}.gz", 'wb') as fo: fo.write(compressed.gz_file) - if zst_enabled: + if compressed.zst_file: with open(f"{file_path}.zst", 'wb') as fo: fo.write(compressed.zst_file) @@ -121,19 +122,19 @@ def add_temp_static_file( def compress_file( - contents: Union[str, bytes], bz2_enabled=False, gz_enabled=False, zst_enabled=False + contents: Union[str, bytes], compression: CompressionConfig ) -> Compressed: if not isinstance(contents, bytes): raw_file = contents.encode("utf-8") else: raw_file = contents - bz2_file = bz2.compress(raw_file) if bz2_enabled else None - gz_file = gzip.compress(raw_file) if gz_enabled else None + bz2_file = bz2.compress(raw_file) if compression.bz2_enabled else None + gz_file = gzip.compress(raw_file) if compression.gz_enabled else None zst_file = ( zstandard.ZstdCompressor( level=ZSTD_COMPRESS_LEVEL, threads=ZSTD_COMPRESS_THREADS ).compress(raw_file) - if zst_enabled + if compression.zst_enabled else None ) return Compressed(raw_file, bz2_file, gz_file, zst_file)