Skip to content

Commit

Permalink
fix issues with optional dependencies (#204)
Browse files Browse the repository at this point in the history
* fix issues with optional dependencies

* format document

* Tree: Format and comment
  • Loading branch information
SecretiveShell authored Sep 20, 2024
1 parent 75af974 commit 3aeddc5
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 53 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ jobs:
npm install @redocly/cli -g
- name: Export OpenAPI docs
run: |
EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-oai.json" --api-servers OAI
EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-kobold.json" --api-servers kobold
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
- name: Build and store Redocly site
run: |
mkdir static
Expand Down
28 changes: 0 additions & 28 deletions backends/exllamav2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,6 @@
from loguru import logger


def check_exllama_version():
"""Verifies the exllama version"""

required_version = version.parse("0.2.2")
current_version = version.parse(package_version("exllamav2").split("+")[0])

unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}.\n"
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"For CUDA 11.8:\n"
"pip install --upgrade .[cu118]\n\n"
"For ROCm:\n"
"pip install --upgrade .[amd]\n\n"
)

if current_version < required_version:
raise SystemExit(unsupported_message)
else:
logger.info(f"ExllamaV2 version: {current_version}")


def hardware_supports_flash_attn(gpu_device_list: list[int]):
"""
Check whether all GPUs in list support FA2
Expand Down
39 changes: 39 additions & 0 deletions backends/exllamav2/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import platform
from packaging import version
from importlib.metadata import version as package_version
from loguru import logger
from common.optional_dependencies import dependencies


def check_exllama_version():
"""Verifies the exllama version"""

install_message = (
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"For CUDA 11.8:\n"
"pip install --upgrade .[cu118]\n\n"
"For ROCm:\n"
"pip install --upgrade .[amd]\n\n"
)

if not dependencies.exl2:
raise SystemExit(("Exllamav2 is not installed.\n" + install_message))

required_version = version.parse("0.2.2")
current_version = version.parse(package_version("exllamav2").split("+")[0])

unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}.\n" + install_message
)

if current_version < required_version:
raise SystemExit(unsupported_message)
else:
logger.info(f"ExllamaV2 version: {current_version}")
10 changes: 3 additions & 7 deletions backends/infinity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@
from typing import List, Optional

from common.utils import unwrap
from common.optional_dependencies import dependencies

# Conditionally import infinity to sidestep its logger
has_infinity_emb: bool = False
try:
if dependencies.extras:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine

has_infinity_emb = True
except ImportError:
pass


class InfinityContainer:
model_dir: pathlib.Path
Expand All @@ -23,7 +19,7 @@ class InfinityContainer:

# Conditionally set the type hint based on importablity
# TODO: Clean this up
if has_infinity_emb:
if dependencies.extras:
engine: Optional[AsyncEmbeddingEngine] = None
else:
engine = None
Expand Down
3 changes: 1 addition & 2 deletions common/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

from common.tabby_config import config, generate_config_file
from endpoints.server import export_openapi
from endpoints.utils import do_export_openapi


def branch_to_actions() -> bool:
"""Checks if a optional action needs to be run."""

if config.actions.export_openapi or do_export_openapi:
if config.actions.export_openapi:
openapi_json = export_openapi()

with open(config.actions.openapi_export_path, "w") as f:
Expand Down
14 changes: 6 additions & 8 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from endpoints.utils import do_export_openapi
from common.optional_dependencies import dependencies

if not do_export_openapi:
if dependencies.exl2:
from backends.exllamav2.model import ExllamaV2Container

# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None

# Type hint the infinity emb container if it exists
from backends.infinity.model import has_infinity_emb

if has_infinity_emb:
from backends.infinity.model import InfinityContainer
if dependencies.extras:
from backends.infinity.model import InfinityContainer

embeddings_container: Optional[InfinityContainer] = None
embeddings_container: Optional[InfinityContainer] = None


class ModelType(Enum):
Expand Down Expand Up @@ -121,7 +119,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
global embeddings_container

# Break out if infinity isn't installed
if not has_infinity_emb:
if not dependencies.extras:
raise ImportError(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
Expand Down
52 changes: 52 additions & 0 deletions common/optional_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Construct a model of all optional dependencies"""

import importlib.util
from pydantic import BaseModel, computed_field


# Declare the exported parts of this module
__all__ = ["dependencies"]


class DependenciesModel(BaseModel):
"""Model of which optional dependencies are installed."""

torch: bool
exllamav2: bool
flash_attn: bool
outlines: bool
infinity_emb: bool
sentence_transformers: bool

@computed_field
@property
def extras(self) -> bool:
return self.outlines and self.infinity_emb and self.sentence_transformers

@computed_field
@property
def exl2(self) -> bool:
return self.torch and self.exllamav2 and self.flash_attn


def is_installed(package_name: str) -> bool:
"""Utility function to check if a package is installed."""

spec = importlib.util.find_spec(package_name)
return spec is not None


def get_installed_deps() -> DependenciesModel:
"""Check if optional dependencies are installed by looping over the fields."""

fields = DependenciesModel.model_fields

installed_deps = {}

for field_name in fields.keys():
installed_deps[field_name] = is_installed(field_name)

return DependenciesModel(**installed_deps)


dependencies = get_installed_deps()
3 changes: 0 additions & 3 deletions endpoints/utils.py

This file was deleted.

4 changes: 1 addition & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from common.signals import signal_handler
from common.tabby_config import config
from endpoints.server import start_api
from endpoints.utils import do_export_openapi

if not do_export_openapi:
from backends.exllamav2.utils import check_exllama_version
from backends.exllamav2.version import check_exllama_version


async def entrypoint_async():
Expand Down

0 comments on commit 3aeddc5

Please sign in to comment.