diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8706aad..59c92d5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index 6898a12..16b6855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ authors = [ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -26,6 +25,7 @@ classifiers = [ dependencies = [ "jinja2", "litellm", + "pydantic", ] [project.urls] @@ -49,7 +49,7 @@ test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" cov-report = [ "- coverage combine", - "coverage report", + "coverage report -m", ] cov = [ "test-cov", @@ -58,7 +58,7 @@ cov = [ docs = "mkdocs build" [[tool.hatch.envs.all.matrix]] -python = ["3.9", "3.10", "3.11", "3.12"] +python = ["3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need @@ -128,6 +128,8 @@ select = [ "YTT", ] ignore = [ + # Allow methods like 'set' + "A003", # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` @@ -136,6 +138,8 @@ ignore = [ "S105", "S106", "S107", # Ignore complexity "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + # Avoid conflicts with the formatter + "ISC001", ] unfixable = [ # Don't touch unused imports diff --git a/src/banks/__init__.py b/src/banks/__init__.py index 8a0228d..681c3f0 100644 --- a/src/banks/__init__.py +++ b/src/banks/__init__.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -from .cache import RenderCache +from .config import config from .env import env from .prompt import AsyncPrompt, Prompt @@ -9,5 +9,5 @@ "env", "Prompt", "AsyncPrompt", - "RenderCache", + "config", ) diff --git a/src/banks/config.py b/src/banks/config.py index 2547f41..faf196e 100644 --- a/src/banks/config.py +++ b/src/banks/config.py @@ -1,5 +1,13 @@ import os +from platformdirs import user_data_path + from .utils import strtobool -async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) + +class BanksConfig: + ASYNC_ENABLED = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) + USER_DATA_PATH = user_data_path(os.environ.get("BANKS_USER_DATA_PATH", "banks")) + + +config = BanksConfig() diff --git a/src/banks/env.py b/src/banks/env.py index 477210f..202a4ae 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -1,11 +1,16 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import os +from pathlib import Path + from jinja2 import Environment, select_autoescape -from .config import async_enabled +from .config import config from .filters import lemmatize from .loader import MultiLoader +from .registries import FileTemplateRegistry +from .registry import TemplateRegistry def _add_extensions(env): @@ -20,6 +25,13 @@ def _add_extensions(env): env.add_extension(HFInferenceEndpointsExtension) +def _add_default_templates(r: TemplateRegistry): + templates_dir = Path(os.path.dirname(__file__)) / "templates" + for tpl_file in templates_dir.glob("*.jinja"): + r.set(name=tpl_file.name, prompt=tpl_file.read_text()) + r.save() + + # Init the Jinja env env = Environment( loader=MultiLoader(), @@ -29,17 +41,14 @@ def _add_extensions(env): ), trim_blocks=True, lstrip_blocks=True, - enable_async=bool(async_enabled), + enable_async=bool(config.ASYNC_ENABLED), ) -# Setup custom filters and default extensions -env.filters["lemmatize"] = lemmatize -_add_extensions(env) +# Init the Template registry +registry = FileTemplateRegistry(config.USER_DATA_PATH) -def with_env(cls): - """ - A decorator that adds an `env` attribute to the decorated class - """ - cls.env = env - return cls +# Setup custom filters and defaults +env.filters["lemmatize"] = lemmatize +_add_extensions(env) +_add_default_templates(registry) diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 37c99d8..6c44b52 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -14,7 +14,7 @@ SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.") -def generate(model_name: str): # noqa +def generate(model_name: str): # noqa # This function exists for documentation purpose. """ `generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content. diff --git a/src/banks/filters/lemmatize.py b/src/banks/filters/lemmatize.py index f603560..4c196d6 100644 --- a/src/banks/filters/lemmatize.py +++ b/src/banks/filters/lemmatize.py @@ -18,8 +18,8 @@ def lemmatize(text: str) -> str: Example: ``` - {{ 'The dog is running' | lemmatize }} - 'the dog be run' + {{"The dog is running" | lemmatize}} + "the dog be run" ``` Note: diff --git a/src/banks/prompt.py b/src/banks/prompt.py index 625e4b2..fc1fb11 100644 --- a/src/banks/prompt.py +++ b/src/banks/prompt.py @@ -4,8 +4,8 @@ from typing import Optional from .cache import DefaultCache, RenderCache -from .config import async_enabled -from .env import env +from .config import config +from .env import env, registry from .errors import AsyncError from .utils import generate_canary_word @@ -46,14 +46,16 @@ def canary_leaked(self, text: str) -> bool: return self.defaults["canary_word"] in text @classmethod - def from_template(cls, name: str) -> "BasePrompt": + def from_template(cls, name: str, version: str | None = None) -> "BasePrompt": """ Create a prompt instance from a template. Prompt templates can be really long and at some point you might want to store them on files. To avoid the boilerplate code to read a file and pass the content as strings to the constructor, `Prompt`s can be - initialized by just passing the name of the template file, provided that the file is stored in a folder called - `templates` in the current path: + initialized by just passing the name of the template file, provided that the file is available to the + loaders that were configured (see `Multiloader`). + + One of the default loaders can load templates stored in a folder called `templates` in the current path: ``` . @@ -81,9 +83,8 @@ def from_template(cls, name: str) -> "BasePrompt": Returns: A new `Prompt` instance. """ - p = cls("") - p._template = env.get_template(name) - return p + tpl = registry.get(name, version) + return cls(tpl.prompt) class Prompt(BasePrompt): @@ -184,7 +185,7 @@ async def main(): def __init__(self, text: str) -> None: super().__init__(text) - if not async_enabled: + if not config.ASYNC_ENABLED: msg = "Async is not enabled. Please set the environment variable 'BANKS_ASYNC_ENABLED=on' and try again." raise AsyncError(msg) diff --git a/src/banks/registries/__init__.py b/src/banks/registries/__init__.py new file mode 100644 index 0000000..adb4838 --- /dev/null +++ b/src/banks/registries/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from .file import FileTemplateRegistry + +__all__ = ("FileTemplateRegistry",) diff --git a/src/banks/registries/file.py b/src/banks/registries/file.py new file mode 100644 index 0000000..e3a62d6 --- /dev/null +++ b/src/banks/registries/file.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from pathlib import Path + +from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError + + +class FileTemplateRegistry: + def __init__(self, user_data_path: Path) -> None: + self._index_fpath: Path = user_data_path / "index.json" + self._index: PromptTemplateIndex = PromptTemplateIndex(templates=[]) + try: + self._index = PromptTemplateIndex.model_validate_json(self._index_fpath.read_text()) + except FileNotFoundError: + # init the user data folder + user_data_path.mkdir(parents=True, exist_ok=True) + + @staticmethod + def _make_id(name: str, version: str | None): + if version: + return f"{name}:{version}" + return name + + def save(self) -> None: + with open(self._index_fpath, "w") as f: + f.write(self._index.model_dump_json()) + + def get(self, name: str, version: str | None = None) -> "PromptTemplate": + tpl_id = self._make_id(name, version) + for tpl in self._index.templates: + if tpl_id == tpl.id: + return tpl + + msg = f"cannot find template '{tpl_id}'" + raise TemplateNotFoundError(msg) + + def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False): + try: + tpl = self.get(name, version) + if overwrite: + tpl.prompt = prompt + return + except TemplateNotFoundError: + tpl_id = self._make_id(name, version) + tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt) + self._index.templates.append(tpl) diff --git a/src/banks/registry.py b/src/banks/registry.py new file mode 100644 index 0000000..1501a90 --- /dev/null +++ b/src/banks/registry.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from typing import Protocol + +from pydantic import BaseModel + + +class TemplateNotFoundError(Exception): ... + + +class PromptTemplate(BaseModel): + id: str + name: str + version: str + prompt: str + + +class PromptTemplateIndex(BaseModel): + templates: list[PromptTemplate] + + +class TemplateRegistry(Protocol): + def save(self) -> None: ... + + def get(self, *, name: str, version: str | None = None) -> "PromptTemplate": ... + + def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False): ... diff --git a/tests/test_file_registry.py b/tests/test_file_registry.py new file mode 100644 index 0000000..5647db4 --- /dev/null +++ b/tests/test_file_registry.py @@ -0,0 +1,70 @@ +import pytest + +from banks.registries.file import FileTemplateRegistry +from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError + + +@pytest.fixture +def populated_index_dir(tmp_path): + tpls = [PromptTemplate(id="name:version", name="name", version="version", prompt="prompt")] + idx = PromptTemplateIndex(templates=tpls) + with open(tmp_path / "index.json", "w") as f: + f.write(idx.model_dump_json()) + return tmp_path + + +def test_init_from_scratch(tmp_path): + index_dir = tmp_path / "test" + r = FileTemplateRegistry(index_dir) + assert r._index_fpath == index_dir / "index.json" + assert index_dir.exists() + + +def test_init_from_existing_dir(tmp_path): + r = FileTemplateRegistry(tmp_path) + r.save() + assert r._index_fpath.exists() + + +def test_init_from_existing_index(populated_index_dir): + r = FileTemplateRegistry(populated_index_dir) + r.get("name", "version") + + +def test__make_id(): + assert FileTemplateRegistry._make_id("name", "version") == "name:version" + assert FileTemplateRegistry._make_id("name", None) == "name" + + +def test_get(populated_index_dir): + r = FileTemplateRegistry(populated_index_dir) + tpl = r.get("name", "version") + assert tpl.id == "name:version" + + +def test_get_not_found(populated_index_dir): + r = FileTemplateRegistry(populated_index_dir) + with pytest.raises(TemplateNotFoundError): + r.get("name", "nonexisting_version") + + +def test_set_existing_no_overwrite(populated_index_dir): + r = FileTemplateRegistry(populated_index_dir) + new_prompt = "a new prompt!" + r.set(name="name", prompt=new_prompt, version="version") # template already exists, expected to be no-op + assert r.get("name", "version").prompt == "prompt" + + +def test_set_existing_overwrite(populated_index_dir): + r = FileTemplateRegistry(populated_index_dir) + new_prompt = "a new prompt!" + r.set(name="name", prompt=new_prompt, version="version", overwrite=True) + assert r.get("name", "version").prompt == new_prompt + + +def test_set_new(populated_index_dir): + r = FileTemplateRegistry(populated_index_dir) + new_prompt = "a new prompt!" + r.set(name="name", prompt=new_prompt, version="version2") + assert r.get("name", "version").prompt == "prompt" + assert r.get("name", "version2").prompt == new_prompt