Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implements on topic guardrail #144

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/llm/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ async def make_conversation(
},
config,
)

last_message = result[MEMORY_KEY][-1].content
add_messages(chat_history, input, last_message)
return {"output": last_message}
Expand Down
5 changes: 3 additions & 2 deletions app/llm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

from .nodes import (
CONV_TYPE_DATA_OWNER,
CONV_TYPE_FAILED_GUARD,
CONV_TYPE_INFO,
CONVERSATION_TYPE_KEY,
DATA_OWNER_AGENT_NODE,
INFORMATION_AGENT_NAME,
RECOMMENDER_AGENT_NAME,
Expand All @@ -29,8 +31,6 @@
)
from .prompts import MEMORY_KEY, WS_ID_KEY

CONVERSATION_TYPE_KEY = "conv_type"


class GraphState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
Expand Down Expand Up @@ -171,6 +171,7 @@ def base_edges() -> list[_Condition_Edge]:
conditional_edge_mapping={
CONV_TYPE_INFO: INFORMATION_AGENT_NAME,
CONV_TYPE_DATA_OWNER: DATA_OWNER_AGENT_NODE,
CONV_TYPE_FAILED_GUARD: END,
},
),
_Condition_Edge(
Expand Down
55 changes: 55 additions & 0 deletions app/llm/guardrails/on_topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field

from app.llm.model import create_model
from app.settings import settings


class Valid(BaseModel):
"""Is the user message valid for topics"""

is_valid: bool = Field(description="user message is valid")


_parser = PydanticOutputParser(pydantic_object=Valid)


def on_topic_guard():
model = create_model(
model=settings.GUARDRAILS_LLM_MODEL,
temperature=0.2,
streaming=False,
name="guardrails",
)
allowed_topics = [
"greeting",
"yes or no answers",
"access request",
"recommending access",
"information about applications",
]
template = f"""
Your job is to determine if the user's input is on topic
allowed topics are: {','.join(allowed_topics)}
"""
prompt = PromptTemplate(
template="{template}.\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={
"format_instructions": _parser.get_format_instructions(),
"template": template,
},
)

prompt_and_model = prompt | model

return prompt_and_model


async def topical_guardrail(user_request) -> Valid:
guard = on_topic_guard()
output = guard.invoke({"query": user_request})
result = _parser.invoke(output)

return result
7 changes: 5 additions & 2 deletions app/llm/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from urllib.parse import parse_qs, urlparse

from langchain_core.language_models.chat_models import (
Expand All @@ -22,9 +23,11 @@ def get_model_from_uri(uri: str) -> tuple[str, str, dict[str, str]]:
return [protocol, parsed_url.netloc, qp]


def create_model(**args) -> BaseChatModel:
def create_model(model: Optional[str] = None, **args) -> BaseChatModel:
if model is None:
model = settings.LLM_MODEL
# TODO convert qp as model args
type, model, qp = get_model_from_uri(settings.LLM_MODEL)
type, model, qp = get_model_from_uri(model)
if type == "fake":
filtered_args = {
key: value for key, value in args.items() if key in LLM_ALLOWED_ARGS["fake"]
Expand Down
45 changes: 42 additions & 3 deletions app/llm/nodes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import functools
from types import coroutine

from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field

from app.llm.guardrails.on_topic import topical_guardrail
from app.llm.tools.deny_access_tool import create_deny_provision_tool
from app.models import ConversationTypes

Expand All @@ -27,13 +29,15 @@
INFORMATION_AGENT_NAME = "Information"
RECOMMENDER_AGENT_NAME = "Recommender"
DATA_OWNER_AGENT_NODE = "DataOwner"
CONVERSATION_TYPE_KEY = "conv_type"
CONV_TYPE_DATA_OWNER = ConversationTypes.data_owner.value
CONV_TYPE_INFO = ConversationTypes.recommendation.value
CONV_TYPE_FAILED_GUARD = "FAILED_GUARD"


def agent_node(state, agent_creator, name):
agent = agent_creator(state)
result = asyncio.new_event_loop().run_until_complete(agent.ainvoke(state))
result = asyncio.run(agent.ainvoke(state))
# We convert the agent output into a format that is suitable to append to the global state
if isinstance(result, ToolMessage):
pass
Expand Down Expand Up @@ -125,6 +129,27 @@ class IGNOutput(BaseModel):
app_name: str = Field(description="the app name")


async def execute_chat_with_guardrail(runnable: coroutine, input):
topical_guardrail_task = asyncio.create_task(topical_guardrail(input))
chat_task = asyncio.create_task(runnable)
while True:
done, _ = await asyncio.wait(
[topical_guardrail_task, chat_task], return_when=asyncio.FIRST_COMPLETED
)
if topical_guardrail_task in done:
guardrail_response = topical_guardrail_task.result()
if not guardrail_response.is_valid:
chat_task.cancel()
return {
"output": "I'm sorry, I can only talk about access requests"
}, False
elif chat_task in done:
result = chat_task.result()
return result, True
else:
await asyncio.sleep(0.2)


def entry_point_node(data_context):
def _epn(state):
agent = create_agent(
Expand All @@ -134,8 +159,22 @@ def _epn(state):
streaming=False,
)

result = agent.invoke(state)
corou = agent.ainvoke(state)
result, ok = asyncio.run(
execute_chat_with_guardrail(
runnable=corou, input=state[MEMORY_KEY][-1].content
)
)

output = result["output"]

if not ok:
return {
"sender": "entry_point",
MEMORY_KEY: [AIMessage(content=output)],
CONVERSATION_TYPE_KEY: CONV_TYPE_FAILED_GUARD,
}

if not isinstance(output, dict):
return {
"sender": "entry_point",
Expand Down
4 changes: 1 addition & 3 deletions app/llm/tests/test_data_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def test_get_data_owner(self):
name="okta", config={}, workspace_id="foo", created_by="[email protected]"
)

owner = asyncio.new_event_loop().run_until_complete(
get_data_owner(ws=ws, directory=dir, app_name="foo")
)
owner = asyncio.run(get_data_owner(ws=ws, directory=dir, app_name="foo"))

self.assertEqual(owner.email, expected_email)
File renamed without changes.
2 changes: 1 addition & 1 deletion app/llm/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_make_request(self):
owner = User(id="123", email="[email protected]")
output = "making request"

_ = asyncio.new_event_loop().run_until_complete(
_ = asyncio.run(
make_request(
ws=ws,
owner=owner,
Expand Down
3 changes: 2 additions & 1 deletion app/llm/tools/app_extra_intructions_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class FindAppInput(BaseModel):
app_name: Optional[str] = Field(description="should be a the app name")


def _find_app(workspace_id: str, app_name: Optional[str] = "") -> str:
async def _find_app(workspace_id: str, app_name: Optional[str] = "") -> str:
app_store = factory_app_store()
empty_res = {
APP_ID_KEY: None,
Expand Down Expand Up @@ -44,6 +44,7 @@ def _find_app(workspace_id: str, app_name: Optional[str] = "") -> str:

find_app_extra_inst_tool = StructuredTool.from_function(
func=_find_app,
coroutine=_find_app,
name="find_app",
description="finds the app name in the messages",
args_schema=FindAppInput,
Expand Down
1 change: 1 addition & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Settings(BaseSettings):
OAUTH2_AUDIENCE: str
# llm
LLM_MODEL: str = "openai://gpt-4-turbo-preview"
GUARDRAILS_LLM_MODEL: str = "openai://gpt-4o"
# messaging
SLACK_CLIENT_ID: str
SLACK_CLIENT_SECRET: str
Expand Down
2 changes: 1 addition & 1 deletion app/slack/event_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def answer(client: WebClient, event, logger, say, context):
input=event["text"],
tx_context=tx_context,
)
result = asyncio.new_event_loop().run_until_complete(co_routine)
result = asyncio.run(co_routine)
say(
blocks=[
{
Expand Down
Loading