Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow chat items to be deleted #120

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def make_memory(settings, history, callbacks):
return_messages=True,
callbacks=callbacks
)
memory.set_messages(history)
memory.set_messages([i for i in history if i.get('role') != 'deleted'])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure the LLM only sees non deleted items

return memory


Expand All @@ -193,16 +193,6 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
callbacks += [BroadcastCallbackHandler(callback)]
chat_model = get_model(streaming=True, callbacks=callbacks, max_tokens=settings.max_response_tokens)

memory = LimitedConversationSummaryBufferMemory(
llm=get_model(),
max_token_limit=settings.history_tokens,
max_history=settings.maxHistory,
chat_memory=ChatMessageHistory(),
return_messages=True,
callbacks=callbacks
)
memory.set_messages(history)

chain = LLMChain(
llm=chat_model,
verbose=False,
Expand Down
16 changes: 6 additions & 10 deletions api/src/stampy_chat/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,15 @@ def __init__(self, engine=None, batch_size=100, save_every=1):
self._last_save = time.time()

def commit(self):
with Session(self.engine) as session:
try:
try:
with Session(self.engine) as session:
session.add_all(self.batch)
session.commit()
logger.debug('added %s items', len(self.batch))
self.batch = []
except SQLAlchemyError as e:
logger.warn('Got error when trying to commit to database: %s', e)
session.rollback()
raise e
self.batch = []
self._last_save = time.time()
except SQLAlchemyError as e:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's an interesting Heisenbug with sessions being dropped - this won't fix it, but at least won't break the chat when it happens. It only happens when the chatbot isn't used for a long period of time, which hopefully shouldn't be an issue once things start being used?
An objection can be raised that this will result in interactions not getting logged, which is an issue, but not a pressing one, as this is only for logging purposes, and should only happen if the error occurs and then the worker is shut down - otherwise the interactions will be saved the next time something happens, as they're saved in batches anyway

logger.warn('Got error when trying to commit to database: %s', e)

def add(self, *items):
"""Add the provided items to the database, commiting them if needed."""
Expand All @@ -69,6 +67,4 @@ def add(self, *items):

def __del__(self):
logger.debug('cleaning up session')
if self.session:
self.commit()
self.session.close()
self.commit()
22 changes: 21 additions & 1 deletion api/tests/stampy_chat/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from unittest.mock import patch
from langchain.llms.fake import FakeListLLM
from langchain.memory import ChatMessageHistory
from langchain.prompts import ChatPromptTemplate
from langchain.schema import ChatMessage, HumanMessage, SystemMessage

from stampy_chat.settings import Settings
from stampy_chat.callbacks import StampyCallbackHandler
from stampy_chat.chat import (
LimitedConversationSummaryBufferMemory,
MessageBufferPromptTemplate,
PrefixedPrompt
PrefixedPrompt,
make_memory,
)


Expand Down Expand Up @@ -140,3 +143,20 @@ def on_memory_set_end(self, messages):
'start': history,
'end': memory.chat_memory,
}


def test_make_memory_skips_deleted():
history = [
{'content': 'this should be kept', 'role': 'system'},
{'content': 'as should this', 'role': 'human'},
{'content': 'this will be ignored', 'role': 'deleted'},
{'content': 'bla bla bla', 'role': 'assistant'},
{'content': 'remove me!!', 'role': 'deleted'},
]
with patch('stampy_chat.chat.get_model', return_value=FakeListLLM(responses=[])):
mem = make_memory(Settings(), history, [])
assert mem.chat_memory == ChatMessageHistory(messages=[
ChatMessage(content='this should be kept', role='system'),
ChatMessage(content='as should this', role='human'),
ChatMessage(content='bla bla bla', role='assistant'),
])
23 changes: 20 additions & 3 deletions web/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,26 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => {

return (
<ul className="flex-auto">
{entries.map((entry, i) => (
<EntryTag entry={entry} key={i} />
))}
{entries.map(
(entry, i) =>
!entry.deleted && (
<li className="group relative flex" key={i}>
<EntryTag entry={entry} />
<span
className="delete-item absolute right-5 hidden cursor-pointer group-hover:block"
onClick={() => {
const entry = entries[i];
if (entry !== undefined) {
entry.deleted = true;
setEntries([...entries]);
}
}}
>
</span>
</li>
)
)}
<SearchBox search={search} onQuery={onQuery} />

{last_entry}
Expand Down
70 changes: 30 additions & 40 deletions web/src/components/entry.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,60 +13,50 @@ import TextareaAutosize from "react-textarea-autosize";

export const User = ({ entry }: { entry: UserEntry }) => {
return (
<li className="mt-1 mb-2 flex">
<TextareaAutosize
className="flex-1 resize-none border border-gray-300 px-1"
value={entry.content}
/>
</li>
<TextareaAutosize
className="flex-1 resize-none border border-gray-300 px-1"
value={entry.content}
/>
);
};

export const Error = ({ entry }: { entry: ErrorMessage }) => {
return (
<li>
<p className="border border-red-500 bg-red-100 px-1 text-red-800">
{" "}
{entry.content}{" "}
</p>
</li>
<p className="border border-red-500 bg-red-100 px-1 text-red-800">
{" "}
{entry.content}{" "}
</p>
);
};

export const Assistant = ({ entry }: { entry: AssistantEntryType }) => {
return (
<li>
<AssistantEntry entry={entry} />
</li>
);
return <AssistantEntry entry={entry} />;
};

export const Stampy = ({ entry }: { entry: StampyMessage }) => {
return (
<li>
<div
className="my-7 rounded bg-slate-500 px-4 py-0.5 text-slate-50"
style={{
marginLeft: "auto",
marginRight: "auto",
maxWidth: "99.8%",
}}
>
<div>
<GlossarySpan content={entry.content} />
</div>
<div className="mb-3 flex justify-end">
<a
href={entry.url}
target="_blank"
className="flex items-center space-x-1"
>
<span>aisafety.info</span>
<Image src={logo} alt="aisafety.info logo" width={19} />
</a>
</div>
<div
className="my-7 rounded bg-slate-500 px-4 py-0.5 text-slate-50"
style={{
marginLeft: "auto",
marginRight: "auto",
maxWidth: "99.8%",
}}
>
<div>
<GlossarySpan content={entry.content} />
</div>
<div className="mb-3 flex justify-end">
<a
href={entry.url}
target="_blank"
className="flex items-center space-x-1"
>
<span>aisafety.info</span>
<Image src={logo} alt="aisafety.info logo" width={19} />
</a>
</div>
</li>
</div>
);
};

Expand Down
5 changes: 3 additions & 2 deletions web/src/hooks/useSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ const MAX_FOLLOWUPS = 4;
const DATA_HEADER = "data: ";
const EVENT_END_HEADER = "event: close";

type EntryRole = "error" | "stampy" | "assistant" | "user" | "deleted";
type HistoryEntry = {
role: "error" | "stampy" | "assistant" | "user";
role: EntryRole;
content: string;
};

Expand Down Expand Up @@ -223,7 +224,7 @@ export const runSearch = async (
const history = entries
.filter((entry) => entry.role !== "error")
.map((entry) => ({
role: entry.role,
role: (entry.deleted ? "deleted" : entry.role) as EntryRole,
Aprillion marked this conversation as resolved.
Show resolved Hide resolved
content: entry.content.trim(),
}));

Expand Down
4 changes: 4 additions & 0 deletions web/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,28 @@ export type Entry = UserEntry | AssistantEntry | ErrorMessage | StampyMessage;
export type UserEntry = {
role: "user";
content: string;
deleted?: boolean;
};

export type AssistantEntry = {
role: "assistant";
content: string;
citations: Citation[];
citationsMap: Map<string, Citation>;
deleted?: boolean;
};

export type ErrorMessage = {
role: "error";
content: string;
deleted?: boolean;
};

export type StampyMessage = {
role: "stampy";
content: string;
url: string;
deleted?: boolean;
};

export type SearchResult = {
Expand Down