Skip to content

Commit

Permalink
Add feature to cancel current chat
Browse files Browse the repository at this point in the history
  • Loading branch information
taichimaeda committed Apr 20, 2024
1 parent 6c0246c commit c0c3561
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 48 deletions.
56 changes: 37 additions & 19 deletions src/chat/App.tsx
Original file line number Diff line number Diff line change
@@ -1,39 +1,56 @@
import { useEffect, useState } from 'react';
import { useEffect, useLayoutEffect, useRef, useState } from 'react';
import { ChatHistory, ChatRole } from 'src/api';
import Markpilot from 'src/main';
import { ChatInput } from './components/ChatBox';
import { ChatItem } from './components/ChatItem';
import { ChatView } from './view';
import { ChatFetcher, ChatView } from './view';

const systemPrompt = `
const SYSTEM_PROMPT = `
Welcome, I'm your Copilot and I'm here to help you get things done faster. You can also start an inline chat session.
I'm powered by AI, so surprises and mistakes are possible. Make sure to verify any generated code or suggestions, and share feedback so that we can learn and improve. Check out the Copilot documentation to learn more.
`;

const defaultHistory: ChatHistory = {
messages: [{ role: 'system', content: systemPrompt }],
messages: [{ role: 'system', content: SYSTEM_PROMPT }],
response: '',
};

export function App({ view, plugin }: { view: ChatView; plugin: Markpilot }) {
const [turn, setTurn] = useState<ChatRole>('system');
const [history, setHistory] = useState<ChatHistory>(defaultHistory);

export function App({
view,
fetcher,
cancel,
plugin,
}: {
view: ChatView;
fetcher: ChatFetcher;
cancel: () => void;
plugin: Markpilot;
}) {
const { settings } = plugin;

const [turn, setTurn] = useState<ChatRole>('user');
const [history, setHistory] = useState<ChatHistory>(
settings.chat.history.messages.length > 1
? settings.chat.history
: defaultHistory,
);

const bottomRef = useRef<HTMLDivElement>(null);

// Expose the method to clear history to the view
// so that the plugin command can call it.
useEffect(() => {
// Expose the method to clear history to the view
// so that the plugin command can call it.
view.clear = () => setHistory(defaultHistory);
}, []);

useEffect(() => {
if (settings.chat.history.messages.length > 1) {
setHistory(settings.chat.history);
}
setTurn('user');
}, []);
// Scroll to the bottom when chat history changes.
useLayoutEffect(() => {
bottomRef?.current?.scrollIntoView();
}, [history]);

// Save chat history to settings when it changes.
// There may be a better way to store chat history, but this works for now.
useEffect(() => {
settings.chat.history = history;
plugin.saveSettings();
Expand All @@ -42,8 +59,7 @@ export function App({ view, plugin }: { view: ChatView; plugin: Markpilot }) {
useEffect(() => {
if (turn === 'assistant') {
(async () => {
const chunks = plugin.chatClient.fetchChat(history.messages);
for await (const chunk of chunks) {
for await (const chunk of fetcher(history.messages)) {
setHistory((history) => ({
...history,
response: history.response + chunk,
Expand Down Expand Up @@ -78,12 +94,14 @@ export function App({ view, plugin }: { view: ChatView; plugin: Markpilot }) {
))}
{turn === 'assistant' && (
<ChatItem
active
message={{ role: 'assistant', content: history.response }}
/>
)}
<div ref={bottomRef} />
</div>
<div className="input-container">
<ChatInput disabled={turn === 'assistant'} submit={submit} />
<ChatInput turn={turn} cancel={cancel} submit={submit} />
</div>
</div>
);
Expand Down
33 changes: 27 additions & 6 deletions src/chat/components/ChatBox.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { SendHorizontal } from 'lucide-react';
import { CircleStop, SendHorizontal } from 'lucide-react';
import { useState } from 'react';
import { ChatRole } from 'src/api';

export function ChatInput({
disabled,
turn,
cancel,
submit,
}: {
disabled: boolean;
turn: ChatRole;
cancel: () => void;
submit: (text: string) => void;
}) {
const [value, setValue] = useState('');
Expand All @@ -18,12 +21,13 @@ export function ChatInput({
<textarea
className="input-field"
style={{ height: `${numRows + 1.5}rem` }}
disabled={disabled}
disabled={turn !== 'user'}
placeholder="Type a message..."
value={value}
onChange={(event) => setValue(event.target.value)}
onKeyDown={(event) => {
if (
turn === 'user' &&
value.trim() !== '' &&
event.key === 'Enter' &&
!event.shiftKey && // Allow newline with shift key
Expand All @@ -39,8 +43,25 @@ export function ChatInput({
className="send-button-container"
style={{ height: `${numRows + 1.5}rem` }}
>
<button className="send-button">
<SendHorizontal size={16} />
<button
className="send-button"
disabled={turn === 'user' && value.trim() === ''}
onClick={(event) => {
if (turn === 'user') {
event.preventDefault();
setValue('');
submit(value);
} else if (turn === 'assistant') {
event.preventDefault();
cancel();
}
}}
>
{turn === 'user' ? (
<SendHorizontal size={16} />
) : (
<CircleStop size={16} />
)}
</button>
</div>
</div>
Expand Down
25 changes: 23 additions & 2 deletions src/chat/components/ChatItem.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ import rehypeKatex from 'rehype-katex';
import remarkMath from 'remark-math';
import { ChatMessage } from 'src/api';

export function ChatItem({ message }: { message: ChatMessage }) {
export function ChatItem({
active,
message,
}: {
active?: boolean;
message: ChatMessage;
}) {
return (
<div
className={
Expand All @@ -13,7 +19,11 @@ export function ChatItem({ message }: { message: ChatMessage }) {
}
>
<ChatItemHeader message={message} />
<ChatItemBody message={message} />
{active && message.content === '' ? (
<ChatItemBodyTyping />
) : (
<ChatItemBody message={message} />
)}
</div>
);
}
Expand Down Expand Up @@ -44,9 +54,20 @@ function ChatItemHeader({ message }: { message: ChatMessage }) {
function ChatItemBody({ message }: { message: ChatMessage }) {
return (
<div className="markpilot-chat-item-body">
{/* TODO: Make markdown content selectable. */}
<ReactMarkdown remarkPlugins={[remarkMath]} rehypePlugins={[rehypeKatex]}>
{message.content}
</ReactMarkdown>
</div>
);
}

function ChatItemBodyTyping() {
return (
<div className="markpilot-chat-item-body-typing">
<div className="dot" />
<div className="dot" />
<div className="dot" />
</div>
);
}
14 changes: 13 additions & 1 deletion src/chat/view.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import { ItemView, WorkspaceLeaf } from 'obsidian';
import * as React from 'react';
import { createRoot, Root } from 'react-dom/client';
import { ChatMessage } from 'src/api';
import Markpilot from 'src/main';
import { App } from './App';

export const CHAT_VIEW_TYPE = 'markpilot-chat-view';

export type ChatFetcher = (
messages: ChatMessage[],
) => AsyncGenerator<string | undefined>;

export class ChatView extends ItemView {
private root: Root;
public clear?: () => void;

constructor(
leaf: WorkspaceLeaf,
private fetcher: ChatFetcher,
private cancel: () => void,
private plugin: Markpilot,
) {
super(leaf);
Expand All @@ -38,7 +45,12 @@ export class ChatView extends ItemView {
this.root = createRoot(containerEl);
this.root.render(
<React.StrictMode>
<App view={this} plugin={this.plugin} />
<App
view={this}
fetcher={this.fetcher}
cancel={this.cancel}
plugin={this.plugin}
/>
</React.StrictMode>,
);
}
Expand Down
11 changes: 3 additions & 8 deletions src/editor/extension.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import Markpilot from 'src/main';
import { debounceAsyncFunc } from '../utils';
import {
acceptCompletionsOnKeydown,
rejectCompletionsOnKeydown,
Expand All @@ -18,18 +17,14 @@ export type CompletionsForce = () => void;

export function inlineCompletionsExtension(
fetcher: CompletionsFetcher,
cancel: () => void,
force: () => void,
plugin: Markpilot,
) {
const { settings } = plugin;
const { debounced, cancel, force } = debounceAsyncFunc(
fetcher,
settings.completions.waitTime,
);

return [
completionsStateField,
completionsRenderPlugin,
showCompletionsOnUpdate(debounced, plugin),
showCompletionsOnUpdate(fetcher, plugin),
acceptCompletionsOnKeydown(force, plugin),
rejectCompletionsOnKeydown(cancel, plugin),
];
Expand Down
37 changes: 30 additions & 7 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
setIcon,
WorkspaceLeaf,
} from 'obsidian';
import { APIClient } from './api';
import { APIClient, ChatMessage } from './api';
import { OllamaAPIClient } from './api/clients/ollama';
import { OpenAIAPIClient } from './api/clients/openai';
import { OpenRouterAPIClient } from './api/clients/openrouter';
Expand All @@ -26,6 +26,7 @@ import {
MarkpilotSettingTab,
} from './settings';
import { SettingsMigrationsRunner } from './settings/runner';
import { debounceAsyncFunc, debounceAsyncGenerator } from './utils';

export default class Markpilot extends Plugin {
settings: MarkpilotSettings;
Expand Down Expand Up @@ -225,9 +226,9 @@ export default class Markpilot extends Plugin {
}

createEditorExtension() {
return inlineCompletionsExtension(async (...args) => {
// TODO:
// Extract this logic to somewhere appropriate.
const { settings } = this;

const fetcher = async (prefix: string, suffix: string) => {
const view = this.app.workspace.getActiveViewOfType(MarkdownView);
const file = view?.file;
const content = view?.editor.getValue();
Expand All @@ -244,13 +245,22 @@ export default class Markpilot extends Plugin {
) {
return;
}
return this.completionsClient.fetchCompletions(...args);
}, this);
return this.completionsClient.fetchCompletions(prefix, suffix);
};

const { debounced, cancel, force } = debounceAsyncFunc(
fetcher,
settings.completions.waitTime,
);

return inlineCompletionsExtension(debounced, cancel, force, this);
}

updateEditorExtension() {
const { workspace } = this.app;

// Clear the existing extensions and insert new ones,
// keeping the reference to the same array.
this.extensions.splice(
0,
this.extensions.length,
Expand All @@ -260,13 +270,26 @@ export default class Markpilot extends Plugin {
}

createChatView(leaf: WorkspaceLeaf) {
const view = new ChatView(leaf, this);
const fetcher = (messages: ChatMessage[]) => {
return this.chatClient.fetchChat(messages);
};
const { debounced, cancel } = debounceAsyncGenerator(fetcher, 0);

const view = new ChatView(leaf, debounced, cancel, this);
if (this.settings.chat.enabled) {
this.activateView();
}
return view;
}

updateChatView() {
if (this.settings.chat.enabled) {
this.activateView();
} else {
this.deactivateView();
}
}

async loadSettings() {
const data = await this.loadData();
if (data === null) {
Expand Down
6 changes: 1 addition & 5 deletions src/settings/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,7 @@ export class MarkpilotSettingTab extends PluginSettingTab {
toggle.setValue(settings.chat.enabled).onChange(async (value) => {
settings.chat.enabled = value;
await plugin.saveSettings();
if (value) {
plugin.activateView();
} else {
plugin.deactivateView();
}
plugin.updateChatView();
this.display(); // Re-render settings tab
}),
);
Expand Down
Loading

0 comments on commit c0c3561

Please sign in to comment.