Skip to content

Commit

Permalink
Merge pull request #111 from StampyAI/chat-options
Browse files Browse the repository at this point in the history
Chat options
  • Loading branch information
mruwnik authored Oct 2, 2023
2 parents 60173a1 + 43f1aab commit 1b43c28
Show file tree
Hide file tree
Showing 13 changed files with 814 additions and 274 deletions.
5 changes: 3 additions & 2 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stampy_chat.env import PINECONE_INDEX, FLASK_PORT
from stampy_chat.get_blocks import get_top_k_blocks
from stampy_chat.chat import talk_to_robot, talk_to_robot_simple
from stampy_chat.settings import Settings


# ---------------------------------- web setup ---------------------------------
Expand Down Expand Up @@ -44,11 +45,11 @@ def semantic():
def chat():

query = request.json.get('query')
mode = request.json.get('mode', 'default')
session_id = request.json.get('sessionId')
history = request.json.get('history', [])
settings = Settings(**request.json.get('settings', {}))

return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history, session_id)), mimetype='text/event-stream')
return Response(stream(talk_to_robot(PINECONE_INDEX, query, history, session_id, settings)), mimetype='text/event-stream')


# ------------- simplified non-streaming chat for internal testing -------------
Expand Down
123 changes: 39 additions & 84 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,128 +6,80 @@
from typing import List, Dict

import openai
import tiktoken

from stampy_chat.env import COMPLETIONS_MODEL
from stampy_chat import logging
from stampy_chat.followups import multisearch_authored
from stampy_chat.get_blocks import get_top_k_blocks, Block
from stampy_chat import logging
from stampy_chat.settings import Settings


logger = logging.getLogger(__name__)


STANDARD_K = 20 if COMPLETIONS_MODEL == 'gpt-4' else 10

# parameters

# NOTE: All this is approximate, there's bits I'm intentionally not counting. Leave a buffer beyond what you might expect.
NUM_TOKENS = 8191 if COMPLETIONS_MODEL == 'gpt-4' else 4095
TOKENS_BUFFER = 50 # the number of tokens to leave as a buffer when calculating remaining tokens
HISTORY_FRACTION = 0.25 # the (approximate) fraction of num_tokens to use for history text before truncating
CONTEXT_FRACTION = 0.5 # the (approximate) fraction of num_tokens to use for context text before truncating

ENCODER = tiktoken.get_encoding("cl100k_base")

SOURCE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please give a clear and coherent answer to the user's questions.(written after \"Q:\") "
"using the following sources. Each source is labeled with a letter. Feel free to "
"use the sources in any order, and try to use multiple sources in your answers.\n\n"
)
SOURCE_PROMPT_SUFFIX = (
"\n\n"
"Before the question (\"Q: \"), there will be a history of previous questions and answers. "
"These sources only apply to the last question. Any sources used in previous answers "
"are invalid."
)

QUESTION_PROMPT = (
"In your answer, please cite any claims you make back to each source "
"using the format: [a], [b], etc. If you use multiple sources to make a claim "
"cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n"
)
PROMPT_MODES = {
'default': "",
"concise": (
"Answer very concisely, getting to the crux of the matter in as "
"few words as possible. Limit your answer to 1-2 sentences.\n\n"
),
"rookie": (
"This user is new to the field of AI Alignment and Safety - don't "
"assume they know any technical terms or jargon. Still give a complete answer "
"without patronizing the user, but take any extra time needed to "
"explain new concepts or to illustrate your answer with examples. "
"Put extra effort into explaining the intuition behind concepts "
"rather than just giving a formal definition.\n\n"
),
}

# --------------------------------- prompt code --------------------------------



# limit a string to a certain number of tokens
def cap(text: str, max_tokens: int) -> str:
def cap(text: str, max_tokens: int, encoder) -> str:
if max_tokens <= 0:
return "..."

encoded_text = ENCODER.encode(text)
encoded_text = encoder.encode(text)

if len(encoded_text) <= max_tokens:
return text
return ENCODER.decode(encoded_text[:max_tokens]) + " ..."
return encoder.decode(encoded_text[:max_tokens]) + " ..."


Prompt = List[Dict[str, str]]


def prompt_context(source_prompt: str, context: List[Block], max_tokens: int) -> str:
token_count = len(ENCODER.encode(source_prompt))
def prompt_context(context: List[Block], settings: Settings) -> str:
source_prompt = settings.source_prompt_prefix
max_tokens = settings.context_tokens
encoder = settings.encoder

token_count = len(encoder.encode(source_prompt))

# Context from top-k blocks
for i, block in enumerate(context):
block_str = f"[{chr(ord('a') + i)}] {block.title} - {','.join(block.authors)} - {block.date}\n{block.text}\n\n"
block_tc = len(ENCODER.encode(block_str))
block_tc = len(encoder.encode(block_str))

if token_count + block_tc > max_tokens:
source_prompt += cap(block_str, max_tokens - token_count)
source_prompt += cap(block_str, max_tokens - token_count, encoder)
break
else:
source_prompt += block_str
token_count += block_tc
return source_prompt.strip()


def prompt_history(history: Prompt, max_tokens: int, n_items=10) -> Prompt:
def prompt_history(history: Prompt, settings: Settings) -> Prompt:
max_tokens = settings.history_tokens
encoder = settings.encoder
token_count = 0
prompt = []

# Get the n_items last messages, starting from the last one. This is because it's assumed
# that more recent messages are more important. The `-1` is because of how slicing works
messages = history[:-n_items - 1:-1]
messages = history[:-settings.maxHistory - 1:-1]
for message in messages:
if message["role"] == "user":
prompt.append({"role": "user", "content": "Q: " + message["content"]})
token_count += len(ENCODER.encode("Q: " + message["content"]))
token_count += len(encoder.encode("Q: " + message["content"]))
else:
content = message["content"]
# censor all source letters into [x]
content = re.sub(r"\[[0-9]+\]", "[x]", content)
content = cap(content, max_tokens - token_count)
content = cap(content, max_tokens - token_count, encoder)

prompt.append({"role": "assistant", "content": content})
token_count += len(ENCODER.encode(content))
token_count += len(encoder.encode(content))

if token_count > max_tokens:
break
return prompt[::-1]


def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt:
if mode not in PROMPT_MODES:
raise ValueError("Invalid mode: " + mode)

def construct_prompt(query: str, settings: Settings, history: Prompt, context: List[Block]) -> Prompt:
# History takes the format: history=[
# {"role": "user", "content": "Die monster. You don’t belong in this world!"},
# {"role": "assistant", "content": "It was not by my hand I am once again given flesh. I was called here by humans who wished to pay me tribute."},
Expand All @@ -138,14 +90,14 @@ def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block
# ]

# Context from top-k blocks
source_prompt = prompt_context(SOURCE_PROMPT, context, int(NUM_TOKENS * CONTEXT_FRACTION))
source_prompt = prompt_context(context, settings)
if history:
source_prompt += SOURCE_PROMPT_SUFFIX
source_prompt += settings.source_prompt_suffix
source_prompt = [{"role": "system", "content": source_prompt.strip()}]

# Write a version of the last 10 messages into history, cutting things off when we hit the token limit.
history_prompt = prompt_history(history, int(NUM_TOKENS * HISTORY_FRACTION))
question_prompt = [{"role": "user", "content": QUESTION_PROMPT + PROMPT_MODES[mode] + "Q: " + query}]
history_prompt = prompt_history(history, settings)
question_prompt = [{"role": "user", "content": settings.question_prompt(query)}]

return source_prompt + history_prompt + question_prompt

Expand All @@ -161,20 +113,21 @@ def check_openai_moderation(prompt: Prompt, query: str):
raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.")


def remaining_tokens(prompt: Prompt):
def remaining_tokens(prompt: Prompt, settings: Settings):
# Count number of tokens left for completion (-50 for a buffer)
encoder = settings.encoder
used_tokens = sum([
len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"]))
len(encoder.encode(message["content"]) + encoder.encode(message["role"]))
for message in prompt
])
return max(0, NUM_TOKENS - used_tokens - TOKENS_BUFFER)
return max(0, settings.numTokens - used_tokens - settings.tokensBuffer)


def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, session_id: str, k: int = STANDARD_K):
def talk_to_robot_internal(index, query: str, history: Prompt, session_id: str, settings: Settings=Settings()):
try:
# 1. Find the most relevant blocks from the Alignment Research Dataset
yield {"state": "loading", "phase": "semantic"}
top_k_blocks = get_top_k_blocks(index, query, k)
top_k_blocks = get_top_k_blocks(index, query, settings.topKBlocks)

yield {
"state": "citations",
Expand All @@ -186,22 +139,24 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio

# 2. Generate a prompt
yield {"state": "loading", "phase": "prompt"}
prompt = construct_prompt(query, mode, history, top_k_blocks)
prompt = construct_prompt(query, settings, history, top_k_blocks)

# 3. Run both the standalone query and the full prompt through
# moderation to see if it will be accepted by OpenAI's api
check_openai_moderation(prompt, query)

# 4. Count number of tokens left for completion (-50 for a buffer)
max_tokens_completion = remaining_tokens(prompt)
max_tokens_completion = remaining_tokens(prompt, settings)
if max_tokens_completion < 40:
raise ValueError(f"{max_tokens_completion} tokens left for the actual query after constructing the context - aborting, as that's not going to be enough")

# 5. Answer the user query
yield {"state": "loading", "phase": "llm"}
t1 = time.time()
response = ''

for chunk in openai.ChatCompletion.create(
model=COMPLETIONS_MODEL,
model=settings.completions,
messages=prompt,
max_tokens=max_tokens_completion,
stream=True,
Expand Down Expand Up @@ -242,15 +197,15 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio


# convert talk_to_robot_internal from dict generator into json generator
def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K):
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k))
def talk_to_robot(index, query: str, history: List[Dict[str, str]], session_id: str, settings: Settings):
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, history, session_id, settings))


# wayyy simplified api
def talk_to_robot_simple(index, query: str):
res = {'response': ''}

for block in talk_to_robot_internal(index, query, "default", []):
for block in talk_to_robot_internal(index, query, []):
if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block:
citations = {}
for i, c in enumerate(block['citations']):
Expand Down
12 changes: 1 addition & 11 deletions api/src/stampy_chat/get_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def parse_block(match) -> Block:
date = date,
url = metadata['url'],
tags = metadata.get('tags'),
text = strip_block(metadata['text'])
text = metadata['text']
)


Expand Down Expand Up @@ -144,13 +144,3 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s')

return join_blocks(blocks)


# we add the title and authors inside the contents of the block, so that
# searches for the title or author will be more likely to pull it up. This
# strips it back out.
def strip_block(text: str) -> str:
r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL)
if not r:
logger.warning("couldn't strip block:\n%s", text)
return r.group(1) if r else text
Loading

0 comments on commit 1b43c28

Please sign in to comment.