Skip to content

Commit

Permalink
Allow abortion of current search
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Oct 3, 2023
1 parent 3f39eac commit ac25ce7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 55 deletions.
4 changes: 2 additions & 2 deletions api/tests/stampy_chat/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ def test_check_openai_moderation_not_flagged():


@pytest.mark.parametrize('prompt, remaining', (
([{'role': 'system', 'content': 'bla'}], 4043),
([{'role': 'system', 'content': 'bla'}], 4045),
(
[
{'role': 'system', 'content': 'bla'},
{'role': 'user', 'content': 'message 1'},
{'role': 'assistant', 'content': 'response 1'},
],
4035
4037
),
(
[
Expand Down
21 changes: 12 additions & 9 deletions web/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,33 @@ const Chat = ({sessionId, settings, onQuery, onNewEntry}: ChatParams) => {
const search = async (
query: string,
query_source: 'search' | 'followups',
disable: () => void,
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void,
controller: AbortController
) => {
// clear the query box, append to entries
const userEntry: Entry = {
role: 'user',
content: query_source === 'search' ? query : query.split('\n', 2)[1]!,
}
addEntry(userEntry)
disable()

const {result, followups} = await runSearch(
query,
query_source,
settings,
entries,
updateCurrent,
sessionId
sessionId,
controller
)
if (result.content !== 'aborted') {
addEntry(userEntry)
addEntry(result)
enable(followups || [])
scroll30()
} else {
enable([])
}
setCurrent(undefined)

addEntry(result)
enable(followups || [])
scroll30()
}

var last_entry = <></>
Expand Down
55 changes: 30 additions & 25 deletions web/src/components/searchbox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ const SearchBoxInternal: React.FC<{
search: (
query: string,
query_source: 'search' | 'followups',
disable: () => void,
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void,
controller: AbortController
) => void
onQuery?: (q: string) => any
}> = ({search, onQuery}) => {
Expand All @@ -42,20 +42,20 @@ const SearchBoxInternal: React.FC<{
const [query, setQuery] = useState(initial_query)
const [loading, setLoading] = useState(false)
const [followups, setFollowups] = useState<Followup[]>([])
const [controller, setController] = useState(new AbortController())

const inputRef = React.useRef<HTMLTextAreaElement>(null)

// because everything is async, I can't just manually set state at the
// point we do a search. Instead it needs to be passed into the search
// method, for some reason.
const enable = (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => {
setLoading(false)
setFollowups(f_set)
}
const disable = () => {
setLoading(true)
setQuery('')
}
const enable =
(controller: AbortController) => (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => {
if (!controller.signal.aborted) setQuery('')

setLoading(false)
setFollowups(f_set)
}

useEffect(() => {
// set focus on the input box
Expand All @@ -70,7 +70,16 @@ const SearchBoxInternal: React.FC<{
inputRef.current.selectionEnd = inputRef.current.textLength
}, [])

if (loading) return <></>
const runSearch = (query: string, searchtype: 'search' | 'followups') => () => {
if (loading || query.trim() === '') return

setLoading(true)
const controller = new AbortController()
setController(controller)
search(query, 'search', enable(controller), controller)
}
const cancelSearch = () => controller.abort()

return (
<>
<div className="mt-1 flex flex-col items-end">
Expand All @@ -80,9 +89,7 @@ const SearchBoxInternal: React.FC<{
<li key={i}>
<button
className="my-1 border border-gray-300 px-1"
onClick={() => {
search(followup.pageid + '\n' + followup.text, 'followups', disable, enable)
}}
onClick={runSearch(followup.pageid + '\n' + followup.text, 'followups')}
>
<span> {followup.text} </span>
</button>
Expand All @@ -91,13 +98,7 @@ const SearchBoxInternal: React.FC<{
})}
</div>

<form
className="mt-1 mb-2 flex"
onSubmit={(e) => {
e.preventDefault()
search(query, 'search', disable, enable)
}}
>
<div className="mt-1 mb-2 flex">
<TextareaAutosize
className="flex-1 resize-none border border-gray-300 px-1"
ref={inputRef}
Expand All @@ -112,14 +113,18 @@ const SearchBoxInternal: React.FC<{
// if <enter> without <shift>, submit the form (if it's not empty)
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault()
if (query.trim() !== '') search(query, 'search', disable, enable)
runSearch(query, 'search')
}
}}
/>
<button className="ml-2" type="submit" disabled={loading}>
{loading ? 'Loading...' : 'Search'}
<button
className="ml-2"
type="button"
onClick={loading ? cancelSearch : runSearch(query, 'search')}
>
{loading ? 'Cancel' : 'Search'}
</button>
</form>
</div>
</>
)
}
Expand Down
45 changes: 33 additions & 12 deletions web/src/hooks/useSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ type HistoryEntry = {
content: string
}

const ignoreAbort = (error: Error) => {
if (error.name !== 'AbortError') {
throw error
}
}

export async function* iterateData(res: Response) {
const reader = res.body!.getReader()
var message = ''
Expand Down Expand Up @@ -104,9 +110,11 @@ const fetchLLM = async (
sessionId: string,
query: string,
settings: LLMSettings,
history: HistoryEntry[]
): Promise<Response> =>
history: HistoryEntry[],
controller: AbortController
): Promise<Response | void> =>
fetch(API_URL + '/chat', {
signal: controller.signal,
method: 'POST',
cache: 'no-cache',
keepalive: true,
Expand All @@ -116,25 +124,31 @@ const fetchLLM = async (
},

body: JSON.stringify({sessionId, query, history, settings}),
})
}).catch(ignoreAbort)

export const queryLLM = async (
query: string,
settings: LLMSettings,
history: HistoryEntry[],
setCurrent: (e?: CurrentSearch) => void,
sessionId: string
sessionId: string,
controller: AbortController
): Promise<SearchResult> => {
// do SSE on a POST request.
const res = await fetchLLM(sessionId, query, settings, history)
const res = await fetchLLM(sessionId, query, settings, history, controller)

if (!res.ok) {
if (!res) {
return {result: {role: 'error', content: 'No response from server'}}
} else if (!res.ok) {
return {result: {role: 'error', content: 'POST Error: ' + res.status}}
}

try {
return await extractAnswer(res, setCurrent)
} catch (e) {
if ((e as Error)?.name === 'AbortError') {
return {result: {role: 'error', content: 'aborted'}}
}
return {
result: {role: 'error', content: e ? e.toString() : 'unknown error'},
}
Expand All @@ -147,16 +161,22 @@ const cleanStampyContent = (contents: string) =>
(_, pre, linkParts, post) => `<a${pre}href="${STAMPY_URL}/?state=${linkParts}"${post}</a>`
)

export const getStampyContent = async (questionId: string): Promise<SearchResult> => {
export const getStampyContent = async (
questionId: string,
controller: AbortController
): Promise<SearchResult> => {
const res = await fetch(`${STAMPY_CONTENT_URL}/${questionId}`, {
method: 'GET',
signal: controller.signal,
headers: {
'Content-Type': 'application/json',
Accept: 'application/json',
},
})
}).catch(ignoreAbort)

if (!res.ok) {
if (!res) {
return {result: {role: 'error', content: 'No response from server'}}
} else if (!res.ok) {
return {result: {role: 'error', content: 'POST Error: ' + res.status}}
}

Expand Down Expand Up @@ -193,7 +213,8 @@ export const runSearch = async (
settings: LLMSettings,
entries: Entry[],
setCurrent: (c: CurrentSearch) => void,
sessionId: string
sessionId: string,
controller: AbortController
): Promise<SearchResult> => {
if (query_source === 'search') {
const history = entries
Expand All @@ -203,12 +224,12 @@ export const runSearch = async (
content: entry.content.trim(),
}))

return await queryLLM(query, settings, history, setCurrent, sessionId)
return await queryLLM(query, settings, history, setCurrent, sessionId, controller)
} else {
// ----------------- HUMAN AUTHORED CONTENT RETRIEVAL ------------------
const [questionId] = query.split('\n', 2)
if (questionId) {
return await getStampyContent(questionId)
return await getStampyContent(questionId, controller)
}
const result = {
role: 'error',
Expand Down
21 changes: 14 additions & 7 deletions web/src/pages/semantic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,42 @@ import type {Followup} from '../types'
import Page from '../components/page'
import {SearchBox} from '../components/searchbox'

const ignoreAbort = (error: Error) => {
if (error.name !== 'AbortError') {
throw error
}
}

const Semantic: NextPage = () => {
const [results, setResults] = useState<SemanticEntry[]>([])

const semantic_search = async (
query: string,
_query_source: 'search' | 'followups',
disable: () => void,
enable: (f_set: Followup[]) => void
enable: (f_set: Followup[]) => void,
controller: AbortController
) => {
disable()

const res = await fetch(API_URL + '/semantic', {
method: 'POST',
signal: controller.signal,
headers: {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*',
},
body: JSON.stringify({query: query}),
})
}).catch(ignoreAbort)

if (!res.ok) {
if (!res) {
enable([])
return
} else if (!res.ok) {
console.error('load failure: ' + res.status)
}
enable([])

const data = await res.json()

setResults(data)
enable([])
}

return (
Expand Down

0 comments on commit ac25ce7

Please sign in to comment.