diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c83a0007..fc48c9dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ env: jobs: test_quetz: # timeout for the whole job - timeout-minutes: 10 + timeout-minutes: 12 runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -79,7 +79,7 @@ jobs: - name: Testing server shell: bash -l -eo pipefail {0} # timeout for the step - timeout-minutes: 5 + timeout-minutes: 8 env: TEST_DB_BACKEND: ${{ matrix.test_database }} QUETZ_TEST_DBINIT: ${{ matrix.db_init }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d364db7a..ebe539eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: (quetz/migrations) +exclude: ^(quetz/migrations|quetz/tests/data/test-server/) repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. @@ -27,4 +27,4 @@ repos: - types-toml - types-ujson - types-aiofiles - args: [--show-error-codes, --implicit-optional] \ No newline at end of file + args: [--show-error-codes, --implicit-optional] diff --git a/docs/source/deploying/configuration.rst b/docs/source/deploying/configuration.rst index 306118c3..f1bd14d2 100644 --- a/docs/source/deploying/configuration.rst +++ b/docs/source/deploying/configuration.rst @@ -245,6 +245,21 @@ the ``redis-server``. For more information, see :ref:`task_workers`. +``compression`` section +^^^^^^^^^^^^^^^^^^^^^^^ + +You can configure which compressions are enabled for the ``repodata.json`` file. + +:gz_enabled: enable gzip compression +:bz2_enabled: enable bzip2 compression +:zst_enabled: enable zstandard compression + +.. note:: + + Compression is an expensive operation for big files. + Updating local channels index is done in the background, so this isnt' an issue. + But for proxy channels, compression is done after downloading the remote ``repodata.json`` and before to serve it. + ``quotas`` section ^^^^^^^^^^^^^^^^^^ diff --git a/environment.yml b/environment.yml index 7c2e78c7..c5b04a3d 100644 --- a/environment.yml +++ b/environment.yml @@ -52,5 +52,6 @@ dependencies: - pytest-asyncio - pytest-timeout - pydantic >=2 + - py-rattler - pip: - git+https://github.com/jupyter-server/jupyter_releaser.git@v2 diff --git a/plugins/quetz_current_repodata/quetz_current_repodata/main.py b/plugins/quetz_current_repodata/quetz_current_repodata/main.py index f00c7ef1..36c57faa 100644 --- a/plugins/quetz_current_repodata/quetz_current_repodata/main.py +++ b/plugins/quetz_current_repodata/quetz_current_repodata/main.py @@ -4,8 +4,12 @@ from conda_index.index import _build_current_repodata import quetz +from quetz.config import Config from quetz.utils import add_temp_static_file +config = Config() +compression = config.get_compression_config() + @quetz.hookimpl def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages): @@ -25,4 +29,5 @@ def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages) "current_repodata.json", tempdir, files, + compression, ) diff --git a/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py b/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py index 4ed2e43f..8e01a28b 100644 --- a/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py +++ b/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py @@ -15,6 +15,7 @@ config = Config() pkgstore = config.get_package_store() +compression = config.get_compression_config() def update_dict(packages, instructions): @@ -147,6 +148,7 @@ def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages) "repodata_from_packages.json", tempdir, files, + compression=compression, ) patch_repodata(repodata, patch_instructions) @@ -162,4 +164,5 @@ def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages) "repodata.json", tempdir, files, + compression=compression, ) diff --git a/plugins/quetz_repodata_patching/tests/test_main.py b/plugins/quetz_repodata_patching/tests/test_main.py index defb3261..3c9b1a0a 100644 --- a/plugins/quetz_repodata_patching/tests/test_main.py +++ b/plugins/quetz_repodata_patching/tests/test_main.py @@ -288,6 +288,7 @@ def test_post_package_indexing( channel_name, package_repodata_patches, db, + config, package_file_name, repodata_stem, compressed_repodata, @@ -301,7 +302,9 @@ def get_db(): yield db with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): - indexing.update_indexes(dao, pkgstore, channel_name) + indexing.update_indexes( + dao, pkgstore, channel_name, compression=config.get_compression_config() + ) ext = "json.bz2" if compressed_repodata else "json" open_ = bz2.open if compressed_repodata else open @@ -372,6 +375,7 @@ def test_index_html( package_file_name, dao, db, + config, remove_instructions, ): @contextmanager @@ -379,7 +383,9 @@ def get_db(): yield db with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): - indexing.update_indexes(dao, pkgstore, channel_name) + indexing.update_indexes( + dao, pkgstore, channel_name, compression=config.get_compression_config() + ) index_path = os.path.join( pkgstore.channels_dir, @@ -412,6 +418,7 @@ def test_patches_for_subdir( package_repodata_patches, dao, db, + config, package_subdir, patches_subdir, ): @@ -420,7 +427,9 @@ def get_db(): yield db with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): - indexing.update_indexes(dao, pkgstore, channel_name) + indexing.update_indexes( + dao, pkgstore, channel_name, compression=config.get_compression_config() + ) index_path = os.path.join( pkgstore.channels_dir, @@ -466,13 +475,16 @@ def test_no_repodata_patches_package( package_file_name, dao, db, + config, ): @contextmanager def get_db(): yield db with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): - indexing.update_indexes(dao, pkgstore, channel_name) + indexing.update_indexes( + dao, pkgstore, channel_name, compression=config.get_compression_config() + ) index_path = os.path.join( pkgstore.channels_dir, diff --git a/plugins/quetz_repodata_zchunk/tests/test_main.py b/plugins/quetz_repodata_zchunk/tests/test_main.py index 364abdae..5664a467 100644 --- a/plugins/quetz_repodata_zchunk/tests/test_main.py +++ b/plugins/quetz_repodata_zchunk/tests/test_main.py @@ -112,8 +112,11 @@ def test_repodata_zchunk( package_file_name, dao, db, + config, ): - indexing.update_indexes(dao, pkgstore, channel_name) + indexing.update_indexes( + dao, pkgstore, channel_name, compression=config.get_compression_config() + ) index_path = os.path.join( pkgstore.channels_dir, diff --git a/quetz/config.py b/quetz/config.py index bf26182c..7aecbb06 100644 --- a/quetz/config.py +++ b/quetz/config.py @@ -4,6 +4,7 @@ import logging import logging.config import os +from dataclasses import dataclass from distutils.util import strtobool from secrets import token_bytes from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Type, Union @@ -22,6 +23,24 @@ _user_dir = appdirs.user_config_dir("quetz") PAGINATION_LIMIT = 20 +COMPRESSION_EXTENSIONS = ["bz2", "gz", "zst"] + + +@dataclass +class CompressionConfig: + bz2_enabled: bool + gz_enabled: bool + zst_enabled: bool + + def enabled_extensions(self): + return [ + ext for ext in COMPRESSION_EXTENSIONS if getattr(self, f"{ext}_enabled") + ] + + def disabled_extensions(self): + return [ + ext for ext in COMPRESSION_EXTENSIONS if not getattr(self, f"{ext}_enabled") + ] class ConfigEntry(NamedTuple): @@ -62,6 +81,7 @@ class Config: ConfigEntry("package_unpack_threads", int, 1), ConfigEntry("frontend_dir", str, default=""), ConfigEntry("redirect_http_to_https", bool, False), + ConfigEntry("rattler_cache_dir", str, default="rattler_cache"), ], ), ConfigSection( @@ -232,6 +252,14 @@ class Config: ConfigEntry("soft_delete_package", bool, required=False, default=False), ], ), + ConfigSection( + "compression", + [ + ConfigEntry("gz_enabled", bool, default=True), + ConfigEntry("bz2_enabled", bool, default=True), + ConfigEntry("zst_enabled", bool, default=False), + ], + ), ] _config_dirs = [_site_dir, _user_dir] _config_files = [os.path.join(d, _filename) for d in _config_dirs] @@ -443,6 +471,20 @@ def _get_environ_config(self) -> Dict[str, Any]: return config + def get_compression_config(self) -> CompressionConfig: + """Return the compression configuration. + + Returns + ------- + compression_config : CompressionConfig + Class defining which compressions are enabled (bzip2, gzip and zstandard) + """ + return CompressionConfig( + self.compression_bz2_enabled, + self.compression_gz_enabled, + self.compression_zst_enabled, + ) + def get_package_store(self) -> pkgstores.PackageStore: """Return the appropriate package store as set in the config. diff --git a/quetz/main.py b/quetz/main.py index db1a36be..5b6a92a4 100644 --- a/quetz/main.py +++ b/quetz/main.py @@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from email.utils import formatdate +from pathlib import PurePath from tempfile import SpooledTemporaryFile, TemporaryFile from typing import Awaitable, Callable, List, Optional, Tuple, Type @@ -694,7 +695,7 @@ def post_channel( dao: Dao = Depends(get_dao), auth: authorization.Rules = Depends(get_rules), task: Task = Depends(get_tasks_worker), - config=Depends(get_config), + config: Config = Depends(get_config), session: requests.Session = Depends(get_remote_session), ): user_id = auth.assert_user() @@ -751,7 +752,12 @@ def post_channel( channel = dao.create_channel(new_channel, user_id, authorization.OWNER, size_limit) pkgstore.create_channel(new_channel.name) if not is_proxy: - indexing.update_indexes(dao, pkgstore, new_channel.name) + indexing.update_indexes( + dao, + pkgstore, + new_channel.name, + compression=config.get_compression_config(), + ) # register mirror if is_mirror and register_mirror: @@ -878,6 +884,7 @@ def delete_package( db=Depends(get_db), auth: authorization.Rules = Depends(get_rules), dao: Dao = Depends(get_dao), + config: Config = Depends(get_config), ): auth.assert_package_delete(package) @@ -904,7 +911,14 @@ def delete_package( wrapped_bg_task = background_task_wrapper(indexing.update_indexes, logger) # Background task to update indexes - background_tasks.add_task(wrapped_bg_task, dao, pkgstore, channel_name, platforms) + background_tasks.add_task( + wrapped_bg_task, + dao, + pkgstore, + channel_name, + platforms, + compression=config.get_compression_config(), + ) @api_router.post( @@ -1232,6 +1246,7 @@ def delete_package_version( dao: Dao = Depends(get_dao), db=Depends(get_db), auth: authorization.Rules = Depends(get_rules), + config: Config = Depends(get_config), ): version = dao.get_package_version_by_filename( channel_name, package_name, filename, platform @@ -1262,7 +1277,14 @@ def delete_package_version( wrapped_bg_task = background_task_wrapper(indexing.update_indexes, logger) # Background task to update indexes - background_tasks.add_task(wrapped_bg_task, dao, pkgstore, channel_name, [platform]) + background_tasks.add_task( + wrapped_bg_task, + dao, + pkgstore, + channel_name, + [platform], + compression=config.get_compression_config(), + ) @api_router.get( @@ -1444,13 +1466,20 @@ def post_file_to_package( channel: db_models.Channel = Depends( ChannelChecker(allow_proxy=False, allow_mirror=False), ), + config: Config = Depends(get_config), ): handle_package_files(package.channel, files, dao, auth, force, package=package) dao.update_channel_size(package.channel_name) wrapped_bg_task = background_task_wrapper(indexing.update_indexes, logger) # Background task to update indexes - background_tasks.add_task(wrapped_bg_task, dao, pkgstore, package.channel_name) + background_tasks.add_task( + wrapped_bg_task, + dao, + pkgstore, + package.channel_name, + compression=config.get_compression_config(), + ) @api_router.post( @@ -1465,6 +1494,7 @@ async def post_upload( force: bool = False, dao: Dao = Depends(get_dao), auth: authorization.Rules = Depends(get_rules), + config: Config = Depends(get_config), ): logger.debug( f"Uploading file {filename} with checksum {sha256} to channel {channel_name}" @@ -1535,7 +1565,13 @@ async def post_upload( wrapped_bg_task = background_task_wrapper(indexing.update_indexes, logger) # Background task to update indexes - background_tasks.add_task(wrapped_bg_task, dao, pkgstore, channel_name) + background_tasks.add_task( + wrapped_bg_task, + dao, + pkgstore, + channel_name, + compression=config.get_compression_config(), + ) @api_router.post("/channels/{channel_name}/files/", status_code=201, tags=["files"]) @@ -1548,6 +1584,7 @@ def post_file_to_channel( ), dao: Dao = Depends(get_dao), auth: authorization.Rules = Depends(get_rules), + config: Config = Depends(get_config), ): handle_package_files(channel, files, dao, auth, force) @@ -1555,7 +1592,13 @@ def post_file_to_channel( wrapped_bg_task = background_task_wrapper(indexing.update_indexes, logger) # Background task to update indexes - background_tasks.add_task(wrapped_bg_task, dao, pkgstore, channel.name) + background_tasks.add_task( + wrapped_bg_task, + dao, + pkgstore, + channel.name, + compression=config.get_compression_config(), + ) def _assert_filename_package_name_consistent(file_name: str, package_name: str): @@ -1883,9 +1926,22 @@ def serve_path( accept_encoding: Optional[str] = Header(None), session=Depends(get_remote_session), dao: Dao = Depends(get_dao), + config: Config = Depends(get_config), ): chunk_size = 10_000 + # Ensure we don't serve an old compressed file if this compression is now disabled + compression = config.get_compression_config() + disabled_compressed_json_extensions = tuple( + f".json.{ext}" for ext in compression.disabled_extensions() + ) + if path.endswith(disabled_compressed_json_extensions): + suffix = PurePath(path).suffix + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"{channel.name}/{path} not found - {suffix} compression disabled", + ) + is_package_request = path.endswith((".tar.bz2", ".conda")) package_name = None @@ -1913,14 +1969,30 @@ def serve_path( if channel.mirror_channel_url and channel.mirror_mode == "proxy": repository = RemoteRepository(channel.mirror_channel_url, session) - if not pkgstore.file_exists(channel.name, path): - download_remote_file(repository, pkgstore, channel.name, path) - elif path.endswith(".json"): + enabled_compressed_json_extensions = tuple( + f".json.{ext}" for ext in compression.enabled_extensions() + ) + if path.endswith((".json",) + enabled_compressed_json_extensions): # repodata.json and current_repodata.json are cached locally # for channel.ttl seconds - _, fmtime, _ = pkgstore.get_filemetadata(channel.name, path) + # if one of the compressed file is requested, we check and download + # the non compressed version if needed + # (compressed files are created locally and should all have the same fmtime) + suffix = PurePath(path).suffix + if suffix == ".json": + json_path = path + else: + json_path = path[: -len(suffix)] + try: + _, fmtime, _ = pkgstore.get_filemetadata(channel.name, json_path) + except FileNotFoundError: + fmtime = 0 if time.time() - fmtime >= channel.ttl: - download_remote_file(repository, pkgstore, channel.name, path) + download_remote_file( + repository, pkgstore, channel.name, json_path, config + ) + elif not pkgstore.file_exists(channel.name, path): + download_remote_file(repository, pkgstore, channel.name, path, config) if ( is_package_request or pkgstore.kind == "LocalStore" @@ -1981,8 +2053,9 @@ def serve_channel_index( accept_encoding: Optional[str] = Header(None), session=Depends(get_remote_session), dao: Dao = Depends(get_dao), + config: Config = Depends(get_config), ): - return serve_path("index.html", channel, accept_encoding, session, dao) + return serve_path("index.html", channel, accept_encoding, session, dao, config) @app.get("/health/ready", status_code=status.HTTP_200_OK) diff --git a/quetz/tasks/common.py b/quetz/tasks/common.py index b19a60f5..8dce5081 100644 --- a/quetz/tasks/common.py +++ b/quetz/tasks/common.py @@ -52,7 +52,9 @@ def __init__( self.db = db self.jobs_dao = JobsDao(db) self.dao = dao.Dao(db) - self.pkgstore = get_config().get_package_store() + config = get_config() + self.pkgstore = config.get_package_store() + self.compression = config.get_compression_config() def execute_channel_action( self, @@ -100,7 +102,7 @@ def execute_channel_action( ) elif action == ChannelActionEnum.validate_packages: auth.assert_validate_package_cache(channel_name) - extra_args = dict(channel_name=channel.name) + extra_args = dict(channel_name=channel.name, compression=self.compression) task = self.jobs_dao.create_job( action.encode("ascii"), user_id, diff --git a/quetz/tasks/indexing.py b/quetz/tasks/indexing.py index a54aaa97..b35989c9 100644 --- a/quetz/tasks/indexing.py +++ b/quetz/tasks/indexing.py @@ -9,6 +9,7 @@ import uuid from datetime import datetime, timezone from pathlib import Path +from typing import Any, Dict, Optional from jinja2 import Environment, PackageLoader, select_autoescape from jinja2.exceptions import UndefinedError @@ -16,6 +17,7 @@ import quetz.config from quetz import channel_data, repo_data from quetz.condainfo import MAX_CONDA_TIMESTAMP +from quetz.config import CompressionConfig from quetz.db_models import PackageVersion from quetz.utils import add_static_file, add_temp_static_file @@ -88,7 +90,9 @@ def _subdir_key(dir): return _subdir_order.get(dir, dir) -def validate_packages(dao, pkgstore, channel_name): +def validate_packages( + dao, pkgstore, channel_name, compression: Optional[CompressionConfig] = None +): # for now we're just validating the size of the uploaded file logger.info("Starting package validation") @@ -175,10 +179,21 @@ def validate_packages(dao, pkgstore, channel_name): logger.info(f"Wrong size: {wrong_size}") logger.info(f"Not uploaded: {inexistent}") - update_indexes(dao, pkgstore, channel_name) + update_indexes( + dao, + pkgstore, + channel_name, + compression=compression, + ) -def update_indexes(dao, pkgstore, channel_name, subdirs=None): +def update_indexes( + dao, + pkgstore, + channel_name, + subdirs=None, + compression: Optional[CompressionConfig] = None, +): jinjaenv = _jinjaenv() channeldata = channel_data.export(dao, channel_name) @@ -187,7 +202,14 @@ def update_indexes(dao, pkgstore, channel_name, subdirs=None): # Generate channeldata.json and its compressed version chandata_json = json.dumps(channeldata, indent=2, sort_keys=False) - add_static_file(chandata_json, channel_name, None, "channeldata.json", pkgstore) + add_static_file( + chandata_json, + channel_name, + None, + "channeldata.json", + pkgstore, + compression=compression, + ) # Generate index.html for the "root" directory channel_index = jinjaenv.get_template("channeldata-index.html.j2").render( @@ -196,11 +218,11 @@ def update_indexes(dao, pkgstore, channel_name, subdirs=None): subdirs=subdirs, current_time=datetime.now(timezone.utc), ) - + # No compressed version created add_static_file(channel_index, channel_name, None, "index.html", pkgstore) # NB. No rss.xml is being generated here - files = {} + files: Dict[str, Any] = {} packages = {} subdir_template = jinjaenv.get_template("subdir-index.html.j2") @@ -227,7 +249,13 @@ def update_indexes(dao, pkgstore, channel_name, subdirs=None): repodata = json.dumps(raw_repodata, indent=2, sort_keys=False) add_temp_static_file( - repodata, channel_name, sdir, "repodata.json", tempdir_path, files + repodata, + channel_name, + sdir, + "repodata.json", + tempdir_path, + files, + compression=compression, ) try: @@ -251,6 +279,7 @@ def update_indexes(dao, pkgstore, channel_name, subdirs=None): current_time=datetime.now(timezone.utc), add_files=files[sdir], ) + # No compressed version created add_static_file(subdir_index_html, channel_name, sdir, "index.html", pkgstore) # recursively walk through the tree diff --git a/quetz/tasks/mirror.py b/quetz/tasks/mirror.py index ab65e145..08f79f95 100644 --- a/quetz/tasks/mirror.py +++ b/quetz/tasks/mirror.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import json import logging @@ -5,11 +6,15 @@ import shutil from concurrent.futures import ThreadPoolExecutor from http.client import IncompleteRead +from pathlib import Path, PurePath from tempfile import SpooledTemporaryFile from typing import List +import aiofiles +import aiofiles.os import requests from fastapi import HTTPException, status +from rattler import Channel, ChannelConfig, Platform, fetch_repo_data from tenacity import TryAgain, retry from tenacity.after import after_log from tenacity.stop import stop_after_attempt @@ -59,6 +64,11 @@ def __init__(self, host, session): def open(self, path): return RemoteFile(self.host, path, self.session) + @property + def rattler_channel(self): + host_path = PurePath(self.host) + return Channel(host_path.name, ChannelConfig(f"{host_path.parent}/")) + class RemoteServerError(Exception): pass @@ -106,8 +116,42 @@ def json(self): return json.load(self.file) +async def download_repodata(repository: RemoteRepository, channel: str, platform: str): + cache_path = Path(Config().general_rattler_cache_dir) / channel / platform + logger.debug(f"Fetching {platform} repodata from {repository.rattler_channel}") + try: + await fetch_repo_data( + channels=[repository.rattler_channel], + platforms=[Platform(platform)], + cache_path=cache_path, + callback=None, + ) + except Exception as e: + logger.error(f"Failed to fetch repodata: {e}") + raise + files = await aiofiles.os.listdir(cache_path) + try: + json_file = [ + filename + for filename in files + if filename.endswith(".json") and not filename.endswith(".info.json") + ][0] + except IndexError: + logger.error(f"No json file found in rattler cache: {cache_path}") + raise RemoteFileNotFound + else: + async with aiofiles.open(cache_path / json_file, "rb") as f: + contents = await f.read() + logger.debug(f"Retrieved repodata from rattler cache: {cache_path / json_file}") + return contents + + def download_remote_file( - repository: RemoteRepository, pkgstore: PackageStore, channel: str, path: str + repository: RemoteRepository, + pkgstore: PackageStore, + channel: str, + path: str, + config: Config, ): """Download a file from a remote repository to a package store""" @@ -122,13 +166,33 @@ def download_remote_file( # Acquire a lock to prevent multiple concurrent downloads of the same file with pkgstore.create_download_lock(channel, path): logger.debug(f"Downloading {path} from {channel} to pkgstore") - remote_file = repository.open(path) - data_stream = remote_file.file - - if path.endswith(".json"): - add_static_file(data_stream.read(), channel, None, path, pkgstore) + if path.endswith("/repodata.json"): + platform = str(PurePath(path).parent) + repodata = asyncio.run(download_repodata(repository, channel, platform)) + add_static_file( + repodata, + channel, + None, + path, + pkgstore, + compression=config.get_compression_config(), + ) else: - pkgstore.add_package(data_stream, channel, path) + remote_file = repository.open(path) + data_stream = remote_file.file + + if path.endswith(".json"): + add_static_file( + data_stream.read(), + channel, + None, + path, + pkgstore, + compression=config.get_compression_config(), + ) + else: + pkgstore.add_package(data_stream, channel, path) + logger.debug(f"Added {path} from {channel} to pkgstore") pkgstore.delete_download_lock(channel, path) @@ -435,7 +499,13 @@ def handle_batch(update_batch): any_updated |= handle_batch(update_batch) if any_updated: - indexing.update_indexes(dao, pkgstore, channel_name, subdirs=[arch]) + indexing.update_indexes( + dao, + pkgstore, + channel_name, + subdirs=[arch], + compression=config.get_compression_config(), + ) def create_packages_from_channeldata( diff --git a/quetz/tasks/reindexing.py b/quetz/tasks/reindexing.py index 7624a660..70781e67 100644 --- a/quetz/tasks/reindexing.py +++ b/quetz/tasks/reindexing.py @@ -172,7 +172,9 @@ def reindex_packages_from_store( ) try: - update_indexes(dao, pkgstore, channel_name) + update_indexes( + dao, pkgstore, channel_name, compression=config.get_compression_config() + ) dao.db.commit() except IntegrityError: dao.rollback() diff --git a/quetz/tests/api/test_main_packages.py b/quetz/tests/api/test_main_packages.py index d7c80d19..8d6b2a1d 100644 --- a/quetz/tests/api/test_main_packages.py +++ b/quetz/tests/api/test_main_packages.py @@ -65,14 +65,23 @@ def test_delete_package_non_member( def test_delete_package_versions_with_package( - auth_client, public_channel, public_package, package_version, dao, db, pkgstore + auth_client, + public_channel, + public_package, + package_version, + dao, + db, + pkgstore, + config, ): assert public_channel.size > 0 assert public_channel.size == package_version.size assert package_version.package_name == public_package.name - update_indexes(dao, pkgstore, public_channel.name) + update_indexes( + dao, pkgstore, public_channel.name, compression=config.get_compression_config() + ) # Get package files package_filenames = [ diff --git a/quetz/tests/data/test-server/README.md b/quetz/tests/data/test-server/README.md new file mode 100644 index 00000000..8ade2a02 --- /dev/null +++ b/quetz/tests/data/test-server/README.md @@ -0,0 +1,7 @@ +# test-server + +A simple server to serve repodata for test purposes. +Originally implemented by [mamba](https://github.com/mamba-org/mamba/tree/a8d595b6ff8ac182e60741c3e8cbd142e7d19905/mamba/tests) +under the BSD-3-Clause license. + +Copied from [rattler](https://github.com/mamba-org/rattler/tree/9a3f2cc92a50fec4f6f7c13488441ca4ac15269b/test-data/test-server). diff --git a/quetz/tests/data/test-server/repo/channeldata.json b/quetz/tests/data/test-server/repo/channeldata.json new file mode 100644 index 00000000..f73272e5 --- /dev/null +++ b/quetz/tests/data/test-server/repo/channeldata.json @@ -0,0 +1,38 @@ +{ + "channeldata_version": 1, + "packages": { + "test-package": { + "activate.d": false, + "binary_prefix": false, + "deactivate.d": false, + "description": null, + "dev_url": null, + "doc_source_url": null, + "doc_url": null, + "home": "https://github.com/mamba-org/mamba", + "icon_hash": null, + "icon_url": null, + "identifiers": null, + "keywords": null, + "license": "BSD", + "post_link": false, + "pre_link": false, + "pre_unlink": false, + "recipe_origin": null, + "run_exports": {}, + "source_git_url": null, + "source_url": null, + "subdirs": [ + "noarch" + ], + "summary": "I am just a test package!", + "tags": null, + "text_prefix": false, + "timestamp": 1613117294, + "version": "0.1" + } + }, + "subdirs": [ + "noarch" + ] +} diff --git a/quetz/tests/data/test-server/repo/index.html b/quetz/tests/data/test-server/repo/index.html new file mode 100644 index 00000000..e4ded9e2 --- /dev/null +++ b/quetz/tests/data/test-server/repo/index.html @@ -0,0 +1,90 @@ + + + repo + + + +

repo

+

RSS Feed   channeldata.json

+noarch    + + + + + + + + + + + + + + + +
PackageLatest VersionDocDevLicensenoarch Summary
test-package0.1BSDX I am just a test package!
+
Updated: 2021-02-12 09:02:37 +0000 - Files: 1
+ + diff --git a/quetz/tests/data/test-server/repo/noarch/current_repodata.json b/quetz/tests/data/test-server/repo/noarch/current_repodata.json new file mode 100644 index 00000000..52facdf7 --- /dev/null +++ b/quetz/tests/data/test-server/repo/noarch/current_repodata.json @@ -0,0 +1,25 @@ +{ + "info": { + "subdir": "noarch" + }, + "packages": { + "test-package-0.1-0.tar.bz2": { + "build": "0", + "build_number": 0, + "depends": [], + "license": "BSD", + "license_family": "BSD", + "md5": "2a8595f37faa2950e1b433acbe91d481", + "name": "test-package", + "noarch": "generic", + "sha256": "b908ffce2d26d94c58c968abf286568d4bcf87d1cfe6c994958351724a6f6988", + "size": 5719, + "subdir": "noarch", + "timestamp": 1613117294885, + "version": "0.1" + } + }, + "packages.conda": {}, + "removed": [], + "repodata_version": 1 +} diff --git a/quetz/tests/data/test-server/repo/noarch/current_repodata.json.bz2 b/quetz/tests/data/test-server/repo/noarch/current_repodata.json.bz2 new file mode 100644 index 00000000..76c130c0 Binary files /dev/null and b/quetz/tests/data/test-server/repo/noarch/current_repodata.json.bz2 differ diff --git a/quetz/tests/data/test-server/repo/noarch/index.html b/quetz/tests/data/test-server/repo/noarch/index.html new file mode 100644 index 00000000..6c83d944 --- /dev/null +++ b/quetz/tests/data/test-server/repo/noarch/index.html @@ -0,0 +1,88 @@ + + + repo/noarch + + + +

repo/noarch

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FilenameSizeLast ModifiedSHA256MD5
repodata.json586 B2021-02-12 09:01:48 +0000cc5f72aaa8d3f508c8adca196fe05cf4b19e1ca1006cfcbb3892d73160bd3b047501ec77771889b42a39c615158cb9c4
repodata.json.bz2351 B2021-02-12 09:01:48 +00009a0288ca48c6b8caa348d7cafefd0981c2d25dcb4a5837a5187ab200b8b9fb450c926155642f0e894d97dc8a5af7007b
repodata_from_packages.json586 B2021-02-12 09:01:48 +0000cc5f72aaa8d3f508c8adca196fe05cf4b19e1ca1006cfcbb3892d73160bd3b047501ec77771889b42a39c615158cb9c4
repodata_from_packages.json.bz2351 B2021-02-12 09:01:48 +00009a0288ca48c6b8caa348d7cafefd0981c2d25dcb4a5837a5187ab200b8b9fb450c926155642f0e894d97dc8a5af7007b
test-package-0.1-0.tar.bz26 KB2021-02-12 08:08:14 +0000b908ffce2d26d94c58c968abf286568d4bcf87d1cfe6c994958351724a6f69882a8595f37faa2950e1b433acbe91d481
+
Updated: 2021-02-12 09:02:37 +0000 - Files: 1
+ + diff --git a/quetz/tests/data/test-server/repo/noarch/repodata.json b/quetz/tests/data/test-server/repo/noarch/repodata.json new file mode 100644 index 00000000..52facdf7 --- /dev/null +++ b/quetz/tests/data/test-server/repo/noarch/repodata.json @@ -0,0 +1,25 @@ +{ + "info": { + "subdir": "noarch" + }, + "packages": { + "test-package-0.1-0.tar.bz2": { + "build": "0", + "build_number": 0, + "depends": [], + "license": "BSD", + "license_family": "BSD", + "md5": "2a8595f37faa2950e1b433acbe91d481", + "name": "test-package", + "noarch": "generic", + "sha256": "b908ffce2d26d94c58c968abf286568d4bcf87d1cfe6c994958351724a6f6988", + "size": 5719, + "subdir": "noarch", + "timestamp": 1613117294885, + "version": "0.1" + } + }, + "packages.conda": {}, + "removed": [], + "repodata_version": 1 +} diff --git a/quetz/tests/data/test-server/repo/noarch/repodata.json.bz2 b/quetz/tests/data/test-server/repo/noarch/repodata.json.bz2 new file mode 100644 index 00000000..76c130c0 Binary files /dev/null and b/quetz/tests/data/test-server/repo/noarch/repodata.json.bz2 differ diff --git a/quetz/tests/data/test-server/repo/noarch/repodata_from_packages.json b/quetz/tests/data/test-server/repo/noarch/repodata_from_packages.json new file mode 100644 index 00000000..52facdf7 --- /dev/null +++ b/quetz/tests/data/test-server/repo/noarch/repodata_from_packages.json @@ -0,0 +1,25 @@ +{ + "info": { + "subdir": "noarch" + }, + "packages": { + "test-package-0.1-0.tar.bz2": { + "build": "0", + "build_number": 0, + "depends": [], + "license": "BSD", + "license_family": "BSD", + "md5": "2a8595f37faa2950e1b433acbe91d481", + "name": "test-package", + "noarch": "generic", + "sha256": "b908ffce2d26d94c58c968abf286568d4bcf87d1cfe6c994958351724a6f6988", + "size": 5719, + "subdir": "noarch", + "timestamp": 1613117294885, + "version": "0.1" + } + }, + "packages.conda": {}, + "removed": [], + "repodata_version": 1 +} diff --git a/quetz/tests/data/test-server/repo/noarch/repodata_from_packages.json.bz2 b/quetz/tests/data/test-server/repo/noarch/repodata_from_packages.json.bz2 new file mode 100644 index 00000000..76c130c0 Binary files /dev/null and b/quetz/tests/data/test-server/repo/noarch/repodata_from_packages.json.bz2 differ diff --git a/quetz/tests/data/test-server/repo/noarch/test-package-0.1-0.tar.bz2 b/quetz/tests/data/test-server/repo/noarch/test-package-0.1-0.tar.bz2 new file mode 100644 index 00000000..f84ad738 Binary files /dev/null and b/quetz/tests/data/test-server/repo/noarch/test-package-0.1-0.tar.bz2 differ diff --git a/quetz/tests/data/test-server/reposerver.py b/quetz/tests/data/test-server/reposerver.py new file mode 100644 index 00000000..56bd2deb --- /dev/null +++ b/quetz/tests/data/test-server/reposerver.py @@ -0,0 +1,395 @@ +# File taken from https://github.com/mamba-org/mamba/tree/a8d595b6ff8ac182e60741c3e8cbd142e7d19905/mamba/tests +# under BSD-3-Clause license + +import argparse +import base64 +import glob +import os +import re +import shutil +import sys +from http.server import HTTPServer, SimpleHTTPRequestHandler +from pathlib import Path +from typing import Dict, List + +try: + import conda_content_trust.authentication as cct_authentication + import conda_content_trust.common as cct_common + import conda_content_trust.metadata_construction as cct_metadata_construction + import conda_content_trust.root_signing as cct_root_signing + import conda_content_trust.signing as cct_signing + + conda_content_trust_available = True +except ImportError: + conda_content_trust_available = False + + +def fatal_error(message: str) -> None: + """Print error and exit.""" + print(message, file=sys.stderr) + exit(1) + + +def get_fingerprint(gpg_output: str) -> str: + lines = gpg_output.splitlines() + fpline = lines[1].strip() + fpline = fpline.replace(" ", "") + return fpline + + +KeySet = Dict[str, List[Dict[str, str]]] + + +def normalize_keys(keys: KeySet) -> KeySet: + out = {} + for ik, iv in keys.items(): + out[ik] = [] + for el in iv: + if isinstance(el, str): + el = el.lower() + keyval = cct_root_signing.fetch_keyval_from_gpg(el) + res = {"fingerprint": el, "public": keyval} + elif isinstance(el, dict): + res = { + "private": el["private"].lower(), + "public": el["public"].lower(), + } + out[ik].append(res) + + return out + + +class RepoSigner: + keys = { + "root": [], + "key_mgr": [ + { + "private": "c9c2060d7e0d93616c2654840b4983d00221d8b6b69c850107da74b42168f937", + "public": "013ddd714962866d12ba5bae273f14d48c89cf0773dee2dbf6d4561e521c83f7", + }, + ], + "pkg_mgr": [ + { + "private": "f3cdab14740066fb277651ec4f96b9f6c3e3eb3f812269797b9656074cd52133", + "public": "f46b5a7caa43640744186564c098955147daa8bac4443887bc64d8bfee3d3569", + } + ], + } + + def __init__(self, in_folder: str) -> None: + self.in_folder = Path(in_folder).resolve() + self.folder = self.in_folder.parent / (str(self.in_folder.name) + "_signed") + self.keys["root"] = [ + get_fingerprint(os.environ["KEY1"]), + get_fingerprint(os.environ["KEY2"]), + ] + self.keys = normalize_keys(self.keys) + + def make_signed_repo(self) -> Path: + print("[reposigner] Using keys:", self.keys) + print("[reposigner] Using folder:", self.folder) + + self.folder.mkdir(exist_ok=True) + self.create_root(self.keys) + self.create_key_mgr(self.keys) + for f in glob.glob(str(self.in_folder / "**" / "repodata.json")): + self.sign_repodata(Path(f), self.keys) + return self.folder + + def create_root(self, keys): + root_keys = keys["root"] + + root_pubkeys = [k["public"] for k in root_keys] + key_mgr_pubkeys = [k["public"] for k in keys["key_mgr"]] + + root_version = 1 + + root_md = cct_metadata_construction.build_root_metadata( + root_pubkeys=root_pubkeys[0:1], + root_threshold=1, + root_version=root_version, + key_mgr_pubkeys=key_mgr_pubkeys, + key_mgr_threshold=1, + ) + + # Wrap the metadata in a signing envelope. + root_md = cct_signing.wrap_as_signable(root_md) + + root_md_serialized_unsigned = cct_common.canonserialize(root_md) + + root_filepath = self.folder / f"{root_version}.root.json" + print("Writing out: ", root_filepath) + # Write unsigned sample root metadata. + with open(root_filepath, "wb") as fout: + fout.write(root_md_serialized_unsigned) + + # This overwrites the file with a signed version of the file. + cct_root_signing.sign_root_metadata_via_gpg( + root_filepath, root_keys[0]["fingerprint"] + ) + + # Load untrusted signed root metadata. + signed_root_md = cct_common.load_metadata_from_file(root_filepath) + + cct_authentication.verify_signable(signed_root_md, root_pubkeys, 1, gpg=True) + + print("[reposigner] Root metadata signed & verified!") + + def create_key_mgr(self, keys): + private_key_key_mgr = cct_common.PrivateKey.from_hex( + keys["key_mgr"][0]["private"] + ) + pkg_mgr_pub_keys = [k["public"] for k in keys["pkg_mgr"]] + key_mgr = cct_metadata_construction.build_delegating_metadata( + metadata_type="key_mgr", # 'root' or 'key_mgr' + delegations={"pkg_mgr": {"pubkeys": pkg_mgr_pub_keys, "threshold": 1}}, + version=1, + # timestamp default: now + # expiration default: now plus root expiration default duration + ) + + key_mgr = cct_signing.wrap_as_signable(key_mgr) + + # sign dictionary in place + cct_signing.sign_signable(key_mgr, private_key_key_mgr) + + key_mgr_serialized = cct_common.canonserialize(key_mgr) + with open(self.folder / "key_mgr.json", "wb") as fobj: + fobj.write(key_mgr_serialized) + + # let's run a verification + root_metadata = cct_common.load_metadata_from_file(self.folder / "1.root.json") + key_mgr_metadata = cct_common.load_metadata_from_file( + self.folder / "key_mgr.json" + ) + + cct_common.checkformat_signable(root_metadata) + + if "delegations" not in root_metadata["signed"]: + raise ValueError('Expected "delegations" entry in root metadata.') + + root_delegations = root_metadata["signed"]["delegations"] # for brevity + cct_common.checkformat_delegations(root_delegations) + if "key_mgr" not in root_delegations: + raise ValueError( + 'Missing expected delegation to "key_mgr" in root metadata.' + ) + cct_common.checkformat_delegation(root_delegations["key_mgr"]) + + # Doing delegation processing. + cct_authentication.verify_delegation("key_mgr", key_mgr_metadata, root_metadata) + + print("[reposigner] success: key mgr metadata verified based on root metadata.") + + return key_mgr + + def sign_repodata(self, repodata_fn, keys): + target_folder = self.folder / repodata_fn.parent.name + if not target_folder.exists(): + target_folder.mkdir() + + final_fn = target_folder / repodata_fn.name + print("copy", repodata_fn, final_fn) + shutil.copyfile(repodata_fn, final_fn) + + pkg_mgr_key = keys["pkg_mgr"][0]["private"] + cct_signing.sign_all_in_repodata(str(final_fn), pkg_mgr_key) + print(f"[reposigner] Signed {final_fn}") + + +class ChannelHandler(SimpleHTTPRequestHandler): + url_pattern = re.compile(r"^/(?:t/[^/]+/)?([^/]+)") + + def do_GET(self) -> None: + # First extract channel name + channel_name = None + if tuple(channels.keys()) != (None,): + match = self.url_pattern.match(self.path) + if match: + channel_name = match.group(1) + # Strip channel for file server + start, end = match.span(1) + self.path = self.path[:start] + self.path[end:] + + # Then dispatch to appropriate auth method + if channel_name in channels: + channel = channels[channel_name] + self.directory = channel["directory"] + auth = channel["auth"] + if auth == "none": + return SimpleHTTPRequestHandler.do_GET(self) + elif auth == "basic": + server_key = base64.b64encode( + bytes(f"{channel['user']}:{channel['password']}", "utf-8") + ).decode("ascii") + return self.basic_do_GET(server_key=server_key) + elif auth == "bearer": + return self.bearer_do_GET(server_key=channel["bearer"]) + elif auth == "token": + return self.token_do_GET(server_token=channel["token"]) + + self.send_response(404) + + def basic_do_HEAD(self) -> None: + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + def basic_do_AUTHHEAD(self) -> None: + self.send_response(401) + self.send_header("WWW-Authenticate", 'Basic realm="Test"') + self.send_header("Content-type", "text/html") + self.end_headers() + + def bearer_do_GET(self, server_key: str) -> None: + auth_header = self.headers.get("Authorization", "") + print(auth_header) + print(f"Bearer {server_key}") + if not auth_header or auth_header != f"Bearer {server_key}": + self.send_response(403) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"no valid api key received") + else: + SimpleHTTPRequestHandler.do_GET(self) + + def basic_do_GET(self, server_key: str) -> None: + """Present frontpage with basic user authentication.""" + auth_header = self.headers.get("Authorization", "") + + if not auth_header: + self.basic_do_AUTHHEAD() + self.wfile.write(b"no auth header received") + elif auth_header == "Basic " + server_key: + SimpleHTTPRequestHandler.do_GET(self) + else: + self.basic_do_AUTHHEAD() + self.wfile.write(auth_header.encode("ascii")) + self.wfile.write(b"not authenticated") + + token_pattern = re.compile("^/t/([^/]+?)/") + + def token_do_GET(self, server_token: str) -> None: + """Present frontpage with user authentication.""" + match = self.token_pattern.search(self.path) + if match: + prefix_length = len(match.group(0)) - 1 + new_path = self.path[prefix_length:] + found_token = match.group(1) + if found_token == server_token: + self.path = new_path + return SimpleHTTPRequestHandler.do_GET(self) + + self.send_response(403) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"no valid api key received") + + +global_parser = argparse.ArgumentParser( + description="Start a multi-channel conda package server." +) +global_parser.add_argument("-p", "--port", type=int, default=8000, help="Port to use.") + +channel_parser = argparse.ArgumentParser( + description="Start a simple conda package server." +) +channel_parser.add_argument( + "-d", + "--directory", + type=str, + default=os.getcwd(), + help="Root directory for serving.", +) +channel_parser.add_argument( + "-n", + "--name", + type=str, + default=None, + help="Unique name of the channel used in URL", +) +channel_parser.add_argument( + "-a", + "--auth", + default=None, + type=str, + help="auth method (none, basic, token, or bearer)", +) +channel_parser.add_argument( + "--sign", + action="store_true", + help="Sign repodata (note: run generate_gpg_keys.sh before)", +) +channel_parser.add_argument( + "--token", + type=str, + default=None, + help="Use token as API Key", +) +channel_parser.add_argument( + "--bearer", + type=str, + default=None, + help="Use bearer token as API Key", +) +channel_parser.add_argument( + "--user", + type=str, + default=None, + help="Use token as API Key", +) +channel_parser.add_argument( + "--password", + type=str, + default=None, + help="Use token as API Key", +) + + +# Gobal args can be given anywhere with the first set of args for backward compatibility. +args, argv_remaining = global_parser.parse_known_args() +PORT = args.port + +# Iteratively parse arguments in sets. +# Each argument set, separated by -- in the CLI is for a channel. +# Credits: @hpaulj on SO https://stackoverflow.com/a/26271421 +channels = {} +while argv_remaining: + args, argv_remaining = channel_parser.parse_known_args(argv_remaining) + # Drop leading -- to move to next argument set + argv_remaining = argv_remaining[1:] + # Consolidation + if not args.auth: + if args.user and args.password: + args.auth = "basic" + elif args.token: + args.auth = "token" + elif args.bearer: + args.auth = "bearer" + else: + args.auth = "none" + if args.sign: + if not conda_content_trust_available: + fatal_error("Conda content trust not installed!") + args.directory = RepoSigner(args.directory).make_signed_repo() + + # name = args.name if args.name else Path(args.directory).name + # args.name = name + channels[args.name] = vars(args) + +print(channels) + +# Unamed channel in multi-channel case would clash URLs but we want to allow +# a single unamed channel for backward compatibility. +if (len(channels) > 1) and (None in channels): + fatal_error("Cannot use empty channel name when using multiple channels") + +server = HTTPServer(("", PORT), ChannelHandler) +print("Server started at localhost:" + str(PORT)) +try: + server.serve_forever() +except: + # Catch all sorts of interrupts + print("Shutting server down") + server.shutdown() + print("Server shut down") diff --git a/quetz/tests/test_indexing.py b/quetz/tests/test_indexing.py index dd76ecb0..b8541352 100644 --- a/quetz/tests/test_indexing.py +++ b/quetz/tests/test_indexing.py @@ -1,9 +1,13 @@ +import bz2 +import gzip import json from pathlib import Path import pytest +import zstandard from quetz import channel_data +from quetz.config import CompressionConfig from quetz.tasks.indexing import update_indexes @@ -12,10 +16,30 @@ def empty_channeldata(dao): return channel_data.export(dao, "") -def test_update_indexes_empty_channel(config, public_channel, dao, empty_channeldata): +def expected_compressed_files(files, bz2_enabled, gz_enabled, zst_enabled): + args = locals().copy() + return [ + f"{s}.{suffix}" + for s in files + for suffix in ["bz2", "gz", "zst"] + if s.endswith(".json") and args[f"{suffix}_enabled"] + ] + + +@pytest.mark.parametrize("bz2_enabled", [True, False]) +@pytest.mark.parametrize("gz_enabled", [True, False]) +@pytest.mark.parametrize("zst_enabled", [True, False]) +def test_update_indexes_empty_channel( + config, public_channel, dao, empty_channeldata, bz2_enabled, gz_enabled, zst_enabled +): pkgstore = config.get_package_store() - update_indexes(dao, pkgstore, public_channel.name) + update_indexes( + dao, + pkgstore, + public_channel.name, + compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled), + ) files = pkgstore.list_files(public_channel.name) @@ -25,11 +49,10 @@ def test_update_indexes_empty_channel(config, public_channel, dao, empty_channel "noarch/index.html", "noarch/repodata.json", ] - expected_files = base_files.copy() - - for suffix in [".bz2", ".gz"]: - expected_files.extend(s + suffix for s in base_files) + expected_files.extend( + expected_compressed_files(base_files, bz2_enabled, gz_enabled, zst_enabled) + ) assert sorted(files) == sorted(expected_files) @@ -38,12 +61,27 @@ def test_update_indexes_empty_channel(config, public_channel, dao, empty_channel assert json.load(fd) == empty_channeldata +@pytest.mark.parametrize("bz2_enabled", [True, False]) +@pytest.mark.parametrize("gz_enabled", [True, False]) +@pytest.mark.parametrize("zst_enabled", [True, False]) def test_update_indexes_empty_package( - config, public_channel, public_package, dao, empty_channeldata + config, + public_channel, + public_package, + dao, + empty_channeldata, + bz2_enabled, + gz_enabled, + zst_enabled, ): pkgstore = config.get_package_store() - update_indexes(dao, pkgstore, public_channel.name) + update_indexes( + dao, + pkgstore, + public_channel.name, + compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled), + ) files = pkgstore.list_files(public_channel.name) @@ -55,9 +93,9 @@ def test_update_indexes_empty_package( ] expected_files = base_files.copy() - - for suffix in [".bz2", ".gz"]: - expected_files.extend(s + suffix for s in base_files) + expected_files.extend( + expected_compressed_files(base_files, bz2_enabled, gz_enabled, zst_enabled) + ) assert sorted(files) == sorted(expected_files) @@ -71,12 +109,28 @@ def test_update_indexes_empty_package( assert channeldata == empty_channeldata +@pytest.mark.parametrize("bz2_enabled", [True, False]) +@pytest.mark.parametrize("gz_enabled", [True, False]) +@pytest.mark.parametrize("zst_enabled", [True, False]) def test_update_indexes_with_package_version( - config, public_channel, public_package, package_version, dao + config, + public_channel, + public_package, + package_version, + dao, + bz2_enabled, + gz_enabled, + zst_enabled, ): + args = locals().copy() pkgstore = config.get_package_store() - update_indexes(dao, pkgstore, public_channel.name) + update_indexes( + dao, + pkgstore, + public_channel.name, + compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled), + ) files = pkgstore.list_files(public_channel.name) @@ -90,10 +144,9 @@ def test_update_indexes_with_package_version( ] expected_files = base_files.copy() - - for suffix in [".bz2", ".gz"]: - expected_files.extend(s + suffix for s in base_files) - + expected_files.extend( + expected_compressed_files(base_files, bz2_enabled, gz_enabled, zst_enabled) + ) expected_files.append(f"linux-64/{package_version.filename}") assert sorted(files) == sorted(expected_files) @@ -103,3 +156,30 @@ def test_update_indexes_with_package_version( channeldata = json.load(fd) assert public_package.name in channeldata["packages"].keys() + + # Check compressed repodata identical to repodata.json when enabled + # or that it doesn't exist when disabled + extensions = ("bz2", "gz", "zst") + enabled_compression_extensions = [ + ext for ext in extensions if args[f"{ext}_enabled"] + ] + disabled_compression_extensions = set(extensions) - set( + enabled_compression_extensions + ) + for subdir in ("noarch", "linux-64"): + repodata_json_path = channel_dir / subdir / "repodata.json" + with open(repodata_json_path, "r") as fd: + ref_repodata = json.load(fd) + for extension in enabled_compression_extensions: + with open(f"{repodata_json_path}.{extension}", "rb") as fd: + compressed_data = fd.read() + if extension == "bz2": + data = bz2.decompress(compressed_data) + elif extension == "gz": + data = gzip.decompress(compressed_data) + else: + data = zstandard.ZstdDecompressor().decompress(compressed_data) + repodata = json.loads(data) + assert repodata == ref_repodata + for extension in disabled_compression_extensions: + assert not Path(f"{repodata_json_path}.{extension}").exists() diff --git a/quetz/tests/test_mirror.py b/quetz/tests/test_mirror.py index e4a95485..1eaa4d4f 100644 --- a/quetz/tests/test_mirror.py +++ b/quetz/tests/test_mirror.py @@ -1,13 +1,19 @@ +import bz2 import concurrent.futures +import gzip import json import os +import subprocess +import time import uuid from io import BytesIO from pathlib import Path +from unittest import mock from unittest.mock import MagicMock from urllib.parse import urlparse import pytest +import zstandard from quetz import hookimpl, rest_models from quetz.authorization import Rules @@ -21,6 +27,7 @@ RemoteServerError, create_packages_from_channeldata, create_versions_from_repodata, + download_repodata, handle_repodata_package, initial_sync_mirror, ) @@ -170,7 +177,7 @@ def close(self): app.dependency_overrides[get_remote_session] = DummySession - return DummySession() + yield DummySession() app.dependency_overrides.pop(get_remote_session) @@ -194,6 +201,26 @@ def mirror_package(mirror_channel, db): db.commit() +@pytest.mark.parametrize( + "host, name", + [ + ("https://conda.anaconda.org/conda-forge", "conda-forge"), + ("https://conda.anaconda.org/conda-forge/", "conda-forge"), + ("https://repod.prefix.dev/conda-forge", "conda-forge"), + ("http://localhost:8000/mychannel", "mychannel"), + ("http://localhost:8000/path/mychannel", "mychannel"), + ], +) +def test_remote_repository_rattler_channel(host, name): + repository = RemoteRepository(host, session=None) + rattler_channel = repository.rattler_channel + assert rattler_channel.name == name + if host.endswith("/"): + assert rattler_channel.base_url == repository.host + else: + assert rattler_channel.base_url == f"{repository.host}/" + + def test_set_mirror_url(db, client, owner): response = client.get("/api/dummylogin/bartosz") assert response.status_code == 200 @@ -1378,3 +1405,170 @@ def test_handle_repodata_package_with_plugin( ) assert plugin.about["conda_version"] == "4.8.4" + + +@pytest.fixture(scope="session") +def serve_repo_data() -> None: + port, repo_name = 8912, "test-repo" + + test_data_dir = Path(__file__).parent / "data" / "test-server" + + with subprocess.Popen( + [ + "python", + str(test_data_dir / "reposerver.py"), + "-d", + str(test_data_dir / "repo"), + "-n", + repo_name, + "-p", + str(port), + ] + ) as proc: + time.sleep(0.5) + yield port, repo_name + proc.terminate() + + +@pytest.fixture(scope="session") +def test_repodata_json() -> bytes: + repodata_json_file = ( + Path(__file__).parent + / "data" + / "test-server" + / "repo" + / "noarch" + / "repodata.json" + ) + return repodata_json_file.read_bytes() + + +@pytest.fixture +def server_proxy_channel(db, serve_repo_data): + port, repo = serve_repo_data + channel = Channel( + name="server-proxy-channel", + mirror_channel_url=f"http://localhost:{port}/{repo}", + mirror_mode="proxy", + ) + db.add(channel) + db.commit() + + yield channel + + db.delete(channel) + db.commit() + + +@pytest.mark.asyncio +async def test_download_repodata( + config, + test_repodata_json, + serve_repo_data, +): + port, repo = serve_repo_data + repository = RemoteRepository(f"http://localhost:{port}/{repo}", session=None) + channel = "test-proxy-channel" + platform = "noarch" + rattler_cache_path = Path(config.general_rattler_cache_dir) / channel / platform + assert not rattler_cache_path.exists() + with mock.patch("quetz.tasks.mirror.Config", return_value=config): + result = await download_repodata(repository, channel, platform) + assert rattler_cache_path.is_dir() + assert result == test_repodata_json + + +@mock.patch("quetz.tasks.mirror.download_repodata") +def test_download_remote_file_repodata( + mock_download_repodata, client, server_proxy_channel, test_repodata_json +): + """Test downloading repodata.json using download_repodata.""" + # download from remote server using download_repodata + mock_download_repodata.return_value = test_repodata_json + assert not mock_download_repodata.called + response = client.get(f"/get/{server_proxy_channel.name}/noarch/repodata.json") + assert response.status_code == 200 + assert response.content == test_repodata_json + assert mock_download_repodata.call_count == 1 + # download from cache (download_repodata not called again) + response = client.get(f"/get/{server_proxy_channel.name}/noarch/repodata.json") + assert response.status_code == 200 + assert response.content == test_repodata_json + assert mock_download_repodata.call_count == 1 + + +@mock.patch("quetz.tasks.mirror.download_repodata") +def test_download_remote_file_current_repodata( + mock_download_repodata, client, server_proxy_channel, test_repodata_json +): + """Test downloading current_repodata.json.""" + response = client.get( + f"/get/{server_proxy_channel.name}/noarch/current_repodata.json" + ) + assert response.status_code == 200 + # current_repodata content is the same as test_repodata_json in the test repo + assert response.content == test_repodata_json + # download_repodata isn't used (download done with repository.open) + assert not mock_download_repodata.called + + +@pytest.fixture(scope="function") +def config_extra(request=None) -> str: + if request is None: + return "" + return ( + "[compression]\n" + f"bz2_enabled = {str(request.param[0]).lower()}\n" + f"gz_enabled = {str(request.param[1]).lower()}\n" + f"zst_enabled = {str(request.param[2]).lower()}\n" + ) + + +@pytest.mark.parametrize( + "config_extra", + [ + (False, False, False), + (False, False, True), + (True, True, True), + (True, True, False), + (False, True, True), + ], + indirect=True, +) +@mock.patch("quetz.tasks.mirror.download_repodata") +def test_download_remote_file_repodata_compressed( + mock_download_repodata, + client, + server_proxy_channel, + test_repodata_json, + config, +): + """Test downloading compressed repodata.json.""" + # download from remote server using download_repodata + mock_download_repodata.return_value = test_repodata_json + assert not mock_download_repodata.called + is_at_least_one_extension_enabled = False + for extension in ("bz2", "gz", "zst"): + response = client.get( + f"/get/{server_proxy_channel.name}/noarch/repodata.json.{extension}" + ) + if getattr(config, f"compression_{extension}_enabled"): + is_at_least_one_extension_enabled = True + assert response.status_code == 200 + if extension == "bz2": + data = bz2.decompress(response.content) + elif extension == "gz": + data = gzip.decompress(response.content) + else: + data = zstandard.ZstdDecompressor().decompress(response.content) + assert data == test_repodata_json + else: + assert response.status_code == 404 + if is_at_least_one_extension_enabled: + # Check that download_repodata is called only once + # (json file downloaded and compressed the first time only) + # cache used afterwards + assert mock_download_repodata.call_count == 1 + else: + # No extension enabled - only 404 returned + assert mock_download_repodata.call_count == 0 diff --git a/quetz/utils.py b/quetz/utils.py index eb678769..1c2a86a4 100644 --- a/quetz/utils.py +++ b/quetz/utils.py @@ -14,16 +14,32 @@ import time import traceback import uuid +from dataclasses import dataclass from datetime import datetime, timezone from functools import wraps from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, Optional, Union from urllib.parse import unquote +import zstandard from sqlalchemy import String, and_, cast, collate, not_, or_ +from .config import CompressionConfig from .db_models import Channel, Package, PackageVersion, User +# Same values as conda-index +# https://github.com/conda/conda-index/blob/58cfdba8cf37b0aa9f5876665025c5949f046a4b/conda_index/index/__init__.py#L46 +ZSTD_COMPRESS_LEVEL = 16 +ZSTD_COMPRESS_THREADS = -1 # automatic + + +@dataclass +class Compressed: + raw_file: bytes + bz2_file: Optional[bytes] + gz_file: Optional[bytes] + zst_file: Optional[bytes] + def check_package_membership(package_name, includelist, excludelist): if includelist: @@ -39,32 +55,43 @@ def check_package_membership(package_name, includelist, excludelist): return True -def add_static_file(contents, channel_name, subdir, fname, pkgstore, file_index=None): - if not isinstance(contents, bytes): - raw_file = contents.encode("utf-8") - else: - raw_file = contents - bz2_file = bz2.compress(raw_file) - gzp_file = gzip.compress(raw_file) - +def add_static_file( + contents, + channel_name, + subdir, + fname, + pkgstore, + file_index=None, + compression: Optional[CompressionConfig] = None, +): + if compression is None: + compression = CompressionConfig(False, False, False) + compressed = compress_file(contents, compression) path = f"{subdir}/{fname}" if subdir else fname - pkgstore.add_file(bz2_file, channel_name, f"{path}.bz2") - pkgstore.add_file(gzp_file, channel_name, f"{path}.gz") - pkgstore.add_file(raw_file, channel_name, f"{path}") + if compression.bz2_enabled: + pkgstore.add_file(compressed.bz2_file, channel_name, f"{path}.bz2") + if compression.gz_enabled: + pkgstore.add_file(compressed.gz_file, channel_name, f"{path}.gz") + if compression.zst_enabled: + pkgstore.add_file(compressed.zst_file, channel_name, f"{path}.zst") + pkgstore.add_file(compressed.raw_file, channel_name, f"{path}") if file_index: - add_entry_for_index(file_index, subdir, fname, raw_file) - add_entry_for_index(file_index, subdir, f"{fname}.bz2", bz2_file) - add_entry_for_index(file_index, subdir, f"{fname}.gz", gzp_file) + add_compressed_entry_for_index(file_index, subdir, fname, compressed) def add_temp_static_file( - contents, channel_name, subdir, fname, temp_dir, file_index=None + contents, + channel_name, + subdir, + fname, + temp_dir, + file_index=None, + compression: Optional[CompressionConfig] = None, ): - if not isinstance(contents, bytes): - raw_file = contents.encode("utf-8") - else: - raw_file = contents + if compression is None: + compression = CompressionConfig(False, False, False) + compressed = compress_file(contents, compression) temp_dir = Path(temp_dir) @@ -79,21 +106,48 @@ def add_temp_static_file( file_path = path / fname with open(file_path, "wb") as fo: - fo.write(raw_file) + fo.write(compressed.raw_file) + if compressed.bz2_file: + with open(f"{file_path}.bz2", "wb") as fo: + fo.write(compressed.bz2_file) + if compressed.gz_file: + with open(f"{file_path}.gz", "wb") as fo: + fo.write(compressed.gz_file) + if compressed.zst_file: + with open(f"{file_path}.zst", "wb") as fo: + fo.write(compressed.zst_file) - bz2_file = bz2.compress(raw_file) - gzp_file = gzip.compress(raw_file) - - with open(f"{file_path}.bz2", "wb") as fo: - fo.write(bz2_file) + if file_index: + add_compressed_entry_for_index(file_index, subdir, fname, compressed) - with open(f"{file_path}.gz", "wb") as fo: - fo.write(gzp_file) - if file_index: - add_entry_for_index(file_index, subdir, fname, raw_file) - add_entry_for_index(file_index, subdir, f"{fname}.bz2", bz2_file) - add_entry_for_index(file_index, subdir, f"{fname}.gz", gzp_file) +def compress_file( + contents: Union[str, bytes], compression: CompressionConfig +) -> Compressed: + if not isinstance(contents, bytes): + raw_file = contents.encode("utf-8") + else: + raw_file = contents + bz2_file = bz2.compress(raw_file) if compression.bz2_enabled else None + gz_file = gzip.compress(raw_file) if compression.gz_enabled else None + zst_file = ( + zstandard.ZstdCompressor( + level=ZSTD_COMPRESS_LEVEL, threads=ZSTD_COMPRESS_THREADS + ).compress(raw_file) + if compression.zst_enabled + else None + ) + return Compressed(raw_file, bz2_file, gz_file, zst_file) + + +def add_compressed_entry_for_index(file_index, subdir, fname, compressed: Compressed): + add_entry_for_index(file_index, subdir, fname, compressed.raw_file) + if compressed.bz2_file: + add_entry_for_index(file_index, subdir, f"{fname}.bz2", compressed.bz2_file) + if compressed.gz_file: + add_entry_for_index(file_index, subdir, f"{fname}.gz", compressed.gz_file) + if compressed.zst_file: + add_entry_for_index(file_index, subdir, f"{fname}.zst", compressed.zst_file) def add_entry_for_index(files, subdir, fname, data_bytes): diff --git a/setup.cfg b/setup.cfg index 9ac47e75..83a9ffdd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ install_requires = python-multipart pydantic>=2.0.0 pyyaml + py-rattler requests sqlalchemy sqlalchemy-utils