From 35a16b02bb50c95024da31dc6226dfa8377beef8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 7 Jan 2025 21:08:57 -0500 Subject: [PATCH] improve prompts output --- agixt/Prompts.py | 25 +++++++++++++++++++++++++ agixt/endpoints/GQL.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/agixt/Prompts.py b/agixt/Prompts.py index d82e25387870..80c17ba87213 100644 --- a/agixt/Prompts.py +++ b/agixt/Prompts.py @@ -167,6 +167,31 @@ def get_global_prompts(self): session.close() return prompts + def get_user_prompts(self): + session = get_session() + user_prompts = ( + session.query(Prompt).filter(Prompt.user_id == self.user_id).all() + ) + prompts = [] + for prompt in user_prompts: + try: + prompt_args = [ + arg.name for arg in prompt.arguments if arg.prompt_id == prompt.id + ] + except: + prompt_args = [] + prompts.append( + { + "name": prompt.name, + "category": prompt.prompt_category.name, + "content": prompt.content, + "description": prompt.description, + "arguments": prompt_args, + } + ) + session.close() + return prompts + def get_prompts(self, prompt_category="Default"): if not prompt_category: prompt_category = "Default" diff --git a/agixt/endpoints/GQL.py b/agixt/endpoints/GQL.py index 3547f9853654..b535b452395a 100644 --- a/agixt/endpoints/GQL.py +++ b/agixt/endpoints/GQL.py @@ -1469,11 +1469,38 @@ async def prompt(self, info, name: str, category: str = "Default") -> PromptType ) @strawberry.field - async def prompts(self, info, category: str = "Default") -> List[PromptType]: + async def prompts(self, info) -> List[PromptType]: """Get all prompts in a category""" user, auth = await get_user_from_context(info) prompt_manager = Prompts(user=user) - return prompt_manager.get_prompts(prompt_category=category) + result = prompt_manager.get_user_prompts() + return [ + PromptType( + name=prompt["name"], + content=prompt["content"], + category=prompt["category"], + description=prompt["description"], + arguments=[PromptArgument(name=arg) for arg in prompt["arguments"]], + ) + for prompt in result + ] + + @strawberry.field + async def promptLibrary(self, info) -> List[PromptType]: + """Get all prompts in a category""" + user, auth = await get_user_from_context(info) + prompt_manager = Prompts(user=user) + result = prompt_manager.get_global_prompts() + return [ + PromptType( + name=prompt["name"], + content=prompt["content"], + category=prompt["category"], + description=prompt["description"], + arguments=[PromptArgument(name=arg) for arg in prompt["arguments"]], + ) + for prompt in result + ] @strawberry.field async def prompt_categories(self, info) -> List[str]: