Skip to content

Commit

Permalink
Merge pull request #128 from StampyAI/qa-endpoint
Browse files Browse the repository at this point in the history
Qa endpoint
  • Loading branch information
mruwnik committed Nov 8, 2023
2 parents e84f517 + 4823337 commit 4d602bf
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 31 deletions.
3 changes: 2 additions & 1 deletion api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ def human(id):
def ratings():
session_id = request.json.get('sessionId')
settings = request.json.get('settings', {})
comment = (request.json.get('comment') or '').strip() or None # only save strings if not empty
score = request.json.get('score')

if not session_id or score is None:
return Response('{"error": "missing params}', 400, mimetype='application/json')

with make_session() as s:
s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings)))
s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings), comment=comment))
s.commit()

return jsonify({'status': 'ok'})
Expand Down
17 changes: 17 additions & 0 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def set_messages(self, history: List[dict]) -> None:
for callback in self.callbacks:
callback.on_memory_set_end(self.chat_memory)

def prune(self) -> None:
"""Prune buffer if it exceeds max token limit.
This is the original Langchain version copied with a fix to handle the case when
all messages are longer than the max_token_limit
"""
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while buffer and curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
self.moving_summary_buffer = self.predict_new_summary(
pruned_memory, self.moving_summary_buffer
)


class ModeratedChatPrompt(ChatPromptTemplate):
"""Wraps a prompt with an OpenAI moderation check which will raise an exception if fails."""
Expand Down
13 changes: 8 additions & 5 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stampy_chat.env import COMPLETIONS_MODEL


Model = namedtuple('Model', ['maxTokens', 'topKBlocks'])
Model = namedtuple('Model', ['maxTokens', 'topKBlocks', 'maxCompletionTokens'])


SOURCE_PROMPT = (
Expand Down Expand Up @@ -51,9 +51,10 @@
'modes': PROMPT_MODES,
}
MODELS = {
'gpt-3.5-turbo': Model(4097, 10),
'gpt-3.5-turbo-16k': Model(16385, 30),
'gpt-4': Model(8192, 20),
'gpt-3.5-turbo': Model(4097, 10, 4096),
'gpt-3.5-turbo-16k': Model(16385, 30, 4096),
'gpt-4': Model(8192, 20, 4096),
"gpt-4-1106-preview": Model(128000, 50, 4096),
# 'gpt-4-32k': Model(32768, 30),
}

Expand Down Expand Up @@ -138,6 +139,8 @@ def set_completions(self, completions, maxNumTokens=None, topKBlocks=None):
else:
self.topKBlocks = MODELS[completions].topKBlocks

self.maxCompletionTokens = MODELS[completions].maxCompletionTokens

@property
def prompt_modes(self):
return self.prompts['modes']
Expand Down Expand Up @@ -170,4 +173,4 @@ def history_tokens(self):

@property
def max_response_tokens(self):
return self.maxNumTokens - self.context_tokens - self.history_tokens
return min(self.maxNumTokens - self.context_tokens - self.history_tokens, self.maxCompletionTokens)
2 changes: 1 addition & 1 deletion web/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => {
return;
} else if (
i === entries.length - 1 &&
["assistant", "stampy"].includes(entry.role)
["assistant", "stampy", "error"].includes(entry.role)
) {
const prev = entries[i - 1];
if (prev !== undefined) setQuery(prev.content);
Expand Down
45 changes: 35 additions & 10 deletions web/src/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,47 @@ import type { Parseable, LLMSettings, Entry, Mode } from "../types";
import { MODELS, ENCODERS } from "../hooks/useSettings";
import { SectionHeader, NumberInput, Slider } from "../components/html";

type ChatSettingsUpdate = [path: string[], value: any];
type ChatSettingsParams = {
settings: LLMSettings;
changeSetting: (path: string[], value: any) => void;
changeSettings: (...v: ChatSettingsUpdate[]) => void;
};

export const ChatSettings = ({
settings,
changeSetting,
changeSettings,
}: ChatSettingsParams) => {
const changeVal = (field: string, value: any) =>
changeSetting([field], value);
changeSettings([[field], value]);
const update = (field: string) => (event: ChangeEvent) =>
changeVal(field, (event.target as HTMLInputElement).value);
const updateNum = (field: string) => (num: Parseable) =>
changeVal(field, num);

const updateTokenFraction = (field: string) => (num: Parseable) => {
// Calculate the fraction of the tokens taken by the buffer
const bufferFraction =
settings.tokensBuffer && settings.maxNumTokens
? settings.tokensBuffer / settings.maxNumTokens
: 0;
const val = Math.min(parseFloat((num || 0).toString()), 1 - bufferFraction);

let context = settings.contextFraction || 0;
let history = settings.historyFraction || 0;

if (field == "contextFraction") {
history = Math.min(history, Math.max(0, 1 - val - bufferFraction));
context = val;
} else {
context = Math.min(context, Math.max(0, 1 - val - bufferFraction));
history = val;
}
changeSettings(
[["contextFraction"], context],
[["historyFraction"], history]
);
};

return (
<div
className="chat-settings mx-5 grid w-[400px] flex-none grid-cols-4 gap-4 border-2 outline-black"
Expand Down Expand Up @@ -118,13 +143,13 @@ export const ChatSettings = ({
value={settings.contextFraction}
field="contextFraction"
label="Approximate fraction of num_tokens to use for citations text before truncating"
updater={updateNum("contextFraction")}
updater={updateTokenFraction("contextFraction")}
/>
<Slider
value={settings.historyFraction}
field="historyFraction"
label="Approximate fraction of num_tokens to use for history text before truncating"
updater={updateNum("historyFraction")}
updater={updateTokenFraction("historyFraction")}
/>
</div>
);
Expand All @@ -134,22 +159,22 @@ type ChatPromptParams = {
settings: LLMSettings;
query: string;
history: Entry[];
changeSetting: (path: string[], value: any) => void;
changeSettings: (...vals: ChatSettingsUpdate[]) => void;
};

export const ChatPrompts = ({
settings,
query,
history,
changeSetting,
changeSettings,
}: ChatPromptParams) => {
const updatePrompt =
(...path: string[]) =>
(event: ChangeEvent) =>
changeSetting(
changeSettings([
["prompts", ...path],
(event.target as HTMLInputElement).value
);
(event.target as HTMLInputElement).value,
]);

return (
<div className="chat-prompts mx-5 w-[400px] flex-none border-2 p-5 outline-black">
Expand Down
28 changes: 22 additions & 6 deletions web/src/hooks/useSettings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export const MODELS: { [key: string]: Model } = {
"gpt-3.5-turbo": { maxNumTokens: 4095, topKBlocks: 10 },
"gpt-3.5-turbo-16k": { maxNumTokens: 16385, topKBlocks: 30 },
"gpt-4": { maxNumTokens: 8192, topKBlocks: 20 },
"gpt-4-1106-preview": { maxNumTokens: 128000, topKBlocks: 50 },
/* 'gpt-4-32k': {maxNumTokens: 32768, topKBlocks: 30}, */
};
export const ENCODERS = ["cl100k_base"];
Expand Down Expand Up @@ -169,25 +170,39 @@ type ChatSettingsParams = {
changeSetting: (path: string[], value: any) => void;
};

type SettingsUpdatePair = [path: string[], val: any];

export default function useSettings() {
const [settingsLoaded, setLoaded] = useState(false);
const [settings, updateSettings] = useState<LLMSettings>(makeSettings({}));
const router = useRouter();

const updateInUrl = (path: string[], value: any) =>
const updateInUrl = (vals: { [key: string]: any }) =>
router.replace({
pathname: router.pathname,
query: {
...router.query,
[path.join(".")]: value.toString(),
},
query: { ...router.query, ...vals },
});

const changeSetting = (path: string[], value: any) => {
updateInUrl(path, value);
updateInUrl({ [path.join(".")]: value });
updateSettings((settings) => ({ ...updateIn(settings, path, value) }));
};

const changeSettings = (...items: SettingsUpdatePair) => {
updateInUrl(
items.reduce(
(acc, [path, val]) => ({ ...acc, [path.join(".")]: val }),
{}
)
);
updateSettings((settings) =>
items.reduce(
(acc, [path, val]) => ({ ...acc, ...updateIn(settings, path, val) }),
settings
)
);
};

const setMode = (mode: Mode | undefined) => {
if (mode) {
updateSettings({ ...settings, mode: mode });
Expand All @@ -210,6 +225,7 @@ export default function useSettings() {
return {
settings,
changeSetting,
changeSettings,
setMode,
settingsLoaded,
randomize,
Expand Down
6 changes: 3 additions & 3 deletions web/src/pages/playground.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const Playground: NextPage = () => {

const [query, setQuery] = useState<string>("");
const [history, setHistory] = useState<Entry[]>([]);
const { settings, changeSetting, setMode } = useSettings();
const { settings, changeSettings, setMode } = useSettings();

// initial load
useEffect(() => {
Expand All @@ -28,15 +28,15 @@ const Playground: NextPage = () => {
settings={settings}
query={query}
history={history}
changeSetting={changeSetting}
changeSettings={changeSettings}
/>
<Chat
sessionId={sessionId}
settings={settings}
onQuery={setQuery}
onNewEntry={setHistory}
/>
<ChatSettings settings={settings} changeSetting={changeSetting} />
<ChatSettings settings={settings} changeSettings={changeSettings} />
</div>
</Page>
);
Expand Down
11 changes: 9 additions & 2 deletions web/src/pages/qa.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ const MAX_FOLLOWUPS = 4;
export const saveRatings = async (
sessionId: string,
score: number,
comment: string | null,
settings: LLMSettings
): Promise<any> =>
fetch(API_URL + "/ratings", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ sessionId, settings, score }),
body: JSON.stringify({ sessionId, settings, score, comment }),
}).then((r) => r.json());

const Rater = ({
Expand All @@ -41,8 +42,10 @@ const Rater = ({
sessionId: string;
reset: () => void;
}) => {
const [comment, setComment] = useState<string | null>(null);

const onRate = async (rate: number) => {
const res = await saveRatings(sessionId, rate, settings);
const res = await saveRatings(sessionId, rate, comment, settings);
if (!res.error) reset();
};

Expand All @@ -58,6 +61,10 @@ const Rater = ({
))}
<span>Good</span>
</div>
<textarea
placeholder="Add any comments here"
onChange={(e) => setComment(e.target.value)}
></textarea>
</div>
);
};
Expand Down
6 changes: 3 additions & 3 deletions web/src/pages/tester.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const Tester: NextPage = () => {
initialQuestions.map((q, i) => ({ question: q, selected: true, index: i }))
);

const { settings, changeSetting, setMode, settingsLoaded } = useSettings();
const { settings, changeSettings, setMode, settingsLoaded } = useSettings();

/** Run a search for the given `question` and insert the query promise into it
*/
Expand Down Expand Up @@ -118,7 +118,7 @@ const Tester: NextPage = () => {
settings={settings}
query="<this is where the query will go>"
history={[]}
changeSetting={changeSetting}
changeSettings={changeSettings}
/>
<div className="chat-settings mx-5 w-[400px] flex-none gap-4 border-2 outline-black">
{questions.map(({ question, selected }, i) => (
Expand Down Expand Up @@ -159,7 +159,7 @@ const Tester: NextPage = () => {
)}
</div>

<ChatSettings settings={settings} changeSetting={changeSetting} />
<ChatSettings settings={settings} changeSettings={changeSettings} />
</div>
</Page>
);
Expand Down
5 changes: 5 additions & 0 deletions web/src/styles/globals.css
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ ol {
width: 2.5em;
margin: 0.5em;
}

.rate-container textarea {
width: 100%;
border: black solid 1px;
}

0 comments on commit 4d602bf

Please sign in to comment.