Skip to content

Commit

Permalink
fix: on topic guardrails
Browse files Browse the repository at this point in the history
fixes on topic guardrail to accept phrases like "try again"
  • Loading branch information
Erez Sharim authored and asaf committed Aug 14, 2024
1 parent d3011b8 commit 71921a7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
15 changes: 10 additions & 5 deletions app/llm/guardrails/on_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


class Valid(BaseModel):
"""Is the user message valid for topics"""
"""Is the user message valid for allowed topics"""

is_valid: bool = Field(description="user message is valid")
why: str = Field(description="explanation why is the user message valid or not")


_parser = PydanticOutputParser(pydantic_object=Valid)
Expand All @@ -20,26 +21,30 @@ class Valid(BaseModel):
def on_topic_guard():
model = create_model(
model=settings.SMALL_LLM_MODEL,
temperature=0.2,
temperature=0.3,
streaming=False,
name="guardrails",
)
allowed_topics = [
"greeting",
"access request",
"work related request",
"recommending access",
"assistant capabilities inquiry",
"information gathering about applications",
"information gathering about access",
"action repetition request",
]
template = """
Your job is to determine if the user's input is on topic
allowed topics are: {allowed_topics}
Your job is to determine if the user's input is on topic.
Any of the following topics are allowed, together or individually:
{allowed_topics}
{format_instructions}
"""
pv = {
"format_instructions": _parser.get_format_instructions(),
"allowed_topics": ",".join(allowed_topics),
"allowed_topics": "\n".join(f"- {topic}" for topic in allowed_topics),
}
prompt = PromptTemplate.from_template(template=template, partial_variables=pv)
sys_msg = SystemMessagePromptTemplate(prompt=prompt)
Expand Down
3 changes: 2 additions & 1 deletion app/llm/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ def _epn(state):
)

corou = agent.ainvoke(state)
input = state[MEMORY_KEY][-5:]
result, ok = asyncio.new_event_loop().run_until_complete(
execute_chat_with_guardrail(runnable=corou, input=state[MEMORY_KEY][-5:])
execute_chat_with_guardrail(runnable=corou, input=input)
)

output = result["output"]
Expand Down

0 comments on commit 71921a7

Please sign in to comment.