From 72bdfb32710022eec420f4919f3b194f46a259ef Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 12:52:27 -0600 Subject: [PATCH 01/15] save progress --- src/_nebari/config_set.py | 14 ++++++++++++++ src/_nebari/initialize.py | 7 +++++++ src/_nebari/subcommands/init.py | 9 +++++++++ 3 files changed, 30 insertions(+) create mode 100644 src/_nebari/config_set.py diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py new file mode 100644 index 000000000..2ff2ef0b2 --- /dev/null +++ b/src/_nebari/config_set.py @@ -0,0 +1,14 @@ +import pathlib + +import yaml + + +def read_config_set(config_set_filepath: str): + """Read a config set from a config file.""" + + filename = pathlib.Path(config_set_filepath) + + with filename.open() as f: + config_dict = yaml.load(f) + + return config_dict diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 4b41f2c5a..ba624cab1 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -9,6 +9,7 @@ import requests from _nebari import constants +from _nebari.config_set import read_config_set from _nebari.provider import git from _nebari.provider.cicd import github from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud @@ -47,6 +48,7 @@ def render_config( region: str = None, disable_prompt: bool = False, ssl_cert_email: str = None, + config_set: str = None, ) -> Dict[str, Any]: config = { "provider": cloud_provider, @@ -176,6 +178,11 @@ def render_config( config["certificate"] = {"type": CertificateEnum.letsencrypt.value} config["certificate"]["acme_email"] = ssl_cert_email + if config_set: + # read the config set, validate, merge/clobber with existing config + config_set_config = read_config_set(config_set) + config.update(config_set_config) + # validate configuration and convert to model from nebari.plugins import nebari_plugin_manager diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index e794841ea..c2f8d416e 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -93,6 +93,7 @@ class InitInputs(schema.Base): region: Optional[str] = None ssl_cert_email: Optional[schema.email_pydantic] = None disable_prompt: bool = False + config_set: Optional[str] = None output: pathlib.Path = pathlib.Path("nebari-config.yaml") explicit: int = 0 @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel): terraform_state=inputs.terraform_state, ssl_cert_email=inputs.ssl_cert_email, disable_prompt=inputs.disable_prompt, + config_set=inputs.config_set, ) try: @@ -496,6 +498,12 @@ def init( False, is_eager=True, ), + config_set: str = typer.Option( + None, + "--config-set", + "-s", + help="Apply a pre-defined set of nebari configuration options.", + ), output: str = typer.Option( pathlib.Path("nebari-config.yaml"), "--output", @@ -554,6 +562,7 @@ def init( inputs.terraform_state = terraform_state inputs.ssl_cert_email = ssl_cert_email inputs.disable_prompt = disable_prompt + inputs.config_set = config_set inputs.output = output inputs.explicit = explicit From 17e1fc468f3242cd88fee9d12d44db5edad03bb6 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:49:25 -0600 Subject: [PATCH 02/15] add metadata to configset --- src/_nebari/config_set.py | 22 ++++++++++++++++++++-- src/_nebari/initialize.py | 6 +++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py index 2ff2ef0b2..eaa7b0181 100644 --- a/src/_nebari/config_set.py +++ b/src/_nebari/config_set.py @@ -1,6 +1,22 @@ import pathlib import yaml +from pydantic import BaseModel, ConfigDict +from pytest import Config + + +class ConfigSetMetadata(BaseModel): + model_config: ConfigDict = ConfigDict( + extra=Config.extra.allow, + ) + name: str = None + description: str = None + nebari_version: str = None + + +class ConfigSet(BaseModel): + metadata: ConfigSetMetadata + config: dict def read_config_set(config_set_filepath: str): @@ -9,6 +25,8 @@ def read_config_set(config_set_filepath: str): filename = pathlib.Path(config_set_filepath) with filename.open() as f: - config_dict = yaml.load(f) + config_set = yaml.load(f) + + # TODO: Validation e.g. Check Nebari version - return config_dict + return config_set diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index ba624cab1..f00a6cb0c 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -8,7 +8,7 @@ import pydantic import requests -from _nebari import constants +from _nebari import constants, utils from _nebari.config_set import read_config_set from _nebari.provider import git from _nebari.provider.cicd import github @@ -180,8 +180,8 @@ def render_config( if config_set: # read the config set, validate, merge/clobber with existing config - config_set_config = read_config_set(config_set) - config.update(config_set_config) + config_set = read_config_set(config_set) + config = utils.deep_merge(config_set.config, config) # validate configuration and convert to model from nebari.plugins import nebari_plugin_manager From 36cddc6a823539579f055c84134172c441b571d3 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:52:35 -0600 Subject: [PATCH 03/15] updates --- src/_nebari/config_set.py | 5 +++-- src/_nebari/initialize.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py index eaa7b0181..9e725f57a 100644 --- a/src/_nebari/config_set.py +++ b/src/_nebari/config_set.py @@ -9,7 +9,7 @@ class ConfigSetMetadata(BaseModel): model_config: ConfigDict = ConfigDict( extra=Config.extra.allow, ) - name: str = None + name: str = None # for use with guided init description: str = None nebari_version: str = None @@ -25,8 +25,9 @@ def read_config_set(config_set_filepath: str): filename = pathlib.Path(config_set_filepath) with filename.open() as f: - config_set = yaml.load(f) + config_set_yaml = yaml.load(f) # TODO: Validation e.g. Check Nebari version + config_set = ConfigSet(**config_set_yaml) return config_set diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index f00a6cb0c..3df598bc6 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -179,7 +179,6 @@ def render_config( config["certificate"]["acme_email"] = ssl_cert_email if config_set: - # read the config set, validate, merge/clobber with existing config config_set = read_config_set(config_set) config = utils.deep_merge(config_set.config, config) From 82322a289017656927b4e1ee7f165f346fbcad1a Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:55:00 -0600 Subject: [PATCH 04/15] update --- src/_nebari/config_set.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py index 9e725f57a..7f9071485 100644 --- a/src/_nebari/config_set.py +++ b/src/_nebari/config_set.py @@ -2,12 +2,11 @@ import yaml from pydantic import BaseModel, ConfigDict -from pytest import Config class ConfigSetMetadata(BaseModel): model_config: ConfigDict = ConfigDict( - extra=Config.extra.allow, + extra="allow", ) name: str = None # for use with guided init description: str = None From b3e83feefea9b2b6a5c064341ad7d35b03cae19f Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:45:09 -0600 Subject: [PATCH 05/15] update --- src/_nebari/config_set.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py index 7f9071485..a9c5fff2a 100644 --- a/src/_nebari/config_set.py +++ b/src/_nebari/config_set.py @@ -1,8 +1,9 @@ import pathlib -import yaml from pydantic import BaseModel, ConfigDict +from _nebari.utils import yaml + class ConfigSetMetadata(BaseModel): model_config: ConfigDict = ConfigDict( From 4cc6546f6a18647fc8c210f9cafcdc4f59ea7db1 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:26:10 -0600 Subject: [PATCH 06/15] update --- src/_nebari/initialize.py | 4 +-- src/_nebari/stages/infrastructure/__init__.py | 29 ++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 3df598bc6..7566fe7b4 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -180,7 +180,7 @@ def render_config( if config_set: config_set = read_config_set(config_set) - config = utils.deep_merge(config_set.config, config) + config = utils.deep_merge(config, config_set.config) # validate configuration and convert to model from nebari.plugins import nebari_plugin_manager @@ -188,7 +188,7 @@ def render_config( try: config_model = nebari_plugin_manager.config_schema.model_validate(config) except pydantic.ValidationError as e: - print(str(e)) + raise e if repository_auto_provision: match = re.search(github_url_regex, repository) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 243abd160..703493373 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -618,19 +618,22 @@ def check_provider(cls, data: Any) -> Any: raise ValueError( f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure" ) - else: - set_providers = [ - provider - for provider in provider_name_abbreviation_map.keys() - if provider in data - ] - num_providers = len(set_providers) - if num_providers > 1: - raise ValueError(f"Multiple providers set: {set_providers}") - elif num_providers == 1: - data["provider"] = provider_name_abbreviation_map[set_providers[0]] - elif num_providers == 0: - data["provider"] = schema.ProviderEnum.local.value + + set_providers = [ + provider + for provider in provider_name_abbreviation_map.keys() + if provider in data + ] + num_providers = len(set_providers) + if num_providers > 1: + raise ValueError( + f"Only a single provider may be set. Multiple providers are set: {set_providers}" + ) + elif num_providers == 1: + data["provider"] = provider_name_abbreviation_map[set_providers[0]] + elif num_providers == 0: + data["provider"] = schema.ProviderEnum.local.value + return data From 917c666e571d690500d93e000922d345b9204996 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:26:23 -0600 Subject: [PATCH 07/15] add test --- tests/tests_unit/test_schema.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index 5c21aef8d..029cdc25e 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -122,13 +122,25 @@ def test_no_provider(config_schema, provider, full_name, default_fields): assert full_name in config.model_dump() -def test_multiple_providers(config_schema): +@pytest.mark.parametrize( + "providers", + [ + { + "local": {}, + "existing": {}, + }, + { + "local": {}, + "google_cloud_platform": {}, + }, + ], +) +def test_multiple_providers(config_schema, providers): config_dict = { "project_name": "test", - "local": {}, - "existing": {}, + **providers, } - msg = r"Multiple providers set: \['local', 'existing'\]" + msg = r"Only a single provider may be set. Multiple providers are set: " with pytest.raises(ValidationError, match=msg): config_schema(**config_dict) From 8fea4dba315c6159626be8402d0fa1d0e27e5230 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:55:22 -0600 Subject: [PATCH 08/15] update --- tests/tests_unit/test_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index 029cdc25e..b4147e9cd 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -130,7 +130,7 @@ def test_no_provider(config_schema, provider, full_name, default_fields): "existing": {}, }, { - "local": {}, + "provider": "local", "google_cloud_platform": {}, }, ], From 6ebdbdb0d36628d15ea4f0e9ac210053fd5793cd Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:55:11 -0600 Subject: [PATCH 09/15] preserve order --- src/_nebari/utils.py | 4 +- tests/tests_unit/test_utils.py | 74 +++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/_nebari/utils.py b/src/_nebari/utils.py index f3d62f353..48b8a91e9 100644 --- a/src/_nebari/utils.py +++ b/src/_nebari/utils.py @@ -160,7 +160,7 @@ def modified_environ(*remove: List[str], **update: Dict[str, str]): def deep_merge(*args): - """Deep merge multiple dictionaries. + """Deep merge multiple dictionaries. Preserves order in dicts and lists. >>> value_1 = { 'a': [1, 2], @@ -190,7 +190,7 @@ def deep_merge(*args): if isinstance(d1, dict) and isinstance(d2, dict): d3 = {} - for key in d1.keys() | d2.keys(): + for key in tuple(d1.keys()) + tuple(d2.keys()): if key in d1 and key in d2: d3[key] = deep_merge(d1[key], d2[key]) elif key in d1: diff --git a/tests/tests_unit/test_utils.py b/tests/tests_unit/test_utils.py index 678cd1f23..88b911ff6 100644 --- a/tests/tests_unit/test_utils.py +++ b/tests/tests_unit/test_utils.py @@ -1,6 +1,6 @@ import pytest -from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion +from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge @pytest.mark.parametrize( @@ -64,3 +64,75 @@ def test_JsonDiff_modified(): diff = JsonDiff(obj1, obj2) modifieds = diff.modified() assert sorted(modifieds) == sorted([(["b", "!"], 2, 3), (["+"], 4, 5)]) + + +def test_deep_merge_order_preservation_dict(): + value_1 = { + "a": [1, 2], + "b": {"c": 1, "z": [5, 6]}, + "e": {"f": {"g": {}}}, + "m": 1, + } + + value_2 = { + "a": [3, 4], + "b": {"d": 2, "z": [7]}, + "e": {"f": {"h": 1}}, + "m": [1], + } + + expected_result = { + "a": [1, 2, 3, 4], + "b": {"c": 1, "z": [5, 6, 7], "d": 2}, + "e": {"f": {"g": {}, "h": 1}}, + "m": 1, + } + + result = deep_merge(value_1, value_2) + assert result == expected_result + assert list(result.keys()) == list(expected_result.keys()) + assert list(result["b"].keys()) == list(expected_result["b"].keys()) + assert list(result["e"]["f"].keys()) == list(expected_result["e"]["f"].keys()) + + +def test_deep_merge_order_preservation_list(): + value_1 = { + "a": [1, 2], + "b": {"c": 1, "z": [5, 6]}, + } + + value_2 = { + "a": [3, 4], + "b": {"d": 2, "z": [7]}, + } + + expected_result = { + "a": [1, 2, 3, 4], + "b": {"c": 1, "z": [5, 6, 7], "d": 2}, + } + + result = deep_merge(value_1, value_2) + assert result == expected_result + assert result["a"] == expected_result["a"] + assert result["b"]["z"] == expected_result["b"]["z"] + + +def test_deep_merge_single_dict(): + value_1 = { + "a": [1, 2], + "b": {"c": 1, "z": [5, 6]}, + } + + expected_result = value_1 + + result = deep_merge(value_1) + assert result == expected_result + assert list(result.keys()) == list(expected_result.keys()) + assert list(result["b"].keys()) == list(expected_result["b"].keys()) + + +def test_deep_merge_empty(): + expected_result = {} + + result = deep_merge() + assert result == expected_result From 5c86415507d3599e15a4fe52a802fd8dfc834e9f Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:10:46 -0600 Subject: [PATCH 10/15] allow explicit to still work --- src/_nebari/stages/infrastructure/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 703493373..fce5bd60b 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -622,7 +622,7 @@ def check_provider(cls, data: Any) -> Any: set_providers = [ provider for provider in provider_name_abbreviation_map.keys() - if provider in data + if provider in data and data[provider] ] num_providers = len(set_providers) if num_providers > 1: From b6ffec8f3c0567be991f5a722a18df69fab39ade Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:43:52 -0600 Subject: [PATCH 11/15] add version checking --- src/_nebari/config_set.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py index a9c5fff2a..7c4b94cd9 100644 --- a/src/_nebari/config_set.py +++ b/src/_nebari/config_set.py @@ -1,17 +1,33 @@ import pathlib -from pydantic import BaseModel, ConfigDict +from packaging.requirements import Requirement +from pydantic import BaseModel, ConfigDict, field_validator +from _nebari._version import __version__ from _nebari.utils import yaml class ConfigSetMetadata(BaseModel): - model_config: ConfigDict = ConfigDict( - extra="allow", - ) - name: str = None # for use with guided init + model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True) + name: str # for use with guided init description: str = None - nebari_version: str = None + nebari_version: str | Requirement + + @field_validator("nebari_version") + @classmethod + def validate_version_requirement(cls, version_req): + if isinstance(version_req, str): + version_req = Requirement(f"nebari{version_req}") + version_req.specifier.prereleases = True + + return version_req + + def check_version(self, version): + if version not in self.nebari_version.specifier: + raise ValueError( + f"Current Nebari version {__version__} is not compatible with " + f'required version {self.nebari_version.specifier} for "{self.name}" config set.' + ) class ConfigSet(BaseModel): @@ -27,7 +43,9 @@ def read_config_set(config_set_filepath: str): with filename.open() as f: config_set_yaml = yaml.load(f) - # TODO: Validation e.g. Check Nebari version config_set = ConfigSet(**config_set_yaml) + # validation + config_set.metadata.check_version(__version__) + return config_set From 237f73b2508ba4db2c0d8562d001178754a011e4 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:04:10 -0600 Subject: [PATCH 12/15] add tests --- tests/tests_unit/test_config_set.py | 73 +++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/tests_unit/test_config_set.py diff --git a/tests/tests_unit/test_config_set.py b/tests/tests_unit/test_config_set.py new file mode 100644 index 000000000..e14200994 --- /dev/null +++ b/tests/tests_unit/test_config_set.py @@ -0,0 +1,73 @@ +from unittest.mock import patch + +import pytest +from packaging.requirements import Requirement + +from _nebari.config_set import ConfigSetMetadata, read_config_set + +test_version = "2024.12.2" + + +def test_valid_version_requirement(): + metadata = ConfigSetMetadata( + name="test-config", nebari_version=">=2024.12.0,<2025.0.0" + ) + assert metadata.nebari_version.specifier.contains(test_version) + + +def test_invalid_version_requirement(): + with pytest.raises(ValueError) as exc_info: + csm = ConfigSetMetadata(name="test-config", nebari_version=">=2025.0.0") + csm.check_version(test_version) + assert "Current Nebari version" in str(exc_info.value) + + +def test_valid_version_requirement_with_requirement_object(): + requirement = Requirement("nebari>=2024.12.0") + metadata = ConfigSetMetadata(name="test-config", nebari_version=requirement) + assert metadata.nebari_version.specifier.contains(test_version) + + +def test_invalid_version_requirement_with_requirement_object(): + requirement = Requirement("nebari>=2025.0.0") + with pytest.raises(ValueError) as exc_info: + csm = ConfigSetMetadata(name="test-config", nebari_version=requirement) + csm.check_version(test_version) + assert "Current Nebari version" in str(exc_info.value) + + +def test_read_config_set_valid(tmp_path): + config_set_yaml = """ + metadata: + name: test-config + nebari_version: ">=2024.12.0" + config: + key: value + """ + config_set_filepath = tmp_path / "config_set.yaml" + config_set_filepath.write_text(config_set_yaml) + with patch("_nebari.config_set.__version__", "2024.12.2"): + config_set = read_config_set(str(config_set_filepath)) + assert config_set.metadata.name == "test-config" + assert config_set.config["key"] == "value" + + +def test_read_config_set_invalid_version(tmp_path): + config_set_yaml = """ + metadata: + name: test-config + nebari_version: ">=2025.0.0" + config: + key: value + """ + config_set_filepath = tmp_path / "config_set.yaml" + config_set_filepath.write_text(config_set_yaml) + + with patch("_nebari.config_set.__version__", "2024.12.2"): + with pytest.raises(ValueError) as exc_info: + read_config_set(str(config_set_filepath)) + assert "Current Nebari version" in str(exc_info.value) + + +if __name__ == "__main__": + pytest.main() From 40d71d64cfbe26264601903138f5da5e9f99108d Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:40:24 -0600 Subject: [PATCH 13/15] revert some behavior --- src/_nebari/stages/infrastructure/__init__.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index fce5bd60b..a5894687a 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -618,21 +618,19 @@ def check_provider(cls, data: Any) -> Any: raise ValueError( f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure" ) - - set_providers = [ - provider - for provider in provider_name_abbreviation_map.keys() - if provider in data and data[provider] - ] - num_providers = len(set_providers) - if num_providers > 1: - raise ValueError( - f"Only a single provider may be set. Multiple providers are set: {set_providers}" - ) - elif num_providers == 1: - data["provider"] = provider_name_abbreviation_map[set_providers[0]] - elif num_providers == 0: - data["provider"] = schema.ProviderEnum.local.value + else: + set_providers = [ + provider + for provider in provider_name_abbreviation_map.keys() + if provider in data + ] + num_providers = len(set_providers) + if num_providers > 1: + raise ValueError(f"Multiple providers set: {set_providers}") + elif num_providers == 1: + data["provider"] = provider_name_abbreviation_map[set_providers[0]] + elif num_providers == 0: + data["provider"] = schema.ProviderEnum.local.value return data From 2682982943370bcf7263a4a4ec9ab4cacb119112 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:46:40 -0600 Subject: [PATCH 14/15] revert prior test changes --- tests/tests_unit/test_schema.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index b4147e9cd..5c21aef8d 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -122,25 +122,13 @@ def test_no_provider(config_schema, provider, full_name, default_fields): assert full_name in config.model_dump() -@pytest.mark.parametrize( - "providers", - [ - { - "local": {}, - "existing": {}, - }, - { - "provider": "local", - "google_cloud_platform": {}, - }, - ], -) -def test_multiple_providers(config_schema, providers): +def test_multiple_providers(config_schema): config_dict = { "project_name": "test", - **providers, + "local": {}, + "existing": {}, } - msg = r"Only a single provider may be set. Multiple providers are set: " + msg = r"Multiple providers set: \['local', 'existing'\]" with pytest.raises(ValidationError, match=msg): config_schema(**config_dict) From 68b506477803cdecd3aecc0859d3c6e7f7553a36 Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:46:06 -0600 Subject: [PATCH 15/15] raise warning if extra provider config given --- src/_nebari/stages/infrastructure/__init__.py | 15 ++++++++++++++- tests/tests_unit/test_schema.py | 10 ++++++++++ tests/tests_unit/test_stages.py | 1 + 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index a5894687a..b75412bd6 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -6,6 +6,7 @@ import re import sys import tempfile +import warnings from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union from pydantic import ConfigDict, Field, field_validator, model_validator @@ -613,11 +614,23 @@ def check_provider(cls, data: Any) -> Any: data[provider] = provider_enum_model_map[provider]() else: # if the provider field is invalid, it won't be set when this validator is called - # so we need to check for it explicitly here, and set the `pre` to True + # so we need to check for it explicitly here, and set mode to "before" # TODO: this is a workaround, check if there is a better way to do this in Pydantic v2 raise ValueError( f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure" ) + set_providers = { + provider + for provider in provider_name_abbreviation_map.keys() + if provider in data and data[provider] + } + expected_provider_config = provider_enum_name_map[provider] + extra_provider_config = set_providers - {expected_provider_config} + if extra_provider_config: + warnings.warn( + f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}" + ) + else: set_providers = [ provider diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index 5c21aef8d..e445ba37d 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider): result_config_dict = config.model_dump() assert provider in result_config_dict assert result_config_dict[provider]["kube_context"] == "some_context" + + +def test_provider_config_mismatch_warning(config_schema): + config_dict = { + "project_name": "test", + "provider": "local", + "existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider + } + with pytest.warns(UserWarning, match="configuration defined for other providers"): + config_schema(**config_dict) diff --git a/tests/tests_unit/test_stages.py b/tests/tests_unit/test_stages.py index c716d9303..c15aa6d9f 100644 --- a/tests/tests_unit/test_stages.py +++ b/tests/tests_unit/test_stages.py @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change( mock_model_fields, mock_get_state, terraform_state_stage, mock_config ): old_config = mock_config.model_copy(deep=True) + old_config.local = None old_config.provider = schema.ProviderEnum.gcp mock_get_state.return_value = old_config.model_dump()