diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java index b73e9eb16..a66e81630 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java @@ -4,19 +4,15 @@ import com.fasterxml.jackson.databind.ObjectMapper; import ee.carlrobert.codegpt.events.CodeGPTEvent; import ee.carlrobert.codegpt.settings.GeneralSettings; -import ee.carlrobert.codegpt.settings.GeneralSettingsState; import ee.carlrobert.codegpt.telemetry.TelemetryAction; import ee.carlrobert.llm.client.openai.completion.ErrorDetails; import ee.carlrobert.llm.completion.CompletionEventListener; -import java.util.List; -import javax.swing.SwingWorker; import okhttp3.sse.EventSource; public class CompletionRequestHandler { private final StringBuilder messageBuilder = new StringBuilder(); private final CompletionResponseEventListener completionResponseEventListener; - private SwingWorker swingWorker; private EventSource eventSource; public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) { @@ -24,15 +20,21 @@ public CompletionRequestHandler(CompletionResponseEventListener completionRespon } public void call(CallParameters callParameters) { - swingWorker = new CompletionRequestWorker(callParameters); - swingWorker.execute(); + try { + eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters)); + } catch (TotalUsageExceededException e) { + completionResponseEventListener.handleTokensExceeded( + callParameters.getConversation(), + callParameters.getMessage()); + } finally { + sendInfo(callParameters); + } } public void cancel() { if (eventSource != null) { eventSource.cancel(); } - swingWorker.cancel(true); } private EventSource startCall( @@ -57,79 +59,48 @@ private void handleCallException(Throwable ex) { completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex); } - private class CompletionRequestWorker extends SwingWorker { + class RequestCompletionEventListener implements CompletionEventListener { private final CallParameters callParameters; - public CompletionRequestWorker(CallParameters callParameters) { + public RequestCompletionEventListener(CallParameters callParameters) { this.callParameters = callParameters; } - protected Void doInBackground() { - var settings = GeneralSettings.getCurrentState(); + @Override + public void onEvent(String data) { try { - eventSource = startCall(callParameters, new RequestCompletionEventListener()); - } catch (TotalUsageExceededException e) { - completionResponseEventListener.handleTokensExceeded( - callParameters.getConversation(), - callParameters.getMessage()); - } finally { - sendInfo(settings); + var event = new ObjectMapper().readValue(data, CodeGPTEvent.class); + completionResponseEventListener.handleCodeGPTEvent(event); + } catch (JsonProcessingException e) { + // ignore } - return null; } - protected void process(List chunks) { + @Override + public void onMessage(String message, EventSource eventSource) { + messageBuilder.append(message); callParameters.getMessage().setResponse(messageBuilder.toString()); - for (String text : chunks) { - messageBuilder.append(text); - completionResponseEventListener.handleMessage(text); - } + completionResponseEventListener.handleMessage(message); } - class RequestCompletionEventListener implements CompletionEventListener { - - @Override - public void onEvent(String data) { - try { - var event = new ObjectMapper().readValue(data, CodeGPTEvent.class); - completionResponseEventListener.handleCodeGPTEvent(event); - } catch (JsonProcessingException e) { - // ignore - } - } - - @Override - public void onMessage(String message, EventSource eventSource) { - publish(message); - } - - @Override - public void onComplete(StringBuilder messageBuilder) { - completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters); - } - - @Override - public void onCancelled(StringBuilder messageBuilder) { - completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters); - } + @Override + public void onComplete(StringBuilder messageBuilder) { + completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters); + } - @Override - public void onError(ErrorDetails error, Throwable ex) { - try { - completionResponseEventListener.handleError(error, ex); - } finally { - sendError(error, ex); - } - } + @Override + public void onCancelled(StringBuilder messageBuilder) { + completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters); } - private void sendInfo(GeneralSettingsState settings) { - TelemetryAction.COMPLETION.createActionMessage() - .property("conversationId", callParameters.getConversation().getId().toString()) - .property("model", callParameters.getConversation().getModel()) - .property("service", settings.getSelectedService().getCode().toLowerCase()) - .send(); + @Override + public void onError(ErrorDetails error, Throwable ex) { + try { + completionResponseEventListener.handleError(error, ex); + } finally { + sendError(error, ex); + } } private void sendError(ErrorDetails error, Throwable ex) { @@ -147,4 +118,12 @@ private void sendError(ErrorDetails error, Throwable ex) { telemetryMessage.send(); } } + + private void sendInfo(CallParameters callParameters) { + TelemetryAction.COMPLETION.createActionMessage() + .property("conversationId", callParameters.getConversation().getId().toString()) + .property("model", callParameters.getConversation().getModel()) + .property("service", GeneralSettings.getSelectedService().getCode().toLowerCase()) + .send(); + } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index f9f2a5a84..b324025e8 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -5,12 +5,12 @@ import static java.lang.String.format; import com.intellij.openapi.Disposable; +import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.project.Project; import com.intellij.ui.JBColor; import com.intellij.util.ui.JBUI; import ee.carlrobert.codegpt.CodeGPTKeys; -import ee.carlrobert.codegpt.EncodingManager; import ee.carlrobert.codegpt.ReferencedFile; import ee.carlrobert.codegpt.actions.ActionType; import ee.carlrobert.codegpt.completions.CallParameters; @@ -41,7 +41,6 @@ import java.util.UUID; import javax.swing.JComponent; import javax.swing.JPanel; -import javax.swing.SwingUtilities; import kotlin.Unit; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -111,7 +110,7 @@ public void sendMessage(Message message) { } public void sendMessage(Message message, ConversationType conversationType) { - SwingUtilities.invokeLater(() -> { + ApplicationManager.getApplication().invokeLater(() -> { var referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES); var chatToolWindowPanel = project.getService(ChatToolWindowContentManager.class) .tryFindChatToolWindowPanel(); @@ -127,6 +126,7 @@ public void sendMessage(Message message, ConversationType conversationType) { chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project)); } + totalTokensPanel.updateConversationTokens(conversation); var userMessagePanel = new UserMessagePanel(project, message, this); var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project); @@ -142,7 +142,6 @@ public void sendMessage(Message message, ConversationType conversationType) { var responsePanel = createResponsePanel(message, conversationType); messagePanel.add(responsePanel); - updateTotalTokens(message); call(callParameters, responsePanel); }); } @@ -163,12 +162,6 @@ private CallParameters getCallParameters( return callParameters; } - private void updateTotalTokens(Message message) { - int userPromptTokens = EncodingManager.getInstance().countTokens(message.getPrompt()); - int conversationTokens = EncodingManager.getInstance().countConversationTokens(conversation); - totalTokensPanel.updateConversationTokens(conversationTokens + userPromptTokens); - } - private ResponsePanel createResponsePanel(Message message, ConversationType conversationType) { return new ResponsePanel() .withReloadAction(() -> reloadMessage(message, conversation, conversationType)) diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java index e909da3d3..0d6edc6ac 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java @@ -18,7 +18,6 @@ import ee.carlrobert.codegpt.ui.OverlayUtil; import ee.carlrobert.codegpt.ui.textarea.UserInputPanel; import ee.carlrobert.llm.client.openai.completion.ErrorDetails; -import javax.swing.SwingUtilities; abstract class ToolWindowCompletionResponseEventListener implements CompletionResponseEventListener { @@ -54,17 +53,16 @@ public ToolWindowCompletionResponseEventListener( @Override public void handleMessage(String partialMessage) { try { - ApplicationManager.getApplication() - .invokeLater(() -> { - responseContainer.update(partialMessage); - messageBuilder.append(partialMessage); - - if (!completed) { - var ongoingTokens = encodingManager.countTokens(messageBuilder.toString()); - totalTokensPanel.update( - totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens); - } - }); + responseContainer.update(partialMessage); + messageBuilder.append(partialMessage); + + if (!completed) { + var ongoingTokens = encodingManager.countTokens(messageBuilder.toString()); + ApplicationManager.getApplication().invokeLater(() -> { + totalTokensPanel.update( + totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens); + }); + } } catch (Exception e) { responseContainer.displayError("Something went wrong."); throw new RuntimeException("Error while updating the content", e); @@ -73,7 +71,7 @@ public void handleMessage(String partialMessage) { @Override public void handleError(ErrorDetails error, Throwable ex) { - SwingUtilities.invokeLater(() -> { + ApplicationManager.getApplication().invokeLater(() -> { try { if ("insufficient_quota".equals(error.getCode())) { responseContainer.displayQuotaExceeded(); @@ -90,7 +88,7 @@ public void handleError(ErrorDetails error, Throwable ex) { @Override public void handleTokensExceeded(Conversation conversation, Message message) { - SwingUtilities.invokeLater(() -> { + ApplicationManager.getApplication().invokeLater(() -> { var answer = OverlayUtil.showTokenLimitExceededDialog(); if (answer == OK) { TelemetryAction.IDE_ACTION.createActionMessage() @@ -110,7 +108,7 @@ public void handleTokensExceeded(Conversation conversation, Message message) { public void handleCompleted(String fullMessage, CallParameters callParameters) { conversationService.saveMessage(fullMessage, callParameters); - SwingUtilities.invokeLater(() -> { + ApplicationManager.getApplication().invokeLater(() -> { try { responsePanel.enableActions(); totalTokensPanel.updateUserPromptTokens(textArea.getText()); @@ -123,7 +121,8 @@ public void handleCompleted(String fullMessage, CallParameters callParameters) { @Override public void handleCodeGPTEvent(CodeGPTEvent event) { - responseContainer.displayWebSearchItem(event.getEvent().getDetails()); + ApplicationManager.getApplication().invokeLater(() -> + responseContainer.displayWebSearchItem(event.getEvent().getDetails())); } private void stopStreaming(ChatMessageResponseBody responseContainer) { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java index 8131f186e..0695f192a 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java @@ -6,6 +6,7 @@ import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED; import com.intellij.openapi.Disposable; +import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.fileEditor.FileEditorManager; import com.intellij.openapi.options.ShowSettingsUtil; import com.intellij.openapi.project.Project; @@ -139,19 +140,21 @@ public void displayQuotaExceeded() { } public void displayError(String message) { - var errorText = format( - "

%s

", - message); - if (responseReceived) { - add(createTextPane(errorText, false)); - } else { - currentlyProcessedTextPane.setText(errorText); - } - hideCaret(); + ApplicationManager.getApplication().invokeLater(() -> { + var errorText = format( + "

%s

", + message); + if (responseReceived) { + add(createTextPane(errorText, false)); + } else { + currentlyProcessedTextPane.setText(errorText); + } + hideCaret(); - if (webpageListPanel != null) { - webpageListPanel.setVisible(false); - } + if (webpageListPanel != null) { + webpageListPanel.setVisible(false); + } + }); } public void displayWebSearchItem(Details details) { @@ -196,19 +199,23 @@ private void processCode(String markdownCode) { var codeBlock = ((FencedCodeBlock) child); var code = codeBlock.getContentChars().unescape(); if (!code.isEmpty()) { - if (currentlyProcessedEditorPanel == null) { - prepareProcessingCode(code, codeBlock.getInfo().unescape()); - } - EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code); + ApplicationManager.getApplication().invokeLater(() -> { + if (currentlyProcessedEditorPanel == null) { + prepareProcessingCode(code, codeBlock.getInfo().unescape()); + } + EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code); + }); } } } private void processText(String markdownText, boolean caretVisible) { - if (currentlyProcessedTextPane == null) { - prepareProcessingText(caretVisible); - } - currentlyProcessedTextPane.setText(convertMdToHtml(markdownText)); + ApplicationManager.getApplication().invokeLater(() -> { + if (currentlyProcessedTextPane == null) { + prepareProcessingText(caretVisible); + } + currentlyProcessedTextPane.setText(convertMdToHtml(markdownText)); + }); } private void prepareProcessingText(boolean caretVisible) { @@ -244,19 +251,17 @@ private JTextPane createTextPane(String text, boolean caretVisible) { } private static JPanel createWebpageListPanel(WebpageList webpageList) { - var panel = new JPanel(new BorderLayout()); - var title = new JPanel(new BorderLayout()); title.setOpaque(false); title.setBorder(JBUI.Borders.empty(8, 0)); title.add(new JBLabel(CodeGPTBundle.get("chatMessageResponseBody.webPagesTitle")) .withFont(JBUI.Fonts.miniFont()), BorderLayout.LINE_START); - panel.add(title); - var listPanel = new JPanel(new BorderLayout()); listPanel.add(webpageList, BorderLayout.LINE_START); - panel.add(listPanel); + var panel = new JPanel(new BorderLayout()); + panel.add(title); + panel.add(listPanel); return panel; } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index e111d4ae4..efcc96cfc 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -1,10 +1,13 @@ package ee.carlrobert.codegpt.completions import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.ConversationService import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.settings.GeneralSettings import ee.carlrobert.codegpt.settings.persona.DEFAULT_PROMPT import ee.carlrobert.codegpt.settings.persona.PersonaSettings +import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel import org.assertj.core.api.Assertions.assertThat import org.assertj.core.groups.Tuple @@ -13,7 +16,7 @@ import testsupport.IntegrationTest class CompletionRequestProviderTest : IntegrationTest() { fun testChatCompletionRequestWithSystemPromptOverride() { - useOpenAIService() + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" val conversation = ConversationService.getInstance().startConversation() val firstMessage = createDummyMessage(500) @@ -42,7 +45,7 @@ class CompletionRequestProviderTest : IntegrationTest() { } fun testChatCompletionRequestWithoutSystemPromptOverride() { - useOpenAIService() + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) service().state.selectedPersona.instructions = DEFAULT_PROMPT val conversation = ConversationService.getInstance().startConversation() val firstMessage = createDummyMessage(500) @@ -71,7 +74,7 @@ class CompletionRequestProviderTest : IntegrationTest() { } fun testChatCompletionRequestRetry() { - useOpenAIService() + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" val conversation = ConversationService.getInstance().startConversation() val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500) @@ -98,8 +101,9 @@ class CompletionRequestProviderTest : IntegrationTest() { } fun testReducedChatCompletionRequest() { + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) service().state.selectedPersona.instructions = DEFAULT_PROMPT - val conversation = ConversationService.getInstance().startConversation() + val conversation = Conversation() conversation.addMessage(createDummyMessage(50)) conversation.addMessage(createDummyMessage(100)) conversation.addMessage(createDummyMessage(150)) @@ -127,7 +131,7 @@ class CompletionRequestProviderTest : IntegrationTest() { } fun testTotalUsageExceededException() { - useOpenAIService() + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) val conversation = ConversationService.getInstance().startConversation() conversation.addMessage(createDummyMessage(1500)) conversation.addMessage(createDummyMessage(1500)) diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt index 1e034ea42..f282f6bd0 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt @@ -61,6 +61,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { waitExpecting { val messages = conversation.messages messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 } val encodingManager = EncodingManager.getInstance() assertThat(panel.tokenDetails).extracting( @@ -70,7 +71,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { "highlightedTokens") .containsExactly( encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countTokens(message.prompt), + encodingManager.countConversationTokens(conversation), 0, 0) assertThat(panel.conversation) @@ -146,6 +147,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { waitExpecting { val messages = conversation.messages messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 } val encodingManager = EncodingManager.getInstance() assertThat(panel.tokenDetails).extracting( @@ -155,7 +157,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { "highlightedTokens") .containsExactly( encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countTokens(message.prompt), + encodingManager.countConversationTokens(conversation), 0, 0) assertThat(panel.conversation) @@ -219,6 +221,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { waitExpecting { val messages = conversation.messages messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 } val encodingManager = EncodingManager.getInstance() assertThat(panel.tokenDetails).extracting( @@ -228,7 +231,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { "highlightedTokens") .containsExactly( encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countTokens(message.prompt), + encodingManager.countConversationTokens(conversation), 0, 0) assertThat(panel.conversation) @@ -309,6 +312,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { waitExpecting { val messages = conversation.messages messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 } val encodingManager = EncodingManager.getInstance() assertThat(panel.tokenDetails).extracting( @@ -318,7 +322,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { "highlightedTokens") .containsExactly( encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countTokens(message.prompt), + encodingManager.countConversationTokens(conversation), 0, 0) assertThat(panel.conversation) @@ -393,6 +397,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { waitExpecting { val messages = conversation.messages messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 } assertThat(panel.conversation) .isNotNull()