Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat & bugfix: Added metadata features and fixed naming bug #11

Merged
merged 17 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@ disable = [
"missing-class-docstring",
"missing-function-docstring",
"cyclic-import",
]
]
max-args = 10
2 changes: 1 addition & 1 deletion src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


# This function exists for documentation purpose.
def generate(model_name: str): # noqa # pylint: disable=W0613
def generate(model_name: str): # pylint: disable=W0613
"""
`generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content.

Expand Down
71 changes: 58 additions & 13 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import time
from pathlib import Path

from pydantic import BaseModel, Field

from banks import Prompt
from banks.registry import TemplateNotFoundError

# Constants
DEFAULT_VERSION = "0"
DEFAULT_INDEX_NAME = "index.json"


class PromptFile(BaseModel):
name: str
version: str
path: Path
meta: dict


class PromptFileIndex(BaseModel):
Expand All @@ -26,7 +32,7 @@ def __init__(self, directory_path: Path, *, force_reindex: bool = False):
raise ValueError(msg)

self._path = directory_path
self._index_path = self._path / "index.json"
self._index_path = self._path / DEFAULT_INDEX_NAME
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noice!

if not self._index_path.exists() or force_reindex:
self._scan()
else:
Expand All @@ -38,24 +44,63 @@ def _load(self):
def _scan(self):
self._index: PromptFileIndex = PromptFileIndex()
for path in self._path.glob("*.jinja*"):
pf = PromptFile(name=path.stem, version="0", path=path)
name, version = path.stem.rsplit(".", 1) if "." in path.stem else (path.stem, DEFAULT_VERSION)
pf = PromptFile(name=name, version=version, path=path, meta={})
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get(self, *, name: str, version: str | None = None) -> "Prompt":
version = version or "0"
def get(self, *, name: str, version: str = DEFAULT_VERSION) -> "PromptFile":
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return Prompt(pf.path.read_text())
return pf
raise TemplateNotFoundError

def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False):
version = version or "0"
def get_prompt(self, *, name: str, version: str = DEFAULT_VERSION) -> Prompt:
return Prompt(self.get(name=name, version=version).path.read_text())

def _get_prompt_file(self, *, name: str, version: str) -> PromptFile | None:
for pf in self._index.files:
if pf.name == name and pf.version == version and overwrite:
pf.path.write_text(prompt.raw)
return
new_prompt_file = self._path / "{name}.{version}.jinja"
if pf.name == name and pf.version == version:
return pf
return None

def _create_pf(self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict) -> "PromptFile":
pf = self._get_prompt_file(name=name, version=version)
if pf:
if not overwrite:
msg = f"Prompt {name}.{version}.jinja already exists. Use overwrite=True to overwrite."
raise ValueError(msg)
pf.path.write_text(prompt.raw)
pf.meta = meta
return pf
new_prompt_file = self._path / f"{name}.{version}.jinja"
new_prompt_file.write_text(prompt.raw)
pf = PromptFile(name=name, version=version, path=new_prompt_file)
self._index.files.append(pf)
pf = PromptFile(name=name, version=version, path=new_prompt_file, meta=meta)
return pf

def set(
self,
*,
name: str,
prompt: Prompt,
meta: dict | None = None,
version: str = DEFAULT_VERSION,
overwrite: bool = False,
):
meta = {**(meta or {}), "created_at": time.ctime()}

pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta)
if pf not in self._index.files:
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get_meta(self, *, name: str, version: str = DEFAULT_VERSION) -> dict:
return self.get(name=name, version=version).meta

def update_meta(self, *, meta: dict, name: str, version: str = DEFAULT_VERSION):
pf = self._get_prompt_file(name=name, version=version)
if not pf:
unk_err = f"Prompt {name}.{version} not found in the index. Cannot set meta for a non-existing prompt."
raise ValueError(unk_err)
pf.meta = meta
self._index_path.write_text(self._index.model_dump_json())
8 changes: 4 additions & 4 deletions tests/test_default_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def _get_data(name):


def test_blog(registry):
p = registry.get(name="blog")
p = registry.get_prompt(name="blog")
assert _get_data("blog.jinja.out") == p.text({"topic": "climate change"})


def test_summarize(registry):
p = registry.get(name="summarize")
p = registry.get_prompt(name="summarize")
documents = [
"A first paragraph talking about AI",
"A second paragraph talking about climate change",
Expand All @@ -44,12 +44,12 @@ def test_summarize(registry):
def test_summarize_lemma(registry):
pytest.importorskip("simplemma")

p = registry.get(name="summarize_lemma")
p = registry.get_prompt(name="summarize_lemma")
assert _get_data("summarize_lemma.jinja.out") == p.text({"document": "The cats are running"})


def test_generate_tweet(registry):
p = registry.get(name="generate_tweet")
p = registry.get_prompt(name="generate_tweet")
ext_name = "banks.extensions.generate.GenerateExtension"
env.extensions[ext_name]._generate = mock.MagicMock(return_value="foo") # type:ignore

Expand Down
56 changes: 50 additions & 6 deletions tests/test_directory_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from pathlib import Path

import pytest
Expand All @@ -20,8 +21,9 @@ def populated_dir(tmp_path):

def test_init_from_scratch(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
p = r.get(name="blog")
p = r.get_prompt(name="blog")
assert p.raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}")
assert r.get_meta(name="blog") == {}


def test_init_from_existing_index(populated_dir):
Expand All @@ -32,12 +34,14 @@ def test_init_from_existing_index(populated_dir):


def test_init_from_existing_index_force(populated_dir):
r = DirectoryTemplateRegistry(populated_dir) # creates the index
_ = DirectoryTemplateRegistry(populated_dir) # creates the index
# change the directory structure
f = populated_dir / "blog.jinja"
os.remove(f)
# force recreation, the renamed file should be updated in the index
r = DirectoryTemplateRegistry(populated_dir, force_reindex=True)
with pytest.raises(TemplateNotFoundError):
r.get_prompt(name="blog")
with pytest.raises(TemplateNotFoundError):
r.get(name="blog")

Expand All @@ -50,18 +54,58 @@ def test_init_invalid_dir():
def test_get_not_found(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
with pytest.raises(TemplateNotFoundError):
r.get(name="FOO")
r.get_prompt(name="FOO")
with pytest.raises(TemplateNotFoundError):
r.get_meta(name="FOO")


def test_set_existing_no_overwrite(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
new_prompt = Prompt("a new prompt!")
r.set(name="blog", prompt=new_prompt) # template already exists, expected to be no-op
assert r.get(name="blog").raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}")
with pytest.raises(ValueError) as e:
r.set(name="blog", prompt=new_prompt)
assert "already exists. Use overwrite=True to overwrite." in str(e.value)


def test_set_existing_overwrite(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
new_prompt = Prompt("a new prompt!")
current_time = time.ctime()
r.set(name="blog", prompt=new_prompt, overwrite=True)
assert r.get(name="blog").raw.startswith("a new prompt!")
assert r.get(name="blog").path.read_text() == "a new prompt!"
assert r.get_prompt(name="blog").raw.startswith("a new prompt!")
assert r.get(name="blog").meta == {"created_at": current_time}
assert r.get_meta(name="blog") == {"created_at": current_time} # created_at changes because it's overwritten


def test_set_multiple_templates(populated_dir):
r = DirectoryTemplateRegistry(Path(populated_dir))
current_time = time.ctime()
new_prompt = Prompt("a very new prompt!")
old_prompt = Prompt("an old prompt!")
r.set(name="new", version="2", prompt=new_prompt)
r.set(name="old", version="1", prompt=old_prompt)
assert r.get_prompt(name="old", version="1").raw == "an old prompt!"
assert r.get_meta(name="old", version="1") == {"created_at": current_time}
assert r.get_prompt(name="new", version="2").raw == "a very new prompt!"
assert r.get_meta(name="new", version="2") == {"created_at": current_time}


def test_update_meta(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)

# test metadata for initial set
new_prompt = Prompt("a very new prompt!")
current_time = time.ctime()
r.set(name="new", version="3", prompt=new_prompt, meta={"accuracy": 91.2})
assert r.get_meta(name="new", version="3") == {"accuracy": 91.2, "created_at": current_time}

# test metadata error update for non-existing prompt
with pytest.raises(ValueError) as e:
_ = r.update_meta(name="foo", version="bar", meta={"accuracy": 91.2, "created_at": current_time})
assert "Cannot set meta for a non-existing prompt." in str(e.value)

# test metadata update for existing prompt
created_time = r.get_meta(name="new", version="3")["created_at"]
_ = r.update_meta(name="new", version="3", meta={"accuracy": 94.3, "created_at": created_time})
assert r.get_meta(name="new", version="3") == {"accuracy": 94.3, "created_at": created_time}