diff --git a/core/core.ts b/core/core.ts index d92fdaed1f..332076aae1 100644 --- a/core/core.ts +++ b/core/core.ts @@ -30,6 +30,7 @@ import { GlobalContext } from "./util/GlobalContext"; import historyManager from "./util/history"; import { editConfigJson, setupInitialDotContinueDirectory } from "./util/paths"; import { Telemetry } from "./util/posthog"; +import { getSymbolsForManyFiles } from "./util/treeSitter"; import { TTS } from "./util/tts"; import type { ContextItemId, IDE, IndexingProgressUpdate } from "."; @@ -351,6 +352,11 @@ export class Core { } }); + on("context/getSymbolsForFiles", async (msg) => { + const { uris } = msg.data; + return await getSymbolsForManyFiles(uris, this.ide); + }); + on("config/getSerializedProfileInfo", async (msg) => { return { config: await this.configHandler.getSerializedConfig(), diff --git a/core/index.d.ts b/core/index.d.ts index 45418dc842..ce2e59449b 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -1,3 +1,4 @@ +import Parser from "web-tree-sitter"; import { GetGhTokenArgs } from "./protocol/ide"; declare global { @@ -335,13 +336,19 @@ export interface InputModifiers { noContext: boolean; } +export interface SymbolWithRange extends RangeInFile { + name: string; + type: Parser.SyntaxNode["type"]; +} + +export type FileSymbolMap = Record; + export interface PromptLog { modelTitle: string; completionOptions: CompletionOptions; prompt: string; completion: string; } - export interface ChatHistoryItem { message: ChatMessage; editorState?: any; diff --git a/core/llm/llms/Bedrock.ts b/core/llm/llms/Bedrock.ts index 01a0201eb3..c2f76d0a8b 100644 --- a/core/llm/llms/Bedrock.ts +++ b/core/llm/llms/Bedrock.ts @@ -14,6 +14,11 @@ import { import { stripImages } from "../images.js"; import { BaseLLM } from "../index.js"; +/** + * Bedrock class implements AWS Bedrock LLM integration. + * It handles streaming completions and chat messages using AWS Bedrock runtime. + * Supports both text and image inputs through Claude 3 models. + */ class Bedrock extends BaseLLM { static providerName: ModelProvider = "bedrock"; static defaultOptions: Partial = { @@ -122,7 +127,9 @@ class Bedrock extends BaseLLM { // TODO: Additionally, consider implementing a global exception handler for the providers to give users clearer feedback. // For example, differentiate between client-side errors (4XX status codes) and server-side issues (5XX status codes), // providing meaningful error messages to improve the user experience. - stopSequences: options.stop?.filter((stop) => stop.trim() !== "").slice(0, 4), + stopSequences: options.stop + ?.filter((stop) => stop.trim() !== "") + .slice(0, 4), }, }; } @@ -164,7 +171,7 @@ class Bedrock extends BaseLLM { try { return await fromIni({ profile: this.profile, - ignoreCache: true + ignoreCache: true, })(); } catch (e) { console.warn( diff --git a/core/protocol/core.ts b/core/protocol/core.ts index f6c849963d..a53ef3f2b3 100644 --- a/core/protocol/core.ts +++ b/core/protocol/core.ts @@ -7,6 +7,7 @@ import type { ContextItemWithId, ContextSubmenuItem, DiffLine, + FileSymbolMap, IdeSettings, IndexingStatusMap, LLMFullCompletionOptions, @@ -77,6 +78,7 @@ export type ToCoreFromIdeOrWebviewProtocol = { }, ContextItemWithId[], ]; + "context/getSymbolsForFiles": [{ uris: string[] }, FileSymbolMap]; "context/loadSubmenuItems": [{ title: string }, ContextSubmenuItem[]]; "autocomplete/complete": [AutocompleteInput, string[]]; "context/addDocs": [SiteIndexingConfig, void]; diff --git a/core/protocol/passThrough.ts b/core/protocol/passThrough.ts index 71e35fe658..7a13926175 100644 --- a/core/protocol/passThrough.ts +++ b/core/protocol/passThrough.ts @@ -22,6 +22,7 @@ export const WEBVIEW_TO_CORE_PASS_THROUGH: (keyof ToCoreFromWebviewProtocol)[] = "config/deleteModel", "config/reload", "context/getContextItems", + "context/getSymbolsForFiles", "context/loadSubmenuItems", "context/addDocs", "context/removeDocs", diff --git a/core/util/treeSitter.ts b/core/util/treeSitter.ts index 9118b80ee1..22a59e85c1 100644 --- a/core/util/treeSitter.ts +++ b/core/util/treeSitter.ts @@ -2,6 +2,7 @@ import fs from "node:fs"; import * as path from "node:path"; import Parser, { Language } from "web-tree-sitter"; +import { FileSymbolMap, IDE, SymbolWithRange } from ".."; export enum LanguageName { CPP = "cpp", @@ -205,3 +206,93 @@ async function loadLanguageForFileExt( ); return await Parser.Language.load(wasmPath); } + +// See https://tree-sitter.github.io/tree-sitter/using-parsers +const GET_SYMBOLS_FOR_NODE_TYPES: Parser.SyntaxNode["type"][] = [ + "class_declaration", + "class_definition", + "function_item", // function name = first "identifier" child + "function_definition", + "method_declaration", // method name = first "identifier" child + "method_definition", + "generator_function_declaration", + // property_identifier + // field_declaration + // "arrow_function", +]; + +export async function getSymbolsForFile( + filepath: string, + contents: string, +): Promise { + const parser = await getParserForFile(filepath); + + if (!parser) { + return; + } + + const tree = parser.parse(contents); + // console.log(`file: ${filepath}`); + + // Function to recursively find all named nodes (classes and functions) + const symbols: SymbolWithRange[] = []; + function findNamedNodesRecursive(node: Parser.SyntaxNode) { + // console.log(`node: ${node.type}, ${node.text}`); + if (GET_SYMBOLS_FOR_NODE_TYPES.includes(node.type)) { + // console.log(`parent: ${node.type}, ${node.text.substring(0, 200)}`); + // node.children.forEach((child) => { + // console.log(`child: ${child.type}, ${child.text}`); + // }); + + // Empirically, the actual name is the last identifier in the node + // Especially with languages where return type is declared before the name + // TODO use findLast in newer version of node target + let identifier: Parser.SyntaxNode | undefined = undefined; + for (let i = node.children.length - 1; i >= 0; i--) { + if ( + node.children[i].type === "identifier" || + node.children[i].type === "property_identifier" + ) { + identifier = node.children[i]; + break; + } + } + + if (identifier?.text) { + symbols.push({ + filepath, + type: node.type, + name: identifier.text, + range: { + start: { + character: node.startPosition.column, + line: node.startPosition.row, + }, + end: { + character: node.endPosition.column + 1, + line: node.endPosition.row + 1, + }, + }, + }); + } + } + node.children.forEach(findNamedNodesRecursive); + } + findNamedNodesRecursive(tree.rootNode); + + return symbols; +} + +export async function getSymbolsForManyFiles( + uris: string[], + ide: IDE, +): Promise { + const filesAndSymbols = await Promise.all( + uris.map(async (uri): Promise<[string, SymbolWithRange[]]> => { + const contents = await ide.readFile(uri); + const symbols = await getSymbolsForFile(uri, contents); + return [uri, symbols ?? []]; + }), + ); + return Object.fromEntries(filesAndSymbols); +} diff --git a/gui/src/components/History/HistoryTableRow.tsx b/gui/src/components/History/HistoryTableRow.tsx index 26107c199c..695389c29c 100644 --- a/gui/src/components/History/HistoryTableRow.tsx +++ b/gui/src/components/History/HistoryTableRow.tsx @@ -1,11 +1,12 @@ import { PencilSquareIcon, TrashIcon } from "@heroicons/react/24/outline"; import { SessionInfo } from "core"; import React, { useState } from "react"; -import { useDispatch } from "react-redux"; +import { useDispatch, useSelector } from "react-redux"; import { useNavigate } from "react-router-dom"; import { Input } from ".."; import useHistory from "../../hooks/useHistory"; import HeaderButtonWithToolTip from "../gui/HeaderButtonWithToolTip"; +import { RootState } from "../../redux/store"; function lastPartOfPath(path: string): string { const sep = path.includes("/") ? "/" : "\\"; @@ -29,6 +30,9 @@ export function HistoryTableRow({ const [sessionTitleEditValue, setSessionTitleEditValue] = useState( session.title, ); + const currentSessionId = useSelector( + (state: RootState) => state.state.sessionId, + ); const { saveSession, deleteSession, loadSession, getSession, updateSession } = useHistory(dispatch); @@ -59,8 +63,11 @@ export function HistoryTableRow({ className="hover:bg-vsc-editor-background relative box-border flex max-w-full cursor-pointer overflow-hidden rounded-lg p-3" onClick={async () => { // Save current session - await saveSession(); - await loadSession(session.sessionId); + if (session.sessionId !== currentSessionId) { + await saveSession(); + await loadSession(session.sessionId); + } + navigate("/"); }} > diff --git a/gui/src/components/StepContainer/StepContainer.tsx b/gui/src/components/StepContainer/StepContainer.tsx index f5673e4e8c..4fdd6eda26 100644 --- a/gui/src/components/StepContainer/StepContainer.tsx +++ b/gui/src/components/StepContainer/StepContainer.tsx @@ -93,6 +93,7 @@ export default function StepContainer(props: StepContainerProps) { )} diff --git a/gui/src/components/mainInput/CodeBlockComponent.tsx b/gui/src/components/mainInput/CodeBlockComponent.tsx index 20d700c634..51e57cae10 100644 --- a/gui/src/components/mainInput/CodeBlockComponent.tsx +++ b/gui/src/components/mainInput/CodeBlockComponent.tsx @@ -2,8 +2,6 @@ import { NodeViewWrapper } from "@tiptap/react"; import { ContextItemWithId } from "core"; import { vscBadgeBackground } from ".."; import CodeSnippetPreview from "../markdown/CodeSnippetPreview"; -import { useSelector } from "react-redux"; -import { RootState } from "../../redux/store"; export const CodeBlockComponent = (props: any) => { const { node, deleteNode, selected, editor, updateAttributes } = props; @@ -13,7 +11,7 @@ export const CodeBlockComponent = (props: any) => { // store.state.history[store.state.history.length - 1].contextItems, // ); // const isFirstContextItem = item.id === contextItems[0]?.id; - const isFirstContextItem = false; + const isFirstContextItem = false; // TODO: fix this, decided not worth the insane renders for now return ( diff --git a/gui/src/components/mainInput/ContextItemsPeek.tsx b/gui/src/components/mainInput/ContextItemsPeek.tsx index 51e0989324..56608b837b 100644 --- a/gui/src/components/mainInput/ContextItemsPeek.tsx +++ b/gui/src/components/mainInput/ContextItemsPeek.tsx @@ -9,10 +9,12 @@ import SafeImg from "../SafeImg"; import { INSTRUCTIONS_BASE_ITEM } from "core/context/providers/utils"; import { getIconFromDropdownItem } from "./MentionList"; import { getBasename } from "core/util"; +import { RootState } from "../../redux/store"; +import { useSelector } from "react-redux"; interface ContextItemsPeekProps { contextItems?: ContextItemWithId[]; - isGatheringContext: boolean; + isCurrentContextPeek: boolean; } interface ContextItemsPeekItemProps { @@ -118,7 +120,7 @@ function ContextItemsPeekItem({ contextItem }: ContextItemsPeekItemProps) { function ContextItemsPeek({ contextItems, - isGatheringContext, + isCurrentContextPeek, }: ContextItemsPeekProps) { const [open, setOpen] = useState(false); @@ -126,7 +128,16 @@ function ContextItemsPeek({ (ctxItem) => !ctxItem.name.includes(INSTRUCTIONS_BASE_ITEM.name), ); - if ((!ctxItems || ctxItems.length === 0) && !isGatheringContext) { + const isGatheringContext = useSelector( + (store: RootState) => store.state.context.isGathering, + ); + const gatheringMessage = useSelector( + (store: RootState) => store.state.context.gatheringMessage, + ); + + const indicateIsGathering = isCurrentContextPeek && isGatheringContext; + + if ((!ctxItems || ctxItems.length === 0) && !indicateIsGathering) { return null; } @@ -151,7 +162,7 @@ function ContextItemsPeek({ {isGatheringContext ? ( <> - Gathering context + {gatheringMessage} ) : ( diff --git a/gui/src/components/mainInput/ContinueInputBox.tsx b/gui/src/components/mainInput/ContinueInputBox.tsx index 4a9ea0dfd8..4fccac8feb 100644 --- a/gui/src/components/mainInput/ContinueInputBox.tsx +++ b/gui/src/components/mainInput/ContinueInputBox.tsx @@ -1,6 +1,5 @@ import { Editor, JSONContent } from "@tiptap/react"; import { ContextItemWithId, InputModifiers } from "core"; -import { useEffect, useRef, useState } from "react"; import { useDispatch, useSelector } from "react-redux"; import styled, { keyframes } from "styled-components"; import { defaultBorderRadius, vscBackground } from ".."; @@ -69,12 +68,6 @@ function ContinueInputBox(props: ContinueInputBoxProps) { const availableContextProviders = useSelector( (store: RootState) => store.state.config.contextProviders, ); - const isGatheringContextStore = useSelector( - (store: RootState) => store.state.isGatheringContext, - ); - - const [isGatheringContext, setIsGatheringContext] = useState(false); - const timeoutRef = useRef(null); useWebviewListener( "newSessionWithPrompt", @@ -92,22 +85,6 @@ function ContinueInputBox(props: ContinueInputBoxProps) { [props.isMainInput], ); - useEffect(() => { - if (isGatheringContextStore && !isGatheringContext) { - // 500ms delay when going from false -> true to prevent flashing loading indicator - timeoutRef.current = setTimeout(() => setIsGatheringContext(true), 500); - } else { - // Update immediately otherwise (i.e. true -> false) - setIsGatheringContext(isGatheringContextStore); - } - - return () => { - if (timeoutRef.current) { - clearTimeout(timeoutRef.current); - } - }; - }, [isGatheringContextStore]); - return (
@@ -135,7 +112,7 @@ function ContinueInputBox(props: ContinueInputBoxProps) {
); diff --git a/gui/src/components/mainInput/resolveInput.ts b/gui/src/components/mainInput/resolveInput.ts index 64acba85a1..b12804d724 100644 --- a/gui/src/components/mainInput/resolveInput.ts +++ b/gui/src/components/mainInput/resolveInput.ts @@ -9,6 +9,8 @@ import { } from "core"; import { stripImages } from "core/llm/images"; import { IIdeMessenger } from "../../context/IdeMessenger"; +import { Dispatch } from "@reduxjs/toolkit"; +import { setIsGatheringContext } from "../../redux/slices/stateSlice"; interface MentionAttrs { label: string; @@ -29,6 +31,7 @@ async function resolveEditorContent( modifiers: InputModifiers, ideMessenger: IIdeMessenger, defaultContextProviders: DefaultContextProvider[], + dispatch: Dispatch, ): Promise<[ContextItemWithId[], RangeInFile[], MessageContent]> { let parts: MessagePart[] = []; let contextItemAttrs: MentionAttrs[] = []; @@ -38,8 +41,7 @@ async function resolveEditorContent( if (p.type === "paragraph") { const [text, ctxItems, foundSlashCommand] = resolveParagraph(p); - // Only take the first slash command - + // Only take the first slash command\ if (foundSlashCommand && typeof slashCommand === "undefined") { slashCommand = foundSlashCommand; } @@ -103,6 +105,17 @@ async function resolveEditorContent( } } + const shouldGatherContext = modifiers.useCodebase || slashCommand; + + if (shouldGatherContext) { + dispatch( + setIsGatheringContext({ + isGathering: true, + gatheringMessage: "Gathering context", + }), + ); + } + let contextItemsText = ""; let contextItems: ContextItemWithId[] = []; for (const item of contextItemAttrs) { @@ -172,6 +185,15 @@ async function resolveEditorContent( } } + if (shouldGatherContext) { + dispatch( + setIsGatheringContext({ + isGathering: false, + gatheringMessage: "Gathering context", + }), + ); + } + return [contextItems, selectedCode, parts]; } diff --git a/gui/src/components/markdown/CodeSnippetPreview.tsx b/gui/src/components/markdown/CodeSnippetPreview.tsx index 39c09cfb66..639e14cc83 100644 --- a/gui/src/components/markdown/CodeSnippetPreview.tsx +++ b/gui/src/components/markdown/CodeSnippetPreview.tsx @@ -124,7 +124,6 @@ function CodeSnippetPreview(props: CodeSnippetPreviewProps) { source={`${fence}${getMarkdownLanguageTagForFile( props.item.description.split(" ")[0], )} ${props.item.description}\n${content}\n${fence}`} - contextItems={[props.item]} /> diff --git a/gui/src/components/markdown/FilenameLink.tsx b/gui/src/components/markdown/FilenameLink.tsx index 3674fc4183..161bd0b0af 100644 --- a/gui/src/components/markdown/FilenameLink.tsx +++ b/gui/src/components/markdown/FilenameLink.tsx @@ -20,15 +20,15 @@ function FilenameLink({ rif }: FilenameLinkProps) { } return ( -
- + {getBasename(rif.filepath)} -
+
); } diff --git a/gui/src/components/markdown/StyledMarkdownPreview.tsx b/gui/src/components/markdown/StyledMarkdownPreview.tsx index e783f8499a..ca1ea2afd2 100644 --- a/gui/src/components/markdown/StyledMarkdownPreview.tsx +++ b/gui/src/components/markdown/StyledMarkdownPreview.tsx @@ -1,4 +1,4 @@ -import { memo, useEffect } from "react"; +import { memo, useCallback, useEffect, useMemo, useRef } from "react"; import { useRemark } from "react-remark"; import rehypeHighlight, { Options } from "rehype-highlight"; import rehypeKatex from "rehype-katex"; @@ -20,7 +20,10 @@ import StepContainerPreToolbar from "./StepContainerPreToolbar"; import { SyntaxHighlightedPre } from "./SyntaxHighlightedPre"; import StepContainerPreActionButtons from "./StepContainerPreActionButtons"; import { patchNestedMarkdown } from "./utils/patchNestedMarkdown"; -import { ContextItemWithId } from "core"; +import { RootState } from "../../redux/store"; +import { ContextItemWithId, SymbolWithRange } from "core"; +import SymbolLink from "./SymbolLink"; +import { useSelector } from "react-redux"; const StyledMarkdown = styled.div<{ fontSize?: number; @@ -93,7 +96,7 @@ interface StyledMarkdownPreviewProps { className?: string; isRenderingInStepContainer?: boolean; // Currently only used to control the rendering of codeblocks scrollLocked?: boolean; - contextItems?: ContextItemWithId[]; + itemIndex?: number; } const HLJS_LANGUAGE_CLASSNAME_PREFIX = "language-"; @@ -117,12 +120,10 @@ function getCodeChildrenContent(children: any) { } else if ( Array.isArray(children) && children.length > 0 && - typeof children[0] === "string" && - children[0] !== "" + typeof children[0] === "string" ) { return children[0]; } - return undefined; } @@ -154,6 +155,28 @@ function processCodeBlocks(tree: any) { const StyledMarkdownPreview = memo(function StyledMarkdownPreview( props: StyledMarkdownPreviewProps, ) { + const contextItems = useSelector( + (state: RootState) => + state.state.history[props.itemIndex - 1]?.contextItems, + ); + const symbols = useSelector((state: RootState) => state.state.symbols); + + // The refs are a workaround because rehype options are stored on initiation + // So they won't use the most up-to-date state values + // So in this case we just put them in refs + const symbolsRef = useRef([]); + const contextItemsRef = useRef([]); + + useEffect(() => { + contextItemsRef.current = contextItems || []; + }, [contextItems]); + useEffect(() => { + // Note, before I was only looking for symbols that matched + // Context item files on current history item + // but in practice global symbols for session makes way more sense + symbolsRef.current = Object.values(symbols).flat(); + }, [symbols]); + const [reactContent, setMarkdownSource] = useRemark({ remarkPlugins: [remarkMath, () => processCodeBlocks], rehypePlugins: [ @@ -236,16 +259,30 @@ const StyledMarkdownPreview = memo(function StyledMarkdownPreview( code: ({ node, ...codeProps }) => { const content = getCodeChildrenContent(codeProps.children); - if (props.contextItems) { - const ctxItem = props.contextItems.find((ctxItem) => + if (content && contextItemsRef.current) { + const ctxItem = contextItemsRef.current.find((ctxItem) => ctxItem.uri?.value.includes(content), ); if (ctxItem) { const rif = ctxItemToRifWithContents(ctxItem); return ; } - } + const exactSymbol = symbolsRef.current.find( + (s) => s.name === content, + ); + if (exactSymbol) { + return ; + } + + // PARTIAL - this is the case where the llm returns e.g. `subtract(number)` instead of `subtract` + const partialSymbol = symbolsRef.current.find((s) => + content.startsWith(s.name), + ); + if (partialSymbol) { + return ; + } + } return {codeProps.children}; }, }, diff --git a/gui/src/components/markdown/SymbolLink.tsx b/gui/src/components/markdown/SymbolLink.tsx new file mode 100644 index 0000000000..45d54ba8f2 --- /dev/null +++ b/gui/src/components/markdown/SymbolLink.tsx @@ -0,0 +1,33 @@ +import { SymbolWithRange } from "core"; +import { useContext } from "react"; +import { IdeMessengerContext } from "../../context/IdeMessenger"; + +interface SymbolLinkProps { + symbol: SymbolWithRange; + content: string; +} + +function SymbolLink({ symbol, content }: SymbolLinkProps) { + const ideMessenger = useContext(IdeMessengerContext); + + function onClick() { + ideMessenger.post("showLines", { + filepath: symbol.filepath, + startLine: symbol.range.start.line, + endLine: symbol.range.end.line, + }); + } + + return ( + + + {content} + + + ); +} + +export default SymbolLink; diff --git a/gui/src/hooks/useChatHandler.ts b/gui/src/hooks/useChatHandler.ts index e93cdf8d36..1e8fd5234b 100644 --- a/gui/src/hooks/useChatHandler.ts +++ b/gui/src/hooks/useChatHandler.ts @@ -38,6 +38,7 @@ import { import { resetNextCodeBlockToApplyIndex } from "../redux/slices/stateSlice"; import { RootState } from "../redux/store"; import useHistory from "./useHistory"; +import { updateFileSymbolsFromContextItems } from "../util/symbols"; function useChatHandler(dispatch: Dispatch, ideMessenger: IIdeMessenger) { const posthog = usePostHog(); @@ -197,7 +198,12 @@ function useChatHandler(dispatch: Dispatch, ideMessenger: IIdeMessenger) { modifiers.useCodebase || hasSlashCommandOrContextProvider(editorState); if (shouldGatherContext) { - dispatch(setIsGatheringContext(true)); + dispatch( + setIsGatheringContext({ + isGathering: true, + gatheringMessage: "Gathering Context", + }), + ); } // Resolve context providers and construct new history @@ -207,10 +213,9 @@ function useChatHandler(dispatch: Dispatch, ideMessenger: IIdeMessenger) { modifiers, ideMessenger, defaultContextProviders, + dispatch, ); - dispatch(setIsGatheringContext(false)); - // Automatically use currently open file if (!modifiers.noContext) { const usingFreeTrial = defaultModel?.provider === "free-trial"; @@ -250,7 +255,11 @@ function useChatHandler(dispatch: Dispatch, ideMessenger: IIdeMessenger) { } } - // dispatch(addContextItems(contextItems)); + await updateFileSymbolsFromContextItems( + selectedContextItems, + ideMessenger, + dispatch, + ); const message: ChatMessage = { role: "user", diff --git a/gui/src/hooks/useHistory.tsx b/gui/src/hooks/useHistory.tsx index bb59356530..e6feec628e 100644 --- a/gui/src/hooks/useHistory.tsx +++ b/gui/src/hooks/useHistory.tsx @@ -6,7 +6,6 @@ import { useCallback, useContext } from "react"; import { useSelector } from "react-redux"; import { IdeMessengerContext } from "../context/IdeMessenger"; import { useLastSessionContext } from "../context/LastSessionContext"; -import { defaultModelSelector } from "../redux/selectors/modelSelectors"; import { newSession } from "../redux/slices/stateSlice"; import { RootState } from "../redux/store"; import { getLocalStorage, setLocalStorage } from "../util/localStorage"; @@ -22,7 +21,6 @@ function truncateText(text: string, maxLength: number) { function useHistory(dispatch: Dispatch) { const state = useSelector((state: RootState) => state.state); - const defaultModel = useSelector(defaultModelSelector); const ideMessenger = useContext(IdeMessengerContext); const { lastSessionId, setLastSessionId } = useLastSessionContext(); @@ -122,6 +120,7 @@ function useHistory(dispatch: Dispatch) { if (result.status === "error") { throw new Error(result.error); } + const sessionContent = result.content; dispatch(newSession(sessionContent)); return sessionContent; diff --git a/gui/src/hooks/useSetup.ts b/gui/src/hooks/useSetup.ts index e8aac6a30a..5b663717e0 100644 --- a/gui/src/hooks/useSetup.ts +++ b/gui/src/hooks/useSetup.ts @@ -20,9 +20,11 @@ import { isJetBrains } from "../util"; import { getLocalStorage, setLocalStorage } from "../util/localStorage"; import useChatHandler from "./useChatHandler"; import { useWebviewListener } from "./useWebviewListener"; +import { updateFileSymbolsFromContextItems } from "../util/symbols"; -function useSetup(dispatch: Dispatch) { +function useSetup(dispatch: Dispatch) { const ideMessenger = useContext(IdeMessengerContext); + const history = useSelector((store: RootState) => store.state.history); const initialConfigLoad = useRef(false); const loadConfig = useCallback(async () => { @@ -66,6 +68,20 @@ function useSetup(dispatch: Dispatch) { await loadConfig(); }); + // Load symbols for chat on any session change + const sessionId = useSelector((store: RootState) => store.state.sessionId); + const sessionIdRef = useRef(""); + useEffect(() => { + if (sessionIdRef.current !== sessionId) { + updateFileSymbolsFromContextItems( + history.flatMap((item) => item.contextItems), + ideMessenger, + dispatch, + ); + } + sessionIdRef.current = sessionId; + }, [sessionId, history, ideMessenger, dispatch]); + useEffect(() => { // Override persisted state dispatch(setInactive()); @@ -93,7 +109,6 @@ function useSetup(dispatch: Dispatch) { ); // IDE event listeners - const history = useSelector((store: RootState) => store.state.history); useWebviewListener( "getWebviewHistoryLength", async () => { diff --git a/gui/src/pages/edit/index.tsx b/gui/src/pages/edit/index.tsx index c601c96509..26fbacbb1f 100644 --- a/gui/src/pages/edit/index.tsx +++ b/gui/src/pages/edit/index.tsx @@ -174,8 +174,11 @@ function Edit() { }, ideMessenger, [], + dispatch, ); + // Note, not currently updating file symbols in edit mode + const prompt = [ ...contextItems.map((item) => item.content), stripImages(userInstructions), diff --git a/gui/src/redux/slices/stateSlice.ts b/gui/src/redux/slices/stateSlice.ts index 1e6abe0bfb..644aa9e585 100644 --- a/gui/src/redux/slices/stateSlice.ts +++ b/gui/src/redux/slices/stateSlice.ts @@ -5,6 +5,7 @@ import { ChatMessage, Checkpoint, ContextItemWithId, + FileSymbolMap, IndexingStatus, PersistedSessionInfo, PromptLog, @@ -12,8 +13,8 @@ import { import { BrowserSerializedContinueConfig } from "core/config/load"; import { ConfigValidationError } from "core/config/validation"; import { stripImages } from "core/llm/images"; -import { v4 as uuidv4, v4 } from "uuid"; import { ApplyState } from "core/protocol/ideWebview"; +import { v4 as uuidv4, v4 } from "uuid"; // We need this to handle reorderings (e.g. a mid-array deletion) of the messages array. // The proper fix is adding a UUID to all chat messages, but this is the temp workaround. @@ -22,9 +23,13 @@ type ChatHistoryItemWithMessageId = ChatHistoryItem & { }; type State = { history: ChatHistoryItemWithMessageId[]; + symbols: FileSymbolMap; + context: { + isGathering: boolean; + gatheringMessage: string; + }; ttsActive: boolean; active: boolean; - isGatheringContext: boolean; config: BrowserSerializedContinueConfig; title: string; sessionId: string; @@ -46,9 +51,13 @@ type State = { const initialState: State = { history: [], + symbols: {}, + context: { + isGathering: false, + gatheringMessage: "Gathering Context", + }, ttsActive: false, active: false, - isGatheringContext: false, configError: undefined, config: { slashCommands: [ @@ -123,8 +132,17 @@ export const stateSlice = createSlice({ setActive: (state) => { state.active = true; }, - setIsGatheringContext: (state, { payload }: PayloadAction) => { - state.isGatheringContext = payload; + setIsGatheringContext: ( + state, + { + payload, + }: PayloadAction<{ + isGathering: boolean; + gatheringMessage: string; + }>, + ) => { + state.context.isGathering = payload.isGathering; + state.context.gatheringMessage = payload.gatheringMessage; }, clearLastResponse: (state) => { if (state.history.length < 2) { @@ -137,6 +155,12 @@ export const stateSlice = createSlice({ consumeMainEditorContent: (state) => { state.mainEditorContent = undefined; }, + updateFileSymbols: (state, action: PayloadAction) => { + state.symbols = { + ...state.symbols, + ...action.payload, + }; + }, setContextItemsAtIndex: ( state, { @@ -253,10 +277,13 @@ export const stateSlice = createSlice({ if (!historyItem) { return; } - historyItem.contextItems.push(...payload.contextItems); + historyItem.contextItems = [ + ...historyItem.contextItems, + ...payload.contextItems, + ]; }, setInactive: (state) => { - state.isGatheringContext = false; + state.context.isGathering = false; state.active = false; }, abortStream: (state) => { @@ -277,8 +304,9 @@ export const stateSlice = createSlice({ state.streamAborter = new AbortController(); state.active = false; - state.isGatheringContext = false; + state.context.isGathering = false; state.isMultifileEdit = false; + state.symbols = {}; if (payload) { state.history = payload.history as any; state.title = payload.title; @@ -420,6 +448,7 @@ export const stateSlice = createSlice({ }); export const { + updateFileSymbols, setContextItemsAtIndex, addContextItemsAtIndex, setInactive, diff --git a/gui/src/util/symbols.ts b/gui/src/util/symbols.ts new file mode 100644 index 0000000000..5c4e1be127 --- /dev/null +++ b/gui/src/util/symbols.ts @@ -0,0 +1,35 @@ +import { ContextItemWithId } from "core"; +import { IIdeMessenger } from "../context/IdeMessenger"; +import { updateFileSymbols } from "../redux/slices/stateSlice"; +import { Dispatch } from "@reduxjs/toolkit"; + +export async function updateFileSymbolsFromContextItems( + contextItems: ContextItemWithId[], + ideMessenger: IIdeMessenger, + dispatch: Dispatch, +) { + // Given a list of context items, + // Get unique file uris + try { + const contextUris = Array.from( + new Set( + contextItems + .filter((item) => item.uri?.type === "file" && item?.uri?.value) + .map((item) => item.uri.value), + ), + ); + // And then update symbols for those files + if (contextUris.length > 0) { + const symbolsResult = await ideMessenger.request( + "context/getSymbolsForFiles", + { uris: contextUris }, + ); + if (symbolsResult.status === "success") { + dispatch(updateFileSymbols(symbolsResult.content)); + } + } + } catch (e) { + // Catch all - don't want file symbols to break the chat experience for now + console.error("Error updating file symbols from context items", e); + } +}