-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Introduce prompt registry, support versioning in `Prompt.from_t…
…emplate` (#6) * first draft * add json roundtrip * use registry in prompt * almost there * remove env dependency * add tests * better config * remove positional args from registry * drop 3.9 * no need to use .parent() * mkdir -p
- Loading branch information
Showing
12 changed files
with
203 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from .cache import RenderCache | ||
from .config import config | ||
from .env import env | ||
from .prompt import AsyncPrompt, Prompt | ||
|
||
__all__ = ( | ||
"env", | ||
"Prompt", | ||
"AsyncPrompt", | ||
"RenderCache", | ||
"config", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
import os | ||
|
||
from platformdirs import user_data_path | ||
|
||
from .utils import strtobool | ||
|
||
async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) | ||
|
||
class BanksConfig: | ||
ASYNC_ENABLED = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) | ||
USER_DATA_PATH = user_data_path(os.environ.get("BANKS_USER_DATA_PATH", "banks")) | ||
|
||
|
||
config = BanksConfig() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,16 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
import os | ||
from pathlib import Path | ||
|
||
from jinja2 import Environment, select_autoescape | ||
|
||
from .config import async_enabled | ||
from .config import config | ||
from .filters import lemmatize | ||
from .loader import MultiLoader | ||
from .registries import FileTemplateRegistry | ||
from .registry import TemplateRegistry | ||
|
||
|
||
def _add_extensions(env): | ||
|
@@ -20,6 +25,13 @@ def _add_extensions(env): | |
env.add_extension(HFInferenceEndpointsExtension) | ||
|
||
|
||
def _add_default_templates(r: TemplateRegistry): | ||
templates_dir = Path(os.path.dirname(__file__)) / "templates" | ||
for tpl_file in templates_dir.glob("*.jinja"): | ||
r.set(name=tpl_file.name, prompt=tpl_file.read_text()) | ||
r.save() | ||
|
||
|
||
# Init the Jinja env | ||
env = Environment( | ||
loader=MultiLoader(), | ||
|
@@ -29,17 +41,14 @@ def _add_extensions(env): | |
), | ||
trim_blocks=True, | ||
lstrip_blocks=True, | ||
enable_async=bool(async_enabled), | ||
enable_async=bool(config.ASYNC_ENABLED), | ||
) | ||
|
||
# Setup custom filters and default extensions | ||
env.filters["lemmatize"] = lemmatize | ||
_add_extensions(env) | ||
# Init the Template registry | ||
registry = FileTemplateRegistry(config.USER_DATA_PATH) | ||
|
||
|
||
def with_env(cls): | ||
""" | ||
A decorator that adds an `env` attribute to the decorated class | ||
""" | ||
cls.env = env | ||
return cls | ||
# Setup custom filters and defaults | ||
env.filters["lemmatize"] = lemmatize | ||
_add_extensions(env) | ||
_add_default_templates(registry) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from .file import FileTemplateRegistry | ||
|
||
__all__ = ("FileTemplateRegistry",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from pathlib import Path | ||
|
||
from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError | ||
|
||
|
||
class FileTemplateRegistry: | ||
def __init__(self, user_data_path: Path) -> None: | ||
self._index_fpath: Path = user_data_path / "index.json" | ||
self._index: PromptTemplateIndex = PromptTemplateIndex(templates=[]) | ||
try: | ||
self._index = PromptTemplateIndex.model_validate_json(self._index_fpath.read_text()) | ||
except FileNotFoundError: | ||
# init the user data folder | ||
user_data_path.mkdir(parents=True, exist_ok=True) | ||
|
||
@staticmethod | ||
def _make_id(name: str, version: str | None): | ||
if version: | ||
return f"{name}:{version}" | ||
return name | ||
|
||
def save(self) -> None: | ||
with open(self._index_fpath, "w") as f: | ||
f.write(self._index.model_dump_json()) | ||
|
||
def get(self, name: str, version: str | None = None) -> "PromptTemplate": | ||
tpl_id = self._make_id(name, version) | ||
for tpl in self._index.templates: | ||
if tpl_id == tpl.id: | ||
return tpl | ||
|
||
msg = f"cannot find template '{tpl_id}'" | ||
raise TemplateNotFoundError(msg) | ||
|
||
def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False): | ||
try: | ||
tpl = self.get(name, version) | ||
if overwrite: | ||
tpl.prompt = prompt | ||
return | ||
except TemplateNotFoundError: | ||
tpl_id = self._make_id(name, version) | ||
tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt) | ||
self._index.templates.append(tpl) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from typing import Protocol | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class TemplateNotFoundError(Exception): ... | ||
|
||
|
||
class PromptTemplate(BaseModel): | ||
id: str | ||
name: str | ||
version: str | ||
prompt: str | ||
|
||
|
||
class PromptTemplateIndex(BaseModel): | ||
templates: list[PromptTemplate] | ||
|
||
|
||
class TemplateRegistry(Protocol): | ||
def save(self) -> None: ... | ||
|
||
def get(self, *, name: str, version: str | None = None) -> "PromptTemplate": ... | ||
|
||
def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pytest | ||
|
||
from banks.registries.file import FileTemplateRegistry | ||
from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError | ||
|
||
|
||
@pytest.fixture | ||
def populated_index_dir(tmp_path): | ||
tpls = [PromptTemplate(id="name:version", name="name", version="version", prompt="prompt")] | ||
idx = PromptTemplateIndex(templates=tpls) | ||
with open(tmp_path / "index.json", "w") as f: | ||
f.write(idx.model_dump_json()) | ||
return tmp_path | ||
|
||
|
||
def test_init_from_scratch(tmp_path): | ||
index_dir = tmp_path / "test" | ||
r = FileTemplateRegistry(index_dir) | ||
assert r._index_fpath == index_dir / "index.json" | ||
assert index_dir.exists() | ||
|
||
|
||
def test_init_from_existing_dir(tmp_path): | ||
r = FileTemplateRegistry(tmp_path) | ||
r.save() | ||
assert r._index_fpath.exists() | ||
|
||
|
||
def test_init_from_existing_index(populated_index_dir): | ||
r = FileTemplateRegistry(populated_index_dir) | ||
r.get("name", "version") | ||
|
||
|
||
def test__make_id(): | ||
assert FileTemplateRegistry._make_id("name", "version") == "name:version" | ||
assert FileTemplateRegistry._make_id("name", None) == "name" | ||
|
||
|
||
def test_get(populated_index_dir): | ||
r = FileTemplateRegistry(populated_index_dir) | ||
tpl = r.get("name", "version") | ||
assert tpl.id == "name:version" | ||
|
||
|
||
def test_get_not_found(populated_index_dir): | ||
r = FileTemplateRegistry(populated_index_dir) | ||
with pytest.raises(TemplateNotFoundError): | ||
r.get("name", "nonexisting_version") | ||
|
||
|
||
def test_set_existing_no_overwrite(populated_index_dir): | ||
r = FileTemplateRegistry(populated_index_dir) | ||
new_prompt = "a new prompt!" | ||
r.set(name="name", prompt=new_prompt, version="version") # template already exists, expected to be no-op | ||
assert r.get("name", "version").prompt == "prompt" | ||
|
||
|
||
def test_set_existing_overwrite(populated_index_dir): | ||
r = FileTemplateRegistry(populated_index_dir) | ||
new_prompt = "a new prompt!" | ||
r.set(name="name", prompt=new_prompt, version="version", overwrite=True) | ||
assert r.get("name", "version").prompt == new_prompt | ||
|
||
|
||
def test_set_new(populated_index_dir): | ||
r = FileTemplateRegistry(populated_index_dir) | ||
new_prompt = "a new prompt!" | ||
r.set(name="name", prompt=new_prompt, version="version2") | ||
assert r.get("name", "version").prompt == "prompt" | ||
assert r.get("name", "version2").prompt == new_prompt |