From 1c89447d63c86c6f6dd2a209a5bc34e4e51ca753 Mon Sep 17 00:00:00 2001 From: AT Date: Thu, 19 Dec 2024 16:31:37 -0500 Subject: [PATCH] Code interpreter (#3173) Signed-off-by: Adam Treat --- common/common.cmake | 1 - gpt4all-chat/CHANGELOG.md | 3 + gpt4all-chat/CMakeLists.txt | 8 +- gpt4all-chat/metadata/models3.json | 16 + gpt4all-chat/qml/AddGPT4AllModelView.qml | 46 ++ gpt4all-chat/qml/ChatCollapsibleItem.qml | 160 ++++ gpt4all-chat/qml/ChatItemView.qml | 173 ++--- gpt4all-chat/qml/ChatTextItem.qml | 139 ++++ gpt4all-chat/qml/ChatView.qml | 2 + gpt4all-chat/src/chat.cpp | 68 +- gpt4all-chat/src/chat.h | 4 +- gpt4all-chat/src/chatlistmodel.cpp | 2 +- gpt4all-chat/src/chatllm.cpp | 187 ++--- gpt4all-chat/src/chatllm.h | 12 +- gpt4all-chat/src/chatmodel.cpp | 345 +++++++++ gpt4all-chat/src/chatmodel.h | 907 ++++++++++++++--------- gpt4all-chat/src/codeinterpreter.cpp | 125 ++++ gpt4all-chat/src/codeinterpreter.h | 84 +++ gpt4all-chat/src/jinja_helpers.cpp | 24 +- gpt4all-chat/src/jinja_helpers.h | 6 +- gpt4all-chat/src/main.cpp | 3 + gpt4all-chat/src/modellist.cpp | 14 +- gpt4all-chat/src/modellist.h | 5 + gpt4all-chat/src/server.cpp | 18 +- gpt4all-chat/src/tool.cpp | 74 ++ gpt4all-chat/src/tool.h | 127 ++++ gpt4all-chat/src/toolcallparser.cpp | 111 +++ gpt4all-chat/src/toolcallparser.h | 47 ++ gpt4all-chat/src/toolmodel.cpp | 31 + gpt4all-chat/src/toolmodel.h | 110 +++ 30 files changed, 2236 insertions(+), 616 deletions(-) create mode 100644 gpt4all-chat/qml/ChatCollapsibleItem.qml create mode 100644 gpt4all-chat/qml/ChatTextItem.qml create mode 100644 gpt4all-chat/src/chatmodel.cpp create mode 100644 gpt4all-chat/src/codeinterpreter.cpp create mode 100644 gpt4all-chat/src/codeinterpreter.h create mode 100644 gpt4all-chat/src/tool.cpp create mode 100644 gpt4all-chat/src/tool.h create mode 100644 gpt4all-chat/src/toolcallparser.cpp create mode 100644 gpt4all-chat/src/toolcallparser.h create mode 100644 gpt4all-chat/src/toolmodel.cpp create mode 100644 gpt4all-chat/src/toolmodel.h diff --git a/common/common.cmake b/common/common.cmake index a3d3b1a0c005..b8b6e969a357 100644 --- a/common/common.cmake +++ b/common/common.cmake @@ -11,7 +11,6 @@ function(gpt4all_add_warning_options target) -Wextra-semi -Wformat=2 -Wmissing-include-dirs - -Wstrict-overflow=2 -Wsuggest-override -Wvla # errors diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index f144d449ed2f..2dd567792e0b 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ## [Unreleased] +### Added +- Built-in javascript code interpreter tool plus model ([#3173](https://github.com/nomic-ai/gpt4all/pull/3173)) + ### Fixed - Fix remote model template to allow for XML in messages ([#3318](https://github.com/nomic-ai/gpt4all/pull/3318)) - Fix Jinja2Cpp bug that broke system message detection in chat templates ([#3325](https://github.com/nomic-ai/gpt4all/pull/3325)) diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index b3f98c986243..60d75a560f67 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -193,8 +193,9 @@ qt_add_executable(chat src/chatapi.cpp src/chatapi.h src/chatlistmodel.cpp src/chatlistmodel.h src/chatllm.cpp src/chatllm.h - src/chatmodel.h + src/chatmodel.h src/chatmodel.cpp src/chatviewtextprocessor.cpp src/chatviewtextprocessor.h + src/codeinterpreter.cpp src/codeinterpreter.h src/database.cpp src/database.h src/download.cpp src/download.h src/embllm.cpp src/embllm.h @@ -207,6 +208,9 @@ qt_add_executable(chat src/mysettings.cpp src/mysettings.h src/network.cpp src/network.h src/server.cpp src/server.h + src/tool.cpp src/tool.h + src/toolcallparser.cpp src/toolcallparser.h + src/toolmodel.cpp src/toolmodel.h src/xlsxtomd.cpp src/xlsxtomd.h ${CHAT_EXE_RESOURCES} ${MACOS_SOURCES} @@ -225,8 +229,10 @@ qt_add_qml_module(chat qml/AddHFModelView.qml qml/ApplicationSettings.qml qml/ChatDrawer.qml + qml/ChatCollapsibleItem.qml qml/ChatItemView.qml qml/ChatMessageButton.qml + qml/ChatTextItem.qml qml/ChatView.qml qml/CollectionsDrawer.qml qml/HomeView.qml diff --git a/gpt4all-chat/metadata/models3.json b/gpt4all-chat/metadata/models3.json index 6c2444ec3b08..e93bdcfca56f 100644 --- a/gpt4all-chat/metadata/models3.json +++ b/gpt4all-chat/metadata/models3.json @@ -1,6 +1,22 @@ [ { "order": "a", + "md5sum": "a54c08a7b90e4029a8c2ab5b5dc936aa", + "name": "Reasoner v1", + "filename": "qwen2.5-coder-7b-instruct-q4_0.gguf", + "filesize": "4431390720", + "requires": "3.5.4-dev0", + "ramrequired": "8", + "parameters": "8 billion", + "quant": "q4_0", + "type": "qwen2", + "description": "", + "url": "https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct-GGUF/resolve/main/qwen2.5-coder-7b-instruct-q4_0.gguf", + "chatTemplate": "{{- '<|im_start|>system\\n' }}\n{% if toolList|length > 0 %}You have access to the following functions:\n{% for tool in toolList %}\nUse the function '{{tool.function}}' to: '{{tool.description}}'\n{% if tool.parameters|length > 0 %}\nparameters:\n{% for info in tool.parameters %}\n {{info.name}}:\n type: {{info.type}}\n description: {{info.description}}\n required: {{info.required}}\n{% endfor %}\n{% endif %}\n# Tool Instructions\nIf you CHOOSE to call this function ONLY reply with the following format:\n'{{tool.symbolicFormat}}'\nHere is an example. If the user says, '{{tool.examplePrompt}}', then you reply\n'{{tool.exampleCall}}'\nAfter the result you might reply with, '{{tool.exampleReply}}'\n{% endfor %}\nYou MUST include both the start and end tags when you use a function.\n\nYou are a helpful AI assistant who uses the functions to break down, analyze, perform, and verify complex reasoning tasks. You SHOULD try to verify your answers using the functions where possible.\n{% endif %}\n{{- '<|im_end|>\\n' }}\n{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}\n{% endfor %}\n{% if add_generation_prompt %}\n{{ '<|im_start|>assistant\\n' }}\n{% endif %}\n", + "systemPrompt": "" + }, + { + "order": "aa", "md5sum": "c87ad09e1e4c8f9c35a5fcef52b6f1c9", "name": "Llama 3 8B Instruct", "filename": "Meta-Llama-3-8B-Instruct.Q4_0.gguf", diff --git a/gpt4all-chat/qml/AddGPT4AllModelView.qml b/gpt4all-chat/qml/AddGPT4AllModelView.qml index dd8da3ed90f5..2a1832af14fa 100644 --- a/gpt4all-chat/qml/AddGPT4AllModelView.qml +++ b/gpt4all-chat/qml/AddGPT4AllModelView.qml @@ -56,6 +56,52 @@ ColumnLayout { Accessible.description: qsTr("Displayed when the models request is ongoing") } + RowLayout { + ButtonGroup { + id: buttonGroup + exclusive: true + } + MyButton { + text: qsTr("All") + checked: true + borderWidth: 0 + backgroundColor: checked ? theme.lightButtonBackground : "transparent" + backgroundColorHovered: theme.lighterButtonBackgroundHovered + backgroundRadius: 5 + padding: 15 + topPadding: 8 + bottomPadding: 8 + textColor: theme.lighterButtonForeground + fontPixelSize: theme.fontSizeLarge + fontPixelBold: true + checkable: true + ButtonGroup.group: buttonGroup + onClicked: { + ModelList.gpt4AllDownloadableModels.filter(""); + } + + } + MyButton { + text: qsTr("Reasoning") + borderWidth: 0 + backgroundColor: checked ? theme.lightButtonBackground : "transparent" + backgroundColorHovered: theme.lighterButtonBackgroundHovered + backgroundRadius: 5 + padding: 15 + topPadding: 8 + bottomPadding: 8 + textColor: theme.lighterButtonForeground + fontPixelSize: theme.fontSizeLarge + fontPixelBold: true + checkable: true + ButtonGroup.group: buttonGroup + onClicked: { + ModelList.gpt4AllDownloadableModels.filter("#reasoning"); + } + } + Layout.bottomMargin: 10 + } + ScrollView { id: scrollView ScrollBar.vertical.policy: ScrollBar.AsNeeded diff --git a/gpt4all-chat/qml/ChatCollapsibleItem.qml b/gpt4all-chat/qml/ChatCollapsibleItem.qml new file mode 100644 index 000000000000..4ff01511bf9b --- /dev/null +++ b/gpt4all-chat/qml/ChatCollapsibleItem.qml @@ -0,0 +1,160 @@ +import Qt5Compat.GraphicalEffects +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts + +import gpt4all +import mysettings +import toolenums + +ColumnLayout { + property alias textContent: innerTextItem.textContent + property bool isCurrent: false + property bool isError: false + + Layout.topMargin: 10 + Layout.bottomMargin: 10 + + Item { + Layout.preferredWidth: childrenRect.width + Layout.preferredHeight: 38 + RowLayout { + anchors.left: parent.left + anchors.top: parent.top + anchors.bottom: parent.bottom + + Item { + width: myTextArea.width + height: myTextArea.height + TextArea { + id: myTextArea + text: { + if (isError) + return qsTr("Analysis encountered error"); + if (isCurrent) + return qsTr("Analyzing"); + return qsTr("Analyzed"); + } + padding: 0 + font.pixelSize: theme.fontSizeLarger + enabled: false + focus: false + readOnly: true + color: headerMA.containsMouse ? theme.mutedDarkTextColorHovered : theme.mutedTextColor + hoverEnabled: false + } + + Item { + id: textColorOverlay + anchors.fill: parent + clip: true + visible: false + Rectangle { + id: animationRec + width: myTextArea.width * 0.3 + anchors.top: parent.top + anchors.bottom: parent.bottom + color: theme.textColor + + SequentialAnimation { + running: isCurrent + loops: Animation.Infinite + NumberAnimation { + target: animationRec; + property: "x"; + from: -animationRec.width; + to: myTextArea.width * 3; + duration: 2000 + } + } + } + } + OpacityMask { + visible: isCurrent + anchors.fill: parent + maskSource: myTextArea + source: textColorOverlay + } + } + + Item { + id: caret + Layout.preferredWidth: contentCaret.width + Layout.preferredHeight: contentCaret.height + Image { + id: contentCaret + anchors.centerIn: parent + visible: false + sourceSize.width: theme.fontSizeLarge + sourceSize.height: theme.fontSizeLarge + mipmap: true + source: { + if (contentLayout.state === "collapsed") + return "qrc:/gpt4all/icons/caret_right.svg"; + else + return "qrc:/gpt4all/icons/caret_down.svg"; + } + } + + ColorOverlay { + anchors.fill: contentCaret + source: contentCaret + color: headerMA.containsMouse ? theme.mutedDarkTextColorHovered : theme.mutedTextColor + } + } + } + + MouseArea { + id: headerMA + hoverEnabled: true + anchors.fill: parent + onClicked: { + if (contentLayout.state === "collapsed") + contentLayout.state = "expanded"; + else + contentLayout.state = "collapsed"; + } + } + } + + ColumnLayout { + id: contentLayout + spacing: 0 + state: "collapsed" + clip: true + + states: [ + State { + name: "expanded" + PropertyChanges { target: contentLayout; Layout.preferredHeight: innerContentLayout.height } + }, + State { + name: "collapsed" + PropertyChanges { target: contentLayout; Layout.preferredHeight: 0 } + } + ] + + transitions: [ + Transition { + SequentialAnimation { + PropertyAnimation { + target: contentLayout + property: "Layout.preferredHeight" + duration: 300 + easing.type: Easing.InOutQuad + } + } + } + ] + + ColumnLayout { + id: innerContentLayout + Layout.leftMargin: 30 + ChatTextItem { + id: innerTextItem + } + } + } +} \ No newline at end of file diff --git a/gpt4all-chat/qml/ChatItemView.qml b/gpt4all-chat/qml/ChatItemView.qml index e6a48bbc108e..9cac2fcce157 100644 --- a/gpt4all-chat/qml/ChatItemView.qml +++ b/gpt4all-chat/qml/ChatItemView.qml @@ -4,9 +4,11 @@ import QtQuick import QtQuick.Controls import QtQuick.Controls.Basic import QtQuick.Layouts +import Qt.labs.qmlmodels import gpt4all import mysettings +import toolenums ColumnLayout { @@ -33,6 +35,9 @@ GridLayout { Layout.alignment: Qt.AlignVCenter | Qt.AlignRight Layout.preferredWidth: 32 Layout.preferredHeight: 32 + Layout.topMargin: model.index > 0 ? 25 : 0 + visible: content !== "" || childItems.length > 0 + Image { id: logo sourceSize: Qt.size(32, 32) @@ -65,6 +70,9 @@ GridLayout { Layout.column: 1 Layout.fillWidth: true Layout.preferredHeight: 38 + Layout.topMargin: model.index > 0 ? 25 : 0 + visible: content !== "" || childItems.length > 0 + RowLayout { spacing: 5 anchors.left: parent.left @@ -72,7 +80,11 @@ GridLayout { anchors.bottom: parent.bottom TextArea { - text: name === "Response: " ? qsTr("GPT4All") : qsTr("You") + text: { + if (name === "Response: ") + return qsTr("GPT4All"); + return qsTr("You"); + } padding: 0 font.pixelSize: theme.fontSizeLarger font.bold: true @@ -88,7 +100,7 @@ GridLayout { color: theme.mutedTextColor } RowLayout { - visible: isCurrentResponse && (value === "" && currentChat.responseInProgress) + visible: isCurrentResponse && (content === "" && currentChat.responseInProgress) Text { color: theme.mutedTextColor font.pixelSize: theme.fontSizeLarger @@ -156,131 +168,36 @@ GridLayout { } } - TextArea { - id: myTextArea - Layout.fillWidth: true - padding: 0 - color: { - if (!currentChat.isServer) - return theme.textColor - return theme.white - } - wrapMode: Text.WordWrap - textFormat: TextEdit.PlainText - focus: false - readOnly: true - font.pixelSize: theme.fontSizeLarge - cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false - cursorPosition: text.length - TapHandler { - id: tapHandler - onTapped: function(eventPoint, button) { - var clickedPos = myTextArea.positionAt(eventPoint.position.x, eventPoint.position.y); - var success = textProcessor.tryCopyAtPosition(clickedPos); - if (success) - copyCodeMessage.open(); - } - } - - MouseArea { - id: conversationMouseArea - anchors.fill: parent - acceptedButtons: Qt.RightButton - - onClicked: (mouse) => { - if (mouse.button === Qt.RightButton) { - conversationContextMenu.x = conversationMouseArea.mouseX - conversationContextMenu.y = conversationMouseArea.mouseY - conversationContextMenu.open() - } - } - } - - onLinkActivated: function(link) { - if (!isCurrentResponse || !currentChat.responseInProgress) - Qt.openUrlExternally(link) - } - - onLinkHovered: function (link) { - if (!isCurrentResponse || !currentChat.responseInProgress) - statusBar.externalHoveredLink = link - } - - MyMenu { - id: conversationContextMenu - MyMenuItem { - text: qsTr("Copy") - enabled: myTextArea.selectedText !== "" - height: enabled ? implicitHeight : 0 - onTriggered: myTextArea.copy() - } - MyMenuItem { - text: qsTr("Copy Message") - enabled: myTextArea.selectedText === "" - height: enabled ? implicitHeight : 0 - onTriggered: { - myTextArea.selectAll() - myTextArea.copy() - myTextArea.deselect() + Repeater { + model: childItems + + DelegateChooser { + id: chooser + role: "name" + DelegateChoice { + roleValue: "Text: "; + ChatTextItem { + Layout.fillWidth: true + textContent: modelData.content } } - MyMenuItem { - text: textProcessor.shouldProcessText ? qsTr("Disable markdown") : qsTr("Enable markdown") - height: enabled ? implicitHeight : 0 - onTriggered: { - textProcessor.shouldProcessText = !textProcessor.shouldProcessText; - textProcessor.setValue(value); + DelegateChoice { + roleValue: "ToolCall: "; + ChatCollapsibleItem { + Layout.fillWidth: true + textContent: modelData.content + isCurrent: modelData.isCurrentResponse + isError: modelData.isToolCallError } } } - ChatViewTextProcessor { - id: textProcessor - } - - function resetChatViewTextProcessor() { - textProcessor.fontPixelSize = myTextArea.font.pixelSize - textProcessor.codeColors.defaultColor = theme.codeDefaultColor - textProcessor.codeColors.keywordColor = theme.codeKeywordColor - textProcessor.codeColors.functionColor = theme.codeFunctionColor - textProcessor.codeColors.functionCallColor = theme.codeFunctionCallColor - textProcessor.codeColors.commentColor = theme.codeCommentColor - textProcessor.codeColors.stringColor = theme.codeStringColor - textProcessor.codeColors.numberColor = theme.codeNumberColor - textProcessor.codeColors.headerColor = theme.codeHeaderColor - textProcessor.codeColors.backgroundColor = theme.codeBackgroundColor - textProcessor.textDocument = textDocument - textProcessor.setValue(value); - } - - property bool textProcessorReady: false - - Component.onCompleted: { - resetChatViewTextProcessor(); - textProcessorReady = true; - } - - Connections { - target: chatModel - function onValueChanged(i, value) { - if (myTextArea.textProcessorReady && index === i) - textProcessor.setValue(value); - } - } - - Connections { - target: MySettings - function onFontSizeChanged() { - myTextArea.resetChatViewTextProcessor(); - } - function onChatThemeChanged() { - myTextArea.resetChatViewTextProcessor(); - } - } + delegate: chooser + } - Accessible.role: Accessible.Paragraph - Accessible.name: text - Accessible.description: name === "Response: " ? "The response by the model" : "The prompt by the user" + ChatTextItem { + Layout.fillWidth: true + textContent: content } ThumbsDownDialog { @@ -289,16 +206,16 @@ GridLayout { y: Math.round((parent.height - height) / 2) width: 640 height: 300 - property string text: value + property string text: content response: newResponse === undefined || newResponse === "" ? text : newResponse onAccepted: { var responseHasChanged = response !== text && response !== newResponse if (thumbsDownState && !thumbsUpState && !responseHasChanged) return - chatModel.updateNewResponse(index, response) - chatModel.updateThumbsUpState(index, false) - chatModel.updateThumbsDownState(index, true) + chatModel.updateNewResponse(model.index, response) + chatModel.updateThumbsUpState(model.index, false) + chatModel.updateThumbsDownState(model.index, true) Network.sendConversation(currentChat.id, getConversationJson()); } } @@ -416,7 +333,7 @@ GridLayout { states: [ State { name: "expanded" - PropertyChanges { target: sourcesLayout; Layout.preferredHeight: flow.height } + PropertyChanges { target: sourcesLayout; Layout.preferredHeight: sourcesFlow.height } }, State { name: "collapsed" @@ -438,7 +355,7 @@ GridLayout { ] Flow { - id: flow + id: sourcesFlow Layout.fillWidth: true spacing: 10 visible: consolidatedSources.length !== 0 @@ -617,9 +534,7 @@ GridLayout { name: qsTr("Copy") source: "qrc:/gpt4all/icons/copy.svg" onClicked: { - myTextArea.selectAll(); - myTextArea.copy(); - myTextArea.deselect(); + chatModel.copyToClipboard(index); } } diff --git a/gpt4all-chat/qml/ChatTextItem.qml b/gpt4all-chat/qml/ChatTextItem.qml new file mode 100644 index 000000000000..e316bf1ce7bc --- /dev/null +++ b/gpt4all-chat/qml/ChatTextItem.qml @@ -0,0 +1,139 @@ +import Qt5Compat.GraphicalEffects +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts + +import gpt4all +import mysettings +import toolenums + +TextArea { + id: myTextArea + property string textContent: "" + visible: textContent != "" + Layout.fillWidth: true + padding: 0 + color: { + if (!currentChat.isServer) + return theme.textColor + return theme.white + } + wrapMode: Text.WordWrap + textFormat: TextEdit.PlainText + focus: false + readOnly: true + font.pixelSize: theme.fontSizeLarge + cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false + cursorPosition: text.length + TapHandler { + id: tapHandler + onTapped: function(eventPoint, button) { + var clickedPos = myTextArea.positionAt(eventPoint.position.x, eventPoint.position.y); + var success = textProcessor.tryCopyAtPosition(clickedPos); + if (success) + copyCodeMessage.open(); + } + } + + MouseArea { + id: conversationMouseArea + anchors.fill: parent + acceptedButtons: Qt.RightButton + + onClicked: (mouse) => { + if (mouse.button === Qt.RightButton) { + conversationContextMenu.x = conversationMouseArea.mouseX + conversationContextMenu.y = conversationMouseArea.mouseY + conversationContextMenu.open() + } + } + } + + onLinkActivated: function(link) { + if (!isCurrentResponse || !currentChat.responseInProgress) + Qt.openUrlExternally(link) + } + + onLinkHovered: function (link) { + if (!isCurrentResponse || !currentChat.responseInProgress) + statusBar.externalHoveredLink = link + } + + MyMenu { + id: conversationContextMenu + MyMenuItem { + text: qsTr("Copy") + enabled: myTextArea.selectedText !== "" + height: enabled ? implicitHeight : 0 + onTriggered: myTextArea.copy() + } + MyMenuItem { + text: qsTr("Copy Message") + enabled: myTextArea.selectedText === "" + height: enabled ? implicitHeight : 0 + onTriggered: { + myTextArea.selectAll() + myTextArea.copy() + myTextArea.deselect() + } + } + MyMenuItem { + text: textProcessor.shouldProcessText ? qsTr("Disable markdown") : qsTr("Enable markdown") + height: enabled ? implicitHeight : 0 + onTriggered: { + textProcessor.shouldProcessText = !textProcessor.shouldProcessText; + textProcessor.setValue(textContent); + } + } + } + + ChatViewTextProcessor { + id: textProcessor + } + + function resetChatViewTextProcessor() { + textProcessor.fontPixelSize = myTextArea.font.pixelSize + textProcessor.codeColors.defaultColor = theme.codeDefaultColor + textProcessor.codeColors.keywordColor = theme.codeKeywordColor + textProcessor.codeColors.functionColor = theme.codeFunctionColor + textProcessor.codeColors.functionCallColor = theme.codeFunctionCallColor + textProcessor.codeColors.commentColor = theme.codeCommentColor + textProcessor.codeColors.stringColor = theme.codeStringColor + textProcessor.codeColors.numberColor = theme.codeNumberColor + textProcessor.codeColors.headerColor = theme.codeHeaderColor + textProcessor.codeColors.backgroundColor = theme.codeBackgroundColor + textProcessor.textDocument = textDocument + textProcessor.setValue(textContent); + } + + property bool textProcessorReady: false + + Component.onCompleted: { + resetChatViewTextProcessor(); + textProcessorReady = true; + } + + Connections { + target: myTextArea + function onTextContentChanged() { + if (myTextArea.textProcessorReady) + textProcessor.setValue(textContent); + } + } + + Connections { + target: MySettings + function onFontSizeChanged() { + myTextArea.resetChatViewTextProcessor(); + } + function onChatThemeChanged() { + myTextArea.resetChatViewTextProcessor(); + } + } + + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: name === "Response: " ? "The response by the model" : "The prompt by the user" +} \ No newline at end of file diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 31aaf565ccff..431d63986f23 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -824,6 +824,8 @@ Rectangle { textInput.forceActiveFocus(); textInput.cursorPosition = text.length; } + height: visible ? implicitHeight : 0 + visible: name !== "ToolResponse: " } remove: Transition { diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index d9d4bc1c5751..5780513971d9 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -3,12 +3,19 @@ #include "chatlistmodel.h" #include "network.h" #include "server.h" +#include "tool.h" +#include "toolcallparser.h" +#include "toolmodel.h" #include #include #include +#include +#include +#include #include #include +#include #include #include #include @@ -16,6 +23,8 @@ #include +using namespace ToolEnums; + Chat::Chat(QObject *parent) : QObject(parent) , m_id(Network::globalInstance()->generateUniqueId()) @@ -54,7 +63,6 @@ void Chat::connectLLM() // Should be in different threads connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); @@ -181,23 +189,12 @@ Chat::ResponseState Chat::responseState() const return m_responseState; } -void Chat::handleResponseChanged(const QString &response) +void Chat::handleResponseChanged() { if (m_responseState != Chat::ResponseGeneration) { m_responseState = Chat::ResponseGeneration; emit responseStateChanged(); } - - const int index = m_chatModel->count() - 1; - m_chatModel->updateValue(index, response); -} - -void Chat::handleResponseFailed(const QString &error) -{ - const int index = m_chatModel->count() - 1; - m_chatModel->updateValue(index, error); - m_chatModel->setError(); - responseStopped(0); } void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) @@ -242,9 +239,54 @@ void Chat::responseStopped(qint64 promptResponseMs) m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); emit responseStateChanged(); + + const QString possibleToolcall = m_chatModel->possibleToolcall(); + + ToolCallParser parser; + parser.update(possibleToolcall); + + if (parser.state() == ToolEnums::ParseState::Complete) { + const QString toolCall = parser.toolCall(); + + // Regex to remove the formatting around the code + static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$"); + QString code = toolCall; + code.remove(regex); + code = code.trimmed(); + + // Right now the code interpreter is the only available tool + Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction); + Q_ASSERT(toolInstance); + + // The param is the code + const ToolParam param = { "code", ToolEnums::ParamType::String, code }; + const QString result = toolInstance->run({param}, 10000 /*msecs to timeout*/); + const ToolEnums::Error error = toolInstance->error(); + const QString errorString = toolInstance->errorString(); + + // Update the current response with meta information about toolcall and re-parent + m_chatModel->updateToolCall({ + ToolCallConstants::CodeInterpreterFunction, + { param }, + result, + error, + errorString + }); + + ++m_consecutiveToolCalls; + + // We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop + if (m_consecutiveToolCalls < 3 || error == ToolEnums::Error::NoError) { + resetResponseState(); + emit promptRequested(m_collections); // triggers a new response + return; + } + } + if (m_generatedName.isEmpty()) emit generateNameRequested(); + m_consecutiveToolCalls = 0; Network::globalInstance()->trackChatEvent("response_complete", { {"first", m_firstResponse}, {"message_count", chatModel()->count()}, diff --git a/gpt4all-chat/src/chat.h b/gpt4all-chat/src/chat.h index 57e413e5873d..dc8f3e180b35 100644 --- a/gpt4all-chat/src/chat.h +++ b/gpt4all-chat/src/chat.h @@ -161,8 +161,7 @@ public Q_SLOTS: void generatedQuestionsChanged(); private Q_SLOTS: - void handleResponseChanged(const QString &response); - void handleResponseFailed(const QString &error); + void handleResponseChanged(); void handleModelLoadingPercentageChanged(float); void promptProcessing(); void generatingQuestions(); @@ -205,6 +204,7 @@ private Q_SLOTS: // - The chat was freshly created during this launch. // - The chat was changed after loading it from disk. bool m_needsSave = true; + int m_consecutiveToolCalls = 0; }; #endif // CHAT_H diff --git a/gpt4all-chat/src/chatlistmodel.cpp b/gpt4all-chat/src/chatlistmodel.cpp index bf76ce4449ae..fd9d450925a2 100644 --- a/gpt4all-chat/src/chatlistmodel.cpp +++ b/gpt4all-chat/src/chatlistmodel.cpp @@ -20,7 +20,7 @@ #include static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC; -static constexpr qint32 CHAT_FORMAT_VERSION = 11; +static constexpr qint32 CHAT_FORMAT_VERSION = 12; class MyChatListModel: public ChatListModel { }; Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index 408f9f3dfdd9..e5a46bf6d0d6 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -7,6 +7,9 @@ #include "localdocs.h" #include "mysettings.h" #include "network.h" +#include "tool.h" +#include "toolmodel.h" +#include "toolcallparser.h" #include @@ -55,6 +58,7 @@ #include using namespace Qt::Literals::StringLiterals; +using namespace ToolEnums; namespace ranges = std::ranges; //#define DEBUG @@ -643,40 +647,16 @@ bool isAllSpace(R &&r) void ChatLLM::regenerateResponse(int index) { Q_ASSERT(m_chatModel); - int promptIdx; - { - auto items = m_chatModel->chatItems(); // holds lock - if (index < 1 || index >= items.size() || items[index].type() != ChatItem::Type::Response) - return; - promptIdx = m_chatModel->getPeerUnlocked(index).value_or(-1); + if (m_chatModel->regenerateResponse(index)) { + emit responseChanged(); + prompt(m_chat->collectionList()); } - - emit responseChanged({}); - m_chatModel->truncate(index + 1); - m_chatModel->updateCurrentResponse(index, true ); - m_chatModel->updateNewResponse (index, {} ); - m_chatModel->updateStopped (index, false); - m_chatModel->updateThumbsUpState (index, false); - m_chatModel->updateThumbsDownState(index, false); - m_chatModel->setError(false); - if (promptIdx >= 0) - m_chatModel->updateSources(promptIdx, {}); - - prompt(m_chat->collectionList()); } std::optional ChatLLM::popPrompt(int index) { Q_ASSERT(m_chatModel); - QString content; - { - auto items = m_chatModel->chatItems(); // holds lock - if (index < 0 || index >= items.size() || items[index].type() != ChatItem::Type::Prompt) - return std::nullopt; - content = items[index].value; - } - m_chatModel->truncate(index); - return content; + return m_chatModel->popPrompt(index); } ModelInfo ChatLLM::modelInfo() const @@ -737,28 +717,28 @@ void ChatLLM::prompt(const QStringList &enabledCollections) promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo)); } catch (const std::exception &e) { // FIXME(jared): this is neither translated nor serialized - emit responseFailed(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); + m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); + m_chatModel->setError(); emit responseStopped(0); } } -// FIXME(jared): We can avoid this potentially expensive copy if we use ChatItem pointers, but this is only safe if we -// hold the lock while generating. We can't do that now because Chat is actually in charge of updating the response, not -// ChatLLM. -std::vector ChatLLM::forkConversation(const QString &prompt) const +std::vector ChatLLM::forkConversation(const QString &prompt) const { Q_ASSERT(m_chatModel); if (m_chatModel->hasError()) throw std::logic_error("cannot continue conversation with an error"); - std::vector conversation; + std::vector conversation; { - auto items = m_chatModel->chatItems(); // holds lock - Q_ASSERT(items.size() >= 2); // should be prompt/response pairs + auto items = m_chatModel->messageItems(); + // It is possible the main thread could have erased the conversation while the llm thread, + // is busy forking the conversatoin but it must have set stop generating first + Q_ASSERT(items.size() >= 2 || m_stopGenerating); // should be prompt/response pairs conversation.reserve(items.size() + 1); conversation.assign(items.begin(), items.end()); } - conversation.emplace_back(ChatItem::prompt_tag, prompt); + conversation.emplace_back(MessageItem::Type::Prompt, prompt.toUtf8()); return conversation; } @@ -793,7 +773,7 @@ std::optional ChatLLM::checkJinjaTemplateError(const std::string &s return std::nullopt; } -std::string ChatLLM::applyJinjaTemplate(std::span items) const +std::string ChatLLM::applyJinjaTemplate(std::span items) const { Q_ASSERT(items.size() >= 1); @@ -820,25 +800,33 @@ std::string ChatLLM::applyJinjaTemplate(std::span items) const uint version = parseJinjaTemplateVersion(chatTemplate); - auto makeMap = [version](const ChatItem &item) { + auto makeMap = [version](const MessageItem &item) { return jinja2::GenericMap([msg = std::make_shared(version, item)] { return msg.get(); }); }; - std::unique_ptr systemItem; + std::unique_ptr systemItem; bool useSystem = !isAllSpace(systemMessage); jinja2::ValuesList messages; messages.reserve(useSystem + items.size()); if (useSystem) { - systemItem = std::make_unique(ChatItem::system_tag, systemMessage); + systemItem = std::make_unique(MessageItem::Type::System, systemMessage.toUtf8()); messages.emplace_back(makeMap(*systemItem)); } for (auto &item : items) messages.emplace_back(makeMap(item)); + jinja2::ValuesList toolList; + const int toolCount = ToolModel::globalInstance()->count(); + for (int i = 0; i < toolCount; ++i) { + Tool *t = ToolModel::globalInstance()->get(i); + toolList.push_back(t->jinjaValue()); + } + jinja2::ValuesMap params { { "messages", std::move(messages) }, { "add_generation_prompt", true }, + { "toolList", toolList }, }; for (auto &[name, token] : model->specialTokens()) params.emplace(std::move(name), std::move(token)); @@ -852,48 +840,44 @@ std::string ChatLLM::applyJinjaTemplate(std::span items) const } auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx, - std::optional> subrange) -> ChatPromptResult + qsizetype startOffset) -> ChatPromptResult { Q_ASSERT(isModelLoaded()); Q_ASSERT(m_chatModel); - // Return a (ChatModelAccessor, std::span) pair where the span represents the relevant messages for this chat. - // "subrange" is used to select only local server messages from the current chat session. + // Return a vector of relevant messages for this chat. + // "startOffset" is used to select only local server messages from the current chat session. auto getChat = [&]() { - auto items = m_chatModel->chatItems(); // holds lock - std::span view(items); - if (subrange) - view = view.subspan(subrange->first, subrange->second); - Q_ASSERT(view.size() >= 2); - return std::pair(std::move(items), view); + auto items = m_chatModel->messageItems(); + if (startOffset > 0) + items.erase(items.begin(), items.begin() + startOffset); + Q_ASSERT(items.size() >= 2); + return items; }; - // copy messages for safety (since we can't hold the lock the whole time) - std::optional> query; - { - // Find the prompt that represents the query. Server chats are flexible and may not have one. - auto [_, view] = getChat(); // holds lock - if (auto peer = m_chatModel->getPeer(view, view.end() - 1)) // peer of response - query = { *peer - view.begin(), (*peer)->value }; - } - QList databaseResults; - if (query && !enabledCollections.isEmpty()) { - auto &[promptIndex, queryStr] = *query; - const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); - emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks - m_chatModel->updateSources(promptIndex, databaseResults); - emit databaseResultsChanged(databaseResults); - } + if (!enabledCollections.isEmpty()) { + std::optional> query; + { + // Find the prompt that represents the query. Server chats are flexible and may not have one. + auto items = getChat(); + if (auto peer = m_chatModel->getPeer(items, items.end() - 1)) // peer of response + query = { *peer - items.begin(), (*peer)->content() }; + } - // copy messages for safety (since we can't hold the lock the whole time) - std::vector chatItems; - { - auto [_, view] = getChat(); // holds lock - chatItems.assign(view.begin(), view.end() - 1); // exclude new response + if (query) { + auto &[promptIndex, queryStr] = *query; + const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); + emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks + m_chatModel->updateSources(promptIndex, databaseResults); + emit databaseResultsChanged(databaseResults); + } } - auto result = promptInternal(chatItems, ctx, !databaseResults.isEmpty()); + auto messageItems = getChat(); + messageItems.pop_back(); // exclude new response + + auto result = promptInternal(messageItems, ctx, !databaseResults.isEmpty()); return { /*PromptResult*/ { .response = std::move(result.response), @@ -905,7 +889,7 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL } auto ChatLLM::promptInternal( - const std::variant, std::string_view> &prompt, + const std::variant, std::string_view> &prompt, const LLModel::PromptContext &ctx, bool usedLocalDocs ) -> PromptResult @@ -915,14 +899,14 @@ auto ChatLLM::promptInternal( auto *mySettings = MySettings::globalInstance(); // unpack prompt argument - const std::span *chatItems = nullptr; + const std::span *messageItems = nullptr; std::string jinjaBuffer; std::string_view conversation; if (auto *nonChat = std::get_if(&prompt)) { conversation = *nonChat; // complete the string without a template } else { - chatItems = &std::get>(prompt); - jinjaBuffer = applyJinjaTemplate(*chatItems); + messageItems = &std::get>(prompt); + jinjaBuffer = applyJinjaTemplate(*messageItems); conversation = jinjaBuffer; } @@ -930,8 +914,8 @@ auto ChatLLM::promptInternal( if (!dynamic_cast(m_llModelInfo.model.get())) { auto nCtx = m_llModelInfo.model->contextLength(); std::string jinjaBuffer2; - auto lastMessageRendered = (chatItems && chatItems->size() > 1) - ? std::string_view(jinjaBuffer2 = applyJinjaTemplate({ &chatItems->back(), 1 })) + auto lastMessageRendered = (messageItems && messageItems->size() > 1) + ? std::string_view(jinjaBuffer2 = applyJinjaTemplate({ &messageItems->back(), 1 })) : conversation; int32_t lastMessageLength = m_llModelInfo.model->countPromptTokens(lastMessageRendered); if (auto limit = nCtx - 4; lastMessageLength > limit) { @@ -951,14 +935,42 @@ auto ChatLLM::promptInternal( return !m_stopGenerating; }; - auto handleResponse = [this, &result](LLModel::Token token, std::string_view piece) -> bool { + ToolCallParser toolCallParser; + auto handleResponse = [this, &result, &toolCallParser](LLModel::Token token, std::string_view piece) -> bool { Q_UNUSED(token) result.responseTokens++; m_timer->inc(); + + // FIXME: This is *not* necessarily fully formed utf data because it can be partial at this point + // handle this like below where we have a QByteArray + toolCallParser.update(QString::fromStdString(piece.data())); + + // Create a toolcall and split the response if needed + if (!toolCallParser.hasSplit() && toolCallParser.state() == ToolEnums::ParseState::Partial) { + const QPair pair = toolCallParser.split(); + m_chatModel->splitToolCall(pair); + } + result.response.append(piece.data(), piece.size()); auto respStr = QString::fromUtf8(result.response); - emit responseChanged(removeLeadingWhitespace(respStr)); - return !m_stopGenerating; + + try { + if (toolCallParser.hasSplit()) + m_chatModel->setResponseValue(toolCallParser.buffer()); + else + m_chatModel->setResponseValue(removeLeadingWhitespace(respStr)); + } catch (const std::exception &e) { + // We have a try/catch here because the main thread might have removed the response from + // the chatmodel by erasing the conversation during the response... the main thread sets + // m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel + Q_ASSERT(m_stopGenerating); + return false; + } + + emit responseChanged(); + + const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete; + return !foundToolCall && !m_stopGenerating; }; QElapsedTimer totalTime; @@ -978,13 +990,20 @@ auto ChatLLM::promptInternal( m_timer->stop(); qint64 elapsed = totalTime.elapsed(); + const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete; + // trim trailing whitespace auto respStr = QString::fromUtf8(result.response); - if (!respStr.isEmpty() && std::as_const(respStr).back().isSpace()) - emit responseChanged(respStr.trimmed()); + if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || foundToolCall)) { + if (toolCallParser.hasSplit()) + m_chatModel->setResponseValue(toolCallParser.buffer()); + else + m_chatModel->setResponseValue(respStr.trimmed()); + emit responseChanged(); + } bool doQuestions = false; - if (!m_isServer && chatItems) { + if (!m_isServer && messageItems && !foundToolCall) { switch (mySettings->suggestionMode()) { case SuggestionMode::On: doQuestions = true; break; case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break; diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index 0d05de87bf8e..e34d3899b0f0 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -220,8 +220,8 @@ public Q_SLOTS: void modelLoadingPercentageChanged(float); void modelLoadingError(const QString &error); void modelLoadingWarning(const QString &warning); - void responseChanged(const QString &response); - void responseFailed(const QString &error); + void responseChanged(); + void responseFailed(); void promptProcessing(); void generatingQuestions(); void responseStopped(qint64 promptResponseMs); @@ -251,20 +251,20 @@ public Q_SLOTS: }; ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx, - std::optional> subrange = {}); + qsizetype startOffset = 0); // passing a string_view directly skips templating and uses the raw string - PromptResult promptInternal(const std::variant, std::string_view> &prompt, + PromptResult promptInternal(const std::variant, std::string_view> &prompt, const LLModel::PromptContext &ctx, bool usedLocalDocs); private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); - std::vector forkConversation(const QString &prompt) const; + std::vector forkConversation(const QString &prompt) const; // Applies the Jinja template. Query mode returns only the last message without special tokens. // Returns a (# of messages, rendered prompt) pair. - std::string applyJinjaTemplate(std::span items) const; + std::string applyJinjaTemplate(std::span items) const; void generateQuestions(qint64 elapsed); diff --git a/gpt4all-chat/src/chatmodel.cpp b/gpt4all-chat/src/chatmodel.cpp new file mode 100644 index 000000000000..54af48d21f38 --- /dev/null +++ b/gpt4all-chat/src/chatmodel.cpp @@ -0,0 +1,345 @@ +#include "chatmodel.h" + +QList ChatItem::consolidateSources(const QList &sources) +{ + QMap groupedData; + for (const ResultInfo &info : sources) { + if (groupedData.contains(info.file)) { + groupedData[info.file].text += "\n---\n" + info.text; + } else { + groupedData[info.file] = info; + } + } + QList consolidatedSources = groupedData.values(); + return consolidatedSources; +} + +void ChatItem::serializeResponse(QDataStream &stream, int version) +{ + stream << value; +} + +void ChatItem::serializeToolCall(QDataStream &stream, int version) +{ + stream << value; + toolCallInfo.serialize(stream, version); +} + +void ChatItem::serializeToolResponse(QDataStream &stream, int version) +{ + stream << value; +} + +void ChatItem::serializeText(QDataStream &stream, int version) +{ + stream << value; +} + +void ChatItem::serializeSubItems(QDataStream &stream, int version) +{ + stream << name; + switch (auto typ = type()) { + using enum ChatItem::Type; + case Response: { serializeResponse(stream, version); break; } + case ToolCall: { serializeToolCall(stream, version); break; } + case ToolResponse: { serializeToolResponse(stream, version); break; } + case Text: { serializeText(stream, version); break; } + case System: + case Prompt: + throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ))); + } + + stream << qsizetype(subItems.size()); + for (ChatItem *item :subItems) + item->serializeSubItems(stream, version); +} + +void ChatItem::serialize(QDataStream &stream, int version) +{ + stream << name; + stream << value; + stream << newResponse; + stream << isCurrentResponse; + stream << stopped; + stream << thumbsUpState; + stream << thumbsDownState; + if (version >= 11 && type() == ChatItem::Type::Response) + stream << isError; + if (version >= 8) { + stream << sources.size(); + for (const ResultInfo &info : sources) { + Q_ASSERT(!info.file.isEmpty()); + stream << info.collection; + stream << info.path; + stream << info.file; + stream << info.title; + stream << info.author; + stream << info.date; + stream << info.text; + stream << info.page; + stream << info.from; + stream << info.to; + } + } else if (version >= 3) { + QList references; + QList referencesContext; + int validReferenceNumber = 1; + for (const ResultInfo &info : sources) { + if (info.file.isEmpty()) + continue; + + QString reference; + { + QTextStream stream(&reference); + stream << (validReferenceNumber++) << ". "; + if (!info.title.isEmpty()) + stream << "\"" << info.title << "\". "; + if (!info.author.isEmpty()) + stream << "By " << info.author << ". "; + if (!info.date.isEmpty()) + stream << "Date: " << info.date << ". "; + stream << "In " << info.file << ". "; + if (info.page != -1) + stream << "Page " << info.page << ". "; + if (info.from != -1) { + stream << "Lines " << info.from; + if (info.to != -1) + stream << "-" << info.to; + stream << ". "; + } + stream << "[Context](context://" << validReferenceNumber - 1 << ")"; + } + references.append(reference); + referencesContext.append(info.text); + } + + stream << references.join("\n"); + stream << referencesContext; + } + if (version >= 10) { + stream << promptAttachments.size(); + for (const PromptAttachment &a : promptAttachments) { + Q_ASSERT(!a.url.isEmpty()); + stream << a.url; + stream << a.content; + } + } + + if (version >= 12) { + stream << qsizetype(subItems.size()); + for (ChatItem *item :subItems) + item->serializeSubItems(stream, version); + } +} + +bool ChatItem::deserializeToolCall(QDataStream &stream, int version) +{ + stream >> value; + return toolCallInfo.deserialize(stream, version);; +} + +bool ChatItem::deserializeToolResponse(QDataStream &stream, int version) +{ + stream >> value; + return true; +} + +bool ChatItem::deserializeText(QDataStream &stream, int version) +{ + stream >> value; + return true; +} + +bool ChatItem::deserializeResponse(QDataStream &stream, int version) +{ + stream >> value; + return true; +} + +bool ChatItem::deserializeSubItems(QDataStream &stream, int version) +{ + stream >> name; + try { + type(); // check name + } catch (const std::exception &e) { + qWarning() << "ChatModel ERROR:" << e.what(); + return false; + } + switch (auto typ = type()) { + using enum ChatItem::Type; + case Response: { deserializeResponse(stream, version); break; } + case ToolCall: { deserializeToolCall(stream, version); break; } + case ToolResponse: { deserializeToolResponse(stream, version); break; } + case Text: { deserializeText(stream, version); break; } + case System: + case Prompt: + throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ))); + } + + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ChatItem *c = new ChatItem(this); + if (!c->deserializeSubItems(stream, version)) { + delete c; + return false; + } + subItems.push_back(c); + } + + return true; +} + +bool ChatItem::deserialize(QDataStream &stream, int version) +{ + if (version < 12) { + int id; + stream >> id; + } + stream >> name; + try { + type(); // check name + } catch (const std::exception &e) { + qWarning() << "ChatModel ERROR:" << e.what(); + return false; + } + stream >> value; + if (version < 10) { + // This is deprecated and no longer used + QString prompt; + stream >> prompt; + } + stream >> newResponse; + stream >> isCurrentResponse; + stream >> stopped; + stream >> thumbsUpState; + stream >> thumbsDownState; + if (version >= 11 && type() == ChatItem::Type::Response) + stream >> isError; + if (version >= 8) { + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ResultInfo info; + stream >> info.collection; + stream >> info.path; + stream >> info.file; + stream >> info.title; + stream >> info.author; + stream >> info.date; + stream >> info.text; + stream >> info.page; + stream >> info.from; + stream >> info.to; + sources.append(info); + } + consolidatedSources = ChatItem::consolidateSources(sources); + } else if (version >= 3) { + QString references; + QList referencesContext; + stream >> references; + stream >> referencesContext; + + if (!references.isEmpty()) { + QList referenceList = references.split("\n"); + + // Ignore empty lines and those that begin with "---" which is no longer used + for (auto it = referenceList.begin(); it != referenceList.end();) { + if (it->trimmed().isEmpty() || it->trimmed().startsWith("---")) + it = referenceList.erase(it); + else + ++it; + } + + Q_ASSERT(referenceList.size() == referencesContext.size()); + for (int j = 0; j < referenceList.size(); ++j) { + QString reference = referenceList[j]; + QString context = referencesContext[j]; + ResultInfo info; + QTextStream refStream(&reference); + QString dummy; + int validReferenceNumber; + refStream >> validReferenceNumber >> dummy; + // Extract title (between quotes) + if (reference.contains("\"")) { + int startIndex = reference.indexOf('"') + 1; + int endIndex = reference.indexOf('"', startIndex); + info.title = reference.mid(startIndex, endIndex - startIndex); + } + + // Extract author (after "By " and before the next period) + if (reference.contains("By ")) { + int startIndex = reference.indexOf("By ") + 3; + int endIndex = reference.indexOf('.', startIndex); + info.author = reference.mid(startIndex, endIndex - startIndex).trimmed(); + } + + // Extract date (after "Date: " and before the next period) + if (reference.contains("Date: ")) { + int startIndex = reference.indexOf("Date: ") + 6; + int endIndex = reference.indexOf('.', startIndex); + info.date = reference.mid(startIndex, endIndex - startIndex).trimmed(); + } + + // Extract file name (after "In " and before the "[Context]") + if (reference.contains("In ") && reference.contains(". [Context]")) { + int startIndex = reference.indexOf("In ") + 3; + int endIndex = reference.indexOf(". [Context]", startIndex); + info.file = reference.mid(startIndex, endIndex - startIndex).trimmed(); + } + + // Extract page number (after "Page " and before the next space) + if (reference.contains("Page ")) { + int startIndex = reference.indexOf("Page ") + 5; + int endIndex = reference.indexOf(' ', startIndex); + if (endIndex == -1) endIndex = reference.length(); + info.page = reference.mid(startIndex, endIndex - startIndex).toInt(); + } + + // Extract lines (after "Lines " and before the next space or hyphen) + if (reference.contains("Lines ")) { + int startIndex = reference.indexOf("Lines ") + 6; + int endIndex = reference.indexOf(' ', startIndex); + if (endIndex == -1) endIndex = reference.length(); + int hyphenIndex = reference.indexOf('-', startIndex); + if (hyphenIndex != -1 && hyphenIndex < endIndex) { + info.from = reference.mid(startIndex, hyphenIndex - startIndex).toInt(); + info.to = reference.mid(hyphenIndex + 1, endIndex - hyphenIndex - 1).toInt(); + } else { + info.from = reference.mid(startIndex, endIndex - startIndex).toInt(); + } + } + info.text = context; + sources.append(info); + } + + consolidatedSources = ChatItem::consolidateSources(sources); + } + } + if (version >= 10) { + qsizetype count; + stream >> count; + QList attachments; + for (int i = 0; i < count; ++i) { + PromptAttachment a; + stream >> a.url; + stream >> a.content; + attachments.append(a); + } + promptAttachments = attachments; + } + + if (version >= 12) { + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ChatItem *c = new ChatItem(this); + if (!c->deserializeSubItems(stream, version)) { + delete c; + return false; + } + subItems.push_back(c); + } + } + return true; +} diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 25eb8e71808c..82dbc68fdc7f 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -2,15 +2,20 @@ #define CHATMODEL_H #include "database.h" +#include "tool.h" +#include "toolcallparser.h" #include "utils.h" #include "xlsxtomd.h" #include +#include #include #include #include +#include #include +#include #include #include #include @@ -21,13 +26,16 @@ #include #include +#include #include #include #include #include +#include using namespace Qt::Literals::StringLiterals; namespace ranges = std::ranges; +namespace views = std::views; struct PromptAttachment { @@ -69,19 +77,81 @@ struct PromptAttachment { }; Q_DECLARE_METATYPE(PromptAttachment) -struct ChatItem +// Used by Server to represent a message from the client. +struct MessageInput +{ + enum class Type { System, Prompt, Response }; + Type type; + QString content; +}; + +class MessageItem { Q_GADGET + Q_PROPERTY(Type type READ type CONSTANT) + Q_PROPERTY(QString content READ content CONSTANT) + +public: + enum class Type { System, Prompt, Response, ToolResponse }; + + MessageItem(Type type, QString content) + : m_type(type), m_content(std::move(content)) {} + + MessageItem(Type type, QString content, const QList &sources, const QList &promptAttachments) + : m_type(type), m_content(std::move(content)), m_sources(sources), m_promptAttachments(promptAttachments) {} + + Type type() const { return m_type; } + const QString &content() const { return m_content; } + + QList sources() const { return m_sources; } + QList promptAttachments() const { return m_promptAttachments; } + + // used with version 0 Jinja templates + QString bakedPrompt() const + { + if (type() != Type::Prompt) + throw std::logic_error("bakedPrompt() called on non-prompt item"); + QStringList parts; + if (!m_sources.isEmpty()) { + parts << u"### Context:\n"_s; + for (auto &source : std::as_const(m_sources)) + parts << u"Collection: "_s << source.collection + << u"\nPath: "_s << source.path + << u"\nExcerpt: "_s << source.text << u"\n\n"_s; + } + for (auto &attached : std::as_const(m_promptAttachments)) + parts << attached.processedContent() << u"\n\n"_s; + parts << m_content; + return parts.join(QString()); + } + +private: + Type m_type; + QString m_content; + QList m_sources; + QList m_promptAttachments; +}; +Q_DECLARE_METATYPE(MessageItem) + +class ChatItem : public QObject +{ + Q_OBJECT Q_PROPERTY(QString name MEMBER name ) Q_PROPERTY(QString value MEMBER value) + // prompts and responses + Q_PROPERTY(QString content READ content NOTIFY contentChanged) + // prompts Q_PROPERTY(QList promptAttachments MEMBER promptAttachments) - Q_PROPERTY(QString bakedPrompt READ bakedPrompt ) // responses - Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse) - Q_PROPERTY(bool isError MEMBER isError ) + Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse NOTIFY isCurrentResponseChanged) + Q_PROPERTY(bool isError MEMBER isError ) + Q_PROPERTY(QList childItems READ childItems ) + + // toolcall + Q_PROPERTY(bool isToolCallError READ isToolCallError NOTIFY isTooCallErrorChanged) // responses (DataLake) Q_PROPERTY(QString newResponse MEMBER newResponse ) @@ -90,34 +160,65 @@ struct ChatItem Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) public: - enum class Type { System, Prompt, Response }; + enum class Type { System, Prompt, Response, Text, ToolCall, ToolResponse }; // tags for constructing ChatItems - struct prompt_tag_t { explicit prompt_tag_t() = default; }; - static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t(); - struct response_tag_t { explicit response_tag_t() = default; }; - static inline constexpr response_tag_t response_tag = response_tag_t(); - struct system_tag_t { explicit system_tag_t() = default; }; - static inline constexpr system_tag_t system_tag = system_tag_t(); + struct prompt_tag_t { explicit prompt_tag_t () = default; }; + struct response_tag_t { explicit response_tag_t () = default; }; + struct system_tag_t { explicit system_tag_t () = default; }; + struct text_tag_t { explicit text_tag_t () = default; }; + struct tool_call_tag_t { explicit tool_call_tag_t () = default; }; + struct tool_response_tag_t { explicit tool_response_tag_t() = default; }; + static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t {}; + static inline constexpr response_tag_t response_tag = response_tag_t {}; + static inline constexpr system_tag_t system_tag = system_tag_t {}; + static inline constexpr text_tag_t text_tag = text_tag_t {}; + static inline constexpr tool_call_tag_t tool_call_tag = tool_call_tag_t {}; + static inline constexpr tool_response_tag_t tool_response_tag = tool_response_tag_t {}; - // FIXME(jared): This should not be necessary. QML should see null or undefined if it - // tries to access something invalid. - ChatItem() = default; +public: + ChatItem(QObject *parent) + : QObject(nullptr) + { + moveToThread(parent->thread()); + setParent(parent); + } - // NOTE: system messages are currently never stored in the model or serialized - ChatItem(system_tag_t, const QString &value) - : name(u"System: "_s), value(value) {} + // NOTE: System messages are currently never serialized and only *stored* by the local server. + // ChatLLM prepends a system MessageItem on-the-fly. + ChatItem(QObject *parent, system_tag_t, const QString &value) + : ChatItem(parent) + { this->name = u"System: "_s; this->value = value; } - ChatItem(prompt_tag_t, const QString &value, const QList &attachments = {}) - : name(u"Prompt: "_s), value(value), promptAttachments(attachments) {} + ChatItem(QObject *parent, prompt_tag_t, const QString &value, const QList &attachments = {}) + : ChatItem(parent) + { this->name = u"Prompt: "_s; this->value = value; this->promptAttachments = attachments; } +private: + ChatItem(QObject *parent, response_tag_t, bool isCurrentResponse, const QString &value = {}) + : ChatItem(parent) + { this->name = u"Response: "_s; this->value = value; this->isCurrentResponse = isCurrentResponse; } + +public: // A new response, to be filled in - ChatItem(response_tag_t) - : name(u"Response: "_s), isCurrentResponse(true) {} + ChatItem(QObject *parent, response_tag_t) + : ChatItem(parent, response_tag, true) {} // An existing response, from Server - ChatItem(response_tag_t, const QString &value) - : name(u"Response: "_s), value(value) {} + ChatItem(QObject *parent, response_tag_t, const QString &value) + : ChatItem(parent, response_tag, false, value) {} + + ChatItem(QObject *parent, text_tag_t, const QString &value) + : ChatItem(parent) + { this->name = u"Text: "_s; this->value = value; } + + ChatItem(QObject *parent, tool_call_tag_t, const QString &value) + : ChatItem(parent) + { this->name = u"ToolCall: "_s; this->value = value; } + + ChatItem(QObject *parent, tool_response_tag_t, const QString &value) + : ChatItem(parent) + { this->name = u"ToolResponse: "_s; this->value = value; } Type type() const { @@ -127,28 +228,187 @@ struct ChatItem return Type::Prompt; if (name == u"Response: "_s) return Type::Response; + if (name == u"Text: "_s) + return Type::Text; + if (name == u"ToolCall: "_s) + return Type::ToolCall; + if (name == u"ToolResponse: "_s) + return Type::ToolResponse; throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name)); } - // used with version 0 Jinja templates - QString bakedPrompt() const + QString flattenedContent() const { - if (type() != Type::Prompt) - throw std::logic_error("bakedPrompt() called on non-prompt item"); - QStringList parts; - if (!sources.isEmpty()) { - parts << u"### Context:\n"_s; - for (auto &source : std::as_const(sources)) - parts << u"Collection: "_s << source.collection - << u"\nPath: "_s << source.path - << u"\nExcerpt: "_s << source.text << u"\n\n"_s; + if (subItems.empty()) + return value; + + // We only flatten one level + QString content; + for (ChatItem *item : subItems) + content += item->value; + return content; + } + + QString content() const + { + if (type() == Type::Response) { + // We parse if this contains any part of a partial toolcall + ToolCallParser parser; + parser.update(value); + + // If no tool call is detected, return the original value + if (parser.startIndex() < 0) + return value; + + // Otherwise we only return the text before and any partial tool call + const QString beforeToolCall = value.left(parser.startIndex()); + return beforeToolCall; } - for (auto &attached : std::as_const(promptAttachments)) - parts << attached.processedContent() << u"\n\n"_s; - parts << value; - return parts.join(QString()); + + // For tool calls we only return content if it is the code interpreter + if (type() == Type::ToolCall) + return codeInterpreterContent(value); + + // We don't show any of content from the tool response in the GUI + if (type() == Type::ToolResponse) + return QString(); + + return value; + } + + QString codeInterpreterContent(const QString &value) const + { + ToolCallParser parser; + parser.update(value); + + // Extract the code + QString code = parser.toolCall(); + code = code.trimmed(); + + QString result; + + // If we've finished the tool call then extract the result from meta information + if (toolCallInfo.name == ToolCallConstants::CodeInterpreterFunction) + result = "```\n" + toolCallInfo.result + "```"; + + // Return the formatted code and the result if available + return code + result; + } + + QString clipboardContent() const + { + QStringList clipContent; + for (const ChatItem *item : subItems) + clipContent << item->clipboardContent(); + clipContent << content(); + return clipContent.join(""); + } + + QList childItems() const + { + // We currently have leaf nodes at depth 3 with nodes at depth 2 as mere containers we don't + // care about in GUI + QList items; + for (const ChatItem *item : subItems) { + items.reserve(items.size() + item->subItems.size()); + ranges::copy(item->subItems, std::back_inserter(items)); + } + return items; + } + + QString possibleToolCall() const + { + if (!subItems.empty()) + return subItems.back()->possibleToolCall(); + if (type() == Type::ToolCall) + return value; + else + return QString(); + } + + void setCurrentResponse(bool b) + { + if (!subItems.empty()) + subItems.back()->setCurrentResponse(b); + isCurrentResponse = b; + emit isCurrentResponseChanged(); + } + + void setValue(const QString &v) + { + if (!subItems.empty() && subItems.back()->isCurrentResponse) { + subItems.back()->setValue(v); + return; + } + + value = v; + emit contentChanged(); } + void setToolCallInfo(const ToolCallInfo &info) + { + toolCallInfo = info; + emit contentChanged(); + emit isTooCallErrorChanged(); + } + + bool isToolCallError() const + { + return toolCallInfo.error != ToolEnums::Error::NoError; + } + + // NB: Assumes response is not current. + static ChatItem *fromMessageInput(QObject *parent, const MessageInput &message) + { + switch (message.type) { + using enum MessageInput::Type; + case Prompt: return new ChatItem(parent, prompt_tag, message.content); + case Response: return new ChatItem(parent, response_tag, message.content); + case System: return new ChatItem(parent, system_tag, message.content); + } + Q_UNREACHABLE(); + } + + MessageItem asMessageItem() const + { + MessageItem::Type msgType; + switch (auto typ = type()) { + using enum ChatItem::Type; + case System: msgType = MessageItem::Type::System; break; + case Prompt: msgType = MessageItem::Type::Prompt; break; + case Response: msgType = MessageItem::Type::Response; break; + case ToolResponse: msgType = MessageItem::Type::ToolResponse; break; + case Text: + case ToolCall: + throw std::invalid_argument(fmt::format("cannot convert ChatItem type {} to message item", int(typ))); + } + return { msgType, flattenedContent(), sources, promptAttachments }; + } + + static QList consolidateSources(const QList &sources); + + void serializeResponse(QDataStream &stream, int version); + void serializeToolCall(QDataStream &stream, int version); + void serializeToolResponse(QDataStream &stream, int version); + void serializeText(QDataStream &stream, int version); + void serializeSubItems(QDataStream &stream, int version); // recursive + void serialize(QDataStream &stream, int version); + + + bool deserializeResponse(QDataStream &stream, int version); + bool deserializeToolCall(QDataStream &stream, int version); + bool deserializeToolResponse(QDataStream &stream, int version); + bool deserializeText(QDataStream &stream, int version); + bool deserializeSubItems(QDataStream &stream, int version); // recursive + bool deserialize(QDataStream &stream, int version); + +Q_SIGNALS: + void contentChanged(); + void isTooCallErrorChanged(); + void isCurrentResponseChanged(); + +public: + // TODO: Maybe we should include the model name here as well as timestamp? QString name; QString value; @@ -161,6 +421,8 @@ struct ChatItem // responses bool isCurrentResponse = false; bool isError = false; + ToolCallInfo toolCallInfo; + std::list subItems; // responses (DataLake) QString newResponse; @@ -168,20 +430,6 @@ struct ChatItem bool thumbsUpState = false; bool thumbsDownState = false; }; -Q_DECLARE_METATYPE(ChatItem) - -class ChatModelAccessor : public std::span { -private: - using Super = std::span; - -public: - template - ChatModelAccessor(QMutex &mutex, T &&...args) - : Super(std::forward(args)...), m_lock(&mutex) {} - -private: - QMutexLocker m_lock; -}; class ChatModel : public QAbstractListModel { @@ -198,6 +446,9 @@ class ChatModel : public QAbstractListModel NameRole = Qt::UserRole + 1, ValueRole, + // prompts and responses + ContentRole, + // prompts PromptAttachmentsRole, @@ -207,6 +458,7 @@ class ChatModel : public QAbstractListModel ConsolidatedSourcesRole, IsCurrentResponseRole, IsErrorRole, + ChildItemsRole, // responses (DataLake) NewResponseRole, @@ -224,25 +476,47 @@ class ChatModel : public QAbstractListModel /* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs * sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */ - static std::optional getPeer(const ChatItem *arr, qsizetype size, qsizetype index) + template + static std::optional getPeer(const T *arr, qsizetype size, qsizetype index) { Q_ASSERT(index >= 0); Q_ASSERT(index < size); + return getPeerInternal(arr, size, index); + } + +private: + static std::optional getPeerInternal(const ChatItem * const *arr, qsizetype size, qsizetype index) + { qsizetype peer; ChatItem::Type expected; - switch (arr[index].type()) { + switch (arr[index]->type()) { using enum ChatItem::Type; case Prompt: peer = index + 1; expected = Response; break; case Response: peer = index - 1; expected = Prompt; break; default: throw std::invalid_argument("getPeer() called on item that is not a prompt or response"); } + if (peer >= 0 && peer < size && arr[peer]->type() == expected) + return peer; + return std::nullopt; + } + + static std::optional getPeerInternal(const MessageItem *arr, qsizetype size, qsizetype index) + { + qsizetype peer; + MessageItem::Type expected; + switch (arr[index].type()) { + using enum MessageItem::Type; + case Prompt: peer = index + 1; expected = Response; break; + case Response: peer = index - 1; expected = Prompt; break; + default: throw std::invalid_argument("getPeer() called on item that is not a prompt or response"); + } if (peer >= 0 && peer < size && arr[peer].type() == expected) return peer; return std::nullopt; } +public: template - requires std::same_as, ChatItem> static auto getPeer(R &&range, ranges::iterator_t item) -> std::optional> { auto begin = ranges::begin(range); @@ -250,7 +524,7 @@ class ChatModel : public QAbstractListModel .transform([&](auto i) { return begin + i; }); } - auto getPeerUnlocked(QList::const_iterator item) const -> std::optional::const_iterator> + auto getPeerUnlocked(QList::const_iterator item) const -> std::optional::const_iterator> { return getPeer(m_chatItems, item); } std::optional getPeerUnlocked(qsizetype index) const @@ -262,7 +536,8 @@ class ChatModel : public QAbstractListModel if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) return QVariant(); - auto item = m_chatItems.cbegin() + index.row(); + auto itemIt = m_chatItems.cbegin() + index.row(); + auto *item = *itemIt; switch (role) { case NameRole: return item->name; @@ -274,8 +549,8 @@ class ChatModel : public QAbstractListModel { QList data; if (item->type() == ChatItem::Type::Response) { - if (auto prompt = getPeerUnlocked(item)) - data = (*prompt)->sources; + if (auto prompt = getPeerUnlocked(itemIt)) + data = (**prompt)->sources; } return QVariant::fromValue(data); } @@ -283,8 +558,8 @@ class ChatModel : public QAbstractListModel { QList data; if (item->type() == ChatItem::Type::Response) { - if (auto prompt = getPeerUnlocked(item)) - data = (*prompt)->consolidatedSources; + if (auto prompt = getPeerUnlocked(itemIt)) + data = (**prompt)->consolidatedSources; } return QVariant::fromValue(data); } @@ -300,6 +575,10 @@ class ChatModel : public QAbstractListModel return item->thumbsDownState; case IsErrorRole: return item->type() == ChatItem::Type::Response && item->isError; + case ContentRole: + return item->content(); + case ChildItemsRole: + return QVariant::fromValue(item->childItems()); } return QVariant(); @@ -319,6 +598,8 @@ class ChatModel : public QAbstractListModel { StoppedRole, "stopped" }, { ThumbsUpStateRole, "thumbsUpState" }, { ThumbsDownStateRole, "thumbsDownState" }, + { ContentRole, "content" }, + { ChildItemsRole, "childItems" }, }; } @@ -335,7 +616,7 @@ class ChatModel : public QAbstractListModel beginInsertRows(QModelIndex(), count, count); { QMutexLocker locker(&m_mutex); - m_chatItems.emplace_back(ChatItem::prompt_tag, value, attachments); + m_chatItems << new ChatItem(this, ChatItem::prompt_tag, value, attachments); } endInsertRows(); emit countChanged(); @@ -354,46 +635,43 @@ class ChatModel : public QAbstractListModel beginInsertRows(QModelIndex(), count, count); { QMutexLocker locker(&m_mutex); - m_chatItems.emplace_back(ChatItem::response_tag); + m_chatItems << new ChatItem(this, ChatItem::response_tag); } endInsertRows(); emit countChanged(); } // Used by Server to append a new conversation to the chat log. - // Appends a new, blank response to the end of the input list. - // Returns an (offset, count) pair representing the indices of the appended items, including the new response. - std::pair appendResponseWithHistory(QList &history) + // Returns the offset of the appended items. + qsizetype appendResponseWithHistory(std::span history) { if (history.empty()) throw std::invalid_argument("at least one message is required"); - // add an empty response to prepare for generation - history.emplace_back(ChatItem::response_tag); - m_mutex.lock(); qsizetype startIndex = m_chatItems.size(); m_mutex.unlock(); - qsizetype endIndex = startIndex + history.size(); + qsizetype nNewItems = history.size() + 1; + qsizetype endIndex = startIndex + nNewItems; beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/); bool hadError; QList newItems; - std::pair subrange; { QMutexLocker locker(&m_mutex); + startIndex = m_chatItems.size(); // just in case hadError = hasErrorUnlocked(); - subrange = { m_chatItems.size(), history.size() }; - m_chatItems.reserve(m_chatItems.size() + history.size()); - for (auto &item : history) - m_chatItems << item; + m_chatItems.reserve(m_chatItems.count() + nNewItems); + for (auto &message : history) + m_chatItems << ChatItem::fromMessageInput(this, message); + m_chatItems << new ChatItem(this, ChatItem::response_tag); } endInsertRows(); emit countChanged(); // Server can add messages when there is an error because each call is a new conversation if (hadError) emit hasErrorChanged(false); - return subrange; + return startIndex; } void truncate(qsizetype size) @@ -403,7 +681,7 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (size >= (oldSize = m_chatItems.size())) return; - if (size && m_chatItems.at(size - 1).type() != ChatItem::Type::Response) + if (size && m_chatItems.at(size - 1)->type() != ChatItem::Type::Response) throw std::invalid_argument( fmt::format("chat model truncated to {} items would not end in a response", size) ); @@ -423,6 +701,44 @@ class ChatModel : public QAbstractListModel emit hasErrorChanged(false); } + QString popPrompt(int index) + { + QString content; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size() || m_chatItems[index]->type() != ChatItem::Type::Prompt) + throw std::logic_error("attempt to pop a prompt, but this is not a prompt"); + content = m_chatItems[index]->content(); + } + truncate(index); + return content; + } + + bool regenerateResponse(int index) + { + int promptIdx; + { + QMutexLocker locker(&m_mutex); + auto items = m_chatItems; // holds lock + if (index < 1 || index >= items.size() || items[index]->type() != ChatItem::Type::Response) + return false; + promptIdx = getPeerUnlocked(index).value_or(-1); + } + + truncate(index + 1); + clearSubItems(index); + setResponseValue({}); + updateCurrentResponse(index, true ); + updateNewResponse (index, {} ); + updateStopped (index, false); + updateThumbsUpState (index, false); + updateThumbsDownState(index, false); + setError(false); + if (promptIdx >= 0) + updateSources(promptIdx, {}); + return true; + } + Q_INVOKABLE void clear() { { @@ -443,28 +759,24 @@ class ChatModel : public QAbstractListModel emit hasErrorChanged(false); } - Q_INVOKABLE ChatItem get(int index) + Q_INVOKABLE QString possibleToolcall() const { QMutexLocker locker(&m_mutex); - if (index < 0 || index >= m_chatItems.size()) return ChatItem(); - return m_chatItems.at(index); + if (m_chatItems.empty()) return QString(); + return m_chatItems.back()->possibleToolCall(); } Q_INVOKABLE void updateCurrentResponse(int index, bool b) { - bool changed = false; { QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.isCurrentResponse != b) { - item.isCurrentResponse = b; - changed = true; - } + ChatItem *item = m_chatItems[index]; + item->setCurrentResponse(b); } - if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole}); + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole}); } Q_INVOKABLE void updateStopped(int index, bool b) @@ -474,45 +786,28 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.stopped != b) { - item.stopped = b; + ChatItem *item = m_chatItems[index]; + if (item->stopped != b) { + item->stopped = b; changed = true; } } if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); } - Q_INVOKABLE void updateValue(int index, const QString &value) + Q_INVOKABLE void setResponseValue(const QString &value) { - bool changed = false; + qsizetype index; { QMutexLocker locker(&m_mutex); - if (index < 0 || index >= m_chatItems.size()) return; + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("we only set this on a response"); - ChatItem &item = m_chatItems[index]; - if (item.value != value) { - item.value = value; - changed = true; - } - } - if (changed) { - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole}); - emit valueChanged(index, value); - } - } - - static QList consolidateSources(const QList &sources) { - QMap groupedData; - for (const ResultInfo &info : sources) { - if (groupedData.contains(info.file)) { - groupedData[info.file].text += "\n---\n" + info.text; - } else { - groupedData[info.file] = info; - } + index = m_chatItems.count() - 1; + ChatItem *item = m_chatItems.back(); + item->setValue(value); } - QList consolidatedSources = groupedData.values(); - return consolidatedSources; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole, ContentRole}); } Q_INVOKABLE void updateSources(int index, const QList &sources) @@ -523,12 +818,12 @@ class ChatModel : public QAbstractListModel if (index < 0 || index >= m_chatItems.size()) return; auto promptItem = m_chatItems.begin() + index; - if (promptItem->type() != ChatItem::Type::Prompt) + if ((*promptItem)->type() != ChatItem::Type::Prompt) throw std::invalid_argument(fmt::format("item at index {} is not a prompt", index)); if (auto peer = getPeerUnlocked(promptItem)) responseIndex = *peer - m_chatItems.cbegin(); - promptItem->sources = sources; - promptItem->consolidatedSources = consolidateSources(sources); + (*promptItem)->sources = sources; + (*promptItem)->consolidatedSources = ChatItem::consolidateSources(sources); } if (responseIndex >= 0) { emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {SourcesRole}); @@ -543,9 +838,9 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.thumbsUpState != b) { - item.thumbsUpState = b; + ChatItem *item = m_chatItems[index]; + if (item->thumbsUpState != b) { + item->thumbsUpState = b; changed = true; } } @@ -559,9 +854,9 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.thumbsDownState != b) { - item.thumbsDownState = b; + ChatItem *item = m_chatItems[index]; + if (item->thumbsDownState != b) { + item->thumbsDownState = b; changed = true; } } @@ -575,137 +870,172 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.newResponse != newResponse) { - item.newResponse = newResponse; + ChatItem *item = m_chatItems[index]; + if (item->newResponse != newResponse) { + item->newResponse = newResponse; changed = true; } } if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); } + Q_INVOKABLE void splitToolCall(const QPair &split) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only set toolcall on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *currentResponse = m_chatItems.back(); + Q_ASSERT(currentResponse->isCurrentResponse); + + // Create a new response container for any text and the tool call + ChatItem *newResponse = new ChatItem(this, ChatItem::response_tag); + + // Add preceding text if any + if (!split.first.isEmpty()) { + ChatItem *textItem = new ChatItem(this, ChatItem::text_tag, split.first); + newResponse->subItems.push_back(textItem); + } + + // Add the toolcall + Q_ASSERT(!split.second.isEmpty()); + ChatItem *toolCallItem = new ChatItem(this, ChatItem::tool_call_tag, split.second); + toolCallItem->isCurrentResponse = true; + newResponse->subItems.push_back(toolCallItem); + + // Add new response and reset our value + currentResponse->subItems.push_back(newResponse); + currentResponse->value = QString(); + } + + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + + Q_INVOKABLE void updateToolCall(const ToolCallInfo &toolCallInfo) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only set toolcall on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *currentResponse = m_chatItems.back(); + Q_ASSERT(currentResponse->isCurrentResponse); + + ChatItem *subResponse = currentResponse->subItems.back(); + Q_ASSERT(subResponse->type() == ChatItem::Type::Response); + Q_ASSERT(subResponse->isCurrentResponse); + + ChatItem *toolCallItem = subResponse->subItems.back(); + Q_ASSERT(toolCallItem->type() == ChatItem::Type::ToolCall); + toolCallItem->setToolCallInfo(toolCallInfo); + toolCallItem->setCurrentResponse(false); + + // Add tool response + ChatItem *toolResponseItem = new ChatItem(this, ChatItem::tool_response_tag, toolCallInfo.result); + currentResponse->subItems.push_back(toolResponseItem); + } + + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + + void clearSubItems(int index) + { + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; + if (m_chatItems.isEmpty() || m_chatItems[index]->type() != ChatItem::Type::Response) + throw std::logic_error("can only clear subitems on a chat that ends with a response"); + + ChatItem *item = m_chatItems.back(); + if (!item->subItems.empty()) { + item->subItems.clear(); + changed = true; + } + } + if (changed) { + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + } + Q_INVOKABLE void setError(bool value = true) { qsizetype index; { QMutexLocker locker(&m_mutex); - if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response) + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) throw std::logic_error("can only set error on a chat that ends with a response"); index = m_chatItems.count() - 1; auto &last = m_chatItems.back(); - if (last.isError == value) + if (last->isError == value) return; // already set - last.isError = value; + last->isError = value; } emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole}); emit hasErrorChanged(value); } + Q_INVOKABLE void copyToClipboard(int index) + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) + return; + ChatItem *item = m_chatItems.at(index); + QClipboard *clipboard = QGuiApplication::clipboard(); + clipboard->setText(item->clipboardContent(), QClipboard::Clipboard); + } + qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); } - ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; } + std::vector messageItems() const + { + // A flattened version of the chat item tree used by the backend and jinja + QMutexLocker locker(&m_mutex); + std::vector chatItems; + for (const ChatItem *item : m_chatItems) { + chatItems.reserve(chatItems.size() + item->subItems.size() + 1); + ranges::copy(item->subItems | views::transform(&ChatItem::asMessageItem), std::back_inserter(chatItems)); + chatItems.push_back(item->asMessageItem()); + } + return chatItems; + } bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); } bool serialize(QDataStream &stream, int version) const { + // FIXME: need to serialize new chatitem tree QMutexLocker locker(&m_mutex); stream << int(m_chatItems.size()); for (auto itemIt = m_chatItems.cbegin(); itemIt < m_chatItems.cend(); ++itemIt) { auto c = *itemIt; // NB: copies if (version < 11) { // move sources from their prompt to the next response - switch (c.type()) { + switch (c->type()) { using enum ChatItem::Type; case Prompt: - c.sources.clear(); - c.consolidatedSources.clear(); + c->sources.clear(); + c->consolidatedSources.clear(); break; case Response: // note: we drop sources for responseless prompts if (auto peer = getPeerUnlocked(itemIt)) { - c.sources = (*peer)->sources; - c.consolidatedSources = (*peer)->consolidatedSources; + c->sources = (**peer)->sources; + c->consolidatedSources = (**peer)->consolidatedSources; } default: ; } } - // FIXME: This 'id' should be eliminated the next time we bump serialization version. - // (Jared) This was apparently never used. - int id = 0; - stream << id; - stream << c.name; - stream << c.value; - stream << c.newResponse; - stream << c.isCurrentResponse; - stream << c.stopped; - stream << c.thumbsUpState; - stream << c.thumbsDownState; - if (version >= 11 && c.type() == ChatItem::Type::Response) - stream << c.isError; - if (version >= 8) { - stream << c.sources.size(); - for (const ResultInfo &info : c.sources) { - Q_ASSERT(!info.file.isEmpty()); - stream << info.collection; - stream << info.path; - stream << info.file; - stream << info.title; - stream << info.author; - stream << info.date; - stream << info.text; - stream << info.page; - stream << info.from; - stream << info.to; - } - } else if (version >= 3) { - QList references; - QList referencesContext; - int validReferenceNumber = 1; - for (const ResultInfo &info : c.sources) { - if (info.file.isEmpty()) - continue; - - QString reference; - { - QTextStream stream(&reference); - stream << (validReferenceNumber++) << ". "; - if (!info.title.isEmpty()) - stream << "\"" << info.title << "\". "; - if (!info.author.isEmpty()) - stream << "By " << info.author << ". "; - if (!info.date.isEmpty()) - stream << "Date: " << info.date << ". "; - stream << "In " << info.file << ". "; - if (info.page != -1) - stream << "Page " << info.page << ". "; - if (info.from != -1) { - stream << "Lines " << info.from; - if (info.to != -1) - stream << "-" << info.to; - stream << ". "; - } - stream << "[Context](context://" << validReferenceNumber - 1 << ")"; - } - references.append(reference); - referencesContext.append(info.text); - } - - stream << references.join("\n"); - stream << referencesContext; - } - if (version >= 10) { - stream << c.promptAttachments.size(); - for (const PromptAttachment &a : c.promptAttachments) { - Q_ASSERT(!a.url.isEmpty()); - stream << a.url; - stream << a.content; - } - } + c->serialize(stream, version); } return stream.status() == QDataStream::Ok; } @@ -717,165 +1047,29 @@ class ChatModel : public QAbstractListModel int size; stream >> size; int lastPromptIndex = -1; - QList chatItems; + QList chatItems; for (int i = 0; i < size; ++i) { - ChatItem c; - // FIXME: see comment in serialization about id - int id; - stream >> id; - stream >> c.name; - try { - c.type(); // check name - } catch (const std::exception &e) { - qWarning() << "ChatModel ERROR:" << e.what(); + ChatItem *c = new ChatItem(this); + if (!c->deserialize(stream, version)) { + delete c; return false; } - stream >> c.value; - if (version < 10) { - // This is deprecated and no longer used - QString prompt; - stream >> prompt; - } - stream >> c.newResponse; - stream >> c.isCurrentResponse; - stream >> c.stopped; - stream >> c.thumbsUpState; - stream >> c.thumbsDownState; - if (version >= 11 && c.type() == ChatItem::Type::Response) - stream >> c.isError; - if (version >= 8) { - qsizetype count; - stream >> count; - QList sources; - for (int i = 0; i < count; ++i) { - ResultInfo info; - stream >> info.collection; - stream >> info.path; - stream >> info.file; - stream >> info.title; - stream >> info.author; - stream >> info.date; - stream >> info.text; - stream >> info.page; - stream >> info.from; - stream >> info.to; - sources.append(info); - } - c.sources = sources; - c.consolidatedSources = consolidateSources(sources); - } else if (version >= 3) { - QString references; - QList referencesContext; - stream >> references; - stream >> referencesContext; - - if (!references.isEmpty()) { - QList sources; - QList referenceList = references.split("\n"); - - // Ignore empty lines and those that begin with "---" which is no longer used - for (auto it = referenceList.begin(); it != referenceList.end();) { - if (it->trimmed().isEmpty() || it->trimmed().startsWith("---")) - it = referenceList.erase(it); - else - ++it; - } - - Q_ASSERT(referenceList.size() == referencesContext.size()); - for (int j = 0; j < referenceList.size(); ++j) { - QString reference = referenceList[j]; - QString context = referencesContext[j]; - ResultInfo info; - QTextStream refStream(&reference); - QString dummy; - int validReferenceNumber; - refStream >> validReferenceNumber >> dummy; - // Extract title (between quotes) - if (reference.contains("\"")) { - int startIndex = reference.indexOf('"') + 1; - int endIndex = reference.indexOf('"', startIndex); - info.title = reference.mid(startIndex, endIndex - startIndex); - } - - // Extract author (after "By " and before the next period) - if (reference.contains("By ")) { - int startIndex = reference.indexOf("By ") + 3; - int endIndex = reference.indexOf('.', startIndex); - info.author = reference.mid(startIndex, endIndex - startIndex).trimmed(); - } - - // Extract date (after "Date: " and before the next period) - if (reference.contains("Date: ")) { - int startIndex = reference.indexOf("Date: ") + 6; - int endIndex = reference.indexOf('.', startIndex); - info.date = reference.mid(startIndex, endIndex - startIndex).trimmed(); - } - - // Extract file name (after "In " and before the "[Context]") - if (reference.contains("In ") && reference.contains(". [Context]")) { - int startIndex = reference.indexOf("In ") + 3; - int endIndex = reference.indexOf(". [Context]", startIndex); - info.file = reference.mid(startIndex, endIndex - startIndex).trimmed(); - } - - // Extract page number (after "Page " and before the next space) - if (reference.contains("Page ")) { - int startIndex = reference.indexOf("Page ") + 5; - int endIndex = reference.indexOf(' ', startIndex); - if (endIndex == -1) endIndex = reference.length(); - info.page = reference.mid(startIndex, endIndex - startIndex).toInt(); - } - - // Extract lines (after "Lines " and before the next space or hyphen) - if (reference.contains("Lines ")) { - int startIndex = reference.indexOf("Lines ") + 6; - int endIndex = reference.indexOf(' ', startIndex); - if (endIndex == -1) endIndex = reference.length(); - int hyphenIndex = reference.indexOf('-', startIndex); - if (hyphenIndex != -1 && hyphenIndex < endIndex) { - info.from = reference.mid(startIndex, hyphenIndex - startIndex).toInt(); - info.to = reference.mid(hyphenIndex + 1, endIndex - hyphenIndex - 1).toInt(); - } else { - info.from = reference.mid(startIndex, endIndex - startIndex).toInt(); - } - } - info.text = context; - sources.append(info); - } - - c.sources = sources; - c.consolidatedSources = consolidateSources(sources); - } - } - if (version >= 10) { - qsizetype count; - stream >> count; - QList attachments; - for (int i = 0; i < count; ++i) { - PromptAttachment a; - stream >> a.url; - stream >> a.content; - attachments.append(a); - } - c.promptAttachments = attachments; - } - - if (version < 11 && c.type() == ChatItem::Type::Response) { + if (version < 11 && c->type() == ChatItem::Type::Response) { // move sources from the response to their last prompt if (lastPromptIndex >= 0) { auto &prompt = chatItems[lastPromptIndex]; - prompt.sources = std::move(c.sources ); - prompt.consolidatedSources = std::move(c.consolidatedSources); + prompt->sources = std::move(c->sources ); + prompt->consolidatedSources = std::move(c->consolidatedSources); lastPromptIndex = -1; } else { // drop sources for promptless responses - c.sources.clear(); - c.consolidatedSources.clear(); + c->sources.clear(); + c->consolidatedSources.clear(); } } chatItems << c; - if (c.type() == ChatItem::Type::Prompt) + if (c->type() == ChatItem::Type::Prompt) lastPromptIndex = chatItems.size() - 1; } @@ -895,7 +1089,6 @@ class ChatModel : public QAbstractListModel Q_SIGNALS: void countChanged(); - void valueChanged(int index, const QString &value); void hasErrorChanged(bool value); private: @@ -904,12 +1097,12 @@ class ChatModel : public QAbstractListModel if (m_chatItems.isEmpty()) return false; auto &last = m_chatItems.back(); - return last.type() == ChatItem::Type::Response && last.isError; + return last->type() == ChatItem::Type::Response && last->isError; } private: mutable QMutex m_mutex; - QList m_chatItems; + QList m_chatItems; }; #endif // CHATMODEL_H diff --git a/gpt4all-chat/src/codeinterpreter.cpp b/gpt4all-chat/src/codeinterpreter.cpp new file mode 100644 index 000000000000..027d0249ca0b --- /dev/null +++ b/gpt4all-chat/src/codeinterpreter.cpp @@ -0,0 +1,125 @@ +#include "codeinterpreter.h" + +#include +#include +#include +#include + +QString CodeInterpreter::run(const QList ¶ms, qint64 timeout) +{ + m_error = ToolEnums::Error::NoError; + m_errorString = QString(); + + Q_ASSERT(params.count() == 1 + && params.first().name == "code" + && params.first().type == ToolEnums::ParamType::String); + + const QString code = params.first().value.toString(); + + QThread workerThread; + CodeInterpreterWorker worker; + worker.moveToThread(&workerThread); + connect(&worker, &CodeInterpreterWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); + connect(&workerThread, &QThread::started, [&worker, code]() { + worker.request(code); + }); + workerThread.start(); + bool timedOut = !workerThread.wait(timeout); + if (timedOut) { + worker.interrupt(); // thread safe + m_error = ToolEnums::Error::TimeoutError; + } + workerThread.quit(); + workerThread.wait(); + if (!timedOut) { + m_error = worker.error(); + m_errorString = worker.errorString(); + } + return worker.response(); +} + +QList CodeInterpreter::parameters() const +{ + return {{ + "code", + ToolEnums::ParamType::String, + "javascript code to compute", + true + }}; +} + +QString CodeInterpreter::symbolicFormat() const +{ + return "{human readable plan to complete the task}\n" + ToolCallConstants::CodeInterpreterPrefix + "{code}\n" + ToolCallConstants::CodeInterpreterSuffix; +} + +QString CodeInterpreter::examplePrompt() const +{ + return R"(Write code to check if a number is prime, use that to see if the number 7 is prime)"; +} + +QString CodeInterpreter::exampleCall() const +{ + static const QString example = R"(function isPrime(n) { + if (n <= 1) { + return false; + } + for (let i = 2; i <= Math.sqrt(n); i++) { + if (n % i === 0) { + return false; + } + } + return true; +} + +const number = 7; +console.log(`The number ${number} is prime: ${isPrime(number)}`); +)"; + + return "Certainly! Let's compute the answer to whether the number 7 is prime.\n" + ToolCallConstants::CodeInterpreterPrefix + example + ToolCallConstants::CodeInterpreterSuffix; +} + +QString CodeInterpreter::exampleReply() const +{ + return R"("The computed result shows that 7 is a prime number.)"; +} + +CodeInterpreterWorker::CodeInterpreterWorker() + : QObject(nullptr) +{ +} + +void CodeInterpreterWorker::request(const QString &code) +{ + JavaScriptConsoleCapture consoleCapture; + QJSValue consoleObject = m_engine.newQObject(&consoleCapture); + m_engine.globalObject().setProperty("console", consoleObject); + + const QJSValue result = m_engine.evaluate(code); + QString resultString = result.isUndefined() ? QString() : result.toString(); + + // NOTE: We purposely do not set the m_error or m_errorString for the code interpreter since + // we *want* the model to see the response has an error so it can hopefully correct itself. The + // error member variables are intended for tools that have error conditions that cannot be corrected. + // For instance, a tool depending upon the network might set these error variables if the network + // is not available. + if (result.isError()) { + const QStringList lines = code.split('\n'); + const int line = result.property("lineNumber").toInt(); + const int index = line - 1; + const QString lineContent = (index >= 0 && index < lines.size()) ? lines.at(index) : "Line not found in code."; + resultString = QString("Uncaught exception at line %1: %2\n\t%3") + .arg(line) + .arg(result.toString()) + .arg(lineContent); + m_error = ToolEnums::Error::UnknownError; + m_errorString = resultString; + } + + if (resultString.isEmpty()) + resultString = consoleCapture.output; + else if (!consoleCapture.output.isEmpty()) + resultString += "\n" + consoleCapture.output; + m_response = resultString; + emit finished(); +} diff --git a/gpt4all-chat/src/codeinterpreter.h b/gpt4all-chat/src/codeinterpreter.h new file mode 100644 index 000000000000..41d2f983f4e1 --- /dev/null +++ b/gpt4all-chat/src/codeinterpreter.h @@ -0,0 +1,84 @@ +#ifndef CODEINTERPRETER_H +#define CODEINTERPRETER_H + +#include "tool.h" +#include "toolcallparser.h" + +#include +#include +#include +#include + +class JavaScriptConsoleCapture : public QObject +{ + Q_OBJECT +public: + QString output; + Q_INVOKABLE void log(const QString &message) + { + const int maxLength = 1024; + if (output.length() >= maxLength) + return; + + if (output.length() + message.length() + 1 > maxLength) { + static const QString trunc = "\noutput truncated at " + QString::number(maxLength) + " characters..."; + int remainingLength = maxLength - output.length(); + if (remainingLength > 0) + output.append(message.left(remainingLength)); + output.append(trunc); + Q_ASSERT(output.length() > maxLength); + } else { + output.append(message + "\n"); + } + } +}; + +class CodeInterpreterWorker : public QObject { + Q_OBJECT +public: + CodeInterpreterWorker(); + virtual ~CodeInterpreterWorker() {} + + QString response() const { return m_response; } + + void request(const QString &code); + void interrupt() { m_engine.setInterrupted(true); } + ToolEnums::Error error() const { return m_error; } + QString errorString() const { return m_errorString; } + +Q_SIGNALS: + void finished(); + +private: + QJSEngine m_engine; + QString m_response; + ToolEnums::Error m_error = ToolEnums::Error::NoError; + QString m_errorString; +}; + +class CodeInterpreter : public Tool +{ + Q_OBJECT +public: + explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {} + virtual ~CodeInterpreter() {} + + QString run(const QList ¶ms, qint64 timeout = 2000) override; + ToolEnums::Error error() const override { return m_error; } + QString errorString() const override { return m_errorString; } + + QString name() const override { return tr("Code Interpreter"); } + QString description() const override { return tr("compute javascript code using console.log as output"); } + QString function() const override { return ToolCallConstants::CodeInterpreterFunction; } + QList parameters() const override; + virtual QString symbolicFormat() const override; + QString examplePrompt() const override; + QString exampleCall() const override; + QString exampleReply() const override; + +private: + ToolEnums::Error m_error = ToolEnums::Error::NoError; + QString m_errorString; +}; + +#endif // CODEINTERPRETER_H diff --git a/gpt4all-chat/src/jinja_helpers.cpp b/gpt4all-chat/src/jinja_helpers.cpp index 826dfb01e812..133e58bc95ba 100644 --- a/gpt4all-chat/src/jinja_helpers.cpp +++ b/gpt4all-chat/src/jinja_helpers.cpp @@ -51,12 +51,14 @@ auto JinjaMessage::keys() const -> const std::unordered_set & static const std::unordered_set userKeys { "role", "content", "sources", "prompt_attachments" }; switch (m_item->type()) { - using enum ChatItem::Type; + using enum MessageItem::Type; case System: case Response: + case ToolResponse: return baseKeys; case Prompt: return userKeys; + break; } Q_UNREACHABLE(); } @@ -67,16 +69,18 @@ bool operator==(const JinjaMessage &a, const JinjaMessage &b) return true; const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item); auto type = ia.type(); - if (type != ib.type() || ia.value != ib.value) + if (type != ib.type() || ia.content() != ib.content()) return false; switch (type) { - using enum ChatItem::Type; + using enum MessageItem::Type; case System: case Response: + case ToolResponse: return true; case Prompt: - return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments; + return ia.sources() == ib.sources() && ia.promptAttachments() == ib.promptAttachments(); + break; } Q_UNREACHABLE(); } @@ -84,26 +88,28 @@ bool operator==(const JinjaMessage &a, const JinjaMessage &b) const JinjaFieldMap JinjaMessage::s_fields = { { "role", [](auto &m) { switch (m.item().type()) { - using enum ChatItem::Type; + using enum MessageItem::Type; case System: return "system"sv; case Prompt: return "user"sv; case Response: return "assistant"sv; + case ToolResponse: return "tool"sv; + break; } Q_UNREACHABLE(); } }, { "content", [](auto &m) { - if (m.version() == 0 && m.item().type() == ChatItem::Type::Prompt) + if (m.version() == 0 && m.item().type() == MessageItem::Type::Prompt) return m.item().bakedPrompt().toStdString(); - return m.item().value.toStdString(); + return m.item().content().toStdString(); } }, { "sources", [](auto &m) { - auto sources = m.item().sources | views::transform([](auto &r) { + auto sources = m.item().sources() | views::transform([](auto &r) { return jinja2::GenericMap([map = std::make_shared(r)] { return map.get(); }); }); return jinja2::ValuesList(sources.begin(), sources.end()); } }, { "prompt_attachments", [](auto &m) { - auto attachments = m.item().promptAttachments | views::transform([](auto &pa) { + auto attachments = m.item().promptAttachments() | views::transform([](auto &pa) { return jinja2::GenericMap([map = std::make_shared(pa)] { return map.get(); }); }); return jinja2::ValuesList(attachments.begin(), attachments.end()); diff --git a/gpt4all-chat/src/jinja_helpers.h b/gpt4all-chat/src/jinja_helpers.h index a196b47f8fdf..f7f4ff9b8b61 100644 --- a/gpt4all-chat/src/jinja_helpers.h +++ b/gpt4all-chat/src/jinja_helpers.h @@ -86,12 +86,12 @@ class JinjaPromptAttachment : public JinjaHelper { class JinjaMessage : public JinjaHelper { public: - explicit JinjaMessage(uint version, const ChatItem &item) noexcept + explicit JinjaMessage(uint version, const MessageItem &item) noexcept : m_version(version), m_item(&item) {} const JinjaMessage &value () const { return *this; } uint version() const { return m_version; } - const ChatItem &item () const { return *m_item; } + const MessageItem &item () const { return *m_item; } size_t GetSize() const override { return keys().size(); } bool HasValue(const std::string &name) const override { return keys().contains(name); } @@ -107,7 +107,7 @@ class JinjaMessage : public JinjaHelper { private: static const JinjaFieldMap s_fields; uint m_version; - const ChatItem *m_item; + const MessageItem *m_item; friend class JinjaHelper; friend bool operator==(const JinjaMessage &a, const JinjaMessage &b); diff --git a/gpt4all-chat/src/main.cpp b/gpt4all-chat/src/main.cpp index 0fc23be3c961..1050e590879d 100644 --- a/gpt4all-chat/src/main.cpp +++ b/gpt4all-chat/src/main.cpp @@ -7,6 +7,7 @@ #include "modellist.h" #include "mysettings.h" #include "network.h" +#include "toolmodel.h" #include #include @@ -116,6 +117,8 @@ int main(int argc, char *argv[]) qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance()); qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); + qmlRegisterSingletonInstance("toollist", 1, 0, "ToolList", ToolModel::globalInstance()); + qmlRegisterUncreatableMetaObject(ToolEnums::staticMetaObject, "toolenums", 1, 0, "ToolEnums", "Error: only enums"); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); { diff --git a/gpt4all-chat/src/modellist.cpp b/gpt4all-chat/src/modellist.cpp index 23aa7dc48b0a..93b22fb94d1e 100644 --- a/gpt4all-chat/src/modellist.cpp +++ b/gpt4all-chat/src/modellist.cpp @@ -473,14 +473,24 @@ GPT4AllDownloadableModels::GPT4AllDownloadableModels(QObject *parent) connect(this, &GPT4AllDownloadableModels::modelReset, this, &GPT4AllDownloadableModels::countChanged); } +void GPT4AllDownloadableModels::filter(const QVector &keywords) +{ + m_keywords = keywords; + invalidateFilter(); +} + bool GPT4AllDownloadableModels::filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const { QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); - bool hasDescription = !sourceModel()->data(index, ModelList::DescriptionRole).toString().isEmpty(); + const QString description = sourceModel()->data(index, ModelList::DescriptionRole).toString(); + bool hasDescription = !description.isEmpty(); bool isClone = sourceModel()->data(index, ModelList::IsCloneRole).toBool(); bool isDiscovered = sourceModel()->data(index, ModelList::IsDiscoveredRole).toBool(); - return !isDiscovered && hasDescription && !isClone; + bool satisfiesKeyword = m_keywords.isEmpty(); + for (const QString &k : m_keywords) + satisfiesKeyword = description.contains(k) ? true : satisfiesKeyword; + return !isDiscovered && hasDescription && !isClone && satisfiesKeyword; } int GPT4AllDownloadableModels::count() const diff --git a/gpt4all-chat/src/modellist.h b/gpt4all-chat/src/modellist.h index 0e22b931d934..0bcc97b484ee 100644 --- a/gpt4all-chat/src/modellist.h +++ b/gpt4all-chat/src/modellist.h @@ -302,11 +302,16 @@ class GPT4AllDownloadableModels : public QSortFilterProxyModel explicit GPT4AllDownloadableModels(QObject *parent); int count() const; + Q_INVOKABLE void filter(const QVector &keywords); + Q_SIGNALS: void countChanged(); protected: bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override; + +private: + QVector m_keywords; }; class HuggingFaceDownloadableModels : public QSortFilterProxyModel diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 222f793c234b..20a3fa7a4d40 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -694,7 +694,8 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) promptCtx, /*usedLocalDocs*/ false); } catch (const std::exception &e) { - emit responseChanged(e.what()); + m_chatModel->setResponseValue(e.what()); + m_chatModel->setError(); emit responseStopped(0); return makeError(QHttpServerResponder::StatusCode::InternalServerError); } @@ -772,16 +773,16 @@ auto Server::handleChatRequest(const ChatRequest &request) Q_ASSERT(!request.messages.isEmpty()); // adds prompt/response items to GUI - QList chatItems; + std::vector messages; for (auto &message : request.messages) { using enum ChatRequest::Message::Role; switch (message.role) { - case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break; - case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break; - case Assistant: chatItems.emplace_back(ChatItem::response_tag, message.content); break; + case System: messages.push_back({ MessageInput::Type::System, message.content }); break; + case User: messages.push_back({ MessageInput::Type::Prompt, message.content }); break; + case Assistant: messages.push_back({ MessageInput::Type::Response, message.content }); break; } } - auto subrange = m_chatModel->appendResponseWithHistory(chatItems); + auto startOffset = m_chatModel->appendResponseWithHistory(messages); // FIXME(jared): taking parameters from the UI inhibits reproducibility of results LLModel::PromptContext promptCtx { @@ -801,9 +802,10 @@ auto Server::handleChatRequest(const ChatRequest &request) for (int i = 0; i < request.n; ++i) { ChatPromptResult result; try { - result = promptInternalChat(m_collections, promptCtx, subrange); + result = promptInternalChat(m_collections, promptCtx, startOffset); } catch (const std::exception &e) { - emit responseChanged(e.what()); + m_chatModel->setResponseValue(e.what()); + m_chatModel->setError(); emit responseStopped(0); return makeError(QHttpServerResponder::StatusCode::InternalServerError); } diff --git a/gpt4all-chat/src/tool.cpp b/gpt4all-chat/src/tool.cpp new file mode 100644 index 000000000000..74975d2830c8 --- /dev/null +++ b/gpt4all-chat/src/tool.cpp @@ -0,0 +1,74 @@ +#include "tool.h" + +#include + +#include + +jinja2::Value Tool::jinjaValue() const +{ + jinja2::ValuesList paramList; + const QList p = parameters(); + for (auto &info : p) { + std::string typeStr; + switch (info.type) { + using enum ToolEnums::ParamType; + case String: typeStr = "string"; break; + case Number: typeStr = "number"; break; + case Integer: typeStr = "integer"; break; + case Object: typeStr = "object"; break; + case Array: typeStr = "array"; break; + case Boolean: typeStr = "boolean"; break; + case Null: typeStr = "null"; break; + } + jinja2::ValuesMap infoMap { + { "name", info.name.toStdString() }, + { "type", typeStr}, + { "description", info.description.toStdString() }, + { "required", info.required } + }; + paramList.push_back(infoMap); + } + + jinja2::ValuesMap params { + { "name", name().toStdString() }, + { "description", description().toStdString() }, + { "function", function().toStdString() }, + { "parameters", paramList }, + { "symbolicFormat", symbolicFormat().toStdString() }, + { "examplePrompt", examplePrompt().toStdString() }, + { "exampleCall", exampleCall().toStdString() }, + { "exampleReply", exampleReply().toStdString() } + }; + return params; +} + +void ToolCallInfo::serialize(QDataStream &stream, int version) +{ + stream << name; + stream << params.size(); + for (auto param : params) { + stream << param.name; + stream << param.type; + stream << param.value; + } + stream << result; + stream << error; + stream << errorString; +} + +bool ToolCallInfo::deserialize(QDataStream &stream, int version) +{ + stream >> name; + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ToolParam p; + stream >> p.name; + stream >> p.type; + stream >> p.value; + } + stream >> result; + stream >> error; + stream >> errorString; + return true; +} diff --git a/gpt4all-chat/src/tool.h b/gpt4all-chat/src/tool.h new file mode 100644 index 000000000000..08c058eb5e66 --- /dev/null +++ b/gpt4all-chat/src/tool.h @@ -0,0 +1,127 @@ +#ifndef TOOL_H +#define TOOL_H + +#include +#include +#include +#include +#include + +#include + +namespace ToolEnums +{ + Q_NAMESPACE + enum class Error + { + NoError = 0, + TimeoutError = 2, + UnknownError = 499, + }; + Q_ENUM_NS(Error) + + enum class ParamType { String, Number, Integer, Object, Array, Boolean, Null }; // json schema types + Q_ENUM_NS(ParamType) + + enum class ParseState { + None, + InStart, + Partial, + Complete, + }; + Q_ENUM_NS(ParseState) +} + +struct ToolParamInfo +{ + QString name; + ToolEnums::ParamType type; + QString description; + bool required; +}; +Q_DECLARE_METATYPE(ToolParamInfo) + +struct ToolParam +{ + QString name; + ToolEnums::ParamType type; + QVariant value; + bool operator==(const ToolParam& other) const + { + return name == other.name && type == other.type && value == other.value; + } +}; +Q_DECLARE_METATYPE(ToolParam) + +struct ToolCallInfo +{ + QString name; + QList params; + QString result; + ToolEnums::Error error = ToolEnums::Error::NoError; + QString errorString; + + void serialize(QDataStream &stream, int version); + bool deserialize(QDataStream &stream, int version); + + bool operator==(const ToolCallInfo& other) const + { + return name == other.name && result == other.result && params == other.params + && error == other.error && errorString == other.errorString; + } +}; +Q_DECLARE_METATYPE(ToolCallInfo) + +class Tool : public QObject +{ + Q_OBJECT + Q_PROPERTY(QString name READ name CONSTANT) + Q_PROPERTY(QString description READ description CONSTANT) + Q_PROPERTY(QString function READ function CONSTANT) + Q_PROPERTY(QList parameters READ parameters CONSTANT) + Q_PROPERTY(QString examplePrompt READ examplePrompt CONSTANT) + Q_PROPERTY(QString exampleCall READ exampleCall CONSTANT) + Q_PROPERTY(QString exampleReply READ exampleReply CONSTANT) + +public: + Tool() : QObject(nullptr) {} + virtual ~Tool() {} + + virtual QString run(const QList ¶ms, qint64 timeout = 2000) = 0; + + // Tools should set these if they encounter errors. For instance, a tool depending upon the network + // might set these error variables if the network is not available. + virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } + virtual QString errorString() const { return QString(); } + + // [Required] Human readable name of the tool. + virtual QString name() const = 0; + + // [Required] Human readable description of what the tool does. Use this tool to: {{description}} + virtual QString description() const = 0; + + // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + virtual QString function() const = 0; + + // [Optional] List describing the tool's parameters. An empty list specifies no parameters. + virtual QList parameters() const { return {}; } + + // [Optional] The symbolic format of the toolcall. + virtual QString symbolicFormat() const { return QString(); } + + // [Optional] A human generated example of a prompt that could result in this tool being called. + virtual QString examplePrompt() const { return QString(); } + + // [Optional] An example of this tool call that pairs with the example query. It should be the + // complete string that the model must generate. + virtual QString exampleCall() const { return QString(); } + + // [Optional] An example of the reply the model might generate given the result of the tool call. + virtual QString exampleReply() const { return QString(); } + + bool operator==(const Tool &other) const { return function() == other.function(); } + + jinja2::Value jinjaValue() const; +}; + +#endif // TOOL_H diff --git a/gpt4all-chat/src/toolcallparser.cpp b/gpt4all-chat/src/toolcallparser.cpp new file mode 100644 index 000000000000..7649c21d9fe8 --- /dev/null +++ b/gpt4all-chat/src/toolcallparser.cpp @@ -0,0 +1,111 @@ +#include "toolcallparser.h" + +#include +#include +#include + +#include + +static const QString ToolCallStart = ToolCallConstants::CodeInterpreterTag; +static const QString ToolCallEnd = ToolCallConstants::CodeInterpreterEndTag; + +ToolCallParser::ToolCallParser() +{ + reset(); +} + +void ToolCallParser::reset() +{ + // Resets the search state, but not the buffer or global state + resetSearchState(); + + // These are global states maintained between update calls + m_buffer.clear(); + m_hasSplit = false; +} + +void ToolCallParser::resetSearchState() +{ + m_expected = ToolCallStart.at(0); + m_expectedIndex = 0; + m_state = ToolEnums::ParseState::None; + m_toolCall.clear(); + m_endTagBuffer.clear(); + m_startIndex = -1; +} + +// This method is called with an arbitrary string and a current state. This method should take the +// current state into account and then parse through the update character by character to arrive at +// the new state. +void ToolCallParser::update(const QString &update) +{ + Q_ASSERT(m_state != ToolEnums::ParseState::Complete); + if (m_state == ToolEnums::ParseState::Complete) { + qWarning() << "ERROR: ToolCallParser::update already found a complete toolcall!"; + return; + } + + m_buffer.append(update); + + for (size_t i = m_buffer.size() - update.size(); i < m_buffer.size(); ++i) { + const QChar c = m_buffer[i]; + const bool foundMatch = m_expected.isNull() || c == m_expected; + if (!foundMatch) { + resetSearchState(); + continue; + } + + switch (m_state) { + case ToolEnums::ParseState::None: + { + m_expectedIndex = 1; + m_expected = ToolCallStart.at(1); + m_state = ToolEnums::ParseState::InStart; + m_startIndex = i; + break; + } + case ToolEnums::ParseState::InStart: + { + if (m_expectedIndex == ToolCallStart.size() - 1) { + m_expectedIndex = 0; + m_expected = QChar(); + m_state = ToolEnums::ParseState::Partial; + } else { + ++m_expectedIndex; + m_expected = ToolCallStart.at(m_expectedIndex); + } + break; + } + case ToolEnums::ParseState::Partial: + { + m_toolCall.append(c); + m_endTagBuffer.append(c); + if (m_endTagBuffer.size() > ToolCallEnd.size()) + m_endTagBuffer.remove(0, 1); + if (m_endTagBuffer == ToolCallEnd) { + m_toolCall.chop(ToolCallEnd.size()); + m_state = ToolEnums::ParseState::Complete; + m_endTagBuffer.clear(); + } + } + case ToolEnums::ParseState::Complete: + { + // Already complete, do nothing further + break; + } + } + } +} + +QPair ToolCallParser::split() +{ + Q_ASSERT(m_state == ToolEnums::ParseState::Partial + || m_state == ToolEnums::ParseState::Complete); + + Q_ASSERT(m_startIndex >= 0); + m_hasSplit = true; + const QString beforeToolCall = m_buffer.left(m_startIndex); + m_buffer = m_buffer.mid(m_startIndex); + m_startIndex = 0; + return { beforeToolCall, m_buffer }; +} diff --git a/gpt4all-chat/src/toolcallparser.h b/gpt4all-chat/src/toolcallparser.h new file mode 100644 index 000000000000..855cb6b7d099 --- /dev/null +++ b/gpt4all-chat/src/toolcallparser.h @@ -0,0 +1,47 @@ +#ifndef TOOLCALLPARSER_H +#define TOOLCALLPARSER_H + +#include "tool.h" + +#include +#include +#include + +namespace ToolCallConstants +{ + const QString CodeInterpreterFunction = R"(javascript_interpret)"; + const QString CodeInterpreterTag = R"(<)" + CodeInterpreterFunction + R"(>)"; + const QString CodeInterpreterEndTag = R"()"; + const QString CodeInterpreterPrefix = CodeInterpreterTag + "\n```javascript\n"; + const QString CodeInterpreterSuffix = "```\n" + CodeInterpreterEndTag; +} + +class ToolCallParser +{ +public: + ToolCallParser(); + void reset(); + void update(const QString &update); + QString buffer() const { return m_buffer; } + QString toolCall() const { return m_toolCall; } + int startIndex() const { return m_startIndex; } + ToolEnums::ParseState state() const { return m_state; } + + // Splits + QPair split(); + bool hasSplit() const { return m_hasSplit; } + +private: + void resetSearchState(); + + QChar m_expected; + int m_expectedIndex; + ToolEnums::ParseState m_state; + QString m_buffer; + QString m_toolCall; + QString m_endTagBuffer; + int m_startIndex; + bool m_hasSplit; +}; + +#endif // TOOLCALLPARSER_H diff --git a/gpt4all-chat/src/toolmodel.cpp b/gpt4all-chat/src/toolmodel.cpp new file mode 100644 index 000000000000..50d3369cf3a9 --- /dev/null +++ b/gpt4all-chat/src/toolmodel.cpp @@ -0,0 +1,31 @@ +#include "toolmodel.h" + +#include "codeinterpreter.h" + +#include +#include +#include + +class MyToolModel: public ToolModel { }; +Q_GLOBAL_STATIC(MyToolModel, toolModelInstance) +ToolModel *ToolModel::globalInstance() +{ + return toolModelInstance(); +} + +ToolModel::ToolModel() + : QAbstractListModel(nullptr) { + + QCoreApplication::instance()->installEventFilter(this); + + Tool* codeInterpreter = new CodeInterpreter; + m_tools.append(codeInterpreter); + m_toolMap.insert(codeInterpreter->function(), codeInterpreter); +} + +bool ToolModel::eventFilter(QObject *obj, QEvent *ev) +{ + if (obj == QCoreApplication::instance() && ev->type() == QEvent::LanguageChange) + emit dataChanged(index(0, 0), index(m_tools.size() - 1, 0)); + return false; +} diff --git a/gpt4all-chat/src/toolmodel.h b/gpt4all-chat/src/toolmodel.h new file mode 100644 index 000000000000..b20e39ccffdf --- /dev/null +++ b/gpt4all-chat/src/toolmodel.h @@ -0,0 +1,110 @@ +#ifndef TOOLMODEL_H +#define TOOLMODEL_H + +#include "tool.h" + +#include +#include +#include +#include +#include +#include +#include + +class ToolModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + +public: + static ToolModel *globalInstance(); + + enum Roles { + NameRole = Qt::UserRole + 1, + DescriptionRole, + FunctionRole, + ParametersRole, + SymbolicFormatRole, + ExamplePromptRole, + ExampleCallRole, + ExampleReplyRole, + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_tools.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_tools.size()) + return QVariant(); + + const Tool *item = m_tools.at(index.row()); + switch (role) { + case NameRole: + return item->name(); + case DescriptionRole: + return item->description(); + case FunctionRole: + return item->function(); + case ParametersRole: + return QVariant::fromValue(item->parameters()); + case SymbolicFormatRole: + return item->symbolicFormat(); + case ExamplePromptRole: + return item->examplePrompt(); + case ExampleCallRole: + return item->exampleCall(); + case ExampleReplyRole: + return item->exampleReply(); + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[NameRole] = "name"; + roles[DescriptionRole] = "description"; + roles[FunctionRole] = "function"; + roles[ParametersRole] = "parameters"; + roles[SymbolicFormatRole] = "symbolicFormat"; + roles[ExamplePromptRole] = "examplePrompt"; + roles[ExampleCallRole] = "exampleCall"; + roles[ExampleReplyRole] = "exampleReply"; + return roles; + } + + Q_INVOKABLE Tool* get(int index) const + { + if (index < 0 || index >= m_tools.size()) return nullptr; + return m_tools.at(index); + } + + Q_INVOKABLE Tool *get(const QString &id) const + { + if (!m_toolMap.contains(id)) return nullptr; + return m_toolMap.value(id); + } + + int count() const { return m_tools.size(); } + +Q_SIGNALS: + void countChanged(); + void valueChanged(int index, const QString &value); + +protected: + bool eventFilter(QObject *obj, QEvent *ev) override; + +private: + explicit ToolModel(); + ~ToolModel() {} + friend class MyToolModel; + QList m_tools; + QHash m_toolMap; +}; + +#endif // TOOLMODEL_H