Skip to content

Commit

Permalink
feat: cache prompt rendering (#4)
Browse files Browse the repository at this point in the history
* cache prompt rendering

* add tests

* fix lint

* make cache pluggable
  • Loading branch information
masci authored May 4, 2024
1 parent bd6c087 commit 1478d13
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/banks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from banks.env import env
from banks.prompt import AsyncPrompt, Prompt
from .cache import RenderCache
from .env import env
from .prompt import AsyncPrompt, Prompt

__all__ = (
"env",
"Prompt",
"AsyncPrompt",
"RenderCache",
)
28 changes: 28 additions & 0 deletions src/banks/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import pickle
from typing import Optional, Protocol, runtime_checkable


@runtime_checkable
class RenderCache(Protocol):
def get(self, context: dict) -> Optional[str]: ...

def set(self, context: dict, prompt: str) -> None: ...

def clear(self) -> None: ...


class DefaultCache:
def __init__(self) -> None:
self._cache: dict[bytes, str] = {}

def get(self, context: dict) -> Optional[str]:
return self._cache.get(pickle.dumps(context, pickle.HIGHEST_PROTOCOL))

def set(self, context: dict, prompt: str) -> None:
self._cache[pickle.dumps(context, pickle.HIGHEST_PROTOCOL)] = prompt

def clear(self) -> None:
self._cache = {}
36 changes: 29 additions & 7 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,37 @@
# SPDX-License-Identifier: MIT
from typing import Optional

from .cache import DefaultCache, RenderCache
from .config import async_enabled
from .env import env
from .errors import AsyncError
from .utils import generate_canary_word


class BasePrompt:
def __init__(self, text: str, canary_word: Optional[str] = None) -> None:
def __init__(
self, text: str, canary_word: Optional[str] = None, render_cache: Optional[RenderCache] = None
) -> None:
"""
Prompt constructor.
Parameters:
text: The template text
text: The template text.
canary_word: The string to use for the `{{canary_word}}` extension. If `None`, a default string will be
generated
generated.
render_cache: The caching backend to store rendered prompts. If `None`, the default in-memory backend will
be used.
"""
self._render_cache = render_cache or DefaultCache()
self._template = env.from_string(text)
self.defaults = {"canary_word": canary_word or generate_canary_word()}

def _cache_get(self, data: dict) -> Optional[str]:
return self._render_cache.get(data)

def _cache_set(self, data: dict, text: str) -> None:
self._render_cache.set(data, text)

def _get_context(self, data: Optional[dict]) -> dict:
if data is None:
return self.defaults
Expand Down Expand Up @@ -100,7 +111,13 @@ def text(self, data: Optional[dict] = None) -> str:
data: A dictionary containing the context variables.
"""
data = self._get_context(data)
return self._template.render(data)
cached = self._cache_get(data)
if cached:
return cached

rendered: str = self._template.render(data)
self._cache_set(data, rendered)
return rendered


class AsyncPrompt(BasePrompt):
Expand Down Expand Up @@ -173,5 +190,10 @@ def __init__(self, text: str) -> None:

async def text(self, data: Optional[dict] = None) -> str:
data = self._get_context(data)
result: str = await self._template.render_async(data)
return result
cached = self._cache_get(data)
if cached:
return cached

rendered: str = await self._template.render_async(data)
self._cache_set(data, rendered)
return rendered
30 changes: 30 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from banks.cache import DefaultCache


@pytest.fixture
def cache():
return DefaultCache()


@pytest.mark.parametrize(
"context_value",
[{}, {"foo": "bar"}],
)
def test_default_cache_set(context_value, cache):
cache.set(context_value, "My prompt")
assert len(cache._cache) == 1
assert next(iter(cache._cache.values())) == "My prompt"


def test_default_cache_get(cache):
cache.set({"foo": "bar"}, "My prompt")
assert cache.get({"foo": "bar"}) == "My prompt"
assert cache.get({"bar"}) is None


def test_default_cache_clear(cache):
cache.set({"foo": "bar"}, "My prompt")
cache.clear()
assert not len(cache._cache)
11 changes: 11 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from unittest import mock

import regex as re

from banks import Prompt
from banks.cache import DefaultCache


def test_canary_word_generation():
Expand All @@ -14,3 +17,11 @@ def test_canary_word_leaked():

p = Prompt("This is my prompt")
assert not p.canary_leaked(p.text())


def test_prompt_cache():
mock_cache = DefaultCache()
mock_cache.set = mock.Mock()
p = Prompt("This is my prompt", render_cache=mock_cache)
p.text()
mock_cache.set.assert_called_once()

0 comments on commit 1478d13

Please sign in to comment.