Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow human approval before skill use for sequential workflows #49

Merged
merged 8 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ async def stream(
member.skills = member.skills

return StreamingResponse(
generator(team, members, team_chat.messages, thread_id),
generator(
team, members, team_chat.messages, thread_id, team_chat.interrupt_decision
),
media_type="text/event-stream",
)
41 changes: 35 additions & 6 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from functools import partial
from typing import Any

from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
Expand All @@ -27,7 +28,7 @@
WorkerNode,
)
from app.core.graph.skills import all_skills
from app.models import ChatMessage, Member, Team
from app.models import ChatMessage, InterruptDecision, Member, Team


def convert_hierarchical_team_to_dict(
Expand Down Expand Up @@ -184,7 +185,6 @@ def exit_chain(state: TeamState) -> dict[str, list[AnyMessage]]:
Pass the final response back to the top-level graph's state.
"""
answer = state["messages"][-1]
# Add human message at the end to prevent consecutive AI message which will cause error for some models
return {"messages": [answer]}


Expand Down Expand Up @@ -397,7 +397,11 @@ def convert_messages_and_tasks_to_dict(data: Any) -> Any:


async def generator(
team: Team, members: list[Member], messages: list[ChatMessage], thread_id: str
team: Team,
members: list[Member],
messages: list[ChatMessage],
thread_id: str,
interrupt_decision: InterruptDecision | None = None,
) -> AsyncGenerator[Any, Any]:
"""Create the graph and stream responses as JSON."""
formatted_messages = [
Expand All @@ -417,7 +421,7 @@ async def generator(
root = create_hierarchical_graph(
teams, leader_name=team_leader, memory=memory
)
state = {
state: dict[str, Any] | None = {
"messages": formatted_messages,
"team": teams[team_leader],
"main_task": formatted_messages,
Expand All @@ -439,11 +443,32 @@ async def generator(
),
"next": first_member.name,
}

config: RunnableConfig = {
"configurable": {"thread_id": thread_id},
"recursion_limit": 25,
}
# Handle interrupt logic by orriding state
if interrupt_decision == InterruptDecision.APPROVED:
state = None
elif interrupt_decision == InterruptDecision.REJECTED:
current_values = await root.aget_state(config)
tool_calls = current_values.values["messages"][-1].tool_calls
state = {
"messages": [
ToolMessage(
tool_call_id=tool_call["id"],
content="API call denied by user. Continue assisting.",
)
for tool_call in tool_calls
]
}

async for output in root.astream_events(
state,
version="v1",
include_names=["work", "delegate", "summarise"],
config={"configurable": {"thread_id": thread_id}, "recursion_limit": 25},
config=config,
):
if output["event"] == "on_chain_end":
output_data = output["data"]["output"]
Expand All @@ -453,6 +478,10 @@ async def generator(
formatted_output = f"data: {json.dumps(transformed_output_data)}\n\n"
if formatted_output != "data: null\n\n":
yield formatted_output
snapshot = await root.aget_state(config)
if snapshot.next:
# Interrupt occured
yield f"data: {json.dumps({'interrupt': True})}\n\n"
except Exception as e:
error_message = {
"error": str(e),
Expand Down
10 changes: 10 additions & 0 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,18 @@ class ChatMessage(BaseModel):
content: str


class InterruptDecision(Enum):
APPROVED = "approved"
REJECTED = "rejected"


class Interrupt(BaseModel):
decision: InterruptDecision


class TeamChat(BaseModel):
messages: list[ChatMessage]
interrupt_decision: InterruptDecision | None = None


class Team(TeamBase, table=True):
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export type { ChatMessageType } from './models/ChatMessageType';
export type { CheckpointOut } from './models/CheckpointOut';
export type { CreateThreadOut } from './models/CreateThreadOut';
export type { HTTPValidationError } from './models/HTTPValidationError';
export type { InterruptDecision } from './models/InterruptDecision';
export type { MemberCreate } from './models/MemberCreate';
export type { MemberOut } from './models/MemberOut';
export type { MembersOut } from './models/MembersOut';
Expand Down Expand Up @@ -47,6 +48,7 @@ export { $ChatMessageType } from './schemas/$ChatMessageType';
export { $CheckpointOut } from './schemas/$CheckpointOut';
export { $CreateThreadOut } from './schemas/$CreateThreadOut';
export { $HTTPValidationError } from './schemas/$HTTPValidationError';
export { $InterruptDecision } from './schemas/$InterruptDecision';
export { $MemberCreate } from './schemas/$MemberCreate';
export { $MemberOut } from './schemas/$MemberOut';
export { $MembersOut } from './schemas/$MembersOut';
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/client/models/CreateThreadOut.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ export type CreateThreadOut = {
id: string;
query: string;
updated_at: string;
last_checkpoint: CheckpointOut;
last_checkpoint: (CheckpointOut | null);
};

6 changes: 6 additions & 0 deletions frontend/src/client/models/InterruptDecision.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */

export type InterruptDecision = 'approved' | 'rejected';
2 changes: 2 additions & 0 deletions frontend/src/client/models/TeamChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
/* eslint-disable */

import type { ChatMessage } from './ChatMessage';
import type { InterruptDecision } from './InterruptDecision';

export type TeamChat = {
messages: Array<ChatMessage>;
interrupt_decision?: (InterruptDecision | null);
};

7 changes: 6 additions & 1 deletion frontend/src/client/schemas/$CreateThreadOut.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ export const $CreateThreadOut = {
format: 'date-time',
},
last_checkpoint: {
type: 'CheckpointOut',
type: 'any-of',
contains: [{
type: 'CheckpointOut',
}, {
type: 'null',
}],
isRequired: true,
},
},
Expand Down
7 changes: 7 additions & 0 deletions frontend/src/client/schemas/$InterruptDecision.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $InterruptDecision = {
type: 'Enum',
} as const;
8 changes: 8 additions & 0 deletions frontend/src/client/schemas/$TeamChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,13 @@ export const $TeamChat = {
},
isRequired: true,
},
interrupt_decision: {
type: 'any-of',
contains: [{
type: 'InterruptDecision',
}, {
type: 'null',
}],
},
},
} as const;
15 changes: 14 additions & 1 deletion frontend/src/components/Members/EditMember.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
Button,
Checkbox,
FormControl,
FormErrorMessage,
FormLabel,
Expand Down Expand Up @@ -51,7 +52,11 @@ const customSelectOption = {
// TODO: Place this somewhere else.
const AVAILABLE_MODELS = {
ChatOpenAI: ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"],
ChatAnthropic: ["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"],
ChatAnthropic: [
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
// ChatCohere: ["command"],
// ChatGoogleGenerativeAI: ["gemini-pro"],
}
Expand Down Expand Up @@ -238,6 +243,14 @@ export function EditMember({
</FormControl>
)}
/>
{member.type.startsWith("freelancer") ? (
<FormControl mt={4}>
<FormLabel htmlFor="interrupt">Human In The Loop</FormLabel>
<Checkbox {...register("interrupt")}>
Require approval before executing skills.
</Checkbox>
</FormControl>
) : null}
<FormControl mt={4} isRequired isInvalid={!!errors.role}>
<FormLabel htmlFor="provider">Provider</FormLabel>
<Select
Expand Down
Loading