Skip to content

Commit

Permalink
feat: Introduce prompt registry, support versioning in `Prompt.from_t…
Browse files Browse the repository at this point in the history
…emplate` (#6)

* first draft

* add json roundtrip

* use registry in prompt

* almost there

* remove env dependency

* add tests

* better config

* remove positional args from registry

* drop 3.9

* no need to use .parent()

* mkdir -p
  • Loading branch information
masci authored Jun 1, 2024
1 parent af82797 commit fea547f
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -26,6 +25,7 @@ classifiers = [
dependencies = [
"jinja2",
"litellm",
"pydantic",
]

[project.urls]
Expand All @@ -49,7 +49,7 @@ test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
"coverage report",
"coverage report -m",
]
cov = [
"test-cov",
Expand All @@ -58,7 +58,7 @@ cov = [
docs = "mkdocs build"

[[tool.hatch.envs.all.matrix]]
python = ["3.9", "3.10", "3.11", "3.12"]
python = ["3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need
Expand Down Expand Up @@ -128,6 +128,8 @@ select = [
"YTT",
]
ignore = [
# Allow methods like 'set'
"A003",
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
Expand All @@ -136,6 +138,8 @@ ignore = [
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Avoid conflicts with the formatter
"ISC001",
]
unfixable = [
# Don't touch unused imports
Expand Down
4 changes: 2 additions & 2 deletions src/banks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from .cache import RenderCache
from .config import config
from .env import env
from .prompt import AsyncPrompt, Prompt

__all__ = (
"env",
"Prompt",
"AsyncPrompt",
"RenderCache",
"config",
)
10 changes: 9 additions & 1 deletion src/banks/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import os

from platformdirs import user_data_path

from .utils import strtobool

async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))

class BanksConfig:
ASYNC_ENABLED = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))
USER_DATA_PATH = user_data_path(os.environ.get("BANKS_USER_DATA_PATH", "banks"))


config = BanksConfig()
31 changes: 20 additions & 11 deletions src/banks/env.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os
from pathlib import Path

from jinja2 import Environment, select_autoescape

from .config import async_enabled
from .config import config
from .filters import lemmatize
from .loader import MultiLoader
from .registries import FileTemplateRegistry
from .registry import TemplateRegistry


def _add_extensions(env):
Expand All @@ -20,6 +25,13 @@ def _add_extensions(env):
env.add_extension(HFInferenceEndpointsExtension)


def _add_default_templates(r: TemplateRegistry):
templates_dir = Path(os.path.dirname(__file__)) / "templates"
for tpl_file in templates_dir.glob("*.jinja"):
r.set(name=tpl_file.name, prompt=tpl_file.read_text())
r.save()


# Init the Jinja env
env = Environment(
loader=MultiLoader(),
Expand All @@ -29,17 +41,14 @@ def _add_extensions(env):
),
trim_blocks=True,
lstrip_blocks=True,
enable_async=bool(async_enabled),
enable_async=bool(config.ASYNC_ENABLED),
)

# Setup custom filters and default extensions
env.filters["lemmatize"] = lemmatize
_add_extensions(env)
# Init the Template registry
registry = FileTemplateRegistry(config.USER_DATA_PATH)


def with_env(cls):
"""
A decorator that adds an `env` attribute to the decorated class
"""
cls.env = env
return cls
# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
_add_extensions(env)
_add_default_templates(registry)
2 changes: 1 addition & 1 deletion src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.")


def generate(model_name: str): # noqa
def generate(model_name: str): # noqa # This function exists for documentation purpose.
"""
`generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content.
Expand Down
4 changes: 2 additions & 2 deletions src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def lemmatize(text: str) -> str:
Example:
```
{{ 'The dog is running' | lemmatize }}
'the dog be run'
{{"The dog is running" | lemmatize}}
"the dog be run"
```
Note:
Expand Down
19 changes: 10 additions & 9 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional

from .cache import DefaultCache, RenderCache
from .config import async_enabled
from .env import env
from .config import config
from .env import env, registry
from .errors import AsyncError
from .utils import generate_canary_word

Expand Down Expand Up @@ -46,14 +46,16 @@ def canary_leaked(self, text: str) -> bool:
return self.defaults["canary_word"] in text

@classmethod
def from_template(cls, name: str) -> "BasePrompt":
def from_template(cls, name: str, version: str | None = None) -> "BasePrompt":
"""
Create a prompt instance from a template.
Prompt templates can be really long and at some point you might want to store them on files. To avoid the
boilerplate code to read a file and pass the content as strings to the constructor, `Prompt`s can be
initialized by just passing the name of the template file, provided that the file is stored in a folder called
`templates` in the current path:
initialized by just passing the name of the template file, provided that the file is available to the
loaders that were configured (see `Multiloader`).
One of the default loaders can load templates stored in a folder called `templates` in the current path:
```
.
Expand Down Expand Up @@ -81,9 +83,8 @@ def from_template(cls, name: str) -> "BasePrompt":
Returns:
A new `Prompt` instance.
"""
p = cls("")
p._template = env.get_template(name)
return p
tpl = registry.get(name, version)
return cls(tpl.prompt)


class Prompt(BasePrompt):
Expand Down Expand Up @@ -184,7 +185,7 @@ async def main():
def __init__(self, text: str) -> None:
super().__init__(text)

if not async_enabled:
if not config.ASYNC_ENABLED:
msg = "Async is not enabled. Please set the environment variable 'BANKS_ASYNC_ENABLED=on' and try again."
raise AsyncError(msg)

Expand Down
6 changes: 6 additions & 0 deletions src/banks/registries/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from .file import FileTemplateRegistry

__all__ = ("FileTemplateRegistry",)
47 changes: 47 additions & 0 deletions src/banks/registries/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from pathlib import Path

from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError


class FileTemplateRegistry:
def __init__(self, user_data_path: Path) -> None:
self._index_fpath: Path = user_data_path / "index.json"
self._index: PromptTemplateIndex = PromptTemplateIndex(templates=[])
try:
self._index = PromptTemplateIndex.model_validate_json(self._index_fpath.read_text())
except FileNotFoundError:
# init the user data folder
user_data_path.mkdir(parents=True, exist_ok=True)

@staticmethod
def _make_id(name: str, version: str | None):
if version:
return f"{name}:{version}"
return name

def save(self) -> None:
with open(self._index_fpath, "w") as f:
f.write(self._index.model_dump_json())

def get(self, name: str, version: str | None = None) -> "PromptTemplate":
tpl_id = self._make_id(name, version)
for tpl in self._index.templates:
if tpl_id == tpl.id:
return tpl

msg = f"cannot find template '{tpl_id}'"
raise TemplateNotFoundError(msg)

def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False):
try:
tpl = self.get(name, version)
if overwrite:
tpl.prompt = prompt
return
except TemplateNotFoundError:
tpl_id = self._make_id(name, version)
tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt)
self._index.templates.append(tpl)
28 changes: 28 additions & 0 deletions src/banks/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from typing import Protocol

from pydantic import BaseModel


class TemplateNotFoundError(Exception): ...


class PromptTemplate(BaseModel):
id: str
name: str
version: str
prompt: str


class PromptTemplateIndex(BaseModel):
templates: list[PromptTemplate]


class TemplateRegistry(Protocol):
def save(self) -> None: ...

def get(self, *, name: str, version: str | None = None) -> "PromptTemplate": ...

def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False): ...
70 changes: 70 additions & 0 deletions tests/test_file_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest

from banks.registries.file import FileTemplateRegistry
from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError


@pytest.fixture
def populated_index_dir(tmp_path):
tpls = [PromptTemplate(id="name:version", name="name", version="version", prompt="prompt")]
idx = PromptTemplateIndex(templates=tpls)
with open(tmp_path / "index.json", "w") as f:
f.write(idx.model_dump_json())
return tmp_path


def test_init_from_scratch(tmp_path):
index_dir = tmp_path / "test"
r = FileTemplateRegistry(index_dir)
assert r._index_fpath == index_dir / "index.json"
assert index_dir.exists()


def test_init_from_existing_dir(tmp_path):
r = FileTemplateRegistry(tmp_path)
r.save()
assert r._index_fpath.exists()


def test_init_from_existing_index(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
r.get("name", "version")


def test__make_id():
assert FileTemplateRegistry._make_id("name", "version") == "name:version"
assert FileTemplateRegistry._make_id("name", None) == "name"


def test_get(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
tpl = r.get("name", "version")
assert tpl.id == "name:version"


def test_get_not_found(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
with pytest.raises(TemplateNotFoundError):
r.get("name", "nonexisting_version")


def test_set_existing_no_overwrite(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
new_prompt = "a new prompt!"
r.set(name="name", prompt=new_prompt, version="version") # template already exists, expected to be no-op
assert r.get("name", "version").prompt == "prompt"


def test_set_existing_overwrite(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
new_prompt = "a new prompt!"
r.set(name="name", prompt=new_prompt, version="version", overwrite=True)
assert r.get("name", "version").prompt == new_prompt


def test_set_new(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
new_prompt = "a new prompt!"
r.set(name="name", prompt=new_prompt, version="version2")
assert r.get("name", "version").prompt == "prompt"
assert r.get("name", "version2").prompt == new_prompt

0 comments on commit fea547f

Please sign in to comment.