Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New API overrides and formatting changes #54

Merged
merged 7 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,7 @@ templates/*
!templates/place_your_templates_here.txt
!templates/alpaca.jinja
!templates/chatml.jinja

# Sampler overrides folder
sampler_overrides/*
!sampler_overrides/sample_preset.yml
5 changes: 3 additions & 2 deletions OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from time import time
from typing import Union, List, Optional, Dict
from uuid import uuid4

from OAI.types.common import UsageStats, CommonCompletionRequest


Expand Down
92 changes: 6 additions & 86 deletions OAI/types/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
""" Common types for OAI. """
from pydantic import BaseModel, Field, AliasChoices
from typing import List, Dict, Optional, Union
from pydantic import BaseModel, Field
from typing import List, Dict, Optional

from common.sampling import SamplerParams


class LogProbs(BaseModel):
Expand All @@ -20,7 +22,7 @@ class UsageStats(BaseModel):
total_tokens: int


class CommonCompletionRequest(BaseModel):
class CommonCompletionRequest(SamplerParams):
"""Represents a common completion request."""

# Model information
Expand All @@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel):
description="Not parsed. Only used for OAI compliance.", default=None
)

# Generation info
# seed: Optional[int] = -1
# Generation info (remainder is in SamplerParams superclass)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = []

# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150

# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
top_a: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None

# Aliased variables
typical: Optional[float] = Field(
default=1.0,
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
)

penalty_range: Optional[int] = Field(
default=-1,
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
description="Aliases: repetition_range, repetition_penalty_range",
)

cfg_scale: Optional[float] = Field(
default=1.0,
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
)

def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]

return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}
3 changes: 1 addition & 2 deletions OAI/types/completion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" Completion API protocols """
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional, Union
from uuid import uuid4

from pydantic import BaseModel, Field

from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats


Expand Down
3 changes: 1 addition & 2 deletions OAI/types/lora.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
""" Lora types """
from pydantic import BaseModel, Field
from time import time
from typing import Optional, List

from pydantic import BaseModel, Field


class LoraCard(BaseModel):
"""Represents a single Lora card."""
Expand Down
5 changes: 2 additions & 3 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" Contains model card types. """
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional

from pydantic import BaseModel, Field, ConfigDict

from gen_logging import LogPreferences
from common.gen_logging import LogPreferences


class ModelCardParameters(BaseModel):
Expand Down
26 changes: 26 additions & 0 deletions OAI/types/sampler_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import Optional


class SamplerOverrideSwitchRequest(BaseModel):
"""Sampler override switch request"""

preset: Optional[str] = Field(
default=None, description="Pass a sampler override preset name"
)

overrides: Optional[dict] = Field(
default=None,
description=(
"Sampling override parent takes in individual keys and overrides."
+ "Ignored if preset is provided."
),
examples=[
{
"top_p": {
"override": 1.5,
"force": False,
}
}
],
)
6 changes: 6 additions & 0 deletions OAI/types/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ class TemplateList(BaseModel):

object: str = "list"
data: List[str] = Field(default_factory=list)


class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""

name: str
3 changes: 1 addition & 2 deletions OAI/types/token.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
""" Tokenization types """
from typing import List

from pydantic import BaseModel
from typing import List


class CommonTokenRequest(BaseModel):
Expand Down
3 changes: 1 addition & 2 deletions OAI/utils_oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
from typing import Optional

from common.utils import unwrap
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
Expand All @@ -14,8 +15,6 @@
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard

from utils import unwrap


def create_completion_response(
text: str,
Expand Down
47 changes: 23 additions & 24 deletions model.py → backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@
ExLlamaV2Lora,
)
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler

from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import (

from common.gen_logging import log_generation_params, log_prompt, log_response
from common.templating import (
PromptTemplate,
find_template_from_model,
get_template_from_model_json,
get_template_from_file,
)
from utils import coalesce, unwrap
from logger import init_logger
from common.utils import coalesce, unwrap
from common.logger import init_logger

logger = init_logger(__name__)

# Bytes to reserve on first device when loading with auto split
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2


class ModelContainer:
class ExllamaV2Container:
"""The model container class for ExLlamaV2 models."""

config: Optional[ExLlamaV2Config] = None
Expand Down Expand Up @@ -163,30 +163,27 @@ def progress(loaded_modules: int, total_modules: int,
if prompt_template_name:
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
# Then try finding the template from the tokenizer_config.json
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
)
try:
self.prompt_template = get_template_from_file(prompt_template_name)
except FileNotFoundError:
self.prompt_template = None

# Try finding the chat template from the model's config.json
# TODO: This may not even be used with huggingface models,
# mark for removal.
if self.prompt_template is None:
# Then try finding the template from the tokenizer_config.json
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation means I need to set a (dummy) prompt_template in the config even if it should use the template from tokenizer_config.json.

This is probably just wrong indendation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and even that does not work because it gets overridden in the next try block below that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest commit.

self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_config),
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_model_config",
"from_tokenizer_config",
)
except FileNotFoundError:
self.prompt_template = None

# If that fails, attempt fetching from model name
if self.prompt_template is None:
try:
template_match = find_template_from_model(model_directory)
if template_match:
self.prompt_template = get_template_from_file(template_match)
self.prompt_template = get_template_from_file(template_match)
except (LookupError, FileNotFoundError):
self.prompt_template = None

# Catch all for template lookup errors
if self.prompt_template:
Expand Down Expand Up @@ -557,7 +554,9 @@ def generate_gen(self, prompt: str, **kwargs):
token_healing = unwrap(kwargs.get("token_healing"), False)
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens)
generate_window = max(
unwrap(kwargs.get("generate_window"), 512), max_tokens // 8
)

# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
Expand Down
File renamed without changes.
File renamed without changes.
7 changes: 3 additions & 4 deletions auth.py → common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
application, it should be fine.
"""
import secrets
from typing import Optional

import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
import yaml
from typing import Optional

from logger import init_logger
from common.logger import init_logger

logger = init_logger(__name__)

Expand Down
9 changes: 7 additions & 2 deletions config.py → common/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import yaml
import pathlib

from logger import init_logger
from utils import unwrap
from common.logger import init_logger
from common.utils import unwrap

logger = init_logger(__name__)

Expand Down Expand Up @@ -56,6 +56,11 @@ def override_config_from_args(args: dict):
}


def get_sampling_config():
"""Returns the sampling parameter config from the global config"""
return unwrap(GLOBAL_CONFIG.get("sampling"), {})


def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})
Expand Down
2 changes: 1 addition & 1 deletion gen_logging.py → common/gen_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel
from typing import Dict, Optional

from logger import init_logger
from common.logger import init_logger

logger = init_logger(__name__)

Expand Down
File renamed without changes.
File renamed without changes.
Loading
Loading