-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ignore deprecation warnings when running tests * add canary word support * fix linting * drop python 3.8 * minimal documentation
- Loading branch information
Showing
12 changed files
with
146 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
|
||
from .utils import strtobool | ||
|
||
async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,28 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
import os | ||
|
||
from jinja2 import Environment, select_autoescape | ||
|
||
from banks.extensions import GenerateExtension, HFInferenceEndpointsExtension | ||
from banks.filters import lemmatize | ||
from banks.loader import MultiLoader | ||
from .config import async_enabled | ||
from .filters import lemmatize | ||
from .loader import MultiLoader | ||
|
||
|
||
def strtobool(val: str) -> bool: | ||
"""Convert a string representation of truth to True or False. | ||
def _add_extensions(env): | ||
""" | ||
We lazily add extensions so that we can use the env in the extensions themselves if needed. | ||
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values | ||
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if | ||
'val' is anything else. | ||
For example, we use banks to manage the system prompt in `GenerateExtension` | ||
""" | ||
val = val.lower() | ||
if val in ("y", "yes", "t", "true", "on", "1"): | ||
return True | ||
elif val in ("n", "no", "f", "false", "off", "0"): | ||
return False | ||
else: | ||
msg = f"invalid truth value {val}" | ||
raise ValueError(msg) | ||
from .extensions import GenerateExtension, HFInferenceEndpointsExtension | ||
|
||
env.add_extension(GenerateExtension) | ||
env.add_extension(HFInferenceEndpointsExtension) | ||
|
||
async_enabled = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) | ||
|
||
# Init the Jinja env | ||
env = Environment( | ||
loader=MultiLoader(), | ||
extensions=[GenerateExtension, HFInferenceEndpointsExtension], | ||
autoescape=select_autoescape( | ||
enabled_extensions=("html", "xml"), | ||
default_for_string=False, | ||
|
@@ -42,8 +32,9 @@ def strtobool(val: str) -> bool: | |
enable_async=bool(async_enabled), | ||
) | ||
|
||
# Setup custom filters | ||
# Setup custom filters and default extensions | ||
env.filters["lemmatize"] = lemmatize | ||
_add_extensions(env) | ||
|
||
|
||
def with_env(cls): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,7 @@ class MissingDependencyError(Exception): | |
|
||
class AsyncError(Exception): | ||
pass | ||
|
||
|
||
class CanaryWordError(Exception): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import secrets | ||
|
||
|
||
def strtobool(val: str) -> bool: | ||
""" | ||
Convert a string representation of truth to True or False. | ||
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values | ||
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if | ||
'val' is anything else. | ||
""" | ||
val = val.lower() | ||
if val in ("y", "yes", "t", "true", "on", "1"): | ||
return True | ||
elif val in ("n", "no", "f", "false", "off", "0"): | ||
return False | ||
else: | ||
msg = f"invalid truth value {val}" | ||
raise ValueError(msg) | ||
|
||
|
||
def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str: | ||
return f"{prefix}{secrets.token_hex(token_length // 2)}{suffix}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
import warnings | ||
|
||
# banks depends on modules producing loads of deprecation warnings, let's just ignore them, | ||
# nothing we can do anyways | ||
warnings.simplefilter("ignore", category=DeprecationWarning) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import regex as re | ||
|
||
from banks import Prompt | ||
|
||
|
||
def test_canary_word_generation(): | ||
p = Prompt("{{canary_word}}This is my prompt") | ||
assert re.match(r"BANKS\[.{8}\]This is my prompt", p.text()) | ||
|
||
|
||
def test_canary_word_leaked(): | ||
p = Prompt("{{canary_word}}This is my prompt") | ||
assert p.canary_leaked(p.text()) | ||
|
||
p = Prompt("This is my prompt") | ||
assert not p.canary_leaked(p.text()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pytest | ||
import regex as re | ||
|
||
from banks.utils import generate_canary_word, strtobool | ||
|
||
|
||
def test_generate_canary_word_defaults(): | ||
default = generate_canary_word() | ||
assert re.match(r"BANKS\[.{8}\]", default) | ||
|
||
|
||
def test_generate_canary_word_params(): | ||
only_token = generate_canary_word(prefix="", suffix="", token_length=16) | ||
assert re.match(r".{16}", only_token) | ||
|
||
only_prefix = generate_canary_word(prefix="foo", suffix="") | ||
assert re.match(r"foo.{8}", only_prefix) | ||
|
||
only_suffix = generate_canary_word(prefix="", suffix="foo") | ||
assert re.match(r".{8}foo", only_suffix) | ||
|
||
|
||
def test_strtobool_error(): | ||
with pytest.raises(ValueError): | ||
strtobool("42") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_input,expected", | ||
[ | ||
("y", True), | ||
("yes", True), | ||
("t", True), | ||
("true", True), | ||
("on", True), | ||
("1", True), | ||
("n", False), | ||
("no", False), | ||
("f", False), | ||
("false", False), | ||
("off", False), | ||
("0", False), | ||
pytest.param("42", True, marks=pytest.mark.xfail), | ||
], | ||
) | ||
def test_strtobool(test_input, expected): | ||
assert strtobool(test_input) == expected |