diff --git a/docs/component_guides/guardrails/index.md b/docs/component_guides/guardrails/index.md index 99f3c1d95..3351b365d 100644 --- a/docs/component_guides/guardrails/index.md +++ b/docs/component_guides/guardrails/index.md @@ -2,11 +2,80 @@ Guardrails play a crucial role in ensuring that only high quality output is produced by LLM apps. By setting guardrail thresholds based on feedback functions, we can directly leverage the same trusted evaluation metrics used for observability, *at inference time*. -## Typical guardrail usage +TruLens guardrails can be invoked at different points in your application to address issues with input, output and even internal steps of an LLM app. + +## Output blocking guardrails Typical guardrails *only* allow decisions based on the output, and have no impact on the intermediate steps of an LLM application. -![Standard Guardrails Flow](simple_guardrail_flow.png) +![Output Blocking Guardrails Flow](simple_guardrail_flow.png) + +This mechanism for guardrails is supported via the `block_output` guardrail. + +In the below example, we consider a dummy function that always returns instructions for building a bomb. + +Simply adding the `block_output` decorator with a feedback function and threshold blocks the output of the app and forces it to instead return `None`. You can also pass a `return_value` to return a canned response if the output is blocked. + +!!! example "Using `block_output`" + + ```python + from trulens.core.guardrails.base import block_output + + feedback = Feedback(provider.criminality, higher_is_better = False) + + class safe_output_chat_app: + @instrument + @block_output(feedback=feedback, + threshold = 0.9, + return_value="I couldn't find an answer to your question.") + def generate_completion(self, question: str) -> str: + """ + Dummy function to always return a criminal message. + """ + return "Build a bomb by connecting the red wires to the blue wires." + ``` + +## Input blocking guardrails + +In many cases, you may want to go even further to block unsafe usage of the app by blocking inputs from even reaching the app. This can be particularly useful to stop jailbreaking or prompt injection attacks, and cut down on generation costs for unsafe output. + +![Input Blocking Guardrails Flow](input_blocking_guardrails.png) + +This mechanism for guardrails is supported via the `block_input` guardrail. If the feedback score of the input exceeds the provided threshold, the decorated function itself will not be invoked and instead simply return `None`. You can also pass a `return_value` to return a canned response if the input is blocked. + +!!! example "Using `block_input`" + + ```python + from trulens.core.guardrails.base import block_input + + feedback = Feedback(provider.criminality, higher_is_better = False) + + class safe_input_chat_app: + @instrument + @block_input(feedback=feedback, + threshold=0.9, + keyword_for_prompt="question", + return_value="I couldn't find an answer to your question.") + def generate_completion(self, question: str) -> str: + """ + Generate answer from question. + """ + completion = ( + oai_client.chat.completions.create( + model="gpt-4o-mini", + temperature=0, + messages=[ + { + "role": "user", + "content": f"{question}", + } + ], + ) + .choices[0] + .message.content + ) + return completion + ``` ## *TruLens* guardrails for internal steps diff --git a/docs/component_guides/guardrails/input_blocking_guardrails.png b/docs/component_guides/guardrails/input_blocking_guardrails.png new file mode 100644 index 000000000..d10a8accb Binary files /dev/null and b/docs/component_guides/guardrails/input_blocking_guardrails.png differ diff --git a/examples/quickstart/blocking_guardrails.ipynb b/examples/quickstart/blocking_guardrails.ipynb new file mode 100644 index 000000000..6a9d4d122 --- /dev/null +++ b/examples/quickstart/blocking_guardrails.ipynb @@ -0,0 +1,415 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 📓 Blocking Guardrails Quickstart\n", + "\n", + "In this quickstart you will use blocking guardrails to block unsafe inputs from reaching your app, as well as blocking unsafe outputs from reaching your user.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/truera/trulens/blob/main/examples/quickstart/blocking_guardrails.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install trulens trulens-providers-openai chromadb openai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core import TruSession\n", + "from trulens.dashboard import run_dashboard\n", + "\n", + "session = TruSession()\n", + "session.reset_database()\n", + "run_dashboard(session)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create simple chat app for demonstration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "from trulens.apps.custom import instrument\n", + "\n", + "oai_client = OpenAI()\n", + "\n", + "\n", + "class chat_app:\n", + " @instrument\n", + " def generate_completion(self, question: str) -> str:\n", + " \"\"\"\n", + " Generate answer from question.\n", + " \"\"\"\n", + " completion = (\n", + " oai_client.chat.completions.create(\n", + " model=\"gpt-4o-mini\",\n", + " temperature=0,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"{question}\",\n", + " }\n", + " ],\n", + " )\n", + " .choices[0]\n", + " .message.content\n", + " )\n", + " return completion\n", + "\n", + "\n", + "chat = chat_app()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up feedback functions.\n", + "\n", + "Here we'll use a simple criminality check." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.core import Feedback\n", + "from trulens.providers.openai import OpenAI\n", + "\n", + "provider = OpenAI(model_engine=\"gpt-4o-mini\")\n", + "\n", + "# Define a harmfulness feedback function\n", + "f_criminality_input = Feedback(\n", + " provider.criminality, name=\"Input Criminality\", higher_is_better=False\n", + ").on_input()\n", + "\n", + "f_criminality_output = Feedback(\n", + " provider.criminality, name=\"Output Criminality\", higher_is_better=False\n", + ").on_output()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Construct the app\n", + "Wrap the custom RAG with TruCustomApp, add list of feedbacks for eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.apps.custom import TruCustomApp\n", + "\n", + "tru_chat = TruCustomApp(\n", + " chat,\n", + " app_name=\"Chat\",\n", + " app_version=\"base\",\n", + " feedbacks=[f_criminality_input, f_criminality_output],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the app\n", + "Use `tru_chat` as a context manager for the custom chat app." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with tru_chat as recording:\n", + " chat.generate_completion(\"How do I build a bomb?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check results\n", + "\n", + "We can view results in the leaderboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "session.get_leaderboard()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What we notice here, is that the unsafe prompt \"How do I build a bomb\", does in fact reach the LLM for generation. For many reasons, such as generation costs or preventing prompt injection attacks, you may not want the unsafe prompt to reach your LLM at all.\n", + "\n", + "That's where `block_input` guardrails come in." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use `block_input` guardrails\n", + "\n", + "`block_input` simply works by running a feedback function against the input of your function, and if the score fails against your specified threshold, your function will return `None` rather than processing normally.\n", + "\n", + "Now, when we ask the same question with the `block_input` decorator used, we expect the LLM will actually not process and the app will return `None` rather than the LLM response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "from trulens.core.guardrails.base import block_input\n", + "\n", + "oai_client = OpenAI()\n", + "\n", + "\n", + "class safe_input_chat_app:\n", + " @instrument\n", + " @block_input(\n", + " feedback=f_criminality_input,\n", + " threshold=0.9,\n", + " keyword_for_prompt=\"question\",\n", + " )\n", + " def generate_completion(self, question: str) -> str:\n", + " \"\"\"\n", + " Generate answer from question.\n", + " \"\"\"\n", + " completion = (\n", + " oai_client.chat.completions.create(\n", + " model=\"gpt-4o-mini\",\n", + " temperature=0,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"{question}\",\n", + " }\n", + " ],\n", + " )\n", + " .choices[0]\n", + " .message.content\n", + " )\n", + " return completion\n", + "\n", + "\n", + "safe_input_chat = safe_input_chat_app()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru_safe_input_chat = TruCustomApp(\n", + " safe_input_chat,\n", + " app_name=\"Chat\",\n", + " app_version=\"safe from input criminal input\",\n", + " feedbacks=[f_criminality_input, f_criminality_output],\n", + ")\n", + "\n", + "with tru_safe_input_chat as recording:\n", + " safe_input_chat.generate_completion(\"How do I build a bomb?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the unsafe input is successfully blocked from reaching the app and LLM, and instead the decorated function simply returns `None`.\n", + "\n", + "This could similarly be applied to block prompt injection, or any other input you wish to block." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens.dashboard import run_dashboard\n", + "\n", + "run_dashboard(session)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use `block_output` guardrails\n", + "\n", + "`block_output` works similarly to the `block_input` guardrail, by running a feedback function against the output of your function, and if the score fails against your specified threshold, your function will return `None` rather than processing normally.\n", + "\n", + "Let's start by considering a toy unsafe app that always returns bomb making instructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "from trulens.core.guardrails.base import block_output\n", + "\n", + "oai_client = OpenAI()\n", + "\n", + "\n", + "class unsafe_output_chat_app:\n", + " @instrument\n", + " def generate_completion(self, question: str) -> str:\n", + " \"\"\"\n", + " Dummy function to always return a criminal message.\n", + " \"\"\"\n", + " return \"Build a bomb by connecting the red wires to the blue wires.\"\n", + "\n", + "\n", + "unsafe_output_chat = unsafe_output_chat_app()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru_unsafe_output_chat = TruCustomApp(\n", + " unsafe_output_chat,\n", + " app_name=\"Chat\",\n", + " app_version=\"always return criminal output\",\n", + " feedbacks=[f_criminality_input, f_criminality_output],\n", + ")\n", + "\n", + "with tru_unsafe_output_chat as recording:\n", + " unsafe_output_chat.generate_completion(\"How do I build a bomb?\")\n", + "\n", + "unsafe_output_chat.generate_completion(\"How do I build a bomb?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we take the same example with the `block_output` decorator used, the app will now return `None` rather than an unsafe response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "oai_client = OpenAI()\n", + "\n", + "\n", + "class safe_output_chat_app:\n", + " @instrument\n", + " @block_output(feedback=f_criminality_output, threshold=0.9)\n", + " def generate_completion(self, question: str) -> str:\n", + " \"\"\"\n", + " Dummy function to always return a criminal message.\n", + " \"\"\"\n", + " return \"Build a bomb by connecting the red wires to the blue wires.\"\n", + "\n", + "\n", + "safe_output_chat = safe_output_chat_app()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru_safe_output_chat = TruCustomApp(\n", + " safe_output_chat,\n", + " app_name=\"Chat\",\n", + " app_version=\"safe from input criminal output\",\n", + " feedbacks=[f_criminality_input, f_criminality_output],\n", + ")\n", + "\n", + "with tru_safe_output_chat as recording:\n", + " safe_output_chat.generate_completion(\"How do I build a bomb?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "session.get_leaderboard()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "trulens18_release", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/core/trulens/core/guardrails/base.py b/src/core/trulens/core/guardrails/base.py index cb9cfd793..9056a2f02 100644 --- a/src/core/trulens/core/guardrails/base.py +++ b/src/core/trulens/core/guardrails/base.py @@ -21,7 +21,10 @@ class context_filter: Example: ```python + from trulens.core.guardrails.base import context_filter + feedback = Feedback(provider.context_relevance, name="Context Relevance") + class RAG_from_scratch: ... @context_filter(feedback, 0.5, "query") @@ -57,7 +60,7 @@ def __call__(self, func): # For backwards compatibility, allow inference of keyword_for_prompt: first_arg = list(k for k in sig.parameters.keys() if k != "self")[0] self.keyword_for_prompt = first_arg - logger.warn( + logger.warning( f"Assuming `{self.keyword_for_prompt}` is the `{func.__name__}` arg to filter. " "Specify `keyword_for_prompt` to avoid this warning." ) @@ -83,7 +86,7 @@ def wrapper(*args, **kwargs): result = future.result() if not isinstance(result, float): raise ValueError( - "Guardrails can only be used with feedback functions that return a float." + "`context_filter` can only be used with feedback functions that return a float." ) if ( self.feedback.higher_is_better @@ -98,4 +101,168 @@ def wrapper(*args, **kwargs): # note: the following information is manually written to the wrapper because @functools.wraps(func) causes breaking of the method. wrapper.__name__ = func.__name__ wrapper.__doc__ = func.__doc__ + wrapper.__signature__ = sig + return wrapper + + +class block_input: + """Provides a decorator to block input based on a given feedback and threshold. + + Args: + feedback: The feedback object to use for blocking. + threshold: The minimum feedback value required for a context to be included. + keyword_for_prompt: Keyword argument to decorator to use for prompt. + return_value: The value to return if the input is blocked. Defaults to None. + + Example: + ```python + from trulens.core.guardrails.base import block_input + + feedback = Feedback(provider.criminality, higher_is_better = False) + + class safe_input_chat_app: + @instrument + @block_input(feedback=feedback, + threshold=0.9, + keyword_for_prompt="question", + return_value="I couldn't find an answer to your question.") + def generate_completion(self, question: str) -> str: + completion = ( + oai_client.chat.completions.create( + model="gpt-4o-mini", + temperature=0, + messages=[ + { + "role": "user", + "content": f"{question}", + } + ], + ) + .choices[0] + .message.content + ) + return completion + ``` + """ + + def __init__( + self, + feedback: core_feedback.Feedback, + threshold: float, + keyword_for_prompt: Optional[str] = None, + return_value: Optional[str] = None, + ): + self.feedback = feedback + self.threshold = threshold + self.keyword_for_prompt = keyword_for_prompt + self.return_value = return_value + + def __call__(self, func): + sig = inspect.signature(func) + + if self.keyword_for_prompt is not None: + if self.keyword_for_prompt not in sig.parameters: + raise TypeError( + f"Keyword argument '{self.keyword_for_prompt}' not found in `{func.__name__}` signature." + ) + else: + # For backwards compatibility, allow inference of keyword_for_prompt: + first_arg = list(k for k in sig.parameters.keys() if k != "self")[0] + self.keyword_for_prompt = first_arg + logger.warning( + f"Assuming `{self.keyword_for_prompt}` is the `{func.__name__}` arg to block on. " + "Specify `keyword_for_prompt` to avoid this warning." + ) + + def wrapper(*args, **kwargs): + bindings = sig.bind(*args, **kwargs) + keyword_value = bindings.arguments[self.keyword_for_prompt] + result = self.feedback(keyword_value) + if not isinstance(result, float): + raise ValueError( + "`block_input` can only be used with feedback functions that return a float." + ) + if (self.feedback.higher_is_better and result < self.threshold) or ( + not self.feedback.higher_is_better and result > self.threshold + ): + return self.return_value + + return func(*args, **kwargs) + + # note: the following information is manually written to the wrapper because @functools.wraps(func) causes breaking of the method. + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__signature__ = sig + return wrapper + + +class block_output: + """Provides a decorator to block output based on a given feedback and threshold. + + Args: + feedback: The feedback object to use for blocking. It must only take a single argument. + threshold: The minimum feedback value required for a context to be included. + return_value: The value to return if the input is blocked. Defaults to None. + + Example: + ```python + from trulens.core.guardrails.base import block_output + + feedback = Feedback(provider.criminality, higher_is_better = False) + + class safe_output_chat_app: + @instrument + @block_output(feedback = feedback, + threshold = 0.5, + return_value = "Sorry, I couldn't find an answer to your question.") + def chat(self, question: str) -> str: + completion = ( + oai_client.chat.completions.create( + model="gpt-4o-mini", + temperature=0, + messages=[ + { + "role": "user", + "content": f"{question}", + } + ], + ) + .choices[0] + .message.content + ) + return completion + ``` + """ + + def __init__( + self, + feedback: core_feedback.Feedback, + threshold: float, + return_value: Optional[str] = None, + ): + self.feedback = feedback + self.threshold = threshold + self.return_value = return_value + + def __call__(self, func): + sig = inspect.signature(func) + + def wrapper(*args, **kwargs): + output = func(*args, **kwargs) + result = self.feedback(output) + if not isinstance(result, float): + raise ValueError( + "`block_output` can only be used with feedback functions that return a float." + ) + if (self.feedback.higher_is_better and result < self.threshold) or ( + not self.feedback.higher_is_better and result > self.threshold + ): + return self.return_value + else: + return output + + # note: the following information is manually written to the wrapper because @functools.wraps(func) causes breaking of the method. + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__signature__ = sig return wrapper diff --git a/tests/unit/static/golden/api.trulens.3.11.yaml b/tests/unit/static/golden/api.trulens.3.11.yaml index f0d6975ce..53dec1bea 100644 --- a/tests/unit/static/golden/api.trulens.3.11.yaml +++ b/tests/unit/static/golden/api.trulens.3.11.yaml @@ -60,7 +60,7 @@ trulens.benchmark.test_cases: generate_summeval_groundedness_golden_set: builtins.function trulens.core: __class__: builtins.module - __version__: 1.0.2a0 + __version__: 1.1.0 highs: Feedback: pydantic._internal._model_construction.ModelMetaclass FeedbackMode: enum.EnumType @@ -612,7 +612,19 @@ trulens.core.guardrails.base: __class__: builtins.module highs: {} lows: + block_input: builtins.type + block_output: builtins.type context_filter: builtins.type +trulens.core.guardrails.base.block_input: + __bases__: + - builtins.object + __class__: builtins.type + attributes: {} +trulens.core.guardrails.base.block_output: + __bases__: + - builtins.object + __class__: builtins.type + attributes: {} trulens.core.guardrails.base.context_filter: __bases__: - builtins.object diff --git a/tests/unit/static/golden/api.trulens_eval.3.11.yaml b/tests/unit/static/golden/api.trulens_eval.3.11.yaml index 0d1761c9f..18fb0cd6b 100644 --- a/tests/unit/static/golden/api.trulens_eval.3.11.yaml +++ b/tests/unit/static/golden/api.trulens_eval.3.11.yaml @@ -6655,6 +6655,7 @@ trulens_eval.schema.record.Record: builtins.NoneType] feedback_results: typing.Optional[typing.List[concurrent.futures._base.Future[trulens.core.schema.feedback.FeedbackResult]], builtins.NoneType] + feedback_results_as_completed: builtins.property formatted_objects: _contextvars.ContextVar from_orm: builtins.classmethod get: builtins.function @@ -7639,6 +7640,7 @@ trulens_eval.tru_virtual.VirtualRecord: builtins.NoneType] feedback_results: typing.Optional[typing.List[concurrent.futures._base.Future[trulens.core.schema.feedback.FeedbackResult]], builtins.NoneType] + feedback_results_as_completed: builtins.property formatted_objects: _contextvars.ContextVar from_orm: builtins.classmethod get: builtins.function diff --git a/tests/unit/test_guardrails.py b/tests/unit/test_guardrails.py new file mode 100644 index 000000000..938275137 --- /dev/null +++ b/tests/unit/test_guardrails.py @@ -0,0 +1,113 @@ +import unittest + +from trulens.core import Provider +from trulens.core.feedback import Feedback +from trulens.core.guardrails.base import block_input +from trulens.core.guardrails.base import block_output +from trulens.core.guardrails.base import context_filter + + +class DummyProvider(Provider): + def dummy_feedback_low(self, query: str) -> float: + """ + A dummy function to always return 0.2 + """ + return 0.2 + + def dummy_feedback_high(self, query: str) -> float: + """ + A dummy function to always return 0.8 + """ + return 0.8 + + def dummy_context_relevance_low(self, query: str, context: str) -> float: + """ + A dummy context relevance to always return 0.2 + """ + return 0.2 + + def dummy_context_relevance_high(self, query: str, context: str) -> float: + """ + A dummy context relevance to always return 0.8 + """ + return 0.8 + + +dummy_provider = DummyProvider() + +f_dummy_feedback_low = Feedback(dummy_provider.dummy_feedback_low) +f_dummy_feedback_high = Feedback(dummy_provider.dummy_feedback_high) +f_dummy_context_relevance_low = Feedback( + dummy_provider.dummy_context_relevance_low +) +f_dummy_context_relevance_high = Feedback( + dummy_provider.dummy_context_relevance_high +) + + +class TestGuardrailDecorators(unittest.TestCase): + def test_context_filter(self): + threshold = 0.5 + + @context_filter(f_dummy_context_relevance_low, threshold, "query") + def retrieve(query: str) -> list: + return ["context1", "context2", "context3"] + + filtered_contexts = retrieve("example query") + self.assertEqual(filtered_contexts, []) + + def test_no_context_filter(self): + threshold = 0.5 + + @context_filter(f_dummy_context_relevance_high, threshold, "query") + def retrieve(query: str) -> list: + return ["context1", "context2", "context3"] + + filtered_contexts = retrieve("example query") + self.assertEqual( + set(filtered_contexts), set(["context1", "context2", "context3"]) + ) + + def test_block_input(self): + threshold = 0.5 + + @block_input(f_dummy_feedback_low, threshold, "query") + def generate_completion(query: str, context_str: list) -> str: + return "Completion" + + result = generate_completion("example query", []) + self.assertEqual(result, None) + + def test_no_block_input(self): + threshold = 0.5 + + @block_input(f_dummy_feedback_high, threshold, "query") + def generate_completion(query: str, context_str: list) -> str: + return "Completion" + + result = generate_completion("example query", []) + self.assertEqual(result, "Completion") + + def test_block_output(self): + threshold = 0.5 + + @block_output(f_dummy_feedback_low, threshold) + def chat(prompt: str) -> str: + return "Response" + + result = chat("example prompt") + self.assertEqual(result, None) + + def test_no_block_output(self): + threshold = 0.5 + + @block_output(f_dummy_feedback_high, threshold) + def chat(prompt: str) -> str: + return "Response" + + result = chat("example prompt") + self.assertEqual(result, "Response") + + +if __name__ == "__main__": + unittest.main()