Skip to content

Commit

Permalink
Merge pull request #185 from SecretiveShell/refactor-config-loading
Browse files Browse the repository at this point in the history
Refactor config loading
  • Loading branch information
kingbri1 authored Sep 5, 2024
2 parents 98768bf + 1c9991f commit ec7f64d
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 176 deletions.
107 changes: 0 additions & 107 deletions common/config.py

This file was deleted.

6 changes: 3 additions & 3 deletions common/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from rich.progress import Progress
from typing import List, Optional

from common.config import lora_config, model_config
from common.logger import get_progress_bar
from common.tabby_config import config
from common.utils import unwrap


Expand Down Expand Up @@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
"""Gets the download folder for the repo."""

if repo_type == "lora":
download_path = pathlib.Path(lora_config().get("lora_dir") or "loras")
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
else:
download_path = pathlib.Path(model_config().get("model_dir") or "models")
download_path = pathlib.Path(config.model.get("model_dir") or "models")

download_path = download_path / (folder_name or repo_id.split("/")[-1])
return download_path
Expand Down
13 changes: 5 additions & 8 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from loguru import logger
from typing import Optional

from common import config
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.utils import unwrap
from endpoints.utils import do_export_openapi

Expand Down Expand Up @@ -153,22 +153,19 @@ async def unload_embedding_model():
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""

model_config = config.model_config()
default_keys = unwrap(model_config.get("use_as_default"), [])
default_keys = unwrap(config.model.get("use_as_default"), [])

# Add extra keys to defaults
default_keys.append("embeddings_device")

if key in default_keys:
# Is this a draft model load parameter?
if model_type == "draft":
draft_config = config.draft_model_config()
return draft_config.get(key)
return config.draft_model.get(key)
elif model_type == "embedding":
embeddings_config = config.embeddings_config()
return embeddings_config.get(key)
return config.embeddings.get(key)
else:
return model_config.get(key)
return config.model.get(key)


async def check_model_container():
Expand Down
6 changes: 3 additions & 3 deletions common/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Optional
from uuid import uuid4

from common import config
from common.tabby_config import config
from common.utils import unwrap


Expand Down Expand Up @@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True):
"""Log a request error to the console."""

trace = traceback.format_exc()
send_trace = unwrap(config.network_config().get("send_tracebacks"), False)
send_trace = unwrap(config.network.get("send_tracebacks"), False)

error_message = TabbyRequestErrorMessage(
message=message, trace=trace if send_trace else None
Expand Down Expand Up @@ -134,7 +134,7 @@ def get_global_depends():

depends = [Depends(add_request_id)]

if config.logging_config().get("requests"):
if config.logging.get("requests"):
depends.append(Depends(log_request))

return depends
88 changes: 88 additions & 0 deletions common/tabby_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import yaml
import pathlib
from loguru import logger
from typing import Optional

from common.utils import unwrap, merge_dicts


class TabbyConfig:
network: dict = {}
logging: dict = {}
model: dict = {}
draft_model: dict = {}
lora: dict = {}
sampling: dict = {}
developer: dict = {}
embeddings: dict = {}

def load(self, arguments: Optional[dict] = None):
"""load the global application config"""

# config is applied in order of items in the list
configs = [
self._from_file(pathlib.Path("config.yml")),
self._from_args(unwrap(arguments, {})),
]

merged_config = merge_dicts(*configs)

self.network = unwrap(merged_config.get("network"), {})
self.logging = unwrap(merged_config.get("logging"), {})
self.model = unwrap(merged_config.get("model"), {})
self.draft_model = unwrap(merged_config.get("draft"), {})
self.lora = unwrap(merged_config.get("draft"), {})
self.sampling = unwrap(merged_config.get("sampling"), {})
self.developer = unwrap(merged_config.get("developer"), {})
self.embeddings = unwrap(merged_config.get("embeddings"), {})

def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""

# try loading from file
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
return unwrap(yaml.safe_load(config_file), {})
except FileNotFoundError:
logger.info(f"The '{config_path.name}' file cannot be found")
except Exception as exc:
logger.error(
f"The YAML config from '{config_path.name}' couldn't load because of "
f"the following error:\n\n{exc}"
)

# if no config file was loaded
return {}

def _from_args(self, args: dict):
"""loads config from the provided arguments"""
config = {}

config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Config file override detected in args.")
config = self.from_file(pathlib.Path(config_override))
return config # Return early if loading from file

for key in ["network", "model", "logging", "developer", "embeddings"]:
override = args.get(key)
if override:
if key == "logging":
# Strip the "log_" prefix from logging keys if present
override = {k.replace("log_", ""): v for k, v in override.items()}
config[key] = override

return config

def _from_environment(self):
"""loads configuration from environment variables"""

# TODO: load config from environment variables
# this means that we can have host default to 0.0.0.0 in docker for example
# this would also mean that docker containers no longer require a non
# default config file to be used
pass


# Create an empty instance of the config class
config: TabbyConfig = TabbyConfig()
19 changes: 19 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@ def prune_dict(input_dict):
return {k: v for k, v in input_dict.items() if v is not None}


def merge_dict(dict1, dict2):
"""Merge 2 dictionaries"""
for key, value in dict2.items():
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
merge_dict(dict1[key], value)
else:
dict1[key] = value
return dict1


def merge_dicts(*dicts):
"""Merge an arbitrary amount of dictionaries"""
result = {}
for dictionary in dicts:
result = merge_dict(result, dictionary)

return result


def flat_map(input_list):
"""Flattens a list of lists into a single list."""

Expand Down
7 changes: 4 additions & 3 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from sse_starlette import EventSourceResponse
from sys import maxsize

from common import config, model
from common import model
from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
Expand Down Expand Up @@ -64,7 +65,7 @@ async def completion_request(
data.prompt = "\n".join(data.prompt)

disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
config.developer.get("disable_request_streaming"), False
)

# Set an empty JSON schema if the request wants a JSON response
Expand Down Expand Up @@ -128,7 +129,7 @@ async def chat_completion_request(
data.json_schema = {"type": "object"}

disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
config.developer.get("disable_request_streaming"), False
)

if data.stream and not disable_request_streaming:
Expand Down
Loading

0 comments on commit ec7f64d

Please sign in to comment.