Skip to content

Commit

Permalink
feat(GoodConf): ability to generate TOML configuration for complex types
Browse files Browse the repository at this point in the history
In function generate_toml recursively traverse the various input fields, generating TOML list and tables to contain nested classes and list of other elements, until a base type is found.

lincolnloop#32
  • Loading branch information
mion00 committed Mar 4, 2023
1 parent 0040cbe commit 9fafbcd
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 13 deletions.
89 changes: 77 additions & 12 deletions goodconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from io import StringIO
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, cast

from pydantic import BaseSettings, PrivateAttr
from pydantic import BaseModel, BaseSettings, PrivateAttr
from pydantic.env_settings import SettingsSourceCallable
from pydantic.fields import Field, FieldInfo, ModelField, Undefined # noqa
from pydantic.fields import ModelField, Undefined

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,18 +53,32 @@ def _find_file(filename: str, require: bool = True) -> Optional[str]:

def initial_for_field(name: str, field: ModelField) -> Any:
info = field.field_info
initial = "" # Default value
try:
if not callable(info.extra["initial"]):
raise ValueError(f"Initial value for `{name}` must be a callable.")
return info.extra["initial"]()
initial = info.extra["initial"]()
except KeyError:
if info.default is not Undefined and info.default is not ...:
return info.default
initial = info.default
if info.default_factory is not None:
return info.default_factory()
initial = info.default_factory()

# If initial is a BaseModel generate the dictionary representation using pydantic
# built-in method
if isinstance(initial, BaseModel):
return initial.dict()
# If initial is a list, concatenate the result in an output list
elif isinstance(initial, list):
# If it contains a list of BaseModel, invoke dict on each of them
if any(isinstance(element, BaseModel) for element in initial):
return [element.dict() for element in initial]
else:
# If they are basic types, simply concatenate them
return [inner for inner in initial]
if field.allow_none:
return None
return ""
return initial


def file_config_settings_source(settings: BaseSettings) -> Dict[str, Any]:
Expand Down Expand Up @@ -144,7 +158,7 @@ def load(self, filename: Optional[str] = None) -> None:
super().__init__()

@classmethod
def get_initial(cls, **override) -> dict:
def get_initial(cls, **override) -> dict[str, Any]:
return {
k: override.get(k, initial_for_field(k, v))
for k, v in cls.__fields__.items()
Expand Down Expand Up @@ -199,11 +213,62 @@ def generate_toml(cls, **override) -> str:
document = tomlkit.document()
if cls.__doc__:
document.add(tomlkit.comment(cls.__doc__))
for k, v in dict_from_toml.unwrap().items():
document.add(k, v)
if cls.__fields__[k].field_info.description:
description = cast(str, cls.__fields__[k].field_info.description)
cast(Item, document[k]).comment(description)

def create_item(field: ModelField, initial_value: Any) -> Item:
"""Recursively traverse the input field,
building the appropriate TOML Item while descending the hierarchy.
Stop when find a basic type is encountered, created as a basic TOML Item"""
# Check to see if the initial_value is a complex type
if isinstance(initial_value, dict):
# If this field contains sub-fields inside,
# create them inside a TOML table
table = tomlkit.table()
# Invoke recursively on each subfield
for name, field in field.type_.__fields__.items():
item = create_item(field, initial_value[name])
# Add the item to the table
table[name] = item
return table
# Che if the initial_value is a list of object
elif isinstance(initial_value, list):
# Check to see if the list of sub-fields contains any complex type.
# In that case, an array of table (aot) is required
if getattr(field, "sub_fields") and any(
sub_field.is_complex() for sub_field in field.sub_fields
):
array = tomlkit.aot()
else:
# The sub-fields are basic types
array = tomlkit.array()

for index, _ in enumerate(initial_value):
# Invoke recursively on each element
if getattr(field, "sub_fields"):
# We have a complex type in the sub_fields
item = create_item(field.sub_fields[0], initial_value[index])
else:
# We have a simple type
item = create_item(field, initial_value[index])
# Append each item to the array
array.append(item)

return array
# Base of the recursion: the initial_value is a simple type
else:
# Create a base TOML item
item = tomlkit.item(initial_value)

# Add description to the item, if present
if field.field_info.description:
description = cast(str, field.field_info.description)
item.comment(description)

return item

for k, initial_value in dict_from_toml.unwrap().items():
item = create_item(cls.__fields__[k], initial_value)
document.add(k, item)

return tomlkit.dumps(document)

@classmethod
Expand Down
30 changes: 29 additions & 1 deletion tests/test_goodconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import pytest
from pydantic import Field, ValidationError
from pydantic import BaseModel, Field, ValidationError

from goodconf import GoodConf
from tests.utils import env_var
Expand Down Expand Up @@ -54,6 +54,34 @@ class TestConf(GoodConf):
assert 'b = ""' in output


def test_dump_complex_toml():
"""Dump a complex configuration class, with inner classes and lists"""
pytest.importorskip("tomlkit")
import tomlkit

class TestConf(GoodConf):
class A(BaseModel):
inner: bool = False
index: int

outer = A(index=0)
simple_list: list[int] = [1, 2]
complex_list: list[A] = [A(index=0)]

output = TestConf.generate_toml()
assert "[outer]" in output
assert "inner = false" in output

# Check that generated toml is valid
doc = tomlkit.parse(output)
assert doc["outer"]["inner"] is False

# Check the lists
assert len(doc["simple_list"]) == 2
assert doc["simple_list"][0] == 1
assert doc["complex_list"][0]["index"] == 0


def test_dump_yaml():
pytest.importorskip("ruamel.yaml")

Expand Down
38 changes: 38 additions & 0 deletions tests/test_initial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import pytest
from pydantic import BaseModel

from goodconf import Field, GoodConf, initial_for_field

Expand Down Expand Up @@ -59,3 +60,40 @@ class G(GoodConf):

initial = G().get_initial()
assert initial["a"] is None


def test_complex_initial():
"""Test a nested inner BaseModel"""

class G(GoodConf):
class A(BaseModel):
inner: str = "test A"

outer_a = A()

initial = G().get_initial()
assert initial["outer_a"]["inner"] == "test A"


def test_list_initial():
"""Test a list of basic types"""

class G(GoodConf):
list = [0, 1, 2]

initial = G().get_initial()
assert len(initial["list"]) == 3


def test_list_complex_initial():
"""Test a list of nested inner BaseModel"""

class G(GoodConf):
class A(BaseModel):
inner: str = "test A"

list = [A()]

initial = G().get_initial()
assert len(initial["list"]) == 1
assert initial["list"][0]["inner"] == "test A"

0 comments on commit 9fafbcd

Please sign in to comment.