diff --git a/api/main.py b/api/main.py index 912778f..d5bb0e8 100644 --- a/api/main.py +++ b/api/main.py @@ -103,13 +103,14 @@ def human(id): def ratings(): session_id = request.json.get('sessionId') settings = request.json.get('settings', {}) + comment = (request.json.get('comment') or '').strip() or None # only save strings if not empty score = request.json.get('score') if not session_id or score is None: return Response('{"error": "missing params}', 400, mimetype='application/json') with make_session() as s: - s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings))) + s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings), comment=comment)) s.commit() return jsonify({'status': 'ok'}) diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index a1f3786..d5774f2 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -108,6 +108,23 @@ def set_messages(self, history: List[dict]) -> None: for callback in self.callbacks: callback.on_memory_set_end(self.chat_memory) + def prune(self) -> None: + """Prune buffer if it exceeds max token limit. + + This is the original Langchain version copied with a fix to handle the case when + all messages are longer than the max_token_limit + """ + buffer = self.chat_memory.messages + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + if curr_buffer_length > self.max_token_limit: + pruned_memory = [] + while buffer and curr_buffer_length > self.max_token_limit: + pruned_memory.append(buffer.pop(0)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + self.moving_summary_buffer = self.predict_new_summary( + pruned_memory, self.moving_summary_buffer + ) + class ModeratedChatPrompt(ChatPromptTemplate): """Wraps a prompt with an OpenAI moderation check which will raise an exception if fails.""" diff --git a/api/src/stampy_chat/settings.py b/api/src/stampy_chat/settings.py index 12aae79..50d42c4 100644 --- a/api/src/stampy_chat/settings.py +++ b/api/src/stampy_chat/settings.py @@ -4,7 +4,7 @@ from stampy_chat.env import COMPLETIONS_MODEL -Model = namedtuple('Model', ['maxTokens', 'topKBlocks']) +Model = namedtuple('Model', ['maxTokens', 'topKBlocks', 'maxCompletionTokens']) SOURCE_PROMPT = ( @@ -51,9 +51,10 @@ 'modes': PROMPT_MODES, } MODELS = { - 'gpt-3.5-turbo': Model(4097, 10), - 'gpt-3.5-turbo-16k': Model(16385, 30), - 'gpt-4': Model(8192, 20), + 'gpt-3.5-turbo': Model(4097, 10, 4096), + 'gpt-3.5-turbo-16k': Model(16385, 30, 4096), + 'gpt-4': Model(8192, 20, 4096), + "gpt-4-1106-preview": Model(128000, 50, 4096), # 'gpt-4-32k': Model(32768, 30), } @@ -138,6 +139,8 @@ def set_completions(self, completions, maxNumTokens=None, topKBlocks=None): else: self.topKBlocks = MODELS[completions].topKBlocks + self.maxCompletionTokens = MODELS[completions].maxCompletionTokens + @property def prompt_modes(self): return self.prompts['modes'] @@ -170,4 +173,4 @@ def history_tokens(self): @property def max_response_tokens(self): - return self.maxNumTokens - self.context_tokens - self.history_tokens + return min(self.maxNumTokens - self.context_tokens - self.history_tokens, self.maxCompletionTokens) diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx index 8038c32..4b2de3d 100644 --- a/web/src/components/chat.tsx +++ b/web/src/components/chat.tsx @@ -185,7 +185,7 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { return; } else if ( i === entries.length - 1 && - ["assistant", "stampy"].includes(entry.role) + ["assistant", "stampy", "error"].includes(entry.role) ) { const prev = entries[i - 1]; if (prev !== undefined) setQuery(prev.content); diff --git a/web/src/components/settings.tsx b/web/src/components/settings.tsx index 6ad74e5..8e704ca 100644 --- a/web/src/components/settings.tsx +++ b/web/src/components/settings.tsx @@ -5,22 +5,47 @@ import type { Parseable, LLMSettings, Entry, Mode } from "../types"; import { MODELS, ENCODERS } from "../hooks/useSettings"; import { SectionHeader, NumberInput, Slider } from "../components/html"; +type ChatSettingsUpdate = [path: string[], value: any]; type ChatSettingsParams = { settings: LLMSettings; - changeSetting: (path: string[], value: any) => void; + changeSettings: (...v: ChatSettingsUpdate[]) => void; }; export const ChatSettings = ({ settings, - changeSetting, + changeSettings, }: ChatSettingsParams) => { const changeVal = (field: string, value: any) => - changeSetting([field], value); + changeSettings([[field], value]); const update = (field: string) => (event: ChangeEvent) => changeVal(field, (event.target as HTMLInputElement).value); const updateNum = (field: string) => (num: Parseable) => changeVal(field, num); + const updateTokenFraction = (field: string) => (num: Parseable) => { + // Calculate the fraction of the tokens taken by the buffer + const bufferFraction = + settings.tokensBuffer && settings.maxNumTokens + ? settings.tokensBuffer / settings.maxNumTokens + : 0; + const val = Math.min(parseFloat((num || 0).toString()), 1 - bufferFraction); + + let context = settings.contextFraction || 0; + let history = settings.historyFraction || 0; + + if (field == "contextFraction") { + history = Math.min(history, Math.max(0, 1 - val - bufferFraction)); + context = val; + } else { + context = Math.min(context, Math.max(0, 1 - val - bufferFraction)); + history = val; + } + changeSettings( + [["contextFraction"], context], + [["historyFraction"], history] + ); + }; + return (
); @@ -134,22 +159,22 @@ type ChatPromptParams = { settings: LLMSettings; query: string; history: Entry[]; - changeSetting: (path: string[], value: any) => void; + changeSettings: (...vals: ChatSettingsUpdate[]) => void; }; export const ChatPrompts = ({ settings, query, history, - changeSetting, + changeSettings, }: ChatPromptParams) => { const updatePrompt = (...path: string[]) => (event: ChangeEvent) => - changeSetting( + changeSettings([ ["prompts", ...path], - (event.target as HTMLInputElement).value - ); + (event.target as HTMLInputElement).value, + ]); return (
diff --git a/web/src/hooks/useSettings.ts b/web/src/hooks/useSettings.ts index 8632472..a1fa301 100644 --- a/web/src/hooks/useSettings.ts +++ b/web/src/hooks/useSettings.ts @@ -50,6 +50,7 @@ export const MODELS: { [key: string]: Model } = { "gpt-3.5-turbo": { maxNumTokens: 4095, topKBlocks: 10 }, "gpt-3.5-turbo-16k": { maxNumTokens: 16385, topKBlocks: 30 }, "gpt-4": { maxNumTokens: 8192, topKBlocks: 20 }, + "gpt-4-1106-preview": { maxNumTokens: 128000, topKBlocks: 50 }, /* 'gpt-4-32k': {maxNumTokens: 32768, topKBlocks: 30}, */ }; export const ENCODERS = ["cl100k_base"]; @@ -169,25 +170,39 @@ type ChatSettingsParams = { changeSetting: (path: string[], value: any) => void; }; +type SettingsUpdatePair = [path: string[], val: any]; + export default function useSettings() { const [settingsLoaded, setLoaded] = useState(false); const [settings, updateSettings] = useState(makeSettings({})); const router = useRouter(); - const updateInUrl = (path: string[], value: any) => + const updateInUrl = (vals: { [key: string]: any }) => router.replace({ pathname: router.pathname, - query: { - ...router.query, - [path.join(".")]: value.toString(), - }, + query: { ...router.query, ...vals }, }); const changeSetting = (path: string[], value: any) => { - updateInUrl(path, value); + updateInUrl({ [path.join(".")]: value }); updateSettings((settings) => ({ ...updateIn(settings, path, value) })); }; + const changeSettings = (...items: SettingsUpdatePair) => { + updateInUrl( + items.reduce( + (acc, [path, val]) => ({ ...acc, [path.join(".")]: val }), + {} + ) + ); + updateSettings((settings) => + items.reduce( + (acc, [path, val]) => ({ ...acc, ...updateIn(settings, path, val) }), + settings + ) + ); + }; + const setMode = (mode: Mode | undefined) => { if (mode) { updateSettings({ ...settings, mode: mode }); @@ -210,6 +225,7 @@ export default function useSettings() { return { settings, changeSetting, + changeSettings, setMode, settingsLoaded, randomize, diff --git a/web/src/pages/playground.tsx b/web/src/pages/playground.tsx index a90ba0e..ad702fd 100644 --- a/web/src/pages/playground.tsx +++ b/web/src/pages/playground.tsx @@ -13,7 +13,7 @@ const Playground: NextPage = () => { const [query, setQuery] = useState(""); const [history, setHistory] = useState([]); - const { settings, changeSetting, setMode } = useSettings(); + const { settings, changeSettings, setMode } = useSettings(); // initial load useEffect(() => { @@ -28,7 +28,7 @@ const Playground: NextPage = () => { settings={settings} query={query} history={history} - changeSetting={changeSetting} + changeSettings={changeSettings} /> { onQuery={setQuery} onNewEntry={setHistory} /> - +
); diff --git a/web/src/pages/qa.tsx b/web/src/pages/qa.tsx index a2d7bfa..19e5932 100644 --- a/web/src/pages/qa.tsx +++ b/web/src/pages/qa.tsx @@ -22,6 +22,7 @@ const MAX_FOLLOWUPS = 4; export const saveRatings = async ( sessionId: string, score: number, + comment: string | null, settings: LLMSettings ): Promise => fetch(API_URL + "/ratings", { @@ -29,7 +30,7 @@ export const saveRatings = async ( headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ sessionId, settings, score }), + body: JSON.stringify({ sessionId, settings, score, comment }), }).then((r) => r.json()); const Rater = ({ @@ -41,8 +42,10 @@ const Rater = ({ sessionId: string; reset: () => void; }) => { + const [comment, setComment] = useState(null); + const onRate = async (rate: number) => { - const res = await saveRatings(sessionId, rate, settings); + const res = await saveRatings(sessionId, rate, comment, settings); if (!res.error) reset(); }; @@ -58,6 +61,10 @@ const Rater = ({ ))} Good + ); }; diff --git a/web/src/pages/tester.tsx b/web/src/pages/tester.tsx index dca8ec0..5223dcb 100644 --- a/web/src/pages/tester.tsx +++ b/web/src/pages/tester.tsx @@ -49,7 +49,7 @@ const Tester: NextPage = () => { initialQuestions.map((q, i) => ({ question: q, selected: true, index: i })) ); - const { settings, changeSetting, setMode, settingsLoaded } = useSettings(); + const { settings, changeSettings, setMode, settingsLoaded } = useSettings(); /** Run a search for the given `question` and insert the query promise into it */ @@ -118,7 +118,7 @@ const Tester: NextPage = () => { settings={settings} query="" history={[]} - changeSetting={changeSetting} + changeSettings={changeSettings} />
{questions.map(({ question, selected }, i) => ( @@ -159,7 +159,7 @@ const Tester: NextPage = () => { )}
- + ); diff --git a/web/src/styles/globals.css b/web/src/styles/globals.css index 854e94e..9b83fe8 100644 --- a/web/src/styles/globals.css +++ b/web/src/styles/globals.css @@ -65,3 +65,8 @@ ol { width: 2.5em; margin: 0.5em; } + +.rate-container textarea { + width: 100%; + border: black solid 1px; +}