Skip to content

Commit

Permalink
Config: Allow existing values to get included in generated file
Browse files Browse the repository at this point in the history
Allows for generation from an existing config file. Primarily used
for migration purposes.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Sep 16, 2024
1 parent 7f03003 commit 81ae461
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
4 changes: 2 additions & 2 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel

from common.config_models import TabbyConfigModel
from common.utils import is_list_type, unwrap_optional
from common.utils import is_list_type, unwrap_optional_type


def add_field_to_group(group, field_name, field_type, field) -> None:
Expand Down Expand Up @@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:

# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional(field_info.annotation)
field_type = unwrap_optional_type(field_info.annotation)
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)
Expand Down
17 changes: 13 additions & 4 deletions common/config_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from inspect import getdoc
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic_core import PydanticUndefined
from textwrap import dedent
from typing import List, Literal, Optional, Union

from pydantic_core import PydanticUndefined
from common.utils import unwrap

CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]

Expand Down Expand Up @@ -488,12 +489,17 @@ def generate_config_file(
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
""")

schema = model if model else TabbyConfigModel()
schema = unwrap(model, TabbyConfigModel())

# TODO: Make the disordered iteration look cleaner
iter_once = False
for field, field_data in schema.model_fields.items():
subfield_model = field_data.default_factory()
# Fetch from the existing model class if it's passed
# Probably can use this on schema too, but play it safe
if model:
subfield_model = getattr(model, field, None)
else:
subfield_model = field_data.default_factory()

if not subfield_model._metadata.include_in_config:
continue
Expand All @@ -519,7 +525,10 @@ def generate_config_file(
else:
sub_iter_once = True

if subfield_data.default_factory:
# If a value already exists, use it
if hasattr(subfield_model, subfield):
value = getattr(subfield_model, subfield)
elif subfield_data.default_factory:
value = subfield_data.default_factory()
else:
value = subfield_data.default
Expand Down
7 changes: 5 additions & 2 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ def is_list_type(type_hint) -> bool:
return False


def unwrap_optional(type_hint) -> Type:
"""unwrap Optional[type] annotations"""
def unwrap_optional_type(type_hint) -> Type:
"""
Unwrap Optional[type] annotations.
This is not the same as unwrap.
"""

if get_origin(type_hint) is Union:
args = get_args(type_hint)
Expand Down

0 comments on commit 81ae461

Please sign in to comment.