From 18f35c9e686f01dee14b050b670a5e3ee3afdc76 Mon Sep 17 00:00:00 2001 From: unclecode Date: Sun, 17 Mar 2024 17:35:24 +0800 Subject: [PATCH] Updates: - chains.py: Supporting "tool_choice". - Update cookbook examples. - Improve promopts, for "force" mode. --- app/libs/chains.py | 35 ++++-- app/prompts.py | 33 +++++- cookbook/function_call_force_schema.py | 137 +++++++++++++++++++++++ cookbook/function_call_with_schema.py | 3 +- cookbook/function_call_without_schema.py | 4 +- 5 files changed, 201 insertions(+), 11 deletions(-) create mode 100644 cookbook/function_call_force_schema.py diff --git a/app/libs/chains.py b/app/libs/chains.py index 2386458..f44215b 100644 --- a/app/libs/chains.py +++ b/app/libs/chains.py @@ -3,10 +3,11 @@ from importlib import import_module import json import uuid +import traceback from fastapi import Request from fastapi.responses import JSONResponse from providers import BaseProvider -from prompts import SYSTEM_MESSAGE, SUFFIX, CLEAN_UP_MESSAGE, get_func_result_guide +from prompts import SYSTEM_MESSAGE, ENFORCED_SYSTAME_MESSAE, SUFFIX, FORCE_CALL_SUFFIX, CLEAN_UP_MESSAGE, get_func_result_guide, get_forced_tool_suffix from providers import GroqProvider import importlib from utils import get_tool_call_response, create_logger @@ -19,8 +20,10 @@ def __init__(self, request: Request, provider: str, body: Dict[str, Any]): self.provider = provider self.body = body self.response = None + # extract all keys from body except messages and tools and set in params self.params = {k: v for k, v in body.items() if k not in ["messages", "tools"]} + # self.no_tool_behaviour = self.params.get("no_tool_behaviour", "return") self.no_tool_behaviour = self.params.get("no_tool_behaviour", "forward") self.params.pop("no_tool_behaviour", None) @@ -50,8 +53,6 @@ def __init__(self, request: Request, provider: str, body: Dict[str, Any]): bt['extra'] = self.params.get("extra", {}) self.params.pop("extra", None) - - self.client : BaseProvider = None @property @@ -60,7 +61,7 @@ def last_message(self): @property def is_tool_call(self): - return bool(self.last_message["role"] == "user" and self.tools) + return bool(self.last_message["role"] == "user" and self.tools and self.params.get("tool_choice", "none") != "none") @property def is_tool_response(self): @@ -88,6 +89,7 @@ async def handle(self, context: Context): return await self._next_handler.handle(context) except Exception as e: _exception_handler: "Handler" = ExceptionHandler() + # Extract the stack trace and log the exception return await _exception_handler.handle(context, e) @@ -130,19 +132,35 @@ class ToolExtractionHandler(Handler): async def handle(self, context: Context): body = context.body if context.is_tool_call: + + # Prepare the messages and tools for the tool extraction messages = [ f"{m['role'].title()}: {m['content']}" for m in context.messages if m["role"] != "system" ] - tools_json = json.dumps([t["function"] for t in context.tools], indent=4) + # Process the tool_choice + tool_choice = context.params.get("tool_choice", "auto") + forced_mode = False + if type(tool_choice) == dict and tool_choice.get("type", None) == "function": + tool_choice = tool_choice["function"].get("name", None) + if not tool_choice: + raise ValueError("Invalid tool choice. 'tool_choice' is set to a dictionary with 'type' as 'function', but 'function' does not have a 'name' key.") + forced_mode = True + + # Regenerate the string tool_json and keep only the forced tool + tools_json = json.dumps([t["function"] for t in context.tools if t["function"]["name"] == tool_choice], indent=4) + + system_message = SYSTEM_MESSAGE if not forced_mode else ENFORCED_SYSTAME_MESSAE + suffix = SUFFIX if not forced_mode else get_forced_tool_suffix(tool_choice) + new_messages = [ - {"role": "system", "content": SYSTEM_MESSAGE}, + {"role": "system", "content": system_message}, { "role": "system", - "content": f"Conversation History:\n{''.join(messages)}\n\nTools: \n{tools_json}\n\n{SUFFIX}", + "content": f"Conversation History:\n{''.join(messages)}\n\nTools: \n{tools_json}\n\n{suffix}", }, ] @@ -309,4 +327,5 @@ async def handle(self, context: Context): class ExceptionHandler(Handler): async def handle(self, context: Context, exception: Exception): print(f"Error processing the request: {exception}") - return JSONResponse(content={"error": "An unexpected error occurred. " + str(exception)}, status_code=500) \ No newline at end of file + print(traceback.format_exc()) + return JSONResponse(content={"error": "An unexpected error occurred. " + str(exception)}, status_code=500) \ No newline at end of file diff --git a/app/prompts.py b/app/prompts.py index c2de2a4..8ed9997 100644 --- a/app/prompts.py +++ b/app/prompts.py @@ -33,13 +33,44 @@ ** If no tools are required, then return an empty list for "tool_calls". ** -**Wrap the JSON response between ```json and ```**. +**Wrap the JSON response between ```json and ```, and rememebr "tool_calls" is a list.**. **Whenever a message starts with 'SYSTEM MESSAGE', that is a guide and help information for you to generate your next response, do not consider them a message from the user, and do not reply to them at all. Just use the information and continue your conversation with the user.**""" + +ENFORCED_SYSTAME_MESSAE = """A history of conversations between an AI assistant and the user, plus the last user's message, is given to you. + +You have access to a specific tool that the AI assistant must use to provide a proper answer. The tool is a function that requires a set of parameters, which are provided in a JSON schema to explain what parameters the tool needs. Your task is to extract the values for these parameters from the user's last message and the conversation history. + +Your job is to closely examine the user's last message and the history of the conversation, then extract the necessary parameter values for the given tool based on the provided JSON schema. Remember that you must use the specified tool to generate the response. + +You should think step by step, provide your reasoning for your response, then add the JSON response at the end following the below schema: + + +{ + "tool_calls": [{ + "name": "function_name", + "arguments": { + "arg1": "value1", + "arg2": "value2", + ... + }] + } +} + + +**Wrap the JSON response between ```json and ```, and rememebr "tool_calls" is a list.**. + +Whenever a message starts with 'SYSTEM MESSAGE', that is a guide and help information for you to generate your next response. Do not consider them a message from the user, and do not reply to them at all. Just use the information and continue your conversation with the user.""" + CLEAN_UP_MESSAGE = "When I tried to extract the content between ```json and ``` and parse the content to valid JSON object, I faced with the abovr error. Remember, you are supposed to wrap the schema between ```json and ```, and do this only one time. First find out what went wrong, that I couldn't extract the JSON between ```json and ```, and also faced error when trying to parse it, then regenerate the your last message and fix the issue." + SUFFIX = """Think step by step and justify your response. Make sure to not miss in case to answer user query we need multiple tools, in that case detect all that we need, then generate a JSON response wrapped between "```json" and "```". Remember to USE THIS JSON WRAPPER ONLY ONE TIME.""" +FORCE_CALL_SUFFIX = """For this task, you HAVE to choose the tool (function) {tool_name}, and ignore other rools. Therefore think step by step and justify your response, then closely examine the user's last message and the history of the conversation, then extract the necessary parameter values for the given tool based on the provided JSON schema. Remember that you must use the specified tool to generate the response. Finally generate a JSON response wrapped between "```json" and "```". Remember to USE THIS JSON WRAPPER ONLY ONE TIME.""" + +def get_forced_tool_suffix(tool_name : str) -> str: + return FORCE_CALL_SUFFIX.format(tool_name=tool_name) def get_func_result_guide(function_call_result : str) -> str: return f"SYSTEM MESSAGE: \n```json\n{function_call_result}\n```\n\nThe above is the result after functions are called. Use the result to answer the user's last question.\n\n" \ No newline at end of file diff --git a/cookbook/function_call_force_schema.py b/cookbook/function_call_force_schema.py new file mode 100644 index 0000000..1f3de5e --- /dev/null +++ b/cookbook/function_call_force_schema.py @@ -0,0 +1,137 @@ + +from duckduckgo_search import DDGS +import requests, os +import json + +api_key=os.environ["GROQ_API_KEY"] +header = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" +} +proxy_url = "https://groqcall.ai/proxy/groq/v1/chat/completions" + +# or "http://localhost:8000/proxy/groq/v1/chat/completions" if running locally +# proxy_url = "http://localhost:8000/proxy/groq/v1/chat/completions" + + +def duckduckgo_search(query, max_results=None): + """ + Use this function to search DuckDuckGo for a query. + """ + with DDGS() as ddgs: + return [r for r in ddgs.text(query, safesearch='off', max_results=max_results)] + +def duckduckgo_news(query, max_results=None): + """ + Use this function to get the latest news from DuckDuckGo. + """ + with DDGS() as ddgs: + return [r for r in ddgs.news(query, safesearch='off', max_results=max_results)] + +function_map = { + "duckduckgo_search": duckduckgo_search, + "duckduckgo_news": duckduckgo_news, +} + +request = { + "messages": [ + { + "role": "system", + "content": "YOU MUST FOLLOW THESE INSTRUCTIONS CAREFULLY.\n\n1. Use markdown to format your answers.\n" + }, + { + "role": "user", + "content": "Whats happening in France? Summarize top stories with sources, very short and concise." + } + ], + "model": "mixtral-8x7b-32768", + # "tool_choice": "auto", + # "tool_choice": "none", + "tool_choice": {"type": "function", "function": {"name": "duckduckgo_search"}}, + "tools": [ + { + "type": "function", + "function": { + "name": "duckduckgo_search", + "description": "Use this function to search DuckDuckGo for a query.\n\nArgs:\n query(str): The query to search for.\n max_results (optional, default=5): The maximum number of results to return.\n\nReturns:\n The result from DuckDuckGo.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string" + }, + "max_results": { + "type": [ + "number", + "null" + ] + } + } + } + } + }, + { + "type": "function", + "function": { + "name": "duckduckgo_news", + "description": "Use this function to get the latest news from DuckDuckGo.\n\nArgs:\n query(str): The query to search for.\n max_results (optional, default=5): The maximum number of results to return.\n\nReturns:\n The latest news from DuckDuckGo.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string" + }, + "max_results": { + "type": [ + "number", + "null" + ] + } + } + } + } + } + ] +} + +response = requests.post( + proxy_url, + headers=header, + json=request +) +# Check if the request was successful +if response.status_code == 200: + # Process the response data (if needed) + res = response.json() + message = res['choices'][0]['message'] + tools_response_messages = [] + if not message['content'] and 'tool_calls' in message: + for tool_call in message['tool_calls']: + tool_name = tool_call['function']['name'] + tool_args = tool_call['function']['arguments'] + tool_args = json.loads(tool_args) + if tool_name not in function_map: + print(f"Error: {tool_name} is not a valid function name.") + continue + tool_func = function_map[tool_name] + tool_response = tool_func(**tool_args) + tools_response_messages.append({ + "role": "tool", "content": json.dumps(tool_response) + }) + + if tools_response_messages: + request['messages'] += tools_response_messages + response = requests.post( + proxy_url, + headers=header, + json=request + ) + if response.status_code == 200: + res = response.json() + print(res['choices'][0]['message']['content']) + else: + print("Error:", response.status_code, response.text) + else: + print(message['content']) +else: + print("Error:", response.status_code, response.text) diff --git a/cookbook/function_call_with_schema.py b/cookbook/function_call_with_schema.py index 8b92118..c2c568e 100644 --- a/cookbook/function_call_with_schema.py +++ b/cookbook/function_call_with_schema.py @@ -1,8 +1,9 @@ from duckduckgo_search import DDGS import requests, os -api_key=os.environ["GROQ_API_KEY"] import json + +api_key=os.environ["GROQ_API_KEY"] header = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" diff --git a/cookbook/function_call_without_schema.py b/cookbook/function_call_without_schema.py index 39396a8..6eea454 100644 --- a/cookbook/function_call_without_schema.py +++ b/cookbook/function_call_without_schema.py @@ -1,6 +1,8 @@ import requests +import json +import os -api_key = "YOUR_GROQ_API_KEY" +api_key=os.environ["GROQ_API_KEY"], header = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json"