diff --git a/pyproject.toml b/pyproject.toml index 48fe595..10b144c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,4 +185,5 @@ disable = [ "missing-class-docstring", "missing-function-docstring", "cyclic-import", -] \ No newline at end of file +] +max-args = 10 diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 3582691..6d693af 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -16,7 +16,7 @@ # This function exists for documentation purpose. -def generate(model_name: str): # noqa # pylint: disable=W0613 +def generate(model_name: str): # pylint: disable=W0613 """ `generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content. diff --git a/src/banks/registries/directory.py b/src/banks/registries/directory.py index c902dec..549e59c 100644 --- a/src/banks/registries/directory.py +++ b/src/banks/registries/directory.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import time from pathlib import Path from pydantic import BaseModel, Field @@ -8,11 +9,16 @@ from banks import Prompt from banks.registry import TemplateNotFoundError +# Constants +DEFAULT_VERSION = "0" +DEFAULT_INDEX_NAME = "index.json" + class PromptFile(BaseModel): name: str version: str path: Path + meta: dict class PromptFileIndex(BaseModel): @@ -26,7 +32,7 @@ def __init__(self, directory_path: Path, *, force_reindex: bool = False): raise ValueError(msg) self._path = directory_path - self._index_path = self._path / "index.json" + self._index_path = self._path / DEFAULT_INDEX_NAME if not self._index_path.exists() or force_reindex: self._scan() else: @@ -38,24 +44,63 @@ def _load(self): def _scan(self): self._index: PromptFileIndex = PromptFileIndex() for path in self._path.glob("*.jinja*"): - pf = PromptFile(name=path.stem, version="0", path=path) + name, version = path.stem.rsplit(".", 1) if "." in path.stem else (path.stem, DEFAULT_VERSION) + pf = PromptFile(name=name, version=version, path=path, meta={}) self._index.files.append(pf) self._index_path.write_text(self._index.model_dump_json()) - def get(self, *, name: str, version: str | None = None) -> "Prompt": - version = version or "0" + def get(self, *, name: str, version: str = DEFAULT_VERSION) -> "PromptFile": for pf in self._index.files: if pf.name == name and pf.version == version and pf.path.exists(): - return Prompt(pf.path.read_text()) + return pf raise TemplateNotFoundError - def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False): - version = version or "0" + def get_prompt(self, *, name: str, version: str = DEFAULT_VERSION) -> Prompt: + return Prompt(self.get(name=name, version=version).path.read_text()) + + def _get_prompt_file(self, *, name: str, version: str) -> PromptFile | None: for pf in self._index.files: - if pf.name == name and pf.version == version and overwrite: - pf.path.write_text(prompt.raw) - return - new_prompt_file = self._path / "{name}.{version}.jinja" + if pf.name == name and pf.version == version: + return pf + return None + + def _create_pf(self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict) -> "PromptFile": + pf = self._get_prompt_file(name=name, version=version) + if pf: + if not overwrite: + msg = f"Prompt {name}.{version}.jinja already exists. Use overwrite=True to overwrite." + raise ValueError(msg) + pf.path.write_text(prompt.raw) + pf.meta = meta + return pf + new_prompt_file = self._path / f"{name}.{version}.jinja" new_prompt_file.write_text(prompt.raw) - pf = PromptFile(name=name, version=version, path=new_prompt_file) - self._index.files.append(pf) + pf = PromptFile(name=name, version=version, path=new_prompt_file, meta=meta) + return pf + + def set( + self, + *, + name: str, + prompt: Prompt, + meta: dict | None = None, + version: str = DEFAULT_VERSION, + overwrite: bool = False, + ): + meta = {**(meta or {}), "created_at": time.ctime()} + + pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta) + if pf not in self._index.files: + self._index.files.append(pf) + self._index_path.write_text(self._index.model_dump_json()) + + def get_meta(self, *, name: str, version: str = DEFAULT_VERSION) -> dict: + return self.get(name=name, version=version).meta + + def update_meta(self, *, meta: dict, name: str, version: str = DEFAULT_VERSION): + pf = self._get_prompt_file(name=name, version=version) + if not pf: + unk_err = f"Prompt {name}.{version} not found in the index. Cannot set meta for a non-existing prompt." + raise ValueError(unk_err) + pf.meta = meta + self._index_path.write_text(self._index.model_dump_json()) diff --git a/tests/test_default_templates.py b/tests/test_default_templates.py index 846dae1..6710591 100644 --- a/tests/test_default_templates.py +++ b/tests/test_default_templates.py @@ -27,12 +27,12 @@ def _get_data(name): def test_blog(registry): - p = registry.get(name="blog") + p = registry.get_prompt(name="blog") assert _get_data("blog.jinja.out") == p.text({"topic": "climate change"}) def test_summarize(registry): - p = registry.get(name="summarize") + p = registry.get_prompt(name="summarize") documents = [ "A first paragraph talking about AI", "A second paragraph talking about climate change", @@ -44,12 +44,12 @@ def test_summarize(registry): def test_summarize_lemma(registry): pytest.importorskip("simplemma") - p = registry.get(name="summarize_lemma") + p = registry.get_prompt(name="summarize_lemma") assert _get_data("summarize_lemma.jinja.out") == p.text({"document": "The cats are running"}) def test_generate_tweet(registry): - p = registry.get(name="generate_tweet") + p = registry.get_prompt(name="generate_tweet") ext_name = "banks.extensions.generate.GenerateExtension" env.extensions[ext_name]._generate = mock.MagicMock(return_value="foo") # type:ignore diff --git a/tests/test_directory_registry.py b/tests/test_directory_registry.py index b92adf0..32879dd 100644 --- a/tests/test_directory_registry.py +++ b/tests/test_directory_registry.py @@ -1,4 +1,5 @@ import os +import time from pathlib import Path import pytest @@ -20,8 +21,9 @@ def populated_dir(tmp_path): def test_init_from_scratch(populated_dir): r = DirectoryTemplateRegistry(populated_dir) - p = r.get(name="blog") + p = r.get_prompt(name="blog") assert p.raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}") + assert r.get_meta(name="blog") == {} def test_init_from_existing_index(populated_dir): @@ -32,12 +34,14 @@ def test_init_from_existing_index(populated_dir): def test_init_from_existing_index_force(populated_dir): - r = DirectoryTemplateRegistry(populated_dir) # creates the index + _ = DirectoryTemplateRegistry(populated_dir) # creates the index # change the directory structure f = populated_dir / "blog.jinja" os.remove(f) # force recreation, the renamed file should be updated in the index r = DirectoryTemplateRegistry(populated_dir, force_reindex=True) + with pytest.raises(TemplateNotFoundError): + r.get_prompt(name="blog") with pytest.raises(TemplateNotFoundError): r.get(name="blog") @@ -50,18 +54,58 @@ def test_init_invalid_dir(): def test_get_not_found(populated_dir): r = DirectoryTemplateRegistry(populated_dir) with pytest.raises(TemplateNotFoundError): - r.get(name="FOO") + r.get_prompt(name="FOO") + with pytest.raises(TemplateNotFoundError): + r.get_meta(name="FOO") def test_set_existing_no_overwrite(populated_dir): r = DirectoryTemplateRegistry(populated_dir) new_prompt = Prompt("a new prompt!") - r.set(name="blog", prompt=new_prompt) # template already exists, expected to be no-op - assert r.get(name="blog").raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}") + with pytest.raises(ValueError) as e: + r.set(name="blog", prompt=new_prompt) + assert "already exists. Use overwrite=True to overwrite." in str(e.value) def test_set_existing_overwrite(populated_dir): r = DirectoryTemplateRegistry(populated_dir) new_prompt = Prompt("a new prompt!") + current_time = time.ctime() r.set(name="blog", prompt=new_prompt, overwrite=True) - assert r.get(name="blog").raw.startswith("a new prompt!") + assert r.get(name="blog").path.read_text() == "a new prompt!" + assert r.get_prompt(name="blog").raw.startswith("a new prompt!") + assert r.get(name="blog").meta == {"created_at": current_time} + assert r.get_meta(name="blog") == {"created_at": current_time} # created_at changes because it's overwritten + + +def test_set_multiple_templates(populated_dir): + r = DirectoryTemplateRegistry(Path(populated_dir)) + current_time = time.ctime() + new_prompt = Prompt("a very new prompt!") + old_prompt = Prompt("an old prompt!") + r.set(name="new", version="2", prompt=new_prompt) + r.set(name="old", version="1", prompt=old_prompt) + assert r.get_prompt(name="old", version="1").raw == "an old prompt!" + assert r.get_meta(name="old", version="1") == {"created_at": current_time} + assert r.get_prompt(name="new", version="2").raw == "a very new prompt!" + assert r.get_meta(name="new", version="2") == {"created_at": current_time} + + +def test_update_meta(populated_dir): + r = DirectoryTemplateRegistry(populated_dir) + + # test metadata for initial set + new_prompt = Prompt("a very new prompt!") + current_time = time.ctime() + r.set(name="new", version="3", prompt=new_prompt, meta={"accuracy": 91.2}) + assert r.get_meta(name="new", version="3") == {"accuracy": 91.2, "created_at": current_time} + + # test metadata error update for non-existing prompt + with pytest.raises(ValueError) as e: + _ = r.update_meta(name="foo", version="bar", meta={"accuracy": 91.2, "created_at": current_time}) + assert "Cannot set meta for a non-existing prompt." in str(e.value) + + # test metadata update for existing prompt + created_time = r.get_meta(name="new", version="3")["created_at"] + _ = r.update_meta(name="new", version="3", meta={"accuracy": 94.3, "created_at": created_time}) + assert r.get_meta(name="new", version="3") == {"accuracy": 94.3, "created_at": created_time}