Skip to content

Commit

Permalink
API: Add template switching and unload endpoints
Browse files Browse the repository at this point in the history
Templates can be switched and unloaded without reloading the entire
model.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Jan 23, 2024
1 parent 1a8198d commit 4cf231d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 25 deletions.
6 changes: 6 additions & 0 deletions OAI/types/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ class TemplateList(BaseModel):

object: str = "list"
data: List[str] = Field(default_factory=list)


class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""

name: str
31 changes: 14 additions & 17 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,30 +163,27 @@ def progress(loaded_modules: int, total_modules: int,
if prompt_template_name:
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
# Then try finding the template from the tokenizer_config.json
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
)
try:
self.prompt_template = get_template_from_file(prompt_template_name)
except FileNotFoundError:
self.prompt_template = None

# Try finding the chat template from the model's config.json
# TODO: This may not even be used with huggingface models,
# mark for removal.
if self.prompt_template is None:
# Then try finding the template from the tokenizer_config.json
try:
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_config),
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_model_config",
"from_tokenizer_config",
)
except FileNotFoundError:
self.prompt_template = None

# If that fails, attempt fetching from model name
if self.prompt_template is None:
try:
template_match = find_template_from_model(model_directory)
if template_match:
self.prompt_template = get_template_from_file(template_match)
self.prompt_template = get_template_from_file(template_match)
except (LookupError, FileNotFoundError):
self.prompt_template = None

# Catch all for template lookup errors
if self.prompt_template:
Expand Down
15 changes: 9 additions & 6 deletions common/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,29 @@ def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_files = get_all_templates()

for filepath in template_files:
template_name = filepath.stem.lower()

# Check if the template name is present in the model name
if template_name in model_name.lower():
return template_name

return None
else:
raise LookupError("Could not find template from model name.")


def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""

template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name=prompt_template_name, template=raw_template.read()
)

return None
else:
# Let the user know if the template file isn't found
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')


# Get a template from a JSON file
Expand All @@ -100,5 +103,5 @@ def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(name=name, template=chat_template)

return None
else:
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
34 changes: 32 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.templating import get_all_templates, get_prompt_from_template
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
from common.logger import init_logger
from OAI.types.completion import CompletionRequest
Expand All @@ -39,7 +43,7 @@
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.template import TemplateList
from OAI.types.template import TemplateList, TemplateSwitchRequest
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
Expand Down Expand Up @@ -258,6 +262,32 @@ async def get_templates():
return TemplateList(data=template_strings)


@app.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
raise HTTPException(400, "New template name not found.")

try:
template = get_template_from_file(data.name)
MODEL_CONTAINER.prompt_template = template
except FileNotFoundError as e:
raise HTTPException(400, "Template does not exist. Check the name?") from e


@app.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""

MODEL_CONTAINER.prompt_template = None


# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
Expand Down

0 comments on commit 4cf231d

Please sign in to comment.