Skip to content

Commit

Permalink
Support pending/loading message while waiting for response (#821)
Browse files Browse the repository at this point in the history
* support pending message draft

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* styling + pending message for /fix

* change default pending message

* remove persona groups

* inline styling

* single timestamp

* use message id as component key

Co-authored-by: david qiu <[email protected]>

* fix conditional useEffect

* prefer MUI Typography in PendingMessageElement to match font size

* merge 2 outer div elements into 1

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: david qiu <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent ff022fe commit 02d1966
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 34 deletions.
5 changes: 3 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
with self.pending("Searching learned documents"):
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down
60 changes: 59 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import contextlib
import os
import time
import traceback
Expand All @@ -17,7 +18,13 @@

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
from jupyter_ai.models import (
AgentChatMessage,
ChatMessage,
ClosePendingMessage,
HumanChatMessage,
PendingMessage,
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.pydantic_v1 import BaseModel
Expand Down Expand Up @@ -193,6 +200,57 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
handler.broadcast_message(agent_msg)
break

def start_pending(self, text: str, ellipsis: bool = True) -> str:
"""
Sends a pending message to the client.
Returns the pending message ID.
"""
persona = self.config_manager.persona

pending_msg = PendingMessage(
id=uuid4().hex,
time=time.time(),
body=text,
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
ellipsis=ellipsis,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(pending_msg)
break
return pending_msg

def close_pending(self, pending_msg: PendingMessage):
"""
Closes a pending message.
"""
close_pending_msg = ClosePendingMessage(
id=pending_msg.id,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(close_pending_msg)
break

@contextlib.contextmanager
def pending(self, text: str, ellipsis: bool = True):
"""
Context manager that sends a pending message to the client, and closes
it after the block is executed.
"""
pending_msg = self.start_pending(text, ellipsis=ellipsis)
try:
yield
finally:
self.close_pending(pending_msg)

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
Expand Down
5 changes: 4 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,8 @@ def create_llm_chain(

async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"])
with self.pending("Generating response"):
response = await self.llm_chain.apredict(
input=message.body, stop=["\nHuman:"]
)
self.reply(response, message)
17 changes: 9 additions & 8 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,13 @@ async def process_message(self, message: HumanChatMessage):
extra_instructions = message.body[4:].strip() or "None."

self.get_llm_chain()
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
cell_content=selection.source,
error_name=selection.error.name,
error_value=selection.error.value,
traceback="\n".join(selection.error.traceback),
)
with self.pending("Analyzing error"):
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
cell_content=selection.source,
error_name=selection.error.name,
error_value=selection.error.value,
traceback="\n".join(selection.error.traceback),
)
self.reply(response, message)
24 changes: 11 additions & 13 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,17 @@ async def process_message(self, message: HumanChatMessage):
# delete and relearn index if embedding model was changed
await self.delete_and_relearn()

if args.verbose:
self.reply(f"Loading and splitting files for {load_path}", message)

try:
await self.learn_dir(
load_path, args.chunk_size, args.chunk_overlap, args.all_files
)
except Exception as e:
response = f"""Learn documents in **{load_path}** failed. {str(e)}."""
else:
self.save()
response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
with self.pending(f"Loading and splitting files for {load_path}"):
try:
await self.learn_dir(
load_path, args.chunk_size, args.chunk_overlap, args.all_files
)
except Exception as e:
response = f"""Learn documents in **{load_path}** failed. {str(e)}."""
else:
self.save()
response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(response, message)

def _build_list_response(self):
Expand Down
23 changes: 22 additions & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,34 @@ class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"


class PendingMessage(BaseModel):
type: Literal["pending"] = "pending"
id: str
time: float
body: str
persona: Persona
ellipsis: bool = True


class ClosePendingMessage(BaseModel):
type: Literal["pending"] = "close-pending"
id: str


# the type of messages being broadcast to clients
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
]

Message = Union[AgentChatMessage, HumanChatMessage, ConnectionMessage, ClearMessage]
Message = Union[
AgentChatMessage,
HumanChatMessage,
ConnectionMessage,
ClearMessage,
PendingMessage,
ClosePendingMessage,
]


class ChatHistory(BaseModel):
Expand Down
29 changes: 22 additions & 7 deletions packages/jupyter-ai/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime';

import { JlThemeProvider } from './jl-theme-provider';
import { ChatMessages } from './chat-messages';
import { PendingMessages } from './pending-messages';
import { ChatInput } from './chat-input';
import { ChatSettings } from './chat-settings';
import { AiService } from '../handler';
Expand Down Expand Up @@ -38,6 +39,9 @@ function ChatBody({
rmRegistry: renderMimeRegistry
}: ChatBodyProps): JSX.Element {
const [messages, setMessages] = useState<AiService.ChatMessage[]>([]);
const [pendingMessages, setPendingMessages] = useState<
AiService.PendingMessage[]
>([]);
const [showWelcomeMessage, setShowWelcomeMessage] = useState<boolean>(false);
const [includeSelection, setIncludeSelection] = useState(true);
const [replaceSelection, setReplaceSelection] = useState(false);
Expand Down Expand Up @@ -73,14 +77,24 @@ function ChatBody({
*/
useEffect(() => {
function handleChatEvents(message: AiService.Message) {
if (message.type === 'connection') {
return;
} else if (message.type === 'clear') {
setMessages([]);
return;
switch (message.type) {
case 'connection':
return;
case 'clear':
setMessages([]);
return;
case 'pending':
setPendingMessages(pendingMessages => [...pendingMessages, message]);
return;
case 'close-pending':
setPendingMessages(pendingMessages =>
pendingMessages.filter(p => p.id !== message.id)
);
return;
default:
setMessages(messageGroups => [...messageGroups, message]);
return;
}

setMessages(messageGroups => [...messageGroups, message]);
}

chatHandler.addListener(handleChatEvents);
Expand Down Expand Up @@ -157,6 +171,7 @@ function ChatBody({
<>
<ScrollContainer sx={{ flexGrow: 1 }}>
<ChatMessages messages={messages} rmRegistry={renderMimeRegistry} />
<PendingMessages messages={pendingMessages} />
</ScrollContainer>
<ChatInput
value={input}
Expand Down
115 changes: 115 additions & 0 deletions packages/jupyter-ai/src/components/pending-messages.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import React, { useState, useEffect } from 'react';

import { Box, Typography } from '@mui/material';
import { AiService } from '../handler';
import { ChatMessageHeader } from './chat-messages';

type PendingMessagesProps = {
messages: AiService.PendingMessage[];
};

type PendingMessageElementProps = {
text: string;
ellipsis: boolean;
};

function PendingMessageElement(props: PendingMessageElementProps): JSX.Element {
const [dots, setDots] = useState('');

useEffect(() => {
const interval = setInterval(() => {
setDots(dots => (dots.length < 3 ? dots + '.' : ''));
}, 500);

return () => clearInterval(interval);
}, []);

let text = props.text;
if (props.ellipsis) {
text = props.text + dots;
}

return (
<Box>
{text.split('\n').map((line, index) => (
<Typography key={index} sx={{ lineHeight: 0.6 }}>
{line}
</Typography>
))}
</Box>
);
}

export function PendingMessages(
props: PendingMessagesProps
): JSX.Element | null {
const [timestamp, setTimestamp] = useState<string>('');
const [agentMessage, setAgentMessage] =
useState<AiService.AgentChatMessage | null>(null);

useEffect(() => {
if (props.messages.length === 0) {
setAgentMessage(null);
setTimestamp('');
return;
}
const lastMessage = props.messages[props.messages.length - 1];
setAgentMessage({
type: 'agent',
id: lastMessage.id,
time: lastMessage.time,
body: '',
reply_to: '',
persona: lastMessage.persona
});

// timestamp format copied from ChatMessage
const newTimestamp = new Date(lastMessage.time * 1000).toLocaleTimeString(
[],
{
hour: 'numeric',
minute: '2-digit'
}
);
setTimestamp(newTimestamp);
}, [props.messages]);

if (!agentMessage) {
return null;
}

return (
<Box
sx={{
padding: 4,
borderTop: '1px solid var(--jp-border-color2)'
}}
>
<ChatMessageHeader
message={agentMessage}
timestamp={timestamp}
sx={{
marginBottom: 4
}}
/>
<Box
sx={{
marginBottom: 1,
paddingRight: 0,
color: 'var(--jp-ui-font-color2)',
'& > :not(:last-child)': {
marginBottom: '2em'
}
}}
>
{props.messages.map(message => (
<PendingMessageElement
key={message.id}
text={message.body}
ellipsis={message.ellipsis}
/>
))}
</Box>
</Box>
);
}
18 changes: 17 additions & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,28 @@ export namespace AiService {
type: 'clear';
};

export type PendingMessage = {
type: 'pending';
id: string;
time: number;
body: string;
persona: Persona;
ellipsis: boolean;
};

export type ClosePendingMessage = {
type: 'close-pending';
id: string;
};

export type ChatMessage = AgentChatMessage | HumanChatMessage;
export type Message =
| AgentChatMessage
| HumanChatMessage
| ConnectionMessage
| ClearMessage;
| ClearMessage
| PendingMessage
| ClosePendingMessage;

export type ChatHistory = {
messages: ChatMessage[];
Expand Down

0 comments on commit 02d1966

Please sign in to comment.