diff --git a/api/tests/stampy_chat/test_chat.py b/api/tests/stampy_chat/test_chat.py index ccc9d6c..f7488ef 100644 --- a/api/tests/stampy_chat/test_chat.py +++ b/api/tests/stampy_chat/test_chat.py @@ -255,14 +255,14 @@ def test_check_openai_moderation_not_flagged(): @pytest.mark.parametrize('prompt, remaining', ( - ([{'role': 'system', 'content': 'bla'}], 4043), + ([{'role': 'system', 'content': 'bla'}], 4045), ( [ {'role': 'system', 'content': 'bla'}, {'role': 'user', 'content': 'message 1'}, {'role': 'assistant', 'content': 'response 1'}, ], - 4035 + 4037 ), ( [ diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx index ea49417..b54d265 100644 --- a/web/src/components/chat.tsx +++ b/web/src/components/chat.tsx @@ -94,16 +94,14 @@ const Chat = ({sessionId, settings, onQuery, onNewEntry}: ChatParams) => { const search = async ( query: string, query_source: 'search' | 'followups', - disable: () => void, - enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void + enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void, + controller: AbortController ) => { // clear the query box, append to entries const userEntry: Entry = { role: 'user', content: query_source === 'search' ? query : query.split('\n', 2)[1]!, } - addEntry(userEntry) - disable() const {result, followups} = await runSearch( query, @@ -111,13 +109,18 @@ const Chat = ({sessionId, settings, onQuery, onNewEntry}: ChatParams) => { settings, entries, updateCurrent, - sessionId + sessionId, + controller ) + if (result.content !== 'aborted') { + addEntry(userEntry) + addEntry(result) + enable(followups || []) + scroll30() + } else { + enable([]) + } setCurrent(undefined) - - addEntry(result) - enable(followups || []) - scroll30() } var last_entry = <> diff --git a/web/src/components/searchbox.tsx b/web/src/components/searchbox.tsx index ba1fec1..07b827c 100644 --- a/web/src/components/searchbox.tsx +++ b/web/src/components/searchbox.tsx @@ -32,8 +32,8 @@ const SearchBoxInternal: React.FC<{ search: ( query: string, query_source: 'search' | 'followups', - disable: () => void, - enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void + enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void, + controller: AbortController ) => void onQuery?: (q: string) => any }> = ({search, onQuery}) => { @@ -42,20 +42,20 @@ const SearchBoxInternal: React.FC<{ const [query, setQuery] = useState(initial_query) const [loading, setLoading] = useState(false) const [followups, setFollowups] = useState([]) + const [controller, setController] = useState(new AbortController()) const inputRef = React.useRef(null) // because everything is async, I can't just manually set state at the // point we do a search. Instead it needs to be passed into the search // method, for some reason. - const enable = (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => { - setLoading(false) - setFollowups(f_set) - } - const disable = () => { - setLoading(true) - setQuery('') - } + const enable = + (controller: AbortController) => (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => { + if (!controller.signal.aborted) setQuery('') + + setLoading(false) + setFollowups(f_set) + } useEffect(() => { // set focus on the input box @@ -70,7 +70,16 @@ const SearchBoxInternal: React.FC<{ inputRef.current.selectionEnd = inputRef.current.textLength }, []) - if (loading) return <> + const runSearch = (query: string, searchtype: 'search' | 'followups') => () => { + if (loading || query.trim() === '') return + + setLoading(true) + const controller = new AbortController() + setController(controller) + search(query, 'search', enable(controller), controller) + } + const cancelSearch = () => controller.abort() + return ( <>
@@ -80,9 +89,7 @@ const SearchBoxInternal: React.FC<{
  • @@ -91,13 +98,7 @@ const SearchBoxInternal: React.FC<{ })}
  • -
    { - e.preventDefault() - search(query, 'search', disable, enable) - }} - > +
    without , submit the form (if it's not empty) if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault() - if (query.trim() !== '') search(query, 'search', disable, enable) + runSearch(query, 'search') } }} /> - - +
    ) } diff --git a/web/src/hooks/useSearch.ts b/web/src/hooks/useSearch.ts index 3eb164b..697a328 100644 --- a/web/src/hooks/useSearch.ts +++ b/web/src/hooks/useSearch.ts @@ -21,6 +21,12 @@ type HistoryEntry = { content: string } +const ignoreAbort = (error: Error) => { + if (error.name !== 'AbortError') { + throw error + } +} + export async function* iterateData(res: Response) { const reader = res.body!.getReader() var message = '' @@ -104,9 +110,11 @@ const fetchLLM = async ( sessionId: string, query: string, settings: LLMSettings, - history: HistoryEntry[] -): Promise => + history: HistoryEntry[], + controller: AbortController +): Promise => fetch(API_URL + '/chat', { + signal: controller.signal, method: 'POST', cache: 'no-cache', keepalive: true, @@ -116,25 +124,31 @@ const fetchLLM = async ( }, body: JSON.stringify({sessionId, query, history, settings}), - }) + }).catch(ignoreAbort) export const queryLLM = async ( query: string, settings: LLMSettings, history: HistoryEntry[], setCurrent: (e?: CurrentSearch) => void, - sessionId: string + sessionId: string, + controller: AbortController ): Promise => { // do SSE on a POST request. - const res = await fetchLLM(sessionId, query, settings, history) + const res = await fetchLLM(sessionId, query, settings, history, controller) - if (!res.ok) { + if (!res) { + return {result: {role: 'error', content: 'No response from server'}} + } else if (!res.ok) { return {result: {role: 'error', content: 'POST Error: ' + res.status}} } try { return await extractAnswer(res, setCurrent) } catch (e) { + if ((e as Error)?.name === 'AbortError') { + return {result: {role: 'error', content: 'aborted'}} + } return { result: {role: 'error', content: e ? e.toString() : 'unknown error'}, } @@ -147,16 +161,22 @@ const cleanStampyContent = (contents: string) => (_, pre, linkParts, post) => `` ) -export const getStampyContent = async (questionId: string): Promise => { +export const getStampyContent = async ( + questionId: string, + controller: AbortController +): Promise => { const res = await fetch(`${STAMPY_CONTENT_URL}/${questionId}`, { method: 'GET', + signal: controller.signal, headers: { 'Content-Type': 'application/json', Accept: 'application/json', }, - }) + }).catch(ignoreAbort) - if (!res.ok) { + if (!res) { + return {result: {role: 'error', content: 'No response from server'}} + } else if (!res.ok) { return {result: {role: 'error', content: 'POST Error: ' + res.status}} } @@ -193,7 +213,8 @@ export const runSearch = async ( settings: LLMSettings, entries: Entry[], setCurrent: (c: CurrentSearch) => void, - sessionId: string + sessionId: string, + controller: AbortController ): Promise => { if (query_source === 'search') { const history = entries @@ -203,12 +224,12 @@ export const runSearch = async ( content: entry.content.trim(), })) - return await queryLLM(query, settings, history, setCurrent, sessionId) + return await queryLLM(query, settings, history, setCurrent, sessionId, controller) } else { // ----------------- HUMAN AUTHORED CONTENT RETRIEVAL ------------------ const [questionId] = query.split('\n', 2) if (questionId) { - return await getStampyContent(questionId) + return await getStampyContent(questionId, controller) } const result = { role: 'error', diff --git a/web/src/pages/semantic.tsx b/web/src/pages/semantic.tsx index ca3d77d..3ecea03 100644 --- a/web/src/pages/semantic.tsx +++ b/web/src/pages/semantic.tsx @@ -5,35 +5,42 @@ import type {Followup} from '../types' import Page from '../components/page' import {SearchBox} from '../components/searchbox' +const ignoreAbort = (error: Error) => { + if (error.name !== 'AbortError') { + throw error + } +} + const Semantic: NextPage = () => { const [results, setResults] = useState([]) const semantic_search = async ( query: string, _query_source: 'search' | 'followups', - disable: () => void, - enable: (f_set: Followup[]) => void + enable: (f_set: Followup[]) => void, + controller: AbortController ) => { - disable() - const res = await fetch(API_URL + '/semantic', { method: 'POST', + signal: controller.signal, headers: { 'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*', }, body: JSON.stringify({query: query}), - }) + }).catch(ignoreAbort) - if (!res.ok) { + if (!res) { enable([]) + return + } else if (!res.ok) { console.error('load failure: ' + res.status) } + enable([]) const data = await res.json() setResults(data) - enable([]) } return (