Skip to content

Commit

Permalink
Remove from_template and add Directory Registry (#9)
Browse files Browse the repository at this point in the history
* directory registry

* move test templates under test folder, stop using MultiLoader

* fix simplemma

* re-enable tests

* linting

* fix macro docs

* fix readme
  • Loading branch information
masci authored Jun 25, 2024
1 parent cb1db67 commit a3bdf8f
Show file tree
Hide file tree
Showing 23 changed files with 266 additions and 152 deletions.
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Docs are available [here](https://masci.github.io/banks/).
- [Use a LLM to generate a text while rendering a prompt](#use-a-llm-to-generate-a-text-while-rendering-a-prompt)
- [Go meta: create a prompt and `generate` its response](#go-meta-create-a-prompt-and-generate-its-response)
- [Go meta(meta): process a LLM response](#go-metameta-process-a-llm-response)
- [Reuse templates from files](#reuse-templates-from-files)
- [Reuse templates from registries](#reuse-templates-from-registries)
- [Async support](#async-support)
- [License](#license)

Expand Down Expand Up @@ -266,19 +266,22 @@ print(p.text({"topic": "climate change"}))

The final answer from the LLM will be printed, this time all in uppercase.

### Reuse templates from files
### Reuse templates from registries

We can get the same result as the previous example loading the prompt template from file
instead of hardcoding it into the Python code. For convenience, Banks comes with a few
default templates distributed the package. We can load those templates from file like this:
We can get the same result as the previous example loading the prompt template from a registry
instead of hardcoding it into the Python code. For convenience, Banks comes with a few registry types
you can use to store your templates. For example, the `DirectoryTemplateRegistry` can load templates
from a directory in the file system. Suppose you have a folder called `templates` in the current path,
and the folder contains a file called `blog.jinja`. You can load the prompt template like this:

```py
from banks import Prompt
from banks.registries import DirectoryTemplateRegistry

registry = DirectoryTemplateRegistry(populated_dir)
prompt = registry.get(name="blog")

p = Prompt.from_template("blog.jinja")
topic = "retrogame computing"
print(p.text({"topic": topic}))
print(prompt.text({"topic": "retrogame computing"}))
```

### Async support
Expand All @@ -292,7 +295,7 @@ Example:
from banks import AsyncPrompt

async def main():
p = AsyncPrompt.from_template("blog.jinja")
p = AsyncPrompt("Write a blog article about the topic {{ topic }}")
result = await p.text({"topic": "AI frameworks"})
print(result)

Expand Down
50 changes: 40 additions & 10 deletions docs/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,46 @@

::: banks.prompt.AsyncPrompt

## Default templates
## Default macros

Banks' package comes with the following prompt templates ready to be used:
Banks' package comes with default template macros you can use in your prompts.

- `banks_macros.jinja`
- `generate_tweet.jinja`
- `run_prompt_process.jinja`
- `summarize_lemma.jinja`
- `blog.jinja`
- `run_prompt.jinja`
- `summarize.jinja`

If Banks is properly installed, something like `Prompt.from_template("blog.jinja")` should always work out of the box.
### `run_prompt`


We can use `run_prompt` in our templates to generate a prompt, send the result to the LLM and get a response.
Take this prompt for example:

```py
from banks import Prompt

prompt_template = """
{% from "banks_macros.jinja" import run_prompt with context %}
{%- call run_prompt() -%}
Write a 500-word blog post on {{ topic }}
Blog post:
{%- endcall -%}
"""

p = Prompt(prompt_template)
print(p.text({"topic": "climate change"}))
```

In this case, Banks will generate internally the prompt text

```
Write a 500-word blog post on climate change
Blog post:
```

but instead of returning it, will send it to the LLM using the `generate` extension under the hood, eventually
returning the final response:

```
Climate change is a phenomenon that has been gaining attention in recent years...
...
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"pytest",
"mkdocs-material",
"mkdocstrings[python]",
"simplemma",
]

[tool.hatch.envs.default.scripts]
Expand Down
4 changes: 3 additions & 1 deletion src/banks/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os
from pathlib import Path
from typing import Any
Expand Down
20 changes: 2 additions & 18 deletions src/banks/env.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os
from pathlib import Path

from jinja2 import Environment, select_autoescape
from jinja2 import Environment, PackageLoader, select_autoescape

from .config import config
from .filters import lemmatize
from .loader import MultiLoader
from .registries import FileTemplateRegistry
from .registry import TemplateRegistry


def _add_extensions(env):
Expand All @@ -25,15 +19,9 @@ def _add_extensions(env):
env.add_extension(HFInferenceEndpointsExtension)


def _add_default_templates(r: TemplateRegistry):
templates_dir = Path(os.path.dirname(__file__)) / "templates"
for tpl_file in templates_dir.glob("*.jinja"):
r.set(name=tpl_file.name, prompt=tpl_file.read_text())


# Init the Jinja env
env = Environment(
loader=MultiLoader(),
loader=PackageLoader("banks", "templates"),
autoescape=select_autoescape(
enabled_extensions=("html", "xml"),
default_for_string=False,
Expand All @@ -43,11 +31,7 @@ def _add_default_templates(r: TemplateRegistry):
enable_async=bool(config.ASYNC_ENABLED),
)

# Init the Template registry
registry = FileTemplateRegistry(config.USER_DATA_PATH)


# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
_add_extensions(env)
_add_default_templates(registry)
2 changes: 1 addition & 1 deletion src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from banks.errors import MissingDependencyError

try:
from simplemma.simplemma import text_lemmatizer
from simplemma import text_lemmatizer

simplemma_avail = True
except ImportError:
Expand Down
48 changes: 6 additions & 42 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .cache import DefaultCache, RenderCache
from .config import config
from .env import env, registry
from .env import env
from .errors import AsyncError
from .utils import generate_canary_word

Expand All @@ -25,6 +25,7 @@ def __init__(
be used.
"""
self._render_cache = render_cache or DefaultCache()
self._raw: str = text
self._template = env.from_string(text)
self.defaults = {"canary_word": canary_word or generate_canary_word()}

Expand All @@ -39,53 +40,16 @@ def _get_context(self, data: Optional[dict]) -> dict:
return self.defaults
return data | self.defaults

@property
def raw(self) -> str:
return self._raw

def canary_leaked(self, text: str) -> bool:
"""
Returns whether the canary word is present in `text`, signalling the prompt might have leaked.
"""
return self.defaults["canary_word"] in text

@classmethod
def from_template(cls, name: str, version: str | None = None) -> "BasePrompt":
"""
Create a prompt instance from a template.
Prompt templates can be really long and at some point you might want to store them on files. To avoid the
boilerplate code to read a file and pass the content as strings to the constructor, `Prompt`s can be
initialized by just passing the name of the template file, provided that the file is available to the
loaders that were configured (see `Multiloader`).
One of the default loaders can load templates stored in a folder called `templates` in the current path:
```
.
└── templates
└── foo.jinja
```
The code would be the following:
```py
from banks import Prompt
p = Prompt.from_template("foo.jinja")
prompt_text = p.text(data={"foo": "bar"})
```
!!! warning
Banks comes with its own set of default templates (see below) which takes precedence over the
ones loaded from the filesystem, so be sure to use different names for your custom
templates
Parameters:
name: The name of the template.
Returns:
A new `Prompt` instance.
"""
tpl = registry.get(name, version)
return cls(tpl.prompt)


class Prompt(BasePrompt):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/banks/registries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from .directory import DirectoryTemplateRegistry
from .file import FileTemplateRegistry

__all__ = ("FileTemplateRegistry",)
__all__ = ("FileTemplateRegistry", "DirectoryTemplateRegistry")
61 changes: 61 additions & 0 deletions src/banks/registries/directory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from pathlib import Path

from pydantic import BaseModel, Field

from banks import Prompt
from banks.registry import TemplateNotFoundError


class PromptFile(BaseModel):
name: str
version: str
path: Path


class PromptFileIndex(BaseModel):
files: list[PromptFile] = Field(default=[])


class DirectoryTemplateRegistry:
def __init__(self, directory_path: Path, *, force_reindex: bool = False):
if not directory_path.is_dir():
msg = "{directory_path} must be a directory."
raise ValueError(msg)

self._path = directory_path
self._index_path = self._path / "index.json"
if not self._index_path.exists() or force_reindex:
self._scan()
else:
self._load()

def _load(self):
self._index = PromptFileIndex.model_validate_json(self._index_path.read_text())

def _scan(self):
self._index: PromptFileIndex = PromptFileIndex()
for path in self._path.glob("*.jinja*"):
pf = PromptFile(name=path.stem, version="0", path=path)
self._index.files.append(pf)
self._index_path.write_text(self._index.model_dump_json())

def get(self, *, name: str, version: str | None = None) -> "Prompt":
version = version or "0"
for pf in self._index.files:
if pf.name == name and pf.version == version and pf.path.exists():
return Prompt(pf.path.read_text())
raise TemplateNotFoundError

def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False):
version = version or "0"
for pf in self._index.files:
if pf.name == name and pf.version == version and overwrite:
pf.path.write_text(prompt.raw)
return
new_prompt_file = self._path / "{name}.{version}.jinja"
new_prompt_file.write_text(prompt.raw)
pf = PromptFile(name=name, version=version, path=new_prompt_file)
self._index.files.append(pf)
37 changes: 28 additions & 9 deletions src/banks/registries/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
# SPDX-License-Identifier: MIT
from pathlib import Path

from banks.registry import PromptTemplate, PromptTemplateIndex, TemplateNotFoundError, InvalidTemplateError
from pydantic import BaseModel

from banks.prompt import Prompt
from banks.registry import InvalidTemplateError, TemplateNotFoundError


class PromptTemplate(BaseModel):
id: str | None
name: str
version: str
prompt: str


class PromptTemplateIndex(BaseModel):
templates: list[PromptTemplate]


class FileTemplateRegistry:
Expand All @@ -19,7 +33,8 @@ def __init__(self, user_data_path: Path) -> None:
@staticmethod
def _make_id(name: str, version: str | None):
if ":" in name:
raise InvalidTemplateError("Template name cannot contain ':'")
msg = "Template name cannot contain ':'"
raise InvalidTemplateError(msg)
if version:
return f"{name}:{version}"
return name
Expand All @@ -28,23 +43,27 @@ def save(self) -> None:
with open(self._index_fpath, "w") as f:
f.write(self._index.model_dump_json())

def get(self, name: str, version: str | None = None) -> "PromptTemplate":
def get(self, name: str, version: str | None = None) -> "Prompt":
tpl_id = self._make_id(name, version)
tpl = self._get_template(tpl_id)
return Prompt(tpl.prompt)

def _get_template(self, tpl_id: str) -> "PromptTemplate":
for tpl in self._index.templates:
if tpl_id == tpl.id:
return tpl

msg = f"cannot find template '{tpl_id}'"
msg = f"cannot find template '{id}'"
raise TemplateNotFoundError(msg)

def set(self, *, name: str, prompt: str, version: str | None = None, overwrite: bool = False):
def set(self, *, name: str, prompt: Prompt, version: str | None = None, overwrite: bool = False):
tpl_id = self._make_id(name, version)
try:
tpl = self.get(name, version)
tpl = self._get_template(tpl_id)
if overwrite:
tpl.prompt = prompt
tpl.prompt = prompt.raw
self.save()
except TemplateNotFoundError:
tpl_id = self._make_id(name, version)
tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt)
tpl = PromptTemplate(id=tpl_id, name=name, version=version or "", prompt=prompt.raw)
self._index.templates.append(tpl)
self.save()
Loading

0 comments on commit a3bdf8f

Please sign in to comment.