diff --git a/api/main.py b/api/main.py index 912778f..d5bb0e8 100644 --- a/api/main.py +++ b/api/main.py @@ -103,13 +103,14 @@ def human(id): def ratings(): session_id = request.json.get('sessionId') settings = request.json.get('settings', {}) + comment = (request.json.get('comment') or '').strip() or None # only save strings if not empty score = request.json.get('score') if not session_id or score is None: return Response('{"error": "missing params}', 400, mimetype='application/json') with make_session() as s: - s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings))) + s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings), comment=comment)) s.commit() return jsonify({'status': 'ok'}) diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index a1f3786..d5774f2 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -108,6 +108,23 @@ def set_messages(self, history: List[dict]) -> None: for callback in self.callbacks: callback.on_memory_set_end(self.chat_memory) + def prune(self) -> None: + """Prune buffer if it exceeds max token limit. + + This is the original Langchain version copied with a fix to handle the case when + all messages are longer than the max_token_limit + """ + buffer = self.chat_memory.messages + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + if curr_buffer_length > self.max_token_limit: + pruned_memory = [] + while buffer and curr_buffer_length > self.max_token_limit: + pruned_memory.append(buffer.pop(0)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + self.moving_summary_buffer = self.predict_new_summary( + pruned_memory, self.moving_summary_buffer + ) + class ModeratedChatPrompt(ChatPromptTemplate): """Wraps a prompt with an OpenAI moderation check which will raise an exception if fails.""" diff --git a/api/src/stampy_chat/settings.py b/api/src/stampy_chat/settings.py index 12aae79..50d42c4 100644 --- a/api/src/stampy_chat/settings.py +++ b/api/src/stampy_chat/settings.py @@ -4,7 +4,7 @@ from stampy_chat.env import COMPLETIONS_MODEL -Model = namedtuple('Model', ['maxTokens', 'topKBlocks']) +Model = namedtuple('Model', ['maxTokens', 'topKBlocks', 'maxCompletionTokens']) SOURCE_PROMPT = ( @@ -51,9 +51,10 @@ 'modes': PROMPT_MODES, } MODELS = { - 'gpt-3.5-turbo': Model(4097, 10), - 'gpt-3.5-turbo-16k': Model(16385, 30), - 'gpt-4': Model(8192, 20), + 'gpt-3.5-turbo': Model(4097, 10, 4096), + 'gpt-3.5-turbo-16k': Model(16385, 30, 4096), + 'gpt-4': Model(8192, 20, 4096), + "gpt-4-1106-preview": Model(128000, 50, 4096), # 'gpt-4-32k': Model(32768, 30), } @@ -138,6 +139,8 @@ def set_completions(self, completions, maxNumTokens=None, topKBlocks=None): else: self.topKBlocks = MODELS[completions].topKBlocks + self.maxCompletionTokens = MODELS[completions].maxCompletionTokens + @property def prompt_modes(self): return self.prompts['modes'] @@ -170,4 +173,4 @@ def history_tokens(self): @property def max_response_tokens(self): - return self.maxNumTokens - self.context_tokens - self.history_tokens + return min(self.maxNumTokens - self.context_tokens - self.history_tokens, self.maxCompletionTokens) diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx index 8038c32..4b2de3d 100644 --- a/web/src/components/chat.tsx +++ b/web/src/components/chat.tsx @@ -185,7 +185,7 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => { return; } else if ( i === entries.length - 1 && - ["assistant", "stampy"].includes(entry.role) + ["assistant", "stampy", "error"].includes(entry.role) ) { const prev = entries[i - 1]; if (prev !== undefined) setQuery(prev.content); diff --git a/web/src/components/settings.tsx b/web/src/components/settings.tsx index 6ad74e5..8e704ca 100644 --- a/web/src/components/settings.tsx +++ b/web/src/components/settings.tsx @@ -5,22 +5,47 @@ import type { Parseable, LLMSettings, Entry, Mode } from "../types"; import { MODELS, ENCODERS } from "../hooks/useSettings"; import { SectionHeader, NumberInput, Slider } from "../components/html"; +type ChatSettingsUpdate = [path: string[], value: any]; type ChatSettingsParams = { settings: LLMSettings; - changeSetting: (path: string[], value: any) => void; + changeSettings: (...v: ChatSettingsUpdate[]) => void; }; export const ChatSettings = ({ settings, - changeSetting, + changeSettings, }: ChatSettingsParams) => { const changeVal = (field: string, value: any) => - changeSetting([field], value); + changeSettings([[field], value]); const update = (field: string) => (event: ChangeEvent) => changeVal(field, (event.target as HTMLInputElement).value); const updateNum = (field: string) => (num: Parseable) => changeVal(field, num); + const updateTokenFraction = (field: string) => (num: Parseable) => { + // Calculate the fraction of the tokens taken by the buffer + const bufferFraction = + settings.tokensBuffer && settings.maxNumTokens + ? settings.tokensBuffer / settings.maxNumTokens + : 0; + const val = Math.min(parseFloat((num || 0).toString()), 1 - bufferFraction); + + let context = settings.contextFraction || 0; + let history = settings.historyFraction || 0; + + if (field == "contextFraction") { + history = Math.min(history, Math.max(0, 1 - val - bufferFraction)); + context = val; + } else { + context = Math.min(context, Math.max(0, 1 - val - bufferFraction)); + history = val; + } + changeSettings( + [["contextFraction"], context], + [["historyFraction"], history] + ); + }; + return (