Skip to content

Commit

Permalink
Add anthropic support
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Mar 23, 2024
1 parent f96f275 commit 491b188
Show file tree
Hide file tree
Showing 8 changed files with 1,252 additions and 1,069 deletions.
1 change: 1 addition & 0 deletions api/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ sqlalchemy = "*"
mysql-connector-python = "*"
langchain = "*"
transformers = "*"
langchain-anthropic = "*"

[dev-packages]
pytest = "*"
Expand Down
2,179 changes: 1,154 additions & 1,025 deletions api/Pipfile.lock

Large diffs are not rendered by default.

71 changes: 54 additions & 17 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import os
from typing import Any, Callable, Dict, List

from langchain.chains import LLMChain, OpenAIModerationChain
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain.memory import ChatMessageHistory, ConversationSummaryBufferMemory
from langchain.prompts import (
BaseChatPromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
FewShotChatMessagePromptTemplate
FewShotChatMessagePromptTemplate,
HumanMessagePromptTemplate
)
from langchain.pydantic_v1 import Extra
from langchain.schema import BaseMessage, ChatMessage, PromptValue, SystemMessage
from langchain.schema import AIMessage, BaseMessage, HumanMessage, PromptValue, SystemMessage

from stampy_chat.env import OPENAI_API_KEY, COMPLETIONS_MODEL, LANGCHAIN_API_KEY, LANGCHAIN_TRACING_V2
from stampy_chat.settings import Settings
from stampy_chat.env import OPENAI_API_KEY, ANTHROPIC_API_KEY, COMPLETIONS_MODEL, LANGCHAIN_API_KEY, LANGCHAIN_TRACING_V2
from stampy_chat.settings import Settings, MODELS, OPENAI, ANTRHROPIC
from stampy_chat.callbacks import StampyCallbackHandler, BroadcastCallbackHandler, LoggerCallbackHandler
from stampy_chat.followups import StampyChain
from stampy_chat.citations import make_example_selector
Expand Down Expand Up @@ -61,8 +61,14 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
return messages


def ChatMessage(m):
if m['role'] == 'assistant':
return AIMessage(**m)
return HumanMessage(**m)


class PrefixedPrompt(BaseChatPromptTemplate):
"""A prompt that will prefix any messages with a system prompt, but only if messages provided."""
"""A prompt that will prefix any messages with a system prompt, but only if messages are provided."""

transformer: Callable[[Any], BaseMessage] = lambda i: i
messages_field: str
Expand All @@ -71,7 +77,7 @@ class PrefixedPrompt(BaseChatPromptTemplate):
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
history = kwargs[self.messages_field]
if history and self.prompt:
return [SystemMessage(content=self.prompt)] + [self.transformer(i) for i in history]
return [AIMessage(content=self.prompt)] + [self.transformer(i) for i in history]
return []


Expand All @@ -96,7 +102,7 @@ def set_messages(self, history: List[dict]) -> None:
for callback in self.callbacks:
callback.on_memory_set_start(history)

messages = [ChatMessage(**m) for m in history]
messages = [ChatMessage(m) for m in history if m.get('role') != 'deleted']
# If there are more than `max_history` messages, first summarize the older ones. If there
# are n messages (where n > max_history), then the first `n - max_history + 1` should be
# summarized and inserted as the first item in the history, so as to ensure there are
Expand All @@ -105,7 +111,7 @@ def set_messages(self, history: List[dict]) -> None:
offset = -self.max_history + 1

pruned = messages[:offset]
summary = ChatMessage(role='assistant', content=self.predict_new_summary(pruned, ''))
summary = AIMessage(role='assistant', content=self.predict_new_summary(pruned, ''))

messages = [summary] + messages[offset:]

Expand Down Expand Up @@ -160,8 +166,37 @@ def format_prompt(self, **kwargs: Any) -> PromptValue:
return prompt


class ChatAnthropicWrapper(ChatAnthropic):
"""Make sure the Anthropic endpoint can handle prompts.
Anthropic only allows alternating human - ai messages, so join them up first.
So much for langchain being plug'n'play...
"""
def _format_params(self, *args, **kwargs):
first = kwargs['messages'][0]
# Anthropic requires the first message to be either a system or human message
if isinstance(first, AIMessage):
first = SystemMessage(content=first.content)

messages = [first]
for m in kwargs['messages'][1:]:
if m.type != messages[-1].type:
messages.append(m)
else:
messages[-1].content += '\n\n' + m.content
kwargs['messages'] = messages
return super()._format_params(*args, **kwargs)


def get_model(**kwargs):
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs)
model = MODELS.get(kwargs.get('model'))
if not model:
raise ValueError("No model provided")
if model.publisher == ANTRHROPIC:
return ChatAnthropicWrapper(anthropic_api_key=ANTHROPIC_API_KEY, **kwargs)
if model.publisher == OPENAI:
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs)
raise ValueError(f'Unsupported model: {kwargs.get("model")}')


class LLMInputsChain(LLMChain):
Expand Down Expand Up @@ -191,7 +226,7 @@ def make_history_summary(settings):
input_variables=['history'],
messages_field='history',
prompt=settings.history_summary_prompt,
transformer=lambda m: ChatMessage(**m)
transformer=ChatMessage,
)
return LLMInputsChain(
llm=model,
Expand All @@ -200,9 +235,9 @@ def make_history_summary(settings):
prompt=ModeratedChatPrompt.from_messages([
summary_prompt,
ChatPromptTemplate.from_messages([
ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'),
HumanMessagePromptTemplate.from_template(template='Q: {query}', role='user'),
]),
SystemMessage(content="Reply in one sentence only"),
HumanMessage(content="Reply in one sentence only"),
]),
)

Expand All @@ -225,8 +260,8 @@ def make_prompt(settings, chat_model, callbacks):
# 3. Construct the main query
query_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=settings.question_prompt),
ChatMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'),
HumanMessage(content=settings.question_prompt),
HumanMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'),
]
)

Expand Down Expand Up @@ -267,6 +302,8 @@ def merge_history(history):
messages = []
current_message = history[0]
for message in history[1:]:
if message.get('role') == 'deleted':
continue
if message.get('role') != current_message.get('role'):
messages.append(current_message)
current_message = message
Expand Down
7 changes: 5 additions & 2 deletions api/src/stampy_chat/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
SemanticSimilarityExampleSelector
)
from langchain.pydantic_v1 import Extra
from langchain.vectorstores import Pinecone
from langchain_community.vectorstores import Pinecone

from stampy_chat.env import PINECONE_INDEX, PINECONE_NAMESPACE, OPENAI_API_KEY, REMOTE_CHAT_INSTANCE
from stampy_chat.callbacks import StampyCallbackHandler
Expand Down Expand Up @@ -89,7 +89,10 @@ def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
for item in history[::-1]:
if len(examples) >= self.k:
break
examples += self.fetch_docs({'answer': item.content})
if isinstance(item, dict):
examples += self.fetch_docs({'answer': item['content']})
else:
examples += self.fetch_docs({'answer': item.content})

examples = [
dict(
Expand Down
7 changes: 4 additions & 3 deletions api/src/stampy_chat/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
# import openai
import pinecone
from pinecone import Pinecone

if os.path.exists('.env'):
from dotenv import load_dotenv
Expand All @@ -16,6 +16,7 @@

### OpenAI ###
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY')

### Models ###
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-ada-002")
Expand All @@ -30,12 +31,12 @@

# Only init pinecone if we have an env value for it.
if PINECONE_API_KEY:
pinecone.init(
pc = Pinecone(
api_key = PINECONE_API_KEY,
environment = PINECONE_ENVIRONMENT,
)

PINECONE_INDEX = pinecone.Index(index_name=PINECONE_INDEX_NAME)
PINECONE_INDEX = pc.Index(PINECONE_INDEX_NAME)

### MySQL ###
user = os.environ.get("CHAT_DB_USER", "user")
Expand Down
19 changes: 13 additions & 6 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stampy_chat.env import COMPLETIONS_MODEL


Model = namedtuple('Model', ['maxTokens', 'topKBlocks', 'maxCompletionTokens'])
Model = namedtuple('Model', ['maxTokens', 'topKBlocks', 'maxCompletionTokens', 'publisher'])


SOURCE_PROMPT = (
Expand Down Expand Up @@ -58,12 +58,19 @@
'question': QUESTION_PROMPT,
'modes': PROMPT_MODES,
}
OPENAI = 'openai'
ANTRHROPIC = 'anthropic'
MODELS = {
'gpt-3.5-turbo': Model(4097, 10, 4096),
'gpt-3.5-turbo-16k': Model(16385, 30, 4096),
'gpt-4': Model(8192, 20, 4096),
"gpt-4-1106-preview": Model(128000, 50, 4096),
# 'gpt-4-32k': Model(32768, 30),
'gpt-3.5-turbo': Model(4097, 10, 4096, OPENAI),
'gpt-3.5-turbo-16k': Model(16385, 30, 4096, OPENAI),
'gpt-4': Model(8192, 20, 4096, OPENAI),
"gpt-4-turbo-preview": Model(128000, 50, 4096, OPENAI),
"claude-3-opus-20240229": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-sonnet-20240229": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-haiku-20240307": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-2.1": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-2.0": Model(100_000, 50, 4096, ANTRHROPIC),
"claude-instant-1.2": Model(100_000, 50, 4096, ANTRHROPIC),
}


Expand Down
28 changes: 14 additions & 14 deletions api/tests/stampy_chat/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain.llms.fake import FakeListLLM
from langchain.memory import ChatMessageHistory
from langchain.prompts import ChatPromptTemplate
from langchain.schema import ChatMessage, HumanMessage, SystemMessage
from langchain.schema import AIMessage, HumanMessage

from stampy_chat.settings import Settings
from stampy_chat.callbacks import StampyCallbackHandler
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_PrefixedPrompt_format_messages():
prompt = PrefixedPrompt(messages_field='history', prompt='bla bla bla', input_variables=[])
history = [HumanMessage(content=f'human message {i}') for i in range(5)]
assert prompt.format_messages(history=history) == [
SystemMessage(content='bla bla bla'),
AIMessage(content='bla bla bla'),
HumanMessage(content='human message 0'),
HumanMessage(content='human message 1'),
HumanMessage(content='human message 2'),
Expand Down Expand Up @@ -94,8 +94,8 @@ def test_LimitedConversationSummaryBufferMemory_set():
{'content': 'bla bla bla', 'role': 'human'},
])
assert memory.chat_memory == ChatMessageHistory(messages=[
ChatMessage(content='a system message', role='system'),
ChatMessage(content='bla bla bla', role='human'),
HumanMessage(content='a system message', role='system'),
HumanMessage(content='bla bla bla', role='human'),
])


Expand All @@ -112,10 +112,10 @@ def test_LimitedConversationSummaryBufferMemory_set_more():
{'content': 'message 5 - should be kept', 'role': 'human'},
])
assert memory.chat_memory == ChatMessageHistory(messages=[
ChatMessage(content='this is a summary of what was before', role='assistant'),
ChatMessage(content='message 3 - should be kept', role='human'),
ChatMessage(content='message 4 - should be kept', role='human'),
ChatMessage(content='message 5 - should be kept', role='human'),
AIMessage(content='this is a summary of what was before', role='assistant'),
HumanMessage(content='message 3 - should be kept', role='human'),
HumanMessage(content='message 4 - should be kept', role='human'),
HumanMessage(content='message 5 - should be kept', role='human'),
])


Expand All @@ -137,8 +137,8 @@ def on_memory_set_end(self, messages):

memory.set_messages(history)
assert memory.chat_memory == ChatMessageHistory(messages=[
ChatMessage(content='a system message', role='system'),
ChatMessage(content='bla bla bla', role='human'),
HumanMessage(content='a system message', role='system'),
HumanMessage(content='bla bla bla', role='human'),
])
assert callback_calls == {
'start': history,
Expand All @@ -157,9 +157,9 @@ def test_make_memory_skips_deleted():
with patch('stampy_chat.chat.get_model', return_value=FakeListLLM(responses=[])):
mem = make_memory(Settings(), history, [])
assert mem.chat_memory == ChatMessageHistory(messages=[
ChatMessage(content='this should be kept', role='system'),
ChatMessage(content='as should this', role='human'),
ChatMessage(content='bla bla bla', role='assistant'),
HumanMessage(content='this should be kept', role='system'),
HumanMessage(content='as should this', role='human'),
AIMessage(content='bla bla bla', role='assistant'),
])


Expand All @@ -175,7 +175,7 @@ def test_merge_history_no_merges():
{'content': 'bla bla bla', 'role': 'assistant'},
{'content': 'remove me!!', 'role': 'deleted'},
]
assert merge_history(history) == history
assert merge_history(history) == [m for m in history if m['role'] != 'deleted']


def test_merge_history_merges():
Expand Down
9 changes: 7 additions & 2 deletions web/src/hooks/useSettings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@ export const MODELS: { [key: string]: Model } = {
"gpt-3.5-turbo": { maxNumTokens: 4095, topKBlocks: 10 },
"gpt-3.5-turbo-16k": { maxNumTokens: 16385, topKBlocks: 30 },
"gpt-4": { maxNumTokens: 8192, topKBlocks: 20 },
"gpt-4-1106-preview": { maxNumTokens: 128000, topKBlocks: 50 },
/* 'gpt-4-32k': {maxNumTokens: 32768, topKBlocks: 30}, */
"gpt-4-turbo-preview": { maxNumTokens: 128000, topKBlocks: 50 },
"claude-3-opus-20240229": { maxNumTokens: 200000, topKBlocks: 50},
"claude-3-sonnet-20240229": { maxNumTokens: 200_000, topKBlocks: 50},
"claude-3-haiku-20240307": { maxNumTokens: 200_000, topKBlocks: 50},
"claude-2.1": { maxNumTokens: 200_000, topKBlocks: 50},
"claude-2.0": { maxNumTokens: 100_000, topKBlocks: 50},
"claude-instant-1.2": { maxNumTokens: 100_000, topKBlocks: 50},
};
export const ENCODERS = ["cl100k_base"];

Expand Down

0 comments on commit 491b188

Please sign in to comment.