diff --git a/common/args.py b/common/args.py index 42f795a4..c3340afc 100644 --- a/common/args.py +++ b/common/args.py @@ -34,42 +34,61 @@ def argument_with_auto(value): ) from ex +def map_pydantic_type_to_argparse(pydantic_type): + """ + Maps Pydantic types to argparse compatible types. + """ + origin = get_origin(pydantic_type) + + # Handle optional types (Union with NoneType) + if origin is Union: + # Filter out NoneType, leaving the actual type + pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None)) + + # Handle lists (argparse can take nargs for lists) + elif origin is List: + pydantic_type = get_args(pydantic_type)[0] # Get the list item type + + # Map basic types (int, float, str, bool) or default to str + if isinstance(pydantic_type, type) and issubclass(pydantic_type, (int, float, str, bool)): + return pydantic_type + return str # Default to string for unknown types + + +def add_field_to_group(group, field_name, field_type, field): + """ + Adds a Pydantic field to an argparse argument group. + """ + arg_type = map_pydantic_type_to_argparse(field_type) + help_text = field.description if field.description else "No description available" + + group.add_argument(f"--{field_name}", type=arg_type, help=help_text) + + def init_argparser(): + """ + Initializes an argparse parser based on a Pydantic config schema. + """ parser = argparse.ArgumentParser(description="TabbyAPI server") + # Loop through each top-level field in the config for field_name, field_type in config.__annotations__.items(): - group = parser.add_argument_group( - field_name, description=f"Arguments for {field_name}" - ) - - # Loop through each field in the sub-model - for sub_field_name, sub_field_type in field_type.__annotations__.items(): - field = field_type.__fields__[sub_field_name] - help_text = ( - field.description if field.description else "No description available" - ) - - origin = get_origin(sub_field_type) - if origin is Union: - sub_field_type = next( - t for t in get_args(sub_field_type) if t is not type(None) - ) - elif origin is List: - sub_field_type = get_args(sub_field_type)[0] - - # Map Pydantic types to argparse types - if isinstance(sub_field_type, type) and issubclass( - sub_field_type, (int, float, str, bool) - ): - arg_type = sub_field_type - else: - arg_type = str # Default to string for unknown types - - group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text) + group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}") + + # Check if the field_type is a Pydantic model or something with __annotations__ + if hasattr(field_type, "__annotations__"): + # Loop through each sub-field in the model + for sub_field_name, sub_field_type in field_type.__annotations__.items(): + field = field_type.__fields__[sub_field_name] + add_field_to_group(group, sub_field_name, sub_field_type, field) + else: + # Handle cases where the field_type is not a Pydantic model + print(f"{field_name=}") + arg_type = map_pydantic_type_to_argparse(field_type) + group.add_argument(f"--{field_name}", type=arg_type, help=f"Argument for {field_name}") return parser - def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser): """Broad conversion of surface level arg groups to dictionaries""" diff --git a/common/model.py b/common/model.py index ffe0bc0d..56e545b7 100644 --- a/common/model.py +++ b/common/model.py @@ -153,25 +153,6 @@ async def unload_embedding_model(): embeddings_container = None -# FIXME: Maybe make this a one-time function instead of a dynamic default -def get_config_default(key: str, model_type: str = "model"): - """Fetches a default value from model config if allowed by the user.""" - - default_keys = unwrap(config.model.use_as_default, []) - - # Add extra keys to defaults - default_keys.append("embeddings_device") - - if key in default_keys: - # Is this a draft model load parameter? - if model_type == "draft": - return config.draft_model.get(key) - elif model_type == "embedding": - return config.embeddings.get(key) - else: - return config.model.get(key) - - async def check_model_container(): """FastAPI depends that checks if a model isn't loaded or currently loading.""" diff --git a/common/tabby_config.py b/common/tabby_config.py index ffd68b31..5e38c6e9 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -13,6 +13,7 @@ class TabbyConfig(tabby_config_model): """Common config class for TabbyAPI.""" # Persistent defaults + # TODO: make this not a dict model_defaults: dict = {} def load(self, arguments: Optional[dict] = None): @@ -34,12 +35,13 @@ def load(self, arguments: Optional[dict] = None): setattr(self, field, model.parse_obj(value)) # Set model defaults dict once to prevent on-demand reconstruction - default_keys = unwrap(self.model.get("use_as_default"), []) - for key in default_keys: - if key in self.model: - self.model_defaults[key] = config.model[key] - elif key in self.draft_model: - self.model_defaults[key] = config.draft_model[key] + # TODO: refactor this use a pydantic model + default_fields = self.model.use_as_default + for field in default_fields: + if hasattr(self.model, field): + self.model_defaults[field] = getattr(config.model, field) + elif hasattr(self.draft_model, field): + self.model_defaults[field] = getattr(config.draft_model, field) def _from_file(self, config_path: pathlib.Path): """loads config from a given file path""" diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index d996299a..8a994223 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -8,7 +8,6 @@ from common.tabby_config import config from common.utils import unwrap from common.config_models import logging_config_model -from common.model import get_config_default class ModelCardParameters(BaseModel): """Represents model card parameters.""" diff --git a/main.py b/main.py index 196eeb86..6e7943bf 100644 --- a/main.py +++ b/main.py @@ -50,11 +50,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(config.network.disable_auth) - - # Override the generation log options if given - if config.logging: - gen_logging.update_from_dict(config.logging) + await load_auth_keys(config.network.disable_auth) gen_logging.broadcast_status()