From 81ae461eb84d1f1128580647aad1a73d1f4d68fd Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 16 Sep 2024 12:19:58 -0400 Subject: [PATCH] Config: Allow existing values to get included in generated file Allows for generation from an existing config file. Primarily used for migration purposes. Signed-off-by: kingbri --- common/args.py | 4 ++-- common/config_models.py | 17 +++++++++++++---- common/utils.py | 7 +++++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/common/args.py b/common/args.py index 7c09646..3863580 100644 --- a/common/args.py +++ b/common/args.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from common.config_models import TabbyConfigModel -from common.utils import is_list_type, unwrap_optional +from common.utils import is_list_type, unwrap_optional_type def add_field_to_group(group, field_name, field_type, field) -> None: @@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser: # Loop through each top-level field in the config for field_name, field_info in TabbyConfigModel.model_fields.items(): - field_type = unwrap_optional(field_info.annotation) + field_type = unwrap_optional_type(field_info.annotation) group = parser.add_argument_group( field_name, description=f"Arguments for {field_name}" ) diff --git a/common/config_models.py b/common/config_models.py index 653280b..52f4f43 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,10 +1,11 @@ from inspect import getdoc from pathlib import Path from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic_core import PydanticUndefined from textwrap import dedent from typing import List, Literal, Optional, Union -from pydantic_core import PydanticUndefined +from common.utils import unwrap CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] @@ -488,12 +489,17 @@ def generate_config_file( # You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n """) - schema = model if model else TabbyConfigModel() + schema = unwrap(model, TabbyConfigModel()) # TODO: Make the disordered iteration look cleaner iter_once = False for field, field_data in schema.model_fields.items(): - subfield_model = field_data.default_factory() + # Fetch from the existing model class if it's passed + # Probably can use this on schema too, but play it safe + if model: + subfield_model = getattr(model, field, None) + else: + subfield_model = field_data.default_factory() if not subfield_model._metadata.include_in_config: continue @@ -519,7 +525,10 @@ def generate_config_file( else: sub_iter_once = True - if subfield_data.default_factory: + # If a value already exists, use it + if hasattr(subfield_model, subfield): + value = getattr(subfield_model, subfield) + elif subfield_data.default_factory: value = subfield_data.default_factory() else: value = subfield_data.default diff --git a/common/utils.py b/common/utils.py index acc0fc9..77958ce 100644 --- a/common/utils.py +++ b/common/utils.py @@ -62,8 +62,11 @@ def is_list_type(type_hint) -> bool: return False -def unwrap_optional(type_hint) -> Type: - """unwrap Optional[type] annotations""" +def unwrap_optional_type(type_hint) -> Type: + """ + Unwrap Optional[type] annotations. + This is not the same as unwrap. + """ if get_origin(type_hint) is Union: args = get_args(type_hint)