Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into refactor-sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
SecretiveShell committed Sep 26, 2024
2 parents cef9a03 + 56ce82e commit 4c17b87
Show file tree
Hide file tree
Showing 16 changed files with 204 additions and 38 deletions.
19 changes: 19 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

from ruamel.yaml import YAML

from common.health import HealthManager

from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
Expand Down Expand Up @@ -956,6 +958,13 @@ def check_unsupported_settings(self, **kwargs):
Meant for dev wheels!
"""

if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "xtc_probability"
):
logger.warning(
"XTC is not supported by the currently " "installed ExLlamaV2 version."
)

return kwargs

async def generate_gen(
Expand Down Expand Up @@ -1001,6 +1010,14 @@ async def generate_gen(
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
gen_settings.skew = unwrap(kwargs.get("skew"), 0)

# XTC
xtc_probability = unwrap(kwargs.get("xtc_probability"), 0.0)
if xtc_probability > 0.0:
gen_settings.xtc_probability = xtc_probability

# 0.1 is the default for this value
gen_settings.xtc_threshold = unwrap(kwargs.get("xtc_threshold", 0.1))

# DynaTemp settings
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
min_temp = unwrap(kwargs.get("min_temp"), 1.0)
Expand Down Expand Up @@ -1373,6 +1390,8 @@ async def generate_gen(
)
asyncio.ensure_future(self.create_generator())

await HealthManager.add_unhealthy_event(ex)

raise ex
finally:
# Log generation options to console
Expand Down
2 changes: 1 addition & 1 deletion backends/exllamav2/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_exllama_version():
"pip install --upgrade .[amd]\n\n"
)

if not dependencies.exl2:
if not dependencies.exllamav2:
raise SystemExit(("Exllamav2 is not installed.\n" + install_message))

required_version = version.parse("0.2.2")
Expand Down
13 changes: 10 additions & 3 deletions common/args.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Argparser for overriding config values"""

import argparse
from typing import Optional
from pydantic import BaseModel

from common.config_models import TabbyConfigModel
from common.utils import is_list_type, unwrap_optional_type
from common.utils import is_list_type, unwrap, unwrap_optional_type


def add_field_to_group(group, field_name, field_type, field) -> None:
Expand All @@ -23,12 +24,18 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
group.add_argument(f"--{field_name}", **kwargs)


def init_argparser() -> argparse.ArgumentParser:
def init_argparser(
existing_parser: Optional[argparse.ArgumentParser] = None,
) -> argparse.ArgumentParser:
"""
Initializes an argparse parser based on a Pydantic config schema.
If an existing provider is given, use that.
"""

parser = argparse.ArgumentParser(description="TabbyAPI server")
parser = unwrap(
existing_parser, argparse.ArgumentParser(description="TabbyAPI server")
)

# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
Expand Down
42 changes: 42 additions & 0 deletions common/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
from collections import deque
from datetime import datetime, timezone
from functools import partial
from pydantic import BaseModel, Field
from typing import Union


class UnhealthyEvent(BaseModel):
"""Represents an error that makes the system unhealthy"""

time: datetime = Field(
default_factory=partial(datetime.now, timezone.utc),
description="Time the error occurred in UTC time",
)
description: str = Field("Unknown error", description="The error message")


class HealthManagerClass:
"""Class to manage the health global state"""

def __init__(self):
# limit the max stored errors to 100 to avoid a memory leak
self.issues: deque[UnhealthyEvent] = deque(maxlen=100)
self._lock = asyncio.Lock()

async def add_unhealthy_event(self, error: Union[str, Exception]):
"""Add a new unhealthy event"""
async with self._lock:
if isinstance(error, Exception):
error = f"{error.__class__.__name__}: {str(error)}"
self.issues.append(UnhealthyEvent(description=error))

async def is_service_healthy(self) -> tuple[bool, list[UnhealthyEvent]]:
"""Check if the service is healthy"""
async with self._lock:
healthy = len(self.issues) == 0
return healthy, list(self.issues)


# Create an instance of the global state manager
HealthManager = HealthManagerClass()
2 changes: 1 addition & 1 deletion common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from common.tabby_config import config
from common.optional_dependencies import dependencies

if dependencies.exl2:
if dependencies.exllamav2:
from backends.exllamav2.model import ExllamaV2Container

# Global model container
Expand Down
2 changes: 1 addition & 1 deletion common/optional_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def extras(self) -> bool:

@computed_field
@property
def exl2(self) -> bool:
def inference(self) -> bool:
return self.torch and self.exllamav2 and self.flash_attn


Expand Down
12 changes: 10 additions & 2 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from typing import Dict, List, Optional, Union

from common.utils import unwrap, prune_dict
from common.utils import filter_none_values, unwrap


# Common class for sampler params
Expand Down Expand Up @@ -129,6 +129,14 @@ class BaseSamplerRequest(BaseModel):
examples=[0.0],
)

xtc_probability: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("xtc_probability", 0.0),
)

xtc_threshold: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("xtc_threshold", 0.1)
)

frequency_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0),
ge=0,
Expand Down Expand Up @@ -337,7 +345,7 @@ def overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""

if isinstance(new_overrides, dict):
overrides_container.overrides = prune_dict(new_overrides)
overrides_container.overrides = filter_none_values(new_overrides)
else:
raise TypeError("New sampler overrides must be a dict!")

Expand Down
7 changes: 5 additions & 2 deletions common/tabby_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ruamel.yaml.scalarstring import PreservedScalarString

from common.config_models import BaseConfigModel, TabbyConfigModel
from common.utils import merge_dicts, unwrap
from common.utils import merge_dicts, filter_none_values, unwrap

yaml = YAML(typ=["rt", "safe"])

Expand All @@ -33,6 +33,9 @@ def load(self, arguments: Optional[dict] = None):
if not arguments_dict.get("actions"):
configs.insert(0, self._from_file(pathlib.Path("config.yml")))

# Remove None (aka unset) values from the configs and merge them together
# This should be less expensive than pruning the entire merged dictionary
configs = filter_none_values(configs)
merged_config = merge_dicts(*configs)

# validate and update config
Expand Down Expand Up @@ -135,7 +138,7 @@ def _from_args(self, args: dict):
"""loads config from the provided arguments"""
config = {}

config_override = args.get("options", {}).get("config", None)
config_override = args.get("config", {}).get("config", None)
if config_override:
logger.info("Config file override detected in args.")
config = self._from_file(pathlib.Path(config_override))
Expand Down
15 changes: 11 additions & 4 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ def coalesce(*args):
return next((arg for arg in args if arg is not None), None)


def prune_dict(input_dict: Dict) -> Dict:
"""Trim out instances of None from a dictionary."""

return {k: v for k, v in input_dict.items() if v is not None}
def filter_none_values(collection: Union[dict, list]) -> Union[dict, list]:
"""Remove None values from a collection."""

if isinstance(collection, dict):
return {
k: filter_none_values(v) for k, v in collection.items() if v is not None
}
elif isinstance(collection, list):
return [filter_none_values(i) for i in collection if i is not None]
else:
return collection


def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def format_prompt_with_template(
# Deal with list in messages.content
# Just replace the content list with the very first text message
for message in data.messages:
if message["role"] == "user" and isinstance(message["content"], list):
if isinstance(message["content"], list):
message["content"] = next(
(
content["text"]
Expand Down
15 changes: 12 additions & 3 deletions endpoints/core/router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import pathlib
from sys import maxsize
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette import EventSourceResponse

from common import model, sampling
Expand All @@ -12,6 +12,7 @@
from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from common.health import HealthManager
from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
Expand All @@ -22,6 +23,7 @@
ModelLoadRequest,
ModelLoadResponse,
)
from endpoints.core.types.health import HealthCheckResponse
from endpoints.core.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
Expand All @@ -47,9 +49,16 @@

# Healthcheck endpoint
@router.get("/health")
async def healthcheck():
async def healthcheck(response: Response) -> HealthCheckResponse:
"""Get the current service health status"""
return {"status": "healthy"}
healthy, issues = await HealthManager.is_service_healthy()

if not healthy:
response.status_code = 503

return HealthCheckResponse(
status="healthy" if healthy else "unhealthy", issues=issues
)


# Model list endpoint
Expand Down
15 changes: 15 additions & 0 deletions endpoints/core/types/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Literal
from pydantic import BaseModel, Field

from common.health import UnhealthyEvent


class HealthCheckResponse(BaseModel):
"""System health status"""

status: Literal["healthy", "unhealthy"] = Field(
"healthy", description="System health status"
)
issues: list[UnhealthyEvent] = Field(
default_factory=list, description="List of issues"
)
6 changes: 6 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ typical:
skew:
override: 0.0
force: false
xtc_probability:
override: 0.0
force: false
xtc_threshold:
override: 0.1
force: false

# MARK: Penalty settings
frequency_penalty:
Expand Down
5 changes: 5 additions & 0 deletions start.bat
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ if exist "%CONDA_PREFIX%" (
if not exist "venv\" (
echo Venv doesn't exist! Creating one for you.
python -m venv venv

if exist "start_options.json" (
echo Removing old start_options.json
del start_options.json
)
)

call .\venv\Scripts\activate.bat
Expand Down
Loading

0 comments on commit 4c17b87

Please sign in to comment.