diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index f1b92462..c3a61f21 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -48,8 +48,8 @@ jobs: npm install @redocly/cli -g - name: Export OpenAPI docs run: | - EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-oai.json" --api-servers OAI - EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-kobold.json" --api-servers kobold + python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold + python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI - name: Build and store Redocly site run: | mkdir static diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index b7a9f54f..6148b44e 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -5,34 +5,6 @@ from loguru import logger -def check_exllama_version(): - """Verifies the exllama version""" - - required_version = version.parse("0.2.2") - current_version = version.parse(package_version("exllamav2").split("+")[0]) - - unsupported_message = ( - f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " - f"or greater. Your current version is {current_version}.\n" - "Please update your environment by running an update script " - "(update_scripts/" - f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "For CUDA 11.8:\n" - "pip install --upgrade .[cu118]\n\n" - "For ROCm:\n" - "pip install --upgrade .[amd]\n\n" - ) - - if current_version < required_version: - raise SystemExit(unsupported_message) - else: - logger.info(f"ExllamaV2 version: {current_version}") - - def hardware_supports_flash_attn(gpu_device_list: list[int]): """ Check whether all GPUs in list support FA2 diff --git a/backends/exllamav2/version.py b/backends/exllamav2/version.py new file mode 100644 index 00000000..fc4afb1e --- /dev/null +++ b/backends/exllamav2/version.py @@ -0,0 +1,39 @@ +import platform +from packaging import version +from importlib.metadata import version as package_version +from loguru import logger +from common.optional_dependencies import dependencies + + +def check_exllama_version(): + """Verifies the exllama version""" + + install_message = ( + "Please update your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For CUDA 11.8:\n" + "pip install --upgrade .[cu118]\n\n" + "For ROCm:\n" + "pip install --upgrade .[amd]\n\n" + ) + + if not dependencies.exl2: + raise SystemExit(("Exllamav2 is not installed.\n" + install_message)) + + required_version = version.parse("0.2.2") + current_version = version.parse(package_version("exllamav2").split("+")[0]) + + unsupported_message = ( + f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " + f"or greater. Your current version is {current_version}.\n" + install_message + ) + + if current_version < required_version: + raise SystemExit(unsupported_message) + else: + logger.info(f"ExllamaV2 version: {current_version}") diff --git a/backends/infinity/model.py b/backends/infinity/model.py index c48a42c2..b593fae3 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -5,16 +5,12 @@ from typing import List, Optional from common.utils import unwrap +from common.optional_dependencies import dependencies # Conditionally import infinity to sidestep its logger -has_infinity_emb: bool = False -try: +if dependencies.extras: from infinity_emb import EngineArgs, AsyncEmbeddingEngine - has_infinity_emb = True -except ImportError: - pass - class InfinityContainer: model_dir: pathlib.Path @@ -23,7 +19,7 @@ class InfinityContainer: # Conditionally set the type hint based on importablity # TODO: Clean this up - if has_infinity_emb: + if dependencies.extras: engine: Optional[AsyncEmbeddingEngine] = None else: engine = None diff --git a/common/actions.py b/common/actions.py index 9dcd5dd4..c7f0a717 100644 --- a/common/actions.py +++ b/common/actions.py @@ -3,13 +3,12 @@ from common.tabby_config import config, generate_config_file from endpoints.server import export_openapi -from endpoints.utils import do_export_openapi def branch_to_actions() -> bool: """Checks if a optional action needs to be run.""" - if config.actions.export_openapi or do_export_openapi: + if config.actions.export_openapi: openapi_json = export_openapi() with open(config.actions.openapi_export_path, "w") as f: diff --git a/common/model.py b/common/model.py index 5fdfc5bb..296f6847 100644 --- a/common/model.py +++ b/common/model.py @@ -13,22 +13,20 @@ from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config -from endpoints.utils import do_export_openapi +from common.optional_dependencies import dependencies -if not do_export_openapi: +if dependencies.exl2: from backends.exllamav2.model import ExllamaV2Container # Global model container container: Optional[ExllamaV2Container] = None embeddings_container = None - # Type hint the infinity emb container if it exists - from backends.infinity.model import has_infinity_emb - if has_infinity_emb: - from backends.infinity.model import InfinityContainer +if dependencies.extras: + from backends.infinity.model import InfinityContainer - embeddings_container: Optional[InfinityContainer] = None + embeddings_container: Optional[InfinityContainer] = None class ModelType(Enum): @@ -121,7 +119,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs): global embeddings_container # Break out if infinity isn't installed - if not has_infinity_emb: + if not dependencies.extras: raise ImportError( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py new file mode 100644 index 00000000..fd2efe3e --- /dev/null +++ b/common/optional_dependencies.py @@ -0,0 +1,52 @@ +"""Construct a model of all optional dependencies""" + +import importlib.util +from pydantic import BaseModel, computed_field + + +# Declare the exported parts of this module +__all__ = ["dependencies"] + + +class DependenciesModel(BaseModel): + """Model of which optional dependencies are installed.""" + + torch: bool + exllamav2: bool + flash_attn: bool + outlines: bool + infinity_emb: bool + sentence_transformers: bool + + @computed_field + @property + def extras(self) -> bool: + return self.outlines and self.infinity_emb and self.sentence_transformers + + @computed_field + @property + def exl2(self) -> bool: + return self.torch and self.exllamav2 and self.flash_attn + + +def is_installed(package_name: str) -> bool: + """Utility function to check if a package is installed.""" + + spec = importlib.util.find_spec(package_name) + return spec is not None + + +def get_installed_deps() -> DependenciesModel: + """Check if optional dependencies are installed by looping over the fields.""" + + fields = DependenciesModel.model_fields + + installed_deps = {} + + for field_name in fields.keys(): + installed_deps[field_name] = is_installed(field_name) + + return DependenciesModel(**installed_deps) + + +dependencies = get_installed_deps() diff --git a/endpoints/utils.py b/endpoints/utils.py deleted file mode 100644 index 291fe349..00000000 --- a/endpoints/utils.py +++ /dev/null @@ -1,3 +0,0 @@ -import os - -do_export_openapi = os.getenv("EXPORT_OPENAPI", "").lower() in ("true", "1") diff --git a/main.py b/main.py index b83fda70..06db5d53 100644 --- a/main.py +++ b/main.py @@ -17,10 +17,8 @@ from common.signals import signal_handler from common.tabby_config import config from endpoints.server import start_api -from endpoints.utils import do_export_openapi -if not do_export_openapi: - from backends.exllamav2.utils import check_exllama_version +from backends.exllamav2.version import check_exllama_version async def entrypoint_async():