diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 138e7167..d484a077 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements-dev.txt + pip install .[dev] - name: Format and show diff with ruff run: | ruff format --diff diff --git a/.gitignore b/.gitignore index f77b7f9e..1c597fad 100644 --- a/.gitignore +++ b/.gitignore @@ -196,3 +196,6 @@ templates/* # Sampler overrides folder sampler_overrides/* !sampler_overrides/sample_preset.yml + +# Gpu lib preferences file +gpu_lib.txt diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index 53c7bd44..00000000 --- a/.ruff.toml +++ /dev/null @@ -1,111 +0,0 @@ -# Exclude a variety of commonly ignored directories. -exclude = [ - ".git", - ".git-rewrite", - ".mypy_cache", - ".pyenv", - ".pytest_cache", - ".ruff_cache", - ".venv", - ".vscode", - "__pypackages__", - "_build", - "build", - "dist", - "node_modules", - "site-packages", - "venv", -] - -# Same as Black. -line-length = 88 -indent-width = 4 - -# Assume Python 3.10 -target-version = "py310" - -[lint] -# Enable preview -preview = true - -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or -# McCabe complexity (`C901`) by default. -# Enable flake8-bugbear (`B`) rules, in addition to the defaults. -select = ["E4", "E7", "E9", "F", "B"] -extend-select = [ - "D419", # empty-docstring - "PLC2401", # non-ascii-name - "E501", # line-too-long - "W291", # trailing-whitespace - "PLC0414", # useless-import-alias - "E999", # syntax-error - "PLE0101", # return-in-init - "F706", # return-outside-function - "F704", # yield-outside-function - "PLE0116", # continue-in-finally - "PLE0117", # nonlocal-without-binding - "PLE0241", # duplicate-bases - "PLE0302", # unexpected-special-method-signature - "PLE0604", # invalid-all-object - "PLE0704", # misplaced-bare-raise - "PLE1205", # logging-too-many-args - "PLE1206", # logging-too-few-args - "PLE1307", # bad-string-format-type - "PLE1310", # bad-str-strip-call - "PLE1507", # invalid-envvar-value - "PLR0124", # comparison-with-itself - "PLR0202", # no-classmethod-decorator - "PLR0203", # no-staticmethod-decorator - "PLR0206", # property-with-parameters - "PLR1704", # redefined-argument-from-local - "PLR1711", # useless-return - "C416", # unnecessary-comprehension - "PLW0108", # unnecessary-lambda - "PLW0127", # self-assigning-variable - "PLW0129", # assert-on-string-literal - "PLW0602", # global-variable-not-assigned - "PLW0604", # global-at-module-level - "F401", # unused-import - "F841", # unused-variable - "E722", # bare-except - "PLW0711", # binary-op-exception - "PLW1501", # bad-open-mode - "PLW1508", # invalid-envvar-default - "PLW1509", # subprocess-popen-preexec-fn -] -ignore = [ - "PLR6301", # no-self-use - "UP004", # useless-object-inheritance - "PLR0904", # too-many-public-methods - "PLR0911", # too-many-return-statements - "PLR0912", # too-many-branches - "PLR0913", # too-many-arguments - "PLR0914", # too-many-locals - "PLR0915", # too-many-statements - "PLR0916", # too-many-boolean-expressions - "PLW0120", # useless-else-on-loop - "PLW0406", # import-self - "PLW0603", # global-statement - "PLW1641", # eq-without-hash -] - -# Allow fix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL"] -unfixable = ["B"] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -[format] -# Like Black, use double quotes for strings. -quote-style = "double" - -# Like Black, indent with spaces, rather than tabs. -indent-style = "space" - -# Like Black, respect magic trailing commas. -skip-magic-trailing-comma = false - -# Like Black, automatically detect the appropriate line ending. -line-ending = "auto" diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 1e05d041..c008b8de 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -2,6 +2,7 @@ import gc import pathlib +import threading import time import torch @@ -389,7 +390,12 @@ def progress(loaded_modules: int, total_modules: int) # Notify that the model is being loaded self.model_is_loading = True - # Load tokenizer + # Reset tokenizer namespace vars and create a tokenizer + ExLlamaV2Tokenizer.unspecial_piece_to_id = {} + ExLlamaV2Tokenizer.unspecial_id_to_piece = {} + ExLlamaV2Tokenizer.extended_id_to_piece = {} + ExLlamaV2Tokenizer.extended_piece_to_id = {} + self.tokenizer = ExLlamaV2Tokenizer(self.config) # Calculate autosplit reserve for all GPUs @@ -623,14 +629,18 @@ def check_unsupported_settings(self, **kwargs): return kwargs - async def generate_gen(self, prompt: str, **kwargs): + async def generate_gen( + self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs + ): """Basic async wrapper for completion generator""" - sync_generator = self.generate_gen_sync(prompt, **kwargs) + sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs) async for value in iterate_in_threadpool(sync_generator): yield value - def generate_gen_sync(self, prompt: str, **kwargs): + def generate_gen_sync( + self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs + ): """ Create generator function for prompt completion. @@ -893,6 +903,7 @@ def generate_gen_sync(self, prompt: str, **kwargs): return_probabilities=request_logprobs > 0, return_top_tokens=request_logprobs, return_logits=request_logprobs > 0, + abort_event=abort_event, ) else: self.generator.begin_stream_ex( @@ -903,6 +914,7 @@ def generate_gen_sync(self, prompt: str, **kwargs): return_probabilities=request_logprobs > 0, return_top_tokens=request_logprobs, return_logits=request_logprobs > 0, + abort_event=abort_event, ) # Reset offsets for subsequent passes if the context is truncated diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 6d01ef7e..9b874a8a 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -6,7 +6,7 @@ def check_exllama_version(): """Verifies the exllama version""" - required_version = version.parse("0.0.15") + required_version = version.parse("0.0.16") current_version = version.parse(package_version("exllamav2").split("+")[0]) if current_version < required_version: diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index e33ec7f6..562e7365 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -2,6 +2,7 @@ from asyncio import CancelledError import pathlib +import threading from typing import Optional from uuid import uuid4 @@ -161,8 +162,11 @@ async def stream_generate_chat_completion( """Generator for the generation process.""" try: const_id = f"chatcmpl-{uuid4().hex}" + abort_event = threading.Event() - new_generation = model.container.generate_gen(prompt, **data.to_gen_params()) + new_generation = model.container.generate_gen( + prompt, abort_event, **data.to_gen_params() + ) async for generation in new_generation: response = _create_stream_chunk(const_id, generation, model_path.name) @@ -174,6 +178,7 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected + abort_event.set() handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index a98d1b76..02b7852b 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -2,6 +2,7 @@ import pathlib from asyncio import CancelledError +import threading from fastapi import HTTPException from typing import Optional @@ -64,8 +65,10 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli """Streaming generation for completions.""" try: + abort_event = threading.Event() + new_generation = model.container.generate_gen( - data.prompt, **data.to_gen_params() + data.prompt, abort_event, **data.to_gen_params() ) async for generation in new_generation: response = _create_response(generation, model_path.name) @@ -78,6 +81,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli except CancelledError: # Get out if the request gets disconnected + abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") except Exception: yield get_generator_error( diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..2cf961cc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,197 @@ +[build-system] +requires = [ + "packaging", + "setuptools", + "wheel", +] +build-backend = "setuptools.build_meta" + +# We're not building the project itself +[tool.setuptools] +py-modules = [] + +[project] +name = "tabbyAPI" +version = "0.0.1" +description = "An OAI compatible exllamav2 API that's both lightweight and fast" +requires-python = ">=3.10" +dependencies = [ + "fastapi >= 0.110.0", + "pydantic >= 2.0.0", + "PyYAML", + "rich", + "uvicorn >= 0.28.1", + "jinja2 >= 3.0.0", + "loguru", + "sse-starlette", +] + +[project.urls] +"Homepage" = "https://github.com/theroyallab/tabbyAPI" + +[project.optional-dependencies] +dev = [ + "ruff == 0.3.2" +] +cu121 = [ + # Torch (Extra index URLs not support in pyproject.toml) + "torch @ https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + + # Exl2 + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + + # Windows FA2 from https://github.com/bdashore3/flash-attention/releases + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + + # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", +] +cu118 = [ + # Torch + "torch @ https://download.pytorch.org/whl/cu118/torch-2.2.1%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.2.1%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.2.1%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/cu118/torch-2.2.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + + # Exl2 + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+cu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + + # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu118torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", +] +amd = [ + # Torch + "torch @ https://download.pytorch.org/whl/rocm5.6/torch-2.2.1%2Brocm5.6-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "torch @ https://download.pytorch.org/whl/rocm5.6/torch-2.2.1%2Brocm5.6-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + + # Exl2 + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+rocm5.6-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.16/exllamav2-0.0.16+rocm5.6-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", +] + +# MARK: Ruff options + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".git", + ".git-rewrite", + ".mypy_cache", + ".pyenv", + ".pytest_cache", + ".ruff_cache", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.10 +target-version = "py310" + +[tool.ruff.lint] +# Enable preview +preview = true + +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +# Enable flake8-bugbear (`B`) rules, in addition to the defaults. +select = ["E4", "E7", "E9", "F", "B"] +extend-select = [ + "D419", # empty-docstring + "PLC2401", # non-ascii-name + "E501", # line-too-long + "W291", # trailing-whitespace + "PLC0414", # useless-import-alias + "E999", # syntax-error + "PLE0101", # return-in-init + "F706", # return-outside-function + "F704", # yield-outside-function + "PLE0116", # continue-in-finally + "PLE0117", # nonlocal-without-binding + "PLE0241", # duplicate-bases + "PLE0302", # unexpected-special-method-signature + "PLE0604", # invalid-all-object + "PLE0704", # misplaced-bare-raise + "PLE1205", # logging-too-many-args + "PLE1206", # logging-too-few-args + "PLE1307", # bad-string-format-type + "PLE1310", # bad-str-strip-call + "PLE1507", # invalid-envvar-value + "PLR0124", # comparison-with-itself + "PLR0202", # no-classmethod-decorator + "PLR0203", # no-staticmethod-decorator + "PLR0206", # property-with-parameters + "PLR1704", # redefined-argument-from-local + "PLR1711", # useless-return + "C416", # unnecessary-comprehension + "PLW0108", # unnecessary-lambda + "PLW0127", # self-assigning-variable + "PLW0129", # assert-on-string-literal + "PLW0602", # global-variable-not-assigned + "PLW0604", # global-at-module-level + "F401", # unused-import + "F841", # unused-variable + "E722", # bare-except + "PLW0711", # binary-op-exception + "PLW1501", # bad-open-mode + "PLW1508", # invalid-envvar-default + "PLW1509", # subprocess-popen-preexec-fn +] +ignore = [ + "PLR6301", # no-self-use + "UP004", # useless-object-inheritance + "PLR0904", # too-many-public-methods + "PLR0911", # too-many-return-statements + "PLR0912", # too-many-branches + "PLR0913", # too-many-arguments + "PLR0914", # too-many-locals + "PLR0915", # too-many-statements + "PLR0916", # too-many-boolean-expressions + "PLW0120", # useless-else-on-loop + "PLW0406", # import-self + "PLW0603", # global-statement + "PLW1641", # eq-without-hash +] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = ["B"] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/requirements-amd.txt b/requirements-amd.txt deleted file mode 100644 index fe4d2d14..00000000 --- a/requirements-amd.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Torch ---extra-index-url https://download.pytorch.org/whl/rocm5.6 -torch ~= 2.2 - -# Exllamav2 -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+rocm5.6-cp311-cp311-linux_x86_64.whl; python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+rocm5.6-cp310-cp310-linux_x86_64.whl; python_version == "3.10" - -# Pip dependencies -fastapi -pydantic >= 2.0.0 -PyYAML -rich -uvicorn -jinja2 >= 3.0.0 -loguru -sse-starlette diff --git a/requirements-cu118.txt b/requirements-cu118.txt deleted file mode 100644 index 3c3330bd..00000000 --- a/requirements-cu118.txt +++ /dev/null @@ -1,27 +0,0 @@ -# Torch ---extra-index-url https://download.pytorch.org/whl/cu118 -torch ~= 2.2 - -# Exllamav2 - -# Windows -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu118-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" - -# Linux -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu118-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu118-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" - -# Pip dependencies -fastapi -pydantic >= 2.0.0 -PyYAML -rich -uvicorn -jinja2 >= 3.0.0 -loguru -sse-starlette - -# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases -https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu118torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 336378e9..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,15 +0,0 @@ -# formatting -ruff==0.3.2 - -## Implement below dependencies when support is added - -# type checking -# mypy==0.991 -# types-PyYAML -# types-requests -# types-setuptools - -# testing -# pytest -# pytest-forked -# pytest-asyncio \ No newline at end of file diff --git a/requirements-nowheel.txt b/requirements-nowheel.txt deleted file mode 100644 index 5b8acd52..00000000 --- a/requirements-nowheel.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Pip dependencies -fastapi -pydantic >= 2.0.0 -PyYAML -rich -uvicorn -jinja2 >= 3.0.0 -loguru -sse-starlette diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 80c148c0..00000000 --- a/requirements.txt +++ /dev/null @@ -1,33 +0,0 @@ -# Torch ---extra-index-url https://download.pytorch.org/whl/cu121 -torch ~= 2.2 - -# Exllamav2 - -# Windows -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" - -# Linux -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.15+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" - -# Pip dependencies -fastapi -pydantic >= 2.0.0 -PyYAML -rich -uvicorn -jinja2 >= 3.0.0 -loguru -sse-starlette - -# Flash attention v2 - -# Windows FA2 from https://github.com/bdashore3/flash-attention/releases -https://github.com/bdashore3/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2.0cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/bdashore3/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" - -# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases -https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/start.py b/start.py index 82655c78..24ad5d5a 100644 --- a/start.py +++ b/start.py @@ -4,32 +4,96 @@ import argparse import os import pathlib +import platform import subprocess +import sys from common.args import convert_args_to_dict, init_argparser -def get_requirements_file(): +def get_user_choice(question, options_dict): + """ + Gets user input in a commandline script. + + Originally from: https://github.com/oobabooga/text-generation-webui/blob/main/one_click.py#L213 + """ + + print() + print(question) + print() + + for key, value in options_dict.items(): + print(f"{key}) {value.get('pretty')}") + + print() + + choice = input("Input> ").upper() + while choice not in options_dict.keys(): + print("Invalid choice. Please try again.") + choice = input("Input> ").upper() + + return choice + + +def get_install_features(): """Fetches the appropriate requirements file depending on the GPU""" - requirements_name = "requirements-nowheel" - ROCM_PATH = os.environ.get("ROCM_PATH") - CUDA_PATH = os.environ.get("CUDA_PATH") + install_features = None + possible_features = ["cu121", "cu118", "amd"] + + # Try getting the GPU lib from a file + saved_lib_path = pathlib.Path("gpu_lib.txt") + if saved_lib_path.exists(): + with open(saved_lib_path.resolve(), "r") as f: + lib = f.readline() - # TODO: Check if the user has an AMD gpu on windows - if ROCM_PATH: - requirements_name = "requirements-amd" + # Assume default if the file is invalid + if lib not in possible_features: + print( + f"WARN: GPU library {lib} not found. " + "Skipping GPU-specific dependencies.\n" + "WARN: Please delete gpu_lib.txt and restart " + "if you want to change your selection." + ) + return - # Also override env vars for ROCm support on non-supported GPUs + print(f"Using {lib} dependencies from your preferences.") + install_features = lib + else: + # Ask the user for the GPU lib + gpu_lib_choices = { + "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, + "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, + "C": {"pretty": "AMD", "internal": "amd"}, + } + user_input = get_user_choice( + "Select your GPU. If you don't know, select Cuda 12.x (A)", + gpu_lib_choices, + ) + + install_features = gpu_lib_choices.get(user_input, {}).get("internal") + + # Write to a file for subsequent runs + with open(saved_lib_path.resolve(), "w") as f: + f.write(install_features) + print( + "Saving your choice to gpu_lib.txt. " + "Delete this file and restart if you want to change your selection." + ) + + if install_features == "amd": + # Exit if using AMD and Windows + if platform.system() == "Windows": + print( + "ERROR: TabbyAPI does not support AMD and Windows. " + "Please use Linux and ROCm 5.6. Exiting." + ) + sys.exit(0) + + # Override env vars for ROCm support on non-supported GPUs os.environ["ROCM_PATH"] = "/opt/rocm" os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0" os.environ["HCC_AMDGPU_TARGET"] = "gfx1030" - elif CUDA_PATH: - cuda_version = pathlib.Path(CUDA_PATH).name - if "12" in cuda_version: - requirements_name = "requirements" - elif "11" in cuda_version: - requirements_name = "requirements-cu118" - return requirements_name + return install_features def add_start_args(parser: argparse.ArgumentParser): @@ -60,10 +124,11 @@ def add_start_args(parser: argparse.ArgumentParser): if args.ignore_upgrade: print("Ignoring pip dependency upgrade due to user request.") else: - requirements_file = ( - "requirements-nowheel" if args.nowheel else get_requirements_file() - ) - subprocess.run(["pip", "install", "-U", "-r", f"{requirements_file}.txt"]) + install_features = None if args.nowheel else get_install_features() + features = f"[{install_features}]" if install_features else "" + + # pip install .[features] + subprocess.run(["pip", "install", "-U", f".{features}"]) # Import entrypoint after installing all requirements from main import entrypoint