Skip to content

Commit

Permalink
fix issues after rebase and ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
SecretiveShell committed Sep 11, 2024
1 parent b2620d1 commit 6d28fdd
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 60 deletions.
77 changes: 48 additions & 29 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,61 @@ def argument_with_auto(value):
) from ex


def map_pydantic_type_to_argparse(pydantic_type):
"""
Maps Pydantic types to argparse compatible types.
"""
origin = get_origin(pydantic_type)

# Handle optional types (Union with NoneType)
if origin is Union:
# Filter out NoneType, leaving the actual type
pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None))

# Handle lists (argparse can take nargs for lists)
elif origin is List:
pydantic_type = get_args(pydantic_type)[0] # Get the list item type

# Map basic types (int, float, str, bool) or default to str
if isinstance(pydantic_type, type) and issubclass(pydantic_type, (int, float, str, bool)):
return pydantic_type
return str # Default to string for unknown types


def add_field_to_group(group, field_name, field_type, field):
"""
Adds a Pydantic field to an argparse argument group.
"""
arg_type = map_pydantic_type_to_argparse(field_type)
help_text = field.description if field.description else "No description available"

group.add_argument(f"--{field_name}", type=arg_type, help=help_text)


def init_argparser():
"""
Initializes an argparse parser based on a Pydantic config schema.
"""
parser = argparse.ArgumentParser(description="TabbyAPI server")

# Loop through each top-level field in the config
for field_name, field_type in config.__annotations__.items():
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)

# Loop through each field in the sub-model
for sub_field_name, sub_field_type in field_type.__annotations__.items():
field = field_type.__fields__[sub_field_name]
help_text = (
field.description if field.description else "No description available"
)

origin = get_origin(sub_field_type)
if origin is Union:
sub_field_type = next(
t for t in get_args(sub_field_type) if t is not type(None)
)
elif origin is List:
sub_field_type = get_args(sub_field_type)[0]

# Map Pydantic types to argparse types
if isinstance(sub_field_type, type) and issubclass(
sub_field_type, (int, float, str, bool)
):
arg_type = sub_field_type
else:
arg_type = str # Default to string for unknown types

group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text)
group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}")

# Check if the field_type is a Pydantic model or something with __annotations__
if hasattr(field_type, "__annotations__"):
# Loop through each sub-field in the model
for sub_field_name, sub_field_type in field_type.__annotations__.items():
field = field_type.__fields__[sub_field_name]
add_field_to_group(group, sub_field_name, sub_field_type, field)
else:
# Handle cases where the field_type is not a Pydantic model
print(f"{field_name=}")
arg_type = map_pydantic_type_to_argparse(field_type)
group.add_argument(f"--{field_name}", type=arg_type, help=f"Argument for {field_name}")

return parser


def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Broad conversion of surface level arg groups to dictionaries"""

Expand Down
19 changes: 0 additions & 19 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,6 @@ async def unload_embedding_model():
embeddings_container = None


# FIXME: Maybe make this a one-time function instead of a dynamic default
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""

default_keys = unwrap(config.model.use_as_default, [])

# Add extra keys to defaults
default_keys.append("embeddings_device")

if key in default_keys:
# Is this a draft model load parameter?
if model_type == "draft":
return config.draft_model.get(key)
elif model_type == "embedding":
return config.embeddings.get(key)
else:
return config.model.get(key)


async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""

Expand Down
14 changes: 8 additions & 6 deletions common/tabby_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TabbyConfig(tabby_config_model):
"""Common config class for TabbyAPI."""

# Persistent defaults
# TODO: make this not a dict
model_defaults: dict = {}

def load(self, arguments: Optional[dict] = None):
Expand All @@ -34,12 +35,13 @@ def load(self, arguments: Optional[dict] = None):
setattr(self, field, model.parse_obj(value))

# Set model defaults dict once to prevent on-demand reconstruction
default_keys = unwrap(self.model.get("use_as_default"), [])
for key in default_keys:
if key in self.model:
self.model_defaults[key] = config.model[key]
elif key in self.draft_model:
self.model_defaults[key] = config.draft_model[key]
# TODO: refactor this use a pydantic model
default_fields = self.model.use_as_default
for field in default_fields:
if hasattr(self.model, field):
self.model_defaults[field] = getattr(config.model, field)
elif hasattr(self.draft_model, field):
self.model_defaults[field] = getattr(config.draft_model, field)

def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""
Expand Down
1 change: 0 additions & 1 deletion endpoints/core/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from common.tabby_config import config
from common.utils import unwrap
from common.config_models import logging_config_model
from common.model import get_config_default

class ModelCardParameters(BaseModel):
"""Represents model card parameters."""
Expand Down
6 changes: 1 addition & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ async def entrypoint_async():
port = fallback_port

# Initialize auth keys
load_auth_keys(config.network.disable_auth)

# Override the generation log options if given
if config.logging:
gen_logging.update_from_dict(config.logging)
await load_auth_keys(config.network.disable_auth)

gen_logging.broadcast_status()

Expand Down

0 comments on commit 6d28fdd

Please sign in to comment.