Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dependencies: Switch to pyproject.toml #88

Merged
merged 9 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
pip install .[dev]
- name: Format and show diff with ruff
run: |
ruff format --diff
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,6 @@ templates/*
# Sampler overrides folder
sampler_overrides/*
!sampler_overrides/sample_preset.yml

# Gpu lib preferences file
gpu_lib.txt
111 changes: 0 additions & 111 deletions .ruff.toml

This file was deleted.

20 changes: 16 additions & 4 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gc
import pathlib
import threading
import time

import torch
Expand Down Expand Up @@ -389,7 +390,12 @@ def progress(loaded_modules: int, total_modules: int)
# Notify that the model is being loaded
self.model_is_loading = True

# Load tokenizer
# Reset tokenizer namespace vars and create a tokenizer
ExLlamaV2Tokenizer.unspecial_piece_to_id = {}
ExLlamaV2Tokenizer.unspecial_id_to_piece = {}
ExLlamaV2Tokenizer.extended_id_to_piece = {}
ExLlamaV2Tokenizer.extended_piece_to_id = {}

self.tokenizer = ExLlamaV2Tokenizer(self.config)

# Calculate autosplit reserve for all GPUs
Expand Down Expand Up @@ -623,14 +629,18 @@ def check_unsupported_settings(self, **kwargs):

return kwargs

async def generate_gen(self, prompt: str, **kwargs):
async def generate_gen(
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
):
"""Basic async wrapper for completion generator"""

sync_generator = self.generate_gen_sync(prompt, **kwargs)
sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs)
async for value in iterate_in_threadpool(sync_generator):
yield value

def generate_gen_sync(self, prompt: str, **kwargs):
def generate_gen_sync(
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
):
"""
Create generator function for prompt completion.

Expand Down Expand Up @@ -893,6 +903,7 @@ def generate_gen_sync(self, prompt: str, **kwargs):
return_probabilities=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
abort_event=abort_event,
)
else:
self.generator.begin_stream_ex(
Expand All @@ -903,6 +914,7 @@ def generate_gen_sync(self, prompt: str, **kwargs):
return_probabilities=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
abort_event=abort_event,
)

# Reset offsets for subsequent passes if the context is truncated
Expand Down
2 changes: 1 addition & 1 deletion backends/exllamav2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def check_exllama_version():
"""Verifies the exllama version"""

required_version = version.parse("0.0.15")
required_version = version.parse("0.0.16")
current_version = version.parse(package_version("exllamav2").split("+")[0])

if current_version < required_version:
Expand Down
7 changes: 6 additions & 1 deletion endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from asyncio import CancelledError
import pathlib
import threading
from typing import Optional
from uuid import uuid4

Expand Down Expand Up @@ -161,8 +162,11 @@ async def stream_generate_chat_completion(
"""Generator for the generation process."""
try:
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = threading.Event()

new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
async for generation in new_generation:
response = _create_stream_chunk(const_id, generation, model_path.name)

Expand All @@ -174,6 +178,7 @@ async def stream_generate_chat_completion(
except CancelledError:
# Get out if the request gets disconnected

abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(
Expand Down
6 changes: 5 additions & 1 deletion endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pathlib
from asyncio import CancelledError
import threading
from fastapi import HTTPException
from typing import Optional

Expand Down Expand Up @@ -64,8 +65,10 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
"""Streaming generation for completions."""

try:
abort_event = threading.Event()

new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
data.prompt, abort_event, **data.to_gen_params()
)
async for generation in new_generation:
response = _create_response(generation, model_path.name)
Expand All @@ -78,6 +81,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
except CancelledError:
# Get out if the request gets disconnected

abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(
Expand Down
Loading
Loading