Skip to content

Commit

Permalink
Merge pull request #189 from SecretiveShell/pydantic-config
Browse files Browse the repository at this point in the history
Update the config system to use Pydantic internally, bridging the gap between the YAML and args. YAML is still the preferred method to configure TabbyAPI, but args are no longer separately maintained.
  • Loading branch information
bdashore3 authored Sep 19, 2024
2 parents 2a41910 + 4cf8551 commit 03189bc
Show file tree
Hide file tree
Showing 22 changed files with 1,092 additions and 632 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ jobs:
npm install @redocly/cli -g
- name: Export OpenAPI docs
run: |
EXPORT_OPENAPI=1 python main.py
mv openapi.json openapi-oai.json
EXPORT_OPENAPI=1 python main.py --api-servers kobold
mv openapi.json openapi-kobold.json
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
- name: Build and store Redocly site
run: |
mkdir static
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,6 @@ openapi.json

# Infinity-emb cache
.infinity_cache/

# Backup files
*.bak
7 changes: 5 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from loguru import logger
from typing import List, Optional, Union

import yaml
from ruamel.yaml import YAML

from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
Expand Down Expand Up @@ -379,7 +379,10 @@ async def set_model_overrides(self, **kwargs):
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
override_args = unwrap(yaml.safe_load(contents), {})

# Create a temporary YAML parser
yaml = YAML(typ="safe")
override_args = unwrap(yaml.load(contents), {})

# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
Expand Down
27 changes: 27 additions & 0 deletions common/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json
from loguru import logger

from common.tabby_config import config, generate_config_file
from endpoints.server import export_openapi


def branch_to_actions() -> bool:
"""Checks if a optional action needs to be run."""

if config.actions.export_openapi:
openapi_json = export_openapi()

with open(config.actions.openapi_export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info(
"Successfully wrote OpenAPI spec to "
+ f"{config.actions.openapi_export_path}"
)
elif config.actions.export_config:
generate_config_file(filename=config.actions.config_export_path)
else:
# did not branch
return False

# branched and ran an action
return True
274 changes: 40 additions & 234 deletions common/args.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,60 @@
"""Argparser for overriding config values"""

import argparse
from pydantic import BaseModel

from common.config_models import TabbyConfigModel
from common.utils import is_list_type, unwrap_optional_type

def str_to_bool(value):
"""Converts a string into a boolean value"""

if value.lower() in {"false", "f", "0", "no", "n"}:
return False
elif value.lower() in {"true", "t", "1", "yes", "y"}:
return True
raise ValueError(f"{value} is not a valid boolean value")


def argument_with_auto(value):
def add_field_to_group(group, field_name, field_type, field) -> None:
"""
Argparse type wrapper for any argument that has an automatic option.
Ex. rope_alpha
Adds a Pydantic field to an argparse argument group.
"""

if value == "auto":
return "auto"
kwargs = {
"help": field.description if field.description else "No description available",
}

try:
return float(value)
except ValueError as ex:
raise argparse.ArgumentTypeError(
'This argument only takes a type of float or "auto"'
) from ex
# If the inner type contains a list, specify argparse as such
if is_list_type(field_type):
kwargs["nargs"] = "+"

group.add_argument(f"--{field_name}", **kwargs)

def init_argparser():
"""Creates an argument parser that any function can use"""

parser = argparse.ArgumentParser(
epilog="NOTE: These args serve to override parts of the config. "
+ "It's highly recommended to edit config.yml for all options and "
+ "better descriptions!"
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
add_config_args(parser)
def init_argparser() -> argparse.ArgumentParser:
"""
Initializes an argparse parser based on a Pydantic config schema.
"""

parser = argparse.ArgumentParser(description="TabbyAPI server")

# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional_type(field_info.annotation)
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)

# Check if the field_type is a Pydantic model
if issubclass(field_type, BaseModel):
for sub_field_name, sub_field_info in field_type.model_fields.items():
sub_field_name = sub_field_name.replace("_", "-")
sub_field_type = sub_field_info.annotation
add_field_to_group(
group, sub_field_name, sub_field_type, sub_field_info
)
else:
field_name = field_name.replace("_", "-")
group.add_argument(f"--{field_name}", help=f"Argument for {field_name}")

return parser


def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict:
"""Broad conversion of surface level arg groups to dictionaries"""

arg_groups = {}
Expand All @@ -64,201 +68,3 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars
arg_groups[group.title] = group_dict

return arg_groups


def add_config_args(parser: argparse.ArgumentParser):
"""Adds config arguments"""

parser.add_argument(
"--config", type=str, help="Path to an overriding config.yml file"
)


def add_network_args(parser: argparse.ArgumentParser):
"""Adds networking arguments"""

network_group = parser.add_argument_group("network")
network_group.add_argument("--host", type=str, help="The IP to host on")
network_group.add_argument("--port", type=int, help="The port to host on")
network_group.add_argument(
"--disable-auth",
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)
network_group.add_argument(
"--send-tracebacks",
type=str_to_bool,
help="Decide whether to send error tracebacks over the API",
)
network_group.add_argument(
"--api-servers",
type=str,
nargs="+",
help="API servers to enable. Options: (OAI, Kobold)",
)


def add_model_args(parser: argparse.ArgumentParser):
"""Adds model arguments"""

model_group = parser.add_argument_group("model")
model_group.add_argument(
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--use-dummy-models",
type=str_to_bool,
help="Add dummy OAI model names for API queries",
)
model_group.add_argument(
"--use-as-default",
type=str,
nargs="+",
help="Names of args to use as a default fallback for API load requests ",
)
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
model_group.add_argument(
"--override-base-seq-len",
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--tensor-parallel",
type=str_to_bool,
help="Use tensor parallelism to load models",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--autosplit-reserve",
type=int,
nargs="+",
help="Reserve VRAM used for autosplit loading (in MBs) ",
)
model_group.add_argument(
"--gpu-split",
type=float,
nargs="+",
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument(
"--rope-alpha",
type=argument_with_auto,
help="Sets rope_alpha for NTK",
)
model_group.add_argument(
"--cache-mode",
type=str,
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
)
model_group.add_argument(
"--cache-size",
type=int,
help="The size of the prompt cache (in number of tokens) to allocate",
)
model_group.add_argument(
"--chunk-size",
type=int,
help="Chunk size for prompt ingestion",
)
model_group.add_argument(
"--max-batch-size",
type=int,
help="Maximum amount of prompts to process at one time",
)
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the jinja2 prompt template for chat completions",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--fasttensors",
type=str_to_bool,
help="Possibly increases model loading speeds",
)


def add_logging_args(parser: argparse.ArgumentParser):
"""Adds logging arguments"""

logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
)
logging_group.add_argument(
"--log-generation-params",
type=str_to_bool,
help="Enable generation parameter logging",
)
logging_group.add_argument(
"--log-requests",
type=str_to_bool,
help="Enable request logging",
)


def add_developer_args(parser: argparse.ArgumentParser):
"""Adds developer-specific arguments"""

developer_group = parser.add_argument_group("developer")
developer_group.add_argument(
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
)
developer_group.add_argument(
"--disable-request-streaming",
type=str_to_bool,
help="Disables API request streaming",
)
developer_group.add_argument(
"--cuda-malloc-backend",
type=str_to_bool,
help="Runs with the pytorch CUDA malloc backend",
)
developer_group.add_argument(
"--uvloop",
type=str_to_bool,
help="Run asyncio using Uvloop or Winloop",
)


def add_sampling_args(parser: argparse.ArgumentParser):
"""Adds sampling-specific arguments"""

sampling_group = parser.add_argument_group("sampling")
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)


def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""

embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)
Loading

0 comments on commit 03189bc

Please sign in to comment.