From 6210a8d393fb30f67db30d01e064e9c7a75973b8 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sat, 28 Sep 2024 12:12:43 +0200 Subject: [PATCH] make registry protocol compatible --- pyproject.toml | 1 + src/banks/registries/directory.py | 6 ++++-- src/banks/registry.py | 10 +++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 10b144c..ed8d82f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ path = "src/banks/__about__.py" dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-cov", "mkdocs-material", "mkdocstrings[python]", "simplemma", diff --git a/src/banks/registries/directory.py b/src/banks/registries/directory.py index 549e59c..90a78b2 100644 --- a/src/banks/registries/directory.py +++ b/src/banks/registries/directory.py @@ -49,7 +49,8 @@ def _scan(self): self._index.files.append(pf) self._index_path.write_text(self._index.model_dump_json()) - def get(self, *, name: str, version: str = DEFAULT_VERSION) -> "PromptFile": + def get(self, *, name: str, version: str | None = None) -> "PromptFile": + version = version or DEFAULT_VERSION for pf in self._index.files: if pf.name == name and pf.version == version and pf.path.exists(): return pf @@ -84,9 +85,10 @@ def set( name: str, prompt: Prompt, meta: dict | None = None, - version: str = DEFAULT_VERSION, + version: str | None = None, overwrite: bool = False, ): + version = version or DEFAULT_VERSION meta = {**(meta or {}), "created_at": time.ctime()} pf = self._create_pf(name=name, prompt=prompt, version=version, overwrite=overwrite, meta=meta) diff --git a/src/banks/registry.py b/src/banks/registry.py index e330c18..9f40aba 100644 --- a/src/banks/registry.py +++ b/src/banks/registry.py @@ -15,4 +15,12 @@ class InvalidTemplateError(Exception): ... class TemplateRegistry(Protocol): def get(self, *, name: str, version: str | None = None) -> "Prompt": ... - def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False): ... + def set( + self, + *, + name: str, + prompt: Prompt, + meta: dict | None = None, + version: str | None = None, + overwrite: bool = False, + ): ...