Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/theroyallab/tabbyapi into s…
Browse files Browse the repository at this point in the history
…ampling-rework
  • Loading branch information
bdashore3 committed Oct 25, 2024
2 parents 402898b + 6e48bb4 commit a4e03ed
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 77 deletions.
55 changes: 26 additions & 29 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,27 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name

Expand All @@ -154,25 +173,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# MARK: User configuration

# Get cache mode
Expand Down Expand Up @@ -251,9 +251,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
else:
self.config.scale_alpha_value = rope_alpha

# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)

# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))

Expand Down Expand Up @@ -341,9 +338,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})

self.draft_config.max_seq_len = self.config.max_seq_len

self.draft_config.scale_pos_emb = unwrap(
Expand Down Expand Up @@ -387,9 +381,12 @@ async def set_model_overrides(self, **kwargs):
override_args = unwrap(yaml.load(contents), {})

# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}

# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
Expand Down
2 changes: 1 addition & 1 deletion backends/exllamav2/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def check_exllama_version():
if not dependencies.exllamav2:
raise SystemExit(("Exllamav2 is not installed.\n" + install_message))

required_version = version.parse("0.2.2")
required_version = version.parse("0.2.3")
current_version = version.parse(package_version("exllamav2").split("+")[0])

unsupported_message = (
Expand Down
7 changes: 0 additions & 7 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,6 @@ class ModelConfig(BaseConfigModel):
),
ge=1,
)
fasttensors: Optional[bool] = Field(
False,
description=(
"Enables fasttensors to possibly increase model loading speeds "
"(default: False)."
),
)

_metadata: Metadata = PrivateAttr(Metadata())
model_config = ConfigDict(protected_namespaces=())
Expand Down
3 changes: 0 additions & 3 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ model:
# WARNING: Don't set this unless you know what you're doing!
num_experts_per_token:

# Enables fasttensors to possibly increase model loading speeds (default: False).
fasttensors: false

# Options for draft models (speculative decoding)
# This will use more VRAM!
draft_model:
Expand Down
7 changes: 5 additions & 2 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ async def load_inline_model(model_name: str, request: Request):

return

# Load the model
await model.load_model(model_path)
# Load the model and also add draft dir
await model.load_model(
model_path,
draft_model=config.draft_model.model_dump(include={"draft_model_dir"}),
)


async def stream_generate_completion(
Expand Down
19 changes: 10 additions & 9 deletions endpoints/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""

# Verify request parameters
if not data.name:
if not data.model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
Expand All @@ -132,11 +132,11 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)

model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / data.name
model_path = model_path / data.model_name

draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
if data.draft_model:
if not data.draft_model.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
Expand Down Expand Up @@ -301,7 +301,7 @@ async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
if not data.embedding_model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
Expand All @@ -310,7 +310,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)

embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_dir / data.name
embedding_model_path = embedding_model_dir / data.embedding_model_name

if not embedding_model_path.exists():
error_message = handle_request_error(
Expand Down Expand Up @@ -441,7 +441,7 @@ async def list_templates(request: Request) -> TemplateList:
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template."""

if not data.name:
if not data.prompt_template_name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
Expand All @@ -450,11 +450,12 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)

try:
template_path = pathlib.Path("templates") / data.name
template_path = pathlib.Path("templates") / data.prompt_template_name
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
f"The template name {data.prompt_template_name} doesn't exist. "
+ "Check the spelling?",
exc_info=False,
).error.message

Expand Down
26 changes: 20 additions & 6 deletions endpoints/core/types/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains model card types."""

from pydantic import BaseModel, Field, ConfigDict
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union

Expand Down Expand Up @@ -48,7 +48,10 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""

# Required
draft_model_name: str
draft_model_name: str = Field(
alias=AliasChoices("draft_model_name", "name"),
description="Aliases: name",
)

# Config arguments
draft_rope_scale: Optional[float] = None
Expand All @@ -63,8 +66,14 @@ class DraftModelLoadRequest(BaseModel):
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""

# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces=[])

# Required
name: str
model_name: str = Field(
alias=AliasChoices("model_name", "name"),
description="Aliases: name",
)

# Config arguments

Expand Down Expand Up @@ -106,15 +115,20 @@ class ModelLoadRequest(BaseModel):
chunk_size: Optional[int] = None
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
fasttensors: Optional[bool] = None

# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
draft_model: Optional[DraftModelLoadRequest] = Field(
default=None,
alias=AliasChoices("draft_model", "draft"),
)
skip_queue: Optional[bool] = False


class EmbeddingModelLoadRequest(BaseModel):
name: str
embedding_model_name: str = Field(
alias=AliasChoices("embedding_model_name", "name"),
description="Aliases: name",
)

# Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
Expand Down
7 changes: 5 additions & 2 deletions endpoints/core/types/template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field
from typing import List


Expand All @@ -12,4 +12,7 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""

name: str
prompt_template_name: str = Field(
alias=AliasChoices("prompt_template_name", "name"),
description="Aliases: name",
)
2 changes: 1 addition & 1 deletion endpoints/core/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def stream_model_load(

# Set the draft model path if it exists
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_data["draft_model"]["draft_model_dir"] = draft_model_path

load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ async def entrypoint_async():
# TODO: remove model_dump()
await model.load_model(
model_path.resolve(),
**config.model.model_dump(),
draft=config.draft_model.model_dump(),
**config.model.model_dump(exclude_none=True),
draft_model=config.draft_model.model_dump(exclude_none=True),
)

# Load loras after loading the model
Expand Down
30 changes: 15 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ cu121 = [
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
Expand All @@ -95,12 +95,12 @@ cu118 = [
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu118.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu118.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu118.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu118.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu118.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu118.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",

# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
Expand All @@ -119,9 +119,9 @@ amd = [
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",

# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
]

# MARK: Ruff options
Expand Down

0 comments on commit a4e03ed

Please sign in to comment.