Skip to content

Commit

Permalink
Merge pull request #135 from StampyAI/history_summary
Browse files Browse the repository at this point in the history
History summary
  • Loading branch information
mruwnik authored Jan 15, 2024
2 parents 9c808ce + a8cc986 commit 2153039
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 7 deletions.
56 changes: 53 additions & 3 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
class PrefixedPrompt(BaseChatPromptTemplate):
"""A prompt that will prefix any messages with a system prompt, but only if messages provided."""

transformer: Callable[[Any], BaseMessage] = lambda i: i
messages_field: str
prompt: str # the system prompt to be used

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
history = kwargs[self.messages_field]
if history and self.prompt:
return [SystemMessage(content=self.prompt)] + history
return [SystemMessage(content=self.prompt)] + [self.transformer(i) for i in history]
return []


Expand Down Expand Up @@ -132,6 +133,12 @@ def prune(self) -> None:
pruned_memory, self.moving_summary_buffer
)

def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""
Because of how wonderfully LangChain is written, this method was blowing up.
It's not needed, so it's getting the chop.
"""


class ModeratedChatPrompt(ChatPromptTemplate):
"""Wraps a prompt with an OpenAI moderation check which will raise an exception if fails."""
Expand All @@ -157,6 +164,49 @@ def get_model(**kwargs):
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs)


class LLMInputsChain(LLMChain):

inputs: Dict[str, Any] = {}

def _call(self, inputs: Dict[str, Any], run_manager=None):
self.inputs = inputs
return super()._call(inputs, run_manager)

def _acall(self, inputs: Dict[str, Any], run_manager=None):
self.inputs = inputs
return super()._acall(inputs, run_manager)

def create_outputs(self, llm_result) -> List[Dict[str, Any]]:
result = super().create_outputs(llm_result)
return [dict(self.inputs, **r) for r in result]


def make_history_summary(settings):
model = get_model(
streaming=False,
max_tokens=settings.maxHistorySummaryTokens,
model=settings.completions
)
summary_prompt = PrefixedPrompt(
input_variables=['history'],
messages_field='history',
prompt=settings.history_summary_prompt,
transformer=lambda m: ChatMessage(**m)
)
return LLMInputsChain(
llm=model,
verbose=False,
output_key='history_summary',
prompt=ModeratedChatPrompt.from_messages([
summary_prompt,
ChatPromptTemplate.from_messages([
ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'),
]),
SystemMessage(content="Reply in one sentence only"),
]),
)


def make_prompt(settings, chat_model, callbacks):
"""Create a proper prompt object will all the nessesery steps."""
# 1. Create the context prompt from items fetched from pinecone
Expand All @@ -176,7 +226,7 @@ def make_prompt(settings, chat_model, callbacks):
query_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=settings.question_prompt),
ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'),
ChatMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'),
]
)

Expand Down Expand Up @@ -247,7 +297,7 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
model=settings.completions
)

chain = LLMChain(
chain = make_history_summary(settings) | LLMChain(
llm=chat_model,
verbose=False,
prompt=make_prompt(settings, chat_model, callbacks),
Expand Down
2 changes: 1 addition & 1 deletion api/src/stampy_chat/followups.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Config:

@property
def input_keys(self) -> List[str]:
return ['query', 'text']
return ['query', 'text', 'history_summary']

@property
def output_keys(self) -> List[str]:
Expand Down
22 changes: 19 additions & 3 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

SOURCE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please give a clear and coherent answer to the user's questions.(written after \"Q:\") "
"Please give a clear and coherent answer to the user's questions. (written after \"Q:\") "
"using the following sources. Each source is labeled with a letter. Feel free to "
"use the sources in any order, and try to use multiple sources in your answers.\n\n"
)
Expand All @@ -19,6 +19,13 @@
"These sources only apply to the last question. Any sources used in previous answers "
"are invalid."
)
HISTORY_SUMMARIZE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please summarize the following chat history (written after \"H:\") in one "
"sentence so as to put the current questions (written after \"Q:\") in context. "
"Please keep things as terse as possible."
"\nH:"
)

QUESTION_PROMPT = (
"In your answer, please cite any claims you make back to each source "
Expand Down Expand Up @@ -47,6 +54,7 @@
DEFAULT_PROMPTS = {
'context': SOURCE_PROMPT,
'history': HISTORY_PROMPT,
'history_summary': HISTORY_SUMMARIZE_PROMPT,
'question': QUESTION_PROMPT,
'modes': PROMPT_MODES,
}
Expand All @@ -72,8 +80,9 @@ def __init__(
topKBlocks=None,
maxNumTokens=None,
min_response_tokens=10,
tokensBuffer=50,
tokensBuffer=100,
maxHistory=10,
maxHistorySummaryTokens=200,
historyFraction=0.25,
contextFraction=0.5,
**_kwargs,
Expand All @@ -93,6 +102,9 @@ def __init__(
self.maxHistory = maxHistory
"""the max number of previous interactions to use as the history"""

self.maxHistorySummaryTokens = maxHistorySummaryTokens
"""the max number of tokens to be used on the history summary"""

self.historyFraction = historyFraction
"""the (approximate) fraction of num_tokens to use for history text before truncating"""

Expand Down Expand Up @@ -153,6 +165,10 @@ def context_prompt(self):
def history_prompt(self):
return self.prompts['history']

@property
def history_summary_prompt(self):
return self.prompts['history_summary']

@property
def mode_prompt(self):
return self.prompts['modes'].get(self.mode, '')
Expand All @@ -174,7 +190,7 @@ def history_tokens(self):
@property
def max_response_tokens(self):
available_tokens = (
self.maxNumTokens -
self.maxNumTokens - self.maxHistorySummaryTokens -
self.context_tokens - len(self.encoder.encode(self.context_prompt)) -
self.history_tokens - len(self.encoder.encode(self.history_prompt)) -
len(self.encoder.encode(self.question_prompt))
Expand Down
2 changes: 2 additions & 0 deletions web/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ export const ChatResponse = ({
return <p>Loading: Sending query...</p>;
case "semantic":
return <p>Loading: Performing semantic search...</p>;
case "history":
return <p>Loading: Processing history...</p>;
case "context":
return <p>Loading: Creating context...</p>;
case "prompt":
Expand Down
16 changes: 16 additions & 0 deletions web/src/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ export const ChatSettings = ({
max={settings.maxNumTokens}
updater={updateNum("tokensBuffer")}
/>
<NumberInput
field="maxHistorySummaryTokens"
value={settings.maxHistorySummaryTokens}
label="The max number of tokens to use for the history summary"
min="0"
max={settings.maxNumTokens}
updater={updateNum("maxHistorySummaryTokens")}
/>

<SectionHeader text="Prompt options" />
<NumberInput
Expand Down Expand Up @@ -178,6 +186,14 @@ export const ChatPrompts = ({

return (
<div className="chat-prompts mx-5 w-[400px] flex-none border-2 p-5 outline-black">
<details>
<summary>History summary prompt</summary>
<TextareaAutosize
className="border-gray w-full border px-1"
value={settings?.prompts?.history_summary}
onChange={updatePrompt("history_summary")}
/>
</details>
<details open>
<summary>Source prompt</summary>
<TextareaAutosize
Expand Down
7 changes: 7 additions & 0 deletions web/src/hooks/useSettings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ const DEFAULT_PROMPTS = {
'Before the question ("Q: "), there will be a history of previous questions and answers. ' +
"These sources only apply to the last question. any sources used in previous answers " +
"are invalid.",
history_summary:
"You are a helpful assistant knowledgeable about AI Alignment and Safety. " +
'Please summarize the following chat history (written after "H:") in one ' +
'sentence so as to put the current questions (written after "Q:") in context. ' +
"Please keep things as terse as possible." +
"\nH:",
question:
"In your answer, please cite any claims you make back to each source " +
"using the format: [a], [b], etc. If you use multiple sources to make a claim " +
Expand Down Expand Up @@ -131,6 +137,7 @@ const SETTINGS_PARSERS = {
maxNumTokens: withDefault(MODELS["gpt-3.5-turbo"]?.maxNumTokens),
tokensBuffer: withDefault(50), // the number of tokens to leave as a buffer when calculating remaining tokens
maxHistory: withDefault(10), // the max number of previous items to use as history
maxHistorySummaryTokens: withDefault(200), // the max number of tokens to use in the history summary
historyFraction: withDefault(0.25), // the (approximate) fraction of num_tokens to use for history text before truncating
contextFraction: withDefault(0.5), // the (approximate) fraction of num_tokens to use for context text before truncating
};
Expand Down

0 comments on commit 2153039

Please sign in to comment.