-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chat options #111
Chat options #111
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,128 +6,80 @@ | |
from typing import List, Dict | ||
|
||
import openai | ||
import tiktoken | ||
|
||
from stampy_chat.env import COMPLETIONS_MODEL | ||
from stampy_chat import logging | ||
from stampy_chat.followups import multisearch_authored | ||
from stampy_chat.get_blocks import get_top_k_blocks, Block | ||
from stampy_chat import logging | ||
from stampy_chat.settings import Settings | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
STANDARD_K = 20 if COMPLETIONS_MODEL == 'gpt-4' else 10 | ||
|
||
# parameters | ||
|
||
# NOTE: All this is approximate, there's bits I'm intentionally not counting. Leave a buffer beyond what you might expect. | ||
NUM_TOKENS = 8191 if COMPLETIONS_MODEL == 'gpt-4' else 4095 | ||
TOKENS_BUFFER = 50 # the number of tokens to leave as a buffer when calculating remaining tokens | ||
HISTORY_FRACTION = 0.25 # the (approximate) fraction of num_tokens to use for history text before truncating | ||
CONTEXT_FRACTION = 0.5 # the (approximate) fraction of num_tokens to use for context text before truncating | ||
|
||
ENCODER = tiktoken.get_encoding("cl100k_base") | ||
|
||
SOURCE_PROMPT = ( | ||
"You are a helpful assistant knowledgeable about AI Alignment and Safety. " | ||
"Please give a clear and coherent answer to the user's questions.(written after \"Q:\") " | ||
"using the following sources. Each source is labeled with a letter. Feel free to " | ||
"use the sources in any order, and try to use multiple sources in your answers.\n\n" | ||
) | ||
SOURCE_PROMPT_SUFFIX = ( | ||
"\n\n" | ||
"Before the question (\"Q: \"), there will be a history of previous questions and answers. " | ||
"These sources only apply to the last question. Any sources used in previous answers " | ||
"are invalid." | ||
) | ||
|
||
QUESTION_PROMPT = ( | ||
"In your answer, please cite any claims you make back to each source " | ||
"using the format: [a], [b], etc. If you use multiple sources to make a claim " | ||
"cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n" | ||
) | ||
PROMPT_MODES = { | ||
'default': "", | ||
"concise": ( | ||
"Answer very concisely, getting to the crux of the matter in as " | ||
"few words as possible. Limit your answer to 1-2 sentences.\n\n" | ||
), | ||
"rookie": ( | ||
"This user is new to the field of AI Alignment and Safety - don't " | ||
"assume they know any technical terms or jargon. Still give a complete answer " | ||
"without patronizing the user, but take any extra time needed to " | ||
"explain new concepts or to illustrate your answer with examples. " | ||
"Put extra effort into explaining the intuition behind concepts " | ||
"rather than just giving a formal definition.\n\n" | ||
), | ||
} | ||
|
||
# --------------------------------- prompt code -------------------------------- | ||
|
||
|
||
|
||
# limit a string to a certain number of tokens | ||
def cap(text: str, max_tokens: int) -> str: | ||
def cap(text: str, max_tokens: int, encoder) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the other changes in this file are basically to get it to use the settings object, rather than various constants |
||
if max_tokens <= 0: | ||
return "..." | ||
|
||
encoded_text = ENCODER.encode(text) | ||
encoded_text = encoder.encode(text) | ||
|
||
if len(encoded_text) <= max_tokens: | ||
return text | ||
return ENCODER.decode(encoded_text[:max_tokens]) + " ..." | ||
return encoder.decode(encoded_text[:max_tokens]) + " ..." | ||
|
||
|
||
Prompt = List[Dict[str, str]] | ||
|
||
|
||
def prompt_context(source_prompt: str, context: List[Block], max_tokens: int) -> str: | ||
token_count = len(ENCODER.encode(source_prompt)) | ||
def prompt_context(context: List[Block], settings: Settings) -> str: | ||
source_prompt = settings.source_prompt_prefix | ||
max_tokens = settings.context_tokens | ||
encoder = settings.encoder | ||
|
||
token_count = len(encoder.encode(source_prompt)) | ||
|
||
# Context from top-k blocks | ||
for i, block in enumerate(context): | ||
block_str = f"[{chr(ord('a') + i)}] {block.title} - {','.join(block.authors)} - {block.date}\n{block.text}\n\n" | ||
block_tc = len(ENCODER.encode(block_str)) | ||
block_tc = len(encoder.encode(block_str)) | ||
|
||
if token_count + block_tc > max_tokens: | ||
source_prompt += cap(block_str, max_tokens - token_count) | ||
source_prompt += cap(block_str, max_tokens - token_count, encoder) | ||
break | ||
else: | ||
source_prompt += block_str | ||
token_count += block_tc | ||
return source_prompt.strip() | ||
|
||
|
||
def prompt_history(history: Prompt, max_tokens: int, n_items=10) -> Prompt: | ||
def prompt_history(history: Prompt, settings: Settings) -> Prompt: | ||
max_tokens = settings.history_tokens | ||
encoder = settings.encoder | ||
token_count = 0 | ||
prompt = [] | ||
|
||
# Get the n_items last messages, starting from the last one. This is because it's assumed | ||
# that more recent messages are more important. The `-1` is because of how slicing works | ||
messages = history[:-n_items - 1:-1] | ||
messages = history[:-settings.maxHistory - 1:-1] | ||
for message in messages: | ||
if message["role"] == "user": | ||
prompt.append({"role": "user", "content": "Q: " + message["content"]}) | ||
token_count += len(ENCODER.encode("Q: " + message["content"])) | ||
token_count += len(encoder.encode("Q: " + message["content"])) | ||
else: | ||
content = message["content"] | ||
# censor all source letters into [x] | ||
content = re.sub(r"\[[0-9]+\]", "[x]", content) | ||
content = cap(content, max_tokens - token_count) | ||
content = cap(content, max_tokens - token_count, encoder) | ||
|
||
prompt.append({"role": "assistant", "content": content}) | ||
token_count += len(ENCODER.encode(content)) | ||
token_count += len(encoder.encode(content)) | ||
|
||
if token_count > max_tokens: | ||
break | ||
return prompt[::-1] | ||
|
||
|
||
def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt: | ||
if mode not in PROMPT_MODES: | ||
raise ValueError("Invalid mode: " + mode) | ||
|
||
def construct_prompt(query: str, settings: Settings, history: Prompt, context: List[Block]) -> Prompt: | ||
# History takes the format: history=[ | ||
# {"role": "user", "content": "Die monster. You don’t belong in this world!"}, | ||
# {"role": "assistant", "content": "It was not by my hand I am once again given flesh. I was called here by humans who wished to pay me tribute."}, | ||
|
@@ -138,14 +90,14 @@ def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block | |
# ] | ||
|
||
# Context from top-k blocks | ||
source_prompt = prompt_context(SOURCE_PROMPT, context, int(NUM_TOKENS * CONTEXT_FRACTION)) | ||
source_prompt = prompt_context(context, settings) | ||
if history: | ||
source_prompt += SOURCE_PROMPT_SUFFIX | ||
source_prompt += settings.source_prompt_suffix | ||
source_prompt = [{"role": "system", "content": source_prompt.strip()}] | ||
|
||
# Write a version of the last 10 messages into history, cutting things off when we hit the token limit. | ||
history_prompt = prompt_history(history, int(NUM_TOKENS * HISTORY_FRACTION)) | ||
question_prompt = [{"role": "user", "content": QUESTION_PROMPT + PROMPT_MODES[mode] + "Q: " + query}] | ||
history_prompt = prompt_history(history, settings) | ||
question_prompt = [{"role": "user", "content": settings.question_prompt(query)}] | ||
|
||
return source_prompt + history_prompt + question_prompt | ||
|
||
|
@@ -161,20 +113,21 @@ def check_openai_moderation(prompt: Prompt, query: str): | |
raise ValueError("This conversation was rejected by OpenAI's moderation filter. Sorry.") | ||
|
||
|
||
def remaining_tokens(prompt: Prompt): | ||
def remaining_tokens(prompt: Prompt, settings: Settings): | ||
# Count number of tokens left for completion (-50 for a buffer) | ||
encoder = settings.encoder | ||
used_tokens = sum([ | ||
len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"])) | ||
len(encoder.encode(message["content"]) + encoder.encode(message["role"])) | ||
for message in prompt | ||
]) | ||
return max(0, NUM_TOKENS - used_tokens - TOKENS_BUFFER) | ||
return max(0, settings.numTokens - used_tokens - settings.tokensBuffer) | ||
|
||
|
||
def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, session_id: str, k: int = STANDARD_K): | ||
def talk_to_robot_internal(index, query: str, history: Prompt, session_id: str, settings: Settings=Settings()): | ||
try: | ||
# 1. Find the most relevant blocks from the Alignment Research Dataset | ||
yield {"state": "loading", "phase": "semantic"} | ||
top_k_blocks = get_top_k_blocks(index, query, k) | ||
top_k_blocks = get_top_k_blocks(index, query, settings.topKBlocks) | ||
|
||
yield { | ||
"state": "citations", | ||
|
@@ -186,22 +139,24 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio | |
|
||
# 2. Generate a prompt | ||
yield {"state": "loading", "phase": "prompt"} | ||
prompt = construct_prompt(query, mode, history, top_k_blocks) | ||
prompt = construct_prompt(query, settings, history, top_k_blocks) | ||
|
||
# 3. Run both the standalone query and the full prompt through | ||
# moderation to see if it will be accepted by OpenAI's api | ||
check_openai_moderation(prompt, query) | ||
|
||
# 4. Count number of tokens left for completion (-50 for a buffer) | ||
max_tokens_completion = remaining_tokens(prompt) | ||
max_tokens_completion = remaining_tokens(prompt, settings) | ||
if max_tokens_completion < 40: | ||
raise ValueError(f"{max_tokens_completion} tokens left for the actual query after constructing the context - aborting, as that's not going to be enough") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be changed? I arbitrarily put 40, but 100 would probably also be too few There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it should be a class or instance variable? Also, What is the trade off here? More tokens cost more per response potentially, but has a better chance of giving the user what we want? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's already a
This prompt can get quite big, taking up a lot of the |
||
|
||
# 5. Answer the user query | ||
yield {"state": "loading", "phase": "llm"} | ||
t1 = time.time() | ||
response = '' | ||
|
||
for chunk in openai.ChatCompletion.create( | ||
model=COMPLETIONS_MODEL, | ||
model=settings.completions, | ||
messages=prompt, | ||
max_tokens=max_tokens_completion, | ||
stream=True, | ||
|
@@ -242,15 +197,15 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, sessio | |
|
||
|
||
# convert talk_to_robot_internal from dict generator into json generator | ||
def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K): | ||
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k)) | ||
def talk_to_robot(index, query: str, history: List[Dict[str, str]], session_id: str, settings: Settings): | ||
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, history, session_id, settings)) | ||
|
||
|
||
# wayyy simplified api | ||
def talk_to_robot_simple(index, query: str): | ||
res = {'response': ''} | ||
|
||
for block in talk_to_robot_internal(index, query, "default", []): | ||
for block in talk_to_robot_internal(index, query, []): | ||
if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block: | ||
citations = {} | ||
for i, c in enumerate(block['citations']): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,7 +71,7 @@ def parse_block(match) -> Block: | |
date = date, | ||
url = metadata['url'], | ||
tags = metadata.get('tags'), | ||
text = strip_block(metadata['text']) | ||
text = metadata['text'] | ||
) | ||
|
||
|
||
|
@@ -144,13 +144,3 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]: | |
logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s') | ||
|
||
return join_blocks(blocks) | ||
|
||
|
||
# we add the title and authors inside the contents of the block, so that | ||
# searches for the title or author will be more likely to pull it up. This | ||
# strips it back out. | ||
def strip_block(text: str) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this continuosly raises warnings that it can't find the titles. Is it even used anymore? Wans't that the whole point of adding the other fields? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just change the logic to |
||
r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL) | ||
if not r: | ||
logger.warning("couldn't strip block:\n%s", text) | ||
return r.group(1) if r else text |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of these are moved over into the Settings class to be able to pass them around