Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yaml config sets #2876

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
51 changes: 51 additions & 0 deletions src/_nebari/config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pathlib

from packaging.requirements import Requirement
from pydantic import BaseModel, ConfigDict, field_validator

from _nebari._version import __version__
from _nebari.utils import yaml


class ConfigSetMetadata(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True)
name: str # for use with guided init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If none is allowed here, shouldn't the type be Optional[str] = None?

description: str = None
nebari_version: str | Requirement

@field_validator("nebari_version")
@classmethod
def validate_version_requirement(cls, version_req):
if isinstance(version_req, str):
version_req = Requirement(f"nebari{version_req}")
version_req.specifier.prereleases = True

return version_req

def check_version(self, version):
if version not in self.nebari_version.specifier:
raise ValueError(
f"Current Nebari version {__version__} is not compatible with "
f'required version {self.nebari_version.specifier} for "{self.name}" config set.'
)


class ConfigSet(BaseModel):
metadata: ConfigSetMetadata
config: dict


def read_config_set(config_set_filepath: str):
"""Read a config set from a config file."""

filename = pathlib.Path(config_set_filepath)

with filename.open() as f:
config_set_yaml = yaml.load(f)

config_set = ConfigSet(**config_set_yaml)

# validation
config_set.metadata.check_version(__version__)

return config_set
10 changes: 8 additions & 2 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pydantic
import requests

from _nebari import constants
from _nebari import constants, utils
from _nebari.config_set import read_config_set
from _nebari.provider import git
from _nebari.provider.cicd import github
from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud
Expand Down Expand Up @@ -47,6 +48,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
config_set: str = None,
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
Expand Down Expand Up @@ -176,13 +178,17 @@ def render_config(
config["certificate"] = {"type": CertificateEnum.letsencrypt.value}
config["certificate"]["acme_email"] = ssl_cert_email

if config_set:
config_set = read_config_set(config_set)
config = utils.deep_merge(config, config_set.config)

# validate configuration and convert to model
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))
Adam-D-Lewis marked this conversation as resolved.
Show resolved Hide resolved
raise e

if repository_auto_provision:
match = re.search(github_url_regex, repository)
Expand Down
16 changes: 15 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -613,11 +614,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand All @@ -631,6 +644,7 @@ def check_provider(cls, data: Any) -> Any:
data["provider"] = provider_name_abbreviation_map[set_providers[0]]
elif num_providers == 0:
data["provider"] = schema.ProviderEnum.local.value

return data


Expand Down
9 changes: 9 additions & 0 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class InitInputs(schema.Base):
region: Optional[str] = None
ssl_cert_email: Optional[schema.email_pydantic] = None
disable_prompt: bool = False
config_set: Optional[str] = None
output: pathlib.Path = pathlib.Path("nebari-config.yaml")
explicit: int = 0

Expand Down Expand Up @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel):
terraform_state=inputs.terraform_state,
ssl_cert_email=inputs.ssl_cert_email,
disable_prompt=inputs.disable_prompt,
config_set=inputs.config_set,
)

try:
Expand Down Expand Up @@ -496,6 +498,12 @@ def init(
False,
is_eager=True,
),
config_set: str = typer.Option(
None,
"--config-set",
"-s",
help="Apply a pre-defined set of nebari configuration options.",
),
output: str = typer.Option(
pathlib.Path("nebari-config.yaml"),
"--output",
Expand Down Expand Up @@ -554,6 +562,7 @@ def init(
inputs.terraform_state = terraform_state
inputs.ssl_cert_email = ssl_cert_email
inputs.disable_prompt = disable_prompt
inputs.config_set = config_set
inputs.output = output
inputs.explicit = explicit

Expand Down
4 changes: 2 additions & 2 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def modified_environ(*remove: List[str], **update: Dict[str, str]):


def deep_merge(*args):
"""Deep merge multiple dictionaries.
"""Deep merge multiple dictionaries. Preserves order in dicts and lists.

>>> value_1 = {
'a': [1, 2],
Expand Down Expand Up @@ -190,7 +190,7 @@ def deep_merge(*args):

if isinstance(d1, dict) and isinstance(d2, dict):
d3 = {}
for key in d1.keys() | d2.keys():
for key in tuple(d1.keys()) + tuple(d2.keys()):
if key in d1 and key in d2:
d3[key] = deep_merge(d1[key], d2[key])
elif key in d1:
Expand Down
73 changes: 73 additions & 0 deletions tests/tests_unit/test_config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from unittest.mock import patch

import pytest
from packaging.requirements import Requirement

from _nebari.config_set import ConfigSetMetadata, read_config_set

test_version = "2024.12.2"


def test_valid_version_requirement():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we parameterize this test with a few different version strings to ensure it works in all the allowed formats? ideally including pre-release versions.

metadata = ConfigSetMetadata(
name="test-config", nebari_version=">=2024.12.0,<2025.0.0"
)
assert metadata.nebari_version.specifier.contains(test_version)


def test_invalid_version_requirement():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be parameterized too to validate multiple invalid types

with pytest.raises(ValueError) as exc_info:
csm = ConfigSetMetadata(name="test-config", nebari_version=">=2025.0.0")
csm.check_version(test_version)
assert "Current Nebari version" in str(exc_info.value)


def test_valid_version_requirement_with_requirement_object():
requirement = Requirement("nebari>=2024.12.0")
metadata = ConfigSetMetadata(name="test-config", nebari_version=requirement)
assert metadata.nebari_version.specifier.contains(test_version)


def test_invalid_version_requirement_with_requirement_object():
requirement = Requirement("nebari>=2025.0.0")
with pytest.raises(ValueError) as exc_info:
csm = ConfigSetMetadata(name="test-config", nebari_version=requirement)
csm.check_version(test_version)
assert "Current Nebari version" in str(exc_info.value)


def test_read_config_set_valid(tmp_path):
config_set_yaml = """
metadata:
name: test-config
nebari_version: ">=2024.12.0"
config:
key: value
"""
config_set_filepath = tmp_path / "config_set.yaml"
config_set_filepath.write_text(config_set_yaml)
with patch("_nebari.config_set.__version__", "2024.12.2"):
config_set = read_config_set(str(config_set_filepath))
assert config_set.metadata.name == "test-config"
assert config_set.config["key"] == "value"


def test_read_config_set_invalid_version(tmp_path):
config_set_yaml = """
metadata:
name: test-config
nebari_version: ">=2025.0.0"
config:
key: value
"""
config_set_filepath = tmp_path / "config_set.yaml"
config_set_filepath.write_text(config_set_yaml)

with patch("_nebari.config_set.__version__", "2024.12.2"):
with pytest.raises(ValueError) as exc_info:
read_config_set(str(config_set_filepath))
assert "Current Nebari version" in str(exc_info.value)


if __name__ == "__main__":
pytest.main()
10 changes: 10 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider):
result_config_dict = config.model_dump()
assert provider in result_config_dict
assert result_config_dict[provider]["kube_context"] == "some_context"


def test_provider_config_mismatch_warning(config_schema):
config_dict = {
"project_name": "test",
"provider": "local",
"existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider
}
with pytest.warns(UserWarning, match="configuration defined for other providers"):
config_schema(**config_dict)
1 change: 1 addition & 0 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change(
mock_model_fields, mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy(deep=True)
old_config.local = None
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config.model_dump()

Expand Down
74 changes: 73 additions & 1 deletion tests/tests_unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion
from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge


@pytest.mark.parametrize(
Expand Down Expand Up @@ -64,3 +64,75 @@ def test_JsonDiff_modified():
diff = JsonDiff(obj1, obj2)
modifieds = diff.modified()
assert sorted(modifieds) == sorted([(["b", "!"], 2, 3), (["+"], 4, 5)])


def test_deep_merge_order_preservation_dict():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
"e": {"f": {"g": {}}},
"m": 1,
}

value_2 = {
"a": [3, 4],
"b": {"d": 2, "z": [7]},
"e": {"f": {"h": 1}},
"m": [1],
}

expected_result = {
"a": [1, 2, 3, 4],
"b": {"c": 1, "z": [5, 6, 7], "d": 2},
"e": {"f": {"g": {}, "h": 1}},
"m": 1,
}

result = deep_merge(value_1, value_2)
assert result == expected_result
assert list(result.keys()) == list(expected_result.keys())
assert list(result["b"].keys()) == list(expected_result["b"].keys())
assert list(result["e"]["f"].keys()) == list(expected_result["e"]["f"].keys())


def test_deep_merge_order_preservation_list():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
}

value_2 = {
"a": [3, 4],
"b": {"d": 2, "z": [7]},
}

expected_result = {
"a": [1, 2, 3, 4],
"b": {"c": 1, "z": [5, 6, 7], "d": 2},
}

result = deep_merge(value_1, value_2)
assert result == expected_result
assert result["a"] == expected_result["a"]
assert result["b"]["z"] == expected_result["b"]["z"]


def test_deep_merge_single_dict():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
}

expected_result = value_1

result = deep_merge(value_1)
assert result == expected_result
assert list(result.keys()) == list(expected_result.keys())
assert list(result["b"].keys()) == list(expected_result["b"].keys())


def test_deep_merge_empty():
expected_result = {}

result = deep_merge()
assert result == expected_result
Loading