Skip to content

Commit

Permalink
Use dataclass to store and pass compression config
Browse files Browse the repository at this point in the history
  • Loading branch information
beenje committed Dec 10, 2023
1 parent 4eb6761 commit 381bcee
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 106 deletions.
47 changes: 33 additions & 14 deletions quetz/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import logging.config
import os
from dataclasses import dataclass
from distutils.util import strtobool
from secrets import token_bytes
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Type, Union
Expand All @@ -22,6 +23,24 @@
_user_dir = appdirs.user_config_dir("quetz")

PAGINATION_LIMIT = 20
COMPRESSION_EXTENSIONS = ["bz2", "gz", "zst"]


@dataclass
class CompressionConfig:
bz2_enabled: bool
gz_enabled: bool
zst_enabled: bool

def enabled_extensions(self):
return [
ext for ext in COMPRESSION_EXTENSIONS if getattr(self, f"{ext}_enabled")
]

def disabled_extensions(self):
return [
ext for ext in COMPRESSION_EXTENSIONS if not getattr(self, f"{ext}_enabled")
]


class ConfigEntry(NamedTuple):
Expand Down Expand Up @@ -444,6 +463,20 @@ def _get_environ_config(self) -> Dict[str, Any]:

return config

def get_compression_config(self) -> CompressionConfig:
"""Return the compression configuration.
Returns
-------
compression_config : CompressionConfig
The compression configuration
"""
return CompressionConfig(
self.compression_bz2_enabled,
self.compression_gz_enabled,
self.compression_zst_enabled,
)

def get_package_store(self) -> pkgstores.PackageStore:
"""Return the appropriate package store as set in the config.
Expand Down Expand Up @@ -495,20 +528,6 @@ def get_package_store(self) -> pkgstores.PackageStore:
}
)

def get_enabled_compression_extensions(self):
return [
ext
for ext in ("bz2", "gz", "zst")
if getattr(self, f"compression_{ext}_enabled")
]

def get_disabled_compression_extensions(self):
return [
ext
for ext in ("bz2", "gz", "zst")
if not getattr(self, f"compression_{ext}_enabled")
]

def configured_section(self, section: str) -> bool:
"""Return if a given section has been configured.
Expand Down
29 changes: 9 additions & 20 deletions quetz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,7 @@ def post_channel(
dao,
pkgstore,
new_channel.name,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)

# register mirror
Expand Down Expand Up @@ -912,9 +910,7 @@ def delete_package(
pkgstore,
channel_name,
platforms,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)


Expand Down Expand Up @@ -1190,9 +1186,7 @@ def delete_package_version(
pkgstore,
channel_name,
[platform],
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)


Expand Down Expand Up @@ -1387,9 +1381,7 @@ def post_file_to_package(
dao,
pkgstore,
package.channel_name,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)


Expand Down Expand Up @@ -1481,9 +1473,7 @@ async def post_upload(
dao,
pkgstore,
channel_name,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)


Expand All @@ -1510,9 +1500,7 @@ def post_file_to_channel(
dao,
pkgstore,
channel.name,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)


Expand Down Expand Up @@ -1846,8 +1834,9 @@ def serve_path(
chunk_size = 10_000

# Ensure we don't serve an old compressed file if this compression is now disabled
compression = config.get_compression_config()
disabled_compressed_json_extensions = tuple(
f".json.{ext}" for ext in config.get_disabled_compression_extensions()
f".json.{ext}" for ext in compression.disabled_extensions()
)
if path.endswith(disabled_compressed_json_extensions):
suffix = PurePath(path).suffix
Expand Down Expand Up @@ -1884,7 +1873,7 @@ def serve_path(
if channel.mirror_channel_url and channel.mirror_mode == "proxy":
repository = RemoteRepository(channel.mirror_channel_url, session)
enabled_compressed_json_extensions = tuple(
f".json.{ext}" for ext in config.get_enabled_compression_extensions()
f".json.{ext}" for ext in compression.enabled_extensions()
)
if path.endswith((".json",) + enabled_compressed_json_extensions):
# repodata.json and current_repodata.json are cached locally
Expand Down
11 changes: 2 additions & 9 deletions quetz/tasks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def __init__(
self.dao = dao.Dao(db)
config = get_config()
self.pkgstore = config.get_package_store()
self.bz2_enabled = config.compression_bz2_enabled
self.gz_enabled = config.compression_gz_enabled
self.zst_enabled = config.compression_zst_enabled
self.compression = config.get_compression_config()

def execute_channel_action(
self,
Expand Down Expand Up @@ -104,12 +102,7 @@ def execute_channel_action(
)
elif action == ChannelActionEnum.validate_packages:
auth.assert_validate_package_cache(channel_name)
extra_args = dict(
channel_name=channel.name,
bz2_enabled=self.bz2_enabled,
gz_enabled=self.gz_enabled,
zst_enabled=self.zst_enabled,
)
extra_args = dict(channel_name=channel.name, compression=self.compression)
task = self.jobs_dao.create_job(
action.encode('ascii'),
user_id,
Expand Down
22 changes: 8 additions & 14 deletions quetz/tasks/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional

from jinja2 import Environment, PackageLoader, select_autoescape
from jinja2.exceptions import UndefinedError

import quetz.config
from quetz import channel_data, repo_data
from quetz.condainfo import MAX_CONDA_TIMESTAMP
from quetz.config import CompressionConfig
from quetz.db_models import PackageVersion
from quetz.utils import add_static_file, add_temp_static_file

Expand Down Expand Up @@ -89,7 +91,7 @@ def _subdir_key(dir):


def validate_packages(
dao, pkgstore, channel_name, bz2_enabled, gz_enabled, zst_enabled
dao, pkgstore, channel_name, compression: Optional[CompressionConfig] = None
):
# for now we're just validating the size of the uploaded file
logger.info("Starting package validation")
Expand Down Expand Up @@ -181,9 +183,7 @@ def validate_packages(
dao,
pkgstore,
channel_name,
bz2_enabled=bz2_enabled,
gz_enabled=gz_enabled,
zst_enabled=zst_enabled,
compression=compression,
)


Expand All @@ -192,9 +192,7 @@ def update_indexes(
pkgstore,
channel_name,
subdirs=None,
bz2_enabled=False,
gz_enabled=False,
zst_enabled=False,
compression: Optional[CompressionConfig] = None,
):
jinjaenv = _jinjaenv()
channeldata = channel_data.export(dao, channel_name)
Expand All @@ -210,9 +208,7 @@ def update_indexes(
None,
"channeldata.json",
pkgstore,
bz2_enabled=bz2_enabled,
gz_enabled=gz_enabled,
zst_enabled=zst_enabled,
compression=compression,
)

# Generate index.html for the "root" directory
Expand All @@ -226,7 +222,7 @@ def update_indexes(
add_static_file(channel_index, channel_name, None, "index.html", pkgstore)

# NB. No rss.xml is being generated here
files = {}
files: Dict[str, Any] = {}
packages = {}
subdir_template = jinjaenv.get_template("subdir-index.html.j2")

Expand Down Expand Up @@ -259,9 +255,7 @@ def update_indexes(
"repodata.json",
tempdir_path,
files,
bz2_enabled=bz2_enabled,
gz_enabled=gz_enabled,
zst_enabled=zst_enabled,
compression=compression,
)

try:
Expand Down
12 changes: 3 additions & 9 deletions quetz/tasks/mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ def download_remote_file(
None,
path,
pkgstore,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)
else:
remote_file = repository.open(path)
Expand All @@ -190,9 +188,7 @@ def download_remote_file(
None,
path,
pkgstore,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)
else:
pkgstore.add_package(data_stream, channel, path)
Expand Down Expand Up @@ -508,9 +504,7 @@ def handle_batch(update_batch):
pkgstore,
channel_name,
subdirs=[arch],
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
compression=config.get_compression_config(),
)


Expand Down
7 changes: 1 addition & 6 deletions quetz/tasks/reindexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,7 @@ def reindex_packages_from_store(

try:
update_indexes(
dao,
pkgstore,
channel_name,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
dao, pkgstore, channel_name, compression=config.get_compression_config()
)
dao.db.commit()
except IntegrityError:
Expand Down
7 changes: 1 addition & 6 deletions quetz/tests/api/test_main_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,7 @@ def test_delete_package_versions_with_package(
assert package_version.package_name == public_package.name

update_indexes(
dao,
pkgstore,
public_channel.name,
bz2_enabled=config.compression_bz2_enabled,
gz_enabled=config.compression_gz_enabled,
zst_enabled=config.compression_zst_enabled,
dao, pkgstore, public_channel.name, compression=config.get_compression_config()
)

# Get package files
Expand Down
13 changes: 4 additions & 9 deletions quetz/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import zstandard

from quetz import channel_data
from quetz.config import CompressionConfig
from quetz.tasks.indexing import update_indexes


Expand Down Expand Up @@ -37,9 +38,7 @@ def test_update_indexes_empty_channel(
dao,
pkgstore,
public_channel.name,
bz2_enabled=bz2_enabled,
gz_enabled=gz_enabled,
zst_enabled=zst_enabled,
compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled),
)

files = pkgstore.list_files(public_channel.name)
Expand Down Expand Up @@ -81,9 +80,7 @@ def test_update_indexes_empty_package(
dao,
pkgstore,
public_channel.name,
bz2_enabled=bz2_enabled,
gz_enabled=gz_enabled,
zst_enabled=zst_enabled,
compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled),
)

files = pkgstore.list_files(public_channel.name)
Expand Down Expand Up @@ -132,9 +129,7 @@ def test_update_indexes_with_package_version(
dao,
pkgstore,
public_channel.name,
bz2_enabled=bz2_enabled,
gz_enabled=gz_enabled,
zst_enabled=zst_enabled,
compression=CompressionConfig(bz2_enabled, gz_enabled, zst_enabled),
)

files = pkgstore.list_files(public_channel.name)
Expand Down
4 changes: 3 additions & 1 deletion quetz/tests/test_mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,9 @@ def test_download_remote_file_current_repodata(


@pytest.fixture(scope="function")
def config_extra(request) -> str:
def config_extra(request=None) -> str:
if request is None:
return ""
return (
"[compression]\n"
f"bz2_enabled = {str(request.param[0]).lower()}\n"
Expand Down
Loading

0 comments on commit 381bcee

Please sign in to comment.