From b4a072cc8e36bec3f4493f60aed015b2bd31d9ae Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 1 Oct 2023 16:49:07 +0200 Subject: [PATCH 1/4] extract chat component --- web/src/components/chat.tsx | 151 +++++++++++++++++++++++++++++++ web/src/components/controls.tsx | 2 +- web/src/pages/index.tsx | 155 ++------------------------------ web/src/types.ts | 2 + 4 files changed, 159 insertions(+), 151 deletions(-) create mode 100644 web/src/components/chat.tsx diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx new file mode 100644 index 0000000..7701524 --- /dev/null +++ b/web/src/components/chat.tsx @@ -0,0 +1,151 @@ +import { useState, useEffect } from "react"; +import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; +import type { + CurrentSearch, + Citation, + Entry, + AssistantEntry as AssistantEntryType, + Mode, + Followup, +} from "../types"; +import { SearchBox } from "../components/searchbox"; +import { AssistantEntry } from "../components/assistant"; +import { Entry as EntryTag } from "../components/entry"; + +const MAX_FOLLOWUPS = 4; + +type State = + | { + state: "idle"; + } + | { + state: "loading"; + phase: "semantic" | "prompt" | "llm"; + citations: Citation[]; + } + | { + state: "streaming"; + response: AssistantEntryType; + }; + +// smooth-scroll to the bottom of the window if we're already less than 30% a screen away +// note: finicky interaction with "smooth" - maybe fix later. +function scroll30() { + if ( + document.documentElement.scrollHeight - window.scrollY > + window.innerHeight * 1.3 + ) + return; + window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" }); +} + +const Chat = ({ sessionId, mode }: { sessionId: string; mode: Mode }) => { + const [entries, setEntries] = useState([]); + const [current, setCurrent] = useState(); + const [citations, setCitations] = useState([]); + + const updateCurrent = (current: CurrentSearch) => { + setCurrent(current); + if (current?.phase === "streaming") { + scroll30(); + } + }; + + const updateCitations = ( + allCitations: Citation[], + current?: CurrentSearch + ) => { + if (!current) return; + + const entryCitations = Array.from(current.citationsMap.values()); + if (!entryCitations.some((c) => !c.index)) { + // All of the entries citations have indexes, so there weren't any changes since the last check + return; + } + + // Get a mapping of all known citations, so as to reuse them if they appear again + const citationsMapping = Object.fromEntries( + allCitations.map((c) => [c.title + c.url, c.index]) + ); + + entryCitations.forEach((c) => { + const hash = c.title + c.url; + const index = citationsMapping[hash]; + if (!index) { + c.index = allCitations.length + 1; + allCitations.push(c); + } else { + c.index = index; + } + }); + setCitations(allCitations); + setCurrent(current); + }; + + const search = async ( + query: string, + query_source: "search" | "followups", + disable: () => void, + enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void + ) => { + // clear the query box, append to entries + const userEntry: Entry = { + role: "user", + content: query_source === "search" ? query : query.split("\n", 2)[1]!, + }; + setEntries((prev) => [...prev, userEntry]); + disable(); + + const { result, followups } = await runSearch( + query, + query_source, + mode, + entries, + updateCurrent, + sessionId + ); + setCurrent(undefined); + + setEntries((prev) => [...prev, result]); + enable(followups || []); + scroll30(); + }; + + var last_entry = <>; + switch (current?.phase) { + case "semantic": + last_entry =

Loading: Performing semantic search...

; + break; + case "prompt": + last_entry =

Loading: Creating prompt...

; + break; + case "llm": + last_entry =

Loading: Waiting for LLM...

; + break; + case "streaming": + updateCitations(citations, current); + last_entry = ; + break; + case "followups": + last_entry = ( + <> + +

Checking for followups...

+ + ); + break; + } + + return ( +
    + {entries.map((entry, i) => ( + + ))} + + + {last_entry} +
+ ); +}; + +export default Chat; diff --git a/web/src/components/controls.tsx b/web/src/components/controls.tsx index d7fc17c..a92ae63 100644 --- a/web/src/components/controls.tsx +++ b/web/src/components/controls.tsx @@ -1,4 +1,4 @@ -export type Mode = "rookie" | "concise" | "default"; +import type { Mode } from "../types"; export const Controls = ({ mode, diff --git a/web/src/pages/index.tsx b/web/src/pages/index.tsx index aa584f4..e34d029 100644 --- a/web/src/pages/index.tsx +++ b/web/src/pages/index.tsx @@ -1,61 +1,17 @@ import { type NextPage } from "next"; import { useState, useEffect } from "react"; import Link from "next/link"; -import Image from "next/image"; -import Page from "../components/page"; -import { API_URL } from "../settings"; import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; -import type { - CurrentSearch, - Citation, - Entry, - UserEntry, - AssistantEntry as AssistantEntryType, - ErrorMessage, - StampyMessage, - Followup, -} from "../types"; -import { SearchBox } from "../components/searchbox"; -import { GlossarySpan } from "../components/glossary"; -import { Controls, Mode } from "../components/controls"; -import { AssistantEntry } from "../components/assistant"; -import { Entry as EntryTag } from "../components/entry"; +import type { Mode } from "../types"; +import Page from "../components/page"; +import Chat from "../components/chat"; +import { Controls } from "../components/controls"; const MAX_FOLLOWUPS = 4; -type State = - | { - state: "idle"; - } - | { - state: "loading"; - phase: "semantic" | "prompt" | "llm"; - citations: Citation[]; - } - | { - state: "streaming"; - response: AssistantEntryType; - }; - -// smooth-scroll to the bottom of the window if we're already less than 30% a screen away -// note: finicky interaction with "smooth" - maybe fix later. -function scroll30() { - if ( - document.documentElement.scrollHeight - window.scrollY > - window.innerHeight * 1.3 - ) - return; - window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" }); -} - const Home: NextPage = () => { - const [entries, setEntries] = useState([]); - const [current, setCurrent] = useState(); const [sessionId, setSessionId] = useState(""); - const [citations, setCitations] = useState([]); - - // [state, ready to save to localstorage] const [mode, setMode] = useState<[Mode, boolean]>(["default", false]); // store mode in localstorage @@ -70,100 +26,6 @@ const Home: NextPage = () => { setSessionId(crypto.randomUUID()); }, []); - const updateCurrent = (current: CurrentSearch) => { - setCurrent(current); - if (current?.phase === "streaming") { - scroll30(); - } - }; - - const updateCitations = ( - allCitations: Citation[], - current: CurrentSearch - ) => { - if (!current) return; - - const entryCitations = Array.from( - current.citationsMap.values() - ) as Citation[]; - if (!entryCitations.some((c) => !c.index)) { - // All of the entries citations have indexes, so there weren't any changes since the last check - return; - } - - // Get a mapping of all known citations, so as to reuse them if they appear again - const citationsMapping = Object.fromEntries( - allCitations.map((c) => [c.title + c.url, c.index]) - ); - - entryCitations.forEach((c) => { - const hash = c.title + c.url; - const index = citationsMapping[hash]; - if (index !== undefined) { - c.index = index; - } else { - c.index = allCitations.length + 1; - allCitations.push(c); - } - }); - setCitations(allCitations); - setCurrent(current); - }; - - const search = async ( - query: string, - query_source: "search" | "followups", - disable: () => void, - enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void - ) => { - // clear the query box, append to entries - const userEntry: Entry = { - role: "user", - content: query_source === "search" ? query : query.split("\n", 2)[1]!, - }; - setEntries((prev) => [...prev, userEntry]); - disable(); - - const { result, followups } = await runSearch( - query, - query_source, - mode[0], - entries, - updateCurrent, - sessionId - ); - setCurrent(undefined); - - setEntries((prev) => [...prev, result]); - enable(followups || []); - scroll30(); - }; - - var last_entry = <>; - switch (current?.phase) { - case "semantic": - last_entry =

Loading: Performing semantic search...

; - break; - case "prompt": - last_entry =

Loading: Creating prompt...

; - break; - case "llm": - last_entry =

Loading: Waiting for LLM...

; - break; - case "streaming": - updateCitations(citations, current); - last_entry = ; - break; - case "followups": - last_entry = ( - <> - -

Checking for followups...

- - ); - break; - } - return ( @@ -176,14 +38,7 @@ const Home: NextPage = () => { welcomed. -
    - {entries.map((entry, i) => ( - - ))} - - - {last_entry} -
+
); }; diff --git a/web/src/types.ts b/web/src/types.ts index b7fe0c4..9bfede9 100644 --- a/web/src/types.ts +++ b/web/src/types.ts @@ -42,3 +42,5 @@ export type SearchResult = { result: Entry; }; export type CurrentSearch = (AssistantEntry & { phase?: string }) | undefined; + +export type Mode = "rookie" | "concise" | "default"; From ff55dd5da89a205036998802e4566424b30c88ed Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 1 Oct 2023 20:30:45 +0200 Subject: [PATCH 2/4] Basic playground --- web/src/components/chat.tsx | 32 +++- web/src/components/header.tsx | 4 +- web/src/components/searchbox.tsx | 8 +- web/src/pages/index.tsx | 2 +- web/src/pages/playground.tsx | 315 +++++++++++++++++++++++++++++++ web/src/types.ts | 14 ++ 6 files changed, 364 insertions(+), 11 deletions(-) create mode 100644 web/src/pages/playground.tsx diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx index 7701524..2e335da 100644 --- a/web/src/components/chat.tsx +++ b/web/src/components/chat.tsx @@ -1,11 +1,12 @@ import { useState, useEffect } from "react"; import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; + import type { CurrentSearch, Citation, Entry, AssistantEntry as AssistantEntryType, - Mode, + LLMSettings, Followup, } from "../types"; import { SearchBox } from "../components/searchbox"; @@ -39,7 +40,14 @@ function scroll30() { window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" }); } -const Chat = ({ sessionId, mode }: { sessionId: string; mode: Mode }) => { +type ChatParams = { + sessionId: string; + settings: LLMSettings; + onQuery?: (q: string) => any; + onNewEntry?: (history: Entry[]) => any; +}; + +const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { const [entries, setEntries] = useState([]); const [current, setCurrent] = useState(); const [citations, setCitations] = useState([]); @@ -82,6 +90,16 @@ const Chat = ({ sessionId, mode }: { sessionId: string; mode: Mode }) => { setCurrent(current); }; + const addEntry = (entry: Entry) => { + setEntries((prev) => { + const entries = [...prev, entry]; + if (onNewEntry) { + onNewEntry(entries); + } + return entries; + }); + }; + const search = async ( query: string, query_source: "search" | "followups", @@ -93,20 +111,20 @@ const Chat = ({ sessionId, mode }: { sessionId: string; mode: Mode }) => { role: "user", content: query_source === "search" ? query : query.split("\n", 2)[1]!, }; - setEntries((prev) => [...prev, userEntry]); + addEntry(userEntry); disable(); const { result, followups } = await runSearch( query, query_source, - mode, + settings.mode, entries, updateCurrent, sessionId ); setCurrent(undefined); - setEntries((prev) => [...prev, result]); + addEntry(result); enable(followups || []); scroll30(); }; @@ -137,11 +155,11 @@ const Chat = ({ sessionId, mode }: { sessionId: string; mode: Mode }) => { } return ( -
    +
      {entries.map((entry, i) => ( ))} - + {last_entry}
    diff --git a/web/src/components/header.tsx b/web/src/components/header.tsx index 9f10fc7..1941361 100644 --- a/web/src/components/header.tsx +++ b/web/src/components/header.tsx @@ -3,7 +3,9 @@ import Link from "next/link"; import Image from "next/image"; import logo from "../logo.svg"; -const Header: React.FC<{ page: "index" | "semantic" }> = ({ page }) => { +const Header: React.FC<{ page: "index" | "semantic" | "playground" }> = ({ + page, +}) => { const sidebar = page === "index" ? ( diff --git a/web/src/components/searchbox.tsx b/web/src/components/searchbox.tsx index ed1733b..1cc422f 100644 --- a/web/src/components/searchbox.tsx +++ b/web/src/components/searchbox.tsx @@ -35,7 +35,8 @@ const SearchBoxInternal: React.FC<{ disable: () => void, enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void ) => void; -}> = ({ search }) => { + onQuery?: (q: string) => any; +}> = ({ search, onQuery }) => { const initial_query = initialQuestions[Math.floor(Math.random() * initialQuestions.length)] || ""; @@ -107,7 +108,10 @@ const SearchBoxInternal: React.FC<{ className="flex-1 resize-none border border-gray-300 px-1" ref={inputRef} value={query} - onChange={(e) => setQuery(e.target.value)} + onChange={(e) => { + setQuery(e.target.value); + onQuery && onQuery(e.target.value); + }} onKeyDown={(e) => { // if , blur the input box if (e.key === "Escape") e.currentTarget.blur(); diff --git a/web/src/pages/index.tsx b/web/src/pages/index.tsx index e34d029..41482d1 100644 --- a/web/src/pages/index.tsx +++ b/web/src/pages/index.tsx @@ -38,7 +38,7 @@ const Home: NextPage = () => { welcomed. - + ); }; diff --git a/web/src/pages/playground.tsx b/web/src/pages/playground.tsx new file mode 100644 index 0000000..fd00fff --- /dev/null +++ b/web/src/pages/playground.tsx @@ -0,0 +1,315 @@ +import type { NextPage } from "next"; +import { useState, useEffect, ChangeEvent } from "react"; +import TextareaAutosize from "react-textarea-autosize"; +import Head from "next/head"; +import Link from "next/link"; + +import { queryLLM, getStampyContent, runSearch } from "../hooks/useSearch"; +import type { Mode, Entry, LLMSettings } from "../types"; +import Header from "../components/header"; +import Chat from "../components/chat"; +import { Controls } from "../components/controls"; + +const MAX_FOLLOWUPS = 4; +const DEFAULT_PROMPTS = { + source: { + prefix: + "You are a helpful assistant knowledgeable about AI Alignment and Safety. " + + 'Please give a clear and coherent answer to the user\'s questions.(written after "Q:") ' + + "using the following sources. Each source is labeled with a letter. Feel free to " + + "use the sources in any order, and try to use multiple sources in your answers.\n\n", + suffix: + "\n\n" + + 'Before the question ("Q: "), there will be a history of previous questions and answers. ' + + "These sources only apply to the last question. Any sources used in previous answers " + + "are invalid.", + }, + question: + "In your answer, please cite any claims you make back to each source " + + "using the format: [a], [b], etc. If you use multiple sources to make a claim " + + 'cite all of them. For example: "AGI is concerning [c, d, e]."\n\n', + modes: { + default: "", + concise: + "Answer very concisely, getting to the crux of the matter in as " + + "few words as possible. Limit your answer to 1-2 sentences.\n\n", + rookie: + "This user is new to the field of AI Alignment and Safety - don't " + + "assume they know any technical terms or jargon. Still give a complete answer " + + "without patronizing the user, but take any extra time needed to " + + "explain new concepts or to illustrate your answer with examples. " + + "Put extra effort into explaining the intuition behind concepts " + + "rather than just giving a formal definition.\n\n", + }, +}; +const DEFAULT_SETTINGS = { + prompts: DEFAULT_PROMPTS, + mode: "default" as Mode, + completions: "gpt-3.5-turbo", + encoder: "cl100k_base", + topKBlocks: 10, // the number of blocks to use as citations + numTokens: 4095, + tokensBuffer: 50, // the number of tokens to leave as a buffer when calculating remaining tokens + historyFraction: 0.25, // the (approximate) fraction of num_tokens to use for history text before truncating + contextFraction: 0.5, // the (approximate) fraction of num_tokens to use for context text before truncating +}; +const COMPLETION_MODELS = ["gpt-3.5-turbo", "gpt-4"]; +const ENCODERS = ["cl100k_base"]; + +const updateIn = (obj, [head, ...rest]: string[], val: any) => { + if (!head) { + // No path provided - do nothing + } else if (!rest || rest.length == 0) { + obj[head] = val; + } else { + updateIn(obj[head], rest, val); + } + return obj; +}; + +type ChatSettingsParams = { + settings: LLMSettings; + updateSettings: (updater: (settings: LLMSettings) => LLMSettings) => void; +}; + +const ChatSettings = ({ settings, updateSettings }: ChatSettingsParams) => { + const update = (setting: string) => (event: ChangeEvent) => { + updateSettings((prev) => ({ + ...prev, + [setting]: (event.target as HTMLInputElement).value, + })); + }; + const between = + (setting: string, min?: number, max?: number, parser?) => + (event: ChangeEvent) => { + let num = parser((event.target as HTMLInputElement).value); + if (isNaN(num)) { + return; + } else if (min !== undefined && num < min) { + num = min; + } else if (max !== undefined && num > max) { + num = max; + } + updateSettings((prev) => ({ ...prev, [setting]: num })); + }; + const intBetween = (setting: string, min?: number, max?: number) => + between(setting, min, max, (v: any) => parseInt(v, 10)); + const floatBetween = (setting: string, min?: number, max?: number) => + between(setting, min, max, parseFloat); + return ( +
    +

    Models

    +
    + + +
    + +
    + + +
    + +

    Token options

    +
    + + +
    + +
    + + +
    + +

    Prompt options

    +
    + + +
    + +
    + + +
    + +
    + + +
    +
    + ); +}; + +type ChatPromptParams = { + settings: LLMSettings; + query: string; + history: Entry[]; + updateSettings: (updater: (settings: LLMSettings) => LLMSettings) => void; +}; + +const ChatPrompts = ({ + settings, + query, + history, + updateSettings, +}: ChatPromptParams) => { + const updatePrompt = + (...path: string[]) => + (event: ChangeEvent) => { + const newPrompts = { + ...updateIn( + settings.prompts, + path, + (event.target as HTMLInputElement).value + ), + }; + updateSettings((settings) => ({ ...settings, prompts: newPrompts })); + }; + return ( +
    +
    + Source prompt + +
    (This is where sources will be injected)
    + {history.length > 0 && ( + + )} +
    + {history.length > 0 && ( +
    + History + {history.map((entry) => ( +
    {entry.content}
    + ))} +
    + )} +
    + Question prompt + + +
    +
    Q: {query}
    +
    + ); +}; + +const Playground: NextPage = () => { + const [sessionId, setSessionId] = useState(""); + const [settings, updateSettings] = useState(DEFAULT_SETTINGS); + + const [query, setQuery] = useState(""); + const [history, setHistory] = useState([]); + + const setMode = (mode: [Mode, boolean]) => { + if (mode[1]) { + localStorage.setItem("chat_mode", mode[0]); + updateSettings((settings) => ({ ...settings, mode: mode[0] })); + } + }; + + // initial load + useEffect(() => { + const mode = (localStorage.getItem("chat_mode") as Mode) || "default"; + setMode([mode, true]); + setSessionId(crypto.randomUUID()); + }, []); + + return ( + <> + + AI Safety Info + +
    +
    + +
    + + + +
    +
    + + ); +}; + +export default Playground; diff --git a/web/src/types.ts b/web/src/types.ts index 9bfede9..f78374e 100644 --- a/web/src/types.ts +++ b/web/src/types.ts @@ -44,3 +44,17 @@ export type SearchResult = { export type CurrentSearch = (AssistantEntry & { phase?: string }) | undefined; export type Mode = "rookie" | "concise" | "default"; + +export type LLMSettings = { + prompts?: { + [key: string]: any; + }; + mode?: Mode; + completions?: string; + encoder?: string; + topKBlocks?: number; + numTokens?: number; + tokensBuffer?: number; + historyFraction?: number; + contextFraction?: number; +}; From 88a88406f589b7c43721e7b87d44c10e7ee88b4d Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 1 Oct 2023 22:02:41 +0200 Subject: [PATCH 3/4] Allow the backend to be configured --- api/main.py | 5 +- api/src/stampy_chat/chat.py | 124 ++++++++---------------- api/src/stampy_chat/settings.py | 147 +++++++++++++++++++++++++++++ api/tests/stampy_chat/test_chat.py | 62 ++++++++---- web/src/components/chat.tsx | 7 +- web/src/hooks/useSearch.ts | 13 +-- web/src/pages/playground.tsx | 15 ++- 7 files changed, 262 insertions(+), 111 deletions(-) create mode 100644 api/src/stampy_chat/settings.py diff --git a/api/main.py b/api/main.py index fc9f9ed..db85001 100644 --- a/api/main.py +++ b/api/main.py @@ -10,6 +10,7 @@ from stampy_chat.env import PINECONE_INDEX, FLASK_PORT from stampy_chat.get_blocks import get_top_k_blocks from stampy_chat.chat import talk_to_robot, talk_to_robot_simple +from stampy_chat.settings import Settings # ---------------------------------- web setup --------------------------------- @@ -44,11 +45,11 @@ def semantic(): def chat(): query = request.json.get('query') - mode = request.json.get('mode', 'default') session_id = request.json.get('sessionId') history = request.json.get('history', []) + settings = Settings(**request.json.get('settings', {})) - return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history, session_id)), mimetype='text/event-stream') + return Response(stream(talk_to_robot(PINECONE_INDEX, query, history, session_id, settings)), mimetype='text/event-stream') # ------------- simplified non-streaming chat for internal testing ------------- diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index a8eb69e..6208cf5 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -6,92 +6,45 @@ from typing import List, Dict import openai -import tiktoken -from stampy_chat.env import COMPLETIONS_MODEL +from stampy_chat import logging from stampy_chat.followups import multisearch_authored from stampy_chat.get_blocks import get_top_k_blocks, Block -from stampy_chat import logging +from stampy_chat.settings import Settings logger = logging.getLogger(__name__) -STANDARD_K = 20 if COMPLETIONS_MODEL == 'gpt-4' else 10 - -# parameters - -# NOTE: All this is approximate, there's bits I'm intentionally not counting. Leave a buffer beyond what you might expect. -NUM_TOKENS = 8191 if COMPLETIONS_MODEL == 'gpt-4' else 4095 -TOKENS_BUFFER = 50 # the number of tokens to leave as a buffer when calculating remaining tokens -HISTORY_FRACTION = 0.25 # the (approximate) fraction of num_tokens to use for history text before truncating -CONTEXT_FRACTION = 0.5 # the (approximate) fraction of num_tokens to use for context text before truncating - -ENCODER = tiktoken.get_encoding("cl100k_base") - -SOURCE_PROMPT = ( - "You are a helpful assistant knowledgeable about AI Alignment and Safety. " - "Please give a clear and coherent answer to the user's questions.(written after \"Q:\") " - "using the following sources. Each source is labeled with a letter. Feel free to " - "use the sources in any order, and try to use multiple sources in your answers.\n\n" -) -SOURCE_PROMPT_SUFFIX = ( - "\n\n" - "Before the question (\"Q: \"), there will be a history of previous questions and answers. " - "These sources only apply to the last question. Any sources used in previous answers " - "are invalid." -) - -QUESTION_PROMPT = ( - "In your answer, please cite any claims you make back to each source " - "using the format: [a], [b], etc. If you use multiple sources to make a claim " - "cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n" -) -PROMPT_MODES = { - 'default': "", - "concise": ( - "Answer very concisely, getting to the crux of the matter in as " - "few words as possible. Limit your answer to 1-2 sentences.\n\n" - ), - "rookie": ( - "This user is new to the field of AI Alignment and Safety - don't " - "assume they know any technical terms or jargon. Still give a complete answer " - "without patronizing the user, but take any extra time needed to " - "explain new concepts or to illustrate your answer with examples. " - "Put extra effort into explaining the intuition behind concepts " - "rather than just giving a formal definition.\n\n" - ), -} - -# --------------------------------- prompt code -------------------------------- - - - # limit a string to a certain number of tokens -def cap(text: str, max_tokens: int) -> str: +def cap(text: str, max_tokens: int, encoder) -> str: if max_tokens <= 0: return "..." - encoded_text = ENCODER.encode(text) + encoded_text = encoder.encode(text) if len(encoded_text) <= max_tokens: return text - return ENCODER.decode(encoded_text[:max_tokens]) + " ..." + return encoder.decode(encoded_text[:max_tokens]) + " ..." Prompt = List[Dict[str, str]] -def prompt_context(source_prompt: str, context: List[Block], max_tokens: int) -> str: - token_count = len(ENCODER.encode(source_prompt)) +def prompt_context(context: List[Block], settings: Settings) -> str: + source_prompt = settings.source_prompt_prefix + max_tokens = settings.context_tokens + encoder = settings.encoder + + token_count = len(encoder.encode(source_prompt)) # Context from top-k blocks for i, block in enumerate(context): block_str = f"[{chr(ord('a') + i)}] {block.title} - {','.join(block.authors)} - {block.date}\n{block.text}\n\n" - block_tc = len(ENCODER.encode(block_str)) + block_tc = len(encoder.encode(block_str)) if token_count + block_tc > max_tokens: - source_prompt += cap(block_str, max_tokens - token_count) + source_prompt += cap(block_str, max_tokens - token_count, encoder) break else: source_prompt += block_str @@ -99,35 +52,34 @@ def prompt_context(source_prompt: str, context: List[Block], max_tokens: int) -> return source_prompt.strip() -def prompt_history(history: Prompt, max_tokens: int, n_items=10) -> Prompt: +def prompt_history(history: Prompt, settings: Settings) -> Prompt: + max_tokens = settings.history_tokens + encoder = settings.encoder token_count = 0 prompt = [] # Get the n_items last messages, starting from the last one. This is because it's assumed # that more recent messages are more important. The `-1` is because of how slicing works - messages = history[:-n_items - 1:-1] + messages = history[:-settings.maxHistory - 1:-1] for message in messages: if message["role"] == "user": prompt.append({"role": "user", "content": "Q: " + message["content"]}) - token_count += len(ENCODER.encode("Q: " + message["content"])) + token_count += len(encoder.encode("Q: " + message["content"])) else: content = message["content"] # censor all source letters into [x] content = re.sub(r"\[[0-9]+\]", "[x]", content) - content = cap(content, max_tokens - token_count) + content = cap(content, max_tokens - token_count, encoder) prompt.append({"role": "assistant", "content": content}) - token_count += len(ENCODER.encode(content)) + token_count += len(encoder.encode(content)) if token_count > max_tokens: break return prompt[::-1] -def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt: - if mode not in PROMPT_MODES: - raise ValueError("Invalid mode: " + mode) - +def construct_prompt(query: str, settings: Settings, history: Prompt, context: List[Block]) -> Prompt: # History takes the format: history=[ # {"role": "user", "content": "Die monster. You don’t belong in this world!"}, # {"role": "assistant", "content": "It was not by my hand I am once again given flesh. I was called here by humans who wished to pay me tribute."}, @@ -138,14 +90,14 @@ def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block # ] # Context from top-k blocks - source_prompt = prompt_context(SOURCE_PROMPT, context, int(NUM_TOKENS * CONTEXT_FRACTION)) + source_prompt = prompt_context(context, settings) if history: - source_prompt += SOURCE_PROMPT_SUFFIX + source_prompt += settings.source_prompt_suffix source_prompt = [{"role": "system", "content": source_prompt.strip()}] # Write a version of the last 10 messages into history, cutting things off when we hit the token limit. - history_prompt = prompt_history(history, int(NUM_TOKENS * HISTORY_FRACTION)) - question_prompt = [{"role": "user", "content": QUESTION_PROMPT + PROMPT_MODES[mode] + "Q: " + query}] + history_prompt = prompt_history(history, settings) + question_prompt = [{"role": "user", "content": settings.question_prompt(query)}] return source_prompt + history_prompt + question_prompt @@ -161,20 +113,21 @@ def check_openai_moderation(prompt: Prompt, query: str): raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.") -def remaining_tokens(prompt: Prompt): +def remaining_tokens(prompt: Prompt, settings: Settings): # Count number of tokens left for completion (-50 for a buffer) + encoder = settings.encoder used_tokens = sum([ - len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"])) + len(encoder.encode(message["content"]) + encoder.encode(message["role"])) for message in prompt ]) - return max(0, NUM_TOKENS - used_tokens - TOKENS_BUFFER) + return max(0, settings.numTokens - used_tokens - settings.tokensBuffer) -def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, session_id: str, k: int = STANDARD_K): +def talk_to_robot_internal(index, query: str, history: Prompt, session_id: str, settings: Settings=Settings()): try: # 1. Find the most relevant blocks from the Alignment Research Dataset yield {"state": "loading", "phase": "semantic"} - top_k_blocks = get_top_k_blocks(index, query, k) + top_k_blocks = get_top_k_blocks(index, query, settings.topKBlocks) yield { "state": "citations", @@ -186,14 +139,16 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio # 2. Generate a prompt yield {"state": "loading", "phase": "prompt"} - prompt = construct_prompt(query, mode, history, top_k_blocks) + prompt = construct_prompt(query, settings, history, top_k_blocks) # 3. Run both the standalone query and the full prompt through # moderation to see if it will be accepted by OpenAI's api check_openai_moderation(prompt, query) # 4. Count number of tokens left for completion (-50 for a buffer) - max_tokens_completion = remaining_tokens(prompt) + max_tokens_completion = remaining_tokens(prompt, settings) + if max_tokens_completion < 40: + raise ValueError(f"{max_tokens_completion} tokens left for the actual query after constructing the context - aborting, as that's not going to be enough") # 5. Answer the user query yield {"state": "loading", "phase": "llm"} @@ -201,7 +156,7 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio response = '' for chunk in openai.ChatCompletion.create( - model=COMPLETIONS_MODEL, + model=settings.completions, messages=prompt, max_tokens=max_tokens_completion, stream=True, @@ -239,18 +194,19 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio except Exception as e: logger.error(e) yield {'state': 'error', 'error': str(e)} + raise # convert talk_to_robot_internal from dict generator into json generator -def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K): - yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k)) +def talk_to_robot(index, query: str, history: List[Dict[str, str]], session_id: str, settings: Settings): + yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, history, session_id, settings)) # wayyy simplified api def talk_to_robot_simple(index, query: str): res = {'response': ''} - for block in talk_to_robot_internal(index, query, "default", []): + for block in talk_to_robot_internal(index, query, []): if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block: citations = {} for i, c in enumerate(block['citations']): diff --git a/api/src/stampy_chat/settings.py b/api/src/stampy_chat/settings.py new file mode 100644 index 0000000..70d6b1a --- /dev/null +++ b/api/src/stampy_chat/settings.py @@ -0,0 +1,147 @@ +import tiktoken + +from stampy_chat.env import COMPLETIONS_MODEL + + +SOURCE_PROMPT = ( + "You are a helpful assistant knowledgeable about AI Alignment and Safety. " + "Please give a clear and coherent answer to the user's questions.(written after \"Q:\") " + "using the following sources. Each source is labeled with a letter. Feel free to " + "use the sources in any order, and try to use multiple sources in your answers.\n\n" +) +SOURCE_PROMPT_SUFFIX = ( + "\n\n" + "Before the question (\"Q: \"), there will be a history of previous questions and answers. " + "These sources only apply to the last question. Any sources used in previous answers " + "are invalid." +) + +QUESTION_PROMPT = ( + "In your answer, please cite any claims you make back to each source " + "using the format: [a], [b], etc. If you use multiple sources to make a claim " + "cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n" +) +PROMPT_MODES = { + 'default': "", + "concise": ( + "Answer very concisely, getting to the crux of the matter in as " + "few words as possible. Limit your answer to 1-2 sentences.\n\n" + ), + "rookie": ( + "This user is new to the field of AI Alignment and Safety - don't " + "assume they know any technical terms or jargon. Still give a complete answer " + "without patronizing the user, but take any extra time needed to " + "explain new concepts or to illustrate your answer with examples. " + "Put extra effort into explaining the intuition behind concepts " + "rather than just giving a formal definition.\n\n" + ), +} +DEFAULT_PROMPTS = { + 'source': { + 'prefix': SOURCE_PROMPT, + 'suffix': SOURCE_PROMPT_SUFFIX, + }, + 'question': QUESTION_PROMPT, + 'modes': PROMPT_MODES, +} + + +class Settings: + + encoders = {} + + def __init__( + self, + prompts=DEFAULT_PROMPTS, + mode='default', + completions=COMPLETIONS_MODEL, + encoder='cl100k_base', + topKBlocks=None, + numTokens=None, + tokensBuffer=50, + maxHistory=10, + historyFraction=0.25, + contextFraction=0.5, + **_kwargs, + ) -> None: + self.prompts = prompts + self.mode = mode + if self.mode_prompt is None: + raise ValueError("Invalid mode: " + mode) + + self.encoder = encoder + + self.set_completions(completions, numTokens, topKBlocks) + + self.tokensBuffer = tokensBuffer + """the number of tokens to leave as a buffer when calculating remaining tokens""" + + self.maxHistory = maxHistory + """the max number of previous interactions to use as the history""" + + self.historyFraction = historyFraction + """the (approximate) fraction of num_tokens to use for history text before truncating""" + + self.contextFraction = contextFraction + """the (approximate) fraction of num_tokens to use for context text before truncating""" + + def __repr__(self) -> str: + return f' { const { result, followups } = await runSearch( query, query_source, - settings.mode, + settings, entries, updateCurrent, sessionId @@ -152,6 +152,11 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { ); break; + default: + last_entry = ( + + ); + break; } return ( diff --git a/web/src/hooks/useSearch.ts b/web/src/hooks/useSearch.ts index 776194f..61dec04 100644 --- a/web/src/hooks/useSearch.ts +++ b/web/src/hooks/useSearch.ts @@ -8,6 +8,7 @@ import type { Followup, CurrentSearch, SearchResult, + LLMSettings, } from "../types"; import { formatCitations, findCitations } from "../components/citations"; @@ -102,7 +103,7 @@ export const extractAnswer = async ( const fetchLLM = async ( sessionId: string, query: string, - mode: string, + settings: LLMSettings, history: HistoryEntry[] ): Promise => fetch(API_URL + "/chat", { @@ -114,18 +115,18 @@ const fetchLLM = async ( Accept: "text/event-stream", }, - body: JSON.stringify({ sessionId, query, mode, history }), + body: JSON.stringify({ sessionId, query, history, settings }), }); export const queryLLM = async ( query: string, - mode: string, + settings: LLMSettings, history: HistoryEntry[], setCurrent: (e?: CurrentSearch) => void, sessionId: string ): Promise => { // do SSE on a POST request. - const res = await fetchLLM(sessionId, query, mode, history); + const res = await fetchLLM(sessionId, query, settings, history); if (!res.ok) { return { result: { role: "error", content: "POST Error: " + res.status } }; @@ -194,7 +195,7 @@ export const getStampyContent = async ( export const runSearch = async ( query: string, query_source: "search" | "followups", - mode: string, + settings: LLMSettings, entries: Entry[], setCurrent: (c: CurrentSearch) => void, sessionId: string @@ -207,7 +208,7 @@ export const runSearch = async ( content: entry.content.trim(), })); - return await queryLLM(query, mode, history, setCurrent, sessionId); + return await queryLLM(query, settings, history, setCurrent, sessionId); } else { // ----------------- HUMAN AUTHORED CONTENT RETRIEVAL ------------------ const [questionId] = query.split("\n", 2); diff --git a/web/src/pages/playground.tsx b/web/src/pages/playground.tsx index fd00fff..b4c0fec 100644 --- a/web/src/pages/playground.tsx +++ b/web/src/pages/playground.tsx @@ -44,12 +44,13 @@ const DEFAULT_PROMPTS = { }; const DEFAULT_SETTINGS = { prompts: DEFAULT_PROMPTS, - mode: "default" as Mode, + mode: "default", completions: "gpt-3.5-turbo", encoder: "cl100k_base", topKBlocks: 10, // the number of blocks to use as citations numTokens: 4095, tokensBuffer: 50, // the number of tokens to leave as a buffer when calculating remaining tokens + maxHistory: 10, // the max number of previous items to use as history historyFraction: 0.25, // the (approximate) fraction of num_tokens to use for history text before truncating contextFraction: 0.5, // the (approximate) fraction of num_tokens to use for context text before truncating }; @@ -162,6 +163,18 @@ const ChatSettings = ({ settings, updateSettings }: ChatSettingsParams) => { /> +
    + + +
    +