Skip to content

Commit

Permalink
add chain library
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 8, 2025
1 parent f5b514a commit 327f24a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
38 changes: 38 additions & 0 deletions agixt/Chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,44 @@ def get_chain(self, chain_name):
session.close()
return chain_data

def get_global_chains(self):
session = get_session()
user_data = session.query(User).filter(User.email == DEFAULT_USER).first()
global_chains = (
session.query(ChainDB).filter(ChainDB.user_id == user_data.id).all()
)
chains = session.query(ChainDB).filter(ChainDB.user_id == self.user_id).all()
chain_list = []
for chain in global_chains:
if chain in chains:
continue
chain_list.append(
{
"name": chain.name,
"description": chain.description,
"steps": chain.steps,
"runs": chain.runs,
}
)
session.close()
return chain_list

def get_user_chains(self):
session = get_session()
chains = session.query(ChainDB).filter(ChainDB.user_id == self.user_id).all()
chain_list = []
for chain in chains:
chain_list.append(
{
"name": chain.name,
"description": chain.description,
"steps": chain.steps,
"runs": chain.runs,
}
)
session.close()
return chain_list

def get_chains(self):
session = get_session()
user_data = session.query(User).filter(User.email == DEFAULT_USER).first()
Expand Down
91 changes: 87 additions & 4 deletions agixt/endpoints/GQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,37 @@ class ChainStep:
prompt: ChainPrompt


class ChainDetails:
"""Represents a chain's full details"""

id: str
name: str
description: Optional[str]
steps: List[ChainStep]
created_at: datetime
updated_at: datetime
user_id: str


@strawberry.type
class ChainRun:
"""Represents a single execution run of a chain"""

id: str
timestamp: datetime
completed: bool


@strawberry.type
class DetailedChain:
"""Represents a chain with full details"""

name: str
description: Optional[str]
steps: List[ChainStep]
runs: List[ChainRun]


@strawberry.type
class ChainConfig:
"""Represents a chain's complete configuration"""
Expand Down Expand Up @@ -1761,12 +1792,64 @@ async def agent_extensions(self, info, agent_name: str) -> List[Extension]:
return [convert_extension(ext) for ext in extension_list]

@strawberry.field
async def chains(self, info) -> List[str]:
async def chain_library(self, info) -> List[DetailedChain]:
"""Get all global chains"""
user, auth = await get_user_from_context(info)
chain_manager = Chain(user=user)
global_chains = chain_manager.get_global_chains()

return [
DetailedChain(
name=chain["name"],
description=chain["description"],
steps=[
ChainStep(
step_number=step.step_number,
agent_name=step.agent_name,
prompt_type=step.prompt_type,
prompt_content=step.prompt,
)
for step in chain["steps"]
],
runs=[
ChainRun(
id=str(run.id),
timestamp=run.timestamp,
completed=True, # You may want to add a completed field to your DB model
)
for run in chain["runs"]
],
)
for chain in global_chains
]

@strawberry.field
async def chains(self, info) -> List[DetailedChain]:
"""Get all user-specific chains"""
user, auth = await get_user_from_context(info)
if not is_admin(email=user, api_key=auth):
raise Exception("Access Denied")
chain_manager = Chain(user=user)
return chain_manager.get_chains()
user_chains = chain_manager.get_user_chains()

return [
DetailedChain(
name=chain["name"],
description=chain["description"],
steps=[
ChainStep(
step_number=step.step_number,
agent_name=step.agent_name,
prompt_type=step.prompt_type,
prompt_content=step.prompt,
)
for step in chain["steps"]
],
runs=[
ChainRun(id=str(run.id), timestamp=run.timestamp, completed=True)
for run in chain["runs"]
],
)
for chain in user_chains
]

@strawberry.field
async def chain(self, info, chain_name: str) -> ChainConfig:
Expand Down

0 comments on commit 327f24a

Please sign in to comment.