Skip to content

Commit

Permalink
Include context in user message
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroliu committed Jan 5, 2025
1 parent 5fa1a2c commit adbdfb8
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 102 deletions.
14 changes: 13 additions & 1 deletion src/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,15 @@ const Chat: React.FC<ChatProps> = ({
);
};

const handleSendMessage = async (toolCalls?: string[]) => {
const handleSendMessage = async ({
toolCalls,
urls,
contextNotes,
}: {
toolCalls?: string[];
urls?: string[];
contextNotes?: TFile[];
} = {}) => {
if (!inputMessage && selectedImages.length === 0) return;

const timestamp = formatDateTime(new Date());
Expand Down Expand Up @@ -146,6 +154,10 @@ const Chat: React.FC<ChatProps> = ({
isVisible: true,
timestamp: timestamp,
content: content,
context: {
notes: contextNotes || [],
urls: urls || [],
},
};

// Clear input and images
Expand Down
80 changes: 20 additions & 60 deletions src/components/chat-components/ChatInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ import { TooltipActionButton } from "./TooltipActionButton";
interface ChatInputProps {
inputMessage: string;
setInputMessage: (message: string) => void;
handleSendMessage: (toolCalls?: string[]) => void;
handleSendMessage: (metadata?: {
toolCalls?: string[];
urls?: string[];
contextNotes?: TFile[];
}) => void;
isGenerating: boolean;
onStopGenerating: () => void;
app: App;
Expand Down Expand Up @@ -86,52 +90,18 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
},
}));

const debounce = <T extends (...args: any[]) => any>(
fn: T,
delay: number
): ((...args: Parameters<T>) => void) => {
let timeoutId: NodeJS.Timeout;
return (...args: Parameters<T>) => {
clearTimeout(timeoutId);
timeoutId = setTimeout(() => fn(...args), delay);
};
};
const onSendMessage = (includeVault: boolean) => {
if (currentChain !== ChainType.COPILOT_PLUS_CHAIN) {
handleSendMessage();
return;
}

// Debounce the context update to prevent excessive re-renders
const debouncedUpdateContext = debounce(
async (
inputValue: string,
setContextNotes: React.Dispatch<React.SetStateAction<TFile[]>>,
currentContextNotes: TFile[],
app: App
) => {
const noteTitles = extractNoteTitles(inputValue);

const notesToAdd = await Promise.all(
noteTitles.map(async (title) => {
const files = app.vault.getMarkdownFiles();
const file = files.find((file) => file.basename === title);
if (file) {
return Object.assign(file, { wasAddedViaReference: true }) as TFile & {
wasAddedViaReference: boolean;
};
}
return undefined;
})
);

const validNotes = notesToAdd.filter(
(note): note is TFile & { wasAddedViaReference: boolean } =>
note !== undefined &&
!currentContextNotes.some((existing) => existing.path === note.path)
);

if (validNotes.length > 0) {
setContextNotes((prev) => [...prev, ...validNotes]);
}
},
50
);
handleSendMessage({
toolCalls: includeVault ? ["@vault"] : [],
contextNotes: contextNotes,
urls: contextUrls,
});
};

const handleInputChange = async (event: React.ChangeEvent<HTMLTextAreaElement>) => {
const inputValue = event.target.value;
Expand All @@ -150,9 +120,6 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
setContextUrls((prev) => Array.from(new Set([...prev, ...newUrls])));
}

// Update context with debouncing
debouncedUpdateContext(inputValue, setContextNotes, contextNotes, app);

// Handle other input triggers
if (cursorPos >= 2 && inputValue.slice(cursorPos - 2, cursorPos) === "[[") {
showNoteTitleModal(cursorPos);
Expand Down Expand Up @@ -185,9 +152,6 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
const newInputMessage = `${before}[[${noteTitle}]]${after}`;
setInputMessage(newInputMessage);

// Manually invoke debouncedUpdateContext
debouncedUpdateContext(newInputMessage, setContextNotes, contextNotes, app);

const activeNote = app.workspace.getActiveFile();
const noteFile = app.vault.getMarkdownFiles().find((file) => file.basename === noteTitle);

Expand Down Expand Up @@ -261,19 +225,15 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
e.preventDefault();
e.stopPropagation();

if (currentChain === ChainType.COPILOT_PLUS_CHAIN) {
handleSendMessage(["@vault"]);
} else {
handleSendMessage();
}
onSendMessage(true);
setHistoryIndex(-1);
setTempInput("");
return;
}

if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
handleSendMessage();
onSendMessage(false);
setHistoryIndex(-1);
setTempInput("");
} else if (e.key === "ArrowUp") {
Expand Down Expand Up @@ -491,13 +451,13 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
<StopCircle />
</button>
)}
<button onClick={() => handleSendMessage()} className="submit-button">
<button onClick={() => onSendMessage(false)} className="submit-button">
<CornerDownLeft size={16} />
<span>chat</span>
</button>

{currentChain === "copilot_plus" && (
<button onClick={() => handleSendMessage(["@vault"])} className="submit-button vault">
<button onClick={() => onSendMessage(true)} className="submit-button vault">
<div className="button-content">
{Platform.isMacOS ? (
<>
Expand Down
48 changes: 44 additions & 4 deletions src/components/chat-components/ChatSingleMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,40 @@ import { ChatMessage } from "@/sharedState";
import { Bot, User } from "lucide-react";
import { App, Component, MarkdownRenderer } from "obsidian";
import React, { useEffect, useRef, useState } from "react";
import { cn } from "@/lib/utils";
import { Badge } from "@/components/ui/badge";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";

function MessageContext({ context }: { context: ChatMessage["context"] }) {
if (!context || (context.notes.length === 0 && context.urls.length === 0)) {
return null;
}

return (
<div className="flex gap-2 flex-wrap">
{context.notes.map((note) => (
<Tooltip key={note.path}>
<TooltipTrigger asChild>
<Badge variant="secondary">
<span className="max-w-40 truncate">{note.basename}</span>
</Badge>
</TooltipTrigger>
<TooltipContent>{note.path}</TooltipContent>
</Tooltip>
))}
{context.urls.map((url) => (
<Tooltip key={url}>
<TooltipTrigger asChild>
<Badge variant="secondary">
<span className="max-w-40 truncate">{url}</span>
</Badge>
</TooltipTrigger>
<TooltipContent>{url}</TooltipContent>
</Tooltip>
))}
</div>
);
}

interface ChatSingleMessageProps {
message: ChatMessage;
Expand Down Expand Up @@ -190,10 +224,16 @@ const ChatSingleMessage: React.FC<ChatSingleMessageProps> = ({
};

return (
<div className="chat-message-container">
<div className={`message ${message.sender === USER_SENDER ? "user-message" : "bot-message"}`}>
<div className="message-icon">{message.sender === USER_SENDER ? <User /> : <Bot />}</div>
<div className="message-content-wrapper">
<div className="flex flex-col w-full mb-1">
<div
className={cn(
"flex rounded-md p-2 mx-2 gap-2",
message.sender === USER_SENDER && "bg-primary-alt"
)}
>
<div className="w-6 shrink-0">{message.sender === USER_SENDER ? <User /> : <Bot />}</div>
<div className="flex flex-col flex-grow max-w-full gap-2">
<MessageContext context={message.context} />
<div className="message-content">{renderMessageContent()}</div>

{!isStreaming && (
Expand Down
5 changes: 5 additions & 0 deletions src/sharedState.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { useEffect, useState } from "react";
import { FormattedDateTime } from "./utils";
import { TFile } from "obsidian";

export interface ChatMessage {
message: string;
Expand All @@ -9,6 +10,10 @@ export interface ChatMessage {
isVisible: boolean;
sources?: { title: string; score: number }[];
content?: any[];
context?: {
notes: TFile[];
urls: string[];
};
}

class SharedState {
Expand Down
36 changes: 0 additions & 36 deletions src/styles/tailwind.css
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,6 @@ If your plugin does not need CSS, delete this file.
height: 12px;
}

.chat-message-container {
display: flex;
flex-direction: column;
width: 100%;
margin-bottom: 8px;
}

.message {
display: flex;
padding: 0;
Expand All @@ -479,20 +472,6 @@ If your plugin does not need CSS, delete this file.
margin-bottom: 0;
}

.message-icon {
width: 24px;
margin-right: 8px;
color: var(--inline-title-color);
flex-shrink: 0;
}

.message-content-wrapper {
display: flex;
flex-direction: column;
flex-grow: 1;
max-width: 100%;
}

.message-content {
word-wrap: break-word;
overflow-wrap: break-word;
Expand Down Expand Up @@ -636,21 +615,6 @@ If your plugin does not need CSS, delete this file.
height: 14px;
}

.user-message {
white-space: pre-wrap;
width: 95%;
color: var(--inline-title-color);
background-color: var(--background-primary-alt);
border-radius: 2px;
padding: 14px 10px 10px 10px;
}

.bot-message {
width: 95%;
border-radius: 2px;
padding: 14px 10px 10px 10px;
}

.copilot-command-modal {
display: flex;
flex-direction: column;
Expand Down
3 changes: 2 additions & 1 deletion tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
"types": ["jest"],
"lib": ["DOM", "ES5", "ES6", "ES7", "ES2022"]
},
"include": ["**/*.ts", "src", "typings/**/*.ts"]
"include": ["**/*.ts", "src", "typings/**/*.ts"],
"exclude": ["node_modules"]
}

0 comments on commit adbdfb8

Please sign in to comment.