From 561c537c32c47330b374f180da7a8b4251827daa Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 16 Jan 2024 09:59:43 -0800 Subject: [PATCH 1/5] provide tools for multiple function calling --- chatlab/registry.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chatlab/registry.py b/chatlab/registry.py index e52696e..7a2f25e 100644 --- a/chatlab/registry.py +++ b/chatlab/registry.py @@ -428,6 +428,10 @@ def api_manifest(self, function_call_option: FunctionCall = "auto") -> APIManife "function_call": function_call_option, } + @property + def tools(self): + return [{"type": "function", "function": adapt_function_definition(f)} for f in self.__schemas.values()] + async def call(self, name: str, arguments: Optional[str] = None) -> Any: """Call a function by name with the given parameters.""" if name is None: From 71f64ccbad8c919d89e8f9aa9b147b9f4a7b9a13 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 16 Jan 2024 10:59:07 -0800 Subject: [PATCH 2/5] include helper for tool_result --- chatlab/__init__.py | 2 ++ chatlab/messaging.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/chatlab/__init__.py b/chatlab/__init__.py index 6305159..1959e02 100644 --- a/chatlab/__init__.py +++ b/chatlab/__init__.py @@ -30,6 +30,7 @@ narrate, system, user, + tool_result ) from .registry import FunctionRegistry from .views.markdown import Markdown @@ -85,6 +86,7 @@ def __init__(self, *args, **kwargs): "assistant", "assistant_function_call", "function_result", + "tool_result", "models", "Session", "Chat", diff --git a/chatlab/messaging.py b/chatlab/messaging.py index 142dfa7..84e3315 100644 --- a/chatlab/messaging.py +++ b/chatlab/messaging.py @@ -100,6 +100,24 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam: } +def tool_result(tool_call_id: str, name: str, content: str) -> ChatCompletionMessageParam: + """Create a tool result message. + + Args: + tool_call_id: The ID of the tool call. + name: The name of the tool. + content: The content of the message. + + Returns: + A dictionary representing a tool result message. + """ + return { + "role": "tool", + "content": content, + "name": name, + "tool_call_id": tool_call_id, + } + # Aliases narrate = system human = user From 282d22f0d502111c57b24fc20e439ab3e0940ba8 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 16 Jan 2024 11:00:25 -0800 Subject: [PATCH 3/5] show an example of using parallel function calling Some `SQLModel` stuff while we're at it. --- CHANGELOG.md | 4 + chatlab/messaging.py | 4 +- notebooks/parallel-function-calling.ipynb | 402 ++++++++++++++++++++++ 3 files changed, 408 insertions(+), 2 deletions(-) create mode 100644 notebooks/parallel-function-calling.ipynb diff --git a/CHANGELOG.md b/CHANGELOG.md index 31284ec..bb9a52e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.3.0] + +- Support tool call format from `FunctionRegistry`. Enables parallel function calling (note: not in `Chat` yet). https://github.com/rgbkrk/chatlab/pull/122 + ## [1.2.1] - Drop Noteable builtin diff --git a/chatlab/messaging.py b/chatlab/messaging.py index 84e3315..ba6e4ab 100644 --- a/chatlab/messaging.py +++ b/chatlab/messaging.py @@ -12,7 +12,7 @@ from typing import Optional -from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam def assistant(content: str) -> ChatCompletionMessageParam: @@ -100,7 +100,7 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam: } -def tool_result(tool_call_id: str, name: str, content: str) -> ChatCompletionMessageParam: +def tool_result(tool_call_id: str, name: str, content: str) -> ChatCompletionToolMessageParam: """Create a tool result message. Args: diff --git a/notebooks/parallel-function-calling.ipynb b/notebooks/parallel-function-calling.ipynb new file mode 100644 index 0000000..57abdbf --- /dev/null +++ b/notebooks/parallel-function-calling.ipynb @@ -0,0 +1,402 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install sqlmodel -q" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional\n", + "from sqlmodel import SQLModel, Field, create_engine, Session\n", + "\n", + "# SQL Model uses Pydantic Models under the hood\n", + "\n", + "class Character(SQLModel, table=True):\n", + " id: Optional[int] = Field(default=None, primary_key=True)\n", + " name: str\n", + " race: str\n", + " character_class: str\n", + " level: int\n", + " background: str\n", + " player_name: Optional[str] = None\n", + " experience_points: int = 0\n", + " strength: int\n", + " dexterity: int\n", + " constitution: int\n", + " intelligence: int\n", + " wisdom: int\n", + " charisma: int\n", + " hit_points: int\n", + " armor_class: int\n", + " alignment: str\n", + " skills: str # Storing as comma-separated string\n", + " languages: str # Storing as comma-separated string\n", + " equipment: str # Storing as comma-separated string\n", + " spells: Optional[str] = None # Storing as comma-separated string\n", + "\n", + " def _repr_llm_(self):\n", + " return f\"\"\n", + " \n", + " def __repr__(self):\n", + " return f\"\"\n", + "\n", + "# SQLite Database URL\n", + "DATABASE_URL = \"sqlite:///:memory:\"\n", + "engine = create_engine(DATABASE_URL)\n", + "\n", + "# Create the database tables\n", + "SQLModel.metadata.create_all(engine)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 13\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "13" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import random\n", + "from IPython.display import SVG\n", + "\n", + "def d20(roll):\n", + " return SVG(f\"\"\"\n", + " \n", + " \n", + " {roll}\n", + " \n", + "\"\"\")\n", + "\n", + "def roll_die(sides: int = 6):\n", + " \"\"\"Roll a die with the given number of sides.\"\"\"\n", + " roll = random.randint(1, sides)\n", + "\n", + " if(sides == 20):\n", + " display(d20(roll))\n", + " \n", + " return roll\n", + "\n", + "# Function to add a new character\n", + "def add_character(character: Character):\n", + " \"\"\"Adds a character to our characters database\"\"\"\n", + " with Session(engine) as session:\n", + " session.add(character)\n", + " session.commit\n", + "\n", + " return character\n", + "\n", + "roll_die(20)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from chatlab import FunctionRegistry\n", + "\n", + "fr = FunctionRegistry()\n", + "fr.register_functions([roll_die, add_character])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "client = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from chatlab import tool_result\n", + "\n", + "async def chatloop(initial_messages):\n", + " \"\"\"Emit messages encountered as well as tool results, making sure to autorun tools and respond to the model.\"\"\"\n", + " buffer = initial_messages.copy()\n", + "\n", + " resp = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-1106\",\n", + " messages=initial_messages,\n", + "\n", + " # Pass in the tools from the function registry. The model will choose\n", + " # whether it uses 0, 1, 2, or N many tools.\n", + " tools=fr.tools,\n", + " tool_choice=\"auto\"\n", + " )\n", + "\n", + " message = resp.choices[0].message\n", + " buffer.append(message)\n", + "\n", + " yield message\n", + "\n", + " # call each of the tools\n", + " if message.tool_calls is not None:\n", + " for tool in message.tool_calls:\n", + " result = await fr.call(tool.function.name, tool.function.arguments)\n", + "\n", + " # An assistant message with 'tool_calls' must be followed by tool messages responding to each 'tool_call_id'.\n", + " tool_call_response = tool_result(tool.id, content=str(result))\n", + " yield tool_call_response\n", + " buffer.append(tool_call_response)\n", + " \n", + " # Once all tools have been called, call the model again\n", + " async for m in chatloop(buffer):\n", + " yield m\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 15\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 11\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 9\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 17\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 4\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 17\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " 13\n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "> Here are the rolled stats:\n", + "> - Strength: 15\n", + "> - Perception: 11\n", + "> - Endurance: 9\n", + "> - Charisma: 17\n", + "> - Intelligence: 4\n", + "> - Agility: 17\n", + "> - Luck: 13" + ], + "text/plain": [ + "> Here are the rolled stats:\n", + "> - Strength: 15\n", + "> - Perception: 11\n", + "> - Endurance: 9\n", + "> - Charisma: 17\n", + "> - Intelligence: 4\n", + "> - Agility: 17\n", + "> - Luck: 13" + ] + }, + "metadata": { + "text/markdown": { + "chatlab": { + "default": true + } + } + }, + "output_type": "display_data" + } + ], + "source": [ + "from pydantic import BaseModel\n", + "from chatlab import system, user, Markdown\n", + "\n", + "async for message in chatloop([\n", + " system(\"Create your character for the Fallout RPG. The user is the DM.\"),\n", + " user(\"Roll for the following stats: Strength, Perception, Endurance, Charisma, Intelligence, Agility, and Luck.\")\n", + " ]):\n", + " # When message is a pydantic model, convert to a dict\n", + "\n", + " if isinstance(message, BaseModel):\n", + " message = message.model_dump()\n", + "\n", + " role = message['role']\n", + " content = message.get('content')\n", + "\n", + " if(role == \"assistant\" and content is not None):\n", + " display(Markdown(\"> \" + content.replace(\"\\n\", \"\\n> \")))\n", + " if(role == \"tool\"):\n", + " pass " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chatlab-3PJ-KiVK-py3.12", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 3142a2a5adc389f8d11b71affe4697a8845dbd87 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 16 Jan 2024 12:02:53 -0800 Subject: [PATCH 4/5] no name due to https://github.com/openai/openai-python/issues/1078 --- chatlab/messaging.py | 4 +- notebooks/parallel-function-calling.ipynb | 55 +++++++++++++---------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/chatlab/messaging.py b/chatlab/messaging.py index ba6e4ab..547e658 100644 --- a/chatlab/messaging.py +++ b/chatlab/messaging.py @@ -100,12 +100,11 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam: } -def tool_result(tool_call_id: str, name: str, content: str) -> ChatCompletionToolMessageParam: +def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessageParam: """Create a tool result message. Args: tool_call_id: The ID of the tool call. - name: The name of the tool. content: The content of the message. Returns: @@ -114,7 +113,6 @@ def tool_result(tool_call_id: str, name: str, content: str) -> ChatCompletionToo return { "role": "tool", "content": content, - "name": name, "tool_call_id": tool_call_id, } diff --git a/notebooks/parallel-function-calling.ipynb b/notebooks/parallel-function-calling.ipynb index 57abdbf..6bfe63a 100644 --- a/notebooks/parallel-function-calling.ipynb +++ b/notebooks/parallel-function-calling.ipynb @@ -209,7 +209,7 @@ "\n", " \n", " \n", - " 15\n", + " 6\n", " \n", "" ], @@ -226,7 +226,7 @@ "\n", " \n", " \n", - " 11\n", + " 4\n", " \n", "" ], @@ -243,7 +243,7 @@ "\n", " \n", " \n", - " 9\n", + " 8\n", " \n", "" ], @@ -260,7 +260,7 @@ "\n", " \n", " \n", - " 17\n", + " 10\n", " \n", "" ], @@ -277,7 +277,7 @@ "\n", " \n", " \n", - " 4\n", + " 10\n", " \n", "" ], @@ -294,7 +294,7 @@ "\n", " \n", " \n", - " 17\n", + " 2\n", " \n", "" ], @@ -311,7 +311,7 @@ "\n", " \n", " \n", - " 13\n", + " 18\n", " \n", "" ], @@ -325,24 +325,26 @@ { "data": { "text/markdown": [ - "> Here are the rolled stats:\n", - "> - Strength: 15\n", - "> - Perception: 11\n", - "> - Endurance: 9\n", - "> - Charisma: 17\n", - "> - Intelligence: 4\n", - "> - Agility: 17\n", - "> - Luck: 13" + "> Here are the results for your character's stats:\n", + "> \n", + "> - Strength: 6\n", + "> - Perception: 4\n", + "> - Endurance: 8\n", + "> - Charisma: 10\n", + "> - Intelligence: 10\n", + "> - Agility: 2\n", + "> - Luck: 18" ], "text/plain": [ - "> Here are the rolled stats:\n", - "> - Strength: 15\n", - "> - Perception: 11\n", - "> - Endurance: 9\n", - "> - Charisma: 17\n", - "> - Intelligence: 4\n", - "> - Agility: 17\n", - "> - Luck: 13" + "> Here are the results for your character's stats:\n", + "> \n", + "> - Strength: 6\n", + "> - Perception: 4\n", + "> - Endurance: 8\n", + "> - Charisma: 10\n", + "> - Intelligence: 10\n", + "> - Agility: 2\n", + "> - Luck: 18" ] }, "metadata": { @@ -376,6 +378,13 @@ " if(role == \"tool\"):\n", " pass " ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From a6d73740b9566fc9717db518f63ae4e2afd3fe6d Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 16 Jan 2024 12:06:26 -0800 Subject: [PATCH 5/5] include test for tool calling format --- tests/test_registry.py | 61 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_registry.py b/tests/test_registry.py index d936058..e95bc0f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -376,3 +376,64 @@ async def test_function_registry_call_edge_cases(): with pytest.raises(UnknownFunctionError): await registry.call(None) # type: ignore + + +# Test the tool calling format +@pytest.mark.asyncio +async def test_function_registry_call_tool(): + registry = FunctionRegistry() + registry.register(simple_func, SimpleModel) + + registry.register(simple_func_with_model_arg) + + tools = registry.tools + + assert tools == [{ + "type": "function", + "function": { + "name": "simple_func", + "description": "A simple test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "string"}, + "z": {"type": "boolean", "default": False, "description": "A simple boolean field"}, + }, + "required": ["x", "y"], + }, + } + }, { + "type": "function", + "function": { + "name": "simple_func_with_model_arg", + "description": "A simple test function with a model argument", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "string"}, + "z": {"default": False, "type": "boolean"}, + "model": {"allOf": [{"$ref": "#/$defs/SimpleModel"}], "default": None}, + }, + "required": ["x", "y"], + "$defs": { + "SimpleModel": { + "title": "SimpleModel", + "type": "object", + "properties": { + "x": {"title": "X", "type": "integer"}, + "y": {"title": "Y", "type": "string"}, + "z": { + "title": "Z", + "description": "A simple boolean field", + "default": False, + "type": "boolean", + }, + }, + "required": ["x", "y"], + } + }, + }, + } + }] \ No newline at end of file