Skip to content

Commit

Permalink
Upgrade to Pydantic V2 (#2348)
Browse files Browse the repository at this point in the history
Co-authored-by: Fangchen Li <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Amit Kumar <[email protected]>
  • Loading branch information
4 people authored Apr 22, 2024
1 parent 0c3adc8 commit 85d3a75
Show file tree
Hide file tree
Showing 29 changed files with 484 additions and 506 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ dependencies = [
"kubernetes==27.2.0",
"pluggy==1.3.0",
"prompt-toolkit==3.0.36",
"pydantic==1.10.12",
"pydantic==2.4.2",
"pynacl==1.5.0",
"python-keycloak>=3.9.0",
"questionary==2.0.0",
Expand Down
25 changes: 18 additions & 7 deletions src/_nebari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@
import pathlib
import re
import sys
import typing
from typing import Any, Dict, List, Union

import pydantic

from _nebari.utils import yaml


def set_nested_attribute(data: typing.Any, attrs: typing.List[str], value: typing.Any):
def set_nested_attribute(data: Any, attrs: List[str], value: Any):
"""Takes an arbitrary set of attributes and accesses the deep
nested object config to set value
"""

def _get_attr(d: typing.Any, attr: str):
def _get_attr(d: Any, attr: str):
if isinstance(d, list) and re.fullmatch(r"\d+", attr):
return d[int(attr)]
elif hasattr(d, "__getitem__"):
return d[attr]
else:
return getattr(d, attr)

def _set_attr(d: typing.Any, attr: str, value: typing.Any):
def _set_attr(d: Any, attr: str, value: Any):
if isinstance(d, list) and re.fullmatch(r"\d+", attr):
d[int(attr)] = value
elif hasattr(d, "__getitem__"):
Expand Down Expand Up @@ -63,6 +63,15 @@ def set_config_from_environment_variables(
return config


def dump_nested_model(model_dict: Dict[str, Union[pydantic.BaseModel, str]]):
result = {}
for key, value in model_dict.items():
result[key] = (
value.model_dump() if isinstance(value, pydantic.BaseModel) else value
)
return result


def read_configuration(
config_filename: pathlib.Path,
config_schema: pydantic.BaseModel,
Expand All @@ -77,7 +86,8 @@ def read_configuration(
)

with filename.open() as f:
config = config_schema(**yaml.load(f.read()))
config_dict = yaml.load(f)
config = config_schema(**config_dict)

if read_environment:
config = set_config_from_environment_variables(config)
Expand All @@ -87,14 +97,15 @@ def read_configuration(

def write_configuration(
config_filename: pathlib.Path,
config: typing.Union[pydantic.BaseModel, typing.Dict],
config: Union[pydantic.BaseModel, Dict],
mode: str = "w",
):
"""Write the nebari configuration file to disk"""
with config_filename.open(mode) as f:
if isinstance(config, pydantic.BaseModel):
yaml.dump(config.dict(), f)
yaml.dump(config.model_dump(), f)
else:
config = dump_nested_model(config)
yaml.dump(config, f)


Expand Down
7 changes: 4 additions & 3 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import tempfile
from pathlib import Path
from typing import Any, Dict

import pydantic
import requests
Expand Down Expand Up @@ -52,7 +53,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
):
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
"namespace": namespace,
Expand Down Expand Up @@ -119,7 +120,7 @@ def render_config(
if cloud_provider == ProviderEnum.do:
do_region = region or constants.DO_DEFAULT_REGION
do_kubernetes_versions = kubernetes_version or get_latest_kubernetes_version(
digital_ocean.kubernetes_versions(do_region)
digital_ocean.kubernetes_versions()
)
config["digital_ocean"] = {
"kubernetes_version": do_kubernetes_versions,
Expand Down Expand Up @@ -200,7 +201,7 @@ def render_config(
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.parse_obj(config)
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))

Expand Down
69 changes: 22 additions & 47 deletions src/_nebari/provider/cicd/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import requests
from nacl import encoding, public
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, RootModel

from _nebari.constants import LATEST_SUPPORTED_PYTHON_VERSION
from _nebari.provider.cicd.common import pip_install_nebari
Expand Down Expand Up @@ -143,49 +143,34 @@ class GHA_on_extras(BaseModel):
paths: List[str]


class GHA_on(BaseModel):
# to allow for dynamic key names
__root__: Dict[str, GHA_on_extras]

# TODO: validate __root__ values
# `push`, `pull_request`, etc.


class GHA_job_steps_extras(BaseModel):
# to allow for dynamic key names
__root__: Union[str, float, int]
GHA_on = RootModel[Dict[str, GHA_on_extras]]
GHA_job_steps_extras = RootModel[Union[str, float, int]]


class GHA_job_step(BaseModel):
name: str
uses: Optional[str]
with_: Optional[Dict[str, GHA_job_steps_extras]] = Field(alias="with")
run: Optional[str]
env: Optional[Dict[str, GHA_job_steps_extras]]

class Config:
allow_population_by_field_name = True
uses: Optional[str] = None
with_: Optional[Dict[str, GHA_job_steps_extras]] = Field(alias="with", default=None)
run: Optional[str] = None
env: Optional[Dict[str, GHA_job_steps_extras]] = None
model_config = ConfigDict(populate_by_name=True)


class GHA_job_id(BaseModel):
name: str
runs_on_: str = Field(alias="runs-on")
permissions: Optional[Dict[str, str]]
permissions: Optional[Dict[str, str]] = None
steps: List[GHA_job_step]

class Config:
allow_population_by_field_name = True
model_config = ConfigDict(populate_by_name=True)


class GHA_jobs(BaseModel):
# to allow for dynamic key names
__root__: Dict[str, GHA_job_id]
GHA_jobs = RootModel[Dict[str, GHA_job_id]]


class GHA(BaseModel):
name: str
on: GHA_on
env: Optional[Dict[str, str]]
env: Optional[Dict[str, str]] = None
jobs: GHA_jobs


Expand All @@ -204,23 +189,15 @@ def checkout_image_step():
return GHA_job_step(
name="Checkout Image",
uses="actions/checkout@v3",
with_={
"token": GHA_job_steps_extras(
__root__="${{ secrets.REPOSITORY_ACCESS_TOKEN }}"
)
},
with_={"token": GHA_job_steps_extras("${{ secrets.REPOSITORY_ACCESS_TOKEN }}")},
)


def setup_python_step():
return GHA_job_step(
name="Set up Python",
uses="actions/setup-python@v4",
with_={
"python-version": GHA_job_steps_extras(
__root__=LATEST_SUPPORTED_PYTHON_VERSION
)
},
with_={"python-version": GHA_job_steps_extras(LATEST_SUPPORTED_PYTHON_VERSION)},
)


Expand All @@ -242,7 +219,7 @@ def gen_nebari_ops(config):
env_vars = gha_env_vars(config)

push = GHA_on_extras(branches=[config.ci_cd.branch], paths=["nebari-config.yaml"])
on = GHA_on(__root__={"push": push})
on = GHA_on({"push": push})

step1 = checkout_image_step()
step2 = setup_python_step()
Expand Down Expand Up @@ -272,7 +249,7 @@ def gen_nebari_ops(config):
),
env={
"COMMIT_MSG": GHA_job_steps_extras(
__root__="nebari-config.yaml automated commit: ${{ github.sha }}"
"nebari-config.yaml automated commit: ${{ github.sha }}"
)
},
)
Expand All @@ -291,7 +268,7 @@ def gen_nebari_ops(config):
},
steps=gha_steps,
)
jobs = GHA_jobs(__root__={"build": job1})
jobs = GHA_jobs({"build": job1})

return NebariOps(
name="nebari auto update",
Expand All @@ -312,18 +289,16 @@ def gen_nebari_linter(config):
pull_request = GHA_on_extras(
branches=[config.ci_cd.branch], paths=["nebari-config.yaml"]
)
on = GHA_on(__root__={"pull_request": pull_request})
on = GHA_on({"pull_request": pull_request})

step1 = checkout_image_step()
step2 = setup_python_step()
step3 = install_nebari_step(config.nebari_version)

step4_envs = {
"PR_NUMBER": GHA_job_steps_extras(__root__="${{ github.event.number }}"),
"REPO_NAME": GHA_job_steps_extras(__root__="${{ github.repository }}"),
"GITHUB_TOKEN": GHA_job_steps_extras(
__root__="${{ secrets.REPOSITORY_ACCESS_TOKEN }}"
),
"PR_NUMBER": GHA_job_steps_extras("${{ github.event.number }}"),
"REPO_NAME": GHA_job_steps_extras("${{ github.repository }}"),
"GITHUB_TOKEN": GHA_job_steps_extras("${{ secrets.REPOSITORY_ACCESS_TOKEN }}"),
}

step4 = GHA_job_step(
Expand All @@ -336,7 +311,7 @@ def gen_nebari_linter(config):
name="nebari", runs_on_="ubuntu-latest", steps=[step1, step2, step3, step4]
)
jobs = GHA_jobs(
__root__={
{
"nebari-validate": job1,
}
)
Expand Down
30 changes: 12 additions & 18 deletions src/_nebari/provider/cicd/gitlab.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,34 @@
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, RootModel

from _nebari.constants import LATEST_SUPPORTED_PYTHON_VERSION
from _nebari.provider.cicd.common import pip_install_nebari


class GLCI_extras(BaseModel):
# to allow for dynamic key names
__root__: Union[str, float, int]
GLCI_extras = RootModel[Union[str, float, int]]


class GLCI_image(BaseModel):
name: str
entrypoint: Optional[str]
entrypoint: Optional[str] = None


class GLCI_rules(BaseModel):
if_: Optional[str] = Field(alias="if")
changes: Optional[List[str]]

class Config:
allow_population_by_field_name = True
changes: Optional[List[str]] = None
model_config = ConfigDict(populate_by_name=True)


class GLCI_job(BaseModel):
image: Optional[Union[str, GLCI_image]]
variables: Optional[Dict[str, str]]
before_script: Optional[List[str]]
after_script: Optional[List[str]]
image: Optional[Union[str, GLCI_image]] = None
variables: Optional[Dict[str, str]] = None
before_script: Optional[List[str]] = None
after_script: Optional[List[str]] = None
script: List[str]
rules: Optional[List[GLCI_rules]]
rules: Optional[List[GLCI_rules]] = None


class GLCI(BaseModel):
__root__: Dict[str, GLCI_job]
GLCI = RootModel[Dict[str, GLCI_job]]


def gen_gitlab_ci(config):
Expand Down Expand Up @@ -76,7 +70,7 @@ def gen_gitlab_ci(config):
)

return GLCI(
__root__={
{
"render-nebari": render_nebari,
}
)
17 changes: 5 additions & 12 deletions src/_nebari/provider/cloud/amazon_web_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,18 @@
import boto3
from botocore.exceptions import ClientError, EndpointConnectionError

from _nebari import constants
from _nebari.constants import AWS_ENV_DOCS
from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version
from _nebari.utils import check_environment_variables
from nebari import schema

MAX_RETRIES = 5
DELAY = 5


def check_credentials():
"""Check for AWS credentials are set in the environment."""
for variable in {
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
}:
if variable not in os.environ:
raise ValueError(
f"""Missing the following required environment variable: {variable}\n
Please see the documentation for more information: {constants.AWS_ENV_DOCS}"""
)
def check_credentials() -> None:
required_variables = {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"}
check_environment_variables(required_variables, AWS_ENV_DOCS)


@functools.lru_cache()
Expand Down
Loading

0 comments on commit 85d3a75

Please sign in to comment.