Skip to content

Commit

Permalink
improve prompts output
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 8, 2025
1 parent 8c54df0 commit 35a16b0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
25 changes: 25 additions & 0 deletions agixt/Prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,31 @@ def get_global_prompts(self):
session.close()
return prompts

def get_user_prompts(self):
session = get_session()
user_prompts = (
session.query(Prompt).filter(Prompt.user_id == self.user_id).all()
)
prompts = []
for prompt in user_prompts:
try:
prompt_args = [
arg.name for arg in prompt.arguments if arg.prompt_id == prompt.id
]
except:
prompt_args = []
prompts.append(
{
"name": prompt.name,
"category": prompt.prompt_category.name,
"content": prompt.content,
"description": prompt.description,
"arguments": prompt_args,
}
)
session.close()
return prompts

def get_prompts(self, prompt_category="Default"):
if not prompt_category:
prompt_category = "Default"
Expand Down
31 changes: 29 additions & 2 deletions agixt/endpoints/GQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,11 +1469,38 @@ async def prompt(self, info, name: str, category: str = "Default") -> PromptType
)

@strawberry.field
async def prompts(self, info, category: str = "Default") -> List[PromptType]:
async def prompts(self, info) -> List[PromptType]:
"""Get all prompts in a category"""
user, auth = await get_user_from_context(info)
prompt_manager = Prompts(user=user)
return prompt_manager.get_prompts(prompt_category=category)
result = prompt_manager.get_user_prompts()
return [
PromptType(
name=prompt["name"],
content=prompt["content"],
category=prompt["category"],
description=prompt["description"],
arguments=[PromptArgument(name=arg) for arg in prompt["arguments"]],
)
for prompt in result
]

@strawberry.field
async def promptLibrary(self, info) -> List[PromptType]:
"""Get all prompts in a category"""
user, auth = await get_user_from_context(info)
prompt_manager = Prompts(user=user)
result = prompt_manager.get_global_prompts()
return [
PromptType(
name=prompt["name"],
content=prompt["content"],
category=prompt["category"],
description=prompt["description"],
arguments=[PromptArgument(name=arg) for arg in prompt["arguments"]],
)
for prompt in result
]

@strawberry.field
async def prompt_categories(self, info) -> List[str]:
Expand Down

0 comments on commit 35a16b0

Please sign in to comment.