From 5126a8d767524b720f96138cd0ad74639425d1d9 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sat, 30 Sep 2023 19:25:31 +0200 Subject: [PATCH 1/4] Separate citations state --- api/src/stampy_chat/chat.py | 2 +- api/tests/stampy_chat/test_chat.py | 4 ++-- web/src/hooks/useSearch.ts | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index 96e67db..a8eb69e 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -177,7 +177,7 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio top_k_blocks = get_top_k_blocks(index, query, k) yield { - "state": "loading", "phase": "semantic", + "state": "citations", "citations": [ {'title': block.title, 'author': block.authors, 'date': block.date, 'url': block.url} for block in top_k_blocks diff --git a/api/tests/stampy_chat/test_chat.py b/api/tests/stampy_chat/test_chat.py index b1b64f5..6ca58a3 100644 --- a/api/tests/stampy_chat/test_chat.py +++ b/api/tests/stampy_chat/test_chat.py @@ -270,7 +270,7 @@ def test_talk_to_robot_internal(history, context): with patch('openai.ChatCompletion.create', return_value=chunks): assert list(talk_to_robot_internal("index", "what is this about?", "default", history, 'session id')) == [ {'phase': 'semantic', 'state': 'loading'}, - {'citations': [], 'phase': 'semantic', 'state': 'loading'}, + {'citations': [], 'state': 'citations'}, {'phase': 'prompt', 'state': 'loading'}, {'phase': 'llm', 'state': 'loading'}, {'content': 'response 1', 'state': 'streaming'}, @@ -300,7 +300,7 @@ def test_talk_to_robot_internal_error(history, context): with patch('openai.ChatCompletion.create', return_value=chunks): assert list(talk_to_robot_internal("index", "what is this about?", "default", history, 'session id')) == [ {'phase': 'semantic', 'state': 'loading'}, - {'citations': [], 'phase': 'semantic', 'state': 'loading'}, + {'citations': [], 'state': 'citations'}, {'phase': 'prompt', 'state': 'loading'}, {'phase': 'llm', 'state': 'loading'}, {'content': 'response 1', 'state': 'streaming'}, diff --git a/web/src/hooks/useSearch.ts b/web/src/hooks/useSearch.ts index 919d62d..90a69e1 100644 --- a/web/src/hooks/useSearch.ts +++ b/web/src/hooks/useSearch.ts @@ -63,8 +63,10 @@ export const extractAnswer = async ( for await (var data of iterateData(res)) { switch (data.state) { case "loading": - // display loading phases, once citations are available toss them - // into the current item. + setCurrent({ phase: data.phase, ...result }); + break; + + case "citations": result = { ...result, citations: data?.citations || result?.citations || [], From 718b177bd0d3c820ea293d7f2429493eaeac5e55 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sat, 30 Sep 2023 23:36:38 +0200 Subject: [PATCH 2/4] Change method of numbering citations --- api/src/stampy_chat/followups.py | 3 ++ web/src/components/assistant.tsx | 58 ++++++++---------------- web/src/components/citations.tsx | 75 +++++++++++++++++++------------- web/src/components/entry.tsx | 8 ++-- web/src/hooks/useSearch.ts | 14 +++--- web/src/pages/index.tsx | 38 ++++++++++------ web/src/pages/semantic.tsx | 2 +- web/src/types.ts | 1 + 8 files changed, 102 insertions(+), 97 deletions(-) diff --git a/api/src/stampy_chat/followups.py b/api/src/stampy_chat/followups.py index 5afce48..cfc065c 100644 --- a/api/src/stampy_chat/followups.py +++ b/api/src/stampy_chat/followups.py @@ -23,6 +23,9 @@ def search_authored(query: str): def get_followups(query): + if not query.strip(): + return [] + url = 'https://nlp.stampy.ai/api/search?query=' + quote(query) response = requests.get(url).json() return [Followup(entry['title'], entry['pageid'], entry['score']) for entry in response] diff --git a/web/src/components/assistant.tsx b/web/src/components/assistant.tsx index 8047d82..dc292f8 100644 --- a/web/src/components/assistant.tsx +++ b/web/src/components/assistant.tsx @@ -1,49 +1,27 @@ -import { ProcessText, ShowCitation, ShowInTextCitation } from "./citations"; +import { useState } from "react"; +import { ShowCitation, CitationsBlock } from "./citations"; import { GlossarySpan } from "./glossary"; -import type { Citation, AssistantEntry } from "../types"; - -export const ShowAssistantEntry: React.FC<{entry: AssistantEntry}> = ({entry}) => { - const in_text_citation_regex = /\[([0-9]+)\]/g; - - let [response, cite_map] = ProcessText(entry.content, entry.base_count); - - // ----------------- create the ordered citation array ----------------- - - const citations = new Map(); - cite_map.forEach((value, key) => { - const index = key.charCodeAt(0) - 'a'.charCodeAt(0); - if (index >= entry.citations.length) { - console.log("invalid citation index: " + index); - } else { - citations.set(value, entry.citations[index]!); - } - }); +import type { Citation, AssistantEntry as AssistantType} from "../types"; +export const AssistantEntry: React.FC<{entry: AssistantType}> = ({entry}) => { return (
- { // split into paragraphs - response.split("\n").map(paragraph => (

{ - paragraph.split(in_text_citation_regex).map((text, i) => { - if (i % 2 === 0) { - return ; - } - i = parseInt(text) - 1; - if (!citations.has(i)) return `[${text}]`; - const citation = citations.get(i)!; - return ( - - ); - }) - }

)) - } -
    - { // show citations - Array.from(citations.entries()).map(([i, citation]) => ( -
  • - -
  • + { entry.content.split("\n").map(paragraph => ( + ()} + /> )) } +
      + { // show citations + Array.from(entry.citationsMap.values()).map(citation => ( +
    • + +
    • + )) + }
); diff --git a/web/src/components/citations.tsx b/web/src/components/citations.tsx index 2b08aa3..3aaaafa 100644 --- a/web/src/components/citations.tsx +++ b/web/src/components/citations.tsx @@ -2,9 +2,7 @@ import type { Citation } from "../types"; import { Colours, A } from "./html"; -// todo: memoize this if too slow. -export const ProcessText: (text: string, base_count: number) => [string, Map] = (text, base_count) => { - +export const formatCitations: (text: string) => string = (text) => { // ---------------------- normalize citation form ---------------------- // the general plan here is just to add parsing cases until we can respond // well to almost everything the LLM emits. We won't ever reach five nines, @@ -41,33 +39,29 @@ export const ProcessText: (text: string, base_count: number) => [string, Map `[${x}]` ) + return response; +} - // -------------- map citations from strings into numbers -------------- - +export const findCitations: (text: string, citations: Citations[]) => Map = (text, citations) => { // figure out what citations are in the response, and map them appropriately - const cite_map = new Map(); - let cite_count = 0; + const cite_map = new Map(); // scan a regex for [x] over the response. If x isn't in the map, add it. // (note: we're actually doing this twice - once on parsing, once on render. // if that looks like a problem, we could swap from strings to custom ropes). const regex = /\[([a-z]+)\]/g; let match; - let response_copy = "" - while ((match = regex.exec(response)) !== null) { - if (!cite_map.has(match[1]!)) { - cite_map.set(match[1]!, base_count + cite_count++); + while ((match = regex.exec(text)) !== null) { + const letter = match[1]; + const citation = citations[letter.charCodeAt(0) - 'a'.charCodeAt(0)] + if (!cite_map.has(letter!)) { + cite_map.set(letter!, citation); } - // replace [x] with [i] - response_copy += response.slice(response_copy.length, match.index) + `[${cite_map.get(match[1]!)! + 1}]`; } - - response = response_copy + response.slice(response_copy.length); - - return [response, cite_map] + return cite_map } -export const ShowCitation: React.FC<{citation: Citation, i: number}> = ({citation, i}) => { +export const ShowCitation: React.FC<{citation: Citation}> = ({citation}) => { var c_str = citation.title; @@ -79,25 +73,44 @@ export const ShowCitation: React.FC<{citation: Citation, i: number}> = ({citatio // if we don't have a url, link to a duckduckgo search for the title instead const url = citation.url && citation.url !== "" ? citation.url - : `https://duckduckgo.com/?q=${encodeURIComponent(citation.title)}`; + : `https://duckduckgo.com/?q=${encodeURIComponent(ndecitation.title)}`; return ( - - [{i + 1}] + [{citation.index}]

{c_str}

); }; -export const ShowInTextCitation: React.FC<{citation: Citation, i: number}> = ({citation, i}) => { - const url = citation.url && citation.url !== "" - ? citation.url - : `https://duckduckgo.com/?q=${encodeURIComponent(citation.title)}`; - return ( - - [{i + 1}] - - ); +export const CitationRef: React.FC<{citation: Citation}> = ({citation}) => { + const url = citation.url && citation.url !== "" + ? citation.url + : `https://duckduckgo.com/?q=${encodeURIComponent(citation.title)}`; + return ( + + [{citation.index}] + + ); }; + + +export const CitationsBlock: React.FC<{text: string, citations: Map, textRenderer: (t: str) => any}> = ({text, citations, textRenderer}) => { + const regex = /\[([a-z]+)\]/g; + return ( +

{ + text.split(regex).map((part, i) => { + // When splitting, the even parts are basic text sections, while the odd ones are + // citations + if (i % 2 == 0) { + return textRenderer(part) + } else { + return () + } + }) + } +

+ ) +} diff --git a/web/src/components/entry.tsx b/web/src/components/entry.tsx index 10e8eb3..d70afdc 100644 --- a/web/src/components/entry.tsx +++ b/web/src/components/entry.tsx @@ -1,11 +1,11 @@ import type { Entry as EntryType, - AssistantEntry, + AssistantEntry as AssistantEntryType, ErrorMessage, StampyMessage, UserEntry, } from "../types"; -import { ShowAssistantEntry } from "./assistant"; +import { AssistantEntry } from "./assistant"; import { GlossarySpan } from "./glossary"; import Image from "next/image"; import logo from "../logo.svg"; @@ -30,10 +30,10 @@ export const Error = ({ entry }: { entry: ErrorMessage }) => { ); }; -export const Assistant = ({ entry }: { entry: AssistantEntry }) => { +export const Assistant = ({ entry }: { entry: AssistantEntryType }) => { return (
  • - +
  • ); }; diff --git a/web/src/hooks/useSearch.ts b/web/src/hooks/useSearch.ts index 90a69e1..d614ae3 100644 --- a/web/src/hooks/useSearch.ts +++ b/web/src/hooks/useSearch.ts @@ -9,6 +9,7 @@ import type { Followup, SearchResult, } from "../types"; +import { formatCitations, findCitations } from '../components/citations'; const MAX_FOLLOWUPS = 4; const DATA_HEADER = "data: " @@ -50,14 +51,13 @@ export async function* iterateData(res: Response) { export const extractAnswer = async ( res: Response, - baseReferencesIndex: number, setCurrent: (e: CurrentSearch) => void ): Promise => { var result: AssistantEntry = { role: "assistant", content: "", citations: [], - base_count: baseReferencesIndex, + citationsMap: Map, }; var followups: Followup[] = []; for await (var data of iterateData(res)) { @@ -76,11 +76,12 @@ export const extractAnswer = async ( case "streaming": // incrementally build up the response + const content = formatCitations((result?.content || "") + data.content); result = { + content, role: "assistant", - content: (result?.content || "") + data.content, citations: result?.citations || [], - base_count: result?.base_count || baseReferencesIndex, + citationsMap: findCitations(content, result?.citations || []), }; setCurrent({ phase: "streaming", ...result }); break; @@ -120,7 +121,6 @@ export const queryLLM = async ( query: string, mode: string, history: HistoryEntry[], - baseReferencesIndex: number, setCurrent: (e?: CurrentSearch) => void, sessionId: string ): Promise => { @@ -132,7 +132,7 @@ export const queryLLM = async ( } try { - return await extractAnswer(res, baseReferencesIndex, setCurrent); + return await extractAnswer(res, setCurrent); } catch (e) { return { result: { role: "error", content: e ? e.toString() : "unknown error" }, @@ -193,7 +193,6 @@ export const runSearch = async ( query: string, query_source: "search" | "followups", mode: string, - baseReferencesIndex: number, entries: Entry[], setCurrent: (c: CurrentSearch) => void, sessionId: string @@ -210,7 +209,6 @@ export const runSearch = async ( query, mode, history, - baseReferencesIndex, setCurrent, sessionId ); diff --git a/web/src/pages/index.tsx b/web/src/pages/index.tsx index 8e01623..c04545b 100644 --- a/web/src/pages/index.tsx +++ b/web/src/pages/index.tsx @@ -6,12 +6,19 @@ import Image from 'next/image'; import Page from "../components/page" import { API_URL } from "../settings" import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; -import type { Citation, Entry, UserEntry, AssistantEntry, ErrorMessage, StampyMessage } from "../types"; +import type { + CurrentSearch, + Citation, + Entry, + UserEntry, + AssistantEntry as AssistantEntryType, + ErrorMessage, + StampyMessage +} from "../types"; import { SearchBox, Followup } from "../components/searchbox"; import { GlossarySpan } from "../components/glossary"; import { Controls, Mode } from "../components/controls"; -import { ShowAssistantEntry } from "../components/assistant"; -import { ProcessText } from "../components/citations"; +import { AssistantEntry } from "../components/assistant"; import { Entry as EntryTag } from "../components/entry"; const MAX_FOLLOWUPS = 4; @@ -24,7 +31,7 @@ type State = { citations: Citation[]; } | { state: "streaming"; - response: AssistantEntry; + response: AssistantEntryType; }; type Mode = "rookie" | "concise" | "default"; @@ -43,7 +50,6 @@ function scroll30() { const Home: NextPage = () => { const [entries, setEntries] = useState([]); - const [runningIndex, setRunningIndex] = useState(0); const [current, setCurrent] = useState(); const [sessionId, setSessionId] = useState() @@ -70,6 +76,15 @@ const Home: NextPage = () => { } }; + const updateCitations = (current: CurrentSearch) => { + const citations = Array.from(current.citationsMap.values()); + if (citations.some(c => !c.index)) { + let index = 1; + citations.forEach((c) => {c.index = index++}); + setCurrent(current) + } + } + const search = async ( query: string, query_source: "search" | "followups", @@ -89,16 +104,12 @@ const Home: NextPage = () => { query, query_source, mode[0], - runningIndex, entries, updateCurrent, sessionId, ); setCurrent(undefined); - if (query_source === "search") { - setRunningIndex(runningIndex + ProcessText(result.content, 0)[1].size); - } setEntries((prev) => [...prev, result]); enable(followups || []); scroll30(); @@ -116,12 +127,13 @@ const Home: NextPage = () => { last_entry =

    Loading: Waiting for LLM...

    ; break; case "streaming": - last_entry = ; + updateCitations(current) + last_entry = ; break; case "followups": last_entry = <> - -

    Checking for followups...

    + +

    Checking for followups...

    ; break; } @@ -135,7 +147,7 @@ const Home: NextPage = () => {