Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add canary word support #2

Merged
merged 5 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v3
Expand Down
5 changes: 5 additions & 0 deletions docs/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ Banks supports the following ones, specific for prompt engineering.
options:
show_root_heading: false

### `{{canary_word}}`

Insert into the prompt a canary word that can be checked later with `Prompt.canary_leaked()`
to ensure the original prompt was not leaked.

## Macros

Macros are a way to implement complex logic in the template itself, think about defining functions but using Jinja
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "banks"
dynamic = ["version"]
description = 'A prompt programming language'
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = "MIT"
keywords = []
authors = [
Expand All @@ -16,7 +16,6 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -56,7 +55,7 @@ cov = [
]

[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
python = ["3.9", "3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need
Expand Down
5 changes: 5 additions & 0 deletions src/banks/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

from .utils import strtobool

async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))
33 changes: 12 additions & 21 deletions src/banks/env.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os

from jinja2 import Environment, select_autoescape

from banks.extensions import GenerateExtension, HFInferenceEndpointsExtension
from banks.filters import lemmatize
from banks.loader import MultiLoader
from .config import async_enabled
from .filters import lemmatize
from .loader import MultiLoader


def strtobool(val: str) -> bool:
"""Convert a string representation of truth to True or False.
def _add_extensions(env):
"""
We lazily add extensions so that we can use the env in the extensions themselves if needed.

True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
For example, we use banks to manage the system prompt in `GenerateExtension`
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
msg = f"invalid truth value {val}"
raise ValueError(msg)
from .extensions import GenerateExtension, HFInferenceEndpointsExtension

env.add_extension(GenerateExtension)
env.add_extension(HFInferenceEndpointsExtension)

async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false"))

# Init the Jinja env
env = Environment(
loader=MultiLoader(),
extensions=[GenerateExtension, HFInferenceEndpointsExtension],
autoescape=select_autoescape(
enabled_extensions=("html", "xml"),
default_for_string=False,
Expand All @@ -42,8 +32,9 @@ def strtobool(val: str) -> bool:
enable_async=bool(async_enabled),
)

# Setup custom filters
# Setup custom filters and default extensions
env.filters["lemmatize"] = lemmatize
_add_extensions(env)


def with_env(cls):
Expand Down
4 changes: 4 additions & 0 deletions src/banks/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class MissingDependencyError(Exception):

class AsyncError(Exception):
pass


class CanaryWordError(Exception):
pass
13 changes: 10 additions & 3 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from jinja2.ext import Extension
from litellm import ModelResponse, acompletion, completion

from banks.errors import CanaryWordError
from banks.prompt import Prompt

DEFAULT_MODEL = "gpt-3.5-turbo"
SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.")


class GenerateExtension(Extension):
Expand Down Expand Up @@ -52,11 +56,14 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
To tweak the prompt used to generate content, change the variable `messages` .
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
return response["choices"][0]["message"]["content"]
content: str = response["choices"][0]["message"]["content"]
if SYSTEM_PROMPT.canary_leaked(content):
msg = "The system prompt has leaked into the response, possible prompt injection!"
raise CanaryWordError(msg)

async def _agenerate(self, text, model_name=DEFAULT_MODEL):
"""
Expand All @@ -65,7 +72,7 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL):
To tweak the prompt used to generate content, change the variable `messages` .
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
Expand Down
21 changes: 16 additions & 5 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
# SPDX-License-Identifier: MIT
from typing import Optional

from banks.env import async_enabled, env
from banks.errors import AsyncError
from .config import async_enabled
from .env import env
from .errors import AsyncError
from .utils import generate_canary_word


class BasePrompt:
def __init__(self, text: str) -> None:
def __init__(self, text: str, canary_word: Optional[str] = None) -> None:
self._template = env.from_string(text)
self.defaults = {"canary_word": canary_word or generate_canary_word()}

def _get_context(self, data: Optional[dict]) -> dict:
if data is None:
return self.defaults
return data | self.defaults

def canary_leaked(self, text: str) -> bool:
return self.defaults["canary_word"] in text

@classmethod
def from_template(cls, name: str) -> "BasePrompt":
Expand All @@ -20,7 +31,7 @@ def from_template(cls, name: str) -> "BasePrompt":

class Prompt(BasePrompt):
def text(self, data: Optional[dict] = None) -> str:
data = data or {}
data = self._get_context(data)
return self._template.render(data)


Expand All @@ -33,6 +44,6 @@ def __init__(self, text: str) -> None:
raise AsyncError(msg)

async def text(self, data: Optional[dict] = None) -> str:
data = data or {}
data = self._get_context(data)
result: str = await self._template.render_async(data)
return result
23 changes: 23 additions & 0 deletions src/banks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import secrets


def strtobool(val: str) -> bool:
"""
Convert a string representation of truth to True or False.

True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
msg = f"invalid truth value {val}"
raise ValueError(msg)


def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str:
return f"{prefix}{secrets.token_hex(token_length // 2)}{suffix}"
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import warnings

# banks depends on modules producing loads of deprecation warnings, let's just ignore them,
# nothing we can do anyways
warnings.simplefilter("ignore", category=DeprecationWarning)
16 changes: 16 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import regex as re

from banks import Prompt


def test_canary_word_generation():
p = Prompt("{{canary_word}}This is my prompt")
assert re.match(r"BANKS\[.{8}\]This is my prompt", p.text())


def test_canary_word_leaked():
p = Prompt("{{canary_word}}This is my prompt")
assert p.canary_leaked(p.text())

p = Prompt("This is my prompt")
assert not p.canary_leaked(p.text())
47 changes: 47 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import regex as re

from banks.utils import generate_canary_word, strtobool


def test_generate_canary_word_defaults():
default = generate_canary_word()
assert re.match(r"BANKS\[.{8}\]", default)


def test_generate_canary_word_params():
only_token = generate_canary_word(prefix="", suffix="", token_length=16)
assert re.match(r".{16}", only_token)

only_prefix = generate_canary_word(prefix="foo", suffix="")
assert re.match(r"foo.{8}", only_prefix)

only_suffix = generate_canary_word(prefix="", suffix="foo")
assert re.match(r".{8}foo", only_suffix)


def test_strtobool_error():
with pytest.raises(ValueError):
strtobool("42")


@pytest.mark.parametrize(
"test_input,expected",
[
("y", True),
("yes", True),
("t", True),
("true", True),
("on", True),
("1", True),
("n", False),
("no", False),
("f", False),
("false", False),
("off", False),
("0", False),
pytest.param("42", True, marks=pytest.mark.xfail),
],
)
def test_strtobool(test_input, expected):
assert strtobool(test_input) == expected
Loading