Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat options #111

Merged
merged 4 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stampy_chat.env import PINECONE_INDEX, FLASK_PORT
from stampy_chat.get_blocks import get_top_k_blocks
from stampy_chat.chat import talk_to_robot, talk_to_robot_simple
from stampy_chat.settings import Settings


# ---------------------------------- web setup ---------------------------------
Expand Down Expand Up @@ -44,11 +45,11 @@ def semantic():
def chat():

query = request.json.get('query')
mode = request.json.get('mode', 'default')
session_id = request.json.get('sessionId')
history = request.json.get('history', [])
settings = Settings(**request.json.get('settings', {}))

return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history, session_id)), mimetype='text/event-stream')
return Response(stream(talk_to_robot(PINECONE_INDEX, query, history, session_id, settings)), mimetype='text/event-stream')


# ------------- simplified non-streaming chat for internal testing -------------
Expand Down
123 changes: 39 additions & 84 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,128 +6,80 @@
from typing import List, Dict

import openai
import tiktoken

from stampy_chat.env import COMPLETIONS_MODEL
from stampy_chat import logging
from stampy_chat.followups import multisearch_authored
from stampy_chat.get_blocks import get_top_k_blocks, Block
from stampy_chat import logging
from stampy_chat.settings import Settings


logger = logging.getLogger(__name__)


STANDARD_K = 20 if COMPLETIONS_MODEL == 'gpt-4' else 10
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of these are moved over into the Settings class to be able to pass them around


# parameters

# NOTE: All this is approximate, there's bits I'm intentionally not counting. Leave a buffer beyond what you might expect.
NUM_TOKENS = 8191 if COMPLETIONS_MODEL == 'gpt-4' else 4095
TOKENS_BUFFER = 50 # the number of tokens to leave as a buffer when calculating remaining tokens
HISTORY_FRACTION = 0.25 # the (approximate) fraction of num_tokens to use for history text before truncating
CONTEXT_FRACTION = 0.5 # the (approximate) fraction of num_tokens to use for context text before truncating

ENCODER = tiktoken.get_encoding("cl100k_base")

SOURCE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please give a clear and coherent answer to the user's questions.(written after \"Q:\") "
"using the following sources. Each source is labeled with a letter. Feel free to "
"use the sources in any order, and try to use multiple sources in your answers.\n\n"
)
SOURCE_PROMPT_SUFFIX = (
"\n\n"
"Before the question (\"Q: \"), there will be a history of previous questions and answers. "
"These sources only apply to the last question. Any sources used in previous answers "
"are invalid."
)

QUESTION_PROMPT = (
"In your answer, please cite any claims you make back to each source "
"using the format: [a], [b], etc. If you use multiple sources to make a claim "
"cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n"
)
PROMPT_MODES = {
'default': "",
"concise": (
"Answer very concisely, getting to the crux of the matter in as "
"few words as possible. Limit your answer to 1-2 sentences.\n\n"
),
"rookie": (
"This user is new to the field of AI Alignment and Safety - don't "
"assume they know any technical terms or jargon. Still give a complete answer "
"without patronizing the user, but take any extra time needed to "
"explain new concepts or to illustrate your answer with examples. "
"Put extra effort into explaining the intuition behind concepts "
"rather than just giving a formal definition.\n\n"
),
}

# --------------------------------- prompt code --------------------------------



# limit a string to a certain number of tokens
def cap(text: str, max_tokens: int) -> str:
def cap(text: str, max_tokens: int, encoder) -> str:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the other changes in this file are basically to get it to use the settings object, rather than various constants

if max_tokens <= 0:
return "..."

encoded_text = ENCODER.encode(text)
encoded_text = encoder.encode(text)

if len(encoded_text) <= max_tokens:
return text
return ENCODER.decode(encoded_text[:max_tokens]) + " ..."
return encoder.decode(encoded_text[:max_tokens]) + " ..."


Prompt = List[Dict[str, str]]


def prompt_context(source_prompt: str, context: List[Block], max_tokens: int) -> str:
token_count = len(ENCODER.encode(source_prompt))
def prompt_context(context: List[Block], settings: Settings) -> str:
source_prompt = settings.source_prompt_prefix
max_tokens = settings.context_tokens
encoder = settings.encoder

token_count = len(encoder.encode(source_prompt))

# Context from top-k blocks
for i, block in enumerate(context):
block_str = f"[{chr(ord('a') + i)}] {block.title} - {','.join(block.authors)} - {block.date}\n{block.text}\n\n"
block_tc = len(ENCODER.encode(block_str))
block_tc = len(encoder.encode(block_str))

if token_count + block_tc > max_tokens:
source_prompt += cap(block_str, max_tokens - token_count)
source_prompt += cap(block_str, max_tokens - token_count, encoder)
break
else:
source_prompt += block_str
token_count += block_tc
return source_prompt.strip()


def prompt_history(history: Prompt, max_tokens: int, n_items=10) -> Prompt:
def prompt_history(history: Prompt, settings: Settings) -> Prompt:
max_tokens = settings.history_tokens
encoder = settings.encoder
token_count = 0
prompt = []

# Get the n_items last messages, starting from the last one. This is because it's assumed
# that more recent messages are more important. The `-1` is because of how slicing works
messages = history[:-n_items - 1:-1]
messages = history[:-settings.maxHistory - 1:-1]
for message in messages:
if message["role"] == "user":
prompt.append({"role": "user", "content": "Q: " + message["content"]})
token_count += len(ENCODER.encode("Q: " + message["content"]))
token_count += len(encoder.encode("Q: " + message["content"]))
else:
content = message["content"]
# censor all source letters into [x]
content = re.sub(r"\[[0-9]+\]", "[x]", content)
content = cap(content, max_tokens - token_count)
content = cap(content, max_tokens - token_count, encoder)

prompt.append({"role": "assistant", "content": content})
token_count += len(ENCODER.encode(content))
token_count += len(encoder.encode(content))

if token_count > max_tokens:
break
return prompt[::-1]


def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt:
if mode not in PROMPT_MODES:
raise ValueError("Invalid mode: " + mode)

def construct_prompt(query: str, settings: Settings, history: Prompt, context: List[Block]) -> Prompt:
# History takes the format: history=[
# {"role": "user", "content": "Die monster. You don’t belong in this world!"},
# {"role": "assistant", "content": "It was not by my hand I am once again given flesh. I was called here by humans who wished to pay me tribute."},
Expand All @@ -138,14 +90,14 @@ def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block
# ]

# Context from top-k blocks
source_prompt = prompt_context(SOURCE_PROMPT, context, int(NUM_TOKENS * CONTEXT_FRACTION))
source_prompt = prompt_context(context, settings)
if history:
source_prompt += SOURCE_PROMPT_SUFFIX
source_prompt += settings.source_prompt_suffix
source_prompt = [{"role": "system", "content": source_prompt.strip()}]

# Write a version of the last 10 messages into history, cutting things off when we hit the token limit.
history_prompt = prompt_history(history, int(NUM_TOKENS * HISTORY_FRACTION))
question_prompt = [{"role": "user", "content": QUESTION_PROMPT + PROMPT_MODES[mode] + "Q: " + query}]
history_prompt = prompt_history(history, settings)
question_prompt = [{"role": "user", "content": settings.question_prompt(query)}]

return source_prompt + history_prompt + question_prompt

Expand All @@ -161,20 +113,21 @@ def check_openai_moderation(prompt: Prompt, query: str):
raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.")


def remaining_tokens(prompt: Prompt):
def remaining_tokens(prompt: Prompt, settings: Settings):
# Count number of tokens left for completion (-50 for a buffer)
encoder = settings.encoder
used_tokens = sum([
len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"]))
len(encoder.encode(message["content"]) + encoder.encode(message["role"]))
for message in prompt
])
return max(0, NUM_TOKENS - used_tokens - TOKENS_BUFFER)
return max(0, settings.numTokens - used_tokens - settings.tokensBuffer)


def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, session_id: str, k: int = STANDARD_K):
def talk_to_robot_internal(index, query: str, history: Prompt, session_id: str, settings: Settings=Settings()):
try:
# 1. Find the most relevant blocks from the Alignment Research Dataset
yield {"state": "loading", "phase": "semantic"}
top_k_blocks = get_top_k_blocks(index, query, k)
top_k_blocks = get_top_k_blocks(index, query, settings.topKBlocks)

yield {
"state": "citations",
Expand All @@ -186,22 +139,24 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio

# 2. Generate a prompt
yield {"state": "loading", "phase": "prompt"}
prompt = construct_prompt(query, mode, history, top_k_blocks)
prompt = construct_prompt(query, settings, history, top_k_blocks)

# 3. Run both the standalone query and the full prompt through
# moderation to see if it will be accepted by OpenAI's api
check_openai_moderation(prompt, query)

# 4. Count number of tokens left for completion (-50 for a buffer)
max_tokens_completion = remaining_tokens(prompt)
max_tokens_completion = remaining_tokens(prompt, settings)
if max_tokens_completion < 40:
raise ValueError(f"{max_tokens_completion} tokens left for the actual query after constructing the context - aborting, as that's not going to be enough")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be changed? I arbitrarily put 40, but 100 would probably also be too few

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it should be a class or instance variable? Also, What is the trade off here? More tokens cost more per response potentially, but has a better chance of giving the user what we want?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's already a settings.numTokens setting which should handle that. The problem here is that a whole lot of additional context is added to the prompt before it gets to this point. Specifially:

  1. settings.source_prompt_prefix
  2. up to settings.topKBlocks of citations from pinecone
  3. settings.source_prompt_suffix
  4. (optional) the last settings.maxHistory items from the conversation
  5. settings.question_prompt
  6. (optional) settings.mode_prompt to regulate the complexity level of the answer
  7. the actual query

This prompt can get quite big, taking up a lot of the settings.numTokens available tokens (in the worst case taking all of them), which limits the number left for the actual response. A seperate setting might help, but won't solve the underlying problem, which would probably require some creative book keeping to limit the number of tokens available for each step. Steps 1, 3, 5, 6 and 7 should always be sent to the LLM (but even then can take up all the tokens), so it's really only the context and history steps that should be limited.


# 5. Answer the user query
yield {"state": "loading", "phase": "llm"}
t1 = time.time()
response = ''

for chunk in openai.ChatCompletion.create(
model=COMPLETIONS_MODEL,
model=settings.completions,
messages=prompt,
max_tokens=max_tokens_completion,
stream=True,
Expand Down Expand Up @@ -242,15 +197,15 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio


# convert talk_to_robot_internal from dict generator into json generator
def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K):
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k))
def talk_to_robot(index, query: str, history: List[Dict[str, str]], session_id: str, settings: Settings):
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, history, session_id, settings))


# wayyy simplified api
def talk_to_robot_simple(index, query: str):
res = {'response': ''}

for block in talk_to_robot_internal(index, query, "default", []):
for block in talk_to_robot_internal(index, query, []):
if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block:
citations = {}
for i, c in enumerate(block['citations']):
Expand Down
12 changes: 1 addition & 11 deletions api/src/stampy_chat/get_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def parse_block(match) -> Block:
date = date,
url = metadata['url'],
tags = metadata.get('tags'),
text = strip_block(metadata['text'])
text = metadata['text']
)


Expand Down Expand Up @@ -144,13 +144,3 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s')

return join_blocks(blocks)


# we add the title and authors inside the contents of the block, so that
# searches for the title or author will be more likely to pull it up. This
# strips it back out.
def strip_block(text: str) -> str:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this continuosly raises warnings that it can't find the titles. Is it even used anymore? Wans't that the whole point of adding the other fields?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just change the logic to if r: return r.group(1) if r else text ? no need for warning. I'm good to delete this too.

r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL)
if not r:
logger.warning("couldn't strip block:\n%s", text)
return r.group(1) if r else text
Loading