Skip to content

Commit

Permalink
feat & bugfix: Added metadata features and fixed naming bug (#11)
Browse files Browse the repository at this point in the history
* Added metadata features and fixed naming bug

* Added prompt exists check before setting meta

* fix lint issues

* fix lint issues

* fix test lint issues

* fix more lint issues

* fix fmt issues

* remove blanket

* fixed overwrite and add tests

* fixed lint locally

* make the linters happy

* piggyback on existing index infra, braking changes

* fix lint issues

* unified prompt details and meta into single index json

* fixed tests from other test files

* addressed all comments

---------

Co-authored-by: Massimiliano Pippi <[email protected]>
  • Loading branch information
mayankjobanputra and masci authored Sep 7, 2024
1 parent c15f849 commit 373ea6f
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 25 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@ disable = [
"missing-class-docstring",
"missing-function-docstring",
"cyclic-import",
]
]
max-args = 10
2 changes: 1 addition & 1 deletion src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


# This function exists for documentation purpose.
def generate(model_name: str): # noqa # pylint: disable=W0613
def generate(model_name: str): # pylint: disable=W0613
"""
`generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content.
Expand Down
71 changes: 58 additions & 13 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import time
from pathlib import Path

from pydantic import BaseModel, Field

from banks import Prompt
from banks.registry import TemplateNotFoundError

# Constants
DEFAULT_VERSION = "0"
DEFAULT_INDEX_NAME = "index.json"


class PromptFile(BaseModel):
name: str
version: str
path: Path
meta: dict


class PromptFileIndex(BaseModel):
Expand All @@ -26,7 +32,7 @@ def __init__(self, directory_path: Path, *, force_reindex: bool = False):
raise ValueError(msg)

self._path = directory_path
self._index_path = self._path / "index.json"
self._index_path = self._path / DEFAULT_INDEX_NAME
if not self._index_path.exists() or force_reindex:
self._scan()
else:
Expand All @@ -38,24 +44,63 @@ def _load(self):
def _scan(self):
self._index: PromptFileIndex = PromptFileIndex()
for path in self._path.glob("*.jinja*"):
pf = PromptFile(name=path.stem, version="0", path=path)
name, version = path.stem.rsplit(".", 1) if "." in path.stem else (path.stem, DEFAULT_VERSION)
pf = PromptFile(name=name, version=version, path=path, meta={})
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get(self, *, name: str, version: str | None = None) -> "Prompt":
version = version or "0"
def get(self, *, name: str, version: str = DEFAULT_VERSION) -> "PromptFile":
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return Prompt(pf.path.read_text())
return pf
raise TemplateNotFoundError

def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False):
version = version or "0"
def get_prompt(self, *, name: str, version: str = DEFAULT_VERSION) -> Prompt:
return Prompt(self.get(name=name, version=version).path.read_text())

def _get_prompt_file(self, *, name: str, version: str) -> PromptFile | None:
for pf in self._index.files:
if pf.name == name and pf.version == version and overwrite:
pf.path.write_text(prompt.raw)
return
new_prompt_file = self._path / "{name}.{version}.jinja"
if pf.name == name and pf.version == version:
return pf
return None

def _create_pf(self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict) -> "PromptFile":
pf = self._get_prompt_file(name=name, version=version)
if pf:
if not overwrite:
msg = f"Prompt {name}.{version}.jinja already exists. Use overwrite=True to overwrite."
raise ValueError(msg)
pf.path.write_text(prompt.raw)
pf.meta = meta
return pf
new_prompt_file = self._path / f"{name}.{version}.jinja"
new_prompt_file.write_text(prompt.raw)
pf = PromptFile(name=name, version=version, path=new_prompt_file)
self._index.files.append(pf)
pf = PromptFile(name=name, version=version, path=new_prompt_file, meta=meta)
return pf

def set(
self,
*,
name: str,
prompt: Prompt,
meta: dict | None = None,
version: str = DEFAULT_VERSION,
overwrite: bool = False,
):
meta = {**(meta or {}), "created_at": time.ctime()}

pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta)
if pf not in self._index.files:
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get_meta(self, *, name: str, version: str = DEFAULT_VERSION) -> dict:
return self.get(name=name, version=version).meta

def update_meta(self, *, meta: dict, name: str, version: str = DEFAULT_VERSION):
pf = self._get_prompt_file(name=name, version=version)
if not pf:
unk_err = f"Prompt {name}.{version} not found in the index. Cannot set meta for a non-existing prompt."
raise ValueError(unk_err)
pf.meta = meta
self._index_path.write_text(self._index.model_dump_json())
8 changes: 4 additions & 4 deletions tests/test_default_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def _get_data(name):


def test_blog(registry):
p = registry.get(name="blog")
p = registry.get_prompt(name="blog")
assert _get_data("blog.jinja.out") == p.text({"topic": "climate change"})


def test_summarize(registry):
p = registry.get(name="summarize")
p = registry.get_prompt(name="summarize")
documents = [
"A first paragraph talking about AI",
"A second paragraph talking about climate change",
Expand All @@ -44,12 +44,12 @@ def test_summarize(registry):
def test_summarize_lemma(registry):
pytest.importorskip("simplemma")

p = registry.get(name="summarize_lemma")
p = registry.get_prompt(name="summarize_lemma")
assert _get_data("summarize_lemma.jinja.out") == p.text({"document": "The cats are running"})


def test_generate_tweet(registry):
p = registry.get(name="generate_tweet")
p = registry.get_prompt(name="generate_tweet")
ext_name = "banks.extensions.generate.GenerateExtension"
env.extensions[ext_name]._generate = mock.MagicMock(return_value="foo") # type:ignore

Expand Down
56 changes: 50 additions & 6 deletions tests/test_directory_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from pathlib import Path

import pytest
Expand All @@ -20,8 +21,9 @@ def populated_dir(tmp_path):

def test_init_from_scratch(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
p = r.get(name="blog")
p = r.get_prompt(name="blog")
assert p.raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}")
assert r.get_meta(name="blog") == {}


def test_init_from_existing_index(populated_dir):
Expand All @@ -32,12 +34,14 @@ def test_init_from_existing_index(populated_dir):


def test_init_from_existing_index_force(populated_dir):
r = DirectoryTemplateRegistry(populated_dir) # creates the index
_ = DirectoryTemplateRegistry(populated_dir) # creates the index
# change the directory structure
f = populated_dir / "blog.jinja"
os.remove(f)
# force recreation, the renamed file should be updated in the index
r = DirectoryTemplateRegistry(populated_dir, force_reindex=True)
with pytest.raises(TemplateNotFoundError):
r.get_prompt(name="blog")
with pytest.raises(TemplateNotFoundError):
r.get(name="blog")

Expand All @@ -50,18 +54,58 @@ def test_init_invalid_dir():
def test_get_not_found(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
with pytest.raises(TemplateNotFoundError):
r.get(name="FOO")
r.get_prompt(name="FOO")
with pytest.raises(TemplateNotFoundError):
r.get_meta(name="FOO")


def test_set_existing_no_overwrite(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
new_prompt = Prompt("a new prompt!")
r.set(name="blog", prompt=new_prompt) # template already exists, expected to be no-op
assert r.get(name="blog").raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}")
with pytest.raises(ValueError) as e:
r.set(name="blog", prompt=new_prompt)
assert "already exists. Use overwrite=True to overwrite." in str(e.value)


def test_set_existing_overwrite(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
new_prompt = Prompt("a new prompt!")
current_time = time.ctime()
r.set(name="blog", prompt=new_prompt, overwrite=True)
assert r.get(name="blog").raw.startswith("a new prompt!")
assert r.get(name="blog").path.read_text() == "a new prompt!"
assert r.get_prompt(name="blog").raw.startswith("a new prompt!")
assert r.get(name="blog").meta == {"created_at": current_time}
assert r.get_meta(name="blog") == {"created_at": current_time} # created_at changes because it's overwritten


def test_set_multiple_templates(populated_dir):
r = DirectoryTemplateRegistry(Path(populated_dir))
current_time = time.ctime()
new_prompt = Prompt("a very new prompt!")
old_prompt = Prompt("an old prompt!")
r.set(name="new", version="2", prompt=new_prompt)
r.set(name="old", version="1", prompt=old_prompt)
assert r.get_prompt(name="old", version="1").raw == "an old prompt!"
assert r.get_meta(name="old", version="1") == {"created_at": current_time}
assert r.get_prompt(name="new", version="2").raw == "a very new prompt!"
assert r.get_meta(name="new", version="2") == {"created_at": current_time}


def test_update_meta(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)

# test metadata for initial set
new_prompt = Prompt("a very new prompt!")
current_time = time.ctime()
r.set(name="new", version="3", prompt=new_prompt, meta={"accuracy": 91.2})
assert r.get_meta(name="new", version="3") == {"accuracy": 91.2, "created_at": current_time}

# test metadata error update for non-existing prompt
with pytest.raises(ValueError) as e:
_ = r.update_meta(name="foo", version="bar", meta={"accuracy": 91.2, "created_at": current_time})
assert "Cannot set meta for a non-existing prompt." in str(e.value)

# test metadata update for existing prompt
created_time = r.get_meta(name="new", version="3")["created_at"]
_ = r.update_meta(name="new", version="3", meta={"accuracy": 94.3, "created_at": created_time})
assert r.get_meta(name="new", version="3") == {"accuracy": 94.3, "created_at": created_time}

0 comments on commit 373ea6f

Please sign in to comment.