Skip to content

Commit

Permalink
fix some graphql issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 8, 2025
1 parent 3fbde33 commit a70ff96
Showing 1 changed file with 59 additions and 35 deletions.
94 changes: 59 additions & 35 deletions agixt/endpoints/GQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,17 +1220,17 @@ def convert_provider_details(details: Dict[str, str]) -> ProviderDetail:
def convert_extension_command(raw_command: dict) -> ExtensionCommand:
"""Helper to convert raw command data to ExtensionCommand type"""
command_args = ExtensionCommandArgs(
required=raw_command["command_args"].get("required", []),
optional=raw_command["command_args"].get("optional", []),
description=raw_command["command_args"].get("description", ""),
required=raw_command.get("command_args", {}).get("required", []),
optional=raw_command.get("command_args", {}).get("optional", []),
description=raw_command.get("command_args", {}).get("description", ""),
)

return ExtensionCommand(
friendly_name=raw_command["friendly_name"],
description=raw_command["description"],
enabled=raw_command["enabled"],
friendly_name=raw_command.get("friendly_name", ""),
description=raw_command.get("description", ""),
enabled=raw_command.get("enabled", False),
command_args=command_args,
extension_name=raw_command["extension_name"],
extension_name=raw_command.get("extension_name", ""),
)


Expand Down Expand Up @@ -1510,25 +1510,46 @@ async def agents(self, info) -> List[AgentType]:
)
agents = get_agents(user=user) # Refresh agent list

return [
AgentType(
id=agent["id"],
name=agent["name"],
status=agent["status"],
company_id=agent.get("company_id"),
settings=[], # These would need to be populated if needed
commands=[], # These would need to be populated if needed
result = []
for agent in agents:
agent_instance = Agent(agent_name=agent["name"], user=user)
config = agent_instance.get_agent_config()

settings = [
AgentSetting(
name=k,
value=(
v
if not any(
x in k.upper()
for x in ["KEY", "SECRET", "PASSWORD", "TOKEN"]
)
else "HIDDEN"
),
)
for k, v in config["settings"].items()
]

commands = [
AgentCommand(name=k, enabled=v) for k, v in config["commands"].items()
]

result.append(
AgentType(
id=agent["id"],
name=agent["name"],
status=agent["status"],
company_id=agent.get("company_id"),
settings=settings,
commands=commands,
)
)
for agent in agents
]

return result

@strawberry.field
async def agent(self, info, name: str) -> AgentType:
"""Get a specific agent's configuration"""
user, auth = await get_user_from_context(info)
if not await is_admin(email=user, api_key=auth):
raise Exception("Access Denied")

agent = Agent(agent_name=name, user=user)
config = agent.get_agent_config()

Expand All @@ -1553,7 +1574,7 @@ async def agent(self, info, name: str) -> AgentType:
return AgentType(
id=agent.agent_id,
name=name,
status=False, # This could be updated if there's a status to track
status=False,
company_id=config["settings"].get("company_id"),
settings=settings,
commands=commands,
Expand Down Expand Up @@ -1604,7 +1625,8 @@ async def memories(
) -> List[Memory]:
"""Query agent memories from a specific collection"""
user, auth = await get_user_from_context(info)

if not auth:
raise Exception("Authorization required")
agent = Agent(agent_name=agent_name, user=user)
memories = Memories(
agent_name=agent_name,
Expand Down Expand Up @@ -1636,21 +1658,27 @@ async def memories(

@strawberry.field
async def memory_collections(self, info, agent_name: str) -> List[str]:
"""Get all memory collections for an agent"""
user, auth = await get_user_from_context(info)

memories = Memories(agent_name=agent_name, user=user)
agent = Agent(agent_name=agent_name, user=user)
memories = Memories(
agent_name=agent_name,
agent_config=agent.get_agent_config(),
collection_number="0",
user=user,
)
return await memories.get_collections()

@strawberry.field
async def external_sources(
self, info, agent_name: str, collection_number: str = "0"
) -> List[str]:
"""Get unique external sources in a collection"""
user, auth = await get_user_from_context(info)

agent = Agent(agent_name=agent_name, user=user)
memories = Memories(
agent_name=agent_name, collection_number=collection_number, user=user
agent_name=agent_name,
agent_config=agent.get_agent_config(),
collection_number=collection_number,
user=user,
)
return await memories.get_external_data_sources()

Expand Down Expand Up @@ -1713,8 +1741,9 @@ async def extensions(self, info) -> List[Extension]:

@strawberry.field
async def agent_extensions(self, info, agent_name: str) -> List[Extension]:
"""Get extensions for a specific agent"""
user, auth = await get_user_from_context(info)
if not auth:
raise Exception("Authorization required")

agent = Agent(agent_name=agent_name, user=user)
extension_list = agent.get_agent_extensions()
Expand All @@ -1741,12 +1770,7 @@ async def agent_extensions(self, info, agent_name: str) -> List[Extension]:

@strawberry.field
async def chains(self, info) -> List[str]:
"""Get all available chains"""
user, auth = await get_user_from_context(info)

if not await is_admin(email=user, api_key=auth):
raise Exception("Access Denied")

chain_manager = Chain(user=user)
return chain_manager.get_chains()

Expand Down

0 comments on commit a70ff96

Please sign in to comment.