diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d2d6a8..2eaede8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v3 diff --git a/docs/prompt.md b/docs/prompt.md index 466b262..81b8b11 100644 --- a/docs/prompt.md +++ b/docs/prompt.md @@ -32,6 +32,11 @@ Banks supports the following ones, specific for prompt engineering. options: show_root_heading: false +### `{{canary_word}}` + +Insert into the prompt a canary word that can be checked later with `Prompt.canary_leaked()` +to ensure the original prompt was not leaked. + ## Macros Macros are a way to implement complex logic in the template itself, think about defining functions but using Jinja diff --git a/pyproject.toml b/pyproject.toml index e7edecf..9a84997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "banks" dynamic = ["version"] description = 'A prompt programming language' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "MIT" keywords = [] authors = [ @@ -16,7 +16,6 @@ authors = [ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -56,7 +55,7 @@ cov = [ ] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +python = ["3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need diff --git a/src/banks/config.py b/src/banks/config.py new file mode 100644 index 0000000..2547f41 --- /dev/null +++ b/src/banks/config.py @@ -0,0 +1,5 @@ +import os + +from .utils import strtobool + +async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) diff --git a/src/banks/env.py b/src/banks/env.py index 028e871..477210f 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -1,38 +1,28 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -import os - from jinja2 import Environment, select_autoescape -from banks.extensions import GenerateExtension, HFInferenceEndpointsExtension -from banks.filters import lemmatize -from banks.loader import MultiLoader +from .config import async_enabled +from .filters import lemmatize +from .loader import MultiLoader -def strtobool(val: str) -> bool: - """Convert a string representation of truth to True or False. +def _add_extensions(env): + """ + We lazily add extensions so that we can use the env in the extensions themselves if needed. - True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values - are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if - 'val' is anything else. + For example, we use banks to manage the system prompt in `GenerateExtension` """ - val = val.lower() - if val in ("y", "yes", "t", "true", "on", "1"): - return True - elif val in ("n", "no", "f", "false", "off", "0"): - return False - else: - msg = f"invalid truth value {val}" - raise ValueError(msg) + from .extensions import GenerateExtension, HFInferenceEndpointsExtension + env.add_extension(GenerateExtension) + env.add_extension(HFInferenceEndpointsExtension) -async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) # Init the Jinja env env = Environment( loader=MultiLoader(), - extensions=[GenerateExtension, HFInferenceEndpointsExtension], autoescape=select_autoescape( enabled_extensions=("html", "xml"), default_for_string=False, @@ -42,8 +32,9 @@ def strtobool(val: str) -> bool: enable_async=bool(async_enabled), ) -# Setup custom filters +# Setup custom filters and default extensions env.filters["lemmatize"] = lemmatize +_add_extensions(env) def with_env(cls): diff --git a/src/banks/errors.py b/src/banks/errors.py index f950d17..cf4c87b 100644 --- a/src/banks/errors.py +++ b/src/banks/errors.py @@ -7,3 +7,7 @@ class MissingDependencyError(Exception): class AsyncError(Exception): pass + + +class CanaryWordError(Exception): + pass diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 8275da8..c1a14c9 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -7,7 +7,11 @@ from jinja2.ext import Extension from litellm import ModelResponse, acompletion, completion +from banks.errors import CanaryWordError +from banks.prompt import Prompt + DEFAULT_MODEL = "gpt-3.5-turbo" +SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.") class GenerateExtension(Extension): @@ -52,11 +56,14 @@ def _generate(self, text, model_name=DEFAULT_MODEL): To tweak the prompt used to generate content, change the variable `messages` . """ messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": SYSTEM_PROMPT.text()}, {"role": "user", "content": text}, ] response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages)) - return response["choices"][0]["message"]["content"] + content: str = response["choices"][0]["message"]["content"] + if SYSTEM_PROMPT.canary_leaked(content): + msg = "The system prompt has leaked into the response, possible prompt injection!" + raise CanaryWordError(msg) async def _agenerate(self, text, model_name=DEFAULT_MODEL): """ @@ -65,7 +72,7 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL): To tweak the prompt used to generate content, change the variable `messages` . """ messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": SYSTEM_PROMPT.text()}, {"role": "user", "content": text}, ] response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) diff --git a/src/banks/prompt.py b/src/banks/prompt.py index 3930ac1..14974f0 100644 --- a/src/banks/prompt.py +++ b/src/banks/prompt.py @@ -3,13 +3,24 @@ # SPDX-License-Identifier: MIT from typing import Optional -from banks.env import async_enabled, env -from banks.errors import AsyncError +from .config import async_enabled +from .env import env +from .errors import AsyncError +from .utils import generate_canary_word class BasePrompt: - def __init__(self, text: str) -> None: + def __init__(self, text: str, canary_word: Optional[str] = None) -> None: self._template = env.from_string(text) + self.defaults = {"canary_word": canary_word or generate_canary_word()} + + def _get_context(self, data: Optional[dict]) -> dict: + if data is None: + return self.defaults + return data | self.defaults + + def canary_leaked(self, text: str) -> bool: + return self.defaults["canary_word"] in text @classmethod def from_template(cls, name: str) -> "BasePrompt": @@ -20,7 +31,7 @@ def from_template(cls, name: str) -> "BasePrompt": class Prompt(BasePrompt): def text(self, data: Optional[dict] = None) -> str: - data = data or {} + data = self._get_context(data) return self._template.render(data) @@ -33,6 +44,6 @@ def __init__(self, text: str) -> None: raise AsyncError(msg) async def text(self, data: Optional[dict] = None) -> str: - data = data or {} + data = self._get_context(data) result: str = await self._template.render_async(data) return result diff --git a/src/banks/utils.py b/src/banks/utils.py new file mode 100644 index 0000000..48f17ef --- /dev/null +++ b/src/banks/utils.py @@ -0,0 +1,23 @@ +import secrets + + +def strtobool(val: str) -> bool: + """ + Convert a string representation of truth to True or False. + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + msg = f"invalid truth value {val}" + raise ValueError(msg) + + +def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str: + return f"{prefix}{secrets.token_hex(token_length // 2)}{suffix}" diff --git a/tests/__init__.py b/tests/__init__.py index bd78f6a..0a47ebb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,8 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import warnings + +# banks depends on modules producing loads of deprecation warnings, let's just ignore them, +# nothing we can do anyways +warnings.simplefilter("ignore", category=DeprecationWarning) diff --git a/tests/test_prompt.py b/tests/test_prompt.py new file mode 100644 index 0000000..c57ad07 --- /dev/null +++ b/tests/test_prompt.py @@ -0,0 +1,16 @@ +import regex as re + +from banks import Prompt + + +def test_canary_word_generation(): + p = Prompt("{{canary_word}}This is my prompt") + assert re.match(r"BANKS\[.{8}\]This is my prompt", p.text()) + + +def test_canary_word_leaked(): + p = Prompt("{{canary_word}}This is my prompt") + assert p.canary_leaked(p.text()) + + p = Prompt("This is my prompt") + assert not p.canary_leaked(p.text()) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d3f9b32 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,47 @@ +import pytest +import regex as re + +from banks.utils import generate_canary_word, strtobool + + +def test_generate_canary_word_defaults(): + default = generate_canary_word() + assert re.match(r"BANKS\[.{8}\]", default) + + +def test_generate_canary_word_params(): + only_token = generate_canary_word(prefix="", suffix="", token_length=16) + assert re.match(r".{16}", only_token) + + only_prefix = generate_canary_word(prefix="foo", suffix="") + assert re.match(r"foo.{8}", only_prefix) + + only_suffix = generate_canary_word(prefix="", suffix="foo") + assert re.match(r".{8}foo", only_suffix) + + +def test_strtobool_error(): + with pytest.raises(ValueError): + strtobool("42") + + +@pytest.mark.parametrize( + "test_input,expected", + [ + ("y", True), + ("yes", True), + ("t", True), + ("true", True), + ("on", True), + ("1", True), + ("n", False), + ("no", False), + ("f", False), + ("false", False), + ("off", False), + ("0", False), + pytest.param("42", True, marks=pytest.mark.xfail), + ], +) +def test_strtobool(test_input, expected): + assert strtobool(test_input) == expected