Skip to content

Commit

Permalink
Tree: Switch to asynchronous file handling
Browse files Browse the repository at this point in the history
Using aiofiles, there's no longer a possiblity of blocking file operations
that can hang up the event loop. In addition, partially migrate
classes to use asynchronous init instead of the normal python magic method.

The only exception is config, since that's handled in the synchonous
init before the event loop starts.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Sep 10, 2024
1 parent 54bfb77 commit 2c3bc71
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 36 deletions.
33 changes: 22 additions & 11 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""

import aiofiles
import asyncio
import gc
import math
Expand Down Expand Up @@ -106,13 +107,17 @@ class ExllamaV2Container:
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()

def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
@classmethod
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Primary initializer for model container.
Primary asynchronous initializer for model container.
Kwargs are located in config_sample.yml
"""

# Create a new instance as a "fake self"
self = cls()

self.quiet = quiet

# Initialize config
Expand Down Expand Up @@ -155,13 +160,13 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.prepare()

# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
Expand All @@ -171,7 +176,7 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
)

# Apply a model's config overrides while respecting user settings
kwargs = self.set_model_overrides(**kwargs)
kwargs = await self.set_model_overrides(**kwargs)

# MARK: User configuration

Expand Down Expand Up @@ -320,7 +325,7 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
self.cache_size = self.config.max_seq_len

# Try to set prompt template
self.prompt_template = self.find_prompt_template(
self.prompt_template = await self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)

Expand Down Expand Up @@ -373,16 +378,22 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2

def set_model_overrides(self, **kwargs):
# Return the created instance
return self

async def set_model_overrides(self, **kwargs):
"""Sets overrides from a model folder's config yaml."""

override_config_path = self.model_dir / "tabby_config.yml"

if not override_config_path.exists():
return kwargs

with open(override_config_path, "r", encoding="utf8") as override_config_file:
override_args = unwrap(yaml.safe_load(override_config_file), {})
async with aiofiles.open(
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
override_args = unwrap(yaml.safe_load(contents), {})

# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
Expand All @@ -393,7 +404,7 @@ def set_model_overrides(self, **kwargs):
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs

def find_prompt_template(self, prompt_template_name, model_directory):
async def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods."""

logger.info("Attempting to load a prompt template if present.")
Expand Down Expand Up @@ -431,7 +442,7 @@ def find_prompt_template(self, prompt_template_name, model_directory):
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = template_func()
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
Expand Down
15 changes: 10 additions & 5 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
application, it should be fine.
"""

import aiofiles
import secrets
import yaml
from fastapi import Header, HTTPException, Request
Expand Down Expand Up @@ -40,7 +41,7 @@ def verify_key(self, test_key: str, key_type: str):
DISABLE_AUTH: bool = False


def load_auth_keys(disable_from_config: bool):
async def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
Expand All @@ -57,17 +58,21 @@ def load_auth_keys(disable_from_config: bool):
return

try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
contents = await auth_file.read()
auth_keys_dict = yaml.safe_load(contents)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
)
AUTH_KEYS = new_auth_keys

with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
new_auth_yaml = yaml.safe_dump(
AUTH_KEYS.model_dump(), default_flow_style=False
)
await auth_file.write(new_auth_yaml)

logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
Expand Down
2 changes: 1 addition & 1 deletion common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing model.")
await unload_model()

container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)

model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress, **kwargs)
Expand Down
8 changes: 5 additions & 3 deletions common/sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Common functions for sampling parameters"""

import aiofiles
import json
import pathlib
import yaml
Expand Down Expand Up @@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
raise TypeError("New sampler overrides must be a dict!")


def overrides_from_file(preset_name: str):
async def overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""

preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
preset = yaml.safe_load(contents)
overrides_from_dict(preset)

logger.info("Applied sampler overrides from file.")
Expand Down
2 changes: 1 addition & 1 deletion common/tabby_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TabbyConfig:
embeddings: dict = {}

def load(self, arguments: Optional[dict] = None):
"""load the global application config"""
"""Synchronously loads the global application config"""

# config is applied in order of items in the list
configs = [
Expand Down
17 changes: 11 additions & 6 deletions common/templating.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""

import aiofiles
import json
import pathlib
from importlib.metadata import version as package_version
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(self, name: str, raw_template: str):
self.template = self.compile(raw_template)

@classmethod
def from_file(self, template_path: pathlib.Path):
async def from_file(self, template_path: pathlib.Path):
"""Get a template from a jinja file."""

# Add the jinja extension if it isn't provided
Expand All @@ -121,10 +122,13 @@ def from_file(self, template_path: pathlib.Path):
template_path = template_path.with_suffix(".jinja")

if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
async with aiofiles.open(
template_path, "r", encoding="utf8"
) as raw_template_stream:
contents = await raw_template_stream.read()
return PromptTemplate(
name=template_name,
raw_template=raw_template_stream.read(),
raw_template=contents,
)
else:
# Let the user know if the template file isn't found
Expand All @@ -133,15 +137,16 @@ def from_file(self, template_path: pathlib.Path):
)

@classmethod
def from_model_json(
async def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')

with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
contents = await config_file.read()
model_config = json.loads(contents)
chat_template = model_config.get(key)

if not chat_template:
Expand Down
14 changes: 9 additions & 5 deletions common/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
Expand All @@ -15,11 +16,11 @@ class GenerationConfig(BaseModel):
bad_words_ids: Optional[List[List[int]]] = None

@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""

generation_config_path = model_directory / "generation_config.json"
with open(
async with aiofiles.open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
generation_config_dict = json.load(generation_config_json)
Expand All @@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel):
badwordsids: Optional[str] = None

@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""

hf_config_path = model_directory / "config.json"
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
hf_config_dict = json.load(hf_config_json)
async with aiofiles.open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return self.model_validate(hf_config_dict)

def get_badwordsids(self):
Expand Down
4 changes: 2 additions & 2 deletions endpoints/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest):

try:
template_path = pathlib.Path("templates") / data.name
model.container.prompt_template = PromptTemplate.from_file(template_path)
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
Expand Down Expand Up @@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):

if data.preset:
try:
sampling.overrides_from_file(data.preset)
await sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def entrypoint_async():
port = fallback_port

# Initialize auth keys
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
await load_auth_keys(unwrap(config.network.get("disable_auth"), False))

# Override the generation log options if given
if config.logging:
Expand All @@ -62,7 +62,7 @@ async def entrypoint_async():
sampling_override_preset = config.sampling.get("override_preset")
if sampling_override_preset:
try:
sampling.overrides_from_file(sampling_override_preset)
await sampling.overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))

Expand Down

0 comments on commit 2c3bc71

Please sign in to comment.