Skip to content

Commit

Permalink
unified prompt details and meta into single index json
Browse files Browse the repository at this point in the history
  • Loading branch information
mayankjobanputra committed Sep 4, 2024
1 parent 09f18c2 commit 8ac0b66
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 56 deletions.
76 changes: 35 additions & 41 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import json
import os
import time
from pathlib import Path

from pydantic import BaseModel, Field

from banks import Prompt
from banks.registry import TemplateNotFoundError
from banks.registry import MetaNotFoundError, TemplateNotFoundError

# Constants
DEFAULT_VERSION = "0"
DEFAULT_INDEX_NAME = "index.json"
DEFAULT_META_PATH = "meta"


class PromptFile(BaseModel):
name: str
version: str
path: Path
meta_path: Path
meta: dict


class PromptFileIndex(BaseModel):
Expand All @@ -35,7 +32,6 @@ def __init__(self, directory_path: Path, *, force_reindex: bool = False):
raise ValueError(msg)

self._path = directory_path
os.makedirs(self._path / DEFAULT_META_PATH, exist_ok=True)
self._index_path = self._path / DEFAULT_INDEX_NAME
if not self._index_path.exists() or force_reindex:
self._scan()
Expand All @@ -48,38 +44,39 @@ def _load(self):
def _scan(self):
self._index: PromptFileIndex = PromptFileIndex()
for path in self._path.glob("*.jinja*"):
meta_file = self._path / DEFAULT_META_PATH / f"{path.stem}.json"
pf = PromptFile(name=path.stem, version="0", path=path, meta_path=meta_file)
name, version = path.stem.rsplit(".", 1) if "." in path.stem else (path.stem, DEFAULT_VERSION)
pf = PromptFile(name=name, version=version, path=path, meta={})
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get(self, *, name: str, version: str | None = None) -> "Prompt":
def get_prompt(self, *, name: str, version: str | None = None) -> "Prompt":
version = version or DEFAULT_VERSION
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return Prompt(pf.path.read_text())
raise TemplateNotFoundError

def _create_new_prompt_and_meta(self, *, name: str, prompt: Prompt, meta: dict, version: str | None = None):
new_prompt_file = self._path / f"{name}.{version}.jinja"
new_prompt_file.write_text(prompt.raw)
new_meta_file = self._path / DEFAULT_META_PATH / f"{name}.{version}.json"
new_meta_file.write_text(json.dumps({**meta, "created_at": time.ctime()}))
return new_prompt_file, new_meta_file

def _set_prompt_and_meta( # pylint: disable=too-many-arguments
self, *, name: str, prompt: Prompt, meta: dict, version: str | None = None, overwrite: bool = False
):
def _get_prompt_file(self, *, name: str, version: str) -> PromptFile | None:
for pf in self._index.files:
if pf.name == name and pf.version == version:
if not overwrite:
msg = f"Prompt {name}.{version}.jinja already exists. Use overwrite=True to overwrite."
raise ValueError(msg)
pf.path.write_text(prompt.raw)
current_meta = json.loads(pf.meta_path.read_text())
pf.meta_path.write_text(json.dumps({**current_meta, **meta}))
return pf.path, pf.meta_path
return self._create_new_prompt_and_meta(name=name, prompt=prompt, meta=meta, version=version)
return pf
return None

def _create_pf( # pylint: disable=too-many-arguments
self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict
) -> "PromptFile":
pf = self._get_prompt_file(name=name, version=version)
if pf:
if not overwrite:
msg = f"Prompt {name}.{version}.jinja already exists. Use overwrite=True to overwrite."
raise ValueError(msg)
pf.path.write_text(prompt.raw)
pf.meta = meta
return pf
new_prompt_file = self._path / f"{name}.{version}.jinja"
new_prompt_file.write_text(prompt.raw)
pf = PromptFile(name=name, version=version, path=new_prompt_file, meta=meta)
return pf

def set( # pylint: disable=too-many-arguments
self,
Expand All @@ -93,26 +90,23 @@ def set( # pylint: disable=too-many-arguments
version = version or DEFAULT_VERSION
meta = {**(meta or {}), "created_at": time.ctime()}

prompt_file, meta_file = self._set_prompt_and_meta(
name=name, prompt=prompt, meta=meta, version=version, overwrite=overwrite
)
pf = PromptFile(name=name, version=version, path=prompt_file, meta_path=meta_file)
pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta)
if pf not in self._index.files:
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get_meta(self, *, name: str, version: str | None = None) -> dict:
version = version or DEFAULT_VERSION
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.meta_path.exists():
return json.loads(open(pf.meta_path, encoding="utf-8").read())
return {}
if pf.name == name and pf.version == version:
return pf.meta
raise MetaNotFoundError

def update_meta(self, *, meta: dict, name: str, version: str | None = None):
version = version or DEFAULT_VERSION
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.meta_path.exists():
current_meta = self.get_meta(name=name, version=version)
pf.meta_path.write_text(json.dumps({**current_meta, **meta}))
return pf.meta_path
unk_err = f"Unknown prompt {name}.{version}.jinja, Cannot set meta for a non-existing prompt."
raise ValueError(unk_err)
pf = self._get_prompt_file(name=name, version=version)
if not pf:
unk_err = f"Prompt {name}.{version} not found in the index. Cannot set meta for a non-existing prompt."
raise ValueError(unk_err)
pf.meta = meta
self._index_path.write_text(self._index.model_dump_json())
3 changes: 3 additions & 0 deletions src/banks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class TemplateNotFoundError(Exception): ...
class InvalidTemplateError(Exception): ...


class MetaNotFoundError(Exception): ...


class TemplateRegistry(Protocol):
def get(self, *, name: str, version: str | None = None) -> "Prompt": ...

Expand Down
27 changes: 12 additions & 15 deletions tests/test_directory_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import time
from pathlib import Path
Expand All @@ -7,27 +6,22 @@

from banks.prompt import Prompt
from banks.registries.directory import DirectoryTemplateRegistry
from banks.registry import TemplateNotFoundError
from banks.registry import MetaNotFoundError, TemplateNotFoundError


@pytest.fixture
def populated_dir(tmp_path):
d = tmp_path / "templates"
m = d / "meta"
d.mkdir()
m.mkdir()
for fp in (Path(__file__).parent / "templates").iterdir():
with open(d / fp.name, "w") as f:
f.write(fp.read_text())
with open(m / f"{fp.stem}.json", "w") as f:
meta = {"created_at": time.ctime()}
f.write(json.dumps(meta))
return d


def test_init_from_scratch(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
p = r.get(name="blog")
p = r.get_prompt(name="blog")
assert p.raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}")


Expand All @@ -46,7 +40,7 @@ def test_init_from_existing_index_force(populated_dir):
# force recreation, the renamed file should be updated in the index
r = DirectoryTemplateRegistry(populated_dir, force_reindex=True)
with pytest.raises(TemplateNotFoundError):
r.get(name="blog")
r.get_prompt(name="blog")


def test_init_invalid_dir():
Expand All @@ -57,7 +51,9 @@ def test_init_invalid_dir():
def test_get_not_found(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
with pytest.raises(TemplateNotFoundError):
r.get(name="FOO")
r.get_prompt(name="FOO")
with pytest.raises(MetaNotFoundError):
r.get_meta(name="FOO")


def test_set_existing_no_overwrite(populated_dir):
Expand All @@ -73,7 +69,7 @@ def test_set_existing_overwrite(populated_dir):
new_prompt = Prompt("a new prompt!")
current_time = time.ctime()
r.set(name="blog", prompt=new_prompt, overwrite=True)
assert r.get(name="blog").raw.startswith("a new prompt!")
assert r.get_prompt(name="blog").raw.startswith("a new prompt!")
assert r.get_meta(name="blog") == {"created_at": current_time} # created_at changes because it's overwritten


Expand All @@ -84,9 +80,9 @@ def test_set_multiple_templates(populated_dir):
old_prompt = Prompt("an old prompt!")
r.set(name="new", version="2", prompt=new_prompt)
r.set(name="old", version="1", prompt=old_prompt)
assert r.get(name="old", version="1").raw == "an old prompt!"
assert r.get_prompt(name="old", version="1").raw == "an old prompt!"
assert r.get_meta(name="old", version="1") == {"created_at": current_time}
assert r.get(name="new", version="2").raw == "a very new prompt!"
assert r.get_prompt(name="new", version="2").raw == "a very new prompt!"
assert r.get_meta(name="new", version="2") == {"created_at": current_time}


Expand All @@ -105,5 +101,6 @@ def test_update_meta(populated_dir):
assert "Cannot set meta for a non-existing prompt." in str(e.value)

# test metadata update for existing prompt
_ = r.update_meta(name="new", version="3", meta={"accuracy": 94.3})
assert r.get_meta(name="new", version="3") == {"accuracy": 94.3, "created_at": current_time}
created_time = r.get_meta(name="new", version="3")["created_at"]
_ = r.update_meta(name="new", version="3", meta={"accuracy": 94.3, "created_at": created_time})
assert r.get_meta(name="new", version="3") == {"accuracy": 94.3, "created_at": created_time}

0 comments on commit 8ac0b66

Please sign in to comment.