-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* cache prompt rendering * add tests * fix lint * make cache pluggable
- Loading branch information
Showing
5 changed files
with
102 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from banks.env import env | ||
from banks.prompt import AsyncPrompt, Prompt | ||
from .cache import RenderCache | ||
from .env import env | ||
from .prompt import AsyncPrompt, Prompt | ||
|
||
__all__ = ( | ||
"env", | ||
"Prompt", | ||
"AsyncPrompt", | ||
"RenderCache", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
import pickle | ||
from typing import Optional, Protocol, runtime_checkable | ||
|
||
|
||
@runtime_checkable | ||
class RenderCache(Protocol): | ||
def get(self, context: dict) -> Optional[str]: ... | ||
|
||
def set(self, context: dict, prompt: str) -> None: ... | ||
|
||
def clear(self) -> None: ... | ||
|
||
|
||
class DefaultCache: | ||
def __init__(self) -> None: | ||
self._cache: dict[bytes, str] = {} | ||
|
||
def get(self, context: dict) -> Optional[str]: | ||
return self._cache.get(pickle.dumps(context, pickle.HIGHEST_PROTOCOL)) | ||
|
||
def set(self, context: dict, prompt: str) -> None: | ||
self._cache[pickle.dumps(context, pickle.HIGHEST_PROTOCOL)] = prompt | ||
|
||
def clear(self) -> None: | ||
self._cache = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pytest | ||
|
||
from banks.cache import DefaultCache | ||
|
||
|
||
@pytest.fixture | ||
def cache(): | ||
return DefaultCache() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"context_value", | ||
[{}, {"foo": "bar"}], | ||
) | ||
def test_default_cache_set(context_value, cache): | ||
cache.set(context_value, "My prompt") | ||
assert len(cache._cache) == 1 | ||
assert next(iter(cache._cache.values())) == "My prompt" | ||
|
||
|
||
def test_default_cache_get(cache): | ||
cache.set({"foo": "bar"}, "My prompt") | ||
assert cache.get({"foo": "bar"}) == "My prompt" | ||
assert cache.get({"bar"}) is None | ||
|
||
|
||
def test_default_cache_clear(cache): | ||
cache.set({"foo": "bar"}, "My prompt") | ||
cache.clear() | ||
assert not len(cache._cache) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters