diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index c14bd33..d22c6ac 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -64,13 +64,14 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: class PrefixedPrompt(BaseChatPromptTemplate): """A prompt that will prefix any messages with a system prompt, but only if messages provided.""" + transformer: Callable[[Any], BaseMessage] = lambda i: i messages_field: str prompt: str # the system prompt to be used def format_messages(self, **kwargs: Any) -> List[BaseMessage]: history = kwargs[self.messages_field] if history and self.prompt: - return [SystemMessage(content=self.prompt)] + history + return [SystemMessage(content=self.prompt)] + [self.transformer(i) for i in history] return [] @@ -132,6 +133,12 @@ def prune(self) -> None: pruned_memory, self.moving_summary_buffer ) + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """ + Because of how wonderfully LangChain is written, this method was blowing up. + It's not needed, so it's getting the chop. + """ + class ModeratedChatPrompt(ChatPromptTemplate): """Wraps a prompt with an OpenAI moderation check which will raise an exception if fails.""" @@ -157,6 +164,49 @@ def get_model(**kwargs): return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs) +class LLMInputsChain(LLMChain): + + inputs: Dict[str, Any] = {} + + def _call(self, inputs: Dict[str, Any], run_manager=None): + self.inputs = inputs + return super()._call(inputs, run_manager) + + def _acall(self, inputs: Dict[str, Any], run_manager=None): + self.inputs = inputs + return super()._acall(inputs, run_manager) + + def create_outputs(self, llm_result) -> List[Dict[str, Any]]: + result = super().create_outputs(llm_result) + return [dict(self.inputs, **r) for r in result] + + +def make_history_summary(settings): + model = get_model( + streaming=False, + max_tokens=settings.maxHistorySummaryTokens, + model=settings.completions + ) + summary_prompt = PrefixedPrompt( + input_variables=['history'], + messages_field='history', + prompt=settings.history_summary_prompt, + transformer=lambda m: ChatMessage(**m) + ) + return LLMInputsChain( + llm=model, + verbose=False, + output_key='history_summary', + prompt=ModeratedChatPrompt.from_messages([ + summary_prompt, + ChatPromptTemplate.from_messages([ + ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'), + ]), + SystemMessage(content="Reply in one sentence only"), + ]), + ) + + def make_prompt(settings, chat_model, callbacks): """Create a proper prompt object will all the nessesery steps.""" # 1. Create the context prompt from items fetched from pinecone @@ -176,7 +226,7 @@ def make_prompt(settings, chat_model, callbacks): query_prompt = ChatPromptTemplate.from_messages( [ SystemMessage(content=settings.question_prompt), - ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'), + ChatMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'), ] ) @@ -247,7 +297,7 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin model=settings.completions ) - chain = LLMChain( + chain = make_history_summary(settings) | LLMChain( llm=chat_model, verbose=False, prompt=make_prompt(settings, chat_model, callbacks), diff --git a/api/src/stampy_chat/followups.py b/api/src/stampy_chat/followups.py index 079dd69..3dbe5d1 100644 --- a/api/src/stampy_chat/followups.py +++ b/api/src/stampy_chat/followups.py @@ -74,7 +74,7 @@ class Config: @property def input_keys(self) -> List[str]: - return ['query', 'text'] + return ['query', 'text', 'history_summary'] @property def output_keys(self) -> List[str]: diff --git a/api/src/stampy_chat/settings.py b/api/src/stampy_chat/settings.py index 8f8a737..53ff4d8 100644 --- a/api/src/stampy_chat/settings.py +++ b/api/src/stampy_chat/settings.py @@ -9,7 +9,7 @@ SOURCE_PROMPT = ( "You are a helpful assistant knowledgeable about AI Alignment and Safety. " - "Please give a clear and coherent answer to the user's questions.(written after \"Q:\") " + "Please give a clear and coherent answer to the user's questions. (written after \"Q:\") " "using the following sources. Each source is labeled with a letter. Feel free to " "use the sources in any order, and try to use multiple sources in your answers.\n\n" ) @@ -19,6 +19,13 @@ "These sources only apply to the last question. Any sources used in previous answers " "are invalid." ) +HISTORY_SUMMARIZE_PROMPT = ( + "You are a helpful assistant knowledgeable about AI Alignment and Safety. " + "Please summarize the following chat history (written after \"H:\") in one " + "sentence so as to put the current questions (written after \"Q:\") in context. " + "Please keep things as terse as possible." + "\nH:" +) QUESTION_PROMPT = ( "In your answer, please cite any claims you make back to each source " @@ -47,6 +54,7 @@ DEFAULT_PROMPTS = { 'context': SOURCE_PROMPT, 'history': HISTORY_PROMPT, + 'history_summary': HISTORY_SUMMARIZE_PROMPT, 'question': QUESTION_PROMPT, 'modes': PROMPT_MODES, } @@ -72,8 +80,9 @@ def __init__( topKBlocks=None, maxNumTokens=None, min_response_tokens=10, - tokensBuffer=50, + tokensBuffer=100, maxHistory=10, + maxHistorySummaryTokens=200, historyFraction=0.25, contextFraction=0.5, **_kwargs, @@ -93,6 +102,9 @@ def __init__( self.maxHistory = maxHistory """the max number of previous interactions to use as the history""" + self.maxHistorySummaryTokens = maxHistorySummaryTokens + """the max number of tokens to be used on the history summary""" + self.historyFraction = historyFraction """the (approximate) fraction of num_tokens to use for history text before truncating""" @@ -153,6 +165,10 @@ def context_prompt(self): def history_prompt(self): return self.prompts['history'] + @property + def history_summary_prompt(self): + return self.prompts['history_summary'] + @property def mode_prompt(self): return self.prompts['modes'].get(self.mode, '') @@ -174,7 +190,7 @@ def history_tokens(self): @property def max_response_tokens(self): available_tokens = ( - self.maxNumTokens - + self.maxNumTokens - self.maxHistorySummaryTokens - self.context_tokens - len(self.encoder.encode(self.context_prompt)) - self.history_tokens - len(self.encoder.encode(self.history_prompt)) - len(self.encoder.encode(self.question_prompt)) diff --git a/web/src/components/chat.tsx b/web/src/components/chat.tsx index f21faae..4944f1c 100644 --- a/web/src/components/chat.tsx +++ b/web/src/components/chat.tsx @@ -63,6 +63,8 @@ export const ChatResponse = ({ return

Loading: Sending query...

; case "semantic": return

Loading: Performing semantic search...

; + case "history": + return

Loading: Processing history...

; case "context": return

Loading: Creating context...

; case "prompt": diff --git a/web/src/components/settings.tsx b/web/src/components/settings.tsx index 8e704ca..c48af7d 100644 --- a/web/src/components/settings.tsx +++ b/web/src/components/settings.tsx @@ -122,6 +122,14 @@ export const ChatSettings = ({ max={settings.maxNumTokens} updater={updateNum("tokensBuffer")} /> + +
+ History summary prompt + +
Source prompt