Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce prompt registry, support versioning in Prompt.from_template #6

Merged
merged 11 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -26,6 +25,7 @@ classifiers = [
dependencies = [
"jinja2",
"litellm",
"pydantic",
]

[project.urls]
Expand All @@ -49,7 +49,7 @@ test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
"coverage report",
"coverage report -m",
]
cov = [
"test-cov",
Expand All @@ -58,7 +58,7 @@ cov = [
docs = "mkdocs build"

[[tool.hatch.envs.all.matrix]]
python = ["3.9", "3.10", "3.11", "3.12"]
python = ["3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need
Expand Down Expand Up @@ -128,6 +128,8 @@ select = [
"YTT",
]
ignore = [
# Allow methods like 'set'
"A003",
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
Expand All @@ -136,6 +138,8 @@ ignore = [
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Avoid conflicts with the formatter
"ISC001",
]
unfixable = [
# Don't touch unused imports
Expand Down
4 changes: 2 additions & 2 deletions src/banks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from .cache import RenderCache
from .config import config
from .env import env
from .prompt import AsyncPrompt, Prompt

__all__ = (
"env",
"Prompt",
"AsyncPrompt",
"RenderCache",
"config",
)
10 changes: 9 additions & 1 deletion src/banks/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import os

from platformdirs import user_data_path

from .utils import strtobool

async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))

class BanksConfig:
ASYNC_ENABLED = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))
USER_DATA_PATH = user_data_path(os.environ.get("BANKS_USER_DATA_PATH", "banks"))


config = BanksConfig()
31 changes: 20 additions & 11 deletions src/banks/env.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os
from pathlib import Path

from jinja2 import Environment, select_autoescape

from .config import async_enabled
from .config import config
from .filters import lemmatize
from .loader import MultiLoader
from .registries import FileTemplateRegistry
from .registry import TemplateRegistry


def _add_extensions(env):
Expand All @@ -20,6 +25,13 @@ def _add_extensions(env):
env.add_extension(HFInferenceEndpointsExtension)


def _add_default_templates(r: TemplateRegistry):
templates_dir = Path(os.path.dirname(__file__)) / "templates"
for tpl_file in templates_dir.glob("*.jinja"):
r.set(name=tpl_file.name, prompt=tpl_file.read_text())
r.save()


# Init the Jinja env
env = Environment(
loader=MultiLoader(),
Expand All @@ -29,17 +41,14 @@ def _add_extensions(env):
),
trim_blocks=True,
lstrip_blocks=True,
enable_async=bool(async_enabled),
enable_async=bool(config.ASYNC_ENABLED),
)

# Setup custom filters and default extensions
env.filters["lemmatize"] = lemmatize
_add_extensions(env)
# Init the Template registry
registry = FileTemplateRegistry(config.USER_DATA_PATH)


def with_env(cls):
"""
A decorator that adds an `env` attribute to the decorated class
"""
cls.env = env
return cls
# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
_add_extensions(env)
_add_default_templates(registry)
2 changes: 1 addition & 1 deletion src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.")


def generate(model_name: str): # noqa
def generate(model_name: str): # noqa # This function exists for documentation purpose.
"""
`generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content.

Expand Down
4 changes: 2 additions & 2 deletions src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def lemmatize(text: str) -> str:

Example:
```
{{ 'The dog is running' | lemmatize }}
'the dog be run'
{{"The dog is running" | lemmatize}}
"the dog be run"
```

Note:
Expand Down
19 changes: 10 additions & 9 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional

from .cache import DefaultCache, RenderCache
from .config import async_enabled
from .env import env
from .config import config
from .env import env, registry
from .errors import AsyncError
from .utils import generate_canary_word

Expand Down Expand Up @@ -46,14 +46,16 @@ def canary_leaked(self, text: str) -> bool:
return self.defaults["canary_word"] in text

@classmethod
def from_template(cls, name: str) -> "BasePrompt":
def from_template(cls, name: str, version: str | None = None) -> "BasePrompt":
"""
Create a prompt instance from a template.

Prompt templates can be really long and at some point you might want to store them on files. To avoid the
boilerplate code to read a file and pass the content as strings to the constructor, `Prompt`s can be
initialized by just passing the name of the template file, provided that the file is stored in a folder called
`templates` in the current path:
initialized by just passing the name of the template file, provided that the file is available to the
loaders that were configured (see `Multiloader`).

One of the default loaders can load templates stored in a folder called `templates` in the current path:

```
.
Expand Down Expand Up @@ -81,9 +83,8 @@ def from_template(cls, name: str) -> "BasePrompt":
Returns:
A new `Prompt` instance.
"""
p = cls("")
p._template = env.get_template(name)
return p
tpl = registry.get(name, version)
return cls(tpl.prompt)


class Prompt(BasePrompt):
Expand Down Expand Up @@ -184,7 +185,7 @@ async def main():
def __init__(self, text: str) -> None:
super().__init__(text)

if not async_enabled:
if not config.ASYNC_ENABLED:
msg = "Async is not enabled. Please set the environment variable 'BANKS_ASYNC_ENABLED=on' and try again."
raise AsyncError(msg)

Expand Down
6 changes: 6 additions & 0 deletions src/banks/registries/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from .file import FileTemplateRegistry

__all__ = ("FileTemplateRegistry",)
47 changes: 47 additions & 0 deletions src/banks/registries/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from pathlib import Path

from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError


class FileTemplateRegistry:
def __init__(self, user_data_path: Path) -> None:
self._index_fpath: Path = user_data_path / "index.json"
self._index: PromptTemplateIndex = PromptTemplateIndex(templates=[])
try:
self._index = PromptTemplateIndex.model_validate_json(self._index_fpath.read_text())
except FileNotFoundError:
# init the user data folder
user_data_path.mkdir(parents=True, exist_ok=True)

@staticmethod
def _make_id(name: str, version: str | None):
if version:
return f"{name}:{version}"
return name

def save(self) -> None:
with open(self._index_fpath, "w") as f:
f.write(self._index.model_dump_json())

def get(self, name: str, version: str | None = None) -> "PromptTemplate":
tpl_id = self._make_id(name, version)
for tpl in self._index.templates:
if tpl_id == tpl.id:
return tpl

msg = f"cannot find template '{tpl_id}'"
raise TemplateNotFoundError(msg)

def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False):
try:
tpl = self.get(name, version)
if overwrite:
tpl.prompt = prompt
return
except TemplateNotFoundError:
tpl_id = self._make_id(name, version)
tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt)
self._index.templates.append(tpl)
28 changes: 28 additions & 0 deletions src/banks/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from typing import Protocol

from pydantic import BaseModel


class TemplateNotFoundError(Exception): ...


class PromptTemplate(BaseModel):
id: str
name: str
version: str
prompt: str


class PromptTemplateIndex(BaseModel):
templates: list[PromptTemplate]


class TemplateRegistry(Protocol):
def save(self) -> None: ...

def get(self, *, name: str, version: str | None = None) -> "PromptTemplate": ...

def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False): ...
70 changes: 70 additions & 0 deletions tests/test_file_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest

from banks.registries.file import FileTemplateRegistry
from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError


@pytest.fixture
def populated_index_dir(tmp_path):
tpls = [PromptTemplate(id="name:version", name="name", version="version", prompt="prompt")]
idx = PromptTemplateIndex(templates=tpls)
with open(tmp_path / "index.json", "w") as f:
f.write(idx.model_dump_json())
return tmp_path


def test_init_from_scratch(tmp_path):
index_dir = tmp_path / "test"
r = FileTemplateRegistry(index_dir)
assert r._index_fpath == index_dir / "index.json"
assert index_dir.exists()


def test_init_from_existing_dir(tmp_path):
r = FileTemplateRegistry(tmp_path)
r.save()
assert r._index_fpath.exists()


def test_init_from_existing_index(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
r.get("name", "version")


def test__make_id():
assert FileTemplateRegistry._make_id("name", "version") == "name:version"
assert FileTemplateRegistry._make_id("name", None) == "name"


def test_get(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
tpl = r.get("name", "version")
assert tpl.id == "name:version"


def test_get_not_found(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
with pytest.raises(TemplateNotFoundError):
r.get("name", "nonexisting_version")


def test_set_existing_no_overwrite(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
new_prompt = "a new prompt!"
r.set(name="name", prompt=new_prompt, version="version") # template already exists, expected to be no-op
assert r.get("name", "version").prompt == "prompt"


def test_set_existing_overwrite(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
new_prompt = "a new prompt!"
r.set(name="name", prompt=new_prompt, version="version", overwrite=True)
assert r.get("name", "version").prompt == new_prompt


def test_set_new(populated_index_dir):
r = FileTemplateRegistry(populated_index_dir)
new_prompt = "a new prompt!"
r.set(name="name", prompt=new_prompt, version="version2")
assert r.get("name", "version").prompt == "prompt"
assert r.get("name", "version2").prompt == new_prompt