Skip to content

Commit

Permalink
refact: new registry api (#14)
Browse files Browse the repository at this point in the history
* refactor registry api and file implementation

* fix linting

* directory registry refactoring

* drop 3.10
  • Loading branch information
masci authored Sep 29, 2024
1 parent 4ed002c commit 3f770f1
Show file tree
Hide file tree
Showing 15 changed files with 216 additions and 328 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.10', '3.11', '3.12']
python-version: ['3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
14 changes: 11 additions & 3 deletions src/banks/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
#
# SPDX-License-Identifier: MIT
class MissingDependencyError(Exception):
pass
"""Some optional dependencies are missing."""


class AsyncError(Exception):
pass
"""An error related to asyncio support."""


class CanaryWordError(Exception):
pass
"""The canary word has leaked."""


class PromptNotFoundError(Exception):
"""The prompt was now found in the registry."""


class InvalidPromptError(Exception):
"""The prompt is not valid."""
4 changes: 3 additions & 1 deletion src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .errors import AsyncError
from .utils import generate_canary_word

DEFAULT_VERSION = "0"


class BasePrompt:
def __init__(
Expand Down Expand Up @@ -36,7 +38,7 @@ def __init__(
self._raw: str = text
self._render_cache = render_cache or DefaultCache()
self._template = env.from_string(text)
self._version = version
self._version = version or DEFAULT_VERSION

self.defaults = {"canary_word": canary_word or generate_canary_word()}

Expand Down
4 changes: 2 additions & 2 deletions src/banks/registries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
#
# SPDX-License-Identifier: MIT
from .directory import DirectoryTemplateRegistry
from .file import FileTemplateRegistry
from .file import FilePromptRegistry

__all__ = ("FileTemplateRegistry", "DirectoryTemplateRegistry")
__all__ = ("FilePromptRegistry", "DirectoryTemplateRegistry")
117 changes: 53 additions & 64 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,29 @@
# SPDX-License-Identifier: MIT
import time
from pathlib import Path
from typing import Self

from pydantic import BaseModel, Field

from banks import Prompt
from banks.registry import TemplateNotFoundError
from banks.errors import InvalidPromptError, PromptNotFoundError
from banks.prompt import DEFAULT_VERSION
from banks.types import PromptModel

# Constants
DEFAULT_VERSION = "0"
DEFAULT_INDEX_NAME = "index.json"


class PromptFile(BaseModel):
name: str
version: str
path: Path
meta: dict
class PromptFile(PromptModel):
path: Path = Field(exclude=True)

@classmethod
def from_prompt(cls: type[Self], prompt: Prompt, path: Path) -> Self:
prompt_file = path / f"{prompt.name}.{prompt.version}.jinja"
prompt_file.write_text(prompt.raw)
return cls(
text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata, path=prompt_file
)


class PromptFileIndex(BaseModel):
Expand All @@ -38,71 +45,53 @@ def __init__(self, directory_path: Path, *, force_reindex: bool = False):
else:
self._load()

@property
def path(self) -> Path:
return self._path

def get(self, *, name: str, version: str | None = None) -> Prompt:
version = version or DEFAULT_VERSION
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return Prompt(**pf.model_dump())
raise PromptNotFoundError

def set(self, *, prompt: Prompt, overwrite: bool = False):
try:
version = prompt.version or DEFAULT_VERSION
idx, pf = self._get_prompt_file(name=prompt.name, version=version)
if overwrite:
prompt.metadata["created_at"] = time.ctime()
self._index.files[idx] = PromptFile.from_prompt(prompt, self._path)
self._save()
else:
msg = f"Prompt with name '{prompt.name}' already exists. Use overwrite=True to overwrite"
raise InvalidPromptError(msg)
except PromptNotFoundError:
prompt.metadata["created_at"] = time.ctime()
pf = PromptFile.from_prompt(prompt, self._path)
self._index.files.append(pf)
self._save()

def _load(self):
self._index = PromptFileIndex.model_validate_json(self._index_path.read_text())

def _save(self):
self._index_path.write_text(self._index.model_dump_json())

def _scan(self):
self._index: PromptFileIndex = PromptFileIndex()
for path in self._path.glob("*.jinja*"):
name, version = path.stem.rsplit(".", 1) if "." in path.stem else (path.stem, DEFAULT_VERSION)
pf = PromptFile(name=name, version=version, path=path, meta={})
self._index.files.append(pf)
with path.open("r") as f:
pf = PromptFile(text=f.read(), name=name, version=version, path=path, metadata={})
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get(self, *, name: str, version: str | None = None) -> "PromptFile":
version = version or DEFAULT_VERSION
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return pf
raise TemplateNotFoundError

def get_prompt(self, *, name: str, version: str = DEFAULT_VERSION) -> Prompt:
return Prompt(self.get(name=name, version=version).path.read_text())

def _get_prompt_file(self, *, name: str, version: str) -> PromptFile | None:
for pf in self._index.files:
def _get_prompt_file(self, *, name: str | None, version: str) -> tuple[int, PromptFile]:
for i, pf in enumerate(self._index.files):
if pf.name == name and pf.version == version:
return pf
return None

def _create_pf(self, *, name: str, prompt: Prompt, version: str, overwrite: bool, meta: dict) -> "PromptFile":
pf = self._get_prompt_file(name=name, version=version)
if pf:
if not overwrite:
msg = f"Prompt {name}.{version}.jinja already exists. Use overwrite=True to overwrite."
raise ValueError(msg)
pf.path.write_text(prompt.raw)
pf.meta = meta
return pf
new_prompt_file = self._path / f"{name}.{version}.jinja"
new_prompt_file.write_text(prompt.raw)
pf = PromptFile(name=name, version=version, path=new_prompt_file, meta=meta)
return pf

def set(
self,
*,
name: str,
prompt: Prompt,
meta: dict | None = None,
version: str | None = None,
overwrite: bool = False,
):
version = version or DEFAULT_VERSION
meta = {**(meta or {}), "created_at": time.ctime()}
return i, pf

pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta)
if pf not in self._index.files:
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get_meta(self, *, name: str, version: str = DEFAULT_VERSION) -> dict:
return self.get(name=name, version=version).meta

def update_meta(self, *, meta: dict, name: str, version: str = DEFAULT_VERSION):
pf = self._get_prompt_file(name=name, version=version)
if not pf:
unk_err = f"Prompt {name}.{version} not found in the index. Cannot set meta for a non-existing prompt."
raise ValueError(unk_err)
pf.meta = meta
self._index_path.write_text(self._index.model_dump_json())
msg = f"cannot find template with name '{name}' and version '{version}'"
raise PromptNotFoundError(msg)
82 changes: 36 additions & 46 deletions src/banks/registries/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,55 @@

from pydantic import BaseModel

from banks.errors import PromptNotFoundError
from banks.prompt import Prompt
from banks.registry import InvalidTemplateError, TemplateNotFoundError
from banks.types import PromptModel


class PromptTemplate(BaseModel):
id: str | None
name: str
version: str
prompt: str
class PromptRegistryIndex(BaseModel):
prompts: list[PromptModel] = []


class PromptTemplateIndex(BaseModel):
templates: list[PromptTemplate]
class FilePromptRegistry:
"""A prompt registry storing all the prompt data in a single JSON file."""

def __init__(self, registry_index: Path) -> None:
"""Creates an instance of the File Prompt Registry.
class FileTemplateRegistry:
def __init__(self, user_data_path: Path) -> None:
self._index_fpath: Path = user_data_path / "index.json"
self._index: PromptTemplateIndex = PromptTemplateIndex(templates=[])
Args:
registry_index: The path to the index file.
"""
self._index_fpath: Path = registry_index
self._index: PromptRegistryIndex = PromptRegistryIndex(prompts=[])
try:
self._index = PromptTemplateIndex.model_validate_json(self._index_fpath.read_text())
self._index = PromptRegistryIndex.model_validate_json(self._index_fpath.read_text())
except FileNotFoundError:
# init the user data folder
user_data_path.mkdir(parents=True, exist_ok=True)
self._index_fpath.parent.mkdir(parents=True, exist_ok=True)

@staticmethod
def _make_id(name: str, version: str | None):
if ":" in name:
msg = "Template name cannot contain ':'"
raise InvalidTemplateError(msg)
if version:
return f"{name}:{version}"
return name
def get(self, *, name: str, version: str | None = None) -> Prompt:
_, model = self._get_prompt_model(name, version)
return Prompt(**model.model_dump())

def save(self) -> None:
def set(self, *, prompt: Prompt, overwrite: bool = False) -> None:
try:
idx, p_model = self._get_prompt_model(prompt.name, prompt.version)
if overwrite:
self._index.prompts[idx] = PromptModel.from_prompt(prompt)
self._save()
except PromptNotFoundError:
p_model = PromptModel.from_prompt(prompt)
self._index.prompts.append(p_model)
self._save()

def _save(self) -> None:
with open(self._index_fpath, "w", encoding="locale") as f:
f.write(self._index.model_dump_json())

def get(self, name: str, version: str | None = None) -> "Prompt":
tpl_id = self._make_id(name, version)
tpl = self._get_template(tpl_id)
return Prompt(tpl.prompt)

def _get_template(self, tpl_id: str) -> "PromptTemplate":
for tpl in self._index.templates:
if tpl_id == tpl.id:
return tpl
def _get_prompt_model(self, name: str | None, version: str | None) -> tuple[int, PromptModel]:
for i, model in enumerate(self._index.prompts):
if model.name == name and model.version == version:
return i, model

msg = f"cannot find template '{id}'"
raise TemplateNotFoundError(msg)

def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False):
tpl_id = self._make_id(name, version)
try:
tpl = self._get_template(tpl_id)
if overwrite:
tpl.prompt = prompt.raw
self.save()
except TemplateNotFoundError:
tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt.raw)
self._index.templates.append(tpl)
self.save()
msg = f"cannot find template with name '{name}' and version '{version}'"
raise PromptNotFoundError(msg)
25 changes: 0 additions & 25 deletions src/banks/registry.py

This file was deleted.

29 changes: 29 additions & 0 deletions src/banks/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from typing import Any, Protocol, Self

from pydantic import BaseModel

from .prompt import Prompt


class PromptRegistry(Protocol):
"""Interface to be implemented by concrete prompt registries."""

def get(self, *, name: str, version: str | None = None) -> Prompt: ...

def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: ...


class PromptModel(BaseModel):
"""Serializable representation of a Prompt."""

text: str
name: str | None = None
version: str | None = None
metadata: dict[str, Any] | None = None

@classmethod
def from_prompt(cls: type[Self], prompt: Prompt) -> Self:
return cls(text=prompt.raw, name=prompt.name, version=prompt.version, metadata=prompt.metadata)
Loading

0 comments on commit 3f770f1

Please sign in to comment.