Skip to content

Commit

Permalink
Add ability to load config from drive
Browse files Browse the repository at this point in the history
Internal-tag: [#56628]
Signed-off-by: Robert Winkler <[email protected]>
  • Loading branch information
rw1nkler committed Apr 15, 2024
1 parent e73a479 commit 1d6784f
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"PyYAML",
"click",
"amaranth==0.4.0.*",
"marshmallow_dataclass",
"numexpr",
"typing_extensions",
"pipeline_manager_backend_communication @ git+https://github.com/antmicro/kenning-pipeline-manager-backend-communication@eb690cfb7766bfbd85a4eff2a1e809573b8b72d0",
Expand Down
134 changes: 134 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2024 Antmicro <www.antmicro.com>
# SPDX-License-Identifier: Apache-2.0

import logging
from pathlib import Path

import pytest
import yaml

from topwrap.config import Config, ConfigManager, RepositoryEntry


class TestConfigManager:
@pytest.fixture
def config_dict(self):
return Config.Schema().dump(
Config(
force_interface_compliance=True,
repositories=[
RepositoryEntry("My topwrap repo", "~/custom/repo/path"),
],
)
)

@pytest.fixture
def custom_config_dicts(self):
return [
(
"custom/path/cfg.yml",
Config.Schema().dump(Config(repositories=[RepositoryEntry("repo1", "path1")])),
),
(
"/global/path/mycfg.yaml",
Config.Schema().dump(Config(repositories=[RepositoryEntry("repo2", "path2")])),
),
]

@pytest.fixture
def incorrect_config_dicts(self):
return [
{
"force_interface_compliance": True,
"repositories": [
{
"name": "My topwrap repo",
"path": "~/custom/repo/path",
"info": "Info should not be here",
}
],
},
{
"force_interface_compliance": True,
"meta": "A missing 'repositories' entry is correct, an additional custom entry is not",
},
]

@staticmethod
def contains_warnings_in_log(caplog):
for name, level, msg in caplog.record_tuples:
if name == "topwrap.config" and level == logging.WARNING:
return True
return False

def test_adding_repo_duplicates(self, fs, config_dict, caplog):
(repo_dict,) = config_dict["repositories"]

manager = ConfigManager()
for path in manager.search_paths:
config_str = yaml.dump(config_dict)
fs.create_file(path, contents=config_str)

config = manager.load()
assert len(config.repositories) == 1
assert not self.contains_warnings_in_log(caplog)

def test_loading_order(self, fs, config_dict, caplog):
(repo_dict,) = config_dict["repositories"]

manager = ConfigManager()
for i, path in enumerate(manager.search_paths):
repo_dict["name"] = str(i)
repo_dict["path"] = str(path)
config_str = yaml.dump(config_dict)
fs.create_file(path, contents=config_str)

config = manager.load()
assert config.repositories == [
RepositoryEntry(name=str(i), path=str(manager.search_paths[i]))
for i in reversed(range(len(manager.search_paths)))
]
assert not self.contains_warnings_in_log(caplog)

def test_custom_search_patchs(self, fs, custom_config_dicts, caplog):
for path, config_dict in custom_config_dicts:
config_str = yaml.dump(config_dict)
fs.create_file(path, contents=config_str)

paths, config_dicts = zip(*custom_config_dicts)
config = ConfigManager(paths).load()
assert len(config.repositories) == len(config_dicts)
assert not self.contains_warnings_in_log(caplog)

def test_config_override(self, fs, config_dict, caplog):
config_path = Path(ConfigManager.DEFAULT_SEARCH_PATHS[0]).expanduser()
config_str = yaml.dump(config_dict)
fs.create_file(config_path, contents=config_str)

manager = ConfigManager()

(repo_dict,) = config_dict["repositories"]

config = manager.load()
assert config.force_interface_compliance is True
assert config.repositories == [RepositoryEntry(repo_dict["name"], repo_dict["path"])]

override_config = Config(
force_interface_compliance=False,
repositories=None,
)

config2 = manager.load(override_config)
assert config2.force_interface_compliance is False
assert config2.repositories == [RepositoryEntry(repo_dict["name"], repo_dict["path"])]
assert not self.contains_warnings_in_log(caplog)

def test_loading_incorrect_configs(self, fs, incorrect_config_dicts, caplog):
config_path = Path(ConfigManager.DEFAULT_SEARCH_PATHS[0]).expanduser()
for incorrect_config in incorrect_config_dicts:
manager = ConfigManager()
config_str = yaml.dump(incorrect_config)
fs.create_file(config_path, contents=config_str)
manager.load()
assert self.contains_warnings_in_log(caplog)
config_path.unlink()
105 changes: 100 additions & 5 deletions topwrap/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,109 @@
# Copyright (c) 2021-2024 Antmicro <www.antmicro.com>
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import field
from os import PathLike
from pathlib import Path
from typing import List, Optional

import marshmallow
import marshmallow_dataclass
import yaml

logger = logging.getLogger(__name__)


class InvalidConfigError(Exception):
"""Raised when the provided configuration is incorrect"""


@marshmallow_dataclass.dataclass
class RepositoryEntry:
"""Contains information about topwrap repository"""

name: str
path: str


@marshmallow_dataclass.dataclass
class Config:
"""Configuration class used to store global choices
for behavior of Topwrap.
"""Global topwrap configuration"""

force_interface_compliance: Optional[bool] = field(
default=False, metadata={"load_default": None}
)
repositories: Optional[List[RepositoryEntry]] = field(
default_factory=list, metadata={"load_default": None}
)

def update(self, config: "Config"):
if config.force_interface_compliance is not None:
self.force_interface_compliance = config.force_interface_compliance

if config.repositories is not None:
if self.repositories is None:
self.repositories = config.repositories
else:
for repo in config.repositories:
if repo not in self.repositories:
self.repositories.append(repo)


class ConfigManager:
"""Manager used to load topwrap's configuration from files.
The configuration files are loaded in a specific order, which also
determines the priority of settings that are defined differently
in the files. The list of default search paths is defined in
the `DEFAULT_SEARCH_PATH` class variable. Configuration files that
are specified earlier in the list have higher priority and can
overwrite the settings from the files that follow. The default list of
search paths can be changed by passing a different list to
the ConfigManager constructor.
"""

def __init__(self, force_interface_compliance=False):
self.force_interface_compliance = force_interface_compliance
DEFAULT_SEARCH_PATHS = [
"topwrap.yaml",
"~/.config/topwrap/topwrap.yaml",
"~/.config/topwrap/config.yaml",
]

def __init__(self, search_paths: Optional[List[PathLike]] = None):
if search_paths is None:
search_paths = self.DEFAULT_SEARCH_PATHS

self.search_paths = []
for path in search_paths:
self.search_paths += [Path(path).expanduser()]

def load(self, overrides: Optional[Config] = None, default: Optional[Config] = None):
config = Config() if default is None else default

for path in reversed(self.search_paths):
if not path.is_file():
continue

with open(path) as f:
try:
yaml_dict = yaml.safe_load(f)
except yaml.YAMLError:
logger.warning(f"{path} configuration file is not a valid YAML")
continue

try:
new_config = Config.Schema().load(yaml_dict)
config.update(new_config)
except marshmallow.ValidationError as e:
logger.warning(f"{path} configuration file is not valid ({e.messages})")
continue

if overrides is not None:
config.update(overrides)

logger.debug(f"Final configuration used by topwrap: {config}")

return config


config = Config()
config = ConfigManager().load()

0 comments on commit 1d6784f

Please sign in to comment.