Skip to content

Commit

Permalink
feat: Add canary word support (#2)
Browse files Browse the repository at this point in the history
* ignore deprecation warnings when running tests

* add canary word support

* fix linting

* drop python 3.8

* minimal documentation
  • Loading branch information
masci authored May 2, 2024
1 parent 884d79f commit 2b1e924
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v3
Expand Down
5 changes: 5 additions & 0 deletions docs/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ Banks supports the following ones, specific for prompt engineering.
options:
show_root_heading: false

### `{{canary_word}}`

Insert into the prompt a canary word that can be checked later with `Prompt.canary_leaked()`
to ensure the original prompt was not leaked.

## Macros

Macros are a way to implement complex logic in the template itself, think about defining functions but using Jinja
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "banks"
dynamic = ["version"]
description = 'A prompt programming language'
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = "MIT"
keywords = []
authors = [
Expand All @@ -16,7 +16,6 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -56,7 +55,7 @@ cov = [
]

[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
python = ["3.9", "3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need
Expand Down
5 changes: 5 additions & 0 deletions src/banks/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

from .utils import strtobool

async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))
33 changes: 12 additions & 21 deletions src/banks/env.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os

from jinja2 import Environment, select_autoescape

from banks.extensions import GenerateExtension, HFInferenceEndpointsExtension
from banks.filters import lemmatize
from banks.loader import MultiLoader
from .config import async_enabled
from .filters import lemmatize
from .loader import MultiLoader


def strtobool(val: str) -> bool:
"""Convert a string representation of truth to True or False.
def _add_extensions(env):
"""
We lazily add extensions so that we can use the env in the extensions themselves if needed.
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
For example, we use banks to manage the system prompt in `GenerateExtension`
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
msg = f"invalid truth value {val}"
raise ValueError(msg)
from .extensions import GenerateExtension, HFInferenceEndpointsExtension

env.add_extension(GenerateExtension)
env.add_extension(HFInferenceEndpointsExtension)

async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))

# Init the Jinja env
env = Environment(
loader=MultiLoader(),
extensions=[GenerateExtension, HFInferenceEndpointsExtension],
autoescape=select_autoescape(
enabled_extensions=("html", "xml"),
default_for_string=False,
Expand All @@ -42,8 +32,9 @@ def strtobool(val: str) -> bool:
enable_async=bool(async_enabled),
)

# Setup custom filters
# Setup custom filters and default extensions
env.filters["lemmatize"] = lemmatize
_add_extensions(env)


def with_env(cls):
Expand Down
4 changes: 4 additions & 0 deletions src/banks/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class MissingDependencyError(Exception):

class AsyncError(Exception):
pass


class CanaryWordError(Exception):
pass
13 changes: 10 additions & 3 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from jinja2.ext import Extension
from litellm import ModelResponse, acompletion, completion

from banks.errors import CanaryWordError
from banks.prompt import Prompt

DEFAULT_MODEL = "gpt-3.5-turbo"
SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.")


class GenerateExtension(Extension):
Expand Down Expand Up @@ -52,11 +56,14 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
To tweak the prompt used to generate content, change the variable `messages` .
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
return response["choices"][0]["message"]["content"]
content: str = response["choices"][0]["message"]["content"]
if SYSTEM_PROMPT.canary_leaked(content):
msg = "The system prompt has leaked into the response, possible prompt injection!"
raise CanaryWordError(msg)

async def _agenerate(self, text, model_name=DEFAULT_MODEL):
"""
Expand All @@ -65,7 +72,7 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL):
To tweak the prompt used to generate content, change the variable `messages` .
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
Expand Down
21 changes: 16 additions & 5 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
# SPDX-License-Identifier: MIT
from typing import Optional

from banks.env import async_enabled, env
from banks.errors import AsyncError
from .config import async_enabled
from .env import env
from .errors import AsyncError
from .utils import generate_canary_word


class BasePrompt:
def __init__(self, text: str) -> None:
def __init__(self, text: str, canary_word: Optional[str] = None) -> None:
self._template = env.from_string(text)
self.defaults = {"canary_word": canary_word or generate_canary_word()}

def _get_context(self, data: Optional[dict]) -> dict:
if data is None:
return self.defaults
return data | self.defaults

def canary_leaked(self, text: str) -> bool:
return self.defaults["canary_word"] in text

@classmethod
def from_template(cls, name: str) -> "BasePrompt":
Expand All @@ -20,7 +31,7 @@ def from_template(cls, name: str) -> "BasePrompt":

class Prompt(BasePrompt):
def text(self, data: Optional[dict] = None) -> str:
data = data or {}
data = self._get_context(data)
return self._template.render(data)


Expand All @@ -33,6 +44,6 @@ def __init__(self, text: str) -> None:
raise AsyncError(msg)

async def text(self, data: Optional[dict] = None) -> str:
data = data or {}
data = self._get_context(data)
result: str = await self._template.render_async(data)
return result
23 changes: 23 additions & 0 deletions src/banks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import secrets


def strtobool(val: str) -> bool:
"""
Convert a string representation of truth to True or False.
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
msg = f"invalid truth value {val}"
raise ValueError(msg)


def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str:
return f"{prefix}{secrets.token_hex(token_length // 2)}{suffix}"
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import warnings

# banks depends on modules producing loads of deprecation warnings, let's just ignore them,
# nothing we can do anyways
warnings.simplefilter("ignore", category=DeprecationWarning)
16 changes: 16 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import regex as re

from banks import Prompt


def test_canary_word_generation():
p = Prompt("{{canary_word}}This is my prompt")
assert re.match(r"BANKS\[.{8}\]This is my prompt", p.text())


def test_canary_word_leaked():
p = Prompt("{{canary_word}}This is my prompt")
assert p.canary_leaked(p.text())

p = Prompt("This is my prompt")
assert not p.canary_leaked(p.text())
47 changes: 47 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import regex as re

from banks.utils import generate_canary_word, strtobool


def test_generate_canary_word_defaults():
default = generate_canary_word()
assert re.match(r"BANKS\[.{8}\]", default)


def test_generate_canary_word_params():
only_token = generate_canary_word(prefix="", suffix="", token_length=16)
assert re.match(r".{16}", only_token)

only_prefix = generate_canary_word(prefix="foo", suffix="")
assert re.match(r"foo.{8}", only_prefix)

only_suffix = generate_canary_word(prefix="", suffix="foo")
assert re.match(r".{8}foo", only_suffix)


def test_strtobool_error():
with pytest.raises(ValueError):
strtobool("42")


@pytest.mark.parametrize(
"test_input,expected",
[
("y", True),
("yes", True),
("t", True),
("true", True),
("on", True),
("1", True),
("n", False),
("no", False),
("f", False),
("false", False),
("off", False),
("0", False),
pytest.param("42", True, marks=pytest.mark.xfail),
],
)
def test_strtobool(test_input, expected):
assert strtobool(test_input) == expected

0 comments on commit 2b1e924

Please sign in to comment.