Skip to content

Commit

Permalink
add gpt-4o
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed May 16, 2024
1 parent d79f926 commit 9a12be3
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 32 deletions.
4 changes: 2 additions & 2 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from langchain.pydantic_v1 import Extra
from langchain.schema import AIMessage, BaseMessage, HumanMessage, PromptValue, SystemMessage

from stampy_chat.env import OPENAI_API_KEY, ANTHROPIC_API_KEY, COMPLETIONS_MODEL, LANGCHAIN_API_KEY, LANGCHAIN_TRACING_V2
from stampy_chat.env import OPENAI_API_KEY, ANTHROPIC_API_KEY, LANGCHAIN_API_KEY, LANGCHAIN_TRACING_V2, SUMMARY_MODEL
from stampy_chat.settings import Settings, MODELS, OPENAI, ANTRHROPIC
from stampy_chat.callbacks import StampyCallbackHandler, BroadcastCallbackHandler, LoggerCallbackHandler
from stampy_chat.followups import StampyChain
Expand Down Expand Up @@ -277,7 +277,7 @@ def make_prompt(settings, chat_model, callbacks):
def make_memory(settings, history, callbacks):
"""Create a memory object to store the chat history."""
memory = LimitedConversationSummaryBufferMemory(
llm=get_model(model=COMPLETIONS_MODEL), # used for summarization
llm=get_model(model=SUMMARY_MODEL), # used for summarization
max_token_limit=settings.history_tokens,
max_history=settings.maxHistory,
chat_memory=ChatMessageHistory(),
Expand Down
1 change: 1 addition & 0 deletions api/src/stampy_chat/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

### Models ###
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-ada-002")
SUMMARY_MODEL = os.environ.get("COMPLETIONS_MODEL", "claude-3-sonnet-20240229")
COMPLETIONS_MODEL = os.environ.get("COMPLETIONS_MODEL", "claude-3-opus-20240229")

### Pinecone ###
Expand Down
1 change: 1 addition & 0 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
'gpt-3.5-turbo-16k': Model(16385, 30, 4096, OPENAI),
'gpt-4': Model(8192, 20, 4096, OPENAI),
"gpt-4-turbo-preview": Model(128000, 50, 4096, OPENAI),
"gpt-4o": Model(128000, 50, 4096, OPENAI),
"claude-3-opus-20240229": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-sonnet-20240229": Model(200_000, 50, 4096, ANTRHROPIC),
"claude-3-haiku-20240307": Model(200_000, 50, 4096, ANTRHROPIC),
Expand Down
59 changes: 29 additions & 30 deletions api/tests/stampy_chat/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,19 @@ def selector():


@pytest.mark.parametrize('num, letter', (
(0, 'a'), (25, 'z'),
(26, '{'), # this is a basic ASCII translator, so too many citations will result in fun
(0, '1'), (25, '26'),
))
def test_ReferencesSelector_make_references(num, letter):
assert ReferencesSelector.make_reference(num) == letter


def test_ReferencesSelector_select_examples(selector):
assert selector.select_examples(input_variables={}) == [
{'bla': 'bla 0', 'id': '0', 'reference': 'a'},
{'bla': 'bla 1', 'id': '1', 'reference': 'b'},
{'bla': 'bla 2', 'id': '2', 'reference': 'c'},
{'bla': 'bla 3', 'id': '3', 'reference': 'd'},
{'bla': 'bla 4', 'id': '4', 'reference': 'e'},
{'bla': 'bla 0', 'id': '0', 'reference': '1'},
{'bla': 'bla 1', 'id': '1', 'reference': '2'},
{'bla': 'bla 2', 'id': '2', 'reference': '3'},
{'bla': 'bla 3', 'id': '3', 'reference': '4'},
{'bla': 'bla 4', 'id': '4', 'reference': '5'},
]


Expand All @@ -63,11 +62,11 @@ def test_ReferencesSelector_select_examples_callbacks(selector):
selector.callbacks = [callback]

expected_examples = [
{'bla': 'bla 0', 'id': '0', 'reference': 'a'},
{'bla': 'bla 1', 'id': '1', 'reference': 'b'},
{'bla': 'bla 2', 'id': '2', 'reference': 'c'},
{'bla': 'bla 3', 'id': '3', 'reference': 'd'},
{'bla': 'bla 4', 'id': '4', 'reference': 'e'},
{'bla': 'bla 0', 'id': '0', 'reference': '1'},
{'bla': 'bla 1', 'id': '1', 'reference': '2'},
{'bla': 'bla 2', 'id': '2', 'reference': '3'},
{'bla': 'bla 3', 'id': '3', 'reference': '4'},
{'bla': 'bla 4', 'id': '4', 'reference': '5'},
]
input_variables = {'var1': 'bla', 'var2': 'ble'}

Expand All @@ -84,11 +83,11 @@ def test_ReferencesSelector_select_examples_removes_duplicates(selector):
] * 5

assert selector.select_examples(input_variables={}) == [
{'bla': 'bla 0', 'id': '0', 'reference': 'a'},
{'bla': 'bla 1', 'id': '1', 'reference': 'b'},
{'bla': 'bla 2', 'id': '2', 'reference': 'c'},
{'bla': 'bla 3', 'id': '3', 'reference': 'd'},
{'bla': 'bla 4', 'id': '4', 'reference': 'e'},
{'bla': 'bla 0', 'id': '0', 'reference': '1'},
{'bla': 'bla 1', 'id': '1', 'reference': '2'},
{'bla': 'bla 2', 'id': '2', 'reference': '3'},
{'bla': 'bla 3', 'id': '3', 'reference': '4'},
{'bla': 'bla 4', 'id': '4', 'reference': '5'},
]


Expand All @@ -107,9 +106,9 @@ def calc_score(i):
]

assert selector.select_examples(input_variables={}) == [
{'bla': 'bla 0', 'id': '0', 'reference': 'a'},
{'bla': 'bla 2', 'id': '2', 'reference': 'b'},
{'bla': 'bla 4', 'id': '4', 'reference': 'c'},
{'bla': 'bla 0', 'id': '0', 'reference': '1'},
{'bla': 'bla 2', 'id': '2', 'reference': '2'},
{'bla': 'bla 4', 'id': '4', 'reference': '3'},
]


Expand All @@ -127,10 +126,10 @@ def searcher(query, *args, **kwargs):
Mock(content='last history item'),
]
assert selector.select_examples(input_variables={'query': 'queried value', 'history': history}) == [
{'bla': 'queried value', 'id': 'queried value', 'reference': 'a'},
{'bla': 'last history item', 'id': 'last history item', 'reference': 'b'},
{'bla': 'second history item', 'id': 'second history item', 'reference': 'c'},
{'bla': 'first history item', 'id': 'first history item', 'reference': 'd'}
{'bla': 'queried value', 'id': 'queried value', 'reference': '1'},
{'bla': 'last history item', 'id': 'last history item', 'reference': '2'},
{'bla': 'second history item', 'id': 'second history item', 'reference': '3'},
{'bla': 'first history item', 'id': 'first history item', 'reference': '4'}
]


Expand All @@ -150,12 +149,12 @@ def searcher(query, *args, **kwargs):
Mock(content='last history item'),
]
assert selector.select_examples(input_variables={'query': 'queried value', 'history': history}) == [
{'bla': 'queried value', 'id': 'queried value - 0', 'reference': 'a'},
{'bla': 'queried value', 'id': 'queried value - 1', 'reference': 'b'},
{'bla': 'queried value', 'id': 'queried value - 2', 'reference': 'c'},
{'bla': 'last history item', 'id': 'last history item - 0', 'reference': 'd'},
{'bla': 'last history item', 'id': 'last history item - 1', 'reference': 'e'},
{'bla': 'last history item', 'id': 'last history item - 2', 'reference': 'f'},
{'bla': 'queried value', 'id': 'queried value - 0', 'reference': '1'},
{'bla': 'queried value', 'id': 'queried value - 1', 'reference': '2'},
{'bla': 'queried value', 'id': 'queried value - 2', 'reference': '3'},
{'bla': 'last history item', 'id': 'last history item - 0', 'reference': '4'},
{'bla': 'last history item', 'id': 'last history item - 1', 'reference': '5'},
{'bla': 'last history item', 'id': 'last history item - 2', 'reference': '6'},
]


Expand Down
1 change: 1 addition & 0 deletions web/src/hooks/useSettings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export const MODELS: { [key: string]: Model } = {
"gpt-3.5-turbo-16k": { maxNumTokens: 16385, topKBlocks: 30 },
"gpt-4": { maxNumTokens: 8192, topKBlocks: 20 },
"gpt-4-turbo-preview": { maxNumTokens: 128000, topKBlocks: 50 },
"gpt-4o": { maxNumTokens: 128000, topKBlocks: 50 },
"claude-3-opus-20240229": { maxNumTokens: 200000, topKBlocks: 50 },
"claude-3-sonnet-20240229": { maxNumTokens: 200_000, topKBlocks: 50 },
"claude-3-haiku-20240307": { maxNumTokens: 200_000, topKBlocks: 50 },
Expand Down

0 comments on commit 9a12be3

Please sign in to comment.