From 1478d137a4a983a7c4643e6c7aa79cbc8df2cc39 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sat, 4 May 2024 09:32:48 +0200 Subject: [PATCH] feat: cache prompt rendering (#4) * cache prompt rendering * add tests * fix lint * make cache pluggable --- src/banks/__init__.py | 6 ++++-- src/banks/cache.py | 28 ++++++++++++++++++++++++++++ src/banks/prompt.py | 36 +++++++++++++++++++++++++++++------- tests/test_cache.py | 30 ++++++++++++++++++++++++++++++ tests/test_prompt.py | 11 +++++++++++ 5 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 src/banks/cache.py create mode 100644 tests/test_cache.py diff --git a/src/banks/__init__.py b/src/banks/__init__.py index 3b7dfcb..8a0228d 100644 --- a/src/banks/__init__.py +++ b/src/banks/__init__.py @@ -1,11 +1,13 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -from banks.env import env -from banks.prompt import AsyncPrompt, Prompt +from .cache import RenderCache +from .env import env +from .prompt import AsyncPrompt, Prompt __all__ = ( "env", "Prompt", "AsyncPrompt", + "RenderCache", ) diff --git a/src/banks/cache.py b/src/banks/cache.py new file mode 100644 index 0000000..e749973 --- /dev/null +++ b/src/banks/cache.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +import pickle +from typing import Optional, Protocol, runtime_checkable + + +@runtime_checkable +class RenderCache(Protocol): + def get(self, context: dict) -> Optional[str]: ... + + def set(self, context: dict, prompt: str) -> None: ... + + def clear(self) -> None: ... + + +class DefaultCache: + def __init__(self) -> None: + self._cache: dict[bytes, str] = {} + + def get(self, context: dict) -> Optional[str]: + return self._cache.get(pickle.dumps(context, pickle.HIGHEST_PROTOCOL)) + + def set(self, context: dict, prompt: str) -> None: + self._cache[pickle.dumps(context, pickle.HIGHEST_PROTOCOL)] = prompt + + def clear(self) -> None: + self._cache = {} diff --git a/src/banks/prompt.py b/src/banks/prompt.py index 9e26037..625e4b2 100644 --- a/src/banks/prompt.py +++ b/src/banks/prompt.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT from typing import Optional +from .cache import DefaultCache, RenderCache from .config import async_enabled from .env import env from .errors import AsyncError @@ -10,19 +11,29 @@ class BasePrompt: - def __init__(self, text: str, canary_word: Optional[str] = None) -> None: + def __init__( + self, text: str, canary_word: Optional[str] = None, render_cache: Optional[RenderCache] = None + ) -> None: """ Prompt constructor. Parameters: - text: The template text + text: The template text. canary_word: The string to use for the `{{canary_word}}` extension. If `None`, a default string will be - generated - + generated. + render_cache: The caching backend to store rendered prompts. If `None`, the default in-memory backend will + be used. """ + self._render_cache = render_cache or DefaultCache() self._template = env.from_string(text) self.defaults = {"canary_word": canary_word or generate_canary_word()} + def _cache_get(self, data: dict) -> Optional[str]: + return self._render_cache.get(data) + + def _cache_set(self, data: dict, text: str) -> None: + self._render_cache.set(data, text) + def _get_context(self, data: Optional[dict]) -> dict: if data is None: return self.defaults @@ -100,7 +111,13 @@ def text(self, data: Optional[dict] = None) -> str: data: A dictionary containing the context variables. """ data = self._get_context(data) - return self._template.render(data) + cached = self._cache_get(data) + if cached: + return cached + + rendered: str = self._template.render(data) + self._cache_set(data, rendered) + return rendered class AsyncPrompt(BasePrompt): @@ -173,5 +190,10 @@ def __init__(self, text: str) -> None: async def text(self, data: Optional[dict] = None) -> str: data = self._get_context(data) - result: str = await self._template.render_async(data) - return result + cached = self._cache_get(data) + if cached: + return cached + + rendered: str = await self._template.render_async(data) + self._cache_set(data, rendered) + return rendered diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..3d98103 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,30 @@ +import pytest + +from banks.cache import DefaultCache + + +@pytest.fixture +def cache(): + return DefaultCache() + + +@pytest.mark.parametrize( + "context_value", + [{}, {"foo": "bar"}], +) +def test_default_cache_set(context_value, cache): + cache.set(context_value, "My prompt") + assert len(cache._cache) == 1 + assert next(iter(cache._cache.values())) == "My prompt" + + +def test_default_cache_get(cache): + cache.set({"foo": "bar"}, "My prompt") + assert cache.get({"foo": "bar"}) == "My prompt" + assert cache.get({"bar"}) is None + + +def test_default_cache_clear(cache): + cache.set({"foo": "bar"}, "My prompt") + cache.clear() + assert not len(cache._cache) diff --git a/tests/test_prompt.py b/tests/test_prompt.py index c57ad07..45b88b6 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,6 +1,9 @@ +from unittest import mock + import regex as re from banks import Prompt +from banks.cache import DefaultCache def test_canary_word_generation(): @@ -14,3 +17,11 @@ def test_canary_word_leaked(): p = Prompt("This is my prompt") assert not p.canary_leaked(p.text()) + + +def test_prompt_cache(): + mock_cache = DefaultCache() + mock_cache.set = mock.Mock() + p = Prompt("This is my prompt", render_cache=mock_cache) + p.text() + mock_cache.set.assert_called_once()