Skip to content

Commit

Permalink
addressed all comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mayankjobanputra committed Sep 6, 2024
1 parent ea7fb82 commit 19b33f6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 24 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@ disable = [
"missing-class-docstring",
"missing-function-docstring",
"cyclic-import",
]
]
max-args = 10
30 changes: 12 additions & 18 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, Field

from banks import Prompt
from banks.registry import MetaNotFoundError, TemplateNotFoundError
from banks.registry import TemplateNotFoundError

# Constants
DEFAULT_VERSION = "0"
Expand Down Expand Up @@ -49,22 +49,22 @@ def _scan(self):
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get_prompt(self, *, name: str, version: str | None = None) -> "Prompt":
version = version or DEFAULT_VERSION
def get(self, *, name: str, version: str = DEFAULT_VERSION) -> "PromptFile":
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return Prompt(pf.path.read_text())
return pf
raise TemplateNotFoundError

def get_prompt(self, *, name: str, version: str = DEFAULT_VERSION) -> Prompt:
return Prompt(self.get(name=name, version=version).path.read_text())

def _get_prompt_file(self, *, name: str, version: str) -> PromptFile | None:
for pf in self._index.files:
if pf.name == name and pf.version == version:
return pf
return None

def _create_pf( # pylint: disable=too-many-arguments
self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict
) -> "PromptFile":
def _create_pf(self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict) -> "PromptFile":
pf = self._get_prompt_file(name=name, version=version)
if pf:
if not overwrite:
Expand All @@ -78,32 +78,26 @@ def _create_pf( # pylint: disable=too-many-arguments
pf = PromptFile(name=name, version=version, path=new_prompt_file, meta=meta)
return pf

def set( # pylint: disable=too-many-arguments
def set(
self,
*,
name: str,
prompt: Prompt,
meta: dict | None = None,
version: str | None = None,
version: str = DEFAULT_VERSION,
overwrite: bool = False,
):
version = version or DEFAULT_VERSION
meta = {**(meta or {}), "created_at": time.ctime()}

pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta)
if pf not in self._index.files:
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get_meta(self, *, name: str, version: str | None = None) -> dict:
version = version or DEFAULT_VERSION
for pf in self._index.files:
if pf.name == name and pf.version == version:
return pf.meta
raise MetaNotFoundError
def get_meta(self, *, name: str, version: str = DEFAULT_VERSION) -> dict:
return self.get(name=name, version=version).meta

def update_meta(self, *, meta: dict, name: str, version: str | None = None):
version = version or DEFAULT_VERSION
def update_meta(self, *, meta: dict, name: str, version: str = DEFAULT_VERSION):
pf = self._get_prompt_file(name=name, version=version)
if not pf:
unk_err = f"Prompt {name}.{version} not found in the index. Cannot set meta for a non-existing prompt."
Expand Down
3 changes: 0 additions & 3 deletions src/banks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ class TemplateNotFoundError(Exception): ...
class InvalidTemplateError(Exception): ...


class MetaNotFoundError(Exception): ...


class TemplateRegistry(Protocol):
def get(self, *, name: str, version: str | None = None) -> "Prompt": ...

Expand Down
9 changes: 7 additions & 2 deletions tests/test_directory_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from banks.prompt import Prompt
from banks.registries.directory import DirectoryTemplateRegistry
from banks.registry import MetaNotFoundError, TemplateNotFoundError
from banks.registry import TemplateNotFoundError


@pytest.fixture
Expand All @@ -23,6 +23,7 @@ def test_init_from_scratch(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
p = r.get_prompt(name="blog")
assert p.raw.startswith("{# Zero-shot, this is already enough for most topics in english -#}")
assert r.get_meta(name="blog") == {}


def test_init_from_existing_index(populated_dir):
Expand All @@ -41,6 +42,8 @@ def test_init_from_existing_index_force(populated_dir):
r = DirectoryTemplateRegistry(populated_dir, force_reindex=True)
with pytest.raises(TemplateNotFoundError):
r.get_prompt(name="blog")
with pytest.raises(TemplateNotFoundError):
r.get(name="blog")


def test_init_invalid_dir():
Expand All @@ -52,7 +55,7 @@ def test_get_not_found(populated_dir):
r = DirectoryTemplateRegistry(populated_dir)
with pytest.raises(TemplateNotFoundError):
r.get_prompt(name="FOO")
with pytest.raises(MetaNotFoundError):
with pytest.raises(TemplateNotFoundError):
r.get_meta(name="FOO")


Expand All @@ -69,7 +72,9 @@ def test_set_existing_overwrite(populated_dir):
new_prompt = Prompt("a new prompt!")
current_time = time.ctime()
r.set(name="blog", prompt=new_prompt, overwrite=True)
assert r.get(name="blog").path.read_text() == "a new prompt!"
assert r.get_prompt(name="blog").raw.startswith("a new prompt!")
assert r.get(name="blog").meta == {"created_at": current_time}
assert r.get_meta(name="blog") == {"created_at": current_time} # created_at changes because it's overwritten


Expand Down

0 comments on commit 19b33f6

Please sign in to comment.