diff --git a/api/main.py b/api/main.py index c24e81b..51b396c 100644 --- a/api/main.py +++ b/api/main.py @@ -44,11 +44,15 @@ def semantic(): @app.route('/chat', methods=['POST']) @cross_origin() def chat(): - query = request.json.get('query') + query = request.json.get('query', None) session_id = request.json.get('sessionId') history = request.json.get('history', []) settings = request.json.get('settings', {}) + if query is None: + query = history[-1].get('content') + history = history[:-1] + def formatter(item): if isinstance(item, Exception): item = {'state': 'error', 'error': str(item)} diff --git a/web/src/components/assistant.tsx b/web/src/components/assistant.tsx index 510b12a..5405656 100644 --- a/web/src/components/assistant.tsx +++ b/web/src/components/assistant.tsx @@ -5,27 +5,24 @@ import type { Citation, AssistantEntry as AssistantType } from "../types"; export const AssistantEntry: React.FC<{ entry: AssistantType }> = ({ entry, -}) => { - return ( -
- {entry.content.split("\n").map((paragraph, i) => ( - } - /> - ))} -
    - { - // show citations - Array.from(entry.citationsMap.values()).map((citation) => ( -
  • - -
  • - )) - } -
-
- ); -}; +}) => ( +
+ {entry.content.split("\n").map((paragraph, i) => ( + } + /> + ))} +
    + {entry.citationsMap && + // show citations + Array.from(entry.citationsMap.values()).map((citation) => ( +
  • + +
  • + ))} +
+
+); diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx index e6fd490..8038c32 100644 --- a/web/src/components/chat.tsx +++ b/web/src/components/chat.tsx @@ -1,5 +1,11 @@ import { useState, useEffect } from "react"; -import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; +import { + queryLLM, + getStampyContent, + EntryRole, + HistoryEntry, +} from "../hooks/useSearch"; +import { initialQuestions } from "../settings"; import type { CurrentSearch, @@ -8,6 +14,7 @@ import type { AssistantEntry as AssistantEntryType, LLMSettings, Followup, + SearchResult, } from "../types"; import useCitations from "../hooks/useCitations"; import { SearchBox } from "../components/searchbox"; @@ -41,6 +48,9 @@ function scroll30() { window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" }); } +const randomQuestion = () => + initialQuestions[Math.floor(Math.random() * initialQuestions.length)] || ""; + export const ChatResponse = ({ current, defaultElem, @@ -73,6 +83,22 @@ export const ChatResponse = ({ } }; +const makeHistory = (query: string, entries: Entry[]): HistoryEntry[] => { + const getRole = (entry: Entry): EntryRole => { + if (entry.deleted) return "deleted"; + if (entry.role === "stampy") return "assistant"; + return entry.role; + }; + + const history = entries + .filter((entry) => entry.role !== "error") + .map((entry) => ({ + role: getRole(entry), + content: entry.content.trim(), + })); + return [...history, { role: "user", content: query }]; +}; + type ChatParams = { sessionId: string; settings: LLMSettings; @@ -82,7 +108,11 @@ type ChatParams = { const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { const [entries, setEntries] = useState([]); + + const [query, setQuery] = useState(() => randomQuestion()); const [current, setCurrent] = useState(); + const [followups, setFollowups] = useState([]); + const [controller, setController] = useState(() => new AbortController()); const { citations, setEntryCitations } = useCitations(); const updateCurrent = (current: CurrentSearch) => { @@ -94,48 +124,79 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { } }; - const addEntry = (entry: Entry) => { + const addResult = (query: string, { result, followups }: SearchResult) => { + const userEntry = { role: "user", content: query }; setEntries((prev) => { - const entries = [...prev, entry]; + const entries = [...prev, userEntry, result] as Entry[]; if (onNewEntry) { onNewEntry(entries); } return entries; }); + setFollowups(followups || []); + setQuery(""); + scroll30(); }; - const search = async ( - query: string, - query_source: "search" | "followups", - enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void, - controller: AbortController - ) => { - // clear the query box, append to entries - const userEntry: Entry = { - role: "user", - content: query_source === "search" ? query : query.split("\n", 2)[1]!, + const abortable = + (f: any) => + (...args: any) => { + controller.abort(); + const newController = new AbortController(); + setController(newController); + return f(newController, ...args); }; - const { result, followups } = await runSearch( - query, - query_source, + const search = async (controller: AbortController, query: string) => { + // clear the query box, append to entries + setFollowups([]); + + const history = makeHistory(query, entries); + + const result = await queryLLM( settings, - entries, + history, updateCurrent, sessionId, controller ); - if (result.content !== "aborted") { - addEntry(userEntry); - addEntry(result); - enable(followups || []); - scroll30(); - } else { - enable([]); + + if (result.result.content !== "aborted") { + addResult(query, result); } setCurrent(undefined); }; + const fetchFollowup = async ( + controller: AbortController, + followup: Followup + ) => { + setCurrent({ role: "assistant", content: "", phase: "started" }); + const result = await getStampyContent(followup.pageid, controller); + if (!controller.signal.aborted) { + addResult(followup.text, result); + } + setCurrent(undefined); + }; + + const deleteEntry = (i: number) => { + const entry = entries[i]; + if (entry === undefined) { + return; + } else if ( + i === entries.length - 1 && + ["assistant", "stampy"].includes(entry.role) + ) { + const prev = entries[i - 1]; + if (prev !== undefined) setQuery(prev.content); + setEntries(entries.slice(0, i - 1)); + setFollowups([]); + } else { + entry.deleted = true; + setEntries([...entries]); + } + }; + return (
    {entries.map( @@ -145,20 +206,24 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { { - const entry = entries[i]; - if (entry !== undefined) { - entry.deleted = true; - setEntries([...entries]); - } - }} + onClick={() => deleteEntry(i)} > ✕ ) )} - + + + { + setQuery(v); + onQuery && onQuery(v); + }} + abortSearch={() => controller.abort()} + /> { }; export default Chat; + +const Followups = ({ + followups, + onClick, +}: { + followups: Followup[]; + onClick: (f: Followup) => void; +}) => ( +
    + {followups.map((followup: Followup, i: number) => ( +
  • + +
  • + ))} +
    +); diff --git a/web/src/components/searchbox.tsx b/web/src/components/searchbox.tsx index 0b7bd61..4bae393 100644 --- a/web/src/components/searchbox.tsx +++ b/web/src/components/searchbox.tsx @@ -6,36 +6,15 @@ import TextareaAutosize from "react-textarea-autosize"; import dynamic from "next/dynamic"; const SearchBoxInternal: React.FC<{ - search: ( - query: string, - query_source: "search" | "followups", - enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void, - controller: AbortController - ) => void; - onQuery?: (q: string) => any; -}> = ({ search, onQuery }) => { - const initial_query = - initialQuestions[Math.floor(Math.random() * initialQuestions.length)] || ""; - - const [query, setQuery] = useState(initial_query); + query: string; + search: (query: string) => void; + abortSearch: () => void; + onQuery: (q: string) => any; +}> = ({ query, search, onQuery, abortSearch }) => { const [loading, setLoading] = useState(false); - const [followups, setFollowups] = useState([]); - const [controller, setController] = useState(new AbortController()); const inputRef = React.useRef(null); - // because everything is async, I can't just manually set state at the - // point we do a search. Instead it needs to be passed into the search - // method, for some reason. - const enable = - (controller: AbortController) => - (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => { - if (!controller.signal.aborted) setQuery(""); - - setLoading(false); - setFollowups(f_set); - }; - useEffect(() => { // set focus on the input box if (!loading) inputRef.current?.focus(); @@ -49,61 +28,41 @@ const SearchBoxInternal: React.FC<{ inputRef.current.selectionEnd = inputRef.current.textLength; }, []); - const runSearch = - (query: string, searchtype: "search" | "followups") => () => { - if (loading || query.trim() === "") return; + const runSearch = (query: string) => async () => { + if (loading || query.trim() === "") return; + + setLoading(true); + await search(query); + setLoading(false); + }; - setLoading(true); - const controller = new AbortController(); - setController(controller); - search(query, searchtype, enable(controller), controller); - }; - const cancelSearch = () => controller.abort(); + const cancelSearch = () => { + abortSearch(); + setLoading(false); + }; return ( <> -
    - {" "} - {followups.map((followup, i) => { - return ( -
  • - -
  • - ); - })} -
    -
    { - setQuery(e.target.value); - onQuery && onQuery(e.target.value); - }} + onChange={(e) => onQuery(e.target.value)} onKeyDown={(e) => { // if , blur the input box if (e.key === "Escape") e.currentTarget.blur(); // if without , submit the form (if it's not empty) if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); - runSearch(query, "search")(); + runSearch(query)(); } }} /> diff --git a/web/src/hooks/useSearch.ts b/web/src/hooks/useSearch.ts index 24e9c72..ff8a352 100644 --- a/web/src/hooks/useSearch.ts +++ b/web/src/hooks/useSearch.ts @@ -16,8 +16,8 @@ const MAX_FOLLOWUPS = 4; const DATA_HEADER = "data: "; const EVENT_END_HEADER = "event: close"; -type EntryRole = "error" | "stampy" | "assistant" | "user" | "deleted"; -type HistoryEntry = { +export type EntryRole = "error" | "stampy" | "assistant" | "user" | "deleted"; +export type HistoryEntry = { role: EntryRole; content: string; }; @@ -112,7 +112,6 @@ export const extractAnswer = async ( const fetchLLM = async ( sessionId: string | undefined, - query: string, settings: LLMSettings, history: HistoryEntry[], controller: AbortController @@ -127,11 +126,10 @@ const fetchLLM = async ( Accept: "text/event-stream", }, - body: JSON.stringify({ sessionId, query, history, settings }), + body: JSON.stringify({ sessionId, history, settings }), }).catch(ignoreAbort); export const queryLLM = async ( - query: string, settings: LLMSettings, history: HistoryEntry[], setCurrent: (e?: CurrentSearch) => void, @@ -140,7 +138,7 @@ export const queryLLM = async ( ): Promise => { setCurrent({ ...makeEntry(), phase: "started" }); // do SSE on a POST request. - const res = await fetchLLM(sessionId, query, settings, history, controller); + const res = await fetchLLM(sessionId, settings, history, controller); if (!res) { return { result: { role: "error", content: "No response from server" } }; @@ -214,42 +212,3 @@ export const getStampyContent = async ( return { followups, result }; }; - -export const runSearch = async ( - query: string, - query_source: "search" | "followups", - settings: LLMSettings, - entries: Entry[], - setCurrent: (c: CurrentSearch) => void, - sessionId: string, - controller: AbortController -): Promise => { - if (query_source === "search") { - const history = entries - .filter((entry) => entry.role !== "error") - .map((entry) => ({ - role: (entry.deleted ? "deleted" : entry.role) as EntryRole, - content: entry.content.trim(), - })); - - return await queryLLM( - query, - settings, - history, - setCurrent, - sessionId, - controller - ); - } else { - // ----------------- HUMAN AUTHORED CONTENT RETRIEVAL ------------------ - const [questionId] = query.split("\n", 2); - if (questionId) { - return await getStampyContent(questionId, controller); - } - const result = { - role: "error", - content: "Could not extract Stampy id from " + query, - }; - return { result } as SearchResult; - } -}; diff --git a/web/src/pages/index.tsx b/web/src/pages/index.tsx index d1f9071..17865f0 100644 --- a/web/src/pages/index.tsx +++ b/web/src/pages/index.tsx @@ -2,9 +2,7 @@ import { type NextPage } from "next"; import { useState, useEffect } from "react"; import Link from "next/link"; -import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; import useSettings from "../hooks/useSettings"; -import type { Mode } from "../types"; import Page from "../components/page"; import Chat from "../components/chat"; import { Controls } from "../components/controls"; diff --git a/web/src/pages/semantic.tsx b/web/src/pages/semantic.tsx index ac23f21..e792227 100644 --- a/web/src/pages/semantic.tsx +++ b/web/src/pages/semantic.tsx @@ -1,10 +1,13 @@ import { type NextPage } from "next"; import React, { useState } from "react"; -import { API_URL } from "../settings"; +import { API_URL, initialQuestions } from "../settings"; import type { Followup } from "../types"; import Page from "../components/page"; import { SearchBox } from "../components/searchbox"; +const randomQuestion = () => + initialQuestions[Math.floor(Math.random() * initialQuestions.length)] || ""; + const ignoreAbort = (error: Error) => { if (error.name !== "AbortError") { throw error; @@ -12,14 +15,13 @@ const ignoreAbort = (error: Error) => { }; const Semantic: NextPage = () => { + const [query, setQuery] = useState(() => randomQuestion()); + const [controller, setController] = useState(() => new AbortController()); const [results, setResults] = useState([]); - const semantic_search = async ( - query: string, - _query_source: "search" | "followups", - enable: (f_set: Followup[]) => void, - controller: AbortController - ) => { + const semantic_search = async (query: string) => { + const controller = new AbortController(); + setController(controller); const res = await fetch(API_URL + "/semantic", { method: "POST", signal: controller.signal, @@ -31,12 +33,10 @@ const Semantic: NextPage = () => { }).catch(ignoreAbort); if (!res) { - enable([]); return; } else if (!res.ok) { console.error("load failure: " + res.status); } - enable([]); const data = await res.json(); @@ -46,7 +46,12 @@ const Semantic: NextPage = () => { return (

    Retrieve relevant data sources from alignment research

    - + controller.abort()} + />
      {results.map((entry, i) => (
    • diff --git a/web/src/pages/tester.tsx b/web/src/pages/tester.tsx index 4752682..dca8ec0 100644 --- a/web/src/pages/tester.tsx +++ b/web/src/pages/tester.tsx @@ -4,7 +4,7 @@ import { useState, useEffect, useCallback } from "react"; import Page from "../components/page"; import useCitations from "../hooks/useCitations"; -import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; +import { queryLLM, getStampyContent } from "../hooks/useSearch"; import useSettings from "../hooks/useSettings"; import { initialQuestions } from "../settings"; import type { @@ -63,9 +63,8 @@ const Tester: NextPage = () => { selected, index, query: queryLLM( - question, settings, - [], + [{ role: "user", content: question }], updater(index), sessionId, controller diff --git a/web/src/types.ts b/web/src/types.ts index 537d40c..6e24438 100644 --- a/web/src/types.ts +++ b/web/src/types.ts @@ -24,7 +24,7 @@ export type AssistantEntry = { role: "assistant"; content: string; citations?: Citation[]; - citationsMap: Map; + citationsMap?: Map; deleted?: boolean; };